Parcourir la source

【同步】AI:最新 MJ 的 code review

YunaiV il y a 1 an
Parent
commit
e781129dbe

+ 1 - 1
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/client/MidjourneyProxyClient.java

@@ -17,7 +17,7 @@ import org.springframework.web.client.RestTemplate;
 import java.util.Collection;
 import java.util.List;
 
-// TODO @fan:这个写到 starter-ai 里哈。搞个 MidjourneyApi,参考 https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java 的风格写哈
+// TODO @fan:【高优】这个写到 starter-ai 里哈。搞个 MidjourneyApi,参考 https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java 的风格写哈
 /**
  * Midjourney Proxy 客户端
  *

+ 2 - 6
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/client/enums/MidjourneySubmitCodeEnum.java

@@ -11,18 +11,14 @@ import java.util.List;
  * Midjourney 提交任务 code 枚举
  *
  * @author fansili
- * @time 2024/5/30 14:33
- * @since 1.0
  */
 @Getter
 @AllArgsConstructor
 public enum MidjourneySubmitCodeEnum {
 
-    // 状态码: 1(提交成功), 21(已存在), 22(排队中), other(错误)
     SUBMIT_SUCCESS("1", "提交成功"),
     ALREADY_EXISTS("21", "已存在"),
     QUEUING("22", "排队中"),
-
     ;
 
     public static final List<String> SUCCESS_CODES = Lists.newArrayList(
@@ -31,7 +27,7 @@ public enum MidjourneySubmitCodeEnum {
             QUEUING.code
     );
 
-    private String code;
-    private String name;
+    private final String code;
+    private final String name;
 
 }

+ 6 - 4
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/image/AiImageController.java

@@ -63,25 +63,27 @@ public class AiImageController {
         return success(true);
     }
 
-    // ================ midjourney 接口
+    // ================ midjourney 接口 ================
 
-    @Operation(summary = "midjourney-imagine 绘画", description = "...")
+    @Operation(summary = "Midjourney imagine(绘画)")
     @PostMapping("/midjourney/imagine")
     public CommonResult<Long> midjourneyImagine(@Validated @RequestBody AiImageMidjourneyImagineReqVO req) {
         return success(imageService.midjourneyImagine(getLoginUserId(), req));
     }
 
-    @Operation(summary = "midjourney proxy - 回调通知")
+    @Operation(summary = "Midjourney 回调通知", description = "由 Midjourney Proxy 回调")
     @PostMapping("/midjourney-notify")
     @PermitAll
     public CommonResult<Boolean> midjourneyNotify(@RequestBody MidjourneyNotifyReqVO notifyReqVO) {
         return success(imageService.midjourneyNotify(notifyReqVO));
     }
 
-    @Operation(summary = "midjourney - action(放大、缩小、U1、U2...)")
+    @Operation(summary = "Midjourney Action", description = "例如说:放大、缩小、U1、U2 等")
     @GetMapping("/midjourney/action")
+    // TODO @fan:id、customerId 的 swagger 注解
     public CommonResult<Boolean> midjourneyAction(@RequestParam("id") Long imageId,
                                                   @RequestParam("customId") String customId) {
         return success(imageService.midjourneyAction(getLoginUserId(), imageId, customId));
     }
+
 }

+ 2 - 0
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/image/vo/AiImageRespVO.java

@@ -50,9 +50,11 @@ public class AiImageRespVO {
     @Schema(description = "绘画 response")
     private MidjourneyNotifyReqVO response;
 
+    // TODO @fan:进度是百分比,还是一个数字哈?感觉这个可以统一成通用字段;
     @Schema(description = "mj 进度")
     private String progress;
 
     @Schema(description = "mj buttons 按钮")
     private List<MidjourneyNotifyReqVO.Button> buttons;
+
 }

+ 2 - 0
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/image/AiImageDO.java

@@ -123,6 +123,7 @@ public class AiImageDO extends BaseDO {
      */
     private String errorMessage;
 
+    // TODO @芋艿:看看是不是 MidjourneyNotifyReqVO.Button 搞到 MJ API 那
     public static class ButtonTypeHandler extends AbstractJsonTypeHandler<Object> {
 
         @Override
@@ -134,6 +135,7 @@ public class AiImageDO extends BaseDO {
         protected String toJson(Object obj) {
             return JsonUtils.toJsonString(obj);
         }
+
     }
 
 }

+ 10 - 0
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/job/MidjourneyJob.java

@@ -30,6 +30,7 @@ import java.util.stream.Collectors;
 @Slf4j
 public class MidjourneyJob implements JobHandler {
 
+    // TODO @fan:@Resource
     @Autowired
     private MidjourneyProxyClient midjourneyProxyClient;
     @Autowired
@@ -37,10 +38,13 @@ public class MidjourneyJob implements JobHandler {
     @Autowired
     private AiImageService imageService;
 
+    // TODO @fan:这个方法,建议实现到 AiImageService,例如说 midjourneySync,返回 int 同步数量。
     @Override
     public String execute(String param) throws Exception {
         // 1、获取 midjourney 平台,状态在 “进行中” 的 image
+        // TODO @fan:43 和 51 其实有点重叠,日志,建议只打 51
         log.info("Midjourney 同步 - 开始...");
+        // TODO @fan:Job、Service 等业务层,不要直接使用 LambdaUpdateWrapper,这样会导致 mapper 穿透到逻辑层。要收敛到 mapper 里。
         List<AiImageDO> imageList = imageMapper.selectList(
                 new LambdaUpdateWrapper<AiImageDO>()
                         .eq(AiImageDO::getStatus, AiImageStatusEnum.IN_PROGRESS.getStatus())
@@ -48,11 +52,14 @@ public class MidjourneyJob implements JobHandler {
         );
         log.info("Midjourney 同步 - 任务数量 {}!", imageList.size());
         if (CollUtil.isEmpty(imageList)) {
+            // TODO @fan:51 和 54,其实有点重叠。建议 51 挪到 55 之后打。
             return "Midjourney 同步 - 数量为空!";
         }
         // 2、批量拉去 task 信息
+        // TODO @fan:imageList.stream().map(AiImageDO::getTaskId).collect(Collectors.toSet())),可以使用 CollectionUtils.convertSet 简化
         List<MidjourneyNotifyReqVO> taskList = midjourneyProxyClient
                 .listByCondition(imageList.stream().map(AiImageDO::getTaskId).collect(Collectors.toSet()));
+        // TODO @fan:taskList.stream().collect(Collectors.toMap(MidjourneyNotifyReqVO::getId, o -> o)),也可以使用 CollectionUtils.convertMap;本质上,重用 set、map 转换,要 convert 简化
         Map<String, MidjourneyNotifyReqVO> taskIdMap = taskList.stream().collect(Collectors.toMap(MidjourneyNotifyReqVO::getId, o -> o));
         // 3、更新 image 状态
         List<AiImageDO> updateImageList = new ArrayList<>();
@@ -62,13 +69,16 @@ public class MidjourneyJob implements JobHandler {
                 log.warn("Midjourney 同步 - {} 任务为空!", aiImageDO.getTaskId());
                 continue;
             }
+            // TODO @ 3.1 和 3.2 是不是融合下;get,然后判空,continue;
             // 3.2 获取通知对象
             MidjourneyNotifyReqVO notifyReqVO = taskIdMap.get(aiImageDO.getTaskId());
             // 3.2 构建更新对象
+            // TODO @fan:建议 List<MidjourneyNotifyReqVO> 作为 imageService 去更新;
             updateImageList.add(imageService.buildUpdateImage(aiImageDO.getId(), notifyReqVO));
         }
         // 4、批了更新 updateImageList
         imageMapper.updateBatch(updateImageList);
         return "Midjourney 同步 - ".concat(String.valueOf(updateImageList.size())).concat(" 任务!");
     }
+
 }

+ 6 - 5
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/image/AiImageService.java

@@ -36,17 +36,18 @@ public interface AiImageService {
      *
      * @param userId 用户编号
      * @param drawReqVO 绘制请求
+     * @return 绘画编号
      */
     Long drawImage(Long userId, AiImageDrawReqVO drawReqVO);
 
     /**
-     * midjourney 图片生成
+     * Midjourney imagine(绘画)
      *
-     * @param loginUserId
-     * @param req
-     * @return
+     * @param userId 用户编号
+     * @param imagineReqVO 绘制请求
+     * @return 绘画编号
      */
-    Long midjourneyImagine(Long loginUserId, AiImageMidjourneyImagineReqVO req);
+    Long midjourneyImagine(Long userId, AiImageMidjourneyImagineReqVO imagineReqVO);
 
     /**
      * 删除【我的】绘画记录

+ 18 - 4
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/image/AiImageServiceImpl.java

@@ -133,10 +133,12 @@ public class AiImageServiceImpl implements AiImageService {
 
     @Override
     @Transactional(rollbackFor = Exception.class)
-    public Long midjourneyImagine(Long loginUserId, AiImageMidjourneyImagineReqVO req) {
+    public Long midjourneyImagine(Long userId, AiImageMidjourneyImagineReqVO req) {
+        // TODO @fan:1 和 2 应该放在一个 1 里面。不然 = = 一个逻辑就显得有很多 1、/2、/3、/4;这么分的原因,是方便阅读的时候,容易理解。
         // 1、构建 AiImageDO
+        // TODO @fan:1)aiImageDO 可以缩写成 image 更简洁;2)可以链式调用,把相同的放在一行里,这样更容易分组
         AiImageDO aiImageDO = new AiImageDO();
-        aiImageDO.setUserId(loginUserId);
+        aiImageDO.setUserId(userId);
         aiImageDO.setPrompt(req.getPrompt());
         aiImageDO.setPlatform(AiPlatformEnum.MIDJOURNEY.getPlatform());
         aiImageDO.setModel(req.getModel());
@@ -145,6 +147,7 @@ public class AiImageServiceImpl implements AiImageService {
         aiImageDO.setStatus(AiImageStatusEnum.IN_PROGRESS.getStatus());
         // 2、保存 image
         imageMapper.insert(aiImageDO);
+        // TODO @fan:3 和 2 之间,应该空一行;因为这里是开始发起请求第三方,是个单独的小块逻辑
         // 3、调用 MidjourneyProxy 提交任务
         MidjourneyImagineReqVO imagineReqVO = BeanUtils.toBean(req, MidjourneyImagineReqVO.class);
         imagineReqVO.setNotifyHook(midjourneyNotifyUrl);
@@ -152,12 +155,15 @@ public class AiImageServiceImpl implements AiImageService {
         imagineReqVO.setState(buildParams(req.getWidth(),
                 req.getHeight(), req.getVersion(), MidjourneyModelEnum.valueOfModel(req.getModel())));
         // 5、提交绘画请求
+        // TODO @fan:5 这里,失败的情况,到底抛出异常,还是 RespVO,可以参考 OpenAI 的 API 封装
         MidjourneySubmitRespVO submitRespVO = midjourneyProxyClient.imagine(imagineReqVO);
         // 6、保存任务 id (状态码: 1(提交成功), 21(已存在), 22(排队中), other(错误))
         if (!MidjourneySubmitCodeEnum.SUCCESS_CODES.contains(submitRespVO.getCode())) {
             throw exception(AI_IMAGE_MIDJOURNEY_SUBMIT_FAIL, submitRespVO.getDescription());
         }
+        // TODO @fan:7 和 6 之间,应该空一行;这样,最终这个逻辑,就会有 2 个空行,3 小块逻辑:1)插入;2)调用;3)更新
         // 7、构建 imageOptions 参数
+        // TODO @fan:1)链式调用;2)其实可以直接使用 AiImageMidjourneyImagineReqVO。不用单独一个 options 类哈。
         MidjourneyImageOptions imageOptions = new MidjourneyImageOptions()
                 .setWidth(req.getWidth())
                 .setHeight(req.getHeight())
@@ -181,10 +187,11 @@ public class AiImageServiceImpl implements AiImageService {
         if (ObjUtil.notEqual(image.getUserId(), userId)) {
             throw exception(AI_IMAGE_NOT_EXISTS);
         }
-        // 2删除记录
+        // 2. 删除记录
         imageMapper.deleteById(id);
     }
 
+    // TODO @fan:建议返回 void;然后如果不存在,就抛出异常哈;
     @Override
     public Boolean midjourneyNotify(MidjourneyNotifyReqVO notifyReqVO) {
         // 1、根据 job id 查询关联的 image
@@ -228,15 +235,19 @@ public class AiImageServiceImpl implements AiImageService {
                 .setErrorMessage(notifyReqVO.getFailReason());
     }
 
+    // TODO @fan:1)不用返回 Boolean
     @Override
-    @Transactional(rollbackFor = Exception.class)
+    @Transactional(rollbackFor = Exception.class) // TODO @fan:只操作一个 db,不用事务哈;
     public Boolean midjourneyAction(Long loginUserId, Long imageId, String customId) {
+        // TODO @fan:1)1 和 2,可以写成 1.1、1.2,都是在做校验;2)validateCustomId 可以直接抛出 AI_IMAGE_CUSTOM_ID_NOT_EXISTS 异常;一般情况下,validateXXX 都是失败抛出异常,isXXXValid 返回 true、false
         // 1、检查 image
+        // TODO @fan:1)aiImageDO 缩写成 image;
         AiImageDO aiImageDO = validateImageExists(imageId);
         // 2、检查 customId
         if (!validateCustomId(customId, aiImageDO.getButtons())) {
             throw exception(AI_IMAGE_CUSTOM_ID_NOT_EXISTS);
         }
+        // TODO @fan:2 和 3 之间,空一行
         // 3、调用 midjourney proxy
         MidjourneySubmitRespVO submitRespVO = midjourneyProxyClient.action(
                 new MidjourneyActionReqVO()
@@ -248,8 +259,10 @@ public class AiImageServiceImpl implements AiImageService {
         if (!MidjourneySubmitCodeEnum.SUCCESS_CODES.contains(submitRespVO.getCode())) {
             throw exception(AI_IMAGE_MIDJOURNEY_SUBMIT_FAIL, submitRespVO.getDescription());
         }
+        // TODO 6 和 4 之间空一行;
         // 4、新增 image 记录
         AiImageDO newImage = BeanUtils.toBean(aiImageDO, AiImageDO.class);
+        // TODO @fan:最好不要 copy 属性。因为未来如果加属性,可能会导致额外 copy 了;最好是 new 赋值下,显示声明。
         // 4.1、重置参数
         newImage.setId(null);
         newImage.setStatus(AiImageStatusEnum.IN_PROGRESS.getStatus());
@@ -290,6 +303,7 @@ public class AiImageServiceImpl implements AiImageService {
         return SpringUtil.getBean(getClass());
     }
 
+    // TODO @fan:这个是不是应该放在 MJ API 的封装里面搞哈?
     /**
      * 构建 Midjourney 自定义参数
      *

+ 1 - 2
yudao-server/src/main/resources/application-local.yaml

@@ -76,14 +76,13 @@ server:
       enabled: true
       charset: UTF-8
       force: true
-# ai
+# ai TODO @fan:这个融合到 yudao.ai 那好点哈
 ai:
   midjourney-proxy:
     url: https://api.holdai.top/mj
     notifyUrl: http://61d61685.r21.cpolar.top/admin-api/ai/image/midjourney-notify
     key: sk-c3qxUCVKsPfdQiYU8440E3Fc8dE5424d9cB124A4Ee2489E3
 
-
 --- #################### 定时任务相关配置 ####################
 
 # Quartz 配置项,对应 QuartzProperties 配置类