瀏覽代碼

【新增】AI:发送消息时,增加上下文

YunaiV 11 月之前
父節點
當前提交
802dee2fc3

+ 3 - 0
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/chat/vo/message/AiChatMessageSendReqVO.java

@@ -19,4 +19,7 @@ public class AiChatMessageSendReqVO {
     @NotEmpty(message = "聊天内容不能为空")
     private String content;
 
+    @Schema(description = "是否携带上下文", example = "true")
+    private Boolean useContext;
+
 }

+ 15 - 2
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/chat/AiChatMessageDO.java

@@ -1,5 +1,6 @@
 package cn.iocoder.yudao.module.ai.dal.dataobject.chat;
 
+import com.baomidou.mybatisplus.annotation.TableId;
 import org.springframework.ai.chat.messages.MessageType;
 import cn.iocoder.yudao.framework.mybatis.core.dataobject.BaseDO;
 import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO;
@@ -27,14 +28,23 @@ public class AiChatMessageDO extends BaseDO {
     /**
      * 编号,作为每条聊天记录的唯一标识符
      */
+    @TableId
     private Long id;
 
     /**
      * 会话编号
      *
-     * 关联 {@link AiChatConversationDO#getId()}
+     * 关联 {@link AiChatConversationDO#getId()} 字段
      */
     private Long conversationId;
+    /**
+     * 回复消息编号
+     *
+     * 关联 {@link #id} 字段
+     *
+     * 大模型回复的消息编号,用于“问答”的关联
+     */
+    private Long replyId;
 
     /**
      * 消息类型
@@ -75,6 +85,9 @@ public class AiChatMessageDO extends BaseDO {
      */
     private String content;
 
-    // TODO 芋艿:是否作为上下文语料?use_context,待定
+    /**
+     * 是否携带上下文
+     */
+    private Boolean useContext;
 
 }

+ 54 - 17
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/impl/AiChatServiceImpl.java

@@ -1,5 +1,9 @@
 package cn.iocoder.yudao.module.ai.service.impl;
 
+import cn.hutool.core.collection.CollUtil;
+import cn.hutool.core.collection.ListUtil;
+import cn.hutool.core.util.ArrayUtil;
+import cn.hutool.core.util.BooleanUtil;
 import cn.hutool.core.util.ObjUtil;
 import cn.hutool.core.util.StrUtil;
 import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
@@ -109,15 +113,14 @@ public class AiChatServiceImpl implements AiChatService {
         StreamingChatClient chatClient = chatClientFactory.getStreamingChatClient(platform);
 
         // 2. 插入 user 发送消息
-        AiChatMessageDO userMessage = createChatMessage(conversation.getId(), model,
-                userId, conversation.getRoleId(), MessageType.USER, sendReqVO.getContent());
+        AiChatMessageDO userMessage = createChatMessage(conversation.getId(), null, model,
+                userId, conversation.getRoleId(), MessageType.USER, sendReqVO.getContent(), sendReqVO.getUseContext());
 
         // 3.1 插入 assistant 接收消息
-        AiChatMessageDO assistantMessage = createChatMessage(conversation.getId(), model,
-                userId, conversation.getRoleId(), MessageType.ASSISTANT, "");
+        AiChatMessageDO assistantMessage = createChatMessage(conversation.getId(), userMessage.getId(), model,
+                userId, conversation.getRoleId(), MessageType.ASSISTANT, "", sendReqVO.getUseContext());
 
         // 3.2 创建 chat 需要的 Prompt
-        // TODO 消息上下文
         Prompt prompt = buildPrompt(conversation, historyMessages, sendReqVO);
         Flux<ChatResponse> streamResponse = chatClient.stream(prompt);
 
@@ -139,32 +142,66 @@ public class AiChatServiceImpl implements AiChatService {
     }
 
     private Prompt buildPrompt(AiChatConversationDO conversation, List<AiChatMessageDO> messages, AiChatMessageSendReqVO sendReqVO) {
-        // TODO 芋艿:1)保留 n 个上下文;2)每一轮 token 数量
-//        if (conversation.getMaxContexts() != null && messages.size() > conversation.getMaxContexts()) {
-//
-//        }
         // 1. 构建 Prompt Message 列表
         List<Message> chatMessages = new ArrayList<>();
         // 1.1 system context 角色设定
         chatMessages.add(new SystemMessage(conversation.getSystemMessage()));
         // 1.2 history message 历史消息
-        messages.forEach(message -> chatMessages.add(new ChatMessage(message.getType().toUpperCase(), message.getContent())));
+        List<AiChatMessageDO> contextMessages = filterContextMessages(messages, conversation, sendReqVO);
+        contextMessages.forEach(message -> chatMessages.add(new ChatMessage(message.getType().toUpperCase(), message.getContent())));
         // 1.3 user message 新发送消息
         chatMessages.add(new UserMessage(sendReqVO.getContent()));
 
         // 2. 构建 ChatOptions 对象 TODO 芋艿:临时注释掉;等文心一言兼容了;
+        // TODO 每一轮 token 数量
 //        ChatOptions chatOptions = ChatOptionsBuilder.builder().withTemperature(conversation.getTemperature().floatValue()).build();
 //        return new Prompt(chatMessages, null);
         return new Prompt(chatMessages);
     }
 
-    private AiChatMessageDO createChatMessage(Long conversationId, AiChatModelDO model,
-                                              Long userId, Long roleId,
-                                              MessageType messageType, String content) {
-        AiChatMessageDO message = new AiChatMessageDO()
-                .setConversationId(conversationId).setModel(model.getModel()).setModelId(model.getId())
-                .setUserId(userId).setRoleId(roleId)
-                .setType(messageType.getValue()).setContent(content);
+    /**
+     * 从历史消息中,获得倒序的 n 组消息作为消息上下文
+     *
+     * n 组:指的是 user + assistant 形成一组
+     *
+     * @param messages 消息列表
+     * @param conversation 会话
+     * @param sendReqVO 发送请求
+     * @return 消息上下文
+     */
+    private List<AiChatMessageDO> filterContextMessages(List<AiChatMessageDO> messages, AiChatConversationDO conversation, AiChatMessageSendReqVO sendReqVO) {
+        if (conversation.getMaxContexts() == null || ObjUtil.notEqual(sendReqVO.getUseContext(), Boolean.TRUE)) {
+            return Collections.emptyList();
+        }
+        List<AiChatMessageDO> contextMessages = new ArrayList<>(conversation.getMaxContexts() * 2);
+        for (int i = messages.size() - 1; i >= 0; i--) {
+            AiChatMessageDO assistantMessage = CollUtil.get(messages, i);
+            if (assistantMessage == null || assistantMessage.getReplyId() == null) {
+                continue;
+            }
+            AiChatMessageDO userMessage = CollUtil.get(messages, i - 1);
+            if (userMessage == null || ObjUtil.notEqual(assistantMessage.getReplyId(), userMessage.getId())
+                || StrUtil.isEmpty(assistantMessage.getContent())) {
+                continue;
+            }
+            // 由于后续要 reverse 反转,所以先添加 assistantMessage
+            contextMessages.add(assistantMessage);
+            contextMessages.add(userMessage);
+            // 超过最大上下文,结束
+            if (contextMessages.size() >= conversation.getMaxContexts() * 2) {
+                break;
+            }
+        }
+        Collections.reverse(contextMessages);
+        return contextMessages;
+    }
+
+    private AiChatMessageDO createChatMessage(Long conversationId, Long replyId,
+                                              AiChatModelDO model, Long userId, Long roleId,
+                                              MessageType messageType, String content, Boolean useContext) {
+        AiChatMessageDO message = new AiChatMessageDO().setConversationId(conversationId).setReplyId(replyId)
+                .setModel(model.getModel()).setModelId(model.getId()).setUserId(userId).setRoleId(roleId)
+                .setType(messageType.getValue()).setContent(content).setUseContext(useContext);
         message.setCreateTime(LocalDateTime.now());
         chatMessageMapper.insert(message);
         return message;