Browse Source

【修改todo】增加 sd 各种参数

cherishsince 10 months ago
parent
commit
f300b8a1ae

+ 1 - 0
yudao-module-ai/yudao-module-ai-api/src/main/java/cn/iocoder/yudao/module/ai/enums/ErrorCodeConstants.java

@@ -37,6 +37,7 @@ public interface ErrorCodeConstants {
     ErrorCode IMAGE_NOT_EXISTS = new ErrorCode(1_022_005_000, "图片不存在!");
     ErrorCode IMAGE_MIDJOURNEY_SUBMIT_FAIL = new ErrorCode(1_022_005_001, "Midjourney 提交失败!原因:{}");
     ErrorCode IMAGE_CUSTOM_ID_NOT_EXISTS = new ErrorCode(1_022_005_002, "Midjourney 按钮 customId 不存在! {}");
+    ErrorCode IMAGE_FAIL = new ErrorCode(1_022_005_002, "图片绘画失败! {}");
 
     // ========== API 音乐 1-040-006-000 ==========
     ErrorCode MUSIC_NOT_EXISTS = new ErrorCode(1_022_006_000, "音乐不存在!");

+ 8 - 1
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/image/AiImageServiceImpl.java

@@ -123,9 +123,16 @@ public class AiImageServiceImpl implements AiImageService {
                     .withResponseFormat("b64_json")
                     .build();
         } else if (ObjUtil.equal(draw.getPlatform(), AiPlatformEnum.STABLE_DIFFUSION.getPlatform())) {
+            // https://platform.stability.ai/docs/api-reference#tag/SDXL-and-SD1.6/operation/textToImage
             // https://platform.stability.ai/docs/api-reference#tag/Text-to-Image/operation/textToImage
             return StabilityAiImageOptions.builder().withModel(draw.getModel())
-                    .withHeight(draw.getHeight()).withWidth(draw.getWidth()) // TODO @范:各种参数的接入
+                    .withHeight(draw.getHeight()).withWidth(draw.getWidth())
+                    .withSeed(Long.valueOf(draw.getOptions().get("seed")))
+                    .withCfgScale(Float.valueOf(draw.getOptions().get("scale")))
+                    .withSteps(Integer.valueOf(draw.getOptions().get("steps")))
+                    .withSampler(String.valueOf(draw.getOptions().get("sampler")))
+                    .withStylePreset(String.valueOf(draw.getOptions().get("stylePreset")))
+                    .withClipGuidancePreset(String.valueOf(draw.getOptions().get("clipGuidancePreset")))
                     .build();
         }
         throw new IllegalArgumentException("不支持的 AI 平台:" + draw.getPlatform());

+ 105 - 105
yudao-module-ai/yudao-spring-boot-starter-ai/src/test/java/cn/iocoder/yudao/framework/ai/chat/QianWenChatClientTests.java

@@ -1,105 +1,105 @@
-package cn.iocoder.yudao.framework.ai.chat;
-
-import cn.iocoder.yudao.framework.ai.core.model.tongyi.QianWenChatClient;
-import cn.iocoder.yudao.framework.ai.core.model.tongyi.QianWenChatModal;
-import cn.iocoder.yudao.framework.ai.core.model.tongyi.QianWenOptions;
-import cn.iocoder.yudao.framework.ai.core.model.tongyi.api.QianWenApi;
-import com.alibaba.dashscope.aigc.generation.GenerationResult;
-import com.alibaba.dashscope.aigc.generation.models.QwenParam;
-import com.alibaba.dashscope.common.Message;
-import com.alibaba.dashscope.common.MessageManager;
-import com.alibaba.dashscope.common.Role;
-import com.alibaba.dashscope.exception.InputRequiredException;
-import com.alibaba.dashscope.exception.NoApiKeyException;
-import org.junit.Before;
-import org.junit.Test;
-import org.springframework.ai.chat.messages.SystemMessage;
-import org.springframework.ai.chat.messages.UserMessage;
-import org.springframework.ai.chat.model.ChatResponse;
-import org.springframework.ai.chat.prompt.Prompt;
-import reactor.core.publisher.Flux;
-
-import java.util.ArrayList;
-import java.util.List;
-import java.util.Scanner;
-import java.util.function.Consumer;
-
-// TODO 芋艿:整理单测
-/**
- * author: fansili
- * time: 2024/3/13 21:37
- */
-public class QianWenChatClientTests {
-
-    private QianWenChatClient qianWenChatClient;
-
-    @Before
-    public void setup() {
-        QianWenApi qianWenApi = new QianWenApi("sk-Zsd81gZYg7", QianWenChatModal.QWEN_72B_CHAT);
-        QianWenOptions qianWenOptions = new QianWenOptions();
-        qianWenOptions.setTopP(0.8F);
-//        qianWenOptions.setTopK(3); TODO 芋艿:临时处理
-//        qianWenOptions.setTemperature(0.6F); TODO 芋艿:临时处理
-        qianWenChatClient = new QianWenChatClient(
-                qianWenApi,
-                qianWenOptions
-        );
-    }
-
-    @Test
-    public void callTest() {
-        List<org.springframework.ai.chat.messages.Message> messages = new ArrayList<>();
-        messages.add(new SystemMessage("你是一个优质的小红书文艺作者,抒写着各城市的美好文化和风景。"));
-        messages.add(new UserMessage("长沙怎么样?"));
-
-        ChatResponse call = qianWenChatClient.call(new Prompt(messages));
-        System.err.println(call.getResult());
-    }
-
-    @Test
-    public void streamTest() {
-        List<org.springframework.ai.chat.messages.Message> messages = new ArrayList<>();
-        messages.add(new SystemMessage("你是一个优质的文言文作者,用文言文描述着各城市的人文风景。"));
-        messages.add(new UserMessage("长沙怎么样?"));
-
-        Flux<ChatResponse> flux = qianWenChatClient.stream(new Prompt(messages));
-        flux.subscribe(new Consumer<ChatResponse>() {
-            @Override
-            public void accept(ChatResponse chatResponse) {
-                System.err.print(chatResponse.getResult().getOutput().getContent());
-            }
-        });
-
-        // 阻止退出
-        Scanner scanner = new Scanner(System.in);
-        scanner.nextLine();
-    }
-
-    @Test
-    public void qianwenDemoTest() throws NoApiKeyException, InputRequiredException {
-        com.alibaba.dashscope.aigc.generation.Generation gen = new com.alibaba.dashscope.aigc.generation.Generation();
-        MessageManager msgManager = new MessageManager(10);
-        Message systemMsg =
-                Message.builder().role(Role.SYSTEM.getValue()).content("You are a helpful assistant.").build();
-        Message userMsg = Message.builder().role(Role.USER.getValue()).content("就当前的海洋污染的情况,写一份限塑的倡议书提纲,需要有理有据地号召大家克制地使用塑料制品").build();
-        msgManager.add(systemMsg);
-        msgManager.add(userMsg);
-        QwenParam param =
-                QwenParam.builder().model("qwen-72b-chat").messages(msgManager.get())
-                        .resultFormat(QwenParam.ResultFormat.MESSAGE)
-                        .topP(0.8)
-                        /* set the random seed, optional, default to 1234 if not set */
-                        .seed(100)
-                        .apiKey("sk-Zsd81gZYg7")
-                        .build();
-        GenerationResult result = gen.call(param);
-        System.out.println(result);
-        System.out.println("-----------------");
-        System.out.println("-----------------");
-        msgManager.add(result);
-        param.setPrompt("能否缩短一些,只讲三点");
-        param.setMessages(msgManager.get());
-        result = gen.call(param);
-        System.out.println(result);
-    }
-}
+//package cn.iocoder.yudao.framework.ai.chat;
+//
+//import cn.iocoder.yudao.framework.ai.core.model.tongyi.QianWenChatClient;
+//import cn.iocoder.yudao.framework.ai.core.model.tongyi.QianWenChatModal;
+//import cn.iocoder.yudao.framework.ai.core.model.tongyi.QianWenOptions;
+//import cn.iocoder.yudao.framework.ai.core.model.tongyi.api.QianWenApi;
+//import com.alibaba.dashscope.aigc.generation.GenerationResult;
+//import com.alibaba.dashscope.aigc.generation.models.QwenParam;
+//import com.alibaba.dashscope.common.Message;
+//import com.alibaba.dashscope.common.MessageManager;
+//import com.alibaba.dashscope.common.Role;
+//import com.alibaba.dashscope.exception.InputRequiredException;
+//import com.alibaba.dashscope.exception.NoApiKeyException;
+//import org.junit.Before;
+//import org.junit.Test;
+//import org.springframework.ai.chat.messages.SystemMessage;
+//import org.springframework.ai.chat.messages.UserMessage;
+//import org.springframework.ai.chat.model.ChatResponse;
+//import org.springframework.ai.chat.prompt.Prompt;
+//import reactor.core.publisher.Flux;
+//
+//import java.util.ArrayList;
+//import java.util.List;
+//import java.util.Scanner;
+//import java.util.function.Consumer;
+//
+//// TODO 芋艿:整理单测
+///**
+// * author: fansili
+// * time: 2024/3/13 21:37
+// */
+//public class QianWenChatClientTests {
+//
+//    private QianWenChatClient qianWenChatClient;
+//
+//    @Before
+//    public void setup() {
+//        QianWenApi qianWenApi = new QianWenApi("sk-Zsd81gZYg7", QianWenChatModal.QWEN_72B_CHAT);
+//        QianWenOptions qianWenOptions = new QianWenOptions();
+//        qianWenOptions.setTopP(0.8F);
+////        qianWenOptions.setTopK(3); TODO 芋艿:临时处理
+////        qianWenOptions.setTemperature(0.6F); TODO 芋艿:临时处理
+//        qianWenChatClient = new QianWenChatClient(
+//                qianWenApi,
+//                qianWenOptions
+//        );
+//    }
+//
+//    @Test
+//    public void callTest() {
+//        List<org.springframework.ai.chat.messages.Message> messages = new ArrayList<>();
+//        messages.add(new SystemMessage("你是一个优质的小红书文艺作者,抒写着各城市的美好文化和风景。"));
+//        messages.add(new UserMessage("长沙怎么样?"));
+//
+//        ChatResponse call = qianWenChatClient.call(new Prompt(messages));
+//        System.err.println(call.getResult());
+//    }
+//
+//    @Test
+//    public void streamTest() {
+//        List<org.springframework.ai.chat.messages.Message> messages = new ArrayList<>();
+//        messages.add(new SystemMessage("你是一个优质的文言文作者,用文言文描述着各城市的人文风景。"));
+//        messages.add(new UserMessage("长沙怎么样?"));
+//
+//        Flux<ChatResponse> flux = qianWenChatClient.stream(new Prompt(messages));
+//        flux.subscribe(new Consumer<ChatResponse>() {
+//            @Override
+//            public void accept(ChatResponse chatResponse) {
+//                System.err.print(chatResponse.getResult().getOutput().getContent());
+//            }
+//        });
+//
+//        // 阻止退出
+//        Scanner scanner = new Scanner(System.in);
+//        scanner.nextLine();
+//    }
+//
+//    @Test
+//    public void qianwenDemoTest() throws NoApiKeyException, InputRequiredException {
+//        com.alibaba.dashscope.aigc.generation.Generation gen = new com.alibaba.dashscope.aigc.generation.Generation();
+//        MessageManager msgManager = new MessageManager(10);
+//        Message systemMsg =
+//                Message.builder().role(Role.SYSTEM.getValue()).content("You are a helpful assistant.").build();
+//        Message userMsg = Message.builder().role(Role.USER.getValue()).content("就当前的海洋污染的情况,写一份限塑的倡议书提纲,需要有理有据地号召大家克制地使用塑料制品").build();
+//        msgManager.add(systemMsg);
+//        msgManager.add(userMsg);
+//        QwenParam param =
+//                QwenParam.builder().model("qwen-72b-chat").messages(msgManager.get())
+//                        .resultFormat(QwenParam.ResultFormat.MESSAGE)
+//                        .topP(0.8)
+//                        /* set the random seed, optional, default to 1234 if not set */
+//                        .seed(100)
+//                        .apiKey("sk-Zsd81gZYg7")
+//                        .build();
+//        GenerationResult result = gen.call(param);
+//        System.out.println(result);
+//        System.out.println("-----------------");
+//        System.out.println("-----------------");
+//        msgManager.add(result);
+//        param.setPrompt("能否缩短一些,只讲三点");
+//        param.setMessages(msgManager.get());
+//        result = gen.call(param);
+//        System.out.println(result);
+//    }
+//}