Prechádzať zdrojové kódy

【添加】midjourney 增加 imagine 接口

cherishsince 1 rok pred
rodič
commit
330fd52b3e

+ 13 - 4
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/AiImageController.java

@@ -1,14 +1,17 @@
 package cn.iocoder.yudao.module.ai.controller;
 
+import cn.iocoder.yudao.framework.common.pojo.CommonResult;
 import cn.iocoder.yudao.module.ai.service.AiImageService;
 import cn.iocoder.yudao.module.ai.vo.AiImageDallDrawingReq;
+import cn.iocoder.yudao.module.ai.vo.AiImageMidjourneyReq;
+import cn.iocoder.yudao.module.ai.vo.AiImageMidjourneyRes;
 import io.swagger.v3.oas.annotations.Operation;
 import io.swagger.v3.oas.annotations.tags.Tag;
 import lombok.AllArgsConstructor;
 import lombok.extern.slf4j.Slf4j;
 import org.springframework.validation.annotation.Validated;
-import org.springframework.web.bind.annotation.GetMapping;
-import org.springframework.web.bind.annotation.ModelAttribute;
+import org.springframework.web.bind.annotation.PostMapping;
+import org.springframework.web.bind.annotation.RequestBody;
 import org.springframework.web.bind.annotation.RequestMapping;
 import org.springframework.web.bind.annotation.RestController;
 import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
@@ -30,10 +33,16 @@ public class AiImageController {
     private final AiImageService aiImageService;
 
     @Operation(summary = "dall2/dall3绘画", description = "openAi dall3是付费的!")
-    @GetMapping("/dallDrawing")
-    public SseEmitter dallDrawing(@Validated @ModelAttribute AiImageDallDrawingReq req) {
+    @PostMapping("/dallDrawing")
+    public SseEmitter dallDrawing(@Validated @RequestBody AiImageDallDrawingReq req) {
         Utf8SseEmitter sseEmitter = new Utf8SseEmitter();
         aiImageService.dallDrawing(req, sseEmitter);
         return sseEmitter;
     }
+
+    @Operation(summary = "midjourney", description = "midjourney图片绘画流程:1、提交任务 2、获取完成的任务 3、选择对应功能 4、获取最终结果")
+    @PostMapping("/midjourney")
+    public CommonResult<AiImageMidjourneyRes> midjourney(@Validated @RequestBody AiImageMidjourneyReq req) {
+        return CommonResult.success(aiImageService.midjourney(req));
+    }
 }

+ 10 - 0
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/AiImageService.java

@@ -2,6 +2,8 @@ package cn.iocoder.yudao.module.ai.service;
 
 import cn.iocoder.yudao.module.ai.controller.Utf8SseEmitter;
 import cn.iocoder.yudao.module.ai.vo.AiImageDallDrawingReq;
+import cn.iocoder.yudao.module.ai.vo.AiImageMidjourneyReq;
+import cn.iocoder.yudao.module.ai.vo.AiImageMidjourneyRes;
 
 /**
  * ai 作图
@@ -19,4 +21,12 @@ public interface AiImageService {
      * @param sseEmitter
      */
     void dallDrawing(AiImageDallDrawingReq req, Utf8SseEmitter sseEmitter);
+
+    /**
+     * midjourney 图片生成
+     *
+     * @param req
+     * @return
+     */
+    AiImageMidjourneyRes midjourney(AiImageMidjourneyReq req);
 }

+ 56 - 6
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/impl/AiImageServiceImpl.java

@@ -8,17 +8,26 @@ import cn.iocoder.yudao.framework.ai.imageopenai.OpenAiImageClient;
 import cn.iocoder.yudao.framework.ai.imageopenai.OpenAiImageModelEnum;
 import cn.iocoder.yudao.framework.ai.imageopenai.OpenAiImageOptions;
 import cn.iocoder.yudao.framework.ai.imageopenai.OpenAiImageStyleEnum;
+import cn.iocoder.yudao.framework.ai.midjourney.api.MidjourneyInteractionsApi;
+import cn.iocoder.yudao.framework.ai.midjourney.webSocket.MidjourneyWebSocketStarter;
+import cn.iocoder.yudao.framework.ai.midjourney.webSocket.WssNotify;
+import cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionUtil;
 import cn.iocoder.yudao.framework.security.core.util.SecurityFrameworkUtils;
+import cn.iocoder.yudao.module.ai.ErrorCodeConstants;
 import cn.iocoder.yudao.module.ai.controller.Utf8SseEmitter;
 import cn.iocoder.yudao.module.ai.dal.dataobject.AiImageDO;
 import cn.iocoder.yudao.module.ai.enums.AiChatDrawingStatusEnum;
 import cn.iocoder.yudao.module.ai.mapper.AiImageMapper;
 import cn.iocoder.yudao.module.ai.service.AiImageService;
 import cn.iocoder.yudao.module.ai.vo.AiImageDallDrawingReq;
+import cn.iocoder.yudao.module.ai.vo.AiImageMidjourneyReq;
+import cn.iocoder.yudao.module.ai.vo.AiImageMidjourneyRes;
+import jakarta.annotation.PostConstruct;
 import lombok.AllArgsConstructor;
 import lombok.extern.slf4j.Slf4j;
 import org.springframework.http.MediaType;
 import org.springframework.stereotype.Service;
+import org.springframework.transaction.annotation.Transactional;
 
 import java.io.IOException;
 
@@ -36,6 +45,24 @@ public class AiImageServiceImpl implements AiImageService {
 
     private final AiImageMapper aiImageMapper;
     private final OpenAiImageClient openAiImageClient;
+    private final MidjourneyWebSocketStarter midjourneyWebSocketStarter;
+    private final MidjourneyInteractionsApi midjourneyInteractionsApi;
+
+    @PostConstruct
+    public void startMidjourney() {
+        log.info("midjourney web socket starter...");
+        midjourneyWebSocketStarter.start(new WssNotify() {
+            @Override
+            public void notify(int code, String message) {
+                log.info("code: {}, message: {}", code, message);
+                if (message.contains("Authentication failed")) {
+                    // TODO 芋艿,这里看怎么处理,token无效的时候会认证失败!
+                    // 认证失败
+                    log.error("midjourney socket 认证失败,检查token是否失效!");
+                }
+            }
+        });
+    }
 
     @Override
     public void dallDrawing(AiImageDallDrawingReq req, Utf8SseEmitter sseEmitter) {
@@ -55,15 +82,33 @@ public class AiImageServiceImpl implements AiImageService {
             // 发送信息
             sendSseEmitter(sseEmitter, imageGeneration);
             // 保存数据库
-            doSave(req, imageGeneration.getOutput().getUrl(), AiChatDrawingStatusEnum.COMPLETE, null);
+            doSave(req.getPrompt(), req.getSize(), req.getModal(),
+                    imageGeneration.getOutput().getUrl(), AiChatDrawingStatusEnum.COMPLETE, null);
         } catch (AiException aiException) {
             // 保存数据库
-            doSave(req, null, AiChatDrawingStatusEnum.FAIL, aiException.getMessage());
+            doSave(req.getPrompt(), req.getSize(), req.getModal(),
+                    null, AiChatDrawingStatusEnum.FAIL, aiException.getMessage());
             // 发送错误信息
             sendSseEmitter(sseEmitter, aiException.getMessage());
         }
     }
 
+    @Override
+    @Transactional(rollbackFor = Exception.class)
+    public AiImageMidjourneyRes midjourney(AiImageMidjourneyReq req) {
+        // 保存数据库
+        doSave(req.getPrompt(), null, "midjoureny",
+                null, AiChatDrawingStatusEnum.SUBMIT, null);
+        // 提交 midjourney 任务
+        Boolean imagine = midjourneyInteractionsApi.imagine(req.getPrompt());
+        if (!imagine) {
+            throw ServiceExceptionUtil.exception(ErrorCodeConstants.AI_MIDJOURNEY_IMAGINE_FAIL);
+        }
+        //
+
+        return null;
+    }
+
     private static void sendSseEmitter(Utf8SseEmitter sseEmitter, Object object) {
         try {
             sseEmitter.send(object, MediaType.APPLICATION_JSON);
@@ -75,14 +120,19 @@ public class AiImageServiceImpl implements AiImageService {
         }
     }
 
-    private void doSave(AiImageDallDrawingReq req, String imageUrl, AiChatDrawingStatusEnum drawingStatusEnum, String drawingError) {
+    private void doSave(String prompt,
+                        String size,
+                        String model,
+                        String imageUrl,
+                        AiChatDrawingStatusEnum drawingStatusEnum,
+                        String drawingError) {
         // 保存数据库
         Long loginUserId = SecurityFrameworkUtils.getLoginUserId();
         AiImageDO aiImageDO = new AiImageDO();
         aiImageDO.setId(null);
-        aiImageDO.setPrompt(req.getPrompt());
-        aiImageDO.setSize(req.getSize());
-        aiImageDO.setModal(req.getModal());
+        aiImageDO.setPrompt(prompt);
+        aiImageDO.setSize(size);
+        aiImageDO.setModal(model);
         aiImageDO.setUserId(loginUserId);
         aiImageDO.setDrawingImageUrl(imageUrl);
         aiImageDO.setDrawingStatus(drawingStatusEnum.getStatus());