Prechádzať zdrojové kódy

【新增】AI:流式发送消息的微调,统一成单接口

YunaiV 11 mesiacov pred
rodič
commit
b31e919d52

+ 3 - 3
yudao-module-ai/yudao-module-ai-biz/src/main/resources/http/chat-message.http → yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/chat/AiChatMessageController.http

@@ -10,13 +10,13 @@ Authorization: {{token}}
 }
 
 
-### chat call
-POST {{baseUrl}}/admin-api/ai/chat/message/send-stream
+### 发送消息(流式)
+POST {{baseUrl}}/ai/chat/message/send-stream
 Content-Type: application/json
 Authorization: {{token}}
 
 {
-  "conversationId": "1781604279872581649",
+  "conversationId": "1781604279872581651",
   "content": "苹果是什么颜色?"
 }
 

+ 3 - 8
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/chat/AiChatMessageController.java

@@ -17,6 +17,7 @@ import reactor.core.publisher.Flux;
 import java.util.List;
 
 import static cn.iocoder.yudao.framework.common.pojo.CommonResult.success;
+import static cn.iocoder.yudao.framework.security.core.util.SecurityFrameworkUtils.getLoginUserId;
 
 @Tag(name = "管理后台 - 聊天消息")
 @RestController
@@ -36,14 +37,8 @@ public class AiChatMessageController {
     @Operation(summary = "发送消息(流式)", description = "流式返回,响应较快")
     @PostMapping(value = "/send-stream", produces = MediaType.TEXT_EVENT_STREAM_VALUE)
     @PermitAll // 解决 SSE 最终响应的时候,会被 Access Denied 拦截的问题
-    public Flux<AiChatMessageRespVO> sendMessageStream(@Validated @RequestBody AiChatMessageSendStreamReqVO sendReqVO) {
-        return chatService.chatStream(sendReqVO);
-    }
-
-    @Operation(summary = "添加/提问", description = "先创建好 message 前端才好渲染")
-    @PostMapping(value = "/add")
-    public CommonResult<AiChatMessageRespVO> add(@Validated @RequestBody AiChatMessageAddReqVO req) {
-        return success(chatService.add(req));
+    public Flux<AiChatMessageSendRespVO> sendChatMessageStream(@Validated @RequestBody AiChatMessageSendReqVO sendReqVO) {
+        return chatService.sendChatMessageStream(sendReqVO, getLoginUserId());
     }
 
     @Operation(summary = "获得指定会话的消息列表")

+ 1 - 0
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/chat/vo/message/AiChatMessageRespVO.java

@@ -44,4 +44,5 @@ public class AiChatMessageRespVO {
 
     @Schema(description = "创建时间", requiredMode = Schema.RequiredMode.REQUIRED, example = "2024-05-12 12:51")
     private LocalDateTime createTime;
+
 }

+ 36 - 0
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/chat/vo/message/AiChatMessageSendRespVO.java

@@ -0,0 +1,36 @@
+package cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message;
+
+import io.swagger.v3.oas.annotations.media.Schema;
+import lombok.Data;
+
+import java.time.LocalDateTime;
+
+@Schema(description = "管理后台 - AI 聊天消息发送 Response VO")
+@Data
+public class AiChatMessageSendRespVO {
+
+    @Schema(description = "发送消息", requiredMode = Schema.RequiredMode.REQUIRED)
+    private Message send;
+
+    @Schema(description = "接收消息", requiredMode = Schema.RequiredMode.REQUIRED)
+    private Message receive;
+
+    @Schema(description = "消息")
+    @Data
+    public static class Message {
+
+        @Schema(description = "编号", requiredMode = Schema.RequiredMode.REQUIRED, example = "1024")
+        private Long id;
+
+        @Schema(description = "消息类型", requiredMode = Schema.RequiredMode.REQUIRED, example = "role")
+        private String type; // 参见 MessageType 枚举类
+
+        @Schema(description = "聊天内容", requiredMode = Schema.RequiredMode.REQUIRED, example = "你好,你好啊")
+        private String content;
+
+        @Schema(description = "创建时间", requiredMode = Schema.RequiredMode.REQUIRED, example = "2024-05-12 12:51")
+        private LocalDateTime createTime;
+
+    }
+
+}

+ 0 - 16
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/chat/vo/message/AiChatMessageSendStreamReqVO.java

@@ -1,16 +0,0 @@
-package cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message;
-
-import io.swagger.v3.oas.annotations.media.Schema;
-import jakarta.validation.constraints.NotEmpty;
-import jakarta.validation.constraints.NotNull;
-import lombok.Data;
-
-@Schema(description = "管理后台 - AI 聊天消息发送 Request VO")
-@Data
-public class AiChatMessageSendStreamReqVO {
-
-    @Schema(description = "提问的 messageId", requiredMode = Schema.RequiredMode.REQUIRED, example = "1024")
-    @NotNull(message = "提问的 messageId 不能为空")
-    private Long id;
-
-}

+ 0 - 8
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/convert/AiChatMessageConvert.java

@@ -27,12 +27,4 @@ public interface AiChatMessageConvert {
      */
     List<AiChatMessageRespVO> convertAiChatMessageRespVOList(List<AiChatMessageDO> aiChatMessageDOList);
 
-    /**
-     * 转换 - aiChatMessageDO
-     *
-     * @param aiChatMessageDO
-     * @return
-     */
-    AiChatMessageRespVO convertAiChatMessageRespVO(AiChatMessageDO aiChatMessageDO);
-
 }

+ 9 - 16
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/AiChatService.java

@@ -22,22 +22,6 @@ public interface AiChatService {
      */
     AiChatMessageRespVO chat(AiChatMessageSendReqVO sendReqVO);
 
-    /**
-     * chat stream
-     *
-     * @param sendReqVO
-     * @return
-     */
-    Flux<AiChatMessageRespVO> chatStream(AiChatMessageSendStreamReqVO sendReqVO);
-
-    /**
-     * 添加 - message
-     *
-     * @param sendReqVO
-     * @return
-     */
-    AiChatMessageRespVO add(AiChatMessageAddReqVO sendReqVO);
-
     /**
      * 获取 - 获取对话 message list
      *
@@ -54,4 +38,13 @@ public interface AiChatService {
      */
     Boolean deleteMessage(Long id);
 
+    /**
+     * 发送消息
+     *
+     * @param sendReqVO
+     * @param userId
+     * @return
+     */
+    Flux<AiChatMessageSendRespVO> sendChatMessageStream(AiChatMessageSendReqVO sendReqVO, Long userId);
+
 }

+ 65 - 81
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/impl/AiChatServiceImpl.java

@@ -1,27 +1,24 @@
 package cn.iocoder.yudao.module.ai.service.impl;
 
 import cn.hutool.core.exceptions.ExceptionUtil;
+import cn.hutool.core.util.ObjUtil;
 import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
+import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageSendRespVO;
 import cn.iocoder.yudao.module.ai.dal.dataobject.chat.AiChatConversationDO;
 import org.springframework.ai.chat.ChatClient;
 import org.springframework.ai.chat.ChatResponse;
 import org.springframework.ai.chat.StreamingChatClient;
 import org.springframework.ai.chat.messages.MessageType;
+import org.springframework.ai.chat.prompt.ChatOptionsBuilder;
 import org.springframework.ai.chat.prompt.Prompt;
-import cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionUtil;
 import cn.iocoder.yudao.framework.security.core.util.SecurityFrameworkUtils;
-import cn.iocoder.yudao.module.ai.ErrorCodeConstants;
 import cn.iocoder.yudao.module.ai.config.AiChatClientFactory;
-import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.conversation.AiChatConversationRespVO;
-import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageAddReqVO;
 import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageRespVO;
 import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageSendReqVO;
-import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageSendStreamReqVO;
 import cn.iocoder.yudao.module.ai.convert.AiChatMessageConvert;
 import cn.iocoder.yudao.module.ai.dal.dataobject.chat.AiChatMessageDO;
 import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO;
 import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatRoleDO;
-import cn.iocoder.yudao.module.ai.dal.mysql.chat.AiChatConversationMapper;
 import cn.iocoder.yudao.module.ai.dal.mysql.AiChatMessageMapper;
 import cn.iocoder.yudao.module.ai.service.chat.AiChatConversationService;
 import cn.iocoder.yudao.module.ai.service.model.AiChatRoleService;
@@ -33,13 +30,16 @@ import org.springframework.stereotype.Service;
 import org.springframework.transaction.annotation.Transactional;
 import reactor.core.publisher.Flux;
 
+import java.time.LocalDateTime;
 import java.util.List;
 import java.util.Map;
 import java.util.Set;
 import java.util.concurrent.atomic.AtomicInteger;
-import java.util.function.Consumer;
 import java.util.stream.Collectors;
 
+import static cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionUtil.exception;
+import static cn.iocoder.yudao.module.ai.ErrorCodeConstants.CHAT_CONVERSATION_NOT_EXISTS;
+
 /**
  * 聊天 service
  *
@@ -52,11 +52,11 @@ import java.util.stream.Collectors;
 @AllArgsConstructor
 public class AiChatServiceImpl implements AiChatService {
 
-    private final AiChatClientFactory aiChatClientFactory;
+    private final AiChatClientFactory chatClientFactory;
 
     private final AiChatMessageMapper aiChatMessageMapper;
     private final AiChatConversationService chatConversationService;
-    private final AiChatModelService aiChatModalService;
+    private final AiChatModelService chatModalService;
     private final AiChatRoleService chatRoleService;
 
     @Transactional(rollbackFor = Exception.class)
@@ -65,7 +65,7 @@ public class AiChatServiceImpl implements AiChatService {
         // 查询对话
         AiChatConversationDO conversation = chatConversationService.validateExists(req.getConversationId());
         // 获取对话模型
-        AiChatModelDO chatModel = aiChatModalService.validateChatModel(conversation.getModelId());
+        AiChatModelDO chatModel = chatModalService.validateChatModel(conversation.getModelId());
         // 获取角色信息
         AiChatRoleDO chatRoleDO = conversation.getRoleId() != null ? chatRoleService.validateChatRole(conversation.getRoleId()) : null;
         // 获取 client 类型
@@ -84,7 +84,7 @@ public class AiChatServiceImpl implements AiChatService {
 //            req.setTopP(req.getTopP());
 //            req.setTemperature(req.getTemperature());
             // 发送 call 调用
-            ChatClient chatClient = aiChatClientFactory.getChatClient(platformEnum);
+            ChatClient chatClient = chatClientFactory.getChatClient(platformEnum);
             ChatResponse call = chatClient.call(prompt);
             content = call.getResult().getOutput().getContent();
             tokens = call.getResults().size();
@@ -113,88 +113,72 @@ public class AiChatServiceImpl implements AiChatService {
                 .setModelId(modelId)
                 .setContent(content)
                 .setTokens(tokens)
-
                 .setTemperature(temperature)
                 .setMaxTokens(maxTokens)
                 .setMaxContexts(maxContexts);
+        insertChatMessageDO.setCreateTime(LocalDateTime.now());
         // 增加 chat message 记录
         aiChatMessageMapper.insert(insertChatMessageDO);
         return insertChatMessageDO;
     }
 
-    public Flux<AiChatMessageRespVO> chatStream(AiChatMessageSendStreamReqVO req) {
-        Long loginUserId = SecurityFrameworkUtils.getLoginUserId();
-        // 查询提问的 message
-        AiChatMessageDO aiChatMessageDO = aiChatMessageMapper.selectById(req.getId());
-        if (aiChatMessageDO == null) {
-            throw ServiceExceptionUtil.exception(ErrorCodeConstants.AI_CHAT_MESSAGE_NOT_EXIST);
+    @Override
+    public Flux<AiChatMessageSendRespVO> sendChatMessageStream(AiChatMessageSendReqVO sendReqVO, Long userId) {
+        // 1.1 校验对话存在
+        AiChatConversationDO conversation = chatConversationService.validateExists(sendReqVO.getConversationId());
+        if (ObjUtil.notEqual(conversation.getUserId(), userId)) {
+            throw exception(CHAT_CONVERSATION_NOT_EXISTS);
         }
-        // 查询对话
-        AiChatConversationDO conversation = chatConversationService.validateExists(aiChatMessageDO.getConversationId());
-        // 获取对话模型
-        AiChatModelDO chatModel = aiChatModalService.validateChatModel(conversation.getModelId());
-        // 获取角色信息
-        AiChatRoleDO chatRoleDO = conversation.getRoleId() != null ? chatRoleService.validateChatRole(conversation.getRoleId()) : null;
-        // 创建 chat 需要的 Prompt
-        Prompt prompt = new Prompt(aiChatMessageDO.getContent());
-        // 提前创建一个 system message
-        AiChatMessageDO systemMessage = insertChatMessage(conversation.getId(), MessageType.SYSTEM, loginUserId, conversation.getRoleId(),
-                chatModel.getModel(), chatModel.getId(), "",
+        // 1.2 校验模型
+        AiChatModelDO model = chatModalService.validateChatModel(conversation.getModelId());
+        AiPlatformEnum platform = AiPlatformEnum.validatePlatform(model.getPlatform());
+        StreamingChatClient chatClient = chatClientFactory.getStreamingChatClient(platform);
+
+        // 2. 插入 user 发送消息 TODO tokens 计算
+        AiChatMessageDO userMessage = insertChatMessage(conversation.getId(), MessageType.USER, userId, conversation.getRoleId(),
+                conversation.getModel(), conversation.getId(), sendReqVO.getContent(),
+                null, conversation.getTemperature(), conversation.getMaxTokens(), conversation.getMaxContexts());
+
+        // 3.1 插入 system 接收消息
+        AiChatMessageDO systemMessage = insertChatMessage(conversation.getId(), MessageType.SYSTEM, userId, conversation.getRoleId(),
+                conversation.getModel(), conversation.getId(), conversation.getSystemMessage(),
                 0, conversation.getTemperature(), conversation.getMaxTokens(), conversation.getMaxContexts());
-//        req.setTopK(req.getTopK());
-//        req.setTopP(req.getTopP());
-//        req.setTemperature(req.getTemperature());
-        // 获取 client 类型
-        AiPlatformEnum platformEnum = AiPlatformEnum.validatePlatform(chatModel.getPlatform());
-        StreamingChatClient streamingChatClient = aiChatClientFactory.getStreamingChatClient(platformEnum);
-        Flux<ChatResponse> streamResponse = streamingChatClient.stream(prompt);
-        // 转换 flex AiChatMessageRespVO
+        // 3.2 创建 chat 需要的 Prompt
+        // TODO 消息上下文
+        Prompt prompt = new Prompt(sendReqVO.getContent());
+//        ChatOptionsBuilder.builder().withTemperature(conversation.getTemperature().floatValue()).build()
+        Flux<ChatResponse> streamResponse = chatClient.stream(prompt);
+        // 3.3 转换 flex AiChatMessageRespVO
         StringBuffer contentBuffer = new StringBuffer();
-        AtomicInteger tokens = new AtomicInteger(0);
+        AtomicInteger tokens = new AtomicInteger(0); // TODO token 计算不对;
         return streamResponse.map(res -> {
-                    AiChatMessageRespVO aiChatMessageRespVO =
-                            AiChatMessageConvert.INSTANCE.convertAiChatMessageRespVO(systemMessage);
-                    aiChatMessageRespVO.setContent(res.getResult().getOutput().getContent());
-                    contentBuffer.append(res.getResult().getOutput().getContent());
-                    tokens.incrementAndGet();
-                    return aiChatMessageRespVO;
-                }
-        ).doOnComplete(new Runnable() {
-            @Override
-            public void run() {
-                log.info("发送完成!");
-                // 保存 chat message
-                aiChatMessageMapper.updateById(new AiChatMessageDO()
-                        .setId(systemMessage.getId())
-                        .setContent(contentBuffer.toString())
-                        .setTokens(tokens.get())
-                );
-            }
-        }).doOnError(new Consumer<Throwable>() {
-            @Override
-            public void accept(Throwable throwable) {
-                log.error("发送错误 {}!", throwable.getMessage());
-                // 更新错误信息
-                aiChatMessageMapper.updateById(new AiChatMessageDO()
-                        .setId(systemMessage.getId())
-                        .setContent(throwable.getMessage())
-                        .setTokens(tokens.get())
-                );
-            }
-        });
-    }
+            contentBuffer.append(res.getResult().getOutput().getContent());
+            tokens.incrementAndGet();
 
-    @Override
-    public AiChatMessageRespVO add(AiChatMessageAddReqVO req) {
-        Long loginUserId = SecurityFrameworkUtils.getLoginUserId();
-        // 查询对话
-        AiChatConversationDO conversation = chatConversationService.validateExists(req.getConversationId());
-        // 获取对话模型
-        AiChatModelDO chatModel = aiChatModalService.validateChatModel(conversation.getModelId());
-        AiChatMessageDO userMessage = insertChatMessage(conversation.getId(), MessageType.USER, loginUserId, conversation.getRoleId(),
-                chatModel.getModel(), chatModel.getId(), req.getContent(),
-                null, conversation.getTemperature(), conversation.getMaxTokens(), conversation.getMaxContexts());
-       return AiChatMessageConvert.INSTANCE.convertAiChatMessageRespVO(userMessage);
+            AiChatMessageSendRespVO.Message send = new AiChatMessageSendRespVO.Message().setId(userMessage.getId())
+                    .setType(MessageType.USER.getValue()).setCreateTime(userMessage.getCreateTime())
+                    .setContent(sendReqVO.getContent());
+            AiChatMessageSendRespVO.Message receive = new AiChatMessageSendRespVO.Message().setId(systemMessage.getId())
+                    .setType(MessageType.SYSTEM.getValue()).setCreateTime(systemMessage.getCreateTime())
+                    .setContent(res.getResult().getOutput().getContent());
+            return new AiChatMessageSendRespVO().setSend(send).setReceive(receive);
+        }).doOnComplete(() -> {
+            log.info("发送完成!");
+            // 保存 chat message
+            aiChatMessageMapper.updateById(new AiChatMessageDO()
+                    .setId(systemMessage.getId())
+                    .setContent(contentBuffer.toString())
+                    .setTokens(tokens.get())
+            );
+        }).doOnError(throwable -> {
+            log.error("发送错误 {}!", throwable.getMessage());
+            // 更新错误信息 TODO 貌似不应该更新异常
+            aiChatMessageMapper.updateById(new AiChatMessageDO()
+                    .setId(systemMessage.getId())
+                    .setContent(throwable.getMessage())
+                    .setTokens(tokens.get())
+            );
+        });
     }
 
     @Override
@@ -205,7 +189,7 @@ public class AiChatServiceImpl implements AiChatService {
         List<AiChatMessageDO> aiChatMessageDOList = aiChatMessageMapper.selectByConversationId(conversationId);
         // 获取模型信息
         Set<Long> modalIds = aiChatMessageDOList.stream().map(AiChatMessageDO::getModelId).collect(Collectors.toSet());
-        List<AiChatModelDO> modalList = aiChatModalService.getModalByIds(modalIds);
+        List<AiChatModelDO> modalList = chatModalService.getModalByIds(modalIds);
         Map<Long, AiChatModelDO> modalIdMap = modalList.stream().collect(Collectors.toMap(AiChatModelDO::getId, o -> o));
         // 转换 AiChatMessageRespVO
         List<AiChatMessageRespVO> aiChatMessageRespList = AiChatMessageConvert.INSTANCE.convertAiChatMessageRespVOList(aiChatMessageDOList);

+ 4 - 1
yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/org/springframework/ai/models/yiyan/YiYanChatClient.java

@@ -94,7 +94,10 @@ public class YiYanChatClient implements ChatClient, StreamingChatClient {
                 String a = ";";
             }
         });
-        return response.map(res -> new ChatResponse(List.of(new Generation(res.getResult()))));
+        return response.map(res -> {
+            // TODO @fan:这里缺少了 usage 的封装
+            return new ChatResponse(List.of(new Generation(res.getResult())));
+        });
     }
 
     private YiYanChatCompletionRequest createRequest(Prompt prompt, boolean stream) {