Browse Source

【优化】优化 chat event stream 模式交互,增加 add message 优先记录

cherishsince 11 tháng trước cách đây
mục cha
commit
5a4162cdc1
13 tập tin đã thay đổi với 155 bổ sung233 xóa
  1. 4 0
      yudao-module-ai/yudao-module-ai-api/src/main/java/cn/iocoder/yudao/module/ai/ErrorCodeConstants.java
  2. 8 3
      yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/chat/AiChatMessageController.java
  3. 20 0
      yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/chat/vo/message/AiChatMessageAddReqVO.java
  4. 17 0
      yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/chat/vo/message/AiChatMessageAddRespVO.java
  5. 16 0
      yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/chat/vo/message/AiChatMessageSendStreamReqVO.java
  6. 9 0
      yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/convert/AiChatMessageConvert.java
  7. 0 1
      yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/mysql/AiChatModelMapper.java
  8. 0 78
      yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/AiChatModelService.java
  9. 10 3
      yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/AiChatService.java
  10. 0 132
      yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/impl/AiChatModalServiceImpl.java
  11. 53 16
      yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/impl/AiChatServiceImpl.java
  12. 10 0
      yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/model/AiChatModelService.java
  13. 8 0
      yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/model/AiChatModelServiceImpl.java

+ 4 - 0
yudao-module-ai/yudao-module-ai-api/src/main/java/cn/iocoder/yudao/module/ai/ErrorCodeConstants.java

@@ -36,5 +36,9 @@ public interface ErrorCodeConstants {
     ErrorCode AI_CHAT_ROLE_NOT_EXIST = new ErrorCode(1_022_000_060, "AI 角色不存在!");
     ErrorCode AI_CHAT_ROLE_NOT_PUBLIC = new ErrorCode(1_022_000_060, "AI 角色未公开!");
 
+    // chat
+
+    ErrorCode AI_CHAT_MESSAGE_NOT_EXIST = new ErrorCode(1_022_000_100, "AI 提问的 MessageId 不存在!");
+
 
 }

+ 8 - 3
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/chat/AiChatMessageController.java

@@ -1,8 +1,7 @@
 package cn.iocoder.yudao.module.ai.controller.admin.chat;
 
 import cn.iocoder.yudao.framework.common.pojo.CommonResult;
-import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageRespVO;
-import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageSendReqVO;
+import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.*;
 import cn.iocoder.yudao.module.ai.service.AiChatService;
 import io.swagger.v3.oas.annotations.Operation;
 import io.swagger.v3.oas.annotations.Parameter;
@@ -38,10 +37,16 @@ public class AiChatMessageController {
     // TODO @fan:要不要使用 Flux 来返回;可以使用 Flux<AiChatMessageRespVO>
     @Operation(summary = "发送消息(流式)", description = "流式返回,响应较快")
     @PostMapping(value = "/send-stream", produces = MediaType.TEXT_EVENT_STREAM_VALUE)
-    public Flux<AiChatMessageRespVO> sendMessageStream(@Validated @RequestBody AiChatMessageSendReqVO sendReqVO) {
+    public Flux<AiChatMessageRespVO> sendMessageStream(@Validated @RequestBody AiChatMessageSendStreamReqVO sendReqVO) {
         return chatService.chatStream(sendReqVO);
     }
 
+    @Operation(summary = "添加/提问", description = "先创建好 message 前端才好渲染")
+    @PostMapping(value = "/add")
+    public CommonResult<AiChatMessageRespVO> add(@Validated @RequestBody AiChatMessageAddReqVO req) {
+        return success(chatService.add(req));
+    }
+
     @Operation(summary = "获得指定会话的消息列表")
     @GetMapping("/list-by-conversation-id")
     @Parameter(name = "conversationId", required = true, description = "会话编号", example = "1024")

+ 20 - 0
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/chat/vo/message/AiChatMessageAddReqVO.java

@@ -0,0 +1,20 @@
+package cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message;
+
+import io.swagger.v3.oas.annotations.media.Schema;
+import jakarta.validation.constraints.NotEmpty;
+import jakarta.validation.constraints.NotNull;
+import lombok.Data;
+
+@Schema(description = "管理后台 - AI 聊天消息发送 Request VO")
+@Data
+public class AiChatMessageAddReqVO {
+
+    @Schema(description = "聊天对话编号", requiredMode = Schema.RequiredMode.REQUIRED, example = "1024")
+    @NotNull(message = "聊天对话编号不能为空")
+    private Long conversationId;
+
+    @Schema(description = "聊天内容", requiredMode = Schema.RequiredMode.REQUIRED, example = "帮我写个 Java 算法")
+    @NotEmpty(message = "聊天内容不能为空")
+    private String content;
+
+}

+ 17 - 0
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/chat/vo/message/AiChatMessageAddRespVO.java

@@ -0,0 +1,17 @@
+package cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message;
+
+import io.swagger.v3.oas.annotations.media.Schema;
+import lombok.Data;
+
+import java.time.LocalDateTime;
+
+@Schema(description = "管理后台 - AI 聊天消息 Add Response VO")
+@Data
+public class AiChatMessageAddRespVO {
+
+    @Schema(description = "用户信息")
+    private AiChatMessageRespVO userMessage;
+
+    @Schema(description = "系统信息")
+    private AiChatMessageRespVO systemMessage;
+}

+ 16 - 0
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/chat/vo/message/AiChatMessageSendStreamReqVO.java

@@ -0,0 +1,16 @@
+package cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message;
+
+import io.swagger.v3.oas.annotations.media.Schema;
+import jakarta.validation.constraints.NotEmpty;
+import jakarta.validation.constraints.NotNull;
+import lombok.Data;
+
+@Schema(description = "管理后台 - AI 聊天消息发送 Request VO")
+@Data
+public class AiChatMessageSendStreamReqVO {
+
+    @Schema(description = "提问的 messageId", requiredMode = Schema.RequiredMode.REQUIRED, example = "1024")
+    @NotNull(message = "提问的 messageId 不能为空")
+    private Long id;
+
+}

+ 9 - 0
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/convert/AiChatMessageConvert.java

@@ -26,4 +26,13 @@ public interface AiChatMessageConvert {
      * @return
      */
     List<AiChatMessageRespVO> convertAiChatMessageRespVOList(List<AiChatMessageDO> aiChatMessageDOList);
+
+    /**
+     * 转换 - aiChatMessageDO
+     *
+     * @param aiChatMessageDO
+     * @return
+     */
+    AiChatMessageRespVO convertAiChatMessageRespVO(AiChatMessageDO aiChatMessageDO);
+
 }

+ 0 - 1
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/mysql/AiChatModelMapper.java

@@ -11,7 +11,6 @@ import org.apache.ibatis.annotations.Mapper;
 
 import java.util.Collection;
 import java.util.List;
-import java.util.Set;
 
 /**
  * API 聊天模型 Mapper

+ 0 - 78
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/AiChatModelService.java

@@ -1,78 +0,0 @@
-package cn.iocoder.yudao.module.ai.service;
-
-import cn.iocoder.yudao.framework.common.pojo.PageResult;
-import cn.iocoder.yudao.module.ai.controller.admin.model.vo.model.*;
-import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO;
-
-import java.util.List;
-import java.util.Set;
-
-/**
- * ai modal
- *
- * @author fansili
- * @time 2024/4/24 19:42
- * @since 1.0
- */
-public interface AiChatModelService {
-
-    /**
-     * ai modal - 列表
-     *
-     * @param req
-     * @return
-     */
-    PageResult<AiChatModelListRespVO> list(AiChatModelListReqVO req);
-
-    /**
-     * ai modal - 添加
-     *
-     * @param req
-     */
-    void add(AiChatModelAddReqVO req);
-
-    /**
-     * ai modal - 更新
-     *
-     * @param req
-     */
-    void update(AiChatModelUpdateReqVO req);
-
-    /**
-     * ai modal - 删除
-     *
-     * @param id
-     */
-    void delete(Long id);
-
-    /**
-     * 获取 - 获取 modal
-     *
-     * @param modalId
-     * @return
-     */
-    AiChatModalRespVO getChatModalOfValidate(Long modalId);
-
-    /**
-     * 校验 - 是否存在
-     *
-     * @param id
-     * @return
-     */
-    AiChatModelDO validateExists(Long id);
-
-    /**
-     * 校验 - 校验是否可用
-     *
-     * @param chatModal
-     */
-    void validateAvailable(AiChatModalRespVO chatModal);
-
-    /**
-     * 获取 - 根据 ids 批量获取
-     *
-     * @param modalIds
-     * @return
-     */
-    List<AiChatModelDO> getModalByIds(Set<Long> modalIds);
-}

+ 10 - 3
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/AiChatService.java

@@ -1,7 +1,6 @@
 package cn.iocoder.yudao.module.ai.service;
 
-import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageRespVO;
-import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageSendReqVO;
+import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.*;
 import reactor.core.publisher.Flux;
 
 import java.util.List;
@@ -29,7 +28,15 @@ public interface AiChatService {
      * @param sendReqVO
      * @return
      */
-    Flux<AiChatMessageRespVO> chatStream(AiChatMessageSendReqVO sendReqVO);
+    Flux<AiChatMessageRespVO> chatStream(AiChatMessageSendStreamReqVO sendReqVO);
+
+    /**
+     * 添加 - message
+     *
+     * @param sendReqVO
+     * @return
+     */
+    AiChatMessageRespVO add(AiChatMessageAddReqVO sendReqVO);
 
     /**
      * 获取 - 获取对话 message list

+ 0 - 132
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/impl/AiChatModalServiceImpl.java

@@ -1,132 +0,0 @@
-package cn.iocoder.yudao.module.ai.service.impl;
-
-import cn.hutool.core.util.StrUtil;
-import cn.hutool.extra.validation.ValidationUtil;
-import cn.iocoder.yudao.framework.ai.AiPlatformEnum;
-import cn.iocoder.yudao.framework.common.enums.CommonStatusEnum;
-import cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionUtil;
-import cn.iocoder.yudao.framework.common.pojo.PageResult;
-import cn.iocoder.yudao.framework.mybatis.core.query.LambdaQueryWrapperX;
-import cn.iocoder.yudao.module.ai.ErrorCodeConstants;
-import cn.iocoder.yudao.module.ai.controller.admin.model.vo.model.*;
-import cn.iocoder.yudao.module.ai.convert.AiChatModelConvert;
-import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO;
-import cn.iocoder.yudao.module.ai.dal.mysql.AiChatModelMapper;
-import cn.iocoder.yudao.module.ai.dal.vo.AiChatModalConfigVO;
-import cn.iocoder.yudao.module.ai.service.AiChatModelService;
-import jakarta.validation.ConstraintViolation;
-import lombok.AllArgsConstructor;
-import lombok.extern.slf4j.Slf4j;
-import org.springframework.stereotype.Service;
-
-import java.util.List;
-import java.util.Set;
-
-/**
- * ai 模型
- *
- * @author fansili
- * @time 2024/4/24 19:42
- * @since 1.0
- */
-@AllArgsConstructor
-@Service
-@Slf4j
-public class AiChatModalServiceImpl implements AiChatModelService {
-
-    private final AiChatModelMapper aiChatModelMapper;
-
-    @Override
-    public PageResult<AiChatModelListRespVO> list(AiChatModelListReqVO req) {
-        LambdaQueryWrapperX<AiChatModelDO> queryWrapperX = new LambdaQueryWrapperX<>();
-        // 查询的都是未禁用的模型
-        queryWrapperX.eq(AiChatModelDO::getStatus, CommonStatusEnum.ENABLE.getStatus());
-        // search
-        if (!StrUtil.isBlank(req.getSearch())) {
-            queryWrapperX.like(AiChatModelDO::getName, req.getSearch().trim());
-        }
-        // 默认排序
-        queryWrapperX.orderByAsc(AiChatModelDO::getSort);
-        // 查询
-        PageResult<AiChatModelDO> aiChatModalDOPageResult = aiChatModelMapper.selectPage(req, queryWrapperX);
-        // 转换 res
-        List<AiChatModelListRespVO> resList = AiChatModelConvert.INSTANCE.convertAiChatModalListRes(aiChatModalDOPageResult.getList());
-        return new PageResult<>(resList, aiChatModalDOPageResult.getTotal());
-    }
-
-    @Override
-    public void add(AiChatModelAddReqVO req) {
-        // 校验 platform、type
-        validatePlatform(req.getPlatform());
-        // 转换 do
-        AiChatModelDO insertChatModalDO = AiChatModelConvert.INSTANCE.convertAiChatModalDO(req);
-        // 设置默认属性
-        insertChatModalDO.setStatus(CommonStatusEnum.ENABLE.getStatus());
-        // 保存数据库
-        aiChatModelMapper.insert(insertChatModalDO);
-    }
-
-    @Override
-    public void update(AiChatModelUpdateReqVO req) {
-        // 校验 platform
-        validatePlatform(req.getPlatform());
-        // 校验模型是否存在
-        validateExists(req.getId());
-        // 转换 updateChatModalDO
-        AiChatModelDO updateChatModalDO = AiChatModelConvert.INSTANCE.convertAiChatModalDO(req);
-        updateChatModalDO.setId(req.getId());
-        // 更新数据库
-        aiChatModelMapper.updateById(updateChatModalDO);
-    }
-
-    @Override
-    public void delete(Long id) {
-        // 检查 modal 是否存在
-        validateExists(id);
-        // 删除 delete
-        aiChatModelMapper.deleteById(id);
-    }
-
-    @Override
-    public AiChatModalRespVO getChatModalOfValidate(Long modalId) {
-        // 检查 modal 是否存在
-        AiChatModelDO aiChatModalDO = validateExists(modalId);
-        return AiChatModelConvert.INSTANCE.convertAiChatModalRes(aiChatModalDO);
-    }
-
-    @Override
-    public void validateAvailable(AiChatModalRespVO chatModal) {
-        // 对话模型是否可用
-        if (!CommonStatusEnum.ENABLE.getStatus().equals(chatModal.getStatus())) {
-            throw ServiceExceptionUtil.exception(ErrorCodeConstants.AI_MODAL_DISABLE_NOT_USED);
-        }
-    }
-
-    @Override
-    public List<AiChatModelDO> getModalByIds(Set<Long> modalIds) {
-        return aiChatModelMapper.selectByIds(modalIds);
-    }
-
-    public AiChatModelDO validateExists(Long id) {
-        AiChatModelDO aiChatModalDO = aiChatModelMapper.selectById(id);
-        if (aiChatModalDO == null) {
-            throw ServiceExceptionUtil.exception(ErrorCodeConstants.AI_MODAL_NOT_EXIST);
-        }
-        return aiChatModalDO;
-    }
-
-    private void validatePlatform(String platform) {
-        try {
-            AiPlatformEnum.valueOfPlatform(platform);
-        } catch (IllegalArgumentException e) {
-            throw ServiceExceptionUtil.exception(ErrorCodeConstants.AI_MODAL_PLATFORM_PARAMS_INCORRECT, e.getMessage());
-        }
-    }
-
-    private void validateModalConfig(AiChatModalConfigVO aiChatModalConfigVO) {
-        Set<ConstraintViolation<AiChatModalConfigVO>> validate = ValidationUtil.validate(aiChatModalConfigVO);
-        for (ConstraintViolation<AiChatModalConfigVO> constraintViolation : validate) {
-            throw ServiceExceptionUtil.exception(ErrorCodeConstants.AI_MODAL_CONFIG_PARAMS_INCORRECT, constraintViolation.getMessage());
-        }
-    }
-}

+ 53 - 16
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/impl/AiChatServiceImpl.java

@@ -7,11 +7,15 @@ import cn.iocoder.yudao.framework.ai.chat.ChatResponse;
 import cn.iocoder.yudao.framework.ai.chat.StreamingChatClient;
 import cn.iocoder.yudao.framework.ai.chat.messages.MessageType;
 import cn.iocoder.yudao.framework.ai.chat.prompt.Prompt;
+import cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionUtil;
 import cn.iocoder.yudao.framework.security.core.util.SecurityFrameworkUtils;
+import cn.iocoder.yudao.module.ai.ErrorCodeConstants;
 import cn.iocoder.yudao.module.ai.config.AiChatClientFactory;
 import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.conversation.AiChatConversationRespVO;
+import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageAddReqVO;
 import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageRespVO;
 import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageSendReqVO;
+import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageSendStreamReqVO;
 import cn.iocoder.yudao.module.ai.convert.AiChatMessageConvert;
 import cn.iocoder.yudao.module.ai.dal.dataobject.chat.AiChatMessageDO;
 import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO;
@@ -19,11 +23,12 @@ import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatRoleDO;
 import cn.iocoder.yudao.module.ai.dal.mysql.AiChatConversationMapper;
 import cn.iocoder.yudao.module.ai.dal.mysql.AiChatMessageMapper;
 import cn.iocoder.yudao.module.ai.service.AiChatConversationService;
-import cn.iocoder.yudao.module.ai.service.model.AiChatModelService;
 import cn.iocoder.yudao.module.ai.service.AiChatRoleService;
 import cn.iocoder.yudao.module.ai.service.AiChatService;
+import cn.iocoder.yudao.module.ai.service.model.AiChatModelService;
 import lombok.AllArgsConstructor;
 import lombok.extern.slf4j.Slf4j;
+import org.springframework.boot.autoconfigure.http.HttpMessageConverters;
 import org.springframework.stereotype.Service;
 import org.springframework.transaction.annotation.Transactional;
 import reactor.core.publisher.Flux;
@@ -53,6 +58,7 @@ public class AiChatServiceImpl implements AiChatService {
     private final AiChatConversationService chatConversationService;
     private final AiChatModelService aiChatModalService;
     private final AiChatRoleService aiChatRoleService;
+    private final HttpMessageConverters messageConverters;
 
     @Transactional(rollbackFor = Exception.class)
     public AiChatMessageRespVO chat(AiChatMessageSendReqVO req) {
@@ -124,10 +130,15 @@ public class AiChatServiceImpl implements AiChatService {
         return insertChatMessageDO;
     }
 
-    public Flux<AiChatMessageRespVO> chatStream(AiChatMessageSendReqVO req) {
+    public Flux<AiChatMessageRespVO> chatStream(AiChatMessageSendStreamReqVO req) {
         Long loginUserId = SecurityFrameworkUtils.getLoginUserId();
+        // 查询提问的 message
+        AiChatMessageDO aiChatMessageDO = aiChatMessageMapper.selectById(req.getId());
+        if (aiChatMessageDO == null) {
+            throw ServiceExceptionUtil.exception(ErrorCodeConstants.AI_CHAT_MESSAGE_NOT_EXIST);
+        }
         // 查询对话
-        AiChatConversationRespVO conversation = chatConversationService.getConversationOfValidate(req.getConversationId());
+        AiChatConversationRespVO conversation = chatConversationService.getConversationOfValidate(aiChatMessageDO.getConversationId());
         // 获取对话模型
         AiChatModelDO chatModel = aiChatModalService.validateChatModel(conversation.getModelId());
         // 获取角色信息
@@ -138,14 +149,14 @@ public class AiChatServiceImpl implements AiChatService {
         // 校验角色是否公开
         aiChatRoleService.validateIsPublic(aiChatRoleDO);
         // 创建 chat 需要的 Prompt
-        Prompt prompt = new Prompt(req.getContent());
+        Prompt prompt = new Prompt(aiChatMessageDO.getContent());
+        // 提前创建一个 system message
+        AiChatMessageDO systemMessage = insertChatMessage(conversation.getId(), MessageType.SYSTEM, loginUserId, conversation.getRoleId(),
+                chatModel.getModel(), chatModel.getId(), "",
+                0, conversation.getTemperature(), conversation.getMaxTokens(), conversation.getMaxContexts());
 //        req.setTopK(req.getTopK());
 //        req.setTopP(req.getTopP());
 //        req.setTemperature(req.getTemperature());
-        // 保存 chat message
-        insertChatMessage(conversation.getId(), MessageType.USER, loginUserId, conversation.getRoleId(),
-                chatModel.getModel(), chatModel.getId(), req.getContent(),
-                null, conversation.getTemperature(), conversation.getMaxTokens(), conversation.getMaxContexts());
         // 获取 client 类型
         AiPlatformEnum platformEnum = AiPlatformEnum.validatePlatform(chatModel.getPlatform());
         StreamingChatClient streamingChatClient = aiChatClientFactory.getStreamingChatClient(platformEnum);
@@ -154,7 +165,8 @@ public class AiChatServiceImpl implements AiChatService {
         StringBuffer contentBuffer = new StringBuffer();
         AtomicInteger tokens = new AtomicInteger(0);
         return streamResponse.map(res -> {
-                    AiChatMessageRespVO aiChatMessageRespVO = new AiChatMessageRespVO();
+                    AiChatMessageRespVO aiChatMessageRespVO =
+                            AiChatMessageConvert.INSTANCE.convertAiChatMessageRespVO(systemMessage);
                     aiChatMessageRespVO.setContent(res.getResult().getOutput().getContent());
                     contentBuffer.append(res.getResult().getOutput().getContent());
                     tokens.incrementAndGet();
@@ -165,22 +177,46 @@ public class AiChatServiceImpl implements AiChatService {
             public void run() {
                 log.info("发送完成!");
                 // 保存 chat message
-                insertChatMessage(conversation.getId(), MessageType.SYSTEM, loginUserId, conversation.getRoleId(),
-                        chatModel.getModel(), chatModel.getId(), contentBuffer.toString(),
-                        tokens.get(), conversation.getTemperature(), conversation.getMaxTokens(), conversation.getMaxContexts());
+                aiChatMessageMapper.updateById(new AiChatMessageDO()
+                        .setId(systemMessage.getId())
+                        .setContent(contentBuffer.toString())
+                        .setTokens(tokens.get())
+                );
             }
         }).doOnError(new Consumer<Throwable>() {
             @Override
             public void accept(Throwable throwable) {
                 log.error("发送错误 {}!", throwable.getMessage());
-                // 保存 chat message
-                insertChatMessage(conversation.getId(), MessageType.SYSTEM, loginUserId, conversation.getRoleId(),
-                        chatModel.getModel(), chatModel.getId(), throwable.getMessage(),
-                        tokens.get(), conversation.getTemperature(), conversation.getMaxTokens(), conversation.getMaxContexts());
+                // 更新错误信息
+                aiChatMessageMapper.updateById(new AiChatMessageDO()
+                        .setId(systemMessage.getId())
+                        .setContent(throwable.getMessage())
+                        .setTokens(tokens.get())
+                );
             }
         });
     }
 
+    @Override
+    public AiChatMessageRespVO add(AiChatMessageAddReqVO req) {
+        Long loginUserId = SecurityFrameworkUtils.getLoginUserId();
+        // 查询对话
+        AiChatConversationRespVO conversation = chatConversationService.getConversationOfValidate(req.getConversationId());
+        // 获取对话模型
+        AiChatModelDO chatModel = aiChatModalService.validateChatModel(conversation.getModelId());
+        // 获取角色信息
+        AiChatRoleDO aiChatRoleDO = null;
+        if (conversation.getRoleId() != null) {
+            aiChatRoleDO = aiChatRoleService.validateExists(conversation.getRoleId());
+        }
+        // 校验角色是否公开
+        aiChatRoleService.validateIsPublic(aiChatRoleDO);
+        AiChatMessageDO userMessage = insertChatMessage(conversation.getId(), MessageType.USER, loginUserId, conversation.getRoleId(),
+                chatModel.getModel(), chatModel.getId(), req.getContent(),
+                null, conversation.getTemperature(), conversation.getMaxTokens(), conversation.getMaxContexts());
+       return AiChatMessageConvert.INSTANCE.convertAiChatMessageRespVO(userMessage);
+    }
+
     @Override
     public List<AiChatMessageRespVO> getMessageListByConversationId(Long conversationId) {
         // 校验对话是否存在
@@ -207,4 +243,5 @@ public class AiChatServiceImpl implements AiChatService {
     public Boolean deleteMessage(Long id) {
         return aiChatMessageMapper.deleteById(id) > 0;
     }
+
 }

+ 10 - 0
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/model/AiChatModelService.java

@@ -6,6 +6,9 @@ import cn.iocoder.yudao.module.ai.controller.admin.model.vo.chatModel.AiChatMode
 import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO;
 import jakarta.validation.Valid;
 
+import java.util.List;
+import java.util.Set;
+
 /**
  * AI 聊天模型 Service 接口
  *
@@ -60,4 +63,11 @@ public interface AiChatModelService {
      */
     AiChatModelDO validateChatModel(Long id);
 
+    /**
+     * 获取 - 根据多个 ids 获取
+     *
+     * @param modalIds
+     * @return
+     */
+    List<AiChatModelDO> getModalByIds(Set<Long> modalIds);
 }

+ 8 - 0
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/model/AiChatModelServiceImpl.java

@@ -12,6 +12,9 @@ import jakarta.annotation.Resource;
 import org.springframework.stereotype.Service;
 import org.springframework.validation.annotation.Validated;
 
+import java.util.List;
+import java.util.Set;
+
 import static cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionUtil.exception;
 import static cn.iocoder.yudao.module.ai.ErrorCodeConstants.*;
 
@@ -89,4 +92,9 @@ public class AiChatModelServiceImpl implements AiChatModelService {
         return model;
     }
 
+    @Override
+    public List<AiChatModelDO> getModalByIds(Set<Long> modalIds) {
+        return chatModelMapper.selectByIds(modalIds);
+    }
+
 }