Browse Source

Merge remote-tracking branch 'origin/master-jdk21-ai' into master-jdk21-ai

cherishsince 8 months ago
parent
commit
2ae90b9edc
15 changed files with 209 additions and 145 deletions
  1. 1 1
      script/idea/http-client.env.json
  2. 1 0
      yudao-module-ai/yudao-module-ai-api/src/main/java/cn/iocoder/yudao/module/ai/enums/AiChatRoleEnum.java
  3. 0 4
      yudao-module-ai/yudao-module-ai-biz/pom.xml
  4. 2 3
      yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/mindmap/AiMindMapDO.java
  5. 1 1
      yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/mysql/mindmap/AiMindMapMapper.java
  6. 1 1
      yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/chat/AiChatMessageServiceImpl.java
  7. 6 0
      yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/image/AiImageServiceImpl.java
  8. 50 28
      yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/mindmap/AiMindMapServiceImpl.java
  9. 46 27
      yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/write/AiWriteServiceImpl.java
  10. 31 8
      yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/factory/AiModelFactoryImpl.java
  11. 2 2
      yudao-module-ai/yudao-spring-boot-starter-ai/src/test/java/cn/iocoder/yudao/framework/ai/image/OpenAiImageModelTests.java
  12. 23 29
      yudao-module-ai/yudao-spring-boot-starter-ai/src/test/java/cn/iocoder/yudao/framework/ai/image/QianFanImageTests.java
  13. 2 2
      yudao-module-ai/yudao-spring-boot-starter-ai/src/test/java/cn/iocoder/yudao/framework/ai/image/StabilityAiImageModelTests.java
  14. 43 0
      yudao-module-ai/yudao-spring-boot-starter-ai/src/test/java/cn/iocoder/yudao/framework/ai/image/TongYiImagesModelTest.java
  15. 0 39
      yudao-module-ai/yudao-spring-boot-starter-ai/src/test/java/cn/iocoder/yudao/framework/ai/image/TongYiImagesModelTests.java

+ 1 - 1
script/idea/http-client.env.json

@@ -1,7 +1,7 @@
 {
   "local": {
     "baseUrl": "http://127.0.0.1:48080/admin-api",
-    "token": "Bearer 1c2ce60de96a4fb0bf5bea9604099a3d",
+    "token": "test1",
     "adminTenentId": "1",
 
     "appApi": "http://127.0.0.1:48080/app-api",

+ 1 - 0
yudao-module-ai/yudao-module-ai-api/src/main/java/cn/iocoder/yudao/module/ai/enums/AiChatRoleEnum.java

@@ -39,6 +39,7 @@ public enum AiChatRoleEnum implements IntArrayValuable {
             除此之外不要任何解释性语句。
             """);
 
+    // TODO @xin:这个 role 是不是删除掉好点哈。= = 目前主要是没做角色枚举。这里多了 role 反倒容易误解哈
     /**
      * 角色
      */

+ 0 - 4
yudao-module-ai/yudao-module-ai-biz/pom.xml

@@ -60,9 +60,5 @@
             <groupId>cn.iocoder.boot</groupId>
             <artifactId>yudao-spring-boot-starter-test</artifactId>
         </dependency>
-        <dependency>
-            <groupId>cn.iocoder.boot</groupId>
-            <artifactId>yudao-spring-boot-starter-excel</artifactId>
-        </dependency>
     </dependencies>
 </project>

+ 2 - 3
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/mindmap/AiMindMapDO.java

@@ -12,8 +12,7 @@ import lombok.Data;
  *
  * @author xiaoxin
  */
-// TODO @xin:如果没 typehandler 的需求,autoResultMap 可以去掉哈
-@TableName(value = "ai_mind_map", autoResultMap = true)
+@TableName(value = "ai_mind_map")
 @Data
 public class AiMindMapDO extends BaseDO {
 
@@ -25,7 +24,7 @@ public class AiMindMapDO extends BaseDO {
 
     /**
      * 用户编号
-     *
+     * <p>
      * 关联 AdminUserDO 的 userId 字段
      */
     private Long userId;

+ 1 - 1
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/mysql/mindmap/AiMindMapMapper.java

@@ -5,7 +5,7 @@ import cn.iocoder.yudao.module.ai.dal.dataobject.mindmap.AiMindMapDO;
 import org.apache.ibatis.annotations.Mapper;
 
 /**
- * AI 音乐 Mapper
+ * AI 思维导图 Mapper
  *
  * @author xiaoxin
  */

+ 1 - 1
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/chat/AiChatMessageServiceImpl.java

@@ -111,7 +111,7 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
         AiChatMessageDO assistantMessage = createChatMessage(conversation.getId(), userMessage.getId(), model,
                 userId, conversation.getRoleId(), MessageType.ASSISTANT, "", sendReqVO.getUseContext());
 
-        // 3.2 创建 chat 需要的 Prompt
+        // 3.2 构建 Prompt,并进行调用
         Prompt prompt = buildPrompt(conversation, historyMessages, model, sendReqVO);
         Flux<ChatResponse> streamResponse = chatModel.stream(prompt);
 

+ 6 - 0
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/image/AiImageServiceImpl.java

@@ -31,6 +31,7 @@ import org.springframework.ai.image.ImageOptions;
 import org.springframework.ai.image.ImagePrompt;
 import org.springframework.ai.image.ImageResponse;
 import org.springframework.ai.openai.OpenAiImageOptions;
+import org.springframework.ai.qianfan.QianFanImageOptions;
 import org.springframework.ai.stabilityai.api.StabilityAiImageOptions;
 import org.springframework.scheduling.annotation.Async;
 import org.springframework.stereotype.Service;
@@ -142,6 +143,11 @@ public class AiImageServiceImpl implements AiImageService {
                     .withModel(draw.getModel()).withN(1)
                     .withHeight(draw.getHeight()).withWidth(draw.getWidth())
                     .build();
+        } else if (ObjUtil.equal(draw.getPlatform(), AiPlatformEnum.YI_YAN.getPlatform())) {
+            return QianFanImageOptions.builder()
+                    .withModel(draw.getModel()).withN(1)
+                    .withHeight(draw.getHeight()).withWidth(draw.getWidth())
+                    .build();
         }
         throw new IllegalArgumentException("不支持的 AI 平台:" + draw.getPlatform());
     }

+ 50 - 28
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/mindmap/AiMindMapServiceImpl.java

@@ -1,6 +1,7 @@
 package cn.iocoder.yudao.module.ai.service.mindmap;
 
 import cn.hutool.core.collection.CollUtil;
+import cn.hutool.core.lang.Assert;
 import cn.hutool.core.util.StrUtil;
 import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
 import cn.iocoder.yudao.framework.ai.core.util.AiUtils;
@@ -31,13 +32,12 @@ import reactor.core.publisher.Flux;
 
 import java.util.ArrayList;
 import java.util.List;
-import java.util.Objects;
 
 import static cn.iocoder.yudao.framework.common.pojo.CommonResult.error;
 import static cn.iocoder.yudao.framework.common.pojo.CommonResult.success;
 
 /**
- * AI 写作 Service 实现类
+ * AI 思维导图 Service 实现类
  *
  * @author xiaoxin
  */
@@ -57,38 +57,28 @@ public class AiMindMapServiceImpl implements AiMindMapService {
 
     @Override
     public Flux<CommonResult<String>> generateMindMap(AiMindMapGenerateReqVO generateReqVO, Long userId) {
-        // 1.1 获取脑图模型 尝试获取思维导图助手角色,如果没有则使用默认模型
-        AiChatRoleDO mindMapRole = CollUtil.getFirst(chatRoleService.getChatRoleListByName(AiChatRoleEnum.AI_MIND_MAP_ROLE.getName()));
-        AiChatModelDO model;
-        String systemMessage;
-        if (Objects.nonNull(mindMapRole) && Objects.nonNull(mindMapRole.getModelId())) {
-            model = chatModalService.getChatModel(mindMapRole.getModelId());
-            systemMessage = mindMapRole.getSystemMessage();
-        } else {
-            model = chatModalService.getRequiredDefaultChatModel();
-            systemMessage = AiChatRoleEnum.AI_MIND_MAP_ROLE.getSystemMessage();
-        }
-
+        // 1. 获取脑图模型。尝试获取思维导图助手角色,如果没有则使用默认模型
+        AiChatRoleDO role = CollUtil.getFirst(
+                chatRoleService.getChatRoleListByName(AiChatRoleEnum.AI_MIND_MAP_ROLE.getName()));
+        // 1.1 获取脑图执行模型
+        AiChatModelDO model = getModel(role);
+        // 1.2 获取角色设定消息
+        String systemMessage = role != null && StrUtil.isNotBlank(role.getSystemMessage())
+                ? role.getSystemMessage() : AiChatRoleEnum.AI_MIND_MAP_ROLE.getSystemMessage();
+        // 1.3 校验平台
         AiPlatformEnum platform = AiPlatformEnum.validatePlatform(model.getPlatform());
         ChatModel chatModel = apiKeyService.getChatModel(model.getKeyId());
 
-        // 2 插入思维导图信息
-        AiMindMapDO mindMapDO = BeanUtils.toBean(generateReqVO, AiMindMapDO.class, e -> e.setUserId(userId).setModel(model.getModel()).setPlatform(platform.getPlatform()));
+        // 2. 插入思维导图信息
+        AiMindMapDO mindMapDO = BeanUtils.toBean(generateReqVO, AiMindMapDO.class,
+                mindMap -> mindMap.setUserId(userId).setModel(model.getModel()).setPlatform(platform.getPlatform()));
         mindMapMapper.insert(mindMapDO);
 
-        ChatOptions chatOptions = AiUtils.buildChatOptions(platform, model.getModel(), model.getTemperature(), model.getMaxTokens());
-        // 3.1 角色设定
-        List<Message> chatMessages = new ArrayList<>();
-        if (StrUtil.isNotBlank(systemMessage)) {
-            chatMessages.add(new SystemMessage(systemMessage));
-        }
-        // 3.2 用户输入
-        chatMessages.add(new UserMessage(generateReqVO.getPrompt()));
-        // 3.3 构建提示词
-        Prompt prompt = new Prompt(chatMessages, chatOptions);
-
+        // 3.1 构建 Prompt,并进行调用
+        Prompt prompt = buildPrompt(generateReqVO, model, systemMessage);
         Flux<ChatResponse> streamResponse = chatModel.stream(prompt);
-        // 3.4 流式返回
+
+        // 3.2 流式返回
         StringBuffer contentBuffer = new StringBuffer();
         return streamResponse.map(chunk -> {
             String newContent = chunk.getResult() != null ? chunk.getResult().getOutput().getContent() : null;
@@ -109,4 +99,36 @@ public class AiMindMapServiceImpl implements AiMindMapService {
 
     }
 
+    private Prompt buildPrompt(AiMindMapGenerateReqVO generateReqVO, AiChatModelDO model, String systemMessage) {
+        // 1. 构建 message 列表
+        List<Message> chatMessages = buildMessages(generateReqVO, systemMessage);
+        // 2. 构建 options 对象
+        AiPlatformEnum platform = AiPlatformEnum.validatePlatform(model.getPlatform());
+        ChatOptions options = AiUtils.buildChatOptions(platform, model.getModel(), model.getTemperature(), model.getMaxTokens());
+        return new Prompt(chatMessages, options);
+    }
+
+    private static List<Message> buildMessages(AiMindMapGenerateReqVO generateReqVO, String systemMessage) {
+        List<Message> chatMessages = new ArrayList<>();
+        // 1. 角色设定
+        if (StrUtil.isNotBlank(systemMessage)) {
+            chatMessages.add(new SystemMessage(systemMessage));
+        }
+        // 2. 用户输入
+        chatMessages.add(new UserMessage(generateReqVO.getPrompt()));
+        return chatMessages;
+    }
+
+    private AiChatModelDO getModel(AiChatRoleDO role) {
+        AiChatModelDO model = null;
+        if (role != null && role.getModelId() != null) {
+            model = chatModalService.getChatModel(role.getModelId());
+        }
+        if (model != null) {
+            model = chatModalService.getRequiredDefaultChatModel();
+        }
+        Assert.notNull(model, "[AI] 获取不到模型");
+        return model;
+    }
+
 }

+ 46 - 27
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/write/AiWriteServiceImpl.java

@@ -1,6 +1,7 @@
 package cn.iocoder.yudao.module.ai.service.write;
 
 import cn.hutool.core.collection.CollUtil;
+import cn.hutool.core.lang.Assert;
 import cn.hutool.core.util.StrUtil;
 import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
 import cn.iocoder.yudao.framework.ai.core.util.AiUtils;
@@ -67,19 +68,15 @@ public class AiWriteServiceImpl implements AiWriteService {
 
     @Override
     public Flux<CommonResult<String>> generateWriteContent(AiWriteGenerateReqVO generateReqVO, Long userId) {
-        // 1.1 获取写作模型 尝试获取写作助手角色,如果没有则使用默认模型
-        AiChatRoleDO writeRole = CollUtil.getFirst(chatRoleService.getChatRoleListByName(AiChatRoleEnum.AI_WRITE_ROLE.getName()));
-        // TODO @xin:如果有 writeRole,但是没 modeId,是不是也可以用 systemMessage 哈?建议的写法是:先通过 modelId 获取 model。如果 model == null,则 chatModalService.getRequiredDefaultChatModel();如果还是 null,则抛出异常;。。。。。。。。。。。。。。然后,systemMessage = writeRole != null && writeRole.systemPrompt != "" 这样处理。
-        AiChatModelDO model;
-        String systemMessage;
-        if (Objects.nonNull(writeRole) && Objects.nonNull(writeRole.getModelId())) {
-            model = chatModalService.getChatModel(writeRole.getModelId());
-            systemMessage = writeRole.getSystemMessage();
-        } else {
-            model = chatModalService.getRequiredDefaultChatModel();
-            systemMessage = AiChatRoleEnum.AI_WRITE_ROLE.getSystemMessage();
-        }
-        // 1.2 校验平台
+        // 1 获取写作模型。尝试获取写作助手角色,没有则使用默认模型
+        AiChatRoleDO writeRole = CollUtil.getFirst(
+                chatRoleService.getChatRoleListByName(AiChatRoleEnum.AI_WRITE_ROLE.getName()));
+        // 1.1 获取写作执行模型
+        AiChatModelDO model = getModel(writeRole);
+        // 1.2 获取角色设定消息
+        String systemMessage = Objects.nonNull(writeRole) && StrUtil.isNotBlank(writeRole.getSystemMessage())
+                ? writeRole.getSystemMessage() : AiChatRoleEnum.AI_WRITE_ROLE.getSystemMessage();
+        // 1.3 校验平台
         AiPlatformEnum platform = AiPlatformEnum.validatePlatform(model.getPlatform());
         StreamingChatModel chatModel = apiKeyService.getChatModel(model.getKeyId());
 
@@ -88,21 +85,11 @@ public class AiWriteServiceImpl implements AiWriteService {
                 write -> write.setUserId(userId).setPlatform(platform.getPlatform()).setModel(model.getModel()));
         writeMapper.insert(writeDO);
 
-        // 3. 调用大模型,写作生成
-        ChatOptions chatOptions = AiUtils.buildChatOptions(platform, model.getModel(), model.getTemperature(), model.getMaxTokens());
-        // 3.1 角色设定
-        // TODO @xin:要不把 90 到 97 这部分,合并到一个方法里。目的是:让这个方法的主干更明确
-        List<Message> chatMessages = new ArrayList<>();
-        if (StrUtil.isNotBlank(systemMessage)) {
-            chatMessages.add(new SystemMessage(systemMessage));
-        }
-        // 3.2 用户输入
-        chatMessages.add(new UserMessage(buildWritingPrompt(generateReqVO)));
-        // 3.3 构建提示词
-        Prompt prompt = new Prompt(chatMessages, chatOptions);
+        // 3.1 构建 Prompt,并进行调用
+        Prompt prompt = buildPrompt(generateReqVO, model, systemMessage);
         Flux<ChatResponse> streamResponse = chatModel.stream(prompt);
 
-        // 4. 流式返回
+        // 3.2 流式返回
         StringBuffer contentBuffer = new StringBuffer();
         return streamResponse.map(chunk -> {
             String newContent = chunk.getResult() != null ? chunk.getResult().getOutput().getContent() : null;
@@ -122,7 +109,39 @@ public class AiWriteServiceImpl implements AiWriteService {
         }).onErrorResume(error -> Flux.just(error(ErrorCodeConstants.WRITE_STREAM_ERROR)));
     }
 
-    private String buildWritingPrompt(AiWriteGenerateReqVO generateReqVO) {
+    private AiChatModelDO getModel(AiChatRoleDO writeRole) {
+        AiChatModelDO model = null;
+        if (Objects.nonNull(writeRole) && Objects.nonNull(writeRole.getModelId())) {
+            model = chatModalService.getChatModel(writeRole.getModelId());
+        }
+        if (Objects.isNull(model)) {
+            model = chatModalService.getRequiredDefaultChatModel();
+        }
+        Assert.notNull(model, "[AI] 获取不到模型");
+        return model;
+    }
+
+    private Prompt buildPrompt(AiWriteGenerateReqVO generateReqVO, AiChatModelDO model, String systemMessage) {
+        // 1. 构建 message 列表
+        List<Message> chatMessages = buildMessages(generateReqVO, systemMessage);
+        // 2. 构建 options 对象
+        AiPlatformEnum platform = AiPlatformEnum.validatePlatform(model.getPlatform());
+        ChatOptions options = AiUtils.buildChatOptions(platform, model.getModel(), model.getTemperature(), model.getMaxTokens());
+        return new Prompt(chatMessages, options);
+    }
+
+    private List<Message> buildMessages(AiWriteGenerateReqVO generateReqVO, String systemMessage) {
+        List<Message> chatMessages = new ArrayList<>();
+        if (StrUtil.isNotBlank(systemMessage)) {
+            // 1.1 角色设定
+            chatMessages.add(new SystemMessage(systemMessage));
+        }
+        // 1.2 用户输入
+        chatMessages.add(new UserMessage(buildUserMessage(generateReqVO)));
+        return chatMessages;
+    }
+
+    private String buildUserMessage(AiWriteGenerateReqVO generateReqVO) {
         String format = dictDataApi.getDictDataLabel(DictTypeConstants.AI_WRITE_FORMAT, generateReqVO.getFormat());
         String tone = dictDataApi.getDictDataLabel(DictTypeConstants.AI_WRITE_TONE, generateReqVO.getTone());
         String language = dictDataApi.getDictDataLabel(DictTypeConstants.AI_WRITE_LANGUAGE, generateReqVO.getLanguage());

+ 31 - 8
yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/factory/AiModelFactoryImpl.java

@@ -18,12 +18,15 @@ import com.alibaba.cloud.ai.tongyi.TongYiConnectionProperties;
 import com.alibaba.cloud.ai.tongyi.chat.TongYiChatModel;
 import com.alibaba.cloud.ai.tongyi.chat.TongYiChatProperties;
 import com.alibaba.cloud.ai.tongyi.image.TongYiImagesModel;
+import com.alibaba.cloud.ai.tongyi.image.TongYiImagesProperties;
 import com.alibaba.dashscope.aigc.generation.Generation;
+import com.alibaba.dashscope.aigc.imagesynthesis.ImageSynthesis;
 import org.springframework.ai.autoconfigure.ollama.OllamaAutoConfiguration;
 import org.springframework.ai.autoconfigure.openai.OpenAiAutoConfiguration;
 import org.springframework.ai.autoconfigure.qianfan.QianFanAutoConfiguration;
 import org.springframework.ai.autoconfigure.qianfan.QianFanChatProperties;
 import org.springframework.ai.autoconfigure.qianfan.QianFanConnectionProperties;
+import org.springframework.ai.autoconfigure.qianfan.QianFanImageProperties;
 import org.springframework.ai.autoconfigure.zhipuai.ZhiPuAiAutoConfiguration;
 import org.springframework.ai.autoconfigure.zhipuai.ZhiPuAiChatProperties;
 import org.springframework.ai.autoconfigure.zhipuai.ZhiPuAiConnectionProperties;
@@ -111,6 +114,10 @@ public class AiModelFactoryImpl implements AiModelFactory {
     public ImageModel getDefaultImageModel(AiPlatformEnum platform) {
         //noinspection EnhancedSwitchMigration
         switch (platform) {
+            case TONG_YI:
+                return SpringUtil.getBean(TongYiImagesModel.class);
+            case YI_YAN:
+                return SpringUtil.getBean(QianFanImageModel.class);
             case OPENAI:
                 return SpringUtil.getBean(OpenAiImageModel.class);
             case STABLE_DIFFUSION:
@@ -124,14 +131,14 @@ public class AiModelFactoryImpl implements AiModelFactory {
     public ImageModel getOrCreateImageModel(AiPlatformEnum platform, String apiKey, String url) {
         //noinspection EnhancedSwitchMigration
         switch (platform) {
+            case TONG_YI:
+                return buildTongYiImagesModel(apiKey);
+            case YI_YAN:
+                return buildQianFanImageModel(apiKey);
             case OPENAI:
                 return buildOpenAiImageModel(apiKey, url);
             case STABLE_DIFFUSION:
                 return buildStabilityAiImageModel(apiKey, url);
-            case TONG_YI:
-                return SpringUtil.getBean(TongYiImagesModel.class);
-            case YI_YAN:
-                return buildQianFanImageModel(apiKey);
             default:
                 throw new IllegalArgumentException(StrUtil.format("未知平台({})", platform));
         }
@@ -175,6 +182,14 @@ public class AiModelFactoryImpl implements AiModelFactory {
         return new TongYiAutoConfiguration().tongYiChatClient(generation, chatOptions, connectionProperties);
     }
 
+    private static TongYiImagesModel buildTongYiImagesModel(String key) {
+        ImageSynthesis imageSynthesis = SpringUtil.getBean(ImageSynthesis.class);
+        TongYiImagesProperties imagesOptions = SpringUtil.getBean(TongYiImagesProperties.class);
+        TongYiConnectionProperties connectionProperties = new TongYiConnectionProperties();
+        connectionProperties.setApiKey(key);
+        return new TongYiAutoConfiguration().tongYiImagesClient(imageSynthesis, imagesOptions, connectionProperties);
+    }
+
     /**
      * 可参考 {@link QianFanAutoConfiguration#qianFanChatModel(QianFanConnectionProperties, QianFanChatProperties, RestClient.Builder, RetryTemplate, ResponseErrorHandler)}
      */
@@ -187,6 +202,18 @@ public class AiModelFactoryImpl implements AiModelFactory {
         return new QianFanChatModel(qianFanApi);
     }
 
+    /**
+     * 可参考 {@link QianFanAutoConfiguration#qianFanImageModel(QianFanConnectionProperties, QianFanImageProperties, RestClient.Builder, RetryTemplate, ResponseErrorHandler)}
+     */
+    private QianFanImageModel buildQianFanImageModel(String key) {
+        List<String> keys = StrUtil.split(key, '|');
+        Assert.equals(keys.size(), 2, "YiYanChatClient 的密钥需要 (appKey|secretKey) 格式");
+        String appKey = keys.get(0);
+        String secretKey = keys.get(1);
+        QianFanImageApi qianFanApi = new QianFanImageApi(appKey, secretKey);
+        return new QianFanImageModel(qianFanApi);
+    }
+
     /**
      * 可参考 {@link YudaoAiAutoConfiguration#deepSeekChatModel(YudaoAiProperties)}
      */
@@ -246,8 +273,4 @@ public class AiModelFactoryImpl implements AiModelFactory {
         return new StabilityAiImageModel(stabilityAiApi);
     }
 
-    private QianFanImageModel buildQianFanImageModel(String key) {
-        List<String> keys = StrUtil.split(key, '|');
-        return new QianFanImageModel(new QianFanImageApi(keys.get(0), keys.get(1)));
-    }
 }

+ 2 - 2
yudao-module-ai/yudao-spring-boot-starter-ai/src/test/java/cn/iocoder/yudao/framework/ai/image/OpenAiImageModelTests.java

@@ -21,7 +21,7 @@ public class OpenAiImageModelTests {
             "https://api.holdai.top",
             "sk-dZEPiVaNcT3FHhef51996bAa0bC74806BeAb620dA5Da10Bf",
             RestClient.builder());
-    private final OpenAiImageModel imageClient = new OpenAiImageModel(imageApi);
+    private final OpenAiImageModel imageModel = new OpenAiImageModel(imageApi);
 
     @Test
     @Disabled
@@ -34,7 +34,7 @@ public class OpenAiImageModelTests {
         ImagePrompt prompt = new ImagePrompt("中国长城!", options);
 
         // 方法调用
-        ImageResponse response = imageClient.call(prompt);
+        ImageResponse response = imageModel.call(prompt);
         // 打印结果
         System.out.println(response);
     }

+ 23 - 29
yudao-module-ai/yudao-spring-boot-starter-ai/src/test/java/cn/iocoder/yudao/framework/ai/image/QianFanImageTests.java

@@ -1,48 +1,42 @@
 package cn.iocoder.yudao.framework.ai.image;
 
-import cn.iocoder.yudao.framework.common.util.json.JsonUtils;
+import org.junit.jupiter.api.Disabled;
 import org.junit.jupiter.api.Test;
-import org.springframework.ai.image.ImageOptionsBuilder;
 import org.springframework.ai.image.ImagePrompt;
 import org.springframework.ai.image.ImageResponse;
 import org.springframework.ai.qianfan.QianFanImageModel;
 import org.springframework.ai.qianfan.QianFanImageOptions;
-import org.springframework.ai.qianfan.api.QianFanApi;
 import org.springframework.ai.qianfan.api.QianFanImageApi;
 
+import static cn.iocoder.yudao.framework.ai.image.StabilityAiImageModelTests.viewImage;
+
 /**
- * 百度千帆 image
+ * {@link QianFanImageModel} 集成测试类
  */
 public class QianFanImageTests {
 
-    @Test
-    public void callTest() {
-        // todo @芋艿 千帆sdk有个错误,暂时没找到问题
-        QianFanImageApi qianFanImageApi = new QianFanImageApi(
-                "ghbbvbW2t7HK7WtYmEITAupm", "njJEr5AsQ5fkB3ucYYDjiQqsOZK20SGb");
-        QianFanImageModel qianFanImageModel = new QianFanImageModel(qianFanImageApi);
+    private final QianFanImageApi imageApi = new QianFanImageApi(
+            "qS8k8dYr2nXunagK4SSU8Xjj", "pHGbx51ql2f0hOyabQvSZezahVC3hh3e");
+    private final QianFanImageModel imageModel = new QianFanImageModel(imageApi);
 
+    @Test
+    @Disabled
+    public void testCall() {
+        // 准备参数
+        // 只支持 1024x1024、768x768、768x1024、1024x768、576x1024、1024x576
         QianFanImageOptions imageOptions = QianFanImageOptions.builder()
-                .withWidth(512)
-                .withHeight(512)
+                .withModel(QianFanImageApi.ImageModel.Stable_Diffusion_XL.getValue())
+                .withWidth(1024).withHeight(1024)
+                .withN(1)
                 .build();
-        ImagePrompt imagePrompt = new ImagePrompt("薄涂炫酷少女头像,田野花朵盛开", imageOptions);
-        ImageResponse call = qianFanImageModel.call(imagePrompt);
-        System.err.println(JsonUtils.toJsonString(call));
-    }
-
-    @Test
-    public void call2Test() {
-        // 官方测试 test https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-qianfan/src/test/java/org/springframework/ai/qianfan/image/QianFanImageModelIT.java
-        var options = ImageOptionsBuilder.builder().withHeight(1024).withWidth(1024).build();
-        var instructions = "薄涂炫酷少女头像,田野花朵盛开";
-
-        ImagePrompt imagePrompt = new ImagePrompt(instructions, options);
-
-        QianFanImageApi qianFanImageApi = new QianFanImageApi(
-                "ghbbvbW2t7HK7WtYmEITAupm", "njJEr5AsQ5fkB3ucYYDjiQqsOZK20SGb");
-        QianFanImageModel imageModel = new QianFanImageModel(qianFanImageApi);
-        ImageResponse imageResponse = imageModel.call(imagePrompt);
+        ImagePrompt prompt = new ImagePrompt("good", imageOptions);
+
+        // 方法调用
+        ImageResponse response = imageModel.call(prompt);
+        // 打印结果
+        String b64Json = response.getResult().getOutput().getB64Json();
+        System.out.println(response);
+        viewImage(b64Json);
     }
 
 }

+ 2 - 2
yudao-module-ai/yudao-spring-boot-starter-ai/src/test/java/cn/iocoder/yudao/framework/ai/image/StabilityAiImageModelTests.java

@@ -24,7 +24,7 @@ public class StabilityAiImageModelTests {
 
     private final StabilityAiApi imageApi = new StabilityAiApi(
             "sk-e53UqbboF8QJCscYvzJscJxJXoFcFg4iJjl1oqgE7baJETmx");
-    private final StabilityAiImageModel imageClient = new StabilityAiImageModel(imageApi);
+    private final StabilityAiImageModel imageModel = new StabilityAiImageModel(imageApi);
 
     @Test
     @Disabled
@@ -37,7 +37,7 @@ public class StabilityAiImageModelTests {
         ImagePrompt prompt = new ImagePrompt("great wall", options);
 
         // 方法调用
-        ImageResponse response = imageClient.call(prompt);
+        ImageResponse response = imageModel.call(prompt);
         // 打印结果
         String b64Json = response.getResult().getOutput().getB64Json();
         System.out.println(response);

+ 43 - 0
yudao-module-ai/yudao-spring-boot-starter-ai/src/test/java/cn/iocoder/yudao/framework/ai/image/TongYiImagesModelTest.java

@@ -0,0 +1,43 @@
+package cn.iocoder.yudao.framework.ai.image;
+
+import com.alibaba.cloud.ai.tongyi.image.TongYiImagesModel;
+import com.alibaba.dashscope.aigc.imagesynthesis.ImageSynthesis;
+import com.alibaba.dashscope.utils.Constants;
+import org.junit.jupiter.api.Disabled;
+import org.junit.jupiter.api.Test;
+import org.springframework.ai.image.ImageOptions;
+import org.springframework.ai.image.ImagePrompt;
+import org.springframework.ai.image.ImageResponse;
+import org.springframework.ai.openai.OpenAiImageOptions;
+
+/**
+ * {@link com.alibaba.cloud.ai.tongyi.image.TongYiImagesModel} 集成测试类
+ *
+ * @author fansili
+ */
+public class TongYiImagesModelTest {
+
+    private final ImageSynthesis imageApi = new ImageSynthesis();
+    private final TongYiImagesModel imageModel = new TongYiImagesModel(imageApi);
+
+    static {
+        Constants.apiKey = "sk-Zsd81gZYg7";
+    }
+
+    @Test
+    @Disabled
+    public void imageCallTest() {
+        // 准备参数
+        ImageOptions options = OpenAiImageOptions.builder()
+                .withModel(ImageSynthesis.Models.WANX_V1)
+                .withHeight(256).withWidth(256)
+                .build();
+        ImagePrompt prompt = new ImagePrompt("中国长城!", options);
+
+        // 方法调用
+        ImageResponse response = imageModel.call(prompt);
+        // 打印结果
+        System.out.println(response);
+    }
+
+}

+ 0 - 39
yudao-module-ai/yudao-spring-boot-starter-ai/src/test/java/cn/iocoder/yudao/framework/ai/image/TongYiImagesModelTests.java

@@ -1,39 +0,0 @@
-package cn.iocoder.yudao.framework.ai.image;
-
-import com.alibaba.dashscope.aigc.imagesynthesis.ImageSynthesis;
-import com.alibaba.dashscope.aigc.imagesynthesis.ImageSynthesisParam;
-import com.alibaba.dashscope.aigc.imagesynthesis.ImageSynthesisResult;
-import com.alibaba.dashscope.exception.NoApiKeyException;
-import com.alibaba.dashscope.utils.Constants;
-import com.alibaba.fastjson.JSON;
-import org.junit.jupiter.api.Test;
-
-import java.util.Map;
-
-// TODO @fan:改成 TongYiImagesModel 哈
-/**
- * 通义万象
- */
-public class TongYiImagesModelTests {
-
-    @Test
-    public void imageCallTest() throws NoApiKeyException {
-        // 设置 api key
-        Constants.apiKey = "sk-Zsd81gZYg7";
-        ImageSynthesisParam param =
-                ImageSynthesisParam.builder()
-                        .model(ImageSynthesis.Models.WANX_V1)
-                        .n(4)
-                        .size("1024*1024")
-                        .prompt("雄鹰自由自在的在蓝天白云下飞翔")
-                        .build();
-        // 创建 ImageSynthesis
-        ImageSynthesis is = new ImageSynthesis();
-        // 调用 call 生成 image
-        ImageSynthesisResult call = is.call(param);
-        System.err.println(JSON.toJSON(call));
-        for (Map<String, String> result : call.getOutput().getResults()) {
-            System.err.println("地址: " + result.get("url"));
-        }
-    }
-}