Przeglądaj źródła

【增加】mj图片处理成功消息,增加component操作,处理error信息保存

cherishsince 1 rok temu
rodzic
commit
38a9c1a7ee

+ 10 - 0
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/convert/AiImageConvert.java

@@ -1,8 +1,10 @@
 package cn.iocoder.yudao.module.ai.convert;
 
+import cn.iocoder.yudao.framework.ai.midjourney.MidjourneyMessage;
 import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageDallDrawingReqVO;
 import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageDallDrawingRespVO;
 import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageListRespVO;
+import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageMidjourneyOperationsVO;
 import cn.iocoder.yudao.module.ai.dal.dataobject.image.AiImageDO;
 import org.mapstruct.Mapper;
 import org.mapstruct.factory.Mappers;
@@ -36,4 +38,12 @@ public interface AiImageConvert {
      * @return
      */
     List<AiImageListRespVO> convertAiImageListRespVO(List<AiImageDO> list);
+
+    /**
+     * 转换 - AiImageMidjourneyOperationsVO
+     *
+     * @param component
+     * @return
+     */
+    AiImageMidjourneyOperationsVO convertAiImageMidjourneyOperationsVO(MidjourneyMessage.Component component);
 }

+ 57 - 2
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/midjourneyHandler/YuDaoMidjourneyMessageHandler.java

@@ -5,14 +5,21 @@ import cn.hutool.core.util.StrUtil;
 import cn.iocoder.yudao.framework.ai.midjourney.MidjourneyMessage;
 import cn.iocoder.yudao.framework.ai.midjourney.constants.MidjourneyGennerateStatusEnum;
 import cn.iocoder.yudao.framework.ai.midjourney.webSocket.MidjourneyMessageHandler;
+import cn.iocoder.yudao.framework.common.util.json.JsonUtils;
+import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageMidjourneyOperationsVO;
+import cn.iocoder.yudao.module.ai.convert.AiImageConvert;
 import cn.iocoder.yudao.module.ai.dal.dataobject.image.AiImageDO;
-import cn.iocoder.yudao.module.ai.enums.AiImageDrawingStatusEnum;
 import cn.iocoder.yudao.module.ai.dal.mysql.AiImageMapper;
+import cn.iocoder.yudao.module.ai.enums.AiImageDrawingStatusEnum;
 import com.alibaba.fastjson2.JSON;
 import lombok.AllArgsConstructor;
 import lombok.extern.slf4j.Slf4j;
 import org.springframework.stereotype.Component;
 
+import java.util.Collections;
+import java.util.List;
+import java.util.stream.Collectors;
+
 /**
  * yudao message handler
  *
@@ -45,6 +52,36 @@ public class YuDaoMidjourneyMessageHandler implements MidjourneyMessageHandler {
         if (StrUtil.isBlank(midjourneyMessage.getNonce())) {
             return;
         }
+        // 根据 Embeds 来判断是否异常
+        if (CollUtil.isEmpty(midjourneyMessage.getEmbeds())) {
+            successHandler(midjourneyMessage);
+        } else {
+            errorHandler(midjourneyMessage);
+        }
+    }
+
+    private void errorHandler(MidjourneyMessage midjourneyMessage) {
+        // image 编号
+        Long aiImageId = Long.valueOf(midjourneyMessage.getNonce());
+        // 获取 error message
+        String errorMessage = getErrorMessage(midjourneyMessage);
+        aiImageMapper.updateById(
+                new AiImageDO()
+                        .setId(aiImageId)
+                        .setDrawingErrorMessage(errorMessage)
+                        .setDrawingStatus(AiImageDrawingStatusEnum.FAIL.getStatus())
+        );
+    }
+
+    private String getErrorMessage(MidjourneyMessage midjourneyMessage) {
+        StringBuilder errorMessage = new StringBuilder();
+        for (MidjourneyMessage.Embed embed : midjourneyMessage.getEmbeds()) {
+            errorMessage.append(embed.getDescription());
+        }
+        return errorMessage.toString();
+    }
+
+    private void successHandler(MidjourneyMessage midjourneyMessage) {
         // 获取id
         Long aiImageId = Long.valueOf(midjourneyMessage.getNonce());
         // 获取生成 url
@@ -59,14 +96,32 @@ public class YuDaoMidjourneyMessageHandler implements MidjourneyMessageHandler {
             drawingStatusEnum = AiImageDrawingStatusEnum.COMPLETE;
         } else if (MidjourneyGennerateStatusEnum.IN_PROGRESS.getStatus().equals(generateStatus)) {
             drawingStatusEnum = AiImageDrawingStatusEnum.IN_PROGRESS;
-        }  else if (MidjourneyGennerateStatusEnum.WAITING.getStatus().equals(generateStatus)) {
+        } else if (MidjourneyGennerateStatusEnum.WAITING.getStatus().equals(generateStatus)) {
             drawingStatusEnum = AiImageDrawingStatusEnum.WAITING;
         }
+        // 获取 midjourneyOperations
+        List<AiImageMidjourneyOperationsVO> midjourneyOperations = getMidjourneyOperationsList(midjourneyMessage);
+        // 更新数据库
         aiImageMapper.updateById(
                 new AiImageDO()
                         .setId(aiImageId)
                         .setDrawingImageUrl(imageUrl)
                         .setDrawingStatus(drawingStatusEnum == null ? null : drawingStatusEnum.getStatus())
+                        .setMjMessageId(midjourneyMessage.getId())
+                        .setMjOperations(JsonUtils.toJsonString(midjourneyOperations))
         );
     }
+
+    private List<AiImageMidjourneyOperationsVO> getMidjourneyOperationsList(MidjourneyMessage midjourneyMessage) {
+        // 为空直接返回
+        if (CollUtil.isEmpty(midjourneyMessage.getComponents())) {
+            return Collections.emptyList();
+        }
+        // 将 component 转成 AiImageMidjourneyOperationsVO
+        return midjourneyMessage.getComponents().stream()
+                .map(componentType -> componentType.getComponents().stream()
+                        .map(AiImageConvert.INSTANCE::convertAiImageMidjourneyOperationsVO)
+                        .collect(Collectors.toList()))
+                .toList().stream().flatMap(List::stream).toList();
+    }
 }