Ver código fonte

【新增】AI:聊天对话的新建(80%)

YunaiV 11 meses atrás
pai
commit
69fa98792c

+ 2 - 1
yudao-module-ai/yudao-module-ai-api/src/main/java/cn/iocoder/yudao/module/ai/ErrorCodeConstants.java

@@ -25,7 +25,8 @@ public interface ErrorCodeConstants {
 
     // ========== API 聊天会话 1-040-003-000 ==========
 
-    ErrorCode CHAT_CONVERSATION_NOT_EXISTS = new ErrorCode(1_040_003_000, "AI 对话不存在!");;
+    ErrorCode CHAT_CONVERSATION_NOT_EXISTS = new ErrorCode(1_040_003_000, "AI 对话不存在!");
+    ErrorCode CHAT_CONVERSATION_MODEL_ERROR = new ErrorCode(1_040_003_001, "操作失败,该聊天模型的配置不完整");
 
     // chat
     ErrorCode AI_CHAT_MESSAGE_NOT_EXIST = new ErrorCode(1_022_000_100, "AI 提问的 MessageId 不存在!");

+ 6 - 4
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/chat/AiChatConversationController.java

@@ -1,9 +1,11 @@
 package cn.iocoder.yudao.module.ai.controller.admin.chat;
 
 import cn.iocoder.yudao.framework.common.pojo.CommonResult;
+import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
 import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.conversation.AiChatConversationCreateMyReqVO;
 import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.conversation.AiChatConversationRespVO;
 import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.conversation.AiChatConversationUpdateMyReqVO;
+import cn.iocoder.yudao.module.ai.dal.dataobject.chat.AiChatConversationDO;
 import cn.iocoder.yudao.module.ai.service.chat.AiChatConversationService;
 import io.swagger.v3.oas.annotations.Operation;
 import io.swagger.v3.oas.annotations.Parameter;
@@ -40,11 +42,11 @@ public class AiChatConversationController {
         return success(true);
     }
 
-    // TODO done @fan:实现一下
     @GetMapping("/my-list")
-    @Operation(summary = "获得聊天会话列表")
-    public CommonResult<List<AiChatConversationRespVO>> getConversationList() {
-        return success(chatConversationService.listConversation());
+    @Operation(summary = "获得【我的】聊天会话列表")
+    public CommonResult<List<AiChatConversationRespVO>> getChatConversationMyList() {
+        List<AiChatConversationDO> list = chatConversationService.getChatConversationListByUserId(getLoginUserId());
+        return success(BeanUtils.toBean(list, AiChatConversationRespVO.class));
     }
 
     // TODO @fan:实现一下

+ 23 - 1
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/chat/vo/conversation/AiChatConversationRespVO.java

@@ -1,5 +1,11 @@
 package cn.iocoder.yudao.module.ai.controller.admin.chat.vo.conversation;
 
+import cn.iocoder.yudao.module.ai.controller.admin.model.vo.chatModel.AiChatModelRespVO;
+import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO;
+import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatRoleDO;
+import com.fhs.core.trans.anno.Trans;
+import com.fhs.core.trans.constant.TransType;
+import com.fhs.core.trans.vo.VO;
 import io.swagger.v3.oas.annotations.media.Schema;
 import jakarta.validation.constraints.NotNull;
 import lombok.Data;
@@ -7,7 +13,7 @@ import lombok.experimental.Accessors;
 
 @Schema(description = "管理后台 - AI 聊天会话 Response VO")
 @Data
-public class AiChatConversationRespVO {
+public class AiChatConversationRespVO implements VO {
 
     @Schema(description = "会话编号", requiredMode = Schema.RequiredMode.REQUIRED, example = "1024")
     private Long id;
@@ -22,9 +28,12 @@ public class AiChatConversationRespVO {
     private Boolean pinned;
 
     @Schema(description = "角色编号", requiredMode = Schema.RequiredMode.NOT_REQUIRED, example = "1")
+    @Trans(type = TransType.SIMPLE, target = AiChatRoleDO.class, fields = "avatar", ref = "roleAvatar")
     private Long roleId;
 
     @Schema(description = "模型编号", requiredMode = Schema.RequiredMode.REQUIRED, example = "1")
+    @Trans(type = TransType.SIMPLE, target = AiChatModelDO.class, fields = {"maxTokens", "maxContexts"},
+            refs = {"modelMaxTokens", "modelMaxContexts"})
     private Long modelId;
 
     @Schema(description = "模型标志", requiredMode = Schema.RequiredMode.REQUIRED, example = "ERNIE-Bot-turbo-0922")
@@ -39,4 +48,17 @@ public class AiChatConversationRespVO {
     @Schema(description = "上下文的最大 Message 数量", requiredMode = Schema.RequiredMode.REQUIRED, example = "10")
     private Integer maxContexts;
 
+    // ========== 关联 role 信息 ==========
+
+    @Schema(description = "角色头像", requiredMode = Schema.RequiredMode.REQUIRED, example = "https://www.iocoder.cn/1.png")
+    private String roleAvatar;
+
+    // ========== 关联 model 信息 ==========
+
+    @Schema(description = "模型的单条回复的最大 Token 数量", requiredMode = Schema.RequiredMode.REQUIRED, example = "4096")
+    private Integer modelMaxTokens;
+
+    @Schema(description = "模型的上下文的最大 Message 数量", requiredMode = Schema.RequiredMode.REQUIRED, example = "10")
+    private Integer modelMaxContexts;
+
 }

+ 0 - 65
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/mysql/AiChatConversationMapper.java

@@ -1,65 +0,0 @@
-package cn.iocoder.yudao.module.ai.dal.mysql;
-
-import cn.hutool.core.collection.CollUtil;
-import cn.hutool.core.util.StrUtil;
-import cn.iocoder.yudao.framework.common.pojo.PageParam;
-import cn.iocoder.yudao.framework.common.pojo.PageResult;
-import cn.iocoder.yudao.framework.mybatis.core.mapper.BaseMapperX;
-import cn.iocoder.yudao.module.ai.dal.dataobject.chat.AiChatConversationDO;
-import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper;
-import org.apache.ibatis.annotations.Mapper;
-import org.apache.ibatis.annotations.Param;
-import org.apache.ibatis.annotations.Update;
-import org.springframework.stereotype.Repository;
-
-import java.util.List;
-
-/**
- * message mapper
- *
- * @fansili
- * @since v1.0
- */
-@Repository
-@Mapper
-public interface AiChatConversationMapper extends BaseMapperX<AiChatConversationDO> {
-
-    /**
-     * 更新 - chat count
-     *
-     * @param id
-     */
-    @Update("update ai_chat_conversation set chat_count = chat_count + 1 where id = #{id}")
-    void updateIncrChatCount(@Param("id") Long id);
-
-    /**
-     * 查询 - 最新的对话
-     *
-     * @param loginUserId
-     */
-    default AiChatConversationDO selectLatestConversation(Long loginUserId) {
-        PageResult<AiChatConversationDO> pageResult = selectPage(new PageParam().setPageNo(1).setPageSize(1),
-                new LambdaQueryWrapper<AiChatConversationDO>()
-                        .eq(AiChatConversationDO::getUserId, loginUserId)
-                        .orderByDesc(AiChatConversationDO::getId));
-        if (CollUtil.isEmpty(pageResult.getList())) {
-            return null;
-        }
-        return pageResult.getList().get(0);
-    }
-
-    /**
-     * 查询 - 前100
-     *
-     * @param search
-     */
-    default List<AiChatConversationDO> selectTop100Conversation(Long loginUserId, String search) {
-        LambdaQueryWrapper<AiChatConversationDO> queryWrapper
-                = new LambdaQueryWrapper<AiChatConversationDO>().eq(AiChatConversationDO::getUserId, loginUserId);
-        if (!StrUtil.isBlank(search)) {
-            queryWrapper.like(AiChatConversationDO::getTitle, search);
-        }
-        queryWrapper.orderByDesc(AiChatConversationDO::getId);
-        return selectPage(new PageParam().setPageNo(1).setPageSize(100), queryWrapper).getList();
-    }
-}

+ 2 - 0
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/mysql/AiChatMessageMapper.java

@@ -2,6 +2,7 @@ package cn.iocoder.yudao.module.ai.dal.mysql;
 
 import cn.iocoder.yudao.framework.mybatis.core.mapper.BaseMapperX;
 import cn.iocoder.yudao.framework.mybatis.core.query.LambdaQueryWrapperX;
+import cn.iocoder.yudao.module.ai.dal.dataobject.chat.AiChatConversationDO;
 import cn.iocoder.yudao.module.ai.dal.dataobject.chat.AiChatMessageDO;
 import org.apache.ibatis.annotations.Mapper;
 import org.springframework.stereotype.Repository;
@@ -30,4 +31,5 @@ public interface AiChatMessageMapper extends BaseMapperX<AiChatMessageDO> {
                         .orderByAsc(AiChatMessageDO::getId)
         );
     }
+
 }

+ 22 - 0
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/mysql/chat/AiChatConversationMapper.java

@@ -0,0 +1,22 @@
+package cn.iocoder.yudao.module.ai.dal.mysql.chat;
+
+import cn.iocoder.yudao.framework.mybatis.core.mapper.BaseMapperX;
+import cn.iocoder.yudao.module.ai.dal.dataobject.chat.AiChatConversationDO;
+import cn.iocoder.yudao.module.ai.dal.dataobject.chat.AiChatMessageDO;
+import org.apache.ibatis.annotations.Mapper;
+
+import java.util.List;
+
+/**
+ * AI 聊天对话 Mapper
+ *
+ * @author 芋道源码
+ */
+@Mapper
+public interface AiChatConversationMapper extends BaseMapperX<AiChatConversationDO> {
+
+    default List<AiChatConversationDO> selectListByUserId(Long userId) {
+        return selectList(AiChatConversationDO::getUserId, userId);
+    }
+
+}

+ 5 - 4
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/chat/AiChatConversationService.java

@@ -19,7 +19,7 @@ public interface AiChatConversationService {
      *
      * @param createReqVO 创建信息
      * @param userId 用户编号
-     * @return 聊天会话
+     * @return 编号
      */
     Long createChatConversationMy(AiChatConversationCreateMyReqVO createReqVO, Long userId);
 
@@ -32,11 +32,12 @@ public interface AiChatConversationService {
     void updateChatConversationMy(AiChatConversationUpdateMyReqVO updateReqVO, Long userId);
 
     /**
-     * 获取 - 对话列表
+     * 获得【我的】聊天会话列表
      *
-     * @return
+     * @param userId 用户编号
+     * @return 聊天会话列表
      */
-    List<AiChatConversationRespVO> listConversation();
+    List<AiChatConversationDO> getChatConversationListByUserId(Long userId);
 
     /**
      * 获取 - 对话

+ 17 - 11
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/chat/AiChatConversationServiceImpl.java

@@ -2,6 +2,7 @@ package cn.iocoder.yudao.module.ai.service.chat;
 
 import cn.hutool.core.lang.Assert;
 import cn.hutool.core.util.ObjUtil;
+import cn.hutool.core.util.ObjectUtil;
 import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
 import cn.iocoder.yudao.framework.security.core.util.SecurityFrameworkUtils;
 import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.conversation.AiChatConversationCreateMyReqVO;
@@ -11,7 +12,7 @@ import cn.iocoder.yudao.module.ai.convert.AiChatConversationConvert;
 import cn.iocoder.yudao.module.ai.dal.dataobject.chat.AiChatConversationDO;
 import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO;
 import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatRoleDO;
-import cn.iocoder.yudao.module.ai.dal.mysql.AiChatConversationMapper;
+import cn.iocoder.yudao.module.ai.dal.mysql.chat.AiChatConversationMapper;
 import cn.iocoder.yudao.module.ai.service.model.AiChatModelService;
 import cn.iocoder.yudao.module.ai.service.model.AiChatRoleService;
 import jakarta.annotation.Resource;
@@ -22,6 +23,7 @@ import org.springframework.validation.annotation.Validated;
 import java.util.List;
 
 import static cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionUtil.*;
+import static cn.iocoder.yudao.module.ai.ErrorCodeConstants.CHAT_CONVERSATION_MODEL_ERROR;
 import static cn.iocoder.yudao.module.ai.ErrorCodeConstants.CHAT_CONVERSATION_NOT_EXISTS;
 
 /**
@@ -49,9 +51,10 @@ public class AiChatConversationServiceImpl implements AiChatConversationService
                 : chatRoleService.getRequiredDefaultChatRole();
         Assert.notNull(role, "必须找到聊天角色");
         // 1.2 获得 AiChatModelDO 聊天模型
-        AiChatModelDO model = role.getModelId() != null ? chatModalService.validateChatModel(role.getId())
+        AiChatModelDO model = role.getModelId() != null ? chatModalService.validateChatModel(role.getModelId())
                 : chatModalService.getRequiredDefaultChatModel();
-        Assert.notNull(role, "必须找到默认模型");
+        Assert.notNull(model, "必须找到默认模型");
+        validateChatModel(model);
 
         // 2. 创建 AiChatConversationDO 聊天对话
         AiChatConversationDO conversation = new AiChatConversationDO()
@@ -70,9 +73,10 @@ public class AiChatConversationServiceImpl implements AiChatConversationService
             throw exception(CHAT_CONVERSATION_NOT_EXISTS);
         }
         // 1.2 校验模型是否存在
-        AiChatModelDO model = null;
+        AiChatModelDO model;
         if (updateReqVO.getModelId() != null) {
             model = chatModalService.validateChatModel(updateReqVO.getModelId());
+            Assert.notNull(model, "必须找到默认模型");
         }
         // 1.3 校验温度参数、Token 数量、消息数量 TODO
 
@@ -81,13 +85,15 @@ public class AiChatConversationServiceImpl implements AiChatConversationService
     }
 
     @Override
-    public List<AiChatConversationRespVO> listConversation() {
-        // 获取用户id
-        Long loginUserId = SecurityFrameworkUtils.getLoginUserId();
-        // 查询前100对话
-        List<AiChatConversationDO> top100Conversation
-                = chatConversationMapper.selectTop100Conversation(loginUserId, null);
-        return AiChatConversationConvert.INSTANCE.covnertChatConversationResList(top100Conversation);
+    public List<AiChatConversationDO> getChatConversationListByUserId(Long userId) {
+        return chatConversationMapper.selectListByUserId(userId);
+    }
+
+    private void validateChatModel(AiChatModelDO model) {
+        if (ObjectUtil.isAllNotEmpty(model.getTemperature(), model.getMaxTokens(), model.getMaxContexts())) {
+            return;
+        }
+        throw exception(CHAT_CONVERSATION_MODEL_ERROR);
     }
 
     @Override

+ 1 - 3
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/impl/AiChatServiceImpl.java

@@ -20,7 +20,7 @@ import cn.iocoder.yudao.module.ai.convert.AiChatMessageConvert;
 import cn.iocoder.yudao.module.ai.dal.dataobject.chat.AiChatMessageDO;
 import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO;
 import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatRoleDO;
-import cn.iocoder.yudao.module.ai.dal.mysql.AiChatConversationMapper;
+import cn.iocoder.yudao.module.ai.dal.mysql.chat.AiChatConversationMapper;
 import cn.iocoder.yudao.module.ai.dal.mysql.AiChatMessageMapper;
 import cn.iocoder.yudao.module.ai.service.chat.AiChatConversationService;
 import cn.iocoder.yudao.module.ai.service.model.AiChatRoleService;
@@ -119,8 +119,6 @@ public class AiChatServiceImpl implements AiChatService {
                 .setMaxContexts(maxContexts);
         // 增加 chat message 记录
         aiChatMessageMapper.insert(insertChatMessageDO);
-        // chat count 先+1
-        aiChatConversationMapper.updateIncrChatCount(conversationId);
         return insertChatMessageDO;
     }