Преглед изворни кода

【增加】midjourney 自动配置。

cherishsince пре 1 година
родитељ
комит
d14c13b2fc

+ 53 - 0
yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/config/YudaoAiAutoConfiguration.java

@@ -1,5 +1,6 @@
 package cn.iocoder.yudao.framework.ai.config;
 
+import cn.hutool.core.io.IoUtil;
 import cn.iocoder.yudao.framework.ai.chatqianwen.QianWenChatClient;
 import cn.iocoder.yudao.framework.ai.chatqianwen.QianWenChatModal;
 import cn.iocoder.yudao.framework.ai.chatqianwen.QianWenOptions;
@@ -13,10 +14,21 @@ import cn.iocoder.yudao.framework.ai.chatyiyan.api.YiYanApi;
 import cn.iocoder.yudao.framework.ai.imageopenai.OpenAiImageApi;
 import cn.iocoder.yudao.framework.ai.imageopenai.OpenAiImageClient;
 import cn.iocoder.yudao.framework.ai.imageopenai.OpenAiImageOptions;
+import cn.iocoder.yudao.framework.ai.midjourney.MidjourneyConfig;
+import cn.iocoder.yudao.framework.ai.midjourney.api.MidjourneyInteractionsApi;
+import cn.iocoder.yudao.framework.ai.midjourney.webSocket.MidjourneyWebSocketStarter;
+import cn.iocoder.yudao.framework.ai.midjourney.webSocket.listener.MidjourneyMessageListener;
+import org.jetbrains.annotations.NotNull;
 import org.springframework.boot.autoconfigure.AutoConfiguration;
 import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
 import org.springframework.boot.context.properties.EnableConfigurationProperties;
+import org.springframework.context.ApplicationContext;
 import org.springframework.context.annotation.Bean;
+import org.springframework.core.io.Resource;
+
+import java.io.IOException;
+import java.util.HashMap;
+import java.util.Map;
 
 /**
  * ai 自动配置
@@ -103,4 +115,45 @@ public class YudaoAiAutoConfiguration {
                         .setStyle(openAiImageProperties.getStyle())
         );
     }
+
+    @Bean
+    @ConditionalOnProperty(value = "yudao.ai.midjourney.enable", havingValue = "true")
+    public MidjourneyWebSocketStarter midjourneyWebSocketStarter(ApplicationContext applicationContext, YudaoAiProperties yudaoAiProperties) {
+        // 获取 midjourneyProperties
+        YudaoAiProperties.MidjourneyProperties midjourneyProperties = yudaoAiProperties.getMidjourney();
+        // 获取 midjourneyConfig
+        MidjourneyConfig midjourneyConfig = getMidjourneyConfig(applicationContext, midjourneyProperties);
+        // 创建 socket messageListener
+        MidjourneyMessageListener messageListener = new MidjourneyMessageListener(midjourneyConfig);
+        // 创建 MidjourneyWebSocketStarter
+        return new MidjourneyWebSocketStarter(midjourneyProperties.getWssUrl(), null, midjourneyConfig, messageListener);
+    }
+
+    @Bean
+    @ConditionalOnProperty(value = "yudao.ai.midjourney.enable", havingValue = "true")
+    public MidjourneyInteractionsApi midjourneyInteractionsApi(ApplicationContext applicationContext, YudaoAiProperties yudaoAiProperties) {
+        // 获取 midjourneyConfig
+        MidjourneyConfig midjourneyConfig = getMidjourneyConfig(applicationContext, yudaoAiProperties.getMidjourney());
+        // 创建 MidjourneyInteractionsApi
+        return new MidjourneyInteractionsApi(midjourneyConfig);
+    }
+
+
+    private static @NotNull MidjourneyConfig getMidjourneyConfig(ApplicationContext applicationContext,
+                                                                 YudaoAiProperties.MidjourneyProperties midjourneyProperties) {
+        Map<String, String> requestTemplates = new HashMap<>();
+        try {
+            Resource[] resources = applicationContext.getResources("classpath:http-body/*.json");
+            for (var resource : resources) {
+                String filename = resource.getFilename();
+                String params = IoUtil.readUtf8(resource.getInputStream());
+                requestTemplates.put(filename.substring(0, filename.length() - 5), params);
+            }
+        } catch (IOException e) {
+            throw new IllegalArgumentException("Midjourney json模板初始化出错! " + e.getMessage());
+        }
+        // 创建 midjourneyConfig
+        return new MidjourneyConfig(midjourneyProperties.getToken(),
+                midjourneyProperties.getGuildId(), midjourneyProperties.getChannelId(), requestTemplates);
+    }
 }

+ 26 - 0
yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/config/YudaoAiProperties.java

@@ -26,6 +26,7 @@ public class YudaoAiProperties {
     private XingHuoProperties xinghuo;
     private YiYanProperties yiyan;
     private OpenAiImageProperties openAiImage;
+    private MidjourneyProperties midjourney;
 
     @Data
     @Accessors(chain = true)
@@ -94,6 +95,8 @@ public class YudaoAiProperties {
     @Data
     @Accessors(chain = true)
     public static class OpenAiImageProperties {
+        private boolean enable = false;
+
         /**
          * api key
          */
@@ -107,4 +110,27 @@ public class YudaoAiProperties {
          */
         private OpenAiImageStyleEnum style = OpenAiImageStyleEnum.VIVID;
     }
+
+    @Data
+    @Accessors(chain = true)
+    public static class MidjourneyProperties {
+        private boolean enable = false;
+
+        /**
+         * socket 链接地址
+         */
+        private String wssUrl = "wss://gateway.discord.gg";
+        /**
+         * token
+         */
+        private String token;
+        /**
+         * 服务id
+         */
+        private String guildId;
+        /**
+         * 频道id
+         */
+        private String channelId;
+    }
 }

+ 5 - 1
yudao-server/src/main/resources/application-local.yaml

@@ -260,7 +260,11 @@ yudao:
       api-key: ${OPEN_AI_KEY}
       model: dall_e_2
       style: vivid
-
+    midjourney:
+      enable: true
+      token: OTcyNzIxMzA0ODkxNDUzNDUw.G_vMOz.BO_Q0sXAD80u5ZKIHPNYDTRX_FgeKL3cKFc53I
+      guild-id: 1225608134878302329
+      channel-id: 1225608134878302332
   captcha:
     enable: false # 本地环境,暂时关闭图片验证码,方便登录等接口的测试;
   security: