فهرست منبع

【新增】AI:conversation 发送消息时,增加上下文

YunaiV 11 ماه پیش
والد
کامیت
276ef98ff1

+ 2 - 2
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/chat/vo/message/AiChatMessageRespVO.java

@@ -19,10 +19,10 @@ public class AiChatMessageRespVO {
     private String type; // 参见 MessageType 枚举类
 
     @Schema(description = "用户编号", example = "4096")
-    private Long userId; // 仅当 user 发送时非空
+    private Long userId;
 
     @Schema(description = "角色编号", example = "888")
-    private Long roleId; // 仅当 assistant 回复时非空
+    private Long roleId;
 
     @Schema(description = "模型标志", requiredMode = Schema.RequiredMode.REQUIRED, example = "gpt-3.5-turbo")
     private String model; // 参见 AiOpenAiModelEnum 枚举类

+ 0 - 4
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/chat/AiChatMessageDO.java

@@ -47,16 +47,12 @@ public class AiChatMessageDO extends BaseDO {
     /**
      * 用户编号
      *
-     * 仅当 user 发送时非空
-     *
      * 关联 AdminUserDO 的 userId 字段
      */
     private Long userId;
     /**
      * 角色编号
      *
-     * 仅当 assistant 回复时非空
-     *
      * 关联 {@link AiChatRoleDO#getId()} 字段
      */
     private Long roleId;

+ 94 - 90
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/impl/AiChatServiceImpl.java

@@ -1,23 +1,22 @@
 package cn.iocoder.yudao.module.ai.service.impl;
 
-import cn.hutool.core.exceptions.ExceptionUtil;
 import cn.hutool.core.util.ObjUtil;
 import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
+import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
 import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageSendRespVO;
 import cn.iocoder.yudao.module.ai.dal.dataobject.chat.AiChatConversationDO;
-import org.springframework.ai.chat.ChatClient;
 import org.springframework.ai.chat.ChatResponse;
 import org.springframework.ai.chat.StreamingChatClient;
-import org.springframework.ai.chat.messages.MessageType;
+import org.springframework.ai.chat.messages.*;
+import org.springframework.ai.chat.prompt.ChatOptions;
+import org.springframework.ai.chat.prompt.ChatOptionsBuilder;
 import org.springframework.ai.chat.prompt.Prompt;
-import cn.iocoder.yudao.framework.security.core.util.SecurityFrameworkUtils;
 import cn.iocoder.yudao.module.ai.config.AiChatClientFactory;
 import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageRespVO;
 import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageSendReqVO;
 import cn.iocoder.yudao.module.ai.convert.AiChatMessageConvert;
 import cn.iocoder.yudao.module.ai.dal.dataobject.chat.AiChatMessageDO;
 import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO;
-import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatRoleDO;
 import cn.iocoder.yudao.module.ai.dal.mysql.AiChatMessageMapper;
 import cn.iocoder.yudao.module.ai.service.chat.AiChatConversationService;
 import cn.iocoder.yudao.module.ai.service.model.AiChatRoleService;
@@ -30,10 +29,7 @@ import org.springframework.transaction.annotation.Transactional;
 import reactor.core.publisher.Flux;
 
 import java.time.LocalDateTime;
-import java.util.List;
-import java.util.Map;
-import java.util.Set;
-import java.util.concurrent.atomic.AtomicInteger;
+import java.util.*;
 import java.util.stream.Collectors;
 
 import static cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionUtil.exception;
@@ -53,64 +49,49 @@ public class AiChatServiceImpl implements AiChatService {
 
     private final AiChatClientFactory chatClientFactory;
 
-    private final AiChatMessageMapper aiChatMessageMapper;
+    private final AiChatMessageMapper chatMessageMapper;
+
     private final AiChatConversationService chatConversationService;
     private final AiChatModelService chatModalService;
     private final AiChatRoleService chatRoleService;
 
     @Transactional(rollbackFor = Exception.class)
     public AiChatMessageRespVO chat(AiChatMessageSendReqVO req) {
-        Long loginUserId = SecurityFrameworkUtils.getLoginUserId();
-        // 查询对话
-        AiChatConversationDO conversation = chatConversationService.validateExists(req.getConversationId());
-        // 获取对话模型
-        AiChatModelDO chatModel = chatModalService.validateChatModel(conversation.getModelId());
-        // 获取角色信息
-        AiChatRoleDO chatRoleDO = conversation.getRoleId() != null ? chatRoleService.validateChatRole(conversation.getRoleId()) : null;
-        // 获取 client 类型
-        AiPlatformEnum platformEnum = AiPlatformEnum.validatePlatform(chatModel.getPlatform());
-        // 保存 chat message
-        insertChatMessage(conversation.getId(), MessageType.USER, loginUserId, conversation.getRoleId(),
-                chatModel.getModel(), chatModel.getId(), req.getContent());
-        String content = null;
-        int tokens = 0;
-        try {
-            // 创建 chat 需要的 Prompt
-            Prompt prompt = new Prompt(req.getContent());
-            // TODO @芋艿 @范 看要不要支持这些
-//            req.setTopK(req.getTopK());
-//            req.setTopP(req.getTopP());
-//            req.setTemperature(req.getTemperature());
-            // 发送 call 调用
-            ChatClient chatClient = chatClientFactory.getChatClient(platformEnum);
-            ChatResponse call = chatClient.call(prompt);
-            content = call.getResult().getOutput().getContent();
-            tokens = call.getResults().size();
-            // 更新 conversation
-        } catch (Exception e) {
-            content = ExceptionUtil.getMessage(e);
-        } finally {
-            // 保存 chat message
-            insertChatMessage(conversation.getId(), MessageType.SYSTEM, loginUserId, conversation.getRoleId(),
-                    chatModel.getModel(), chatModel.getId(), content);
-        }
-        return new AiChatMessageRespVO().setContent(content);
-    }
-
-    private AiChatMessageDO insertChatMessage(Long conversationId, MessageType messageType, Long loginUserId, Long roleId,
-                                              String model, Long modelId, String content) {
-        AiChatMessageDO insertChatMessageDO = new AiChatMessageDO()
-                .setConversationId(conversationId)
-                .setType(messageType.getValue())
-                .setUserId(loginUserId)
-                .setRoleId(roleId)
-                .setModel(model)
-                .setModelId(modelId)
-                .setContent(content);
-        insertChatMessageDO.setCreateTime(LocalDateTime.now());
-        // 增加 chat message 记录
-        aiChatMessageMapper.insert(insertChatMessageDO);
-        return insertChatMessageDO;
+         return null; // TODO 芋艿:一起改
+//        Long loginUserId = SecurityFrameworkUtils.getLoginUserId();
+//        // 查询对话
+//        AiChatConversationDO conversation = chatConversationService.validateExists(req.getConversationId());
+//        // 获取对话模型
+//        AiChatModelDO chatModel = chatModalService.validateChatModel(conversation.getModelId());
+//        // 获取角色信息
+//        AiChatRoleDO chatRoleDO = conversation.getRoleId() != null ? chatRoleService.validateChatRole(conversation.getRoleId()) : null;
+//        // 获取 client 类型
+//        AiPlatformEnum platformEnum = AiPlatformEnum.validatePlatform(chatModel.getPlatform());
+//        // 保存 chat message
+//        createChatMessage(conversation.getId(), MessageType.USER, loginUserId, conversation.getRoleId(),
+//                chatModel.getModel(), chatModel.getId(), req.getContent());
+//        String content = null;
+//        int tokens = 0;
+//        try {
+//            // 创建 chat 需要的 Prompt
+//            Prompt prompt = new Prompt(req.getContent());
+//            // TODO @芋艿 @范 看要不要支持这些
+////            req.setTopK(req.getTopK());
+////            req.setTopP(req.getTopP());
+////            req.setTemperature(req.getTemperature());
+//            // 发送 call 调用
+//            ChatClient chatClient = chatClientFactory.getChatClient(platformEnum);
+//            ChatResponse call = chatClient.call(prompt);
+//            content = call.getResult().getOutput().getContent();
+//            // 更新 conversation
+//        } catch (Exception e) {
+//            content = ExceptionUtil.getMessage(e);
+//        } finally {
+//            // 保存 chat message
+//            createChatMessage(conversation.getId(), MessageType.SYSTEM, loginUserId, conversation.getRoleId(),
+//                    chatModel.getModel(), chatModel.getId(), content);
+//        }
+//        return new AiChatMessageRespVO().setContent(content);
     }
 
     @Override
@@ -120,55 +101,78 @@ public class AiChatServiceImpl implements AiChatService {
         if (ObjUtil.notEqual(conversation.getUserId(), userId)) {
             throw exception(CHAT_CONVERSATION_NOT_EXISTS); // TODO 芋艿:异常情况的对接;
         }
+        List<AiChatMessageDO> historyMessages = chatMessageMapper.selectByConversationId(conversation.getId());
         // 1.2 校验模型
         AiChatModelDO model = chatModalService.validateChatModel(conversation.getModelId());
         AiPlatformEnum platform = AiPlatformEnum.validatePlatform(model.getPlatform());
         StreamingChatClient chatClient = chatClientFactory.getStreamingChatClient(platform);
 
         // 2. 插入 user 发送消息
-        AiChatMessageDO userMessage = insertChatMessage(conversation.getId(), MessageType.USER, userId, conversation.getRoleId(),
-                conversation.getModel(), conversation.getId(), sendReqVO.getContent());
+        AiChatMessageDO userMessage = createChatMessage(conversation.getId(), model,
+                userId, conversation.getRoleId(), MessageType.USER, sendReqVO.getContent());
+
+        // 3.1 插入 assistant 接收消息
+        AiChatMessageDO assistantMessage = createChatMessage(conversation.getId(), model,
+                userId, conversation.getRoleId(), MessageType.ASSISTANT, "");
 
-        // 3.1 插入 system 接收消息
-        AiChatMessageDO systemMessage = insertChatMessage(conversation.getId(), MessageType.SYSTEM, userId, conversation.getRoleId(),
-                conversation.getModel(), conversation.getId(), conversation.getSystemMessage());
         // 3.2 创建 chat 需要的 Prompt
         // TODO 消息上下文
-        Prompt prompt = new Prompt(sendReqVO.getContent());
-//        ChatOptionsBuilder.builder().withTemperature(conversation.getTemperature().floatValue()).build()
+        Prompt prompt = buildPrompt(conversation, historyMessages, sendReqVO);
         Flux<ChatResponse> streamResponse = chatClient.stream(prompt);
-        // 3.3 转换 flex AiChatMessageRespVO
+
+        // 3.3 流式返回
         StringBuffer contentBuffer = new StringBuffer();
-        return streamResponse.map(res -> {
-            contentBuffer.append(res.getResult().getOutput().getContent());
-            AiChatMessageSendRespVO.Message send = new AiChatMessageSendRespVO.Message().setId(userMessage.getId())
-                    .setType(MessageType.USER.getValue()).setCreateTime(userMessage.getCreateTime())
-                    .setContent(sendReqVO.getContent());
-            AiChatMessageSendRespVO.Message receive = new AiChatMessageSendRespVO.Message().setId(systemMessage.getId())
-                    .setType(MessageType.SYSTEM.getValue()).setCreateTime(systemMessage.getCreateTime())
-                    .setContent(res.getResult().getOutput().getContent());
-            return new AiChatMessageSendRespVO().setSend(send).setReceive(receive);
+        return streamResponse.map(response -> {
+            String newContent = response.getResult().getOutput().getContent();
+            contentBuffer.append(newContent);
+            // 响应结果
+            return new AiChatMessageSendRespVO().setSend(BeanUtils.toBean(userMessage, AiChatMessageSendRespVO.Message.class))
+                    .setReceive(BeanUtils.toBean(userMessage, AiChatMessageSendRespVO.Message.class).setContent(newContent));
         }).doOnComplete(() -> {
-            log.info("发送完成!");
-            // 保存 chat message
-            aiChatMessageMapper.updateById(new AiChatMessageDO()
-                    .setId(systemMessage.getId())
-                    .setContent(contentBuffer.toString()));
+            chatMessageMapper.updateById(new AiChatMessageDO().setId(assistantMessage.getId()).setContent(contentBuffer.toString()));
         }).doOnError(throwable -> {
-            log.error("发送错误 {}!", throwable.getMessage());
-            // 更新错误信息 TODO 貌似不应该更新异常
-            aiChatMessageMapper.updateById(new AiChatMessageDO()
-                    .setId(systemMessage.getId())
-                    .setContent(throwable.getMessage()));
+            log.error("[sendChatMessageStream][userId({}) sendReqVO({}) 发生异常]", userId, sendReqVO, throwable);
+            chatMessageMapper.updateById(new AiChatMessageDO().setId(assistantMessage.getId()).setContent(throwable.getMessage()));
         });
     }
 
+    private Prompt buildPrompt(AiChatConversationDO conversation, List<AiChatMessageDO> messages, AiChatMessageSendReqVO sendReqVO) {
+        // TODO 芋艿:1)保留 n 个上下文;2)每一轮 token 数量
+//        if (conversation.getMaxContexts() != null && messages.size() > conversation.getMaxContexts()) {
+//
+//        }
+        // 1. 构建 Prompt Message 列表
+        List<Message> chatMessages = new ArrayList<>();
+        // 1.1 system context 角色设定
+        chatMessages.add(new SystemMessage(conversation.getSystemMessage()));
+        // 1.2 history message 历史消息
+        messages.forEach(message -> chatMessages.add(new ChatMessage(message.getType().toUpperCase(), message.getContent())));
+        // 1.3 user message 新发送消息
+        chatMessages.add(new UserMessage(sendReqVO.getContent()));
+
+        // 2. 构建 ChatOptions 对象
+        ChatOptions chatOptions = ChatOptionsBuilder.builder().withTemperature(conversation.getTemperature().floatValue()).build();
+        return new Prompt(chatMessages, chatOptions);
+    }
+
+    private AiChatMessageDO createChatMessage(Long conversationId, AiChatModelDO model,
+                                              Long userId, Long roleId,
+                                              MessageType messageType, String content) {
+        AiChatMessageDO message = new AiChatMessageDO()
+                .setConversationId(conversationId).setModel(model.getModel()).setModelId(model.getId())
+                .setUserId(userId).setRoleId(roleId)
+                .setType(messageType.getValue()).setContent(content);
+        message.setCreateTime(LocalDateTime.now());
+        chatMessageMapper.insert(message);
+        return message;
+    }
+
     @Override
     public List<AiChatMessageRespVO> getMessageListByConversationId(Long conversationId) {
         // 校验对话是否存在
         chatConversationService.validateExists(conversationId);
         // 获取对话所有 message
-        List<AiChatMessageDO> aiChatMessageDOList = aiChatMessageMapper.selectByConversationId(conversationId);
+        List<AiChatMessageDO> aiChatMessageDOList = chatMessageMapper.selectByConversationId(conversationId);
         // 获取模型信息
         Set<Long> modalIds = aiChatMessageDOList.stream().map(AiChatMessageDO::getModelId).collect(Collectors.toSet());
         List<AiChatModelDO> modalList = chatModalService.getModalByIds(modalIds);
@@ -187,7 +191,7 @@ public class AiChatServiceImpl implements AiChatService {
 
     @Override
     public Boolean deleteMessage(Long id) {
-        return aiChatMessageMapper.deleteById(id) > 0;
+        return chatMessageMapper.deleteById(id) > 0;
     }
 
 }

+ 2 - 2
yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/enums/AiPlatformEnum.java

@@ -15,13 +15,13 @@ import lombok.Getter;
 public enum AiPlatformEnum {
 
     OPENAI("OpenAI", "OpenAI"),
-    OLLAMA("dall", "dall"),
+    OLLAMA("Ollama", "Ollama"),
 
     YI_YAN("yiyan", "一言"),
     QIAN_WEN("qianwen", "千问"),
     XING_HUO("xinghuo", "星火"),
     OPEN_AI_DALL("dall", "dall"),
-    MIDJOURNEY("Ollama", "Ollama"),
+    MIDJOURNEY("midjourney", "midjourney"),
 
     ;