Browse Source

【优化】聊天 event stream 改为 flex 返回更加的优雅

cherishsince 11 tháng trước cách đây
mục cha
commit
5579620140

+ 0 - 26
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/Utf8SseEmitter.java

@@ -1,26 +0,0 @@
-package cn.iocoder.yudao.module.ai.controller;
-
-import org.springframework.http.HttpHeaders;
-import org.springframework.http.MediaType;
-import org.springframework.http.server.ServerHttpResponse;
-import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
-
-import java.nio.charset.StandardCharsets;
-
-/**
- * 解决中文乱码
- *
- * @author fansili
- * @time 2024/4/14 15:13
- * @since 1.0
- */
-public class Utf8SseEmitter extends SseEmitter {
-
-    @Override
-    protected void extendResponse(ServerHttpResponse outputMessage) {
-        super.extendResponse(outputMessage);
-
-        HttpHeaders headers = outputMessage.getHeaders();
-        headers.setContentType(new MediaType(MediaType.TEXT_EVENT_STREAM, StandardCharsets.UTF_8));
-    }
-}

+ 4 - 7
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/chat/AiChatMessageController.java

@@ -1,10 +1,9 @@
 package cn.iocoder.yudao.module.ai.controller.admin.chat;
 
 import cn.iocoder.yudao.framework.common.pojo.CommonResult;
-import cn.iocoder.yudao.module.ai.controller.Utf8SseEmitter;
 import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageRespVO;
-import cn.iocoder.yudao.module.ai.service.AiChatService;
 import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageSendReqVO;
+import cn.iocoder.yudao.module.ai.service.AiChatService;
 import io.swagger.v3.oas.annotations.Operation;
 import io.swagger.v3.oas.annotations.Parameter;
 import io.swagger.v3.oas.annotations.tags.Tag;
@@ -13,7 +12,7 @@ import lombok.extern.slf4j.Slf4j;
 import org.springframework.http.MediaType;
 import org.springframework.validation.annotation.Validated;
 import org.springframework.web.bind.annotation.*;
-import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
+import reactor.core.publisher.Flux;
 
 import java.util.List;
 
@@ -39,10 +38,8 @@ public class AiChatMessageController {
     // TODO @fan:要不要使用 Flux 来返回;可以使用 Flux<AiChatMessageRespVO>
     @Operation(summary = "发送消息(流式)", description = "流式返回,响应较快")
     @PostMapping(value = "/send-stream", produces = MediaType.TEXT_EVENT_STREAM_VALUE)
-    public SseEmitter sendMessageStream(@Validated @RequestBody AiChatMessageSendReqVO sendReqVO) {
-        Utf8SseEmitter sseEmitter = new Utf8SseEmitter();
-        chatService.chatStream(sendReqVO, sseEmitter);
-        return sseEmitter;
+    public Flux<AiChatMessageRespVO> sendMessageStream(@Validated @RequestBody AiChatMessageSendReqVO sendReqVO) {
+        return chatService.chatStream(sendReqVO);
     }
 
     @Operation(summary = "获得指定会话的消息列表")

+ 6 - 7
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/image/AiImageController.java

@@ -1,10 +1,9 @@
 package cn.iocoder.yudao.module.ai.controller.admin.image;
 
 import cn.iocoder.yudao.framework.common.pojo.CommonResult;
-import cn.iocoder.yudao.module.ai.controller.Utf8SseEmitter;
-import cn.iocoder.yudao.module.ai.service.AiImageService;
 import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageDallDrawingReq;
 import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageMidjourneyReq;
+import cn.iocoder.yudao.module.ai.service.AiImageService;
 import io.swagger.v3.oas.annotations.Operation;
 import io.swagger.v3.oas.annotations.tags.Tag;
 import lombok.AllArgsConstructor;
@@ -14,7 +13,6 @@ import org.springframework.web.bind.annotation.PostMapping;
 import org.springframework.web.bind.annotation.RequestBody;
 import org.springframework.web.bind.annotation.RequestMapping;
 import org.springframework.web.bind.annotation.RestController;
-import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
 
 // TODO @芋艿:整理接口定义
 /**
@@ -35,10 +33,11 @@ public class AiImageController {
 
     @Operation(summary = "dall2/dall3绘画", description = "openAi dall3是付费的!")
     @PostMapping("/dallDrawing")
-    public SseEmitter dallDrawing(@Validated @RequestBody AiImageDallDrawingReq req) {
-        Utf8SseEmitter sseEmitter = new Utf8SseEmitter();
-        aiImageService.dallDrawing(req, sseEmitter);
-        return sseEmitter;
+    public void dallDrawing(@Validated @RequestBody AiImageDallDrawingReq req) {
+//        Utf8SseEmitter sseEmitter = new Utf8SseEmitter();
+//        aiImageService.dallDrawing(req, sseEmitter);
+//        return sseEmitter;
+
     }
 
     @Operation(summary = "midjourney", description = "midjourney图片绘画流程:1、提交任务 2、获取完成的任务 3、选择对应功能 4、获取最终结果")

+ 3 - 4
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/AiChatService.java

@@ -1,8 +1,8 @@
 package cn.iocoder.yudao.module.ai.service;
 
-import cn.iocoder.yudao.module.ai.controller.Utf8SseEmitter;
 import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageRespVO;
 import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageSendReqVO;
+import reactor.core.publisher.Flux;
 
 import java.util.List;
 
@@ -26,11 +26,10 @@ public interface AiChatService {
     /**
      * chat stream
      *
-     * @param req
-     * @param sseEmitter
+     * @param sendReqVO
      * @return
      */
-    void chatStream(AiChatMessageSendReqVO req, Utf8SseEmitter sseEmitter);
+    Flux<AiChatMessageRespVO> chatStream(AiChatMessageSendReqVO sendReqVO);
 
     /**
      * 获取 - 获取对话 message list

+ 1 - 3
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/AiImageService.java

@@ -1,6 +1,5 @@
 package cn.iocoder.yudao.module.ai.service;
 
-import cn.iocoder.yudao.module.ai.controller.Utf8SseEmitter;
 import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageDallDrawingReq;
 import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageMidjourneyReq;
 
@@ -17,9 +16,8 @@ public interface AiImageService {
      * ai绘画 - dall2/dall3 绘画
      *
      * @param req
-     * @param sseEmitter
      */
-    void dallDrawing(AiImageDallDrawingReq req, Utf8SseEmitter sseEmitter);
+    void dallDrawing(AiImageDallDrawingReq req);
 
     /**
      * midjourney 图片生成

+ 33 - 38
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/impl/AiChatServiceImpl.java

@@ -9,7 +9,6 @@ import cn.iocoder.yudao.framework.ai.chat.messages.MessageType;
 import cn.iocoder.yudao.framework.ai.chat.prompt.Prompt;
 import cn.iocoder.yudao.framework.security.core.util.SecurityFrameworkUtils;
 import cn.iocoder.yudao.module.ai.config.AiChatClientFactory;
-import cn.iocoder.yudao.module.ai.controller.Utf8SseEmitter;
 import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.conversation.AiChatConversationRespVO;
 import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageRespVO;
 import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageSendReqVO;
@@ -25,13 +24,12 @@ import cn.iocoder.yudao.module.ai.service.AiChatRoleService;
 import cn.iocoder.yudao.module.ai.service.AiChatService;
 import lombok.AllArgsConstructor;
 import lombok.extern.slf4j.Slf4j;
-import org.springframework.http.MediaType;
 import org.springframework.stereotype.Service;
 import org.springframework.transaction.annotation.Transactional;
 import reactor.core.publisher.Flux;
 
-import java.io.IOException;
 import java.util.List;
+import java.util.concurrent.atomic.AtomicInteger;
 import java.util.function.Consumer;
 
 /**
@@ -76,6 +74,7 @@ public class AiChatServiceImpl implements AiChatService {
                 chatModal.getModel(), chatModal.getId(), req.getContent(),
                 null, conversation.getTemperature(), conversation.getMaxTokens(), conversation.getMaxContexts());
         String content = null;
+        int tokens = 0;
         try {
             // 创建 chat 需要的 Prompt
             Prompt prompt = new Prompt(req.getContent());
@@ -87,6 +86,7 @@ public class AiChatServiceImpl implements AiChatService {
             ChatClient chatClient = aiChatClientFactory.getChatClient(platformEnum);
             ChatResponse call = chatClient.call(prompt);
             content = call.getResult().getOutput().getContent();
+            tokens = call.getResults().size();
             // 更新 conversation
         } catch (Exception e) {
             content = ExceptionUtil.getMessage(e);
@@ -94,7 +94,7 @@ public class AiChatServiceImpl implements AiChatService {
             // 保存 chat message
             insertChatMessage(conversation.getId(), MessageType.SYSTEM, loginUserId, conversation.getRoleId(),
                     chatModal.getModel(), chatModal.getId(), content,
-                    null, conversation.getTemperature(), conversation.getMaxTokens(), conversation.getMaxContexts());
+                    tokens, conversation.getTemperature(), conversation.getMaxTokens(), conversation.getMaxContexts());
         }
         return new AiChatMessageRespVO().setContent(content);
     }
@@ -123,8 +123,7 @@ public class AiChatServiceImpl implements AiChatService {
         return insertChatMessageDO;
     }
 
-    @Override
-    public void chatStream(AiChatMessageSendReqVO req, Utf8SseEmitter sseEmitter) {
+    public Flux<AiChatMessageRespVO> chatStream(AiChatMessageSendReqVO req) {
         Long loginUserId = SecurityFrameworkUtils.getLoginUserId();
         // 查询对话
         AiChatConversationRespVO conversation = chatConversationService.getConversationOfValidate(req.getConversationId());
@@ -144,47 +143,43 @@ public class AiChatServiceImpl implements AiChatService {
 //        req.setTopK(req.getTopK());
 //        req.setTopP(req.getTopP());
 //        req.setTemperature(req.getTemperature());
-        // 保存 chat message
         // 保存 chat message
         insertChatMessage(conversation.getId(), MessageType.USER, loginUserId, conversation.getRoleId(),
                 chatModal.getModel(), chatModal.getId(), req.getContent(),
                 null, conversation.getTemperature(), conversation.getMaxTokens(), conversation.getMaxContexts());
-
         // 获取 client 类型
         AiPlatformEnum platformEnum = AiPlatformEnum.valueOfPlatform(chatModal.getPlatform());
         StreamingChatClient streamingChatClient = aiChatClientFactory.getStreamingChatClient(platformEnum);
         Flux<ChatResponse> streamResponse = streamingChatClient.stream(prompt);
-
+        // 转换 flex AiChatMessageRespVO
         StringBuffer contentBuffer = new StringBuffer();
-        streamResponse.subscribe(
-                new Consumer<ChatResponse>() {
-                    @Override
-                    public void accept(ChatResponse chatResponse) {
-                        String content = chatResponse.getResults().get(0).getOutput().getContent();
-                        try {
-                            contentBuffer.append(content);
-                            sseEmitter.send(new AiChatMessageRespVO().setContent(content), MediaType.APPLICATION_JSON);
-                        } catch (IOException e) {
-                            log.error("发送异常{}", ExceptionUtil.getMessage(e));
-                            // 如果不是因为关闭而抛出异常,则重新连接
-                            sseEmitter.completeWithError(e);
-                        }
-                    }
-                },
-                error -> {
-                    //
-                    log.error("subscribe错误 {}", ExceptionUtil.getMessage(error));
-                },
-                () -> {
-                    log.info("发送完成!");
-                    sseEmitter.complete();
-                    // 保存 chat message
-                    insertChatMessage(conversation.getId(), MessageType.SYSTEM, loginUserId, conversation.getRoleId(),
-                            chatModal.getModel(), chatModal.getId(), contentBuffer.toString(),
-                            null, conversation.getTemperature(), conversation.getMaxTokens(), conversation.getMaxContexts());
-
+        AtomicInteger tokens = new AtomicInteger(0);
+        return streamResponse.map(res -> {
+                    AiChatMessageRespVO aiChatMessageRespVO = new AiChatMessageRespVO();
+                    aiChatMessageRespVO.setContent(res.getResult().getOutput().getContent());
+                    contentBuffer.append(res.getResult().getOutput().getContent());
+                    tokens.incrementAndGet();
+                    return aiChatMessageRespVO;
                 }
-        );
+        ).doOnComplete(new Runnable() {
+            @Override
+            public void run() {
+                log.info("发送完成!");
+                // 保存 chat message
+                insertChatMessage(conversation.getId(), MessageType.SYSTEM, loginUserId, conversation.getRoleId(),
+                        chatModal.getModel(), chatModal.getId(), contentBuffer.toString(),
+                        tokens.get(), conversation.getTemperature(), conversation.getMaxTokens(), conversation.getMaxContexts());
+            }
+        }).doOnError(new Consumer<Throwable>() {
+            @Override
+            public void accept(Throwable throwable) {
+                log.error("发送错误 {}!", throwable.getMessage());
+                // 保存 chat message
+                insertChatMessage(conversation.getId(), MessageType.SYSTEM, loginUserId, conversation.getRoleId(),
+                        chatModal.getModel(), chatModal.getId(), throwable.getMessage(),
+                        tokens.get(), conversation.getTemperature(), conversation.getMaxTokens(), conversation.getMaxContexts());
+            }
+        });
     }
 
     @Override
@@ -194,7 +189,7 @@ public class AiChatServiceImpl implements AiChatService {
         // 获取对话所有 message
         List<AiChatMessageDO> aiChatMessageDOList = aiChatMessageMapper.selectByConversationId(conversationId);
         // 转换 AiChatMessageRespVO
-       return AiChatMessageConvert.INSTANCE.convertAiChatMessageRespVOList(aiChatMessageDOList);
+        return AiChatMessageConvert.INSTANCE.convertAiChatMessageRespVOList(aiChatMessageDOList);
     }
 
     @Override

+ 17 - 21
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/impl/AiImageServiceImpl.java

@@ -5,8 +5,8 @@ import cn.iocoder.yudao.framework.ai.image.ImageGeneration;
 import cn.iocoder.yudao.framework.ai.image.ImagePrompt;
 import cn.iocoder.yudao.framework.ai.image.ImageResponse;
 import cn.iocoder.yudao.framework.ai.imageopenai.OpenAiImageClient;
-import cn.iocoder.yudao.framework.ai.imageopenai.enums.OpenAiImageModelEnum;
 import cn.iocoder.yudao.framework.ai.imageopenai.OpenAiImageOptions;
+import cn.iocoder.yudao.framework.ai.imageopenai.enums.OpenAiImageModelEnum;
 import cn.iocoder.yudao.framework.ai.imageopenai.enums.OpenAiImageStyleEnum;
 import cn.iocoder.yudao.framework.ai.midjourney.api.MidjourneyInteractionsApi;
 import cn.iocoder.yudao.framework.ai.midjourney.webSocket.MidjourneyWebSocketStarter;
@@ -14,22 +14,18 @@ import cn.iocoder.yudao.framework.ai.midjourney.webSocket.WssNotify;
 import cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionUtil;
 import cn.iocoder.yudao.framework.security.core.util.SecurityFrameworkUtils;
 import cn.iocoder.yudao.module.ai.ErrorCodeConstants;
-import cn.iocoder.yudao.module.ai.controller.Utf8SseEmitter;
+import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageDallDrawingReq;
+import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageMidjourneyReq;
 import cn.iocoder.yudao.module.ai.dal.dataobject.image.AiImageDO;
-import cn.iocoder.yudao.module.ai.enums.AiChatDrawingStatusEnum;
 import cn.iocoder.yudao.module.ai.dal.mysql.AiImageMapper;
+import cn.iocoder.yudao.module.ai.enums.AiChatDrawingStatusEnum;
 import cn.iocoder.yudao.module.ai.service.AiImageService;
-import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageDallDrawingReq;
-import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageMidjourneyReq;
 import jakarta.annotation.PostConstruct;
 import lombok.AllArgsConstructor;
 import lombok.extern.slf4j.Slf4j;
-import org.springframework.http.MediaType;
 import org.springframework.stereotype.Service;
 import org.springframework.transaction.annotation.Transactional;
 
-import java.io.IOException;
-
 /**
  * ai 作图
  *
@@ -64,7 +60,7 @@ public class AiImageServiceImpl implements AiImageService {
     }
 
     @Override
-    public void dallDrawing(AiImageDallDrawingReq req, Utf8SseEmitter sseEmitter) {
+    public void dallDrawing(AiImageDallDrawingReq req) {
         // 获取 model
         OpenAiImageModelEnum openAiImageModelEnum = OpenAiImageModelEnum.valueOfModel(req.getModal());
         OpenAiImageStyleEnum openAiImageStyleEnum = OpenAiImageStyleEnum.valueOfStyle(req.getStyle());
@@ -79,7 +75,7 @@ public class AiImageServiceImpl implements AiImageService {
             // 发送
             ImageGeneration imageGeneration = imageResponse.getResult();
             // 发送信息
-            sendSseEmitter(sseEmitter, imageGeneration);
+//            sendSseEmitter(sseEmitter, imageGeneration);
             // 保存数据库
             doSave(req.getPrompt(), req.getSize(), req.getModal(),
                     imageGeneration.getOutput().getUrl(), AiChatDrawingStatusEnum.COMPLETE, null);
@@ -88,7 +84,7 @@ public class AiImageServiceImpl implements AiImageService {
             doSave(req.getPrompt(), req.getSize(), req.getModal(),
                     null, AiChatDrawingStatusEnum.FAIL, aiException.getMessage());
             // 发送错误信息
-            sendSseEmitter(sseEmitter, aiException.getMessage());
+//            sendSseEmitter(sseEmitter, aiException.getMessage());
         }
     }
 
@@ -105,16 +101,16 @@ public class AiImageServiceImpl implements AiImageService {
         }
     }
 
-    private static void sendSseEmitter(Utf8SseEmitter sseEmitter, Object object) {
-        try {
-            sseEmitter.send(object, MediaType.APPLICATION_JSON);
-        } catch (IOException e) {
-            throw new RuntimeException(e);
-        } finally {
-            // 发送 complete
-            sseEmitter.complete();
-        }
-    }
+//    private static void sendSseEmitter(Utf8SseEmitter sseEmitter, Object object) {
+//        try {
+//            sseEmitter.send(object, MediaType.APPLICATION_JSON);
+//        } catch (IOException e) {
+//            throw new RuntimeException(e);
+//        } finally {
+//            // 发送 complete
+//            sseEmitter.complete();
+//        }
+//    }
 
     private AiImageDO doSave(String prompt,
                         String size,

+ 6 - 2
yudao-server/src/main/resources/application-local.yaml

@@ -2,7 +2,6 @@ server:
   port: 48080
 
 --- #################### 数据库相关配置 ####################
-
 spring:
   # 数据源配置项
   autoconfigure:
@@ -79,7 +78,12 @@ spring:
       port: 6379 # 端口
       database: 0 # 数据库索引
 #    password: dev # 密码,建议生产环境开启
-
+server:
+  servlet:
+    encoding:
+      enabled: true
+      charset: UTF-8
+      force: true
 --- #################### 定时任务相关配置 ####################
 
 # Quartz 配置项,对应 QuartzProperties 配置类