Browse Source

【优化】AI 写作:做角色设定,提高准确率

xiaoxin 8 months ago
parent
commit
bcdb23b89d

+ 63 - 0
yudao-module-ai/yudao-module-ai-api/src/main/java/cn/iocoder/yudao/module/ai/enums/AiChatRoleEnum.java

@@ -0,0 +1,63 @@
+package cn.iocoder.yudao.module.ai.enums;
+
+import cn.iocoder.yudao.framework.common.core.IntArrayValuable;
+import lombok.AllArgsConstructor;
+import lombok.Getter;
+
+import java.util.Arrays;
+
+/**
+ * AI 写作类型的枚举
+ *
+ * @author xiaoxin
+ */
+@AllArgsConstructor
+@Getter
+public enum AiChatRoleEnum implements IntArrayValuable {
+
+    AI_WRITE_ROLE(1, "写作助手", """
+            你是一位出色的写作助手,能够帮助用户生成创意和灵感,并在用户提供场景和提示词时生成对应的回复。你的任务包括:
+            1.	撰写建议:根据用户提供的主题或问题,提供详细的写作建议、情节发展方向、角色设定以及背景描写,确保内容结构清晰、有逻辑。
+            2.	回复生成:根据用户提供的场景和提示词,生成合适的对话或文字回复,确保语气和风格符合场景需求。
+            除此之外不需要除了正文内容外的其他回复,如标题、开头、任何解释性语句或道歉。
+            """),
+    AI_MIND_MAP_ROLE(2, "脑图助手", """
+             你是一位非常优秀的思维导图助手,你会把用户的所有提问都总结成思维导图,然后以 Markdown 格式输出。markdown 只需要输出一级标题,二级标题,三级标题,四级标题,最多输出四级,除此之外不要输出任何其他 markdown 标记。下面是一个合格的例子:
+             # Geek-AI 助手
+             ## 完整的开源系统
+             ### 前端开源
+             ### 后端开源
+             ## 支持各种大模型
+             ### OpenAI
+             ### Azure
+             ### 文心一言
+             ### 通义千问
+             ## 集成多种收费方式
+             ### 支付宝
+             ### 微信
+            除此之外不要任何解释性语句。
+            """);
+
+
+    /**
+     * 角色
+     */
+    private final Integer role;
+    /**
+     * 角色名
+     */
+    private final String name;
+
+    /**
+     * 角色设定
+     */
+    private final String prompt;
+
+    public static final int[] ARRAYS = Arrays.stream(values()).mapToInt(AiChatRoleEnum::getRole).toArray();
+
+    @Override
+    public int[] array() {
+        return ARRAYS;
+    }
+
+}

+ 0 - 7
yudao-module-ai/yudao-module-ai-api/src/main/java/cn/iocoder/yudao/module/ai/enums/write/AiWriteTypeEnum.java

@@ -1,7 +1,5 @@
 package cn.iocoder.yudao.module.ai.enums.write;
 
-import cn.hutool.core.util.ArrayUtil;
-import cn.hutool.core.util.StrUtil;
 import cn.iocoder.yudao.framework.common.core.IntArrayValuable;
 import lombok.AllArgsConstructor;
 import lombok.Getter;
@@ -41,9 +39,4 @@ public enum AiWriteTypeEnum implements IntArrayValuable {
         return ARRAYS;
     }
 
-    public static void validateType(Integer type) {
-        if (ArrayUtil.contains(ARRAYS, type)) return;
-        throw new IllegalArgumentException(StrUtil.format("未知写作类型({})", type));
-    }
-
 }

+ 1 - 1
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/mindmap/AiMindMapController.java

@@ -26,7 +26,7 @@ public class AiMindMapController {
     private AiMindMapService mindMapService;
 
     @PostMapping(value = "/generate-stream", produces = MediaType.TEXT_EVENT_STREAM_VALUE)
-    @Operation(summary = "发送消息(流式)", description = "流式返回,响应较快")
+    @Operation(summary = "脑图生成(流式)", description = "流式返回,响应较快")
     @PermitAll  // 解决 SSE 最终响应的时候,会被 Access Denied 拦截的问题
     public Flux<CommonResult<String>> generateMindMap(@RequestBody @Valid AiMindMapGenerateReqVO generateReqVO) {
         return mindMapService.generateMindMap(generateReqVO, getLoginUserId());

+ 1 - 1
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/write/vo/AiWriteGenerateReqVO.java

@@ -11,7 +11,7 @@ import lombok.Data;
 public class AiWriteGenerateReqVO {
 
     @Schema(description = "写作类型", requiredMode = Schema.RequiredMode.REQUIRED, example = "1")
-    @InEnum(AiWriteTypeEnum.class)
+    @InEnum(value = AiWriteTypeEnum.class, message = "写作类型必须是 {value}")
     private Integer type;
 
     @Schema(description = "写作内容提示", example = "1.撰写:田忌赛马;2.回复:不批")

+ 6 - 2
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/mysql/model/AiChatRoleMapper.java

@@ -4,9 +4,7 @@ import cn.iocoder.yudao.framework.common.enums.CommonStatusEnum;
 import cn.iocoder.yudao.framework.common.pojo.PageResult;
 import cn.iocoder.yudao.framework.mybatis.core.mapper.BaseMapperX;
 import cn.iocoder.yudao.framework.mybatis.core.query.LambdaQueryWrapperX;
-import cn.iocoder.yudao.framework.mybatis.core.query.QueryWrapperX;
 import cn.iocoder.yudao.module.ai.controller.admin.model.vo.chatRole.AiChatRolePageReqVO;
-import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO;
 import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatRoleDO;
 import org.apache.ibatis.annotations.Mapper;
 
@@ -47,4 +45,10 @@ public interface AiChatRoleMapper extends BaseMapperX<AiChatRoleDO> {
                 .groupBy(AiChatRoleDO::getCategory));
     }
 
+    default List<AiChatRoleDO> selectListByName(String name) {
+        return selectList(new LambdaQueryWrapperX<AiChatRoleDO>()
+                .likeIfPresent(AiChatRoleDO::getName, name)
+                .orderByAsc(AiChatRoleDO::getSort));
+    }
+
 }

+ 9 - 40
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/mindmap/AiMindMapServiceImpl.java

@@ -5,15 +5,14 @@ import cn.hutool.core.util.StrUtil;
 import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
 import cn.iocoder.yudao.framework.ai.core.util.AiUtils;
 import cn.iocoder.yudao.framework.common.pojo.CommonResult;
-import cn.iocoder.yudao.framework.common.pojo.PageResult;
 import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
 import cn.iocoder.yudao.framework.tenant.core.util.TenantUtils;
 import cn.iocoder.yudao.module.ai.controller.admin.mindmap.vo.AiMindMapGenerateReqVO;
-import cn.iocoder.yudao.module.ai.controller.admin.model.vo.chatRole.AiChatRolePageReqVO;
 import cn.iocoder.yudao.module.ai.dal.dataobject.mindmap.AiMindMapDO;
 import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO;
 import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatRoleDO;
 import cn.iocoder.yudao.module.ai.dal.mysql.mindmap.AiMindMapMapper;
+import cn.iocoder.yudao.module.ai.enums.AiChatRoleEnum;
 import cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants;
 import cn.iocoder.yudao.module.ai.service.model.AiApiKeyService;
 import cn.iocoder.yudao.module.ai.service.model.AiChatModelService;
@@ -56,61 +55,40 @@ public class AiMindMapServiceImpl implements AiMindMapService {
     @Resource
     private AiMindMapMapper mindMapMapper;
 
-    private static final String DEFAULT_SYSTEM_MESSAGE = """
-             你是一位非常优秀的思维导图助手,你会把用户的所有提问都总结成思维导图,然后以 Markdown 格式输出。markdown 只需要输出一级标题,二级标题,三级标题,四级标题,最多输出四级,除此之外不要输出任何其他 markdown 标记。下面是一个合格的例子:
-             # Geek-AI 助手
-             
-             ## 完整的开源系统
-             ### 前端开源
-             ### 后端开源
-                      
-             ## 支持各种大模型
-             ### OpenAI
-             ### Azure
-             ### 文心一言
-             ### 通义千问
-                        
-             ## 集成多种收费方式
-             ### 支付宝
-             ### 微信
-                       
-             另外,除此之外不要任何解释性语句。
-            """;
-
     @Override
     public Flux<CommonResult<String>> generateMindMap(AiMindMapGenerateReqVO generateReqVO, Long userId) {
         // 1.1 获取脑图模型 尝试获取思维导图助手角色,如果没有则使用默认模型
-        AiChatRoleDO mindMapRole = selectOneMindMapRole();
+        AiChatRoleDO mindMapRole = CollUtil.getFirst(chatRoleService.getChatRoleListByName(AiChatRoleEnum.AI_MIND_MAP_ROLE.getName()));
         AiChatModelDO model;
         String systemMessage;
-        if (Objects.nonNull(mindMapRole)) {
+        if (Objects.nonNull(mindMapRole) && Objects.nonNull(mindMapRole.getModelId())) {
             model = chatModalService.getChatModel(mindMapRole.getModelId());
             systemMessage = mindMapRole.getSystemMessage();
         } else {
             model = chatModalService.getRequiredDefaultChatModel();
-            systemMessage = DEFAULT_SYSTEM_MESSAGE;
+            systemMessage = AiChatRoleEnum.AI_MIND_MAP_ROLE.getPrompt();
         }
 
         AiPlatformEnum platform = AiPlatformEnum.validatePlatform(model.getPlatform());
         ChatModel chatModel = apiKeyService.getChatModel(model.getKeyId());
 
-        // 1.2 插入思维导图信息
+        // 2 插入思维导图信息
         AiMindMapDO mindMapDO = BeanUtils.toBean(generateReqVO, AiMindMapDO.class, e -> e.setUserId(userId).setModel(model.getModel()).setPlatform(platform.getPlatform()));
         mindMapMapper.insert(mindMapDO);
 
         ChatOptions chatOptions = AiUtils.buildChatOptions(platform, model.getModel(), model.getTemperature(), model.getMaxTokens());
-        // 2.1 角色设定
+        // 3.1 角色设定
         List<Message> chatMessages = new ArrayList<>();
         if (StrUtil.isNotBlank(systemMessage)) {
             chatMessages.add(new SystemMessage(systemMessage));
         }
-        // 2.2 用户输入
+        // 3.2 用户输入
         chatMessages.add(new UserMessage(generateReqVO.getPrompt()));
-        // 2.3 构建提示词
+        // 3.3 构建提示词
         Prompt prompt = new Prompt(chatMessages, chatOptions);
 
         Flux<ChatResponse> streamResponse = chatModel.stream(prompt);
-        // 2.4 流式返回
+        // 3.4 流式返回
         StringBuffer contentBuffer = new StringBuffer();
         return streamResponse.map(chunk -> {
             String newContent = chunk.getResult() != null ? chunk.getResult().getOutput().getContent() : null;
@@ -131,13 +109,4 @@ public class AiMindMapServiceImpl implements AiMindMapService {
 
     }
 
-    private AiChatRoleDO selectOneMindMapRole() {
-        AiChatRoleDO chatRoleDO = null;
-        PageResult<AiChatRoleDO> mindMapRolePage = chatRoleService.getChatRolePage(new AiChatRolePageReqVO().setName("思维导图助手"));
-        List<AiChatRoleDO> list = mindMapRolePage.getList();
-        if (CollUtil.isNotEmpty(list)) {
-            chatRoleDO = list.get(0);
-        }
-        return chatRoleDO;
-    }
 }

+ 7 - 0
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/model/AiChatRoleService.java

@@ -118,4 +118,11 @@ public interface AiChatRoleService {
      */
     List<String> getChatRoleCategoryList();
 
+    /**
+     * 根据名字获得聊天角色
+     * @param name 名字
+     * @return 聊天角色列表
+     */
+    List<AiChatRoleDO> getChatRoleListByName(String name);
+
 }

+ 5 - 0
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/model/AiChatRoleServiceImpl.java

@@ -137,5 +137,10 @@ public class AiChatRoleServiceImpl implements AiChatRoleService {
         return convertList(list, AiChatRoleDO::getCategory, role -> role != null && StrUtil.isNotBlank(role.getCategory()));
     }
 
+    @Override
+    public List<AiChatRoleDO> getChatRoleListByName(String name) {
+        return chatRoleMapper.selectListByName(name);
+    }
+
 }
 

+ 19 - 23
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/write/AiWriteServiceImpl.java

@@ -5,15 +5,14 @@ import cn.hutool.core.util.StrUtil;
 import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
 import cn.iocoder.yudao.framework.ai.core.util.AiUtils;
 import cn.iocoder.yudao.framework.common.pojo.CommonResult;
-import cn.iocoder.yudao.framework.common.pojo.PageResult;
 import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
 import cn.iocoder.yudao.framework.tenant.core.util.TenantUtils;
-import cn.iocoder.yudao.module.ai.controller.admin.model.vo.chatRole.AiChatRolePageReqVO;
 import cn.iocoder.yudao.module.ai.controller.admin.write.vo.AiWriteGenerateReqVO;
 import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO;
 import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatRoleDO;
 import cn.iocoder.yudao.module.ai.dal.dataobject.write.AiWriteDO;
 import cn.iocoder.yudao.module.ai.dal.mysql.write.AiWriteMapper;
+import cn.iocoder.yudao.module.ai.enums.AiChatRoleEnum;
 import cn.iocoder.yudao.module.ai.enums.DictTypeConstants;
 import cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants;
 import cn.iocoder.yudao.module.ai.enums.write.AiWriteTypeEnum;
@@ -23,6 +22,9 @@ import cn.iocoder.yudao.module.ai.service.model.AiChatRoleService;
 import cn.iocoder.yudao.module.system.api.dict.DictDataApi;
 import jakarta.annotation.Resource;
 import lombok.extern.slf4j.Slf4j;
+import org.springframework.ai.chat.messages.Message;
+import org.springframework.ai.chat.messages.SystemMessage;
+import org.springframework.ai.chat.messages.UserMessage;
 import org.springframework.ai.chat.model.ChatResponse;
 import org.springframework.ai.chat.model.StreamingChatModel;
 import org.springframework.ai.chat.prompt.ChatOptions;
@@ -30,6 +32,7 @@ import org.springframework.ai.chat.prompt.Prompt;
 import org.springframework.stereotype.Service;
 import reactor.core.publisher.Flux;
 
+import java.util.ArrayList;
 import java.util.List;
 import java.util.Objects;
 
@@ -61,13 +64,15 @@ public class AiWriteServiceImpl implements AiWriteService {
     @Override
     public Flux<CommonResult<String>> generateWriteContent(AiWriteGenerateReqVO generateReqVO, Long userId) {
         // 1.1 获取写作模型 尝试获取写作助手角色,如果没有则使用默认模型
-        AiChatRoleDO writeRole = selectOneWriteRole();
+        AiChatRoleDO writeRole = CollUtil.getFirst(chatRoleService.getChatRoleListByName(AiChatRoleEnum.AI_WRITE_ROLE.getName()));
         AiChatModelDO model;
-        // TODO @xin:writeRole.getModelId 可能为空。所以,最好是先通过 chatRole 拿。如果它没拿到,通过 getRequiredDefaultChatModel 再拿。
-        if (Objects.nonNull(writeRole)) {
+        String systemMessage;
+        if (Objects.nonNull(writeRole) && Objects.nonNull(writeRole.getModelId())) {
             model = chatModalService.getChatModel(writeRole.getModelId());
+            systemMessage = writeRole.getSystemMessage();
         } else {
             model = chatModalService.getRequiredDefaultChatModel();
+            systemMessage = AiChatRoleEnum.AI_WRITE_ROLE.getPrompt();
         }
         // 1.2 校验平台
         AiPlatformEnum platform = AiPlatformEnum.validatePlatform(model.getPlatform());
@@ -77,9 +82,16 @@ public class AiWriteServiceImpl implements AiWriteService {
         AiWriteDO writeDO = BeanUtils.toBean(generateReqVO, AiWriteDO.class, e -> e.setUserId(userId).setModel(model.getModel()).setPlatform(platform.getPlatform()));
         writeMapper.insert(writeDO);
 
-        // 3.1 构建提示词
         ChatOptions chatOptions = AiUtils.buildChatOptions(platform, model.getModel(), model.getTemperature(), model.getMaxTokens());
-        Prompt prompt = new Prompt(buildWritingPrompt(generateReqVO), chatOptions);
+        // 3.1 角色设定
+        List<Message> chatMessages = new ArrayList<>();
+        if (StrUtil.isNotBlank(systemMessage)) {
+            chatMessages.add(new SystemMessage(systemMessage));
+        }
+        // 3.2 用户输入
+        chatMessages.add(new UserMessage(buildWritingPrompt(generateReqVO)));
+        // 3.3 构建提示词
+        Prompt prompt = new Prompt(chatMessages, chatOptions);
         Flux<ChatResponse> streamResponse = chatModel.stream(prompt);
 
         // 3.2 流式返回
@@ -102,24 +114,8 @@ public class AiWriteServiceImpl implements AiWriteService {
         }).onErrorResume(error -> Flux.just(error(ErrorCodeConstants.WRITE_STREAM_ERROR)));
     }
 
-    // TODO @xin:chatRoleService 增加一个 getChatRoleListByName;
-    private AiChatRoleDO selectOneWriteRole() {
-        AiChatRoleDO chatRoleDO = null;
-        // TODO @xin:"写作助手" 枚举下。
-        PageResult<AiChatRoleDO> writeRolePage = chatRoleService.getChatRolePage(new AiChatRolePageReqVO().setName("写作助手"));
-        List<AiChatRoleDO> list = writeRolePage.getList();
-        // TODO @xin:CollUtil.getFirst 简化下
-        if (CollUtil.isNotEmpty(list)) {
-            chatRoleDO = list.get(0);
-        }
-        return chatRoleDO;
-    }
-
     private String buildWritingPrompt(AiWriteGenerateReqVO generateReqVO) {
-        // 校验写作类型是否合法
         Integer type = generateReqVO.getType();
-        // TODO @xin:这里可以搞到 validator 的校验。InEnum
-        AiWriteTypeEnum.validateType(type);
         String format = dictDataApi.getDictDataLabel(DictTypeConstants.AI_WRITE_FORMAT, generateReqVO.getFormat());
         String tone = dictDataApi.getDictDataLabel(DictTypeConstants.AI_WRITE_TONE, generateReqVO.getTone());
         String language = dictDataApi.getDictDataLabel(DictTypeConstants.AI_WRITE_LANGUAGE, generateReqVO.getLanguage());