Переглянути джерело

【代码优化】AI:MJ 配置类的简化

YunaiV 1 рік тому
батько
коміт
b4eed07d61

+ 0 - 24
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/config/YudaoMidjourneyConfiguration.java

@@ -1,24 +0,0 @@
-package cn.iocoder.yudao.module.ai.config;
-
-import cn.iocoder.yudao.framework.ai.core.model.midjourney.MidjourneyConfig;
-import cn.iocoder.yudao.framework.ai.core.model.midjourney.api.MidjourneyApi;
-import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
-import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
-import org.springframework.context.annotation.Bean;
-import org.springframework.context.annotation.Configuration;
-
-/**
- * 配置
- *
- * @author fansili
- * @time 2024/6/13 09:50
- */
-@Configuration
-public class YudaoMidjourneyConfiguration {
-
-    @Bean
-    @ConditionalOnProperty(value = "ai.midjourney-proxy.enable", havingValue = "true")
-    public MidjourneyApi midjourneyApi(YudaoMidjourneyProperties midjourneyProperties) {
-        return new MidjourneyApi(BeanUtils.toBean(midjourneyProperties, MidjourneyConfig.class));
-    }
-}

+ 0 - 22
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/config/YudaoMidjourneyProperties.java

@@ -1,22 +0,0 @@
-package cn.iocoder.yudao.module.ai.config;
-
-import lombok.Data;
-import org.springframework.boot.context.properties.ConfigurationProperties;
-import org.springframework.context.annotation.Configuration;
-
-/**
- * Midjourney 属性
- *
- * @author fansili
- * @time 2024/6/5 15:02
- * @since 1.0
- */
-@Configuration
-@ConfigurationProperties(prefix = "ai.midjourney-proxy")
-@Data
-public class YudaoMidjourneyProperties {
-
-    private String enable;
-    private String key;
-    private String url;
-}

+ 2 - 6
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/image/AiImageServiceImpl.java

@@ -29,7 +29,6 @@ import org.springframework.ai.image.ImagePrompt;
 import org.springframework.ai.image.ImageResponse;
 import org.springframework.ai.openai.OpenAiImageOptions;
 import org.springframework.ai.stabilityai.api.StabilityAiImageOptions;
-import org.springframework.beans.factory.annotation.Value;
 import org.springframework.scheduling.annotation.Async;
 import org.springframework.stereotype.Service;
 import org.springframework.transaction.annotation.Transactional;
@@ -63,9 +62,6 @@ public class AiImageServiceImpl implements AiImageService {
     @Resource
     private MidjourneyApi midjourneyApi;
 
-    @Value("${ai.midjourney-proxy.notifyUrl:http://127.0.0.1:48080/admin-api/ai/image/midjourney-notify}")
-    private String midjourneyNotifyUrl;
-
     @Override
     public PageResult<AiImageDO> getImagePageMy(Long userId, PageParam pageReqVO) {
         return imageMapper.selectPage(userId, pageReqVO);
@@ -159,7 +155,7 @@ public class AiImageServiceImpl implements AiImageService {
 
         // 2. 调用 Midjourney Proxy 提交任务
         MidjourneyApi.ImagineRequest imagineRequest = new MidjourneyApi.ImagineRequest(
-                null, midjourneyNotifyUrl, reqVO.getPrompt(),
+                null, reqVO.getPrompt(),null,
                 MidjourneyApi.ImagineRequest.buildState(reqVO.getWidth(), reqVO.getHeight(), reqVO.getVersion(), reqVO.getModel()));
         MidjourneyApi.SubmitResponse imagineResponse = midjourneyApi.imagine(imagineRequest);
 
@@ -258,7 +254,7 @@ public class AiImageServiceImpl implements AiImageService {
 
         // 2. 调用 Midjourney Proxy 提交任务
         MidjourneyApi.SubmitResponse actionResponse = midjourneyApi.action(
-                new MidjourneyApi.ActionRequest(button.customId(), image.getTaskId(), midjourneyNotifyUrl));
+                new MidjourneyApi.ActionRequest(button.customId(), image.getTaskId(), null));
         if (!MidjourneyApi.SubmitCodeEnum.SUCCESS_CODES.contains(actionResponse.code())) {
             String description = actionResponse.description().contains("quota_not_enough") ?
                     "账户余额不足" : actionResponse.description();

+ 8 - 0
yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/config/YudaoAiAutoConfiguration.java

@@ -2,6 +2,7 @@ package cn.iocoder.yudao.framework.ai.config;
 
 import cn.iocoder.yudao.framework.ai.core.factory.AiClientFactory;
 import cn.iocoder.yudao.framework.ai.core.factory.AiClientFactoryImpl;
+import cn.iocoder.yudao.framework.ai.core.model.midjourney.api.MidjourneyApi;
 import cn.iocoder.yudao.framework.ai.core.model.suno.api.SunoApi;
 import cn.iocoder.yudao.framework.ai.core.model.tongyi.QianWenChatClient;
 import cn.iocoder.yudao.framework.ai.core.model.tongyi.QianWenChatModal;
@@ -96,6 +97,13 @@ public class YudaoAiAutoConfiguration {
         );
     }
 
+    @Bean
+    @ConditionalOnProperty(value = "yudao.ai.midjourney.enable", havingValue = "true")
+    public MidjourneyApi midjourneyApi(YudaoAiProperties yudaoAiProperties) {
+        YudaoAiProperties.MidjourneyProperties config = yudaoAiProperties.getMidjourney();
+        return new MidjourneyApi(config.getBaseUrl(), config.getApiKey(), config.getNotifyUrl());
+    }
+
     @Bean
     @ConditionalOnProperty(value = "yudao.ai.suno.enable", havingValue = "true")
     public SunoApi sunoApi(YudaoAiProperties yudaoAiProperties) {

+ 8 - 18
yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/config/YudaoAiProperties.java

@@ -64,15 +64,18 @@ public class YudaoAiProperties {
     @Data
     @Accessors(chain = true)
     public static class XingHuoProperties extends ChatProperties {
+
         private String appId;
         private String appKey;
         private String secretKey;
         private XingHuoChatModel model;
+
     }
 
     @Data
     @Accessors(chain = true)
     public static class YiYanProperties extends ChatProperties {
+
         /**
          * appKey
          */
@@ -92,26 +95,13 @@ public class YudaoAiProperties {
     }
 
     @Data
-    @Accessors(chain = true)
     public static class MidjourneyProperties {
-        private boolean enable = false;
 
-        /**
-         * socket 链接地址
-         */
-        private String wssUrl = "wss://gateway.discord.gg";
-        /**
-         * token
-         */
-        private String token;
-        /**
-         * 服务id
-         */
-        private String guildId;
-        /**
-         * 频道id
-         */
-        private String channelId;
+        private String enable;
+        private String apiKey;
+        private String baseUrl;
+        private String notifyUrl;
+
     }
 
     @Data

+ 59 - 24
yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/model/midjourney/api/MidjourneyApi.java

@@ -1,10 +1,11 @@
 package cn.iocoder.yudao.framework.ai.core.model.midjourney.api;
 
-import cn.iocoder.yudao.framework.ai.core.model.midjourney.MidjourneyConfig;
+import cn.hutool.core.util.StrUtil;
 import cn.iocoder.yudao.framework.common.util.json.JsonUtils;
 import com.google.common.collect.ImmutableMap;
 import com.google.common.collect.Lists;
 import lombok.AllArgsConstructor;
+import lombok.Data;
 import lombok.Getter;
 import lombok.extern.slf4j.Slf4j;
 import org.springframework.ai.openai.api.ApiUtils;
@@ -26,11 +27,17 @@ public class MidjourneyApi {
 
     private final WebClient webClient;
 
-    public MidjourneyApi(MidjourneyConfig midjourneyConfig) {
+    /**
+     * 回调地址
+     */
+    private final String notifyUrl;
+
+    public MidjourneyApi(String baseUrl, String apiKey, String notifyUrl) {
         this.webClient = WebClient.builder()
-                .baseUrl(midjourneyConfig.getUrl())
-                .defaultHeaders(ApiUtils.getJsonContentHeaders(midjourneyConfig.getKey()))
+                .baseUrl(baseUrl)
+                .defaultHeaders(ApiUtils.getJsonContentHeaders(apiKey))
                 .build();
+        this.notifyUrl = notifyUrl;
     }
 
     /**
@@ -40,6 +47,9 @@ public class MidjourneyApi {
      * @return 提交结果
      */
     public SubmitResponse imagine(ImagineRequest request) {
+        if (StrUtil.isEmpty(request.getNotifyHook())) {
+            request.setNotifyHook(notifyUrl);
+        }
         String response = post("/submit/imagine", request);
         return JsonUtils.parseObject(response, SubmitResponse.class);
     }
@@ -51,8 +61,11 @@ public class MidjourneyApi {
      * @return 提交结果
      */
     public SubmitResponse action(ActionRequest request) {
-        String res = post("/submit/action", request);
-        return JsonUtils.parseObject(res, SubmitResponse.class);
+        if (StrUtil.isEmpty(request.getNotifyHook())) {
+            request.setNotifyHook(notifyUrl);
+        }
+        String response = post("/submit/action", request);
+        return JsonUtils.parseObject(response, SubmitResponse.class);
     }
 
     /**
@@ -86,23 +99,40 @@ public class MidjourneyApi {
 
     /**
      * Imagine 请求(生成图片)
-     *
-     * @param base64Array 垫图(参考图) base64数 组
-     * @param notifyHook 通知地址
-     * @param prompt 提示词
-     * @param state 自定义参数
      */
-    public record ImagineRequest(List<String> base64Array,
-                                 String notifyHook,
-                                 String prompt,
-                                 String state) {
+    @Data
+    public static final class ImagineRequest {
+
+        /**
+         * 垫图(参考图) base64 数组
+         */
+        private List<String> base64Array;
+        /**
+         * 提示词
+         */
+        private String prompt;
+        /**
+         * 通知地址
+         */
+        private String notifyHook;
+        /**
+         * 自定义参数
+         */
+        private String state;
+
+        public ImagineRequest(List<String> base64Array, String prompt, String notifyHook, String state) {
+            this.base64Array = base64Array;
+            this.prompt = prompt;
+            this.notifyHook = notifyHook;
+            this.state = state;
+        }
 
         public static String buildState(Integer width, Integer height, String version, String model) {
             StringBuilder params = new StringBuilder();
             //  --ar 来设置尺寸
             params.append(String.format(" --ar %s:%s ", width, height));
             // --niji 模型
-            if (MidjourneyApi.ModelEnum.NIJI.getModel().equals(model)) {
+            if (ModelEnum.NIJI.getModel().equals(model)) {
                 params.append(String.format(" --niji %s ", version));
             } else {
                 params.append(String.format(" --v %s ", version));
@@ -114,15 +144,20 @@ public class MidjourneyApi {
 
     /**
      * Action 请求
-     *
-     * @param customId   操作按钮id
-     * @param taskId     操作按钮id
-     * @param notifyHook 通知地址
      */
-    public record ActionRequest(String customId,
-                                String taskId,
-                                String notifyHook
-    ) {
+    @Data
+    public static final class ActionRequest {
+
+        private String customId;
+        private String taskId;
+        private String notifyHook;
+
+        public ActionRequest(String taskId, String customId, String notifyHook) {
+            this.customId = customId;
+            this.taskId = taskId;
+            this.notifyHook = notifyHook;
+        }
+
     }
 
     /**

+ 4 - 10
yudao-server/src/main/resources/application.yaml

@@ -194,20 +194,14 @@ yudao.ai:
     api-key: sk-Zsd81gZYg7
   midjourney:
     enable: true
-    token: MTE4MjE3MjY2MjkxNTY3ODIzOA.GEV1SG.c49F8lZoGCUHwsj8O0UdodmM6nyQHvuD2fXflw
-    guild-id: 1237948819677904956
-    channel-id: 1237948819677904960
+#    base-url: https://api.holdai.top/mj-relax/mj
+    base-url: https://api.holdai.top/mj
+    api-key: sk-dZEPiVaNcT3FHhef51996bAa0bC74806BeAb620dA5Da10Bf
+    notify-url: http://java.nat300.top/admin-api/ai/image/midjourney/notify
   suno:
     enable: true
     base-url: https://suno-imrqwwui8-status2xxs-projects.vercel.app
 
-ai:
-  midjourney-proxy:
-    enable: true
-    url: https://api.holdai.top/mj
-    notifyUrl: http://61d61685.r21.cpolar.top/admin-api/ai/image/midjourney-notify
-    key: sk-dZEPiVaNcT3FHhef51996bAa0bC74806BeAb620dA5Da10Bf
-
 --- #################### 芋道相关配置 ####################
 
 yudao: