浏览代码

【调整】调整AI对话模块

cherishsince 1 年之前
父节点
当前提交
f86e24bb86

+ 16 - 0
yudao-module-ai/yudao-module-ai-api/src/main/java/cn/iocoder/yudao/module/ai/AiCommonConstants.java

@@ -0,0 +1,16 @@
+package cn.iocoder.yudao.module.ai;
+
+/**
+ * ai 常用的常量
+ *
+ * @author fansili
+ * @time 2024/5/7 09:29
+ * @since 1.0
+ */
+public class AiCommonConstants {
+
+    /**
+     * 对话 - 默认 title
+     */
+    public static final String CONVERSATION_DEFAULT_TITLE = "新增对话";
+}

+ 4 - 0
yudao-module-ai/yudao-module-ai-api/src/main/java/cn/iocoder/yudao/module/ai/ErrorCodeConstants.java

@@ -15,6 +15,10 @@ public interface ErrorCodeConstants {
 
     ErrorCode AI_MODULE_NOT_SUPPORTED = new ErrorCode(1_022_000_000, "AI 模型暂不支持!");
     ErrorCode AI_CHAT_ROLE_NOT_EXISTENT = new ErrorCode(1_022_000_001, "AI Role 不存在!");;
+
+    // conversation
+
+    ErrorCode AI_CONVERSATION_NOT_EXISTS = new ErrorCode(1_022_000_002, "AI 对话不存在!");;
     ErrorCode AI_CHAT_CONTINUE_CONVERSATION_ID_NOT_NULL = new ErrorCode(1_022_000_002, "chat 继续对话,对话 id 不能为空!");;
     ErrorCode AI_CHAT_CONTINUE_NOT_EXIST = new ErrorCode(1_022_000_020, "chat 对话不存在!");
     ErrorCode AI_CHAT_CONVERSATION_NOT_YOURS = new ErrorCode(1_022_000_021, "这条 chat 对话不是你的!");

+ 16 - 10
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/chat/AiChatConversationController.java

@@ -2,12 +2,15 @@ package cn.iocoder.yudao.module.ai.controller.admin.chat;
 
 import cn.iocoder.yudao.framework.common.pojo.CommonResult;
 import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.conversation.AiChatConversationCreateReqVO;
+import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.conversation.AiChatConversationListReqVO;
 import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.conversation.AiChatConversationRespVO;
 import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.conversation.AiChatConversationUpdateReqVO;
+import cn.iocoder.yudao.module.ai.service.AiChatConversationService;
 import io.swagger.v3.oas.annotations.Operation;
 import io.swagger.v3.oas.annotations.Parameter;
 import io.swagger.v3.oas.annotations.tags.Tag;
 import jakarta.validation.Valid;
+import lombok.AllArgsConstructor;
 import lombok.extern.slf4j.Slf4j;
 import org.springframework.security.access.prepost.PreAuthorize;
 import org.springframework.web.bind.annotation.*;
@@ -16,33 +19,36 @@ import java.util.List;
 
 import static cn.iocoder.yudao.framework.common.pojo.CommonResult.success;
 
+@Slf4j
 @Tag(name = "管理后台 - 聊天会话")
 @RestController
 @RequestMapping("/ai/chat/conversation")
-@Slf4j
+@AllArgsConstructor
 public class AiChatConversationController {
 
-    // TODO @fan:实现一下
+    private final AiChatConversationService aiChatConversationService;
+
+    // TODO done @fan:实现一下
     @PostMapping("/create")
     @Operation(summary = "创建聊天会话")
     @PreAuthorize("@ss.hasPermission('ai:chat-conversation:create')")
     public CommonResult<Long> createConversation(@RequestBody @Valid AiChatConversationCreateReqVO createReqVO) {
-        return success(1L);
+        return success(aiChatConversationService.createConversation(createReqVO));
     }
 
-    // TODO @fan:实现一下
+    // TODO done @fan:实现一下
     @PutMapping("/update")
     @Operation(summary = "更新聊天会话")
     @PreAuthorize("@ss.hasPermission('ai:chat-conversation:create')")
     public CommonResult<Boolean> updateConversation(@RequestBody @Valid AiChatConversationUpdateReqVO updateReqVO) {
-        return success(true);
+        return success(aiChatConversationService.updateConversation(updateReqVO));
     }
 
-    // TODO @fan:实现一下
+    // TODO done @fan:实现一下
     @GetMapping("/list")
     @Operation(summary = "获得聊天会话列表")
-    public CommonResult<List<AiChatConversationRespVO>> getConversationList() {
-        return success(null);
+    public CommonResult<List<AiChatConversationRespVO>> getConversationList(@ModelAttribute AiChatConversationListReqVO listReqVO) {
+        return success(aiChatConversationService.listConversation(listReqVO));
     }
 
     // TODO @fan:实现一下
@@ -50,7 +56,7 @@ public class AiChatConversationController {
     @Operation(summary = "获得聊天会话")
     @Parameter(name = "id", required = true, description = "会话编号", example = "1024")
     public CommonResult<AiChatConversationRespVO> getConversation(@RequestParam("id") Long id) {
-        return success(null);
+        return success(aiChatConversationService.getConversationOfValidate(id));
     }
 
     // TODO @fan:实现一下
@@ -58,7 +64,7 @@ public class AiChatConversationController {
     @Operation(summary = "删除聊天会话")
     @Parameter(name = "id", required = true, description = "会话编号", example = "1024")
     public CommonResult<Boolean> deleteConversation(@RequestParam("id") Long id) {
-        return success(null);
+        return success(aiChatConversationService.deleteConversation(id));
     }
 
 }

+ 13 - 0
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/chat/vo/conversation/AiChatConversationListReqVO.java

@@ -0,0 +1,13 @@
+package cn.iocoder.yudao.module.ai.controller.admin.chat.vo.conversation;
+
+import io.swagger.v3.oas.annotations.media.Schema;
+import lombok.Data;
+
+@Schema(description = "管理后台 - AI 聊天会话 Response VO")
+@Data
+public class AiChatConversationListReqVO {
+
+    @Schema(description = "会话标题", requiredMode = Schema.RequiredMode.REQUIRED, example = "我是一个标题")
+    private String title;
+
+}

+ 8 - 0
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/convert/AiChatConversationConvert.java

@@ -1,5 +1,6 @@
 package cn.iocoder.yudao.module.ai.convert;
 
+import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.conversation.AiChatConversationUpdateReqVO;
 import cn.iocoder.yudao.module.ai.dal.dataobject.chat.AiChatConversationDO;
 import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.conversation.AiChatConversationRespVO;
 import org.mapstruct.Mapper;
@@ -34,4 +35,11 @@ public interface AiChatConversationConvert {
      * @return
      */
     AiChatConversationRespVO covnertChatConversationRes(AiChatConversationDO aiChatConversationDO);
+
+    /**
+     * 转换 - AiChatConversationDO
+     *
+     * @param updateReqVO
+     */
+    AiChatConversationDO convertAiChatConversationDO(AiChatConversationUpdateReqVO updateReqVO);
 }

+ 1 - 0
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/mysql/AiChatModalMapper.java

@@ -5,6 +5,7 @@ import cn.iocoder.yudao.framework.common.pojo.PageParam;
 import cn.iocoder.yudao.framework.common.pojo.PageResult;
 import cn.iocoder.yudao.framework.mybatis.core.mapper.BaseMapperX;
 import cn.iocoder.yudao.framework.mybatis.core.query.LambdaQueryWrapperX;
+import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModalDO;
 import org.apache.ibatis.annotations.Mapper;
 import org.springframework.stereotype.Repository;
 

+ 11 - 19
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/AiChatConversationService.java

@@ -1,7 +1,9 @@
 package cn.iocoder.yudao.module.ai.service;
 
 import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.conversation.AiChatConversationCreateReqVO;
+import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.conversation.AiChatConversationListReqVO;
 import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.conversation.AiChatConversationRespVO;
+import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.conversation.AiChatConversationUpdateReqVO;
 
 import java.util.List;
 
@@ -14,29 +16,20 @@ import java.util.List;
 public interface AiChatConversationService {
 
     /**
-     * 对话 - 创建普通对话
+     * 对话 - 创建对话
      *
      * @param req
      * @return
      */
-    AiChatConversationRespVO createConversation(AiChatConversationCreateUserReq req);
+    Long createConversation(AiChatConversationCreateReqVO req);
 
     /**
-     * 对话 - 创建role对话
+     * 对话 - 更新对话
      *
-     * @param req
-     * @return
-     */
-    AiChatConversationRespVO createRoleConversation(AiChatConversationCreateReqVO req);
-
-
-    /**
-     * 获取 - 对话
-     *
-     * @param id
+     * @param updateReqVO
      * @return
      */
-    AiChatConversationRespVO getConversation(Long id);
+    Boolean updateConversation(AiChatConversationUpdateReqVO updateReqVO);
 
     /**
      * 获取 - 对话列表
@@ -44,22 +37,21 @@ public interface AiChatConversationService {
      * @param req
      * @return
      */
-    List<AiChatConversationRespVO> listConversation(AiChatConversationListReq req);
+    List<AiChatConversationRespVO> listConversation(AiChatConversationListReqVO req);
 
     /**
-     * 更新 - 更新模型
+     * 获取 - 对话
      *
      * @param id
-     * @param modalId
      * @return
      */
-    void updateModal(Long id, Long modalId);
+    AiChatConversationRespVO getConversationOfValidate(Long id);
 
     /**
      * 删除 - 根据id
      *
      * @param id
      */
-    void delete(Long id);
+    Boolean deleteConversation(Long id);
 
 }

+ 1 - 1
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/AiChatModalService.java

@@ -51,5 +51,5 @@ public interface AiChatModalService {
      * @param modalId
      * @return
      */
-    AiChatModalRes getChatModal(Long modalId);
+    AiChatModalRes getChatModalOfValidate(Long modalId);
 }

+ 56 - 77
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/impl/AiChatConversationServiceImpl.java

@@ -2,17 +2,20 @@ package cn.iocoder.yudao.module.ai.service.impl;
 
 import cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionUtil;
 import cn.iocoder.yudao.framework.security.core.util.SecurityFrameworkUtils;
+import cn.iocoder.yudao.module.ai.AiCommonConstants;
 import cn.iocoder.yudao.module.ai.ErrorCodeConstants;
 import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.conversation.AiChatConversationCreateReqVO;
+import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.conversation.AiChatConversationListReqVO;
 import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.conversation.AiChatConversationRespVO;
+import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.conversation.AiChatConversationUpdateReqVO;
 import cn.iocoder.yudao.module.ai.controller.admin.model.vo.model.AiChatModalRes;
 import cn.iocoder.yudao.module.ai.controller.admin.model.vo.role.AiChatRoleRes;
 import cn.iocoder.yudao.module.ai.convert.AiChatConversationConvert;
-import cn.iocoder.yudao.module.ai.enums.AiChatConversationTypeEnum;
-import cn.iocoder.yudao.module.ai.enums.AiChatModalDisableEnum;
+import cn.iocoder.yudao.module.ai.dal.dataobject.chat.AiChatConversationDO;
+import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModalDO;
 import cn.iocoder.yudao.module.ai.dal.mysql.AiChatConversationMapper;
 import cn.iocoder.yudao.module.ai.dal.mysql.AiChatModalMapper;
-import cn.iocoder.yudao.module.ai.dal.mysql.AiChatRoleMapper;
+import cn.iocoder.yudao.module.ai.enums.AiChatModalDisableEnum;
 import cn.iocoder.yudao.module.ai.service.AiChatConversationService;
 import cn.iocoder.yudao.module.ai.service.AiChatModalService;
 import cn.iocoder.yudao.module.ai.service.AiChatRoleService;
@@ -34,119 +37,95 @@ import java.util.List;
 @AllArgsConstructor
 public class AiChatConversationServiceImpl implements AiChatConversationService {
 
-    private final AiChatRoleMapper aiChatRoleMapper;
     private final AiChatModalMapper aiChatModalMapper;
-    private final AiChatConversationMapper aiChatConversationMapper;
     private final AiChatModalService aiChatModalService;
     private final AiChatRoleService aiChatRoleService;
+    private final AiChatConversationMapper aiChatConversationMapper;
 
     @Override
-    public AiChatConversationRespVO createConversation(AiChatConversationCreateUserReq req) {
+    public Long createConversation(AiChatConversationCreateReqVO req) {
         // 获取用户id
         Long loginUserId = SecurityFrameworkUtils.getLoginUserId();
-        // 查询最新的对话
-        AiChatConversationDO latestConversation = aiChatConversationMapper.selectLatestConversation(loginUserId);
-        // 如果有对话没有被使用过,那就返回这个
-        if (latestConversation != null && latestConversation.getChatCount() <= 0) {
-            return AiChatConversationConvert.INSTANCE.covnertChatConversationRes(latestConversation);
-        }
-        // 获取第一个模型
+        // 默认使用 sort 排序第一个模型
         AiChatModalDO aiChatModalDO = aiChatModalMapper.selectFirstModal();
-        // 创建新的 Conversation
-        AiChatConversationDO insertConversation = saveConversation(req.getTitle(), loginUserId,
-                null, null, AiChatConversationTypeEnum.USER_CHAT,
-                aiChatModalDO.getId(), aiChatModalDO.getModal());
-        // 转换 res
-        return AiChatConversationConvert.INSTANCE.covnertChatConversationRes(insertConversation);
-    }
-
-    @Override
-    public AiChatConversationRespVO createRoleConversation(AiChatConversationCreateReqVO req) {
-        // 获取用户id
-        Long loginUserId = SecurityFrameworkUtils.getLoginUserId();
-        // 查询最新的对话
-//        AiChatConversationDO latestConversation = aiChatConversationMapper.selectLatestConversation(loginUserId);
-//        // 如果有对话没有被使用过,那就返回这个
-//        if (latestConversation != null && latestConversation.getChatCount() <= 0) {
-//            return AiChatConversationConvert.INSTANCE.covnertChatConversationRes(latestConversation);
-//        }
         // 查询角色
-        AiChatRoleRes chatRoleRes = aiChatRoleService.getChatRole(req.getRoleId());
-        // 获取第一个模型
-        AiChatModalDO aiChatModalDO = aiChatModalMapper.selectFirstModal();
+        AiChatRoleRes chatRoleRes = null;
+        if (req.getRoleId() != null) {
+            chatRoleRes = aiChatRoleService.getChatRole(req.getRoleId());
+        }
+        Long chatRoleId = chatRoleRes != null ? chatRoleRes.getId() : null;
         // 创建新的 Conversation
-        AiChatConversationDO insertConversation = saveConversation(req.getTitle(), loginUserId,
-                req.getRoleId(), chatRoleRes.getName(), AiChatConversationTypeEnum.ROLE_CHAT,
-                aiChatModalDO.getId(), aiChatModalDO.getModal());
-        // 转换 res
-        return AiChatConversationConvert.INSTANCE.covnertChatConversationRes(insertConversation);
+        AiChatConversationDO insertConversation = saveConversation(AiCommonConstants.CONVERSATION_DEFAULT_TITLE,
+                loginUserId, chatRoleId, aiChatModalDO.getId(), aiChatModalDO.getModel()
+        );
+        // 返回对话id
+        return insertConversation.getId();
     }
 
     private @NotNull AiChatConversationDO saveConversation(String title,
                                                            Long userId,
                                                            Long roleId,
-                                                           String roleName,
-                                                           AiChatConversationTypeEnum typeEnum,
                                                            Long modalId,
-                                                           String modal) {
+                                                           String model) {
         AiChatConversationDO insertConversation = new AiChatConversationDO();
         insertConversation.setId(null);
         insertConversation.setUserId(userId);
-        insertConversation.setRoleId(roleId);
-        insertConversation.setRoleName(roleName);
         insertConversation.setTitle(title);
-        insertConversation.setChatCount(0);
-        insertConversation.setType(typeEnum.getType());
-        insertConversation.setModalId(modalId);
-        insertConversation.setModal(modal);
+        insertConversation.setPinned(false);
+
+        insertConversation.setRoleId(roleId);
+        insertConversation.setModelId(modalId);
+        insertConversation.setModel(model);
+
+        insertConversation.setTemperature(null);
+        insertConversation.setMaxTokens(null);
+        insertConversation.setMaxContexts(null);
         aiChatConversationMapper.insert(insertConversation);
         return insertConversation;
     }
 
     @Override
-    public AiChatConversationRespVO getConversation(Long id) {
-        AiChatConversationDO aiChatConversationDO = validateExists(id);
-        return AiChatConversationConvert.INSTANCE.covnertChatConversationRes(aiChatConversationDO);
-    }
-
-    private @NotNull AiChatConversationDO validateExists(Long id) {
-        AiChatConversationDO aiChatConversationDO = aiChatConversationMapper.selectById(id);
-        if (aiChatConversationDO == null) {
-            throw ServiceExceptionUtil.exception(ErrorCodeConstants.AI_CHAT_CONTINUE_NOT_EXIST);
+    public Boolean updateConversation(AiChatConversationUpdateReqVO updateReqVO) {
+        // 校验对话是否存在
+        validateExists(updateReqVO.getId());
+        // 获取模型信息并验证
+        AiChatModalRes chatModal = aiChatModalService.getChatModalOfValidate(updateReqVO.getModelId());
+        // 校验modal是否可用
+        if (AiChatModalDisableEnum.YES.getValue().equals(chatModal.getDisable())) {
+            throw ServiceExceptionUtil.exception(ErrorCodeConstants.AI_MODAL_DISABLE_NOT_USED);
         }
-        return aiChatConversationDO;
+        // 更新对话信息
+        AiChatConversationDO updateAiChatConversationDO
+                = AiChatConversationConvert.INSTANCE.convertAiChatConversationDO(updateReqVO);
+        return aiChatConversationMapper.updateById(updateAiChatConversationDO) > 0;
     }
 
     @Override
-    public List<AiChatConversationRespVO> listConversation(AiChatConversationListReq req) {
+    public List<AiChatConversationRespVO> listConversation(AiChatConversationListReqVO listReqVO) {
         // 获取用户id
         Long loginUserId = SecurityFrameworkUtils.getLoginUserId();
         // 查询前100对话
         List<AiChatConversationDO> top100Conversation
-                = aiChatConversationMapper.selectTop100Conversation(loginUserId, req.getSearch());
+                = aiChatConversationMapper.selectTop100Conversation(loginUserId, listReqVO.getTitle());
         return AiChatConversationConvert.INSTANCE.covnertChatConversationResList(top100Conversation);
     }
 
     @Override
-    public void updateModal(Long id, Long modalId) {
-        // 校验对话是否存在
-        validateExists(id);
-        // 获取模型
-        AiChatModalRes chatModal = aiChatModalService.getChatModal(modalId);
-        // 判断模型是否禁用
-        if (AiChatModalDisableEnum.YES.getValue().equals(chatModal.getDisable())) {
-            throw ServiceExceptionUtil.exception(ErrorCodeConstants.AI_MODAL_DISABLE_NOT_USED);
-        }
-        // 更新对话
-        aiChatConversationMapper.updateById(new AiChatConversationDO()
-                .setId(id)
-                .setModalId(chatModal.getId())
-                .setModal(chatModal.getModal())
-        );
+    public AiChatConversationRespVO getConversationOfValidate(Long id) {
+        AiChatConversationDO aiChatConversationDO = validateExists(id);
+        return AiChatConversationConvert.INSTANCE.covnertChatConversationRes(aiChatConversationDO);
     }
 
     @Override
-    public void delete(Long id) {
-        aiChatConversationMapper.deleteById(id);
+    public Boolean deleteConversation(Long id) {
+        return aiChatConversationMapper.deleteById(id) > 0;
+    }
+
+    private @NotNull AiChatConversationDO validateExists(Long id) {
+        AiChatConversationDO aiChatConversationDO = aiChatConversationMapper.selectById(id);
+        if (aiChatConversationDO == null) {
+            throw ServiceExceptionUtil.exception(ErrorCodeConstants.AI_CONVERSATION_NOT_EXISTS);
+        }
+        return aiChatConversationDO;
     }
 }

+ 2 - 1
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/impl/AiChatModalServiceImpl.java

@@ -12,6 +12,7 @@ import cn.iocoder.yudao.framework.common.util.json.JsonUtils;
 import cn.iocoder.yudao.framework.mybatis.core.query.LambdaQueryWrapperX;
 import cn.iocoder.yudao.module.ai.ErrorCodeConstants;
 import cn.iocoder.yudao.module.ai.convert.AiChatModalConvert;
+import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModalDO;
 import cn.iocoder.yudao.module.ai.dal.vo.AiChatModalChatConfigVO;
 import cn.iocoder.yudao.module.ai.dal.vo.AiChatModalConfigVO;
 import cn.iocoder.yudao.module.ai.dal.vo.AiChatModalDallConfigVO;
@@ -109,7 +110,7 @@ public class AiChatModalServiceImpl implements AiChatModalService {
     }
 
     @Override
-    public AiChatModalRes getChatModal(Long modalId) {
+    public AiChatModalRes getChatModalOfValidate(Long modalId) {
         // 检查 modal 是否存在
         AiChatModalDO aiChatModalDO = validateChatModalExists(modalId);
         return AiChatModalConvert.INSTANCE.convertAiChatModalRes(aiChatModalDO);

+ 2 - 2
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/impl/AiChatServiceImpl.java

@@ -58,7 +58,7 @@ public class AiChatServiceImpl implements AiChatService {
         // 获取 client 类型
         AiPlatformEnum platformEnum = AiPlatformEnum.valueOfPlatform(req.getModal());
         // 获取对话信息
-        AiChatConversationRespVO conversationRes = chatConversationService.getConversation(req.getConversationId());
+        AiChatConversationRespVO conversationRes = chatConversationService.getConversationOfValidate(req.getConversationId());
         // 保存 chat message
         saveChatMessage(req, conversationRes, loginUserId);
         String content = null;
@@ -133,7 +133,7 @@ public class AiChatServiceImpl implements AiChatService {
         // 获取 client 类型
         AiPlatformEnum platformEnum = AiPlatformEnum.valueOfPlatform(req.getModal());
         // 获取对话信息
-        AiChatConversationRespVO conversationRes = chatConversationService.getConversation(req.getConversationId());
+        AiChatConversationRespVO conversationRes = chatConversationService.getConversationOfValidate(req.getConversationId());
         // 创建 chat 需要的 Prompt
         Prompt prompt = new Prompt(req.getPrompt());
         req.setTopK(req.getTopK());