Эх сурвалжийг харах

【增加】chatglm 实现 spring ai 标准

cherishsince 9 сар өмнө
parent
commit
d865cc293b

+ 7 - 0
yudao-module-ai/yudao-spring-boot-starter-ai/pom.xml

@@ -60,6 +60,13 @@
             <version>2.14.0</version>
         </dependency>
 
+        <!-- bigmodel -->
+        <dependency>
+            <groupId>cn.bigmodel.openapi</groupId>
+            <artifactId>oapi-java-sdk</artifactId>
+            <version>release-V4-2.0.2</version>
+        </dependency>
+
         <!-- Test 测试相关 -->
         <dependency>
             <groupId>org.springframework.boot</groupId>

+ 75 - 0
yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/model/chatglm/ChatGlmImageModel.java

@@ -0,0 +1,75 @@
+package cn.iocoder.yudao.framework.ai.core.model.chatglm;
+
+import cn.iocoder.yudao.framework.ai.core.model.chatglm.api.ChatGlmResponseMetadata;
+import com.zhipu.oapi.ClientV4;
+import com.zhipu.oapi.service.v4.image.CreateImageRequest;
+import com.zhipu.oapi.service.v4.image.ImageApiResponse;
+import org.springframework.ai.image.*;
+
+import java.io.ByteArrayOutputStream;
+import java.net.URL;
+import java.util.Base64;
+import java.util.stream.Collectors;
+
+public class ChatGlmImageModel implements ImageModel {
+
+    private ClientV4 client;
+
+    public ChatGlmImageModel(String apiSecretKey) {
+        client = new ClientV4.Builder(apiSecretKey).build();
+    }
+
+    @Override
+    public ImageResponse call(ImagePrompt request) {
+        CreateImageRequest imageRequest = CreateImageRequest.builder()
+                .model(request.getOptions().getModel())
+                .prompt(request.getInstructions().get(0).getText())
+                .build();
+        return convert(client.createImage(imageRequest));
+    }
+
+    private ImageResponse convert(ImageApiResponse result) {
+        return new ImageResponse(
+                result.getData().getData().stream().map(item -> {
+                    try {
+                        String url = item.getUrl();
+                        String base64Image = convertImageToBase64(url);
+                        Image image = new Image(url, base64Image);
+                        return new ImageGeneration(image);
+                    } catch (Exception e) {
+                        throw new RuntimeException(e);
+                    }
+                }).collect(Collectors.toList()),
+                new ChatGlmResponseMetadata(result)
+        );
+    }
+
+
+    /**
+     * Convert image to base64.
+     * @param imageUrl the image url.
+     * @return the base64 image.
+     * @throws Exception the exception.
+     */
+    public String convertImageToBase64(String imageUrl) throws Exception {
+
+        var url = new URL(imageUrl);
+        var inputStream = url.openStream();
+        var outputStream = new ByteArrayOutputStream();
+        var buffer = new byte[4096];
+        int bytesRead;
+
+        while ((bytesRead = inputStream.read(buffer)) != -1) {
+            outputStream.write(buffer, 0, bytesRead);
+        }
+
+        var imageBytes = outputStream.toByteArray();
+
+        String base64Image = Base64.getEncoder().encodeToString(imageBytes);
+
+        inputStream.close();
+        outputStream.close();
+
+        return base64Image;
+    }
+}

+ 115 - 0
yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/model/chatglm/ChatGlmImageOptions.java

@@ -0,0 +1,115 @@
+package cn.iocoder.yudao.framework.ai.core.model.chatglm;
+
+import com.fasterxml.jackson.annotation.JsonProperty;
+import lombok.Setter;
+import org.springframework.ai.image.ImageOptions;
+
+/**
+ * chatglm
+ * api地址:https://open.bigmodel.cn/dev/api#cogview
+ */
+@Setter
+public class ChatGlmImageOptions implements ImageOptions {
+
+    @JsonProperty("n")
+    private Integer n;
+
+    @JsonProperty("model")
+    private String model = "cogview-3";
+
+    @JsonProperty("size_width")
+    private Integer width;
+
+    @JsonProperty("size_height")
+    private Integer height;
+
+    @JsonProperty("size")
+    private String size;
+
+    @JsonProperty("style")
+    private String style;
+
+    @JsonProperty("user_id")
+    private String user;
+
+    @JsonProperty("responseFormat")
+    private String responseFormat;
+
+    // ==== build
+
+
+    public static ChatGlmImageOptions.Builder builder() {
+        return new ChatGlmImageOptions.Builder();
+    }
+
+    public static class Builder {
+
+        private final ChatGlmImageOptions options;
+
+        private Builder() {
+            this.options = new ChatGlmImageOptions();
+        }
+
+        public ChatGlmImageOptions.Builder withN(Integer n) {
+            options.setN(n);
+            return this;
+        }
+
+        public ChatGlmImageOptions.Builder withModel(String model) {
+            options.setModel(model);
+            return this;
+        }
+
+        public ChatGlmImageOptions.Builder withWidth(Integer width) {
+            options.setWidth(width);
+            return this;
+        }
+
+        public ChatGlmImageOptions.Builder withHeight(Integer height) {
+            options.setHeight(height);
+            return this;
+        }
+
+        public ChatGlmImageOptions.Builder withStyle(String style) {
+            options.setStyle(style);
+            return this;
+        }
+
+        public ChatGlmImageOptions.Builder withUser(String user) {
+            options.setUser(user);
+            return this;
+        }
+
+        public ChatGlmImageOptions build() {
+            return options;
+        }
+
+    }
+
+    // ==== get
+
+    @Override
+    public Integer getN() {
+        return n;
+    }
+
+    @Override
+    public String getModel() {
+        return model;
+    }
+
+    @Override
+    public Integer getWidth() {
+        return width;
+    }
+
+    @Override
+    public Integer getHeight() {
+        return height;
+    }
+
+    @Override
+    public String getResponseFormat() {
+        return responseFormat;
+    }
+}

+ 24 - 0
yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/model/chatglm/api/ChatGlmResponseMetadata.java

@@ -0,0 +1,24 @@
+package cn.iocoder.yudao.framework.ai.core.model.chatglm.api;
+
+import com.zhipu.oapi.service.v4.image.ImageApiResponse;
+import org.springframework.ai.image.ImageResponseMetadata;
+
+import java.util.HashMap;
+
+public class ChatGlmResponseMetadata extends HashMap<String, Object> implements ImageResponseMetadata {
+
+    private Long created;
+
+    public ChatGlmResponseMetadata(ImageApiResponse result) {
+        created = result.getData().getCreated();
+    }
+
+    @Override
+    public Long getCreated() {
+        return created;
+    }
+
+    public void setCreated(Long created) {
+        this.created = created;
+    }
+}

+ 40 - 0
yudao-module-ai/yudao-spring-boot-starter-ai/src/test/java/cn/iocoder/yudao/framework/ai/image/ChatGlmImageModelTests.java

@@ -0,0 +1,40 @@
+package cn.iocoder.yudao.framework.ai.image;
+
+import cn.iocoder.yudao.framework.ai.core.model.chatglm.ChatGlmImageModel;
+import cn.iocoder.yudao.framework.ai.core.model.chatglm.ChatGlmImageOptions;
+import cn.iocoder.yudao.framework.common.util.json.JsonUtils;
+import com.alibaba.fastjson.JSON;
+import com.zhipu.oapi.ClientV4;
+import com.zhipu.oapi.core.httpclient.ApacheHttpClientTransport;
+import com.zhipu.oapi.service.v4.image.CreateImageRequest;
+import com.zhipu.oapi.service.v4.image.ImageApiResponse;
+import org.junit.jupiter.api.Test;
+import org.springframework.ai.image.ImageOptionsBuilder;
+import org.springframework.ai.image.ImagePrompt;
+import org.springframework.ai.image.ImageResponse;
+import org.springframework.ai.qianfan.QianFanImageModel;
+import org.springframework.ai.qianfan.QianFanImageOptions;
+import org.springframework.ai.qianfan.api.QianFanImageApi;
+
+/**
+ * 百度千帆 image
+ */
+public class ChatGlmImageModelTests {
+
+    @Test
+    public void callTest() {
+        ChatGlmImageModel model = new ChatGlmImageModel("78d3228c1d9e5e342a3e1ab349e2dd7b.VXLoq5vrwK2ofboy");
+        ImageResponse call = model.call(new ImagePrompt("万里长城", ChatGlmImageOptions.builder().build()));
+        System.err.println(call.getResult().getOutput().getUrl());
+    }
+
+    @Test
+    public void createImageTest() {
+        ClientV4 client = new ClientV4.Builder("78d3228c1d9e5e342a3e1ab349e2dd7b.VXLoq5vrwK2ofboy").build();
+        CreateImageRequest createImageRequest = new CreateImageRequest();
+        createImageRequest.setModel("cogview-3");
+        createImageRequest.setPrompt("长城!");
+        ImageApiResponse image = client.createImage(createImageRequest);
+        System.err.println(JSON.toJSONString(image));
+    }
+}