Browse Source

!1013 新增文心一言、智谱 AI 的绘图能力
Merge pull request !1013 from 芋道源码/master-jdk21-ai

芋道源码 8 months ago
parent
commit
3aad33c3cd

+ 8 - 1
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/image/AiImageServiceImpl.java

@@ -33,6 +33,7 @@ import org.springframework.ai.image.ImageResponse;
 import org.springframework.ai.openai.OpenAiImageOptions;
 import org.springframework.ai.qianfan.QianFanImageOptions;
 import org.springframework.ai.stabilityai.api.StabilityAiImageOptions;
+import org.springframework.ai.zhipuai.ZhiPuAiImageOptions;
 import org.springframework.scheduling.annotation.Async;
 import org.springframework.stereotype.Service;
 import org.springframework.transaction.annotation.Transactional;
@@ -104,7 +105,9 @@ public class AiImageServiceImpl implements AiImageService {
             ImageResponse response = imageModel.call(new ImagePrompt(req.getPrompt(), request));
 
             // 2. 上传到文件服务
-            byte[] fileContent = Base64.decode(response.getResult().getOutput().getB64Json());
+            String b64Json = response.getResult().getOutput().getB64Json();
+            byte[] fileContent = StrUtil.isNotEmpty(b64Json) ? Base64.decode(b64Json)
+                    : HttpUtil.downloadBytes(response.getResult().getOutput().getUrl());
             String filePath = fileApi.createFile(fileContent);
 
             // 3. 更新数据库
@@ -148,6 +151,10 @@ public class AiImageServiceImpl implements AiImageService {
                     .withModel(draw.getModel()).withN(1)
                     .withHeight(draw.getHeight()).withWidth(draw.getWidth())
                     .build();
+        } else if (ObjUtil.equal(draw.getPlatform(), AiPlatformEnum.ZHI_PU.getPlatform())) {
+            return ZhiPuAiImageOptions.builder()
+                    .withModel(draw.getModel())
+                    .build();
         }
         throw new IllegalArgumentException("不支持的 AI 平台:" + draw.getPlatform());
     }

+ 19 - 1
yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/factory/AiModelFactoryImpl.java

@@ -30,6 +30,7 @@ import org.springframework.ai.autoconfigure.qianfan.QianFanImageProperties;
 import org.springframework.ai.autoconfigure.zhipuai.ZhiPuAiAutoConfiguration;
 import org.springframework.ai.autoconfigure.zhipuai.ZhiPuAiChatProperties;
 import org.springframework.ai.autoconfigure.zhipuai.ZhiPuAiConnectionProperties;
+import org.springframework.ai.autoconfigure.zhipuai.ZhiPuAiImageProperties;
 import org.springframework.ai.chat.model.ChatModel;
 import org.springframework.ai.image.ImageModel;
 import org.springframework.ai.model.function.FunctionCallbackContext;
@@ -47,7 +48,9 @@ import org.springframework.ai.qianfan.api.QianFanImageApi;
 import org.springframework.ai.stabilityai.StabilityAiImageModel;
 import org.springframework.ai.stabilityai.api.StabilityAiApi;
 import org.springframework.ai.zhipuai.ZhiPuAiChatModel;
+import org.springframework.ai.zhipuai.ZhiPuAiImageModel;
 import org.springframework.ai.zhipuai.api.ZhiPuAiApi;
+import org.springframework.ai.zhipuai.api.ZhiPuAiImageApi;
 import org.springframework.retry.support.RetryTemplate;
 import org.springframework.web.client.ResponseErrorHandler;
 import org.springframework.web.client.RestClient;
@@ -118,6 +121,8 @@ public class AiModelFactoryImpl implements AiModelFactory {
                 return SpringUtil.getBean(TongYiImagesModel.class);
             case YI_YAN:
                 return SpringUtil.getBean(QianFanImageModel.class);
+            case ZHI_PU:
+                return SpringUtil.getBean(ZhiPuAiImageModel.class);
             case OPENAI:
                 return SpringUtil.getBean(OpenAiImageModel.class);
             case STABLE_DIFFUSION:
@@ -135,6 +140,8 @@ public class AiModelFactoryImpl implements AiModelFactory {
                 return buildTongYiImagesModel(apiKey);
             case YI_YAN:
                 return buildQianFanImageModel(apiKey);
+            case ZHI_PU:
+                return buildZhiPuAiImageModel(apiKey, url);
             case OPENAI:
                 return buildOpenAiImageModel(apiKey, url);
             case STABLE_DIFFUSION:
@@ -222,7 +229,8 @@ public class AiModelFactoryImpl implements AiModelFactory {
     }
 
     /**
-     * 可参考 {@link ZhiPuAiAutoConfiguration#zhiPuAiChatModel(ZhiPuAiConnectionProperties, ZhiPuAiChatProperties, RestClient.Builder, List, FunctionCallbackContext, RetryTemplate, ResponseErrorHandler)}
+     * 可参考 {@link ZhiPuAiAutoConfiguration#zhiPuAiChatModel(
+     * ZhiPuAiConnectionProperties, ZhiPuAiChatProperties, RestClient.Builder, List, FunctionCallbackContext, RetryTemplate, ResponseErrorHandler)}
      */
     private ZhiPuAiChatModel buildZhiPuChatModel(String apiKey, String url) {
         url = StrUtil.blankToDefault(url, ZhiPuAiConnectionProperties.DEFAULT_BASE_URL);
@@ -230,6 +238,16 @@ public class AiModelFactoryImpl implements AiModelFactory {
         return new ZhiPuAiChatModel(zhiPuAiApi);
     }
 
+    /**
+     * 可参考 {@link ZhiPuAiAutoConfiguration#zhiPuAiImageModel(
+     * ZhiPuAiConnectionProperties, ZhiPuAiImageProperties, RestClient.Builder, RetryTemplate, ResponseErrorHandler)}
+     */
+    private ZhiPuAiImageModel buildZhiPuAiImageModel(String apiKey, String url) {
+        url = StrUtil.blankToDefault(url, ZhiPuAiConnectionProperties.DEFAULT_BASE_URL);
+        ZhiPuAiImageApi zhiPuAiApi = new ZhiPuAiImageApi(url, apiKey, RestClient.builder());
+        return new ZhiPuAiImageModel(zhiPuAiApi);
+    }
+
     /**
      * 可参考 {@link YudaoAiAutoConfiguration#xingHuoChatClient(YudaoAiProperties)}
      */

+ 35 - 0
yudao-module-ai/yudao-spring-boot-starter-ai/src/test/java/cn/iocoder/yudao/framework/ai/image/ZhiPuAiImageModelTests.java

@@ -0,0 +1,35 @@
+package cn.iocoder.yudao.framework.ai.image;
+
+import org.junit.jupiter.api.Disabled;
+import org.junit.jupiter.api.Test;
+import org.springframework.ai.image.ImagePrompt;
+import org.springframework.ai.image.ImageResponse;
+import org.springframework.ai.zhipuai.ZhiPuAiImageModel;
+import org.springframework.ai.zhipuai.ZhiPuAiImageOptions;
+import org.springframework.ai.zhipuai.api.ZhiPuAiImageApi;
+
+/**
+ * {@link ZhiPuAiImageModel} 集成测试
+ */
+public class ZhiPuAiImageModelTests {
+
+    private final ZhiPuAiImageApi imageApi = new ZhiPuAiImageApi(
+            "78d3228c1d9e5e342a3e1ab349e2dd7b.VXLoq5vrwK2ofboy");
+    private final ZhiPuAiImageModel imageModel = new ZhiPuAiImageModel(imageApi);
+
+    @Test
+    @Disabled
+    public void testCall() {
+        // 准备参数
+        ZhiPuAiImageOptions imageOptions = ZhiPuAiImageOptions.builder()
+                .withModel(ZhiPuAiImageApi.ImageModel.CogView_3.getValue())
+                .build();
+        ImagePrompt prompt = new ImagePrompt("万里长城", imageOptions);
+
+        // 方法调用
+        ImageResponse response = imageModel.call(prompt);
+        // 打印结果
+        System.out.println(response);
+    }
+
+}