Browse Source

【增加】AI Image mj 增加 action 操作

cherishsince 10 months ago
parent
commit
776d6e4e1e

+ 1 - 0
yudao-module-ai/yudao-module-ai-api/src/main/java/cn/iocoder/yudao/module/ai/enums/ErrorCodeConstants.java

@@ -44,5 +44,6 @@ public interface ErrorCodeConstants {
 
     ErrorCode AI_IMAGE_NOT_EXISTS = new ErrorCode(1_022_005_000, "image 不存在!");
     ErrorCode AI_IMAGE_MIDJOURNEY_SUBMIT_FAIL = new ErrorCode(1_022_005_001, "Midjourney 提交失败! {}");
+    ErrorCode AI_IMAGE_CUSTOM_ID_NOT_EXISTS = new ErrorCode(1_022_005_002, "Midjourney 按钮 customId 不存在! {}");
 
 }

+ 8 - 2
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/image/AiImageController.java

@@ -74,8 +74,14 @@ public class AiImageController {
     @Operation(summary = "midjourney proxy - 回调通知")
     @PostMapping("/midjourney-notify")
     @PermitAll
-    public CommonResult<Boolean> midjourneyNotify( @RequestBody MidjourneyNotifyReqVO notifyReqVO) {
-        return success(imageService.midjourneyNotify(getLoginUserId(), notifyReqVO));
+    public CommonResult<Boolean> midjourneyNotify(@RequestBody MidjourneyNotifyReqVO notifyReqVO) {
+        return success(imageService.midjourneyNotify(notifyReqVO));
     }
 
+    @Operation(summary = "midjourney - action(放大、缩小、U1、U2...)")
+    @PostMapping("/midjourney/action")
+    public CommonResult<Boolean> midjourneyAction(@RequestParam("id") Long imageId,
+                                                  @RequestParam("customId") String customId) {
+        return success(imageService.midjourneyAction(getLoginUserId(), imageId, customId));
+    }
 }

+ 10 - 2
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/image/AiImageService.java

@@ -59,10 +59,18 @@ public interface AiImageService {
     /**
      * midjourney proxy - 回调通知
      *
-     * @param loginUserId
      * @param notifyReqVO
      * @return
      */
-    Boolean midjourneyNotify(Long loginUserId, MidjourneyNotifyReqVO notifyReqVO);
+    Boolean midjourneyNotify(MidjourneyNotifyReqVO notifyReqVO);
 
+    /**
+     * midjourney - action(放大、缩小、U1、U2...)
+     *
+     * @param loginUserId
+     * @param imageId
+     * @param customId
+     * @return
+     */
+    Boolean midjourneyAction(Long loginUserId, Long imageId, String customId);
 }

+ 61 - 17
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/image/AiImageServiceImpl.java

@@ -15,6 +15,7 @@ import cn.iocoder.yudao.module.ai.client.MidjourneyProxyClient;
 import cn.iocoder.yudao.module.ai.client.enums.MidjourneyModelEnum;
 import cn.iocoder.yudao.module.ai.client.enums.MidjourneySubmitCodeEnum;
 import cn.iocoder.yudao.module.ai.client.enums.MidjourneyTaskStatusEnum;
+import cn.iocoder.yudao.module.ai.client.vo.MidjourneyActionReqVO;
 import cn.iocoder.yudao.module.ai.client.vo.MidjourneyImagineReqVO;
 import cn.iocoder.yudao.module.ai.client.vo.MidjourneyNotifyReqVO;
 import cn.iocoder.yudao.module.ai.client.vo.MidjourneySubmitRespVO;
@@ -39,9 +40,10 @@ import org.springframework.scheduling.annotation.Async;
 import org.springframework.stereotype.Service;
 import org.springframework.transaction.annotation.Transactional;
 
+import java.util.List;
+
 import static cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionUtil.exception;
-import static cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants.AI_IMAGE_MIDJOURNEY_SUBMIT_FAIL;
-import static cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants.AI_IMAGE_NOT_EXISTS;
+import static cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants.*;
 
 /**
  * AI 绘画 Service 实现类
@@ -136,30 +138,21 @@ public class AiImageServiceImpl implements AiImageService {
         aiImageDO.setUserId(loginUserId);
         aiImageDO.setPrompt(req.getPrompt());
         aiImageDO.setPlatform(AiPlatformEnum.MIDJOURNEY.getPlatform());
-        // todo @范 平台需要转换(mj 模型一般分版本)
         aiImageDO.setModel(null);
         aiImageDO.setWidth(null);
         aiImageDO.setHeight(null);
         aiImageDO.setStatus(AiImageStatusEnum.IN_PROGRESS.getStatus());
-
         // 2、保存 image
         imageMapper.insert(aiImageDO);
-
         // 3、调用 MidjourneyProxy 提交任务
         MidjourneyImagineReqVO imagineReqVO = BeanUtils.toBean(req, MidjourneyImagineReqVO.class);
         imagineReqVO.setNotifyHook(midjourneyNotifyUrl);
-        // 设置 midjourney 扩展参数
-        //  --ar 来设置尺寸
-        String midjourneySizeParam = String.format(" --ar %s:%s ", req.getWidth(), req.getHeight());
-        // --v 版本
-        String midjourneyVersionParam = String.format(" --v %s ", req.getVersion());
-        // --niji 模型
-        MidjourneyModelEnum midjourneyModelEnum = MidjourneyModelEnum.valueOfModel(req.getModel());
-        String midjourneyNijiParam = MidjourneyModelEnum.NIJI == midjourneyModelEnum ? " --niji " : "";
-        // 设置参数
-        imagineReqVO.setState(midjourneySizeParam.concat(midjourneyVersionParam).concat(midjourneyNijiParam));
+        // 4、设置 midjourney 扩展参数
+        imagineReqVO.setState(buildParams(req.getWidth(),
+                req.getHeight(), req.getVersion(), MidjourneyModelEnum.valueOfModel(req.getModel())));
+        // 5、提交绘画请求
         MidjourneySubmitRespVO submitRespVO = midjourneyProxyClient.imagine(imagineReqVO);
-        // 4、保存任务 id (状态码: 1(提交成功), 21(已存在), 22(排队中), other(错误))
+        // 6、保存任务 id (状态码: 1(提交成功), 21(已存在), 22(排队中), other(错误))
         if (!MidjourneySubmitCodeEnum.SUCCESS_CODES.contains(submitRespVO.getCode())) {
             throw exception(AI_IMAGE_MIDJOURNEY_SUBMIT_FAIL, submitRespVO.getDescription());
         }
@@ -170,6 +163,8 @@ public class AiImageServiceImpl implements AiImageService {
         return aiImageDO.getId();
     }
 
+
+
     @Override
     public void deleteImageMy(Long id, Long userId) {
         // 1. 校验是否存在
@@ -182,7 +177,7 @@ public class AiImageServiceImpl implements AiImageService {
     }
 
     @Override
-    public Boolean midjourneyNotify(Long loginUserId, MidjourneyNotifyReqVO notifyReqVO) {
+    public Boolean midjourneyNotify(MidjourneyNotifyReqVO notifyReqVO) {
         // 1、根据 job id 查询关联的 image
         AiImageDO image = imageMapper.selectByJobId(notifyReqVO.getId());
         if (image == null) {
@@ -220,6 +215,34 @@ public class AiImageServiceImpl implements AiImageService {
         return true;
     }
 
+    @Override
+    @Transactional(rollbackFor = Exception.class)
+    public Boolean midjourneyAction(Long loginUserId, Long imageId, String customId) {
+        // 1、检查 image
+        AiImageDO aiImageDO = validateImageExists(imageId);
+        // 2、检查 customId
+        if (!validateCustomId(customId, aiImageDO.getButtons())) {
+            throw exception(AI_IMAGE_CUSTOM_ID_NOT_EXISTS);
+        }
+        // 3、调用 midjourney proxy
+        midjourneyProxyClient.action(
+                new MidjourneyActionReqVO()
+                        .setCustomId(customId)
+                        .setTaskId(aiImageDO.getJobId())
+                        .setNotifyHook(midjourneyNotifyUrl)
+        );
+        return Boolean.TRUE;
+    }
+
+    private static boolean validateCustomId(String customId, List<MidjourneyNotifyReqVO.Button> buttons) {
+        for (MidjourneyNotifyReqVO.Button button : buttons) {
+            if (button.getCustomId().equals(customId)) {
+                return true;
+            }
+        }
+        return false;
+    }
+
     private AiImageDO validateImageExists(Long id) {
         AiImageDO image = imageMapper.selectById(id);
         if (image == null) {
@@ -237,4 +260,25 @@ public class AiImageServiceImpl implements AiImageService {
         return SpringUtil.getBean(getClass());
     }
 
+    /**
+     * 构建 Midjourney 自定义参数
+     *
+     * @param width
+     * @param height
+     * @param version
+     * @param model
+     * @return
+     */
+    private String buildParams(String width, String height, String version, MidjourneyModelEnum model) {
+        StringBuilder params = new StringBuilder();
+        //  --ar 来设置尺寸
+        params.append(String.format(" --ar %s:%s ", width, height));
+        // --v 版本
+        params.append(String.format(" --v %s ", version));
+        // --niji 模型
+        if (MidjourneyModelEnum.NIJI == model) {
+            params.append(" --niji ");
+        }
+        return params.toString();
+    }
 }

+ 2 - 1
yudao-server/src/main/resources/application-local.yaml

@@ -80,7 +80,8 @@ server:
 ai:
   midjourney-proxy:
     url: https://api.holdai.top/mj
-    notifyUrl: http://7b1aada4.r26.cpolar.top/admin-api/ai/image/midjourney-notify
+    notifyUrl: http://61d61685.r21.cpolar.top/admin-api/ai/image/midjourney-notify
+    key: sk-c3qxUCVKsPfdQiYU8440E3Fc8dE5424d9cB124A4Ee2489E3
 
 
 --- #################### 定时任务相关配置 ####################