|
@@ -9,7 +9,6 @@ import org.springframework.ai.chat.ChatClient;
|
|
|
import org.springframework.ai.chat.ChatResponse;
|
|
|
import org.springframework.ai.chat.StreamingChatClient;
|
|
|
import org.springframework.ai.chat.messages.MessageType;
|
|
|
-import org.springframework.ai.chat.prompt.ChatOptionsBuilder;
|
|
|
import org.springframework.ai.chat.prompt.Prompt;
|
|
|
import cn.iocoder.yudao.framework.security.core.util.SecurityFrameworkUtils;
|
|
|
import cn.iocoder.yudao.module.ai.config.AiChatClientFactory;
|
|
@@ -72,8 +71,7 @@ public class AiChatServiceImpl implements AiChatService {
|
|
|
AiPlatformEnum platformEnum = AiPlatformEnum.validatePlatform(chatModel.getPlatform());
|
|
|
// 保存 chat message
|
|
|
insertChatMessage(conversation.getId(), MessageType.USER, loginUserId, conversation.getRoleId(),
|
|
|
- chatModel.getModel(), chatModel.getId(), req.getContent(),
|
|
|
- null, conversation.getTemperature(), conversation.getMaxTokens(), conversation.getMaxContexts());
|
|
|
+ chatModel.getModel(), chatModel.getId(), req.getContent());
|
|
|
String content = null;
|
|
|
int tokens = 0;
|
|
|
try {
|
|
@@ -94,28 +92,21 @@ public class AiChatServiceImpl implements AiChatService {
|
|
|
} finally {
|
|
|
// 保存 chat message
|
|
|
insertChatMessage(conversation.getId(), MessageType.SYSTEM, loginUserId, conversation.getRoleId(),
|
|
|
- chatModel.getModel(), chatModel.getId(), content,
|
|
|
- tokens, conversation.getTemperature(), conversation.getMaxTokens(), conversation.getMaxContexts());
|
|
|
+ chatModel.getModel(), chatModel.getId(), content);
|
|
|
}
|
|
|
return new AiChatMessageRespVO().setContent(content);
|
|
|
}
|
|
|
|
|
|
private AiChatMessageDO insertChatMessage(Long conversationId, MessageType messageType, Long loginUserId, Long roleId,
|
|
|
- String model, Long modelId, String content, Integer tokens, Double temperature,
|
|
|
- Integer maxTokens, Integer maxContexts) {
|
|
|
+ String model, Long modelId, String content) {
|
|
|
AiChatMessageDO insertChatMessageDO = new AiChatMessageDO()
|
|
|
- .setId(null)
|
|
|
.setConversationId(conversationId)
|
|
|
.setType(messageType.getValue())
|
|
|
.setUserId(loginUserId)
|
|
|
.setRoleId(roleId)
|
|
|
.setModel(model)
|
|
|
.setModelId(modelId)
|
|
|
- .setContent(content)
|
|
|
- .setTokens(tokens)
|
|
|
- .setTemperature(temperature)
|
|
|
- .setMaxTokens(maxTokens)
|
|
|
- .setMaxContexts(maxContexts);
|
|
|
+ .setContent(content);
|
|
|
insertChatMessageDO.setCreateTime(LocalDateTime.now());
|
|
|
// 增加 chat message 记录
|
|
|
aiChatMessageMapper.insert(insertChatMessageDO);
|
|
@@ -134,15 +125,13 @@ public class AiChatServiceImpl implements AiChatService {
|
|
|
AiPlatformEnum platform = AiPlatformEnum.validatePlatform(model.getPlatform());
|
|
|
StreamingChatClient chatClient = chatClientFactory.getStreamingChatClient(platform);
|
|
|
|
|
|
- // 2. 插入 user 发送消息 TODO tokens 计算
|
|
|
+ // 2. 插入 user 发送消息
|
|
|
AiChatMessageDO userMessage = insertChatMessage(conversation.getId(), MessageType.USER, userId, conversation.getRoleId(),
|
|
|
- conversation.getModel(), conversation.getId(), sendReqVO.getContent(),
|
|
|
- null, conversation.getTemperature(), conversation.getMaxTokens(), conversation.getMaxContexts());
|
|
|
+ conversation.getModel(), conversation.getId(), sendReqVO.getContent());
|
|
|
|
|
|
// 3.1 插入 system 接收消息
|
|
|
AiChatMessageDO systemMessage = insertChatMessage(conversation.getId(), MessageType.SYSTEM, userId, conversation.getRoleId(),
|
|
|
- conversation.getModel(), conversation.getId(), conversation.getSystemMessage(),
|
|
|
- 0, conversation.getTemperature(), conversation.getMaxTokens(), conversation.getMaxContexts());
|
|
|
+ conversation.getModel(), conversation.getId(), conversation.getSystemMessage());
|
|
|
// 3.2 创建 chat 需要的 Prompt
|
|
|
// TODO 消息上下文
|
|
|
Prompt prompt = new Prompt(sendReqVO.getContent());
|
|
@@ -150,11 +139,8 @@ public class AiChatServiceImpl implements AiChatService {
|
|
|
Flux<ChatResponse> streamResponse = chatClient.stream(prompt);
|
|
|
// 3.3 转换 flex AiChatMessageRespVO
|
|
|
StringBuffer contentBuffer = new StringBuffer();
|
|
|
- AtomicInteger tokens = new AtomicInteger(0); // TODO token 计算不对;
|
|
|
return streamResponse.map(res -> {
|
|
|
contentBuffer.append(res.getResult().getOutput().getContent());
|
|
|
- tokens.incrementAndGet();
|
|
|
-
|
|
|
AiChatMessageSendRespVO.Message send = new AiChatMessageSendRespVO.Message().setId(userMessage.getId())
|
|
|
.setType(MessageType.USER.getValue()).setCreateTime(userMessage.getCreateTime())
|
|
|
.setContent(sendReqVO.getContent());
|
|
@@ -167,17 +153,13 @@ public class AiChatServiceImpl implements AiChatService {
|
|
|
// 保存 chat message
|
|
|
aiChatMessageMapper.updateById(new AiChatMessageDO()
|
|
|
.setId(systemMessage.getId())
|
|
|
- .setContent(contentBuffer.toString())
|
|
|
- .setTokens(tokens.get())
|
|
|
- );
|
|
|
+ .setContent(contentBuffer.toString()));
|
|
|
}).doOnError(throwable -> {
|
|
|
log.error("发送错误 {}!", throwable.getMessage());
|
|
|
// 更新错误信息 TODO 貌似不应该更新异常
|
|
|
aiChatMessageMapper.updateById(new AiChatMessageDO()
|
|
|
.setId(systemMessage.getId())
|
|
|
- .setContent(throwable.getMessage())
|
|
|
- .setTokens(tokens.get())
|
|
|
- );
|
|
|
+ .setContent(throwable.getMessage()));
|
|
|
});
|
|
|
}
|
|
|
|