|
@@ -10,14 +10,19 @@ import cn.iocoder.yudao.framework.ai.chat.prompt.Prompt;
|
|
|
import cn.iocoder.yudao.framework.security.core.util.SecurityFrameworkUtils;
|
|
|
import cn.iocoder.yudao.module.ai.config.AiChatClientFactory;
|
|
|
import cn.iocoder.yudao.module.ai.controller.Utf8SseEmitter;
|
|
|
+import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.conversation.AiChatConversationRespVO;
|
|
|
+import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageRespVO;
|
|
|
+import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageSendReqVO;
|
|
|
+import cn.iocoder.yudao.module.ai.controller.admin.model.vo.model.AiChatModalRes;
|
|
|
import cn.iocoder.yudao.module.ai.dal.dataobject.chat.AiChatMessageDO;
|
|
|
+import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatRoleDO;
|
|
|
import cn.iocoder.yudao.module.ai.dal.mysql.AiChatConversationMapper;
|
|
|
import cn.iocoder.yudao.module.ai.dal.mysql.AiChatMessageMapper;
|
|
|
import cn.iocoder.yudao.module.ai.dal.mysql.AiChatRoleMapper;
|
|
|
import cn.iocoder.yudao.module.ai.service.AiChatConversationService;
|
|
|
+import cn.iocoder.yudao.module.ai.service.AiChatModalService;
|
|
|
+import cn.iocoder.yudao.module.ai.service.AiChatRoleService;
|
|
|
import cn.iocoder.yudao.module.ai.service.AiChatService;
|
|
|
-import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.conversation.AiChatConversationRespVO;
|
|
|
-import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageSendReqVO;
|
|
|
import lombok.AllArgsConstructor;
|
|
|
import lombok.extern.slf4j.Slf4j;
|
|
|
import org.springframework.http.MediaType;
|
|
@@ -45,29 +50,39 @@ public class AiChatServiceImpl implements AiChatService {
|
|
|
private final AiChatMessageMapper aiChatMessageMapper;
|
|
|
private final AiChatConversationMapper aiChatConversationMapper;
|
|
|
private final AiChatConversationService chatConversationService;
|
|
|
+ private final AiChatModalService aiChatModalService;
|
|
|
+ private final AiChatRoleService aiChatRoleService;
|
|
|
|
|
|
- /**
|
|
|
- * chat
|
|
|
- *
|
|
|
- * @param req
|
|
|
- * @return
|
|
|
- */
|
|
|
@Transactional(rollbackFor = Exception.class)
|
|
|
- public String chat(AiChatMessageSendReqVO req) {
|
|
|
+ public AiChatMessageRespVO chat(AiChatMessageSendReqVO req) {
|
|
|
Long loginUserId = SecurityFrameworkUtils.getLoginUserId();
|
|
|
+ // 查询对话
|
|
|
+ AiChatConversationRespVO conversation = chatConversationService.getConversationOfValidate(req.getConversationId());
|
|
|
+ // 获取对话模型
|
|
|
+ AiChatModalRes chatModal = aiChatModalService.getChatModalOfValidate(conversation.getModelId());
|
|
|
+ // 对话模型是否可用
|
|
|
+ aiChatModalService.validateAvailable(chatModal);
|
|
|
+ // 获取角色信息
|
|
|
+ AiChatRoleDO aiChatRoleDO = null;
|
|
|
+ if (conversation.getRoleId() != null) {
|
|
|
+ aiChatRoleDO = aiChatRoleService.validateExists(conversation.getRoleId());
|
|
|
+ }
|
|
|
+ // 校验角色是否公开
|
|
|
+ aiChatRoleService.validateIsPublic(aiChatRoleDO);
|
|
|
// 获取 client 类型
|
|
|
- AiPlatformEnum platformEnum = AiPlatformEnum.valueOfPlatform(req.getModal());
|
|
|
- // 获取对话信息
|
|
|
- AiChatConversationRespVO conversationRes = chatConversationService.getConversationOfValidate(req.getConversationId());
|
|
|
+ AiPlatformEnum platformEnum = AiPlatformEnum.valueOfPlatform(chatModal.getModal());
|
|
|
// 保存 chat message
|
|
|
- saveChatMessage(req, conversationRes, loginUserId);
|
|
|
+ insertChatMessage(conversation.getId(), MessageType.USER, loginUserId, conversation.getRoleId(),
|
|
|
+ chatModal.getModal(), chatModal.getId(), req.getContent(),
|
|
|
+ null, conversation.getTemperature(), conversation.getMaxTokens(), conversation.getMaxContexts());
|
|
|
String content = null;
|
|
|
try {
|
|
|
// 创建 chat 需要的 Prompt
|
|
|
- Prompt prompt = new Prompt(req.getPrompt());
|
|
|
- req.setTopK(req.getTopK());
|
|
|
- req.setTopP(req.getTopP());
|
|
|
- req.setTemperature(req.getTemperature());
|
|
|
+ Prompt prompt = new Prompt(req.getContent());
|
|
|
+ // TODO @芋艿 @范 看要不要支持这些
|
|
|
+// req.setTopK(req.getTopK());
|
|
|
+// req.setTopP(req.getTopP());
|
|
|
+// req.setTemperature(req.getTemperature());
|
|
|
// 发送 call 调用
|
|
|
ChatClient chatClient = aiChatClientFactory.getChatClient(platformEnum);
|
|
|
ChatResponse call = chatClient.call(prompt);
|
|
@@ -78,69 +93,66 @@ public class AiChatServiceImpl implements AiChatService {
|
|
|
content = ExceptionUtil.getMessage(e);
|
|
|
} finally {
|
|
|
// 保存 chat message
|
|
|
- saveSystemChatMessage(req, conversationRes, loginUserId, content);
|
|
|
+ insertChatMessage(conversation.getId(), MessageType.SYSTEM, loginUserId, conversation.getRoleId(),
|
|
|
+ chatModal.getModal(), chatModal.getId(), req.getContent(),
|
|
|
+ null, conversation.getTemperature(), conversation.getMaxTokens(), conversation.getMaxContexts());
|
|
|
}
|
|
|
- return content;
|
|
|
+ return new AiChatMessageRespVO().setContent(content);
|
|
|
}
|
|
|
|
|
|
- private void saveChatMessage(AiChatMessageSendReqVO req, AiChatConversationRespVO conversationRes, Long loginUserId) {
|
|
|
- Long chatConversationId = conversationRes.getId();
|
|
|
- // 增加 chat message 记录
|
|
|
- aiChatMessageMapper.insert(
|
|
|
- new AiChatMessageDO()
|
|
|
- .setId(null)
|
|
|
- .setConversationId(chatConversationId)
|
|
|
- .setUserId(loginUserId)
|
|
|
- .setMessage(req.getPrompt())
|
|
|
- .setMessageType(MessageType.USER.getValue())
|
|
|
- .setTopK(req.getTopK())
|
|
|
- .setTopP(req.getTopP())
|
|
|
- .setTemperature(req.getTemperature())
|
|
|
- );
|
|
|
- // chat count 先+1
|
|
|
- aiChatConversationMapper.updateIncrChatCount(req.getConversationId());
|
|
|
- }
|
|
|
+ private AiChatMessageDO insertChatMessage(Long conversationId, MessageType messageType, Long loginUserId, Long roleId,
|
|
|
+ String model, Long modelId, String content, Integer tokens, Double temperature,
|
|
|
+ Integer maxTokens, Integer maxContexts) {
|
|
|
+ AiChatMessageDO insertChatMessageDO = new AiChatMessageDO()
|
|
|
+ .setId(null)
|
|
|
+ .setConversationId(conversationId)
|
|
|
+ .setType(messageType.getValue())
|
|
|
+ .setUserId(loginUserId)
|
|
|
+ .setRoleId(roleId)
|
|
|
+ .setModel(model)
|
|
|
+ .setModelId(modelId)
|
|
|
+ .setContent(content)
|
|
|
+ .setTokens(tokens)
|
|
|
|
|
|
- public void saveSystemChatMessage(AiChatMessageSendReqVO req, AiChatConversationRespVO conversationRes, Long loginUserId, String systemPrompts) {
|
|
|
- Long chatConversationId = conversationRes.getId();
|
|
|
+ .setTemperature(temperature)
|
|
|
+ .setMaxTokens(maxTokens)
|
|
|
+ .setMaxContexts(maxContexts);
|
|
|
// 增加 chat message 记录
|
|
|
- aiChatMessageMapper.insert(
|
|
|
- new AiChatMessageDO()
|
|
|
- .setId(null)
|
|
|
- .setConversationId(chatConversationId)
|
|
|
- .setUserId(loginUserId)
|
|
|
- .setMessage(systemPrompts)
|
|
|
- .setMessageType(MessageType.SYSTEM.getValue())
|
|
|
- .setTopK(req.getTopK())
|
|
|
- .setTopP(req.getTopP())
|
|
|
- .setTemperature(req.getTemperature())
|
|
|
- );
|
|
|
-
|
|
|
+ aiChatMessageMapper.insert(insertChatMessageDO);
|
|
|
// chat count 先+1
|
|
|
- aiChatConversationMapper.updateIncrChatCount(req.getConversationId());
|
|
|
+ aiChatConversationMapper.updateIncrChatCount(conversationId);
|
|
|
+ return insertChatMessageDO;
|
|
|
}
|
|
|
|
|
|
- /**
|
|
|
- * chat stream
|
|
|
- *
|
|
|
- * @param req
|
|
|
- * @param sseEmitter
|
|
|
- * @return
|
|
|
- */
|
|
|
@Override
|
|
|
public void chatStream(AiChatMessageSendReqVO req, Utf8SseEmitter sseEmitter) {
|
|
|
Long loginUserId = SecurityFrameworkUtils.getLoginUserId();
|
|
|
- // 获取 client 类型
|
|
|
- AiPlatformEnum platformEnum = AiPlatformEnum.valueOfPlatform(req.getModal());
|
|
|
- // 获取对话信息
|
|
|
- AiChatConversationRespVO conversationRes = chatConversationService.getConversationOfValidate(req.getConversationId());
|
|
|
+ // 查询对话
|
|
|
+ AiChatConversationRespVO conversation = chatConversationService.getConversationOfValidate(req.getConversationId());
|
|
|
+ // 获取对话模型
|
|
|
+ AiChatModalRes chatModal = aiChatModalService.getChatModalOfValidate(conversation.getModelId());
|
|
|
+ // 对话模型是否可用
|
|
|
+ aiChatModalService.validateAvailable(chatModal);
|
|
|
+ // 获取角色信息
|
|
|
+ AiChatRoleDO aiChatRoleDO = null;
|
|
|
+ if (conversation.getRoleId() != null) {
|
|
|
+ aiChatRoleDO = aiChatRoleService.validateExists(conversation.getRoleId());
|
|
|
+ }
|
|
|
+ // 校验角色是否公开
|
|
|
+ aiChatRoleService.validateIsPublic(aiChatRoleDO);
|
|
|
// 创建 chat 需要的 Prompt
|
|
|
- Prompt prompt = new Prompt(req.getPrompt());
|
|
|
- req.setTopK(req.getTopK());
|
|
|
- req.setTopP(req.getTopP());
|
|
|
- req.setTemperature(req.getTemperature());
|
|
|
+ Prompt prompt = new Prompt(req.getContent());
|
|
|
+// req.setTopK(req.getTopK());
|
|
|
+// req.setTopP(req.getTopP());
|
|
|
+// req.setTemperature(req.getTemperature());
|
|
|
// 保存 chat message
|
|
|
- saveChatMessage(req, conversationRes, loginUserId);
|
|
|
+ // 保存 chat message
|
|
|
+ insertChatMessage(conversation.getId(), MessageType.USER, loginUserId, conversation.getRoleId(),
|
|
|
+ chatModal.getModal(), chatModal.getId(), req.getContent(),
|
|
|
+ null, conversation.getTemperature(), conversation.getMaxTokens(), conversation.getMaxContexts());
|
|
|
+
|
|
|
+ // 获取 client 类型
|
|
|
+ AiPlatformEnum platformEnum = AiPlatformEnum.valueOfPlatform(chatModal.getModal());
|
|
|
StreamingChatClient streamingChatClient = aiChatClientFactory.getStreamingChatClient(platformEnum);
|
|
|
Flux<ChatResponse> streamResponse = streamingChatClient.stream(prompt);
|
|
|
|
|
@@ -168,7 +180,10 @@ public class AiChatServiceImpl implements AiChatService {
|
|
|
log.info("发送完成!");
|
|
|
sseEmitter.complete();
|
|
|
// 保存 chat message
|
|
|
- saveSystemChatMessage(req, conversationRes, loginUserId, contentBuffer.toString());
|
|
|
+ insertChatMessage(conversation.getId(), MessageType.SYSTEM, loginUserId, conversation.getRoleId(),
|
|
|
+ chatModal.getModal(), chatModal.getId(), req.getContent(),
|
|
|
+ null, conversation.getTemperature(), conversation.getMaxTokens(), conversation.getMaxContexts());
|
|
|
+
|
|
|
}
|
|
|
);
|
|
|
}
|