Bladeren bron

【调整】调整AI聊天模块

cherishsince 1 jaar geleden
bovenliggende
commit
424210066f

+ 1 - 0
yudao-module-ai/yudao-module-ai-api/src/main/java/cn/iocoder/yudao/module/ai/ErrorCodeConstants.java

@@ -30,6 +30,7 @@ public interface ErrorCodeConstants {
     // role
 
     ErrorCode AI_CHAT_ROLE_NOT_EXIST = new ErrorCode(1_022_000_060, "chatRole 不存在!");
+    ErrorCode AI_CHAT_ROLE_NOT_PUBLIC = new ErrorCode(1_022_000_060, "AI 角色未公开!");
 
     // modal
 

+ 1 - 1
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/chat/AiChatMessageController.java

@@ -32,7 +32,7 @@ public class AiChatMessageController {
     @PostMapping("/send")
     public CommonResult<AiChatMessageRespVO> sendMessage(@Validated @ModelAttribute AiChatMessageSendReqVO sendReqVO) {
         // TODO @fan:使用 static import;这样就 success 就行了;
-        return success(null);
+        return success(chatService.chat(sendReqVO));
     }
 
     // TODO @芋艿:调用这个方法异常,Unable to handle the Spring Security Exception because the response is already committed.;可以再试试

+ 3 - 0
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/chat/vo/conversation/AiChatConversationRespVO.java

@@ -21,6 +21,9 @@ public class AiChatConversationRespVO {
     @Schema(description = "是否置顶", requiredMode = Schema.RequiredMode.REQUIRED, example = "true")
     private Boolean pinned;
 
+    @Schema(description = "角色编号", requiredMode = Schema.RequiredMode.NOT_REQUIRED, example = "1")
+    private Long roleId;
+
     @Schema(description = "模型编号", requiredMode = Schema.RequiredMode.REQUIRED, example = "1")
     private Long modelId;
 

+ 7 - 0
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/AiChatModalService.java

@@ -52,4 +52,11 @@ public interface AiChatModalService {
      * @return
      */
     AiChatModalRes getChatModalOfValidate(Long modalId);
+
+    /**
+     * 校验 - 校验是否可用
+     *
+     * @param chatModal
+     */
+    void validateAvailable(AiChatModalRes chatModal);
 }

+ 16 - 0
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/AiChatRoleService.java

@@ -2,6 +2,7 @@ package cn.iocoder.yudao.module.ai.service;
 
 import cn.iocoder.yudao.framework.common.pojo.PageResult;
 import cn.iocoder.yudao.module.ai.controller.admin.model.vo.role.*;
+import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatRoleDO;
 
 /**
  * chat 角色
@@ -58,4 +59,19 @@ public interface AiChatRoleService {
      * @return
      */
     AiChatRoleRes getChatRole(Long roleId);
+
+    /**
+     * 校验 - 角色是否存在
+     *
+     * @param id
+     * @return
+     */
+    AiChatRoleDO validateExists(Long id);
+
+    /**
+     * 校验 - 角色是否公开
+     *
+     * @param aiChatRoleDO
+     */
+    void validateIsPublic(AiChatRoleDO aiChatRoleDO);
 }

+ 3 - 2
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/AiChatService.java

@@ -1,6 +1,7 @@
 package cn.iocoder.yudao.module.ai.service;
 
 import cn.iocoder.yudao.module.ai.controller.Utf8SseEmitter;
+import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageRespVO;
 import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageSendReqVO;
 
 /**
@@ -15,10 +16,10 @@ public interface AiChatService {
     /**
      * chat
      *
-     * @param req
+     * @param sendReqVO
      * @return
      */
-    String chat(AiChatMessageSendReqVO req);
+    AiChatMessageRespVO chat(AiChatMessageSendReqVO sendReqVO);
 
     /**
      * chat stream

+ 1 - 4
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/impl/AiChatConversationServiceImpl.java

@@ -15,7 +15,6 @@ import cn.iocoder.yudao.module.ai.dal.dataobject.chat.AiChatConversationDO;
 import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModalDO;
 import cn.iocoder.yudao.module.ai.dal.mysql.AiChatConversationMapper;
 import cn.iocoder.yudao.module.ai.dal.mysql.AiChatModalMapper;
-import cn.iocoder.yudao.module.ai.enums.AiChatModalDisableEnum;
 import cn.iocoder.yudao.module.ai.service.AiChatConversationService;
 import cn.iocoder.yudao.module.ai.service.AiChatModalService;
 import cn.iocoder.yudao.module.ai.service.AiChatRoleService;
@@ -91,9 +90,7 @@ public class AiChatConversationServiceImpl implements AiChatConversationService
         // 获取模型信息并验证
         AiChatModalRes chatModal = aiChatModalService.getChatModalOfValidate(updateReqVO.getModelId());
         // 校验modal是否可用
-        if (AiChatModalDisableEnum.YES.getValue().equals(chatModal.getDisable())) {
-            throw ServiceExceptionUtil.exception(ErrorCodeConstants.AI_MODAL_DISABLE_NOT_USED);
-        }
+        aiChatModalService.validateAvailable(chatModal);
         // 更新对话信息
         AiChatConversationDO updateAiChatConversationDO
                 = AiChatConversationConvert.INSTANCE.convertAiChatConversationDO(updateReqVO);

+ 8 - 0
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/impl/AiChatModalServiceImpl.java

@@ -116,6 +116,14 @@ public class AiChatModalServiceImpl implements AiChatModalService {
         return AiChatModalConvert.INSTANCE.convertAiChatModalRes(aiChatModalDO);
     }
 
+    @Override
+    public void validateAvailable(AiChatModalRes chatModal) {
+        // 对话模型是否可用
+        if (AiChatModalDisableEnum.YES.getValue().equals(chatModal.getDisable())) {
+            throw ServiceExceptionUtil.exception(ErrorCodeConstants.AI_MODAL_DISABLE_NOT_USED);
+        }
+    }
+
     private AiChatModalDO validateChatModalExists(Long id) {
         AiChatModalDO aiChatModalDO = aiChatModalMapper.selectById(id);
         if (aiChatModalDO == null) {

+ 15 - 5
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/impl/AiChatRoleServiceImpl.java

@@ -70,7 +70,7 @@ public class AiChatRoleServiceImpl implements AiChatRoleService {
         AiChatRoleClassifyEnum.valueOfClassify(req.getClassify());
         AiChatRoleEnableEnum.valueOfType(req.getEnable());
         // 检查角色是否存在
-        validateChatRoleExists(id);
+        validateExists(id);
         // 转换do
         AiChatRoleDO updateChatRole = AiChatRoleConvert.INSTANCE.convertAiChatRoleDO(req);
         updateChatRole.setId(id);
@@ -83,7 +83,7 @@ public class AiChatRoleServiceImpl implements AiChatRoleService {
         // 转换enum,并校验enum
         AiChatRoleEnableEnum.valueOfType(req.getEnable());
         // 检查角色是否存在
-        validateChatRoleExists(id);
+        validateExists(id);
         // 更新
         aiChatRoleMapper.updateById(new AiChatRoleDO()
                 .setId(id)
@@ -94,7 +94,7 @@ public class AiChatRoleServiceImpl implements AiChatRoleService {
     @Override
     public void delete(Long chatRoleId) {
         // 检查角色是否存在
-        validateChatRoleExists(chatRoleId);
+        validateExists(chatRoleId);
         // 删除
         aiChatRoleMapper.deleteById(chatRoleId);
     }
@@ -102,15 +102,25 @@ public class AiChatRoleServiceImpl implements AiChatRoleService {
     @Override
     public AiChatRoleRes getChatRole(Long roleId) {
         // 检查角色是否存在
-        AiChatRoleDO aiChatRoleDO = validateChatRoleExists(roleId);
+        AiChatRoleDO aiChatRoleDO = validateExists(roleId);
         return AiChatRoleConvert.INSTANCE.convertAiChatRoleRes(aiChatRoleDO);
     }
 
-    private AiChatRoleDO validateChatRoleExists(Long id) {
+    public AiChatRoleDO validateExists(Long id) {
         AiChatRoleDO aiChatRoleDO = aiChatRoleMapper.selectById(id);
         if (aiChatRoleDO == null) {
             throw ServiceExceptionUtil.exception(ErrorCodeConstants.AI_CHAT_ROLE_NOT_EXIST);
         }
         return aiChatRoleDO;
     }
+
+    public void validateIsPublic(AiChatRoleDO aiChatRoleDO) {
+        if (aiChatRoleDO == null) {
+            return;
+        }
+        if (!aiChatRoleDO.getPublicStatus()) {
+            throw ServiceExceptionUtil.exception(ErrorCodeConstants.AI_CHAT_ROLE_NOT_PUBLIC);
+        }
+    }
 }
+

+ 83 - 68
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/impl/AiChatServiceImpl.java

@@ -10,14 +10,19 @@ import cn.iocoder.yudao.framework.ai.chat.prompt.Prompt;
 import cn.iocoder.yudao.framework.security.core.util.SecurityFrameworkUtils;
 import cn.iocoder.yudao.module.ai.config.AiChatClientFactory;
 import cn.iocoder.yudao.module.ai.controller.Utf8SseEmitter;
+import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.conversation.AiChatConversationRespVO;
+import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageRespVO;
+import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageSendReqVO;
+import cn.iocoder.yudao.module.ai.controller.admin.model.vo.model.AiChatModalRes;
 import cn.iocoder.yudao.module.ai.dal.dataobject.chat.AiChatMessageDO;
+import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatRoleDO;
 import cn.iocoder.yudao.module.ai.dal.mysql.AiChatConversationMapper;
 import cn.iocoder.yudao.module.ai.dal.mysql.AiChatMessageMapper;
 import cn.iocoder.yudao.module.ai.dal.mysql.AiChatRoleMapper;
 import cn.iocoder.yudao.module.ai.service.AiChatConversationService;
+import cn.iocoder.yudao.module.ai.service.AiChatModalService;
+import cn.iocoder.yudao.module.ai.service.AiChatRoleService;
 import cn.iocoder.yudao.module.ai.service.AiChatService;
-import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.conversation.AiChatConversationRespVO;
-import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageSendReqVO;
 import lombok.AllArgsConstructor;
 import lombok.extern.slf4j.Slf4j;
 import org.springframework.http.MediaType;
@@ -45,29 +50,39 @@ public class AiChatServiceImpl implements AiChatService {
     private final AiChatMessageMapper aiChatMessageMapper;
     private final AiChatConversationMapper aiChatConversationMapper;
     private final AiChatConversationService chatConversationService;
+    private final AiChatModalService aiChatModalService;
+    private final AiChatRoleService aiChatRoleService;
 
-    /**
-     * chat
-     *
-     * @param req
-     * @return
-     */
     @Transactional(rollbackFor = Exception.class)
-    public String chat(AiChatMessageSendReqVO req) {
+    public AiChatMessageRespVO chat(AiChatMessageSendReqVO req) {
         Long loginUserId = SecurityFrameworkUtils.getLoginUserId();
+        // 查询对话
+        AiChatConversationRespVO conversation = chatConversationService.getConversationOfValidate(req.getConversationId());
+        // 获取对话模型
+        AiChatModalRes chatModal = aiChatModalService.getChatModalOfValidate(conversation.getModelId());
+        // 对话模型是否可用
+        aiChatModalService.validateAvailable(chatModal);
+        // 获取角色信息
+        AiChatRoleDO aiChatRoleDO = null;
+        if (conversation.getRoleId() != null) {
+            aiChatRoleDO = aiChatRoleService.validateExists(conversation.getRoleId());
+        }
+        // 校验角色是否公开
+        aiChatRoleService.validateIsPublic(aiChatRoleDO);
         // 获取 client 类型
-        AiPlatformEnum platformEnum = AiPlatformEnum.valueOfPlatform(req.getModal());
-        // 获取对话信息
-        AiChatConversationRespVO conversationRes = chatConversationService.getConversationOfValidate(req.getConversationId());
+        AiPlatformEnum platformEnum = AiPlatformEnum.valueOfPlatform(chatModal.getModal());
         // 保存 chat message
-        saveChatMessage(req, conversationRes, loginUserId);
+        insertChatMessage(conversation.getId(), MessageType.USER, loginUserId, conversation.getRoleId(),
+                chatModal.getModal(), chatModal.getId(), req.getContent(),
+                null, conversation.getTemperature(), conversation.getMaxTokens(), conversation.getMaxContexts());
         String content = null;
         try {
             // 创建 chat 需要的 Prompt
-            Prompt prompt = new Prompt(req.getPrompt());
-            req.setTopK(req.getTopK());
-            req.setTopP(req.getTopP());
-            req.setTemperature(req.getTemperature());
+            Prompt prompt = new Prompt(req.getContent());
+            // TODO @芋艿 @范 看要不要支持这些
+//            req.setTopK(req.getTopK());
+//            req.setTopP(req.getTopP());
+//            req.setTemperature(req.getTemperature());
             // 发送 call 调用
             ChatClient chatClient = aiChatClientFactory.getChatClient(platformEnum);
             ChatResponse call = chatClient.call(prompt);
@@ -78,69 +93,66 @@ public class AiChatServiceImpl implements AiChatService {
             content = ExceptionUtil.getMessage(e);
         } finally {
             // 保存 chat message
-            saveSystemChatMessage(req, conversationRes, loginUserId, content);
+            insertChatMessage(conversation.getId(), MessageType.SYSTEM, loginUserId, conversation.getRoleId(),
+                    chatModal.getModal(), chatModal.getId(), req.getContent(),
+                    null, conversation.getTemperature(), conversation.getMaxTokens(), conversation.getMaxContexts());
         }
-        return content;
+        return new AiChatMessageRespVO().setContent(content);
     }
 
-    private void saveChatMessage(AiChatMessageSendReqVO req, AiChatConversationRespVO conversationRes, Long loginUserId) {
-        Long chatConversationId = conversationRes.getId();
-        // 增加 chat message 记录
-        aiChatMessageMapper.insert(
-                new AiChatMessageDO()
-                        .setId(null)
-                        .setConversationId(chatConversationId)
-                        .setUserId(loginUserId)
-                        .setMessage(req.getPrompt())
-                        .setMessageType(MessageType.USER.getValue())
-                        .setTopK(req.getTopK())
-                        .setTopP(req.getTopP())
-                        .setTemperature(req.getTemperature())
-        );
-        // chat count 先+1
-        aiChatConversationMapper.updateIncrChatCount(req.getConversationId());
-    }
+    private AiChatMessageDO insertChatMessage(Long conversationId, MessageType messageType, Long loginUserId, Long roleId,
+                                              String model, Long modelId, String content, Integer tokens, Double temperature,
+                                              Integer maxTokens, Integer maxContexts) {
+        AiChatMessageDO insertChatMessageDO = new AiChatMessageDO()
+                .setId(null)
+                .setConversationId(conversationId)
+                .setType(messageType.getValue())
+                .setUserId(loginUserId)
+                .setRoleId(roleId)
+                .setModel(model)
+                .setModelId(modelId)
+                .setContent(content)
+                .setTokens(tokens)
 
-    public void saveSystemChatMessage(AiChatMessageSendReqVO req, AiChatConversationRespVO conversationRes, Long loginUserId, String systemPrompts) {
-        Long chatConversationId = conversationRes.getId();
+                .setTemperature(temperature)
+                .setMaxTokens(maxTokens)
+                .setMaxContexts(maxContexts);
         // 增加 chat message 记录
-        aiChatMessageMapper.insert(
-                new AiChatMessageDO()
-                        .setId(null)
-                        .setConversationId(chatConversationId)
-                        .setUserId(loginUserId)
-                        .setMessage(systemPrompts)
-                        .setMessageType(MessageType.SYSTEM.getValue())
-                        .setTopK(req.getTopK())
-                        .setTopP(req.getTopP())
-                        .setTemperature(req.getTemperature())
-        );
-
+        aiChatMessageMapper.insert(insertChatMessageDO);
         // chat count 先+1
-        aiChatConversationMapper.updateIncrChatCount(req.getConversationId());
+        aiChatConversationMapper.updateIncrChatCount(conversationId);
+        return insertChatMessageDO;
     }
 
-    /**
-     * chat stream
-     *
-     * @param req
-     * @param sseEmitter
-     * @return
-     */
     @Override
     public void chatStream(AiChatMessageSendReqVO req, Utf8SseEmitter sseEmitter) {
         Long loginUserId = SecurityFrameworkUtils.getLoginUserId();
-        // 获取 client 类型
-        AiPlatformEnum platformEnum = AiPlatformEnum.valueOfPlatform(req.getModal());
-        // 获取对话信息
-        AiChatConversationRespVO conversationRes = chatConversationService.getConversationOfValidate(req.getConversationId());
+        // 查询对话
+        AiChatConversationRespVO conversation = chatConversationService.getConversationOfValidate(req.getConversationId());
+        // 获取对话模型
+        AiChatModalRes chatModal = aiChatModalService.getChatModalOfValidate(conversation.getModelId());
+        // 对话模型是否可用
+        aiChatModalService.validateAvailable(chatModal);
+        // 获取角色信息
+        AiChatRoleDO aiChatRoleDO = null;
+        if (conversation.getRoleId() != null) {
+            aiChatRoleDO = aiChatRoleService.validateExists(conversation.getRoleId());
+        }
+        // 校验角色是否公开
+        aiChatRoleService.validateIsPublic(aiChatRoleDO);
         // 创建 chat 需要的 Prompt
-        Prompt prompt = new Prompt(req.getPrompt());
-        req.setTopK(req.getTopK());
-        req.setTopP(req.getTopP());
-        req.setTemperature(req.getTemperature());
+        Prompt prompt = new Prompt(req.getContent());
+//        req.setTopK(req.getTopK());
+//        req.setTopP(req.getTopP());
+//        req.setTemperature(req.getTemperature());
         // 保存 chat message
-        saveChatMessage(req, conversationRes, loginUserId);
+        // 保存 chat message
+        insertChatMessage(conversation.getId(), MessageType.USER, loginUserId, conversation.getRoleId(),
+                chatModal.getModal(), chatModal.getId(), req.getContent(),
+                null, conversation.getTemperature(), conversation.getMaxTokens(), conversation.getMaxContexts());
+
+        // 获取 client 类型
+        AiPlatformEnum platformEnum = AiPlatformEnum.valueOfPlatform(chatModal.getModal());
         StreamingChatClient streamingChatClient = aiChatClientFactory.getStreamingChatClient(platformEnum);
         Flux<ChatResponse> streamResponse = streamingChatClient.stream(prompt);
 
@@ -168,7 +180,10 @@ public class AiChatServiceImpl implements AiChatService {
                     log.info("发送完成!");
                     sseEmitter.complete();
                     // 保存 chat message
-                    saveSystemChatMessage(req, conversationRes, loginUserId, contentBuffer.toString());
+                    insertChatMessage(conversation.getId(), MessageType.SYSTEM, loginUserId, conversation.getRoleId(),
+                            chatModal.getModal(), chatModal.getId(), req.getContent(),
+                            null, conversation.getTemperature(), conversation.getMaxTokens(), conversation.getMaxContexts());
+
                 }
         );
     }