|
@@ -1,5 +1,9 @@
|
|
|
package cn.iocoder.yudao.module.ai.service.impl;
|
|
|
|
|
|
+import cn.hutool.core.collection.CollUtil;
|
|
|
+import cn.hutool.core.collection.ListUtil;
|
|
|
+import cn.hutool.core.util.ArrayUtil;
|
|
|
+import cn.hutool.core.util.BooleanUtil;
|
|
|
import cn.hutool.core.util.ObjUtil;
|
|
|
import cn.hutool.core.util.StrUtil;
|
|
|
import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
|
|
@@ -109,15 +113,14 @@ public class AiChatServiceImpl implements AiChatService {
|
|
|
StreamingChatClient chatClient = chatClientFactory.getStreamingChatClient(platform);
|
|
|
|
|
|
// 2. 插入 user 发送消息
|
|
|
- AiChatMessageDO userMessage = createChatMessage(conversation.getId(), model,
|
|
|
- userId, conversation.getRoleId(), MessageType.USER, sendReqVO.getContent());
|
|
|
+ AiChatMessageDO userMessage = createChatMessage(conversation.getId(), null, model,
|
|
|
+ userId, conversation.getRoleId(), MessageType.USER, sendReqVO.getContent(), sendReqVO.getUseContext());
|
|
|
|
|
|
// 3.1 插入 assistant 接收消息
|
|
|
- AiChatMessageDO assistantMessage = createChatMessage(conversation.getId(), model,
|
|
|
- userId, conversation.getRoleId(), MessageType.ASSISTANT, "");
|
|
|
+ AiChatMessageDO assistantMessage = createChatMessage(conversation.getId(), userMessage.getId(), model,
|
|
|
+ userId, conversation.getRoleId(), MessageType.ASSISTANT, "", sendReqVO.getUseContext());
|
|
|
|
|
|
// 3.2 创建 chat 需要的 Prompt
|
|
|
- // TODO 消息上下文
|
|
|
Prompt prompt = buildPrompt(conversation, historyMessages, sendReqVO);
|
|
|
Flux<ChatResponse> streamResponse = chatClient.stream(prompt);
|
|
|
|
|
@@ -139,32 +142,66 @@ public class AiChatServiceImpl implements AiChatService {
|
|
|
}
|
|
|
|
|
|
private Prompt buildPrompt(AiChatConversationDO conversation, List<AiChatMessageDO> messages, AiChatMessageSendReqVO sendReqVO) {
|
|
|
- // TODO 芋艿:1)保留 n 个上下文;2)每一轮 token 数量
|
|
|
-// if (conversation.getMaxContexts() != null && messages.size() > conversation.getMaxContexts()) {
|
|
|
-//
|
|
|
-// }
|
|
|
// 1. 构建 Prompt Message 列表
|
|
|
List<Message> chatMessages = new ArrayList<>();
|
|
|
// 1.1 system context 角色设定
|
|
|
chatMessages.add(new SystemMessage(conversation.getSystemMessage()));
|
|
|
// 1.2 history message 历史消息
|
|
|
- messages.forEach(message -> chatMessages.add(new ChatMessage(message.getType().toUpperCase(), message.getContent())));
|
|
|
+ List<AiChatMessageDO> contextMessages = filterContextMessages(messages, conversation, sendReqVO);
|
|
|
+ contextMessages.forEach(message -> chatMessages.add(new ChatMessage(message.getType().toUpperCase(), message.getContent())));
|
|
|
// 1.3 user message 新发送消息
|
|
|
chatMessages.add(new UserMessage(sendReqVO.getContent()));
|
|
|
|
|
|
// 2. 构建 ChatOptions 对象 TODO 芋艿:临时注释掉;等文心一言兼容了;
|
|
|
+ // TODO 每一轮 token 数量
|
|
|
// ChatOptions chatOptions = ChatOptionsBuilder.builder().withTemperature(conversation.getTemperature().floatValue()).build();
|
|
|
// return new Prompt(chatMessages, null);
|
|
|
return new Prompt(chatMessages);
|
|
|
}
|
|
|
|
|
|
- private AiChatMessageDO createChatMessage(Long conversationId, AiChatModelDO model,
|
|
|
- Long userId, Long roleId,
|
|
|
- MessageType messageType, String content) {
|
|
|
- AiChatMessageDO message = new AiChatMessageDO()
|
|
|
- .setConversationId(conversationId).setModel(model.getModel()).setModelId(model.getId())
|
|
|
- .setUserId(userId).setRoleId(roleId)
|
|
|
- .setType(messageType.getValue()).setContent(content);
|
|
|
+ /**
|
|
|
+ * 从历史消息中,获得倒序的 n 组消息作为消息上下文
|
|
|
+ *
|
|
|
+ * n 组:指的是 user + assistant 形成一组
|
|
|
+ *
|
|
|
+ * @param messages 消息列表
|
|
|
+ * @param conversation 会话
|
|
|
+ * @param sendReqVO 发送请求
|
|
|
+ * @return 消息上下文
|
|
|
+ */
|
|
|
+ private List<AiChatMessageDO> filterContextMessages(List<AiChatMessageDO> messages, AiChatConversationDO conversation, AiChatMessageSendReqVO sendReqVO) {
|
|
|
+ if (conversation.getMaxContexts() == null || ObjUtil.notEqual(sendReqVO.getUseContext(), Boolean.TRUE)) {
|
|
|
+ return Collections.emptyList();
|
|
|
+ }
|
|
|
+ List<AiChatMessageDO> contextMessages = new ArrayList<>(conversation.getMaxContexts() * 2);
|
|
|
+ for (int i = messages.size() - 1; i >= 0; i--) {
|
|
|
+ AiChatMessageDO assistantMessage = CollUtil.get(messages, i);
|
|
|
+ if (assistantMessage == null || assistantMessage.getReplyId() == null) {
|
|
|
+ continue;
|
|
|
+ }
|
|
|
+ AiChatMessageDO userMessage = CollUtil.get(messages, i - 1);
|
|
|
+ if (userMessage == null || ObjUtil.notEqual(assistantMessage.getReplyId(), userMessage.getId())
|
|
|
+ || StrUtil.isEmpty(assistantMessage.getContent())) {
|
|
|
+ continue;
|
|
|
+ }
|
|
|
+ // 由于后续要 reverse 反转,所以先添加 assistantMessage
|
|
|
+ contextMessages.add(assistantMessage);
|
|
|
+ contextMessages.add(userMessage);
|
|
|
+ // 超过最大上下文,结束
|
|
|
+ if (contextMessages.size() >= conversation.getMaxContexts() * 2) {
|
|
|
+ break;
|
|
|
+ }
|
|
|
+ }
|
|
|
+ Collections.reverse(contextMessages);
|
|
|
+ return contextMessages;
|
|
|
+ }
|
|
|
+
|
|
|
+ private AiChatMessageDO createChatMessage(Long conversationId, Long replyId,
|
|
|
+ AiChatModelDO model, Long userId, Long roleId,
|
|
|
+ MessageType messageType, String content, Boolean useContext) {
|
|
|
+ AiChatMessageDO message = new AiChatMessageDO().setConversationId(conversationId).setReplyId(replyId)
|
|
|
+ .setModel(model.getModel()).setModelId(model.getId()).setUserId(userId).setRoleId(roleId)
|
|
|
+ .setType(messageType.getValue()).setContent(content).setUseContext(useContext);
|
|
|
message.setCreateTime(LocalDateTime.now());
|
|
|
chatMessageMapper.insert(message);
|
|
|
return message;
|