Browse Source

对接openai image

cherishsince 1 năm trước cách đây
mục cha
commit
5999b80471

+ 0 - 1
yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/image/ImageClient.java

@@ -19,7 +19,6 @@ package cn.iocoder.yudao.framework.ai.image;
 
 import cn.iocoder.yudao.framework.ai.model.ModelClient;
 
-@FunctionalInterface
 public interface ImageClient extends ModelClient<ImagePrompt, ImageResponse> {
 
 	/**

+ 49 - 13
yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/imageopenai/OpenAiImageApi.java

@@ -1,14 +1,25 @@
 package cn.iocoder.yudao.framework.ai.imageopenai;
 
+import cn.hutool.json.JSONUtil;
 import cn.iocoder.yudao.framework.ai.imageopenai.api.OpenAiImageRequest;
 import cn.iocoder.yudao.framework.ai.imageopenai.api.OpenAiImageResponse;
 import cn.iocoder.yudao.framework.ai.util.JacksonUtil;
 import io.netty.channel.ChannelOption;
+import lombok.extern.slf4j.Slf4j;
+import org.apache.http.HttpEntity;
+import org.apache.http.client.methods.CloseableHttpResponse;
+import org.apache.http.client.methods.HttpPost;
+import org.apache.http.entity.StringEntity;
+import org.apache.http.impl.client.CloseableHttpClient;
+import org.apache.http.impl.client.HttpClients;
+import org.apache.http.util.EntityUtils;
 import org.springframework.http.client.reactive.ReactorClientHttpConnector;
 import org.springframework.web.reactive.function.BodyInserters;
 import org.springframework.web.reactive.function.client.WebClient;
 import reactor.netty.http.client.HttpClient;
 
+import java.io.IOException;
+import java.net.URI;
 import java.time.Duration;
 
 /**
@@ -17,6 +28,7 @@ import java.time.Duration;
  * author: fansili
  * time: 2024/3/17 09:53
  */
+@Slf4j
 public class OpenAiImageApi {
 
     private static final String DEFAULT_BASE_URL = "https://api.openai.com";
@@ -24,6 +36,8 @@ public class OpenAiImageApi {
     // 发送请求 webClient
     private final WebClient webClient;
 
+    private CloseableHttpClient httpclient = HttpClients.createDefault();
+
     public OpenAiImageApi(String apiKey) {
         this.apiKey = apiKey;
         // 创建一个HttpClient实例并设置超时
@@ -37,18 +51,40 @@ public class OpenAiImageApi {
     }
 
     public OpenAiImageResponse createImage(OpenAiImageRequest request) {
-        String res = webClient.post()
-                .uri(uriBuilder -> uriBuilder.path("/v1/images/generations").build())
-                .header("Content-Type", "application/json")
-                .header("Authorization", "Bearer " + apiKey)
-                // 设置请求体(这里假设jsonStr是一个JSON格式的字符串)
-                .body(BodyInserters.fromValue(JacksonUtil.toJson(request)))
-                // 发送请求并获取响应体
-                .retrieve()
-                // 转换响应体为String类型
-                .bodyToMono(String.class)
-                .block();
-        // TODO: 2024/3/17 这里发送请求会失败!
-        return null;
+        HttpPost httpPost = new HttpPost();
+        httpPost.setURI(URI.create(DEFAULT_BASE_URL.concat("/v1/images/generations")));
+        httpPost.setHeader("Content-Type", "application/json");
+        httpPost.setHeader("Authorization", "Bearer " + apiKey);
+        httpPost.setEntity(new StringEntity(JacksonUtil.toJson(request), "UTF-8"));
+
+        CloseableHttpResponse response= null;
+        try {
+            response = httpclient.execute(httpPost);
+            HttpEntity entity = response.getEntity();
+            String resultJson = EntityUtils.toString(entity);
+            log.info("openai 图片生成结果: {}", resultJson);
+            return JSONUtil.toBean(resultJson, OpenAiImageResponse.class);
+        } catch (IOException e) {
+            throw new RuntimeException(e);
+        } finally {
+            if (response != null) {
+                try {
+                    response.close();
+                } catch (IOException e) {
+                    throw new RuntimeException(e);
+                }
+            }
+        }
+//        String res = webClient.post()
+//                .uri(uriBuilder -> uriBuilder.path("/v1/images/generations").build())
+//                .header("Content-Type", "application/json")
+//                .header("Authorization", "Bearer " + apiKey)
+//                // 设置请求体(这里假设jsonStr是一个JSON格式的字符串)
+//                .body(BodyInserters.fromValue(JacksonUtil.toJson(request)))
+//                // 发送请求并获取响应体
+//                .retrieve()
+//                // 转换响应体为String类型
+//                .bodyToMono(String.class)
+//                .block();
     }
 }

+ 10 - 6
yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/imageopenai/OpenAiImageClient.java

@@ -1,15 +1,13 @@
 package cn.iocoder.yudao.framework.ai.imageopenai;
 
 import cn.hutool.core.bean.BeanUtil;
+import cn.hutool.core.codec.Base64;
+import cn.hutool.http.HttpUtil;
 import cn.iocoder.yudao.framework.ai.chat.ChatException;
 import cn.iocoder.yudao.framework.ai.chatyiyan.exception.YiYanApiException;
-import cn.iocoder.yudao.framework.ai.image.ImageClient;
-import cn.iocoder.yudao.framework.ai.image.ImageOptions;
-import cn.iocoder.yudao.framework.ai.image.ImagePrompt;
-import cn.iocoder.yudao.framework.ai.image.ImageResponse;
+import cn.iocoder.yudao.framework.ai.image.*;
 import cn.iocoder.yudao.framework.ai.imageopenai.api.OpenAiImageRequest;
 import cn.iocoder.yudao.framework.ai.imageopenai.api.OpenAiImageResponse;
-import jdk.jfr.Frequency;
 import lombok.extern.slf4j.Slf4j;
 import org.springframework.retry.RetryCallback;
 import org.springframework.retry.RetryContext;
@@ -74,9 +72,15 @@ public class OpenAiImageClient implements ImageClient {
             // 创建请求
             OpenAiImageRequest request = new OpenAiImageRequest();
             BeanUtil.copyProperties(openAiImageOptions, request);
+            request.setPrompt(imagePrompt.getInstructions().get(0).getText());
             // 发送请求
             OpenAiImageResponse response = openAiImageApi.createImage(request);
-            return null;
+            return new ImageResponse(response.getData().stream().map(res -> {
+                byte[] bytes = HttpUtil.downloadBytes(res.getUrl());
+                String base64 = Base64.encode(bytes);
+                return new ImageGeneration(new Image(res.getUrl(), base64));
+            }).toList());
         });
     }
+
 }

+ 16 - 0
yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/imageopenai/OpenAiImageOptions.java

@@ -2,6 +2,7 @@ package cn.iocoder.yudao.framework.ai.imageopenai;
 
 import cn.iocoder.yudao.framework.ai.image.ImageOptions;
 import lombok.Data;
+import lombok.Getter;
 import lombok.experimental.Accessors;
 
 /**
@@ -47,6 +48,21 @@ public class OpenAiImageOptions implements ImageOptions {
     // 代表您的终端用户的唯一标识符,有助于OpenAI监控并检测滥用行为。了解更多信息请参考官方文档。
     private String endUserId;
 
+    @Getter
+    public enum ResponseFormatEnum {
+
+        URL("url"),
+        BASE64("b64_json"),
+
+        ;
+
+        ResponseFormatEnum(String value) {
+            this.value = value;
+        }
+
+        private String value;
+    }
+
     //
     // 适配 spring ai
 

+ 1 - 0
yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/imageopenai/api/OpenAiImageResponse.java

@@ -23,6 +23,7 @@ public class OpenAiImageResponse {
     public static class Item {
 
         private String url;
+        private String b64_json;
 
     }
 }

+ 38 - 2
yudao-module-ai/yudao-spring-boot-starter-ai/src/test/java/cn/iocoder/yudao/framework/ai/image/OpenAiImageClientTests.java

@@ -6,6 +6,14 @@ import cn.iocoder.yudao.framework.ai.imageopenai.OpenAiImageOptions;
 import org.junit.Before;
 import org.junit.Test;
 
+import javax.imageio.ImageIO;
+import javax.swing.*;
+import java.awt.image.BufferedImage;
+import java.io.ByteArrayInputStream;
+import java.io.IOException;
+import java.util.Base64;
+import java.util.Scanner;
+
 /**
  * author: fansili
  * time: 2024/3/17 10:40
@@ -20,12 +28,40 @@ public class OpenAiImageClientTests {
         // 初始化 openAiImageClient
         this.openAiImageClient = new OpenAiImageClient(
                 new OpenAiImageApi(""),
-                new OpenAiImageOptions()
+                new OpenAiImageOptions().setResponseFormat(OpenAiImageOptions.ResponseFormatEnum.URL.getValue())
         );
     }
 
     @Test
     public void callTest() {
-        openAiImageClient.call(new ImagePrompt("我和我的小狗,一起在北极和企鹅玩排球。"));
+        ImageResponse call = openAiImageClient.call(new ImagePrompt("我和我的小狗,一起在北极和企鹅玩排球。"));
+        System.err.println("url: " + call.getResult().getOutput().getUrl());
+        System.err.println("base64: " + call.getResult().getOutput().getB64Json());
+
+        String base64String = call.getResult().getOutput().getB64Json();
+        ImageIcon imageIcon = new ImageIcon(decodeBase64ToImage(base64String));
+        JLabel label = new JLabel(imageIcon);
+
+        JFrame frame = new JFrame("Base64 Image Display");
+        frame.getContentPane().add(label);
+        frame.pack();
+        frame.setVisible(true);
+
+        // 阻止退出
+        Scanner scanner = new Scanner(System.in);
+        scanner.nextLine();
+    }
+
+
+    // 将Base64解码为BufferedImage
+    private static BufferedImage decodeBase64ToImage(String base64String) {
+        try {
+            byte[] decodedBytes = Base64.getDecoder().decode(base64String);
+            ByteArrayInputStream bis = new ByteArrayInputStream(decodedBytes);
+            return ImageIO.read(bis);
+        } catch (IOException e) {
+            System.out.println("Error decoding the base64 image: " + e.getMessage());
+            return null;
+        }
     }
 }