Ver Fonte

【优化】优化 chat event stream 模式交互,增加 add message 优先记录

cherishsince há 11 meses atrás
pai
commit
5a4162cdc1
13 ficheiros alterados com 155 adições e 233 exclusões
  1. 4 0
      yudao-module-ai/yudao-module-ai-api/src/main/java/cn/iocoder/yudao/module/ai/ErrorCodeConstants.java
  2. 8 3
      yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/chat/AiChatMessageController.java
  3. 20 0
      yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/chat/vo/message/AiChatMessageAddReqVO.java
  4. 17 0
      yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/chat/vo/message/AiChatMessageAddRespVO.java
  5. 16 0
      yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/chat/vo/message/AiChatMessageSendStreamReqVO.java
  6. 9 0
      yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/convert/AiChatMessageConvert.java
  7. 0 1
      yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/mysql/AiChatModelMapper.java
  8. 0 78
      yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/AiChatModelService.java
  9. 10 3
      yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/AiChatService.java
  10. 0 132
      yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/impl/AiChatModalServiceImpl.java
  11. 53 16
      yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/impl/AiChatServiceImpl.java
  12. 10 0
      yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/model/AiChatModelService.java
  13. 8 0
      yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/model/AiChatModelServiceImpl.java

+ 4 - 0
yudao-module-ai/yudao-module-ai-api/src/main/java/cn/iocoder/yudao/module/ai/ErrorCodeConstants.java

@@ -36,5 +36,9 @@ public interface ErrorCodeConstants {
     ErrorCode AI_CHAT_ROLE_NOT_EXIST = new ErrorCode(1_022_000_060, "AI 角色不存在!");
     ErrorCode AI_CHAT_ROLE_NOT_PUBLIC = new ErrorCode(1_022_000_060, "AI 角色未公开!");
 
+    // chat
+
+    ErrorCode AI_CHAT_MESSAGE_NOT_EXIST = new ErrorCode(1_022_000_100, "AI 提问的 MessageId 不存在!");
+
 
 }

+ 8 - 3
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/chat/AiChatMessageController.java

@@ -1,8 +1,7 @@
 package cn.iocoder.yudao.module.ai.controller.admin.chat;
 
 import cn.iocoder.yudao.framework.common.pojo.CommonResult;
-import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageRespVO;
-import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageSendReqVO;
+import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.*;
 import cn.iocoder.yudao.module.ai.service.AiChatService;
 import io.swagger.v3.oas.annotations.Operation;
 import io.swagger.v3.oas.annotations.Parameter;
@@ -38,10 +37,16 @@ public class AiChatMessageController {
     // TODO @fan:要不要使用 Flux 来返回;可以使用 Flux<AiChatMessageRespVO>
     @Operation(summary = "发送消息(流式)", description = "流式返回,响应较快")
     @PostMapping(value = "/send-stream", produces = MediaType.TEXT_EVENT_STREAM_VALUE)
-    public Flux<AiChatMessageRespVO> sendMessageStream(@Validated @RequestBody AiChatMessageSendReqVO sendReqVO) {
+    public Flux<AiChatMessageRespVO> sendMessageStream(@Validated @RequestBody AiChatMessageSendStreamReqVO sendReqVO) {
         return chatService.chatStream(sendReqVO);
     }
 
+    @Operation(summary = "添加/提问", description = "先创建好 message 前端才好渲染")
+    @PostMapping(value = "/add")
+    public CommonResult<AiChatMessageRespVO> add(@Validated @RequestBody AiChatMessageAddReqVO req) {
+        return success(chatService.add(req));
+    }
+
     @Operation(summary = "获得指定会话的消息列表")
     @GetMapping("/list-by-conversation-id")
     @Parameter(name = "conversationId", required = true, description = "会话编号", example = "1024")

+ 20 - 0
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/chat/vo/message/AiChatMessageAddReqVO.java

@@ -0,0 +1,20 @@
+package cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message;
+
+import io.swagger.v3.oas.annotations.media.Schema;
+import jakarta.validation.constraints.NotEmpty;
+import jakarta.validation.constraints.NotNull;
+import lombok.Data;
+
+@Schema(description = "管理后台 - AI 聊天消息发送 Request VO")
+@Data
+public class AiChatMessageAddReqVO {
+
+    @Schema(description = "聊天对话编号", requiredMode = Schema.RequiredMode.REQUIRED, example = "1024")
+    @NotNull(message = "聊天对话编号不能为空")
+    private Long conversationId;
+
+    @Schema(description = "聊天内容", requiredMode = Schema.RequiredMode.REQUIRED, example = "帮我写个 Java 算法")
+    @NotEmpty(message = "聊天内容不能为空")
+    private String content;
+
+}

+ 17 - 0
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/chat/vo/message/AiChatMessageAddRespVO.java

@@ -0,0 +1,17 @@
+package cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message;
+
+import io.swagger.v3.oas.annotations.media.Schema;
+import lombok.Data;
+
+import java.time.LocalDateTime;
+
+@Schema(description = "管理后台 - AI 聊天消息 Add Response VO")
+@Data
+public class AiChatMessageAddRespVO {
+
+    @Schema(description = "用户信息")
+    private AiChatMessageRespVO userMessage;
+
+    @Schema(description = "系统信息")
+    private AiChatMessageRespVO systemMessage;
+}

+ 16 - 0
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/chat/vo/message/AiChatMessageSendStreamReqVO.java

@@ -0,0 +1,16 @@
+package cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message;
+
+import io.swagger.v3.oas.annotations.media.Schema;
+import jakarta.validation.constraints.NotEmpty;
+import jakarta.validation.constraints.NotNull;
+import lombok.Data;
+
+@Schema(description = "管理后台 - AI 聊天消息发送 Request VO")
+@Data
+public class AiChatMessageSendStreamReqVO {
+
+    @Schema(description = "提问的 messageId", requiredMode = Schema.RequiredMode.REQUIRED, example = "1024")
+    @NotNull(message = "提问的 messageId 不能为空")
+    private Long id;
+
+}

+ 9 - 0
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/convert/AiChatMessageConvert.java

@@ -26,4 +26,13 @@ public interface AiChatMessageConvert {
      * @return
      */
     List<AiChatMessageRespVO> convertAiChatMessageRespVOList(List<AiChatMessageDO> aiChatMessageDOList);
+
+    /**
+     * 转换 - aiChatMessageDO
+     *
+     * @param aiChatMessageDO
+     * @return
+     */
+    AiChatMessageRespVO convertAiChatMessageRespVO(AiChatMessageDO aiChatMessageDO);
+
 }

+ 0 - 1
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/mysql/AiChatModelMapper.java

@@ -11,7 +11,6 @@ import org.apache.ibatis.annotations.Mapper;
 
 import java.util.Collection;
 import java.util.List;
-import java.util.Set;
 
 /**
  * API 聊天模型 Mapper

+ 0 - 78
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/AiChatModelService.java

@@ -1,78 +0,0 @@
-package cn.iocoder.yudao.module.ai.service;
-
-import cn.iocoder.yudao.framework.common.pojo.PageResult;
-import cn.iocoder.yudao.module.ai.controller.admin.model.vo.model.*;
-import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO;
-
-import java.util.List;
-import java.util.Set;
-
-/**
- * ai modal
- *
- * @author fansili
- * @time 2024/4/24 19:42
- * @since 1.0
- */
-public interface AiChatModelService {
-
-    /**
-     * ai modal - 列表
-     *
-     * @param req
-     * @return
-     */
-    PageResult<AiChatModelListRespVO> list(AiChatModelListReqVO req);
-
-    /**
-     * ai modal - 添加
-     *
-     * @param req
-     */
-    void add(AiChatModelAddReqVO req);
-
-    /**
-     * ai modal - 更新
-     *
-     * @param req
-     */
-    void update(AiChatModelUpdateReqVO req);
-
-    /**
-     * ai modal - 删除
-     *
-     * @param id
-     */
-    void delete(Long id);
-
-    /**
-     * 获取 - 获取 modal
-     *
-     * @param modalId
-     * @return
-     */
-    AiChatModalRespVO getChatModalOfValidate(Long modalId);
-
-    /**
-     * 校验 - 是否存在
-     *
-     * @param id
-     * @return
-     */
-    AiChatModelDO validateExists(Long id);
-
-    /**
-     * 校验 - 校验是否可用
-     *
-     * @param chatModal
-     */
-    void validateAvailable(AiChatModalRespVO chatModal);
-
-    /**
-     * 获取 - 根据 ids 批量获取
-     *
-     * @param modalIds
-     * @return
-     */
-    List<AiChatModelDO> getModalByIds(Set<Long> modalIds);
-}

+ 10 - 3
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/AiChatService.java

@@ -1,7 +1,6 @@
 package cn.iocoder.yudao.module.ai.service;
 
-import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageRespVO;
-import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageSendReqVO;
+import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.*;
 import reactor.core.publisher.Flux;
 
 import java.util.List;
@@ -29,7 +28,15 @@ public interface AiChatService {
      * @param sendReqVO
      * @return
      */
-    Flux<AiChatMessageRespVO> chatStream(AiChatMessageSendReqVO sendReqVO);
+    Flux<AiChatMessageRespVO> chatStream(AiChatMessageSendStreamReqVO sendReqVO);
+
+    /**
+     * 添加 - message
+     *
+     * @param sendReqVO
+     * @return
+     */
+    AiChatMessageRespVO add(AiChatMessageAddReqVO sendReqVO);
 
     /**
      * 获取 - 获取对话 message list

+ 0 - 132
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/impl/AiChatModalServiceImpl.java

@@ -1,132 +0,0 @@
-package cn.iocoder.yudao.module.ai.service.impl;
-
-import cn.hutool.core.util.StrUtil;
-import cn.hutool.extra.validation.ValidationUtil;
-import cn.iocoder.yudao.framework.ai.AiPlatformEnum;
-import cn.iocoder.yudao.framework.common.enums.CommonStatusEnum;
-import cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionUtil;
-import cn.iocoder.yudao.framework.common.pojo.PageResult;
-import cn.iocoder.yudao.framework.mybatis.core.query.LambdaQueryWrapperX;
-import cn.iocoder.yudao.module.ai.ErrorCodeConstants;
-import cn.iocoder.yudao.module.ai.controller.admin.model.vo.model.*;
-import cn.iocoder.yudao.module.ai.convert.AiChatModelConvert;
-import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO;
-import cn.iocoder.yudao.module.ai.dal.mysql.AiChatModelMapper;
-import cn.iocoder.yudao.module.ai.dal.vo.AiChatModalConfigVO;
-import cn.iocoder.yudao.module.ai.service.AiChatModelService;
-import jakarta.validation.ConstraintViolation;
-import lombok.AllArgsConstructor;
-import lombok.extern.slf4j.Slf4j;
-import org.springframework.stereotype.Service;
-
-import java.util.List;
-import java.util.Set;
-
-/**
- * ai 模型
- *
- * @author fansili
- * @time 2024/4/24 19:42
- * @since 1.0
- */
-@AllArgsConstructor
-@Service
-@Slf4j
-public class AiChatModalServiceImpl implements AiChatModelService {
-
-    private final AiChatModelMapper aiChatModelMapper;
-
-    @Override
-    public PageResult<AiChatModelListRespVO> list(AiChatModelListReqVO req) {
-        LambdaQueryWrapperX<AiChatModelDO> queryWrapperX = new LambdaQueryWrapperX<>();
-        // 查询的都是未禁用的模型
-        queryWrapperX.eq(AiChatModelDO::getStatus, CommonStatusEnum.ENABLE.getStatus());
-        // search
-        if (!StrUtil.isBlank(req.getSearch())) {
-            queryWrapperX.like(AiChatModelDO::getName, req.getSearch().trim());
-        }
-        // 默认排序
-        queryWrapperX.orderByAsc(AiChatModelDO::getSort);
-        // 查询
-        PageResult<AiChatModelDO> aiChatModalDOPageResult = aiChatModelMapper.selectPage(req, queryWrapperX);
-        // 转换 res
-        List<AiChatModelListRespVO> resList = AiChatModelConvert.INSTANCE.convertAiChatModalListRes(aiChatModalDOPageResult.getList());
-        return new PageResult<>(resList, aiChatModalDOPageResult.getTotal());
-    }
-
-    @Override
-    public void add(AiChatModelAddReqVO req) {
-        // 校验 platform、type
-        validatePlatform(req.getPlatform());
-        // 转换 do
-        AiChatModelDO insertChatModalDO = AiChatModelConvert.INSTANCE.convertAiChatModalDO(req);
-        // 设置默认属性
-        insertChatModalDO.setStatus(CommonStatusEnum.ENABLE.getStatus());
-        // 保存数据库
-        aiChatModelMapper.insert(insertChatModalDO);
-    }
-
-    @Override
-    public void update(AiChatModelUpdateReqVO req) {
-        // 校验 platform
-        validatePlatform(req.getPlatform());
-        // 校验模型是否存在
-        validateExists(req.getId());
-        // 转换 updateChatModalDO
-        AiChatModelDO updateChatModalDO = AiChatModelConvert.INSTANCE.convertAiChatModalDO(req);
-        updateChatModalDO.setId(req.getId());
-        // 更新数据库
-        aiChatModelMapper.updateById(updateChatModalDO);
-    }
-
-    @Override
-    public void delete(Long id) {
-        // 检查 modal 是否存在
-        validateExists(id);
-        // 删除 delete
-        aiChatModelMapper.deleteById(id);
-    }
-
-    @Override
-    public AiChatModalRespVO getChatModalOfValidate(Long modalId) {
-        // 检查 modal 是否存在
-        AiChatModelDO aiChatModalDO = validateExists(modalId);
-        return AiChatModelConvert.INSTANCE.convertAiChatModalRes(aiChatModalDO);
-    }
-
-    @Override
-    public void validateAvailable(AiChatModalRespVO chatModal) {
-        // 对话模型是否可用
-        if (!CommonStatusEnum.ENABLE.getStatus().equals(chatModal.getStatus())) {
-            throw ServiceExceptionUtil.exception(ErrorCodeConstants.AI_MODAL_DISABLE_NOT_USED);
-        }
-    }
-
-    @Override
-    public List<AiChatModelDO> getModalByIds(Set<Long> modalIds) {
-        return aiChatModelMapper.selectByIds(modalIds);
-    }
-
-    public AiChatModelDO validateExists(Long id) {
-        AiChatModelDO aiChatModalDO = aiChatModelMapper.selectById(id);
-        if (aiChatModalDO == null) {
-            throw ServiceExceptionUtil.exception(ErrorCodeConstants.AI_MODAL_NOT_EXIST);
-        }
-        return aiChatModalDO;
-    }
-
-    private void validatePlatform(String platform) {
-        try {
-            AiPlatformEnum.valueOfPlatform(platform);
-        } catch (IllegalArgumentException e) {
-            throw ServiceExceptionUtil.exception(ErrorCodeConstants.AI_MODAL_PLATFORM_PARAMS_INCORRECT, e.getMessage());
-        }
-    }
-
-    private void validateModalConfig(AiChatModalConfigVO aiChatModalConfigVO) {
-        Set<ConstraintViolation<AiChatModalConfigVO>> validate = ValidationUtil.validate(aiChatModalConfigVO);
-        for (ConstraintViolation<AiChatModalConfigVO> constraintViolation : validate) {
-            throw ServiceExceptionUtil.exception(ErrorCodeConstants.AI_MODAL_CONFIG_PARAMS_INCORRECT, constraintViolation.getMessage());
-        }
-    }
-}

+ 53 - 16
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/impl/AiChatServiceImpl.java

@@ -7,11 +7,15 @@ import cn.iocoder.yudao.framework.ai.chat.ChatResponse;
 import cn.iocoder.yudao.framework.ai.chat.StreamingChatClient;
 import cn.iocoder.yudao.framework.ai.chat.messages.MessageType;
 import cn.iocoder.yudao.framework.ai.chat.prompt.Prompt;
+import cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionUtil;
 import cn.iocoder.yudao.framework.security.core.util.SecurityFrameworkUtils;
+import cn.iocoder.yudao.module.ai.ErrorCodeConstants;
 import cn.iocoder.yudao.module.ai.config.AiChatClientFactory;
 import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.conversation.AiChatConversationRespVO;
+import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageAddReqVO;
 import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageRespVO;
 import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageSendReqVO;
+import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageSendStreamReqVO;
 import cn.iocoder.yudao.module.ai.convert.AiChatMessageConvert;
 import cn.iocoder.yudao.module.ai.dal.dataobject.chat.AiChatMessageDO;
 import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO;
@@ -19,11 +23,12 @@ import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatRoleDO;
 import cn.iocoder.yudao.module.ai.dal.mysql.AiChatConversationMapper;
 import cn.iocoder.yudao.module.ai.dal.mysql.AiChatMessageMapper;
 import cn.iocoder.yudao.module.ai.service.AiChatConversationService;
-import cn.iocoder.yudao.module.ai.service.model.AiChatModelService;
 import cn.iocoder.yudao.module.ai.service.AiChatRoleService;
 import cn.iocoder.yudao.module.ai.service.AiChatService;
+import cn.iocoder.yudao.module.ai.service.model.AiChatModelService;
 import lombok.AllArgsConstructor;
 import lombok.extern.slf4j.Slf4j;
+import org.springframework.boot.autoconfigure.http.HttpMessageConverters;
 import org.springframework.stereotype.Service;
 import org.springframework.transaction.annotation.Transactional;
 import reactor.core.publisher.Flux;
@@ -53,6 +58,7 @@ public class AiChatServiceImpl implements AiChatService {
     private final AiChatConversationService chatConversationService;
     private final AiChatModelService aiChatModalService;
     private final AiChatRoleService aiChatRoleService;
+    private final HttpMessageConverters messageConverters;
 
     @Transactional(rollbackFor = Exception.class)
     public AiChatMessageRespVO chat(AiChatMessageSendReqVO req) {
@@ -124,10 +130,15 @@ public class AiChatServiceImpl implements AiChatService {
         return insertChatMessageDO;
     }
 
-    public Flux<AiChatMessageRespVO> chatStream(AiChatMessageSendReqVO req) {
+    public Flux<AiChatMessageRespVO> chatStream(AiChatMessageSendStreamReqVO req) {
         Long loginUserId = SecurityFrameworkUtils.getLoginUserId();
+        // 查询提问的 message
+        AiChatMessageDO aiChatMessageDO = aiChatMessageMapper.selectById(req.getId());
+        if (aiChatMessageDO == null) {
+            throw ServiceExceptionUtil.exception(ErrorCodeConstants.AI_CHAT_MESSAGE_NOT_EXIST);
+        }
         // 查询对话
-        AiChatConversationRespVO conversation = chatConversationService.getConversationOfValidate(req.getConversationId());
+        AiChatConversationRespVO conversation = chatConversationService.getConversationOfValidate(aiChatMessageDO.getConversationId());
         // 获取对话模型
         AiChatModelDO chatModel = aiChatModalService.validateChatModel(conversation.getModelId());
         // 获取角色信息
@@ -138,14 +149,14 @@ public class AiChatServiceImpl implements AiChatService {
         // 校验角色是否公开
         aiChatRoleService.validateIsPublic(aiChatRoleDO);
         // 创建 chat 需要的 Prompt
-        Prompt prompt = new Prompt(req.getContent());
+        Prompt prompt = new Prompt(aiChatMessageDO.getContent());
+        // 提前创建一个 system message
+        AiChatMessageDO systemMessage = insertChatMessage(conversation.getId(), MessageType.SYSTEM, loginUserId, conversation.getRoleId(),
+                chatModel.getModel(), chatModel.getId(), "",
+                0, conversation.getTemperature(), conversation.getMaxTokens(), conversation.getMaxContexts());
 //        req.setTopK(req.getTopK());
 //        req.setTopP(req.getTopP());
 //        req.setTemperature(req.getTemperature());
-        // 保存 chat message
-        insertChatMessage(conversation.getId(), MessageType.USER, loginUserId, conversation.getRoleId(),
-                chatModel.getModel(), chatModel.getId(), req.getContent(),
-                null, conversation.getTemperature(), conversation.getMaxTokens(), conversation.getMaxContexts());
         // 获取 client 类型
         AiPlatformEnum platformEnum = AiPlatformEnum.validatePlatform(chatModel.getPlatform());
         StreamingChatClient streamingChatClient = aiChatClientFactory.getStreamingChatClient(platformEnum);
@@ -154,7 +165,8 @@ public class AiChatServiceImpl implements AiChatService {
         StringBuffer contentBuffer = new StringBuffer();
         AtomicInteger tokens = new AtomicInteger(0);
         return streamResponse.map(res -> {
-                    AiChatMessageRespVO aiChatMessageRespVO = new AiChatMessageRespVO();
+                    AiChatMessageRespVO aiChatMessageRespVO =
+                            AiChatMessageConvert.INSTANCE.convertAiChatMessageRespVO(systemMessage);
                     aiChatMessageRespVO.setContent(res.getResult().getOutput().getContent());
                     contentBuffer.append(res.getResult().getOutput().getContent());
                     tokens.incrementAndGet();
@@ -165,22 +177,46 @@ public class AiChatServiceImpl implements AiChatService {
             public void run() {
                 log.info("发送完成!");
                 // 保存 chat message
-                insertChatMessage(conversation.getId(), MessageType.SYSTEM, loginUserId, conversation.getRoleId(),
-                        chatModel.getModel(), chatModel.getId(), contentBuffer.toString(),
-                        tokens.get(), conversation.getTemperature(), conversation.getMaxTokens(), conversation.getMaxContexts());
+                aiChatMessageMapper.updateById(new AiChatMessageDO()
+                        .setId(systemMessage.getId())
+                        .setContent(contentBuffer.toString())
+                        .setTokens(tokens.get())
+                );
             }
         }).doOnError(new Consumer<Throwable>() {
             @Override
             public void accept(Throwable throwable) {
                 log.error("发送错误 {}!", throwable.getMessage());
-                // 保存 chat message
-                insertChatMessage(conversation.getId(), MessageType.SYSTEM, loginUserId, conversation.getRoleId(),
-                        chatModel.getModel(), chatModel.getId(), throwable.getMessage(),
-                        tokens.get(), conversation.getTemperature(), conversation.getMaxTokens(), conversation.getMaxContexts());
+                // 更新错误信息
+                aiChatMessageMapper.updateById(new AiChatMessageDO()
+                        .setId(systemMessage.getId())
+                        .setContent(throwable.getMessage())
+                        .setTokens(tokens.get())
+                );
             }
         });
     }
 
+    @Override
+    public AiChatMessageRespVO add(AiChatMessageAddReqVO req) {
+        Long loginUserId = SecurityFrameworkUtils.getLoginUserId();
+        // 查询对话
+        AiChatConversationRespVO conversation = chatConversationService.getConversationOfValidate(req.getConversationId());
+        // 获取对话模型
+        AiChatModelDO chatModel = aiChatModalService.validateChatModel(conversation.getModelId());
+        // 获取角色信息
+        AiChatRoleDO aiChatRoleDO = null;
+        if (conversation.getRoleId() != null) {
+            aiChatRoleDO = aiChatRoleService.validateExists(conversation.getRoleId());
+        }
+        // 校验角色是否公开
+        aiChatRoleService.validateIsPublic(aiChatRoleDO);
+        AiChatMessageDO userMessage = insertChatMessage(conversation.getId(), MessageType.USER, loginUserId, conversation.getRoleId(),
+                chatModel.getModel(), chatModel.getId(), req.getContent(),
+                null, conversation.getTemperature(), conversation.getMaxTokens(), conversation.getMaxContexts());
+       return AiChatMessageConvert.INSTANCE.convertAiChatMessageRespVO(userMessage);
+    }
+
     @Override
     public List<AiChatMessageRespVO> getMessageListByConversationId(Long conversationId) {
         // 校验对话是否存在
@@ -207,4 +243,5 @@ public class AiChatServiceImpl implements AiChatService {
     public Boolean deleteMessage(Long id) {
         return aiChatMessageMapper.deleteById(id) > 0;
     }
+
 }

+ 10 - 0
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/model/AiChatModelService.java

@@ -6,6 +6,9 @@ import cn.iocoder.yudao.module.ai.controller.admin.model.vo.chatModel.AiChatMode
 import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO;
 import jakarta.validation.Valid;
 
+import java.util.List;
+import java.util.Set;
+
 /**
  * AI 聊天模型 Service 接口
  *
@@ -60,4 +63,11 @@ public interface AiChatModelService {
      */
     AiChatModelDO validateChatModel(Long id);
 
+    /**
+     * 获取 - 根据多个 ids 获取
+     *
+     * @param modalIds
+     * @return
+     */
+    List<AiChatModelDO> getModalByIds(Set<Long> modalIds);
 }

+ 8 - 0
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/model/AiChatModelServiceImpl.java

@@ -12,6 +12,9 @@ import jakarta.annotation.Resource;
 import org.springframework.stereotype.Service;
 import org.springframework.validation.annotation.Validated;
 
+import java.util.List;
+import java.util.Set;
+
 import static cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionUtil.exception;
 import static cn.iocoder.yudao.module.ai.ErrorCodeConstants.*;
 
@@ -89,4 +92,9 @@ public class AiChatModelServiceImpl implements AiChatModelService {
         return model;
     }
 
+    @Override
+    public List<AiChatModelDO> getModalByIds(Set<Long> modalIds) {
+        return chatModelMapper.selectByIds(modalIds);
+    }
+
 }