瀏覽代碼

【解决todo】AI 写作、脑图:model、systemMessage获取逻辑调整

xiaoxin 8 月之前
父節點
當前提交
29e421432d

+ 2 - 3
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/mindmap/AiMindMapDO.java

@@ -12,8 +12,7 @@ import lombok.Data;
  *
  * @author xiaoxin
  */
-// TODO @xin:如果没 typehandler 的需求,autoResultMap 可以去掉哈
-@TableName(value = "ai_mind_map", autoResultMap = true)
+@TableName(value = "ai_mind_map")
 @Data
 public class AiMindMapDO extends BaseDO {
 
@@ -25,7 +24,7 @@ public class AiMindMapDO extends BaseDO {
 
     /**
      * 用户编号
-     *
+     * <p>
      * 关联 AdminUserDO 的 userId 字段
      */
     private Long userId;

+ 35 - 18
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/mindmap/AiMindMapServiceImpl.java

@@ -1,6 +1,7 @@
 package cn.iocoder.yudao.module.ai.service.mindmap;
 
 import cn.hutool.core.collection.CollUtil;
+import cn.hutool.core.lang.Assert;
 import cn.hutool.core.util.StrUtil;
 import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
 import cn.iocoder.yudao.framework.ai.core.util.AiUtils;
@@ -57,33 +58,25 @@ public class AiMindMapServiceImpl implements AiMindMapService {
 
     @Override
     public Flux<CommonResult<String>> generateMindMap(AiMindMapGenerateReqVO generateReqVO, Long userId) {
-        // 1.1 获取脑图模型 尝试获取思维导图助手角色,如果没有则使用默认模型
+        // 1 获取脑图模型 尝试获取思维导图助手角色,如果没有则使用默认模型
         AiChatRoleDO mindMapRole = CollUtil.getFirst(chatRoleService.getChatRoleListByName(AiChatRoleEnum.AI_MIND_MAP_ROLE.getName()));
-        AiChatModelDO model;
-        String systemMessage;
-        if (Objects.nonNull(mindMapRole) && Objects.nonNull(mindMapRole.getModelId())) {
-            model = chatModalService.getChatModel(mindMapRole.getModelId());
-            systemMessage = mindMapRole.getSystemMessage();
-        } else {
-            model = chatModalService.getRequiredDefaultChatModel();
-            systemMessage = AiChatRoleEnum.AI_MIND_MAP_ROLE.getSystemMessage();
-        }
-
+        // 1.1 获取脑图执行模型
+        AiChatModelDO model = getModel(mindMapRole);
+        // 1.2 获取角色设定消息
+        String systemMessage = Objects.nonNull(mindMapRole) && StrUtil.isNotBlank(mindMapRole.getSystemMessage())
+                ? mindMapRole.getSystemMessage() : AiChatRoleEnum.AI_MIND_MAP_ROLE.getSystemMessage();
+        // 1.3 校验平台
         AiPlatformEnum platform = AiPlatformEnum.validatePlatform(model.getPlatform());
         ChatModel chatModel = apiKeyService.getChatModel(model.getKeyId());
 
         // 2 插入思维导图信息
-        AiMindMapDO mindMapDO = BeanUtils.toBean(generateReqVO, AiMindMapDO.class, e -> e.setUserId(userId).setModel(model.getModel()).setPlatform(platform.getPlatform()));
+        AiMindMapDO mindMapDO = BeanUtils.toBean(generateReqVO, AiMindMapDO.class,
+                mindMap -> mindMap.setUserId(userId).setModel(model.getModel()).setPlatform(platform.getPlatform()));
         mindMapMapper.insert(mindMapDO);
 
         ChatOptions chatOptions = AiUtils.buildChatOptions(platform, model.getModel(), model.getTemperature(), model.getMaxTokens());
         // 3.1 角色设定
-        List<Message> chatMessages = new ArrayList<>();
-        if (StrUtil.isNotBlank(systemMessage)) {
-            chatMessages.add(new SystemMessage(systemMessage));
-        }
-        // 3.2 用户输入
-        chatMessages.add(new UserMessage(generateReqVO.getPrompt()));
+        List<Message> chatMessages = buildMessages(generateReqVO, systemMessage);
         // 3.3 构建提示词
         Prompt prompt = new Prompt(chatMessages, chatOptions);
 
@@ -109,4 +102,28 @@ public class AiMindMapServiceImpl implements AiMindMapService {
 
     }
 
+    private static List<Message> buildMessages(AiMindMapGenerateReqVO generateReqVO, String systemMessage) {
+        List<Message> chatMessages = new ArrayList<>();
+        if (StrUtil.isNotBlank(systemMessage)) {
+            // 1.1 角色设定
+            chatMessages.add(new SystemMessage(systemMessage));
+        }
+        // 1.2 用户输入
+        chatMessages.add(new UserMessage(generateReqVO.getPrompt()));
+        return chatMessages;
+    }
+
+    // TODO 芋艿:这里脑图、写作都用到了,是不是可以抽哪里去
+    private AiChatModelDO getModel(AiChatRoleDO chatRoleDO) {
+        AiChatModelDO model = null;
+        if (Objects.nonNull(chatRoleDO) && Objects.nonNull(chatRoleDO.getModelId())) {
+            model = chatModalService.getChatModel(chatRoleDO.getModelId());
+        }
+        if (Objects.isNull(model)) {
+            model = chatModalService.getRequiredDefaultChatModel();
+        }
+        Assert.notNull(model, "[AI] 获取不到模型");
+        return model;
+    }
+
 }

+ 35 - 21
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/write/AiWriteServiceImpl.java

@@ -1,6 +1,7 @@
 package cn.iocoder.yudao.module.ai.service.write;
 
 import cn.hutool.core.collection.CollUtil;
+import cn.hutool.core.lang.Assert;
 import cn.hutool.core.util.StrUtil;
 import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
 import cn.iocoder.yudao.framework.ai.core.util.AiUtils;
@@ -67,19 +68,14 @@ public class AiWriteServiceImpl implements AiWriteService {
 
     @Override
     public Flux<CommonResult<String>> generateWriteContent(AiWriteGenerateReqVO generateReqVO, Long userId) {
-        // 1.1 获取写作模型 尝试获取写作助手角色,如果没有则使用默认模型
+        // 1 获取写作模型 尝试获取写作助手角色,没有则使用默认模型
         AiChatRoleDO writeRole = CollUtil.getFirst(chatRoleService.getChatRoleListByName(AiChatRoleEnum.AI_WRITE_ROLE.getName()));
-        // TODO @xin:如果有 writeRole,但是没 modeId,是不是也可以用 systemMessage 哈?建议的写法是:先通过 modelId 获取 model。如果 model == null,则 chatModalService.getRequiredDefaultChatModel();如果还是 null,则抛出异常;。。。。。。。。。。。。。。然后,systemMessage = writeRole != null && writeRole.systemPrompt != "" 这样处理。
-        AiChatModelDO model;
-        String systemMessage;
-        if (Objects.nonNull(writeRole) && Objects.nonNull(writeRole.getModelId())) {
-            model = chatModalService.getChatModel(writeRole.getModelId());
-            systemMessage = writeRole.getSystemMessage();
-        } else {
-            model = chatModalService.getRequiredDefaultChatModel();
-            systemMessage = AiChatRoleEnum.AI_WRITE_ROLE.getSystemMessage();
-        }
-        // 1.2 校验平台
+        // 1.1 获取写作执行模型
+        AiChatModelDO model = getModel(writeRole);
+        // 1.2 获取角色设定消息
+        String systemMessage = Objects.nonNull(writeRole) && StrUtil.isNotBlank(writeRole.getSystemMessage())
+                ? writeRole.getSystemMessage() : AiChatRoleEnum.AI_WRITE_ROLE.getSystemMessage();
+        // 1.3 校验平台
         AiPlatformEnum platform = AiPlatformEnum.validatePlatform(model.getPlatform());
         StreamingChatModel chatModel = apiKeyService.getChatModel(model.getKeyId());
 
@@ -90,16 +86,11 @@ public class AiWriteServiceImpl implements AiWriteService {
 
         // 3. 调用大模型,写作生成
         ChatOptions chatOptions = AiUtils.buildChatOptions(platform, model.getModel(), model.getTemperature(), model.getMaxTokens());
-        // 3.1 角色设定
-        // TODO @xin:要不把 90 到 97 这部分,合并到一个方法里。目的是:让这个方法的主干更明确
-        List<Message> chatMessages = new ArrayList<>();
-        if (StrUtil.isNotBlank(systemMessage)) {
-            chatMessages.add(new SystemMessage(systemMessage));
-        }
-        // 3.2 用户输入
-        chatMessages.add(new UserMessage(buildWritingPrompt(generateReqVO)));
-        // 3.3 构建提示词
+        // 3.1 构建消息列表
+        List<Message> chatMessages = buildMessages(generateReqVO, systemMessage);
+        // 3.2 构建提示词
         Prompt prompt = new Prompt(chatMessages, chatOptions);
+        // 3.3 流式调用
         Flux<ChatResponse> streamResponse = chatModel.stream(prompt);
 
         // 4. 流式返回
@@ -122,6 +113,29 @@ public class AiWriteServiceImpl implements AiWriteService {
         }).onErrorResume(error -> Flux.just(error(ErrorCodeConstants.WRITE_STREAM_ERROR)));
     }
 
+    private AiChatModelDO getModel(AiChatRoleDO writeRole) {
+        AiChatModelDO model = null;
+        if (Objects.nonNull(writeRole) && Objects.nonNull(writeRole.getModelId())) {
+            model = chatModalService.getChatModel(writeRole.getModelId());
+        }
+        if (Objects.isNull(model)) {
+            model = chatModalService.getRequiredDefaultChatModel();
+        }
+        Assert.notNull(model, "[AI] 获取不到模型");
+        return model;
+    }
+
+    private List<Message> buildMessages(AiWriteGenerateReqVO generateReqVO, String systemMessage) {
+        List<Message> chatMessages = new ArrayList<>();
+        if (StrUtil.isNotBlank(systemMessage)) {
+            // 1.1 角色设定
+            chatMessages.add(new SystemMessage(systemMessage));
+        }
+        // 1.2 用户输入
+        chatMessages.add(new UserMessage(buildWritingPrompt(generateReqVO)));
+        return chatMessages;
+    }
+
     private String buildWritingPrompt(AiWriteGenerateReqVO generateReqVO) {
         String format = dictDataApi.getDictDataLabel(DictTypeConstants.AI_WRITE_FORMAT, generateReqVO.getFormat());
         String tone = dictDataApi.getDictDataLabel(DictTypeConstants.AI_WRITE_TONE, generateReqVO.getTone());