|
@@ -19,7 +19,6 @@ import cn.iocoder.yudao.module.ai.dal.dataobject.chat.AiChatMessageDO;
|
|
|
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatRoleDO;
|
|
|
import cn.iocoder.yudao.module.ai.dal.mysql.AiChatConversationMapper;
|
|
|
import cn.iocoder.yudao.module.ai.dal.mysql.AiChatMessageMapper;
|
|
|
-import cn.iocoder.yudao.module.ai.dal.mysql.AiChatRoleMapper;
|
|
|
import cn.iocoder.yudao.module.ai.service.AiChatConversationService;
|
|
|
import cn.iocoder.yudao.module.ai.service.AiChatModelService;
|
|
|
import cn.iocoder.yudao.module.ai.service.AiChatRoleService;
|
|
@@ -48,7 +47,6 @@ import java.util.function.Consumer;
|
|
|
public class AiChatServiceImpl implements AiChatService {
|
|
|
|
|
|
private final AiChatClientFactory aiChatClientFactory;
|
|
|
- private final AiChatRoleMapper aiChatRoleMapper;
|
|
|
private final AiChatMessageMapper aiChatMessageMapper;
|
|
|
private final AiChatConversationMapper aiChatConversationMapper;
|
|
|
private final AiChatConversationService chatConversationService;
|
|
@@ -72,7 +70,7 @@ public class AiChatServiceImpl implements AiChatService {
|
|
|
// 校验角色是否公开
|
|
|
aiChatRoleService.validateIsPublic(aiChatRoleDO);
|
|
|
// 获取 client 类型
|
|
|
- AiPlatformEnum platformEnum = AiPlatformEnum.valueOfPlatform(chatModal.getModel());
|
|
|
+ AiPlatformEnum platformEnum = AiPlatformEnum.valueOfPlatform(chatModal.getPlatform());
|
|
|
// 保存 chat message
|
|
|
insertChatMessage(conversation.getId(), MessageType.USER, loginUserId, conversation.getRoleId(),
|
|
|
chatModal.getModel(), chatModal.getId(), req.getContent(),
|
|
@@ -90,13 +88,12 @@ public class AiChatServiceImpl implements AiChatService {
|
|
|
ChatResponse call = chatClient.call(prompt);
|
|
|
content = call.getResult().getOutput().getContent();
|
|
|
// 更新 conversation
|
|
|
-
|
|
|
} catch (Exception e) {
|
|
|
content = ExceptionUtil.getMessage(e);
|
|
|
} finally {
|
|
|
// 保存 chat message
|
|
|
insertChatMessage(conversation.getId(), MessageType.SYSTEM, loginUserId, conversation.getRoleId(),
|
|
|
- chatModal.getModel(), chatModal.getId(), req.getContent(),
|
|
|
+ chatModal.getModel(), chatModal.getId(), content,
|
|
|
null, conversation.getTemperature(), conversation.getMaxTokens(), conversation.getMaxContexts());
|
|
|
}
|
|
|
return new AiChatMessageRespVO().setContent(content);
|
|
@@ -154,7 +151,7 @@ public class AiChatServiceImpl implements AiChatService {
|
|
|
null, conversation.getTemperature(), conversation.getMaxTokens(), conversation.getMaxContexts());
|
|
|
|
|
|
// 获取 client 类型
|
|
|
- AiPlatformEnum platformEnum = AiPlatformEnum.valueOfPlatform(chatModal.getModel());
|
|
|
+ AiPlatformEnum platformEnum = AiPlatformEnum.valueOfPlatform(chatModal.getPlatform());
|
|
|
StreamingChatClient streamingChatClient = aiChatClientFactory.getStreamingChatClient(platformEnum);
|
|
|
Flux<ChatResponse> streamResponse = streamingChatClient.stream(prompt);
|
|
|
|
|
@@ -166,7 +163,7 @@ public class AiChatServiceImpl implements AiChatService {
|
|
|
String content = chatResponse.getResults().get(0).getOutput().getContent();
|
|
|
try {
|
|
|
contentBuffer.append(content);
|
|
|
- sseEmitter.send(content, MediaType.APPLICATION_JSON);
|
|
|
+ sseEmitter.send(new AiChatMessageRespVO().setContent(content), MediaType.APPLICATION_JSON);
|
|
|
} catch (IOException e) {
|
|
|
log.error("发送异常{}", ExceptionUtil.getMessage(e));
|
|
|
// 如果不是因为关闭而抛出异常,则重新连接
|
|
@@ -183,7 +180,7 @@ public class AiChatServiceImpl implements AiChatService {
|
|
|
sseEmitter.complete();
|
|
|
// 保存 chat message
|
|
|
insertChatMessage(conversation.getId(), MessageType.SYSTEM, loginUserId, conversation.getRoleId(),
|
|
|
- chatModal.getModel(), chatModal.getId(), req.getContent(),
|
|
|
+ chatModal.getModel(), chatModal.getId(), contentBuffer.toString(),
|
|
|
null, conversation.getTemperature(), conversation.getMaxTokens(), conversation.getMaxContexts());
|
|
|
|
|
|
}
|