Bläddra i källkod

【代码评审】AI:写作部分的建议

YunaiV 8 månader sedan
förälder
incheckning
4c21ae32fe

+ 1 - 1
yudao-module-ai/yudao-module-ai-api/src/main/java/cn/iocoder/yudao/module/ai/enums/write/AiWriteTypeEnum.java

@@ -32,7 +32,7 @@ public enum AiWriteTypeEnum implements IntArrayValuable {
     /**
      * 模版
      */
-    private final String template;
+    private final String prompt;
 
     public static final int[] ARRAYS = Arrays.stream(values()).mapToInt(AiWriteTypeEnum::getType).toArray();
 

+ 13 - 10
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/write/AiWriteServiceImpl.java

@@ -63,26 +63,26 @@ public class AiWriteServiceImpl implements AiWriteService {
         // 1.1 获取写作模型 尝试获取写作助手角色,如果没有则使用默认模型
         AiChatRoleDO writeRole = selectOneWriteRole();
         AiChatModelDO model;
+        // TODO @xin:writeRole.getModelId 可能为空。所以,最好是先通过 chatRole 拿。如果它没拿到,通过 getRequiredDefaultChatModel 再拿。
         if (Objects.nonNull(writeRole)) {
             model = chatModalService.getChatModel(writeRole.getModelId());
         } else {
             model = chatModalService.getRequiredDefaultChatModel();
         }
-
+        // 1.2 校验平台
         AiPlatformEnum platform = AiPlatformEnum.validatePlatform(model.getPlatform());
-
         StreamingChatModel chatModel = apiKeyService.getChatModel(model.getKeyId());
 
-        // 1.2 插入写作信息
+        // 2. 插入写作信息
         AiWriteDO writeDO = BeanUtils.toBean(generateReqVO, AiWriteDO.class, e -> e.setUserId(userId).setModel(model.getModel()).setPlatform(platform.getPlatform()));
         writeMapper.insert(writeDO);
 
-        // 2.1 构建提示词
+        // 3.1 构建提示词
         ChatOptions chatOptions = AiUtils.buildChatOptions(platform, model.getModel(), model.getTemperature(), model.getMaxTokens());
         Prompt prompt = new Prompt(buildWritingPrompt(generateReqVO), chatOptions);
         Flux<ChatResponse> streamResponse = chatModel.stream(prompt);
 
-        // 2.2 流式返回
+        // 3.2 流式返回
         StringBuffer contentBuffer = new StringBuffer();
         return streamResponse.map(chunk -> {
             String newContent = chunk.getResult() != null ? chunk.getResult().getOutput().getContent() : null;
@@ -102,10 +102,13 @@ public class AiWriteServiceImpl implements AiWriteService {
         }).onErrorResume(error -> Flux.just(error(ErrorCodeConstants.WRITE_STREAM_ERROR)));
     }
 
+    // TODO @xin:chatRoleService 增加一个 getChatRoleListByName;
     private AiChatRoleDO selectOneWriteRole() {
         AiChatRoleDO chatRoleDO = null;
+        // TODO @xin:"写作助手" 枚举下。
         PageResult<AiChatRoleDO> writeRolePage = chatRoleService.getChatRolePage(new AiChatRolePageReqVO().setName("写作助手"));
         List<AiChatRoleDO> list = writeRolePage.getList();
+        // TODO @xin:CollUtil.getFirst 简化下
         if (CollUtil.isNotEmpty(list)) {
             chatRoleDO = list.get(0);
         }
@@ -113,19 +116,19 @@ public class AiWriteServiceImpl implements AiWriteService {
     }
 
     private String buildWritingPrompt(AiWriteGenerateReqVO generateReqVO) {
+        // 校验写作类型是否合法
         Integer type = generateReqVO.getType();
+        // TODO @xin:这里可以搞到 validator 的校验。InEnum
+        AiWriteTypeEnum.validateType(type);
         String format = dictDataApi.getDictDataLabel(DictTypeConstants.AI_WRITE_FORMAT, generateReqVO.getFormat());
         String tone = dictDataApi.getDictDataLabel(DictTypeConstants.AI_WRITE_TONE, generateReqVO.getTone());
         String language = dictDataApi.getDictDataLabel(DictTypeConstants.AI_WRITE_LANGUAGE, generateReqVO.getLanguage());
         String length = dictDataApi.getDictDataLabel(DictTypeConstants.AI_WRITE_LENGTH, generateReqVO.getLength());
         String prompt = generateReqVO.getPrompt();
-        // 校验写作类型是否合法
-        AiWriteTypeEnum.validateType(type);
-
         if (Objects.equals(type, AiWriteTypeEnum.WRITING.getType())) {
-            return StrUtil.format(AiWriteTypeEnum.WRITING.getTemplate(), prompt, format, tone, language, length);
+            return StrUtil.format(AiWriteTypeEnum.WRITING.getPrompt(), prompt, format, tone, language, length);
         } else {
-            return StrUtil.format(AiWriteTypeEnum.REPLY.getTemplate(), generateReqVO.getOriginalContent(), prompt, format, tone, language, length);
+            return StrUtil.format(AiWriteTypeEnum.REPLY.getPrompt(), generateReqVO.getOriginalContent(), prompt, format, tone, language, length);
         }
     }
 

+ 3 - 2
yudao-module-ai/yudao-spring-boot-starter-ai/src/test/java/cn/iocoder/yudao/framework/ai/image/TongYiImagesModelTests.java

@@ -10,15 +10,16 @@ import org.junit.jupiter.api.Test;
 
 import java.util.Map;
 
+// TODO @fan:改成 TongYiImagesModel 哈
 /**
- * 通义万象 - 测试
+ * 通义万象
  */
 public class TongYiImagesModelTests {
 
     @Test
     public void imageCallTest() throws NoApiKeyException {
         // 设置 api key
-        Constants.apiKey="sk-Zsd81gZYg7";
+        Constants.apiKey = "sk-Zsd81gZYg7";
         ImageSynthesisParam param =
                 ImageSynthesisParam.builder()
                         .model(ImageSynthesis.Models.WANX_V1)