Kaynağa Gözat

调整 chat 和 steamChat 获取对话逻辑

cherishsince 1 yıl önce
ebeveyn
işleme
25523fd53e

+ 11 - 12
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/impl/ChatServiceImpl.java

@@ -18,7 +18,9 @@ import cn.iocoder.yudao.module.ai.enums.ChatTypeEnum;
 import cn.iocoder.yudao.module.ai.mapper.AiChatConversationMapper;
 import cn.iocoder.yudao.module.ai.mapper.AiChatMessageMapper;
 import cn.iocoder.yudao.module.ai.mapper.AiChatRoleMapper;
+import cn.iocoder.yudao.module.ai.service.ChatConversationService;
 import cn.iocoder.yudao.module.ai.service.ChatService;
+import cn.iocoder.yudao.module.ai.vo.ChatConversationRes;
 import cn.iocoder.yudao.module.ai.vo.ChatReq;
 import lombok.AllArgsConstructor;
 import lombok.extern.slf4j.Slf4j;
@@ -46,6 +48,7 @@ public class ChatServiceImpl implements ChatService {
     private final AiChatRoleMapper aiChatRoleMapper;
     private final AiChatMessageMapper aiChatMessageMapper;
     private final AiChatConversationMapper aiChatConversationMapper;
+    private final ChatConversationService chatConversationService;
 
 
     /**
@@ -59,13 +62,10 @@ public class ChatServiceImpl implements ChatService {
         Long loginUserId = SecurityFrameworkUtils.getLoginUserId();
         // 获取 client 类型
         AiClientNameEnum clientNameEnum = AiClientNameEnum.valueOfName(req.getModal());
-        // 获取 对话类型(新建还是继续)
-        ChatConversationTypeEnum chatConversationTypeEnum = ChatConversationTypeEnum.valueOfType(req.getConversationType());
-        AiChatConversationDO aiChatConversationDO = getChatConversationNoExistToCreate(req, chatConversationTypeEnum, loginUserId);
-
+        // 获取对话信息
+        ChatConversationRes conversationRes = chatConversationService.getConversation(req.getConversationId());
         // 保存 chat message
-        saveChatMessage(req, aiChatConversationDO.getId(), loginUserId);
-
+        saveChatMessage(req, conversationRes.getId(), loginUserId);
         String content = null;
         try {
             // 创建 chat 需要的 Prompt
@@ -80,7 +80,7 @@ public class ChatServiceImpl implements ChatService {
             content = ExceptionUtil.getMessage(e);
         } finally {
             // 保存 chat message
-            saveSystemChatMessage(req, aiChatConversationDO.getId(), loginUserId, content);
+            saveSystemChatMessage(req, conversationRes.getId(), loginUserId, content);
         }
         return content;
     }
@@ -176,16 +176,15 @@ public class ChatServiceImpl implements ChatService {
         Long loginUserId = SecurityFrameworkUtils.getLoginUserId();
         // 获取 client 类型
         AiClientNameEnum clientNameEnum = AiClientNameEnum.valueOfName(req.getModal());
-        // 获取 对话类型(新建还是继续)
-        ChatConversationTypeEnum chatConversationTypeEnum = ChatConversationTypeEnum.valueOfType(req.getConversationType());
-        AiChatConversationDO aiChatConversationDO = getChatConversationNoExistToCreate(req, chatConversationTypeEnum, loginUserId);
+        // 获取对话信息
+        ChatConversationRes conversationRes = chatConversationService.getConversation(req.getConversationId());
         // 创建 chat 需要的 Prompt
         Prompt prompt = new Prompt(req.getPrompt());
         req.setTopK(req.getTopK());
         req.setTopP(req.getTopP());
         req.setTemperature(req.getTemperature());
         // 保存 chat message
-        saveChatMessage(req, aiChatConversationDO.getId(), loginUserId);
+        saveChatMessage(req, conversationRes.getId(), loginUserId);
         Flux<ChatResponse> streamResponse = aiClient.stream(prompt, clientNameEnum.getName());
 
         StringBuffer contentBuffer = new StringBuffer();
@@ -212,7 +211,7 @@ public class ChatServiceImpl implements ChatService {
                     log.info("发送完成!");
                     sseEmitter.complete();
                     // 保存 chat message
-                    saveSystemChatMessage(req, aiChatConversationDO.getId(), loginUserId, contentBuffer.toString());
+                    saveSystemChatMessage(req, conversationRes.getId(), loginUserId, contentBuffer.toString());
                 }
         );
     }