浏览代码

增加创建role对话

cherishsince 1 年之前
父节点
当前提交
1ab1538afe

+ 3 - 2
yudao-module-ai/yudao-module-ai-api/src/main/java/cn/iocoder/yudao/module/ai/enums/ChatConversationTypeEnum.java

@@ -15,8 +15,9 @@ import lombok.Getter;
 @Getter
 public enum ChatConversationTypeEnum {
 
-    NEW("new", "新建对话"),
-    CONTINUE("continue", "继续对话"),
+    // roleChat、userChat
+    ROLE_CHAT("roleChat", "角色对话"),
+    USER_CHAT("userChat", "用户对话"),
 
     ;
 

+ 12 - 5
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/ChatConversationController.java

@@ -2,7 +2,8 @@ package cn.iocoder.yudao.module.ai.controller;
 
 import cn.iocoder.yudao.framework.common.pojo.CommonResult;
 import cn.iocoder.yudao.module.ai.service.ChatConversationService;
-import cn.iocoder.yudao.module.ai.vo.ChatConversationCreateReq;
+import cn.iocoder.yudao.module.ai.vo.ChatConversationCreateRoleReq;
+import cn.iocoder.yudao.module.ai.vo.ChatConversationCreateUserReq;
 import cn.iocoder.yudao.module.ai.vo.ChatConversationListReq;
 import cn.iocoder.yudao.module.ai.vo.ChatConversationRes;
 import io.swagger.v3.oas.annotations.Operation;
@@ -30,10 +31,16 @@ public class ChatConversationController {
 
     private final ChatConversationService chatConversationService;
 
-    @Operation(summary = "创建 - 对话")
-    @PostMapping("/create")
-    public CommonResult<ChatConversationRes> create(@RequestBody @Validated ChatConversationCreateReq req) {
-        return CommonResult.success(chatConversationService.create(req));
+    @Operation(summary = "创建 - 对话普通对话")
+    @PostMapping("/createConversation")
+    public CommonResult<ChatConversationRes> createConversation(@RequestBody @Validated ChatConversationCreateUserReq req) {
+        return CommonResult.success(chatConversationService.createConversation(req));
+    }
+
+    @Operation(summary = "创建 - 对话角色对话")
+    @PostMapping("/createRoleConversation")
+    public CommonResult<ChatConversationRes> createRoleConversation(@RequestBody @Validated ChatConversationCreateRoleReq req) {
+        return CommonResult.success(chatConversationService.createRoleConversation(req));
     }
 
     @Operation(summary = "获取 - 获取对话")

+ 5 - 1
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/mapper/AiChatConversationMapper.java

@@ -24,7 +24,11 @@ import java.util.List;
 @Mapper
 public interface AiChatConversationMapper extends BaseMapperX<AiChatConversationDO> {
 
-
+    /**
+     * 更新 - chat count
+     *
+     * @param id
+     */
     @Update("update ai_chat_conversation set chat_count = chat_count + 1 where id = #{id}")
     void updateIncrChatCount(@Param("id") Long id);
 

+ 14 - 3
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/ChatConversationService.java

@@ -1,6 +1,7 @@
 package cn.iocoder.yudao.module.ai.service;
 
-import cn.iocoder.yudao.module.ai.vo.ChatConversationCreateReq;
+import cn.iocoder.yudao.module.ai.vo.ChatConversationCreateRoleReq;
+import cn.iocoder.yudao.module.ai.vo.ChatConversationCreateUserReq;
 import cn.iocoder.yudao.module.ai.vo.ChatConversationListReq;
 import cn.iocoder.yudao.module.ai.vo.ChatConversationRes;
 
@@ -15,12 +16,21 @@ import java.util.List;
 public interface ChatConversationService {
 
     /**
-     * 对话 - 创建
+     * 对话 - 创建普通对话
      *
      * @param req
      * @return
      */
-    ChatConversationRes create(ChatConversationCreateReq req);
+    ChatConversationRes createConversation(ChatConversationCreateUserReq req);
+
+    /**
+     * 对话 - 创建role对话
+     *
+     * @param req
+     * @return
+     */
+    ChatConversationRes createRoleConversation(ChatConversationCreateRoleReq req);
+
 
     /**
      * 获取 - 对话
@@ -44,4 +54,5 @@ public interface ChatConversationService {
      * @param id
      */
     void delete(Long id);
+
 }

+ 43 - 9
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/impl/ChatConversationServiceImpl.java

@@ -5,13 +5,18 @@ import cn.iocoder.yudao.framework.security.core.util.SecurityFrameworkUtils;
 import cn.iocoder.yudao.module.ai.ErrorCodeConstants;
 import cn.iocoder.yudao.module.ai.convert.ChatConversationConvert;
 import cn.iocoder.yudao.module.ai.dal.dataobject.AiChatConversationDO;
+import cn.iocoder.yudao.module.ai.dal.dataobject.AiChatRoleDO;
+import cn.iocoder.yudao.module.ai.enums.ChatConversationTypeEnum;
 import cn.iocoder.yudao.module.ai.mapper.AiChatConversationMapper;
+import cn.iocoder.yudao.module.ai.mapper.AiChatRoleMapper;
 import cn.iocoder.yudao.module.ai.service.ChatConversationService;
-import cn.iocoder.yudao.module.ai.vo.ChatConversationCreateReq;
+import cn.iocoder.yudao.module.ai.vo.ChatConversationCreateRoleReq;
+import cn.iocoder.yudao.module.ai.vo.ChatConversationCreateUserReq;
 import cn.iocoder.yudao.module.ai.vo.ChatConversationListReq;
 import cn.iocoder.yudao.module.ai.vo.ChatConversationRes;
 import lombok.AllArgsConstructor;
 import lombok.extern.slf4j.Slf4j;
+import org.jetbrains.annotations.NotNull;
 import org.springframework.stereotype.Service;
 
 import java.util.List;
@@ -27,10 +32,11 @@ import java.util.List;
 @AllArgsConstructor
 public class ChatConversationServiceImpl implements ChatConversationService {
 
+    private final AiChatRoleMapper aiChatRoleMapper;
     private final AiChatConversationMapper aiChatConversationMapper;
 
     @Override
-    public ChatConversationRes create(ChatConversationCreateReq req) {
+    public ChatConversationRes createConversation(ChatConversationCreateUserReq req) {
         // 获取用户id
         Long loginUserId = SecurityFrameworkUtils.getLoginUserId();
         // 查询最新的对话
@@ -40,17 +46,45 @@ public class ChatConversationServiceImpl implements ChatConversationService {
             return ChatConversationConvert.INSTANCE.covnertChatConversationRes(latestConversation);
         }
         // 创建新的 Conversation
+        AiChatConversationDO insertConversation = saveConversation(req.getTitle(), loginUserId,
+                null, null, ChatConversationTypeEnum.USER_CHAT);
+        // 转换 res
+        return ChatConversationConvert.INSTANCE.covnertChatConversationRes(insertConversation);
+    }
+
+    @Override
+    public ChatConversationRes createRoleConversation(ChatConversationCreateRoleReq req) {
+        // 获取用户id
+        Long loginUserId = SecurityFrameworkUtils.getLoginUserId();
+        // 查询最新的对话
+        AiChatConversationDO latestConversation = aiChatConversationMapper.selectLatestConversation(loginUserId);
+        // 如果有对话没有被使用过,那就返回这个
+        if (latestConversation != null && latestConversation.getChatCount() <= 0) {
+            return ChatConversationConvert.INSTANCE.covnertChatConversationRes(latestConversation);
+        }
+        AiChatRoleDO aiChatRoleDO = aiChatRoleMapper.selectById(req.getChatRoleId());
+        // 创建新的 Conversation
+        AiChatConversationDO insertConversation = saveConversation(req.getTitle(), loginUserId,
+                req.getChatRoleId(), aiChatRoleDO.getRoleName(), ChatConversationTypeEnum.ROLE_CHAT);
+        // 转换 res
+        return ChatConversationConvert.INSTANCE.covnertChatConversationRes(insertConversation);
+    }
+
+    private @NotNull AiChatConversationDO saveConversation(String title,
+                                                           Long userId,
+                                                           Long chatRoleId,
+                                                           String chatRoleName,
+                                                           ChatConversationTypeEnum typeEnum) {
         AiChatConversationDO insertConversation = new AiChatConversationDO();
         insertConversation.setId(null);
-        insertConversation.setUserId(loginUserId);
-        insertConversation.setChatRoleId(null);
-        insertConversation.setChatRoleName(null);
-        insertConversation.setTitle(null);
+        insertConversation.setUserId(userId);
+        insertConversation.setChatRoleId(chatRoleId);
+        insertConversation.setChatRoleName(chatRoleName);
+        insertConversation.setTitle(title);
         insertConversation.setChatCount(0);
-        insertConversation.setType(req.getChatType());
+        insertConversation.setType(typeEnum.getType());
         aiChatConversationMapper.insert(insertConversation);
-        // 转换 res
-        return ChatConversationConvert.INSTANCE.covnertChatConversationRes(insertConversation);
+        return insertConversation;
     }
 
     @Override

+ 10 - 41
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/impl/ChatServiceImpl.java

@@ -5,15 +5,10 @@ import cn.iocoder.yudao.framework.ai.chat.ChatResponse;
 import cn.iocoder.yudao.framework.ai.chat.messages.MessageType;
 import cn.iocoder.yudao.framework.ai.chat.prompt.Prompt;
 import cn.iocoder.yudao.framework.ai.config.AiClient;
-import cn.iocoder.yudao.framework.common.exception.ServerException;
 import cn.iocoder.yudao.framework.security.core.util.SecurityFrameworkUtils;
-import cn.iocoder.yudao.module.ai.ErrorCodeConstants;
 import cn.iocoder.yudao.module.ai.controller.Utf8SseEmitter;
-import cn.iocoder.yudao.module.ai.dal.dataobject.AiChatConversationDO;
 import cn.iocoder.yudao.module.ai.dal.dataobject.AiChatMessageDO;
-import cn.iocoder.yudao.module.ai.dal.dataobject.AiChatRoleDO;
 import cn.iocoder.yudao.module.ai.enums.AiClientNameEnum;
-import cn.iocoder.yudao.module.ai.enums.ChatTypeEnum;
 import cn.iocoder.yudao.module.ai.mapper.AiChatConversationMapper;
 import cn.iocoder.yudao.module.ai.mapper.AiChatMessageMapper;
 import cn.iocoder.yudao.module.ai.mapper.AiChatRoleMapper;
@@ -49,7 +44,6 @@ public class ChatServiceImpl implements ChatService {
     private final AiChatConversationMapper aiChatConversationMapper;
     private final ChatConversationService chatConversationService;
 
-
     /**
      * chat
      *
@@ -64,7 +58,7 @@ public class ChatServiceImpl implements ChatService {
         // 获取对话信息
         ChatConversationRes conversationRes = chatConversationService.getConversation(req.getConversationId());
         // 保存 chat message
-        saveChatMessage(req, conversationRes.getId(), loginUserId);
+        saveChatMessage(req, conversationRes, loginUserId);
         String content = null;
         try {
             // 创建 chat 需要的 Prompt
@@ -75,16 +69,19 @@ public class ChatServiceImpl implements ChatService {
             // 发送 call 调用
             ChatResponse call = aiClient.call(prompt, clientNameEnum.getName());
             content = call.getResult().getOutput().getContent();
+            // 更新 conversation
+
         } catch (Exception e) {
             content = ExceptionUtil.getMessage(e);
         } finally {
             // 保存 chat message
-            saveSystemChatMessage(req, conversationRes.getId(), loginUserId, content);
+            saveSystemChatMessage(req, conversationRes, loginUserId, content);
         }
         return content;
     }
 
-    private void saveChatMessage(ChatReq req, Long chatConversationId, Long loginUserId) {
+    private void saveChatMessage(ChatReq req, ChatConversationRes conversationRes, Long loginUserId) {
+        Long chatConversationId = conversationRes.getId();
         // 增加 chat message 记录
         aiChatMessageMapper.insert(
                 new AiChatMessageDO()
@@ -97,12 +94,12 @@ public class ChatServiceImpl implements ChatService {
                         .setTopP(req.getTopP())
                         .setTemperature(req.getTemperature())
         );
-
         // chat count 先+1
         aiChatConversationMapper.updateIncrChatCount(req.getConversationId());
     }
 
-    public void saveSystemChatMessage(ChatReq req, Long chatConversationId, Long loginUserId, String systemPrompts) {
+    public void saveSystemChatMessage(ChatReq req, ChatConversationRes conversationRes, Long loginUserId, String systemPrompts) {
+        Long chatConversationId = conversationRes.getId();
         // 增加 chat message 记录
         aiChatMessageMapper.insert(
                 new AiChatMessageDO()
@@ -120,34 +117,6 @@ public class ChatServiceImpl implements ChatService {
         aiChatConversationMapper.updateIncrChatCount(req.getConversationId());
     }
 
-    private AiChatConversationDO createNewChatConversation(ChatReq req, Long loginUserId) {
-        // 获取 chat 角色
-        String chatRoleName = null;
-        ChatTypeEnum chatTypeEnum = null;
-        Long chatRoleId = req.getChatRoleId();
-        if (req.getChatRoleId() != null) {
-            AiChatRoleDO aiChatRoleDO = aiChatRoleMapper.selectById(chatRoleId);
-            if (aiChatRoleDO == null) {
-                throw new ServerException(ErrorCodeConstants.AI_CHAT_ROLE_NOT_EXISTENT);
-            }
-            chatTypeEnum = ChatTypeEnum.ROLE_CHAT;
-            chatRoleName = aiChatRoleDO.getRoleName();
-        } else {
-            chatTypeEnum = ChatTypeEnum.USER_CHAT;
-        }
-        //
-        AiChatConversationDO insertChatConversation = new AiChatConversationDO()
-                .setId(null)
-                .setUserId(loginUserId)
-                .setChatRoleId(req.getChatRoleId())
-                .setChatRoleName(chatRoleName)
-                .setType(chatTypeEnum.getType())
-                .setChatCount(1)
-                .setTitle(req.getPrompt().substring(0, 20) + "...");
-        aiChatConversationMapper.insert(insertChatConversation);
-        return insertChatConversation;
-    }
-
     /**
      * chat stream
      *
@@ -168,7 +137,7 @@ public class ChatServiceImpl implements ChatService {
         req.setTopP(req.getTopP());
         req.setTemperature(req.getTemperature());
         // 保存 chat message
-        saveChatMessage(req, conversationRes.getId(), loginUserId);
+        saveChatMessage(req, conversationRes, loginUserId);
         Flux<ChatResponse> streamResponse = aiClient.stream(prompt, clientNameEnum.getName());
 
         StringBuffer contentBuffer = new StringBuffer();
@@ -195,7 +164,7 @@ public class ChatServiceImpl implements ChatService {
                     log.info("发送完成!");
                     sseEmitter.complete();
                     // 保存 chat message
-                    saveSystemChatMessage(req, conversationRes.getId(), loginUserId, contentBuffer.toString());
+                    saveSystemChatMessage(req, conversationRes, loginUserId, contentBuffer.toString());
                 }
         );
     }

+ 26 - 0
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/vo/ChatConversationCreateRoleReq.java

@@ -0,0 +1,26 @@
+package cn.iocoder.yudao.module.ai.vo;
+
+import io.swagger.v3.oas.annotations.media.Schema;
+import jakarta.validation.constraints.NotNull;
+import lombok.Data;
+import lombok.experimental.Accessors;
+
+/**
+ * 聊天对话
+ *
+ * @author fansili
+ * @time 2024/4/18 16:24
+ * @since 1.0
+ */
+@Data
+@Accessors(chain = true)
+public class ChatConversationCreateRoleReq {
+
+    @Schema(description = "chat角色Id")
+    @NotNull(message = "聊天角色id不能为空!")
+    private Long chatRoleId;
+
+    @Schema(description = "标题(有程序自动生成)")
+    @NotNull(message = "标题不能为空!")
+    private String title;
+}

+ 4 - 5
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/vo/ChatConversationCreateReq.java → yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/vo/ChatConversationCreateUserReq.java

@@ -14,10 +14,9 @@ import lombok.experimental.Accessors;
  */
 @Data
 @Accessors(chain = true)
-public class ChatConversationCreateReq {
-
-    @Schema(description = "对话类型(roleChat、userChat)")
-    @NotNull(message = "聊天类型不能为空!")
-    private String chatType;
+public class ChatConversationCreateUserReq {
 
+    @Schema(description = "对话标题")
+    @NotNull(message = "标题不能为空!")
+    private String title;
 }

+ 6 - 1
yudao-module-ai/yudao-module-ai-biz/src/main/resources/http/chat-conversation.http

@@ -16,7 +16,12 @@ GET {{baseUrl}}/ai/chat/conversation/1781604279872581644
 Authorization: {{token}}
 
 
-### 对话 - id获取
+### 对话 - list
+GET {{baseUrl}}/ai/chat/conversation/list
+Authorization: {{token}}
+
+
+### 对话 - 删除
 DELETE {{baseUrl}}/ai/chat/conversation/1781604279872581644
 Authorization: {{token}}