Browse Source

【代码优化】AI:通义千问的 tests 类

YunaiV 8 months ago
parent
commit
c6c003707e

+ 1 - 1
script/idea/http-client.env.json

@@ -1,7 +1,7 @@
 {
   "local": {
     "baseUrl": "http://127.0.0.1:48080/admin-api",
-    "token": "Bearer 1c2ce60de96a4fb0bf5bea9604099a3d",
+    "token": "test1",
     "adminTenentId": "1",
 
     "appApi": "http://127.0.0.1:48080/app-api",

+ 1 - 0
yudao-module-ai/yudao-module-ai-api/src/main/java/cn/iocoder/yudao/module/ai/enums/AiChatRoleEnum.java

@@ -39,6 +39,7 @@ public enum AiChatRoleEnum implements IntArrayValuable {
             除此之外不要任何解释性语句。
             """);
 
+    // TODO @xin:这个 role 是不是删除掉好点哈。= = 目前主要是没做角色枚举。这里多了 role 反倒容易误解哈
     /**
      * 角色
      */

+ 0 - 4
yudao-module-ai/yudao-module-ai-biz/pom.xml

@@ -60,9 +60,5 @@
             <groupId>cn.iocoder.boot</groupId>
             <artifactId>yudao-spring-boot-starter-test</artifactId>
         </dependency>
-        <dependency>
-            <groupId>cn.iocoder.boot</groupId>
-            <artifactId>yudao-spring-boot-starter-excel</artifactId>
-        </dependency>
     </dependencies>
 </project>

+ 1 - 1
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/mysql/mindmap/AiMindMapMapper.java

@@ -5,7 +5,7 @@ import cn.iocoder.yudao.module.ai.dal.dataobject.mindmap.AiMindMapDO;
 import org.apache.ibatis.annotations.Mapper;
 
 /**
- * AI 音乐 Mapper
+ * AI 思维导图 Mapper
  *
  * @author xiaoxin
  */

+ 31 - 8
yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/factory/AiModelFactoryImpl.java

@@ -18,12 +18,15 @@ import com.alibaba.cloud.ai.tongyi.TongYiConnectionProperties;
 import com.alibaba.cloud.ai.tongyi.chat.TongYiChatModel;
 import com.alibaba.cloud.ai.tongyi.chat.TongYiChatProperties;
 import com.alibaba.cloud.ai.tongyi.image.TongYiImagesModel;
+import com.alibaba.cloud.ai.tongyi.image.TongYiImagesProperties;
 import com.alibaba.dashscope.aigc.generation.Generation;
+import com.alibaba.dashscope.aigc.imagesynthesis.ImageSynthesis;
 import org.springframework.ai.autoconfigure.ollama.OllamaAutoConfiguration;
 import org.springframework.ai.autoconfigure.openai.OpenAiAutoConfiguration;
 import org.springframework.ai.autoconfigure.qianfan.QianFanAutoConfiguration;
 import org.springframework.ai.autoconfigure.qianfan.QianFanChatProperties;
 import org.springframework.ai.autoconfigure.qianfan.QianFanConnectionProperties;
+import org.springframework.ai.autoconfigure.qianfan.QianFanImageProperties;
 import org.springframework.ai.autoconfigure.zhipuai.ZhiPuAiAutoConfiguration;
 import org.springframework.ai.autoconfigure.zhipuai.ZhiPuAiChatProperties;
 import org.springframework.ai.autoconfigure.zhipuai.ZhiPuAiConnectionProperties;
@@ -111,6 +114,10 @@ public class AiModelFactoryImpl implements AiModelFactory {
     public ImageModel getDefaultImageModel(AiPlatformEnum platform) {
         //noinspection EnhancedSwitchMigration
         switch (platform) {
+            case TONG_YI:
+                return SpringUtil.getBean(TongYiImagesModel.class);
+            case YI_YAN:
+                return SpringUtil.getBean(QianFanImageModel.class);
             case OPENAI:
                 return SpringUtil.getBean(OpenAiImageModel.class);
             case STABLE_DIFFUSION:
@@ -124,14 +131,14 @@ public class AiModelFactoryImpl implements AiModelFactory {
     public ImageModel getOrCreateImageModel(AiPlatformEnum platform, String apiKey, String url) {
         //noinspection EnhancedSwitchMigration
         switch (platform) {
+            case TONG_YI:
+                return buildTongYiImagesModel(apiKey);
+            case YI_YAN:
+                return buildQianFanImageModel(apiKey);
             case OPENAI:
                 return buildOpenAiImageModel(apiKey, url);
             case STABLE_DIFFUSION:
                 return buildStabilityAiImageModel(apiKey, url);
-            case TONG_YI:
-                return SpringUtil.getBean(TongYiImagesModel.class);
-            case YI_YAN:
-                return buildQianFanImageModel(apiKey);
             default:
                 throw new IllegalArgumentException(StrUtil.format("未知平台({})", platform));
         }
@@ -175,6 +182,14 @@ public class AiModelFactoryImpl implements AiModelFactory {
         return new TongYiAutoConfiguration().tongYiChatClient(generation, chatOptions, connectionProperties);
     }
 
+    private static TongYiImagesModel buildTongYiImagesModel(String key) {
+        ImageSynthesis imageSynthesis = SpringUtil.getBean(ImageSynthesis.class);
+        TongYiImagesProperties imagesOptions = SpringUtil.getBean(TongYiImagesProperties.class);
+        TongYiConnectionProperties connectionProperties = new TongYiConnectionProperties();
+        connectionProperties.setApiKey(key);
+        return new TongYiAutoConfiguration().tongYiImagesClient(imageSynthesis, imagesOptions, connectionProperties);
+    }
+
     /**
      * 可参考 {@link QianFanAutoConfiguration#qianFanChatModel(QianFanConnectionProperties, QianFanChatProperties, RestClient.Builder, RetryTemplate, ResponseErrorHandler)}
      */
@@ -187,6 +202,18 @@ public class AiModelFactoryImpl implements AiModelFactory {
         return new QianFanChatModel(qianFanApi);
     }
 
+    /**
+     * 可参考 {@link QianFanAutoConfiguration#qianFanImageModel(QianFanConnectionProperties, QianFanImageProperties, RestClient.Builder, RetryTemplate, ResponseErrorHandler)}
+     */
+    private QianFanImageModel buildQianFanImageModel(String key) {
+        List<String> keys = StrUtil.split(key, '|');
+        Assert.equals(keys.size(), 2, "YiYanChatClient 的密钥需要 (appKey|secretKey) 格式");
+        String appKey = keys.get(0);
+        String secretKey = keys.get(1);
+        QianFanImageApi qianFanApi = new QianFanImageApi(appKey, secretKey);
+        return new QianFanImageModel(qianFanApi);
+    }
+
     /**
      * 可参考 {@link YudaoAiAutoConfiguration#deepSeekChatModel(YudaoAiProperties)}
      */
@@ -246,8 +273,4 @@ public class AiModelFactoryImpl implements AiModelFactory {
         return new StabilityAiImageModel(stabilityAiApi);
     }
 
-    private QianFanImageModel buildQianFanImageModel(String key) {
-        List<String> keys = StrUtil.split(key, '|');
-        return new QianFanImageModel(new QianFanImageApi(keys.get(0), keys.get(1)));
-    }
 }

+ 2 - 2
yudao-module-ai/yudao-spring-boot-starter-ai/src/test/java/cn/iocoder/yudao/framework/ai/image/OpenAiImageModelTests.java

@@ -21,7 +21,7 @@ public class OpenAiImageModelTests {
             "https://api.holdai.top",
             "sk-dZEPiVaNcT3FHhef51996bAa0bC74806BeAb620dA5Da10Bf",
             RestClient.builder());
-    private final OpenAiImageModel imageClient = new OpenAiImageModel(imageApi);
+    private final OpenAiImageModel imageModel = new OpenAiImageModel(imageApi);
 
     @Test
     @Disabled
@@ -34,7 +34,7 @@ public class OpenAiImageModelTests {
         ImagePrompt prompt = new ImagePrompt("中国长城!", options);
 
         // 方法调用
-        ImageResponse response = imageClient.call(prompt);
+        ImageResponse response = imageModel.call(prompt);
         // 打印结果
         System.out.println(response);
     }

+ 3 - 2
yudao-module-ai/yudao-spring-boot-starter-ai/src/test/java/cn/iocoder/yudao/framework/ai/image/QianFanImageTests.java

@@ -7,7 +7,6 @@ import org.springframework.ai.image.ImagePrompt;
 import org.springframework.ai.image.ImageResponse;
 import org.springframework.ai.qianfan.QianFanImageModel;
 import org.springframework.ai.qianfan.QianFanImageOptions;
-import org.springframework.ai.qianfan.api.QianFanApi;
 import org.springframework.ai.qianfan.api.QianFanImageApi;
 
 /**
@@ -19,7 +18,7 @@ public class QianFanImageTests {
     public void callTest() {
         // todo @芋艿 千帆sdk有个错误,暂时没找到问题
         QianFanImageApi qianFanImageApi = new QianFanImageApi(
-                "ghbbvbW2t7HK7WtYmEITAupm", "njJEr5AsQ5fkB3ucYYDjiQqsOZK20SGb");
+                "qS8k8dYr2nXunagK4SSU8Xjj", "pHGbx51ql2f0hOyabQvSZezahVC3hh3e");
         QianFanImageModel qianFanImageModel = new QianFanImageModel(qianFanImageApi);
 
         QianFanImageOptions imageOptions = QianFanImageOptions.builder()
@@ -45,4 +44,6 @@ public class QianFanImageTests {
         ImageResponse imageResponse = imageModel.call(imagePrompt);
     }
 
+
+
 }

+ 2 - 2
yudao-module-ai/yudao-spring-boot-starter-ai/src/test/java/cn/iocoder/yudao/framework/ai/image/StabilityAiImageModelTests.java

@@ -24,7 +24,7 @@ public class StabilityAiImageModelTests {
 
     private final StabilityAiApi imageApi = new StabilityAiApi(
             "sk-e53UqbboF8QJCscYvzJscJxJXoFcFg4iJjl1oqgE7baJETmx");
-    private final StabilityAiImageModel imageClient = new StabilityAiImageModel(imageApi);
+    private final StabilityAiImageModel imageModel = new StabilityAiImageModel(imageApi);
 
     @Test
     @Disabled
@@ -37,7 +37,7 @@ public class StabilityAiImageModelTests {
         ImagePrompt prompt = new ImagePrompt("great wall", options);
 
         // 方法调用
-        ImageResponse response = imageClient.call(prompt);
+        ImageResponse response = imageModel.call(prompt);
         // 打印结果
         String b64Json = response.getResult().getOutput().getB64Json();
         System.out.println(response);

+ 41 - 0
yudao-module-ai/yudao-spring-boot-starter-ai/src/test/java/cn/iocoder/yudao/framework/ai/image/TongYiImagesModelTest.java

@@ -0,0 +1,41 @@
+package cn.iocoder.yudao.framework.ai.image;
+
+import com.alibaba.cloud.ai.tongyi.image.TongYiImagesModel;
+import com.alibaba.dashscope.aigc.imagesynthesis.ImageSynthesis;
+import com.alibaba.dashscope.utils.Constants;
+import org.junit.jupiter.api.Test;
+import org.springframework.ai.image.ImageOptions;
+import org.springframework.ai.image.ImagePrompt;
+import org.springframework.ai.image.ImageResponse;
+import org.springframework.ai.openai.OpenAiImageOptions;
+
+/**
+ * {@link com.alibaba.cloud.ai.tongyi.image.TongYiImagesModel} 集成测试类
+ *
+ * @author fansili
+ */
+public class TongYiImagesModelTest {
+
+    private final ImageSynthesis imageApi = new ImageSynthesis();
+    private final TongYiImagesModel imageModel = new TongYiImagesModel(imageApi);
+
+    static {
+        Constants.apiKey = "sk-Zsd81gZYg7";
+    }
+
+    @Test
+    public void imageCallTest() {
+        // 准备参数
+        ImageOptions options = OpenAiImageOptions.builder()
+                .withModel(ImageSynthesis.Models.WANX_V1)
+                .withHeight(256).withWidth(256)
+                .build();
+        ImagePrompt prompt = new ImagePrompt("中国长城!", options);
+
+        // 方法调用
+        ImageResponse response = imageModel.call(prompt);
+        // 打印结果
+        System.out.println(response);
+    }
+
+}

+ 0 - 39
yudao-module-ai/yudao-spring-boot-starter-ai/src/test/java/cn/iocoder/yudao/framework/ai/image/TongYiImagesModelTests.java

@@ -1,39 +0,0 @@
-package cn.iocoder.yudao.framework.ai.image;
-
-import com.alibaba.dashscope.aigc.imagesynthesis.ImageSynthesis;
-import com.alibaba.dashscope.aigc.imagesynthesis.ImageSynthesisParam;
-import com.alibaba.dashscope.aigc.imagesynthesis.ImageSynthesisResult;
-import com.alibaba.dashscope.exception.NoApiKeyException;
-import com.alibaba.dashscope.utils.Constants;
-import com.alibaba.fastjson.JSON;
-import org.junit.jupiter.api.Test;
-
-import java.util.Map;
-
-// TODO @fan:改成 TongYiImagesModel 哈
-/**
- * 通义万象
- */
-public class TongYiImagesModelTests {
-
-    @Test
-    public void imageCallTest() throws NoApiKeyException {
-        // 设置 api key
-        Constants.apiKey = "sk-Zsd81gZYg7";
-        ImageSynthesisParam param =
-                ImageSynthesisParam.builder()
-                        .model(ImageSynthesis.Models.WANX_V1)
-                        .n(4)
-                        .size("1024*1024")
-                        .prompt("雄鹰自由自在的在蓝天白云下飞翔")
-                        .build();
-        // 创建 ImageSynthesis
-        ImageSynthesis is = new ImageSynthesis();
-        // 调用 call 生成 image
-        ImageSynthesisResult call = is.call(param);
-        System.err.println(JSON.toJSON(call));
-        for (Map<String, String> result : call.getOutput().getResults()) {
-            System.err.println("地址: " + result.get("url"));
-        }
-    }
-}