Browse Source

【代码优化】AI:ChatGlm 替换成 ZhiPuAiImage 实现

YunaiV 8 months ago
parent
commit
73502d565f

+ 6 - 4
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/image/AiImageServiceImpl.java

@@ -9,7 +9,6 @@ import cn.hutool.core.util.StrUtil;
 import cn.hutool.extra.spring.SpringUtil;
 import cn.hutool.http.HttpUtil;
 import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
-import cn.iocoder.yudao.framework.ai.core.model.chatglm.ChatGlmImageOptions;
 import cn.iocoder.yudao.framework.ai.core.model.midjourney.api.MidjourneyApi;
 import cn.iocoder.yudao.framework.common.pojo.PageParam;
 import cn.iocoder.yudao.framework.common.pojo.PageResult;
@@ -34,6 +33,7 @@ import org.springframework.ai.image.ImageResponse;
 import org.springframework.ai.openai.OpenAiImageOptions;
 import org.springframework.ai.qianfan.QianFanImageOptions;
 import org.springframework.ai.stabilityai.api.StabilityAiImageOptions;
+import org.springframework.ai.zhipuai.ZhiPuAiImageOptions;
 import org.springframework.scheduling.annotation.Async;
 import org.springframework.stereotype.Service;
 import org.springframework.transaction.annotation.Transactional;
@@ -105,7 +105,9 @@ public class AiImageServiceImpl implements AiImageService {
             ImageResponse response = imageModel.call(new ImagePrompt(req.getPrompt(), request));
 
             // 2. 上传到文件服务
-            byte[] fileContent = Base64.decode(response.getResult().getOutput().getB64Json());
+            String b64Json = response.getResult().getOutput().getB64Json();
+            byte[] fileContent = StrUtil.isNotEmpty(b64Json) ? Base64.decode(b64Json)
+                    : HttpUtil.downloadBytes(response.getResult().getOutput().getUrl());
             String filePath = fileApi.createFile(fileContent);
 
             // 3. 更新数据库
@@ -149,8 +151,8 @@ public class AiImageServiceImpl implements AiImageService {
                     .withModel(draw.getModel()).withN(1)
                     .withHeight(draw.getHeight()).withWidth(draw.getWidth())
                     .build();
-        } else if (ObjUtil.equal(draw.getPlatform(), AiPlatformEnum.CHATGLM.getPlatform())) {
-            return ChatGlmImageOptions.builder()
+        } else if (ObjUtil.equal(draw.getPlatform(), AiPlatformEnum.ZHI_PU.getPlatform())) {
+            return ZhiPuAiImageOptions.builder()
                     .withModel(draw.getModel())
                     .build();
         }

+ 0 - 7
yudao-module-ai/yudao-spring-boot-starter-ai/pom.xml

@@ -60,13 +60,6 @@
             <version>2.14.0</version>
         </dependency>
 
-        <!-- bigmodel -->
-        <dependency>
-            <groupId>cn.bigmodel.openapi</groupId>
-            <artifactId>oapi-java-sdk</artifactId>
-            <version>release-V4-2.0.2</version>
-        </dependency>
-
         <!-- Test 测试相关 -->
         <dependency>
             <groupId>org.springframework.boot</groupId>

+ 0 - 1
yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/enums/AiPlatformEnum.java

@@ -28,7 +28,6 @@ public enum AiPlatformEnum {
     STABLE_DIFFUSION("StableDiffusion", "StableDiffusion"), // Stability AI
     MIDJOURNEY("Midjourney", "Midjourney"), // Midjourney
     SUNO("Suno", "Suno"), // Suno AI
-    CHATGLM("ChatGlm", "ChatGlm"), // Suno AI
 
     ;
 

+ 19 - 7
yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/factory/AiModelFactoryImpl.java

@@ -9,7 +9,6 @@ import cn.hutool.extra.spring.SpringUtil;
 import cn.iocoder.yudao.framework.ai.config.YudaoAiAutoConfiguration;
 import cn.iocoder.yudao.framework.ai.config.YudaoAiProperties;
 import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
-import cn.iocoder.yudao.framework.ai.core.model.chatglm.ChatGlmImageModel;
 import cn.iocoder.yudao.framework.ai.core.model.deepseek.DeepSeekChatModel;
 import cn.iocoder.yudao.framework.ai.core.model.midjourney.api.MidjourneyApi;
 import cn.iocoder.yudao.framework.ai.core.model.suno.api.SunoApi;
@@ -31,6 +30,7 @@ import org.springframework.ai.autoconfigure.qianfan.QianFanImageProperties;
 import org.springframework.ai.autoconfigure.zhipuai.ZhiPuAiAutoConfiguration;
 import org.springframework.ai.autoconfigure.zhipuai.ZhiPuAiChatProperties;
 import org.springframework.ai.autoconfigure.zhipuai.ZhiPuAiConnectionProperties;
+import org.springframework.ai.autoconfigure.zhipuai.ZhiPuAiImageProperties;
 import org.springframework.ai.chat.model.ChatModel;
 import org.springframework.ai.image.ImageModel;
 import org.springframework.ai.model.function.FunctionCallbackContext;
@@ -48,7 +48,9 @@ import org.springframework.ai.qianfan.api.QianFanImageApi;
 import org.springframework.ai.stabilityai.StabilityAiImageModel;
 import org.springframework.ai.stabilityai.api.StabilityAiApi;
 import org.springframework.ai.zhipuai.ZhiPuAiChatModel;
+import org.springframework.ai.zhipuai.ZhiPuAiImageModel;
 import org.springframework.ai.zhipuai.api.ZhiPuAiApi;
+import org.springframework.ai.zhipuai.api.ZhiPuAiImageApi;
 import org.springframework.retry.support.RetryTemplate;
 import org.springframework.web.client.ResponseErrorHandler;
 import org.springframework.web.client.RestClient;
@@ -119,6 +121,8 @@ public class AiModelFactoryImpl implements AiModelFactory {
                 return SpringUtil.getBean(TongYiImagesModel.class);
             case YI_YAN:
                 return SpringUtil.getBean(QianFanImageModel.class);
+            case ZHI_PU:
+                return SpringUtil.getBean(ZhiPuAiImageModel.class);
             case OPENAI:
                 return SpringUtil.getBean(OpenAiImageModel.class);
             case STABLE_DIFFUSION:
@@ -136,12 +140,12 @@ public class AiModelFactoryImpl implements AiModelFactory {
                 return buildTongYiImagesModel(apiKey);
             case YI_YAN:
                 return buildQianFanImageModel(apiKey);
+            case ZHI_PU:
+                return buildZhiPuAiImageModel(apiKey, url);
             case OPENAI:
                 return buildOpenAiImageModel(apiKey, url);
             case STABLE_DIFFUSION:
                 return buildStabilityAiImageModel(apiKey, url);
-            case CHATGLM:
-                return buildChatGlmModel(apiKey);
             default:
                 throw new IllegalArgumentException(StrUtil.format("未知平台({})", platform));
         }
@@ -225,7 +229,8 @@ public class AiModelFactoryImpl implements AiModelFactory {
     }
 
     /**
-     * 可参考 {@link ZhiPuAiAutoConfiguration#zhiPuAiChatModel(ZhiPuAiConnectionProperties, ZhiPuAiChatProperties, RestClient.Builder, List, FunctionCallbackContext, RetryTemplate, ResponseErrorHandler)}
+     * 可参考 {@link ZhiPuAiAutoConfiguration#zhiPuAiChatModel(
+     * ZhiPuAiConnectionProperties, ZhiPuAiChatProperties, RestClient.Builder, List, FunctionCallbackContext, RetryTemplate, ResponseErrorHandler)}
      */
     private ZhiPuAiChatModel buildZhiPuChatModel(String apiKey, String url) {
         url = StrUtil.blankToDefault(url, ZhiPuAiConnectionProperties.DEFAULT_BASE_URL);
@@ -233,6 +238,16 @@ public class AiModelFactoryImpl implements AiModelFactory {
         return new ZhiPuAiChatModel(zhiPuAiApi);
     }
 
+    /**
+     * 可参考 {@link ZhiPuAiAutoConfiguration#zhiPuAiImageModel(
+     * ZhiPuAiConnectionProperties, ZhiPuAiImageProperties, RestClient.Builder, RetryTemplate, ResponseErrorHandler)}
+     */
+    private ZhiPuAiImageModel buildZhiPuAiImageModel(String apiKey, String url) {
+        url = StrUtil.blankToDefault(url, ZhiPuAiConnectionProperties.DEFAULT_BASE_URL);
+        ZhiPuAiImageApi zhiPuAiApi = new ZhiPuAiImageApi(url, apiKey, RestClient.builder());
+        return new ZhiPuAiImageModel(zhiPuAiApi);
+    }
+
     /**
      * 可参考 {@link YudaoAiAutoConfiguration#xingHuoChatClient(YudaoAiProperties)}
      */
@@ -276,7 +291,4 @@ public class AiModelFactoryImpl implements AiModelFactory {
         return new StabilityAiImageModel(stabilityAiApi);
     }
 
-    private ChatGlmImageModel buildChatGlmModel(String apiKey) {
-        return new ChatGlmImageModel(apiKey);
-    }
 }

+ 0 - 75
yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/model/chatglm/ChatGlmImageModel.java

@@ -1,75 +0,0 @@
-package cn.iocoder.yudao.framework.ai.core.model.chatglm;
-
-import cn.iocoder.yudao.framework.ai.core.model.chatglm.api.ChatGlmResponseMetadata;
-import com.zhipu.oapi.ClientV4;
-import com.zhipu.oapi.service.v4.image.CreateImageRequest;
-import com.zhipu.oapi.service.v4.image.ImageApiResponse;
-import org.springframework.ai.image.*;
-
-import java.io.ByteArrayOutputStream;
-import java.net.URL;
-import java.util.Base64;
-import java.util.stream.Collectors;
-
-public class ChatGlmImageModel implements ImageModel {
-
-    private ClientV4 client;
-
-    public ChatGlmImageModel(String apiSecretKey) {
-        client = new ClientV4.Builder(apiSecretKey).build();
-    }
-
-    @Override
-    public ImageResponse call(ImagePrompt request) {
-        CreateImageRequest imageRequest = CreateImageRequest.builder()
-                .model(request.getOptions().getModel())
-                .prompt(request.getInstructions().get(0).getText())
-                .build();
-        return convert(client.createImage(imageRequest));
-    }
-
-    private ImageResponse convert(ImageApiResponse result) {
-        return new ImageResponse(
-                result.getData().getData().stream().map(item -> {
-                    try {
-                        String url = item.getUrl();
-                        String base64Image = convertImageToBase64(url);
-                        Image image = new Image(url, base64Image);
-                        return new ImageGeneration(image);
-                    } catch (Exception e) {
-                        throw new RuntimeException(e);
-                    }
-                }).collect(Collectors.toList()),
-                new ChatGlmResponseMetadata(result)
-        );
-    }
-
-
-    /**
-     * Convert image to base64.
-     * @param imageUrl the image url.
-     * @return the base64 image.
-     * @throws Exception the exception.
-     */
-    public String convertImageToBase64(String imageUrl) throws Exception {
-
-        var url = new URL(imageUrl);
-        var inputStream = url.openStream();
-        var outputStream = new ByteArrayOutputStream();
-        var buffer = new byte[4096];
-        int bytesRead;
-
-        while ((bytesRead = inputStream.read(buffer)) != -1) {
-            outputStream.write(buffer, 0, bytesRead);
-        }
-
-        var imageBytes = outputStream.toByteArray();
-
-        String base64Image = Base64.getEncoder().encodeToString(imageBytes);
-
-        inputStream.close();
-        outputStream.close();
-
-        return base64Image;
-    }
-}

+ 0 - 115
yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/model/chatglm/ChatGlmImageOptions.java

@@ -1,115 +0,0 @@
-package cn.iocoder.yudao.framework.ai.core.model.chatglm;
-
-import com.fasterxml.jackson.annotation.JsonProperty;
-import lombok.Setter;
-import org.springframework.ai.image.ImageOptions;
-
-/**
- * chatglm
- * api地址:https://open.bigmodel.cn/dev/api#cogview
- */
-@Setter
-public class ChatGlmImageOptions implements ImageOptions {
-
-    @JsonProperty("n")
-    private Integer n;
-
-    @JsonProperty("model")
-    private String model = "cogview-3";
-
-    @JsonProperty("size_width")
-    private Integer width;
-
-    @JsonProperty("size_height")
-    private Integer height;
-
-    @JsonProperty("size")
-    private String size;
-
-    @JsonProperty("style")
-    private String style;
-
-    @JsonProperty("user_id")
-    private String user;
-
-    @JsonProperty("responseFormat")
-    private String responseFormat;
-
-    // ==== build
-
-
-    public static ChatGlmImageOptions.Builder builder() {
-        return new ChatGlmImageOptions.Builder();
-    }
-
-    public static class Builder {
-
-        private final ChatGlmImageOptions options;
-
-        private Builder() {
-            this.options = new ChatGlmImageOptions();
-        }
-
-        public ChatGlmImageOptions.Builder withN(Integer n) {
-            options.setN(n);
-            return this;
-        }
-
-        public ChatGlmImageOptions.Builder withModel(String model) {
-            options.setModel(model);
-            return this;
-        }
-
-        public ChatGlmImageOptions.Builder withWidth(Integer width) {
-            options.setWidth(width);
-            return this;
-        }
-
-        public ChatGlmImageOptions.Builder withHeight(Integer height) {
-            options.setHeight(height);
-            return this;
-        }
-
-        public ChatGlmImageOptions.Builder withStyle(String style) {
-            options.setStyle(style);
-            return this;
-        }
-
-        public ChatGlmImageOptions.Builder withUser(String user) {
-            options.setUser(user);
-            return this;
-        }
-
-        public ChatGlmImageOptions build() {
-            return options;
-        }
-
-    }
-
-    // ==== get
-
-    @Override
-    public Integer getN() {
-        return n;
-    }
-
-    @Override
-    public String getModel() {
-        return model;
-    }
-
-    @Override
-    public Integer getWidth() {
-        return width;
-    }
-
-    @Override
-    public Integer getHeight() {
-        return height;
-    }
-
-    @Override
-    public String getResponseFormat() {
-        return responseFormat;
-    }
-}

+ 0 - 24
yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/model/chatglm/api/ChatGlmResponseMetadata.java

@@ -1,24 +0,0 @@
-package cn.iocoder.yudao.framework.ai.core.model.chatglm.api;
-
-import com.zhipu.oapi.service.v4.image.ImageApiResponse;
-import org.springframework.ai.image.ImageResponseMetadata;
-
-import java.util.HashMap;
-
-public class ChatGlmResponseMetadata extends HashMap<String, Object> implements ImageResponseMetadata {
-
-    private Long created;
-
-    public ChatGlmResponseMetadata(ImageApiResponse result) {
-        created = result.getData().getCreated();
-    }
-
-    @Override
-    public Long getCreated() {
-        return created;
-    }
-
-    public void setCreated(Long created) {
-        this.created = created;
-    }
-}

+ 0 - 40
yudao-module-ai/yudao-spring-boot-starter-ai/src/test/java/cn/iocoder/yudao/framework/ai/image/ChatGlmImageModelTests.java

@@ -1,40 +0,0 @@
-package cn.iocoder.yudao.framework.ai.image;
-
-import cn.iocoder.yudao.framework.ai.core.model.chatglm.ChatGlmImageModel;
-import cn.iocoder.yudao.framework.ai.core.model.chatglm.ChatGlmImageOptions;
-import cn.iocoder.yudao.framework.common.util.json.JsonUtils;
-import com.alibaba.fastjson.JSON;
-import com.zhipu.oapi.ClientV4;
-import com.zhipu.oapi.core.httpclient.ApacheHttpClientTransport;
-import com.zhipu.oapi.service.v4.image.CreateImageRequest;
-import com.zhipu.oapi.service.v4.image.ImageApiResponse;
-import org.junit.jupiter.api.Test;
-import org.springframework.ai.image.ImageOptionsBuilder;
-import org.springframework.ai.image.ImagePrompt;
-import org.springframework.ai.image.ImageResponse;
-import org.springframework.ai.qianfan.QianFanImageModel;
-import org.springframework.ai.qianfan.QianFanImageOptions;
-import org.springframework.ai.qianfan.api.QianFanImageApi;
-
-/**
- * 百度千帆 image
- */
-public class ChatGlmImageModelTests {
-
-    @Test
-    public void callTest() {
-        ChatGlmImageModel model = new ChatGlmImageModel("78d3228c1d9e5e342a3e1ab349e2dd7b.VXLoq5vrwK2ofboy");
-        ImageResponse call = model.call(new ImagePrompt("万里长城", ChatGlmImageOptions.builder().build()));
-        System.err.println(call.getResult().getOutput().getUrl());
-    }
-
-    @Test
-    public void createImageTest() {
-        ClientV4 client = new ClientV4.Builder("78d3228c1d9e5e342a3e1ab349e2dd7b.VXLoq5vrwK2ofboy").build();
-        CreateImageRequest createImageRequest = new CreateImageRequest();
-        createImageRequest.setModel("cogview-3");
-        createImageRequest.setPrompt("长城!");
-        ImageApiResponse image = client.createImage(createImageRequest);
-        System.err.println(JSON.toJSONString(image));
-    }
-}

+ 35 - 0
yudao-module-ai/yudao-spring-boot-starter-ai/src/test/java/cn/iocoder/yudao/framework/ai/image/ZhiPuAiImageModelTests.java

@@ -0,0 +1,35 @@
+package cn.iocoder.yudao.framework.ai.image;
+
+import org.junit.jupiter.api.Disabled;
+import org.junit.jupiter.api.Test;
+import org.springframework.ai.image.ImagePrompt;
+import org.springframework.ai.image.ImageResponse;
+import org.springframework.ai.zhipuai.ZhiPuAiImageModel;
+import org.springframework.ai.zhipuai.ZhiPuAiImageOptions;
+import org.springframework.ai.zhipuai.api.ZhiPuAiImageApi;
+
+/**
+ * {@link ZhiPuAiImageModel} 集成测试
+ */
+public class ZhiPuAiImageModelTests {
+
+    private final ZhiPuAiImageApi imageApi = new ZhiPuAiImageApi(
+            "78d3228c1d9e5e342a3e1ab349e2dd7b.VXLoq5vrwK2ofboy");
+    private final ZhiPuAiImageModel imageModel = new ZhiPuAiImageModel(imageApi);
+
+    @Test
+    @Disabled
+    public void testCall() {
+        // 准备参数
+        ZhiPuAiImageOptions imageOptions = ZhiPuAiImageOptions.builder()
+                .withModel(ZhiPuAiImageApi.ImageModel.CogView_3.getValue())
+                .build();
+        ImagePrompt prompt = new ImagePrompt("万里长城", imageOptions);
+
+        // 方法调用
+        ImageResponse response = imageModel.call(prompt);
+        // 打印结果
+        System.out.println(response);
+    }
+
+}