|
@@ -12,7 +12,6 @@ import cn.iocoder.yudao.module.ai.config.AiChatClientFactory;
|
|
|
import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.conversation.AiChatConversationRespVO;
|
|
|
import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageRespVO;
|
|
|
import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageSendReqVO;
|
|
|
-import cn.iocoder.yudao.module.ai.controller.admin.model.vo.model.AiChatModalRespVO;
|
|
|
import cn.iocoder.yudao.module.ai.convert.AiChatMessageConvert;
|
|
|
import cn.iocoder.yudao.module.ai.dal.dataobject.chat.AiChatMessageDO;
|
|
|
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO;
|
|
@@ -20,7 +19,7 @@ import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatRoleDO;
|
|
|
import cn.iocoder.yudao.module.ai.dal.mysql.AiChatConversationMapper;
|
|
|
import cn.iocoder.yudao.module.ai.dal.mysql.AiChatMessageMapper;
|
|
|
import cn.iocoder.yudao.module.ai.service.AiChatConversationService;
|
|
|
-import cn.iocoder.yudao.module.ai.service.AiChatModelService;
|
|
|
+import cn.iocoder.yudao.module.ai.service.model.AiChatModelService;
|
|
|
import cn.iocoder.yudao.module.ai.service.AiChatRoleService;
|
|
|
import cn.iocoder.yudao.module.ai.service.AiChatService;
|
|
|
import lombok.AllArgsConstructor;
|
|
@@ -61,9 +60,7 @@ public class AiChatServiceImpl implements AiChatService {
|
|
|
// 查询对话
|
|
|
AiChatConversationRespVO conversation = chatConversationService.getConversationOfValidate(req.getConversationId());
|
|
|
// 获取对话模型
|
|
|
- AiChatModalRespVO chatModal = aiChatModalService.getChatModalOfValidate(conversation.getModelId());
|
|
|
- // 对话模型是否可用
|
|
|
- aiChatModalService.validateAvailable(chatModal);
|
|
|
+ AiChatModelDO chatModel = aiChatModalService.validateChatModel(conversation.getModelId());
|
|
|
// 获取角色信息
|
|
|
AiChatRoleDO aiChatRoleDO = null;
|
|
|
if (conversation.getRoleId() != null) {
|
|
@@ -72,10 +69,10 @@ public class AiChatServiceImpl implements AiChatService {
|
|
|
// 校验角色是否公开
|
|
|
aiChatRoleService.validateIsPublic(aiChatRoleDO);
|
|
|
// 获取 client 类型
|
|
|
- AiPlatformEnum platformEnum = AiPlatformEnum.valueOfPlatform(chatModal.getPlatform());
|
|
|
+ AiPlatformEnum platformEnum = AiPlatformEnum.validatePlatform(chatModel.getPlatform());
|
|
|
// 保存 chat message
|
|
|
insertChatMessage(conversation.getId(), MessageType.USER, loginUserId, conversation.getRoleId(),
|
|
|
- chatModal.getModel(), chatModal.getId(), req.getContent(),
|
|
|
+ chatModel.getModel(), chatModel.getId(), req.getContent(),
|
|
|
null, conversation.getTemperature(), conversation.getMaxTokens(), conversation.getMaxContexts());
|
|
|
String content = null;
|
|
|
int tokens = 0;
|
|
@@ -97,7 +94,7 @@ public class AiChatServiceImpl implements AiChatService {
|
|
|
} finally {
|
|
|
// 保存 chat message
|
|
|
insertChatMessage(conversation.getId(), MessageType.SYSTEM, loginUserId, conversation.getRoleId(),
|
|
|
- chatModal.getModel(), chatModal.getId(), content,
|
|
|
+ chatModel.getModel(), chatModel.getId(), content,
|
|
|
tokens, conversation.getTemperature(), conversation.getMaxTokens(), conversation.getMaxContexts());
|
|
|
}
|
|
|
return new AiChatMessageRespVO().setContent(content);
|
|
@@ -132,9 +129,7 @@ public class AiChatServiceImpl implements AiChatService {
|
|
|
// 查询对话
|
|
|
AiChatConversationRespVO conversation = chatConversationService.getConversationOfValidate(req.getConversationId());
|
|
|
// 获取对话模型
|
|
|
- AiChatModalRespVO chatModal = aiChatModalService.getChatModalOfValidate(conversation.getModelId());
|
|
|
- // 对话模型是否可用
|
|
|
- aiChatModalService.validateAvailable(chatModal);
|
|
|
+ AiChatModelDO chatModel = aiChatModalService.validateChatModel(conversation.getModelId());
|
|
|
// 获取角色信息
|
|
|
AiChatRoleDO aiChatRoleDO = null;
|
|
|
if (conversation.getRoleId() != null) {
|
|
@@ -149,10 +144,10 @@ public class AiChatServiceImpl implements AiChatService {
|
|
|
// req.setTemperature(req.getTemperature());
|
|
|
// 保存 chat message
|
|
|
insertChatMessage(conversation.getId(), MessageType.USER, loginUserId, conversation.getRoleId(),
|
|
|
- chatModal.getModel(), chatModal.getId(), req.getContent(),
|
|
|
+ chatModel.getModel(), chatModel.getId(), req.getContent(),
|
|
|
null, conversation.getTemperature(), conversation.getMaxTokens(), conversation.getMaxContexts());
|
|
|
// 获取 client 类型
|
|
|
- AiPlatformEnum platformEnum = AiPlatformEnum.valueOfPlatform(chatModal.getPlatform());
|
|
|
+ AiPlatformEnum platformEnum = AiPlatformEnum.validatePlatform(chatModel.getPlatform());
|
|
|
StreamingChatClient streamingChatClient = aiChatClientFactory.getStreamingChatClient(platformEnum);
|
|
|
Flux<ChatResponse> streamResponse = streamingChatClient.stream(prompt);
|
|
|
// 转换 flex AiChatMessageRespVO
|
|
@@ -171,7 +166,7 @@ public class AiChatServiceImpl implements AiChatService {
|
|
|
log.info("发送完成!");
|
|
|
// 保存 chat message
|
|
|
insertChatMessage(conversation.getId(), MessageType.SYSTEM, loginUserId, conversation.getRoleId(),
|
|
|
- chatModal.getModel(), chatModal.getId(), contentBuffer.toString(),
|
|
|
+ chatModel.getModel(), chatModel.getId(), contentBuffer.toString(),
|
|
|
tokens.get(), conversation.getTemperature(), conversation.getMaxTokens(), conversation.getMaxContexts());
|
|
|
}
|
|
|
}).doOnError(new Consumer<Throwable>() {
|
|
@@ -180,7 +175,7 @@ public class AiChatServiceImpl implements AiChatService {
|
|
|
log.error("发送错误 {}!", throwable.getMessage());
|
|
|
// 保存 chat message
|
|
|
insertChatMessage(conversation.getId(), MessageType.SYSTEM, loginUserId, conversation.getRoleId(),
|
|
|
- chatModal.getModel(), chatModal.getId(), throwable.getMessage(),
|
|
|
+ chatModel.getModel(), chatModel.getId(), throwable.getMessage(),
|
|
|
tokens.get(), conversation.getTemperature(), conversation.getMaxTokens(), conversation.getMaxContexts());
|
|
|
}
|
|
|
});
|