Sfoglia il codice sorgente

stream 保存聊天记录

cherishsince 1 anno fa
parent
commit
f84d25d3b7

+ 1 - 26
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/ChatController.java

@@ -2,7 +2,6 @@ package cn.iocoder.yudao.module.ai.controller;
 
 import cn.hutool.core.exceptions.ExceptionUtil;
 import cn.iocoder.yudao.framework.ai.chat.ChatResponse;
-import cn.iocoder.yudao.framework.ai.config.AiClient;
 import cn.iocoder.yudao.framework.common.pojo.CommonResult;
 import cn.iocoder.yudao.module.ai.service.ChatService;
 import cn.iocoder.yudao.module.ai.vo.ChatReq;
@@ -38,7 +37,6 @@ import java.util.function.Consumer;
 public class ChatController {
 
     @Autowired
-    private AiClient aiClient;
     private final ChatService chatService;
 
     @Operation(summary = "聊天-chat", description = "这个一般等待时间比较久,需要全部完成才会返回!")
@@ -52,30 +50,7 @@ public class ChatController {
     @GetMapping(value = "/chatStream", produces = MediaType.TEXT_EVENT_STREAM_VALUE)
     public SseEmitter chatStream(@Validated @ModelAttribute ChatReq req) {
         Utf8SseEmitter sseEmitter = new Utf8SseEmitter();
-        Flux<ChatResponse> streamResponse = chatService.chatStream(req);
-        streamResponse.subscribe(
-                new Consumer<ChatResponse>() {
-                    @Override
-                    public void accept(ChatResponse chatResponse) {
-                        String content = chatResponse.getResults().get(0).getOutput().getContent();
-                        try {
-                            sseEmitter.send(content, MediaType.APPLICATION_JSON);
-                        } catch (IOException e) {
-                            log.error("发送异常{}", ExceptionUtil.getMessage(e));
-                            // 如果不是因为关闭而抛出异常,则重新连接
-                            sseEmitter.completeWithError(e);
-                        }
-                    }
-                },
-                error -> {
-                    //
-                    log.error("subscribe错误 {}", ExceptionUtil.getMessage(error));
-                },
-                () -> {
-                    log.info("发送完成!");
-                    sseEmitter.complete();
-                }
-        );
+        chatService.chatStream(req, sseEmitter);
         return sseEmitter;
     }
 }

+ 3 - 4
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/ChatService.java

@@ -1,9 +1,7 @@
 package cn.iocoder.yudao.module.ai.service;
 
-import cn.iocoder.yudao.framework.ai.chat.ChatResponse;
-import cn.iocoder.yudao.module.ai.enums.AiClientNameEnum;
+import cn.iocoder.yudao.module.ai.controller.Utf8SseEmitter;
 import cn.iocoder.yudao.module.ai.vo.ChatReq;
-import reactor.core.publisher.Flux;
 
 /**
  * 聊天 chat
@@ -26,7 +24,8 @@ public interface ChatService {
      * chat stream
      *
      * @param req
+     * @param sseEmitter
      * @return
      */
-    Flux<ChatResponse> chatStream(ChatReq req);
+    void chatStream(ChatReq req, Utf8SseEmitter sseEmitter);
 }

+ 88 - 17
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/impl/ChatServiceImpl.java

@@ -8,6 +8,7 @@ import cn.iocoder.yudao.framework.ai.config.AiClient;
 import cn.iocoder.yudao.framework.common.exception.ServerException;
 import cn.iocoder.yudao.framework.security.core.util.SecurityFrameworkUtils;
 import cn.iocoder.yudao.module.ai.ErrorCodeConstants;
+import cn.iocoder.yudao.module.ai.controller.Utf8SseEmitter;
 import cn.iocoder.yudao.module.ai.dataobject.AiChatConversationDO;
 import cn.iocoder.yudao.module.ai.dataobject.AiChatMessageDO;
 import cn.iocoder.yudao.module.ai.dataobject.AiChatRoleDO;
@@ -21,10 +22,14 @@ import cn.iocoder.yudao.module.ai.service.ChatService;
 import cn.iocoder.yudao.module.ai.vo.ChatReq;
 import lombok.AllArgsConstructor;
 import lombok.extern.slf4j.Slf4j;
+import org.springframework.http.MediaType;
 import org.springframework.stereotype.Service;
 import org.springframework.transaction.annotation.Transactional;
 import reactor.core.publisher.Flux;
 
+import java.io.IOException;
+import java.util.function.Consumer;
+
 /**
  * 聊天 service
  *
@@ -51,25 +56,17 @@ public class ChatServiceImpl implements ChatService {
      */
     @Transactional(rollbackFor = Exception.class)
     public String chat(ChatReq req) {
+        Long loginUserId = SecurityFrameworkUtils.getLoginUserId();
         // 获取 client 类型
         AiClientNameEnum clientNameEnum = AiClientNameEnum.valueOfName(req.getModal());
         // 获取 对话类型(新建还是继续)
         ChatConversationTypeEnum chatConversationTypeEnum = ChatConversationTypeEnum.valueOfType(req.getConversationType());
+        AiChatConversationDO aiChatConversationDO = getChatConversationNoExistToCreate(req, chatConversationTypeEnum, loginUserId);
 
-        AiChatConversationDO aiChatConversationDO;
-        Long loginUserId = SecurityFrameworkUtils.getLoginUserId();
-        if (ChatConversationTypeEnum.NEW == chatConversationTypeEnum) {
-            // 创建一个新的对话
-            aiChatConversationDO = createNewChatConversation(req, loginUserId);
-        } else {
-            // 继续对话
-            if (req.getConversationId() == null) {
-                throw new ServerException(ErrorCodeConstants.AI_CHAT_CONTINUE_CONVERSATION_ID_NOT_NULL);
-            }
-            aiChatConversationDO = aiChatConversationMapper.selectById(req.getConversationId());
-        }
+        // 保存 chat message
+        saveChatMessage(req, aiChatConversationDO.getId(), loginUserId);
 
-        String content;
+        String content = null;
         try {
             // 创建 chat 需要的 Prompt
             Prompt prompt = new Prompt(req.getPrompt());
@@ -81,13 +78,19 @@ public class ChatServiceImpl implements ChatService {
             content = call.getResult().getOutput().getContent();
         } catch (Exception e) {
             content = ExceptionUtil.getMessage(e);
+        } finally {
+            // 保存 chat message
+            saveSystemChatMessage(req, aiChatConversationDO.getId(), loginUserId, content);
         }
+        return content;
+    }
 
+    private void saveChatMessage(ChatReq req, Long chatConversationId, Long loginUserId) {
         // 增加 chat message 记录
         aiChatMessageMapper.insert(
                 new AiChatMessageDO()
                         .setId(null)
-                        .setChatConversationId(aiChatConversationDO.getId())
+                        .setChatConversationId(chatConversationId)
                         .setUserId(loginUserId)
                         .setMessage(req.getPrompt())
                         .setMessageType(MessageType.USER.getValue())
@@ -98,7 +101,39 @@ public class ChatServiceImpl implements ChatService {
 
         // chat count 先+1
         aiChatConversationMapper.updateIncrChatCount(req.getConversationId());
-        return content;
+    }
+
+    public void saveSystemChatMessage(ChatReq req, Long chatConversationId, Long loginUserId, String systemPrompts) {
+        // 增加 chat message 记录
+        aiChatMessageMapper.insert(
+                new AiChatMessageDO()
+                        .setId(null)
+                        .setChatConversationId(chatConversationId)
+                        .setUserId(loginUserId)
+                        .setMessage(systemPrompts)
+                        .setMessageType(MessageType.SYSTEM.getValue())
+                        .setTopK(req.getTopK())
+                        .setTopP(req.getTopP())
+                        .setTemperature(req.getTemperature())
+        );
+
+        // chat count 先+1
+        aiChatConversationMapper.updateIncrChatCount(req.getConversationId());
+    }
+
+    private AiChatConversationDO getChatConversationNoExistToCreate(ChatReq req, ChatConversationTypeEnum chatConversationTypeEnum, Long loginUserId) {
+        AiChatConversationDO aiChatConversationDO;
+        if (ChatConversationTypeEnum.NEW == chatConversationTypeEnum) {
+            // 创建一个新的对话
+            aiChatConversationDO = createNewChatConversation(req, loginUserId);
+        } else {
+            // 继续对话
+            if (req.getConversationId() == null) {
+                throw new ServerException(ErrorCodeConstants.AI_CHAT_CONTINUE_CONVERSATION_ID_NOT_NULL);
+            }
+            aiChatConversationDO = aiChatConversationMapper.selectById(req.getConversationId());
+        }
+        return aiChatConversationDO;
     }
 
     private AiChatConversationDO createNewChatConversation(ChatReq req, Long loginUserId) {
@@ -133,16 +168,52 @@ public class ChatServiceImpl implements ChatService {
      * chat stream
      *
      * @param req
+     * @param sseEmitter
      * @return
      */
     @Override
-    public Flux<ChatResponse> chatStream(ChatReq req) {
+    public void chatStream(ChatReq req, Utf8SseEmitter sseEmitter) {
+        Long loginUserId = SecurityFrameworkUtils.getLoginUserId();
+        // 获取 client 类型
         AiClientNameEnum clientNameEnum = AiClientNameEnum.valueOfName(req.getModal());
+        // 获取 对话类型(新建还是继续)
+        ChatConversationTypeEnum chatConversationTypeEnum = ChatConversationTypeEnum.valueOfType(req.getConversationType());
+        AiChatConversationDO aiChatConversationDO = getChatConversationNoExistToCreate(req, chatConversationTypeEnum, loginUserId);
         // 创建 chat 需要的 Prompt
         Prompt prompt = new Prompt(req.getPrompt());
         req.setTopK(req.getTopK());
         req.setTopP(req.getTopP());
         req.setTemperature(req.getTemperature());
-        return aiClient.stream(prompt, clientNameEnum.getName());
+        // 保存 chat message
+        saveChatMessage(req, aiChatConversationDO.getId(), loginUserId);
+        Flux<ChatResponse> streamResponse = aiClient.stream(prompt, clientNameEnum.getName());
+
+        StringBuffer contentBuffer = new StringBuffer();
+        streamResponse.subscribe(
+                new Consumer<ChatResponse>() {
+                    @Override
+                    public void accept(ChatResponse chatResponse) {
+                        String content = chatResponse.getResults().get(0).getOutput().getContent();
+                        try {
+                            contentBuffer.append(content);
+                            sseEmitter.send(content, MediaType.APPLICATION_JSON);
+                        } catch (IOException e) {
+                            log.error("发送异常{}", ExceptionUtil.getMessage(e));
+                            // 如果不是因为关闭而抛出异常,则重新连接
+                            sseEmitter.completeWithError(e);
+                        }
+                    }
+                },
+                error -> {
+                    //
+                    log.error("subscribe错误 {}", ExceptionUtil.getMessage(error));
+                },
+                () -> {
+                    log.info("发送完成!");
+                    sseEmitter.complete();
+                    // 保存 chat message
+                    saveSystemChatMessage(req, aiChatConversationDO.getId(), loginUserId, contentBuffer.toString());
+                }
+        );
     }
 }