|
@@ -5,14 +5,21 @@ import cn.hutool.core.util.StrUtil;
|
|
|
import cn.iocoder.yudao.framework.ai.midjourney.MidjourneyMessage;
|
|
|
import cn.iocoder.yudao.framework.ai.midjourney.constants.MidjourneyGennerateStatusEnum;
|
|
|
import cn.iocoder.yudao.framework.ai.midjourney.webSocket.MidjourneyMessageHandler;
|
|
|
+import cn.iocoder.yudao.framework.common.util.json.JsonUtils;
|
|
|
+import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageMidjourneyOperationsVO;
|
|
|
+import cn.iocoder.yudao.module.ai.convert.AiImageConvert;
|
|
|
import cn.iocoder.yudao.module.ai.dal.dataobject.image.AiImageDO;
|
|
|
-import cn.iocoder.yudao.module.ai.enums.AiImageDrawingStatusEnum;
|
|
|
import cn.iocoder.yudao.module.ai.dal.mysql.AiImageMapper;
|
|
|
+import cn.iocoder.yudao.module.ai.enums.AiImageDrawingStatusEnum;
|
|
|
import com.alibaba.fastjson2.JSON;
|
|
|
import lombok.AllArgsConstructor;
|
|
|
import lombok.extern.slf4j.Slf4j;
|
|
|
import org.springframework.stereotype.Component;
|
|
|
|
|
|
+import java.util.Collections;
|
|
|
+import java.util.List;
|
|
|
+import java.util.stream.Collectors;
|
|
|
+
|
|
|
/**
|
|
|
* yudao message handler
|
|
|
*
|
|
@@ -45,6 +52,36 @@ public class YuDaoMidjourneyMessageHandler implements MidjourneyMessageHandler {
|
|
|
if (StrUtil.isBlank(midjourneyMessage.getNonce())) {
|
|
|
return;
|
|
|
}
|
|
|
+ // 根据 Embeds 来判断是否异常
|
|
|
+ if (CollUtil.isEmpty(midjourneyMessage.getEmbeds())) {
|
|
|
+ successHandler(midjourneyMessage);
|
|
|
+ } else {
|
|
|
+ errorHandler(midjourneyMessage);
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ private void errorHandler(MidjourneyMessage midjourneyMessage) {
|
|
|
+ // image 编号
|
|
|
+ Long aiImageId = Long.valueOf(midjourneyMessage.getNonce());
|
|
|
+ // 获取 error message
|
|
|
+ String errorMessage = getErrorMessage(midjourneyMessage);
|
|
|
+ aiImageMapper.updateById(
|
|
|
+ new AiImageDO()
|
|
|
+ .setId(aiImageId)
|
|
|
+ .setDrawingErrorMessage(errorMessage)
|
|
|
+ .setDrawingStatus(AiImageDrawingStatusEnum.FAIL.getStatus())
|
|
|
+ );
|
|
|
+ }
|
|
|
+
|
|
|
+ private String getErrorMessage(MidjourneyMessage midjourneyMessage) {
|
|
|
+ StringBuilder errorMessage = new StringBuilder();
|
|
|
+ for (MidjourneyMessage.Embed embed : midjourneyMessage.getEmbeds()) {
|
|
|
+ errorMessage.append(embed.getDescription());
|
|
|
+ }
|
|
|
+ return errorMessage.toString();
|
|
|
+ }
|
|
|
+
|
|
|
+ private void successHandler(MidjourneyMessage midjourneyMessage) {
|
|
|
// 获取id
|
|
|
Long aiImageId = Long.valueOf(midjourneyMessage.getNonce());
|
|
|
// 获取生成 url
|
|
@@ -59,14 +96,32 @@ public class YuDaoMidjourneyMessageHandler implements MidjourneyMessageHandler {
|
|
|
drawingStatusEnum = AiImageDrawingStatusEnum.COMPLETE;
|
|
|
} else if (MidjourneyGennerateStatusEnum.IN_PROGRESS.getStatus().equals(generateStatus)) {
|
|
|
drawingStatusEnum = AiImageDrawingStatusEnum.IN_PROGRESS;
|
|
|
- } else if (MidjourneyGennerateStatusEnum.WAITING.getStatus().equals(generateStatus)) {
|
|
|
+ } else if (MidjourneyGennerateStatusEnum.WAITING.getStatus().equals(generateStatus)) {
|
|
|
drawingStatusEnum = AiImageDrawingStatusEnum.WAITING;
|
|
|
}
|
|
|
+ // 获取 midjourneyOperations
|
|
|
+ List<AiImageMidjourneyOperationsVO> midjourneyOperations = getMidjourneyOperationsList(midjourneyMessage);
|
|
|
+ // 更新数据库
|
|
|
aiImageMapper.updateById(
|
|
|
new AiImageDO()
|
|
|
.setId(aiImageId)
|
|
|
.setDrawingImageUrl(imageUrl)
|
|
|
.setDrawingStatus(drawingStatusEnum == null ? null : drawingStatusEnum.getStatus())
|
|
|
+ .setMjMessageId(midjourneyMessage.getId())
|
|
|
+ .setMjOperations(JsonUtils.toJsonString(midjourneyOperations))
|
|
|
);
|
|
|
}
|
|
|
+
|
|
|
+ private List<AiImageMidjourneyOperationsVO> getMidjourneyOperationsList(MidjourneyMessage midjourneyMessage) {
|
|
|
+ // 为空直接返回
|
|
|
+ if (CollUtil.isEmpty(midjourneyMessage.getComponents())) {
|
|
|
+ return Collections.emptyList();
|
|
|
+ }
|
|
|
+ // 将 component 转成 AiImageMidjourneyOperationsVO
|
|
|
+ return midjourneyMessage.getComponents().stream()
|
|
|
+ .map(componentType -> componentType.getComponents().stream()
|
|
|
+ .map(AiImageConvert.INSTANCE::convertAiImageMidjourneyOperationsVO)
|
|
|
+ .collect(Collectors.toList()))
|
|
|
+ .toList().stream().flatMap(List::stream).toList();
|
|
|
+ }
|
|
|
}
|