|
@@ -2,7 +2,7 @@ package cn.iocoder.yudao.module.ai.service.image;
|
|
|
|
|
|
import cn.hutool.core.bean.BeanUtil;
|
|
|
import cn.hutool.core.codec.Base64;
|
|
|
-import cn.hutool.core.exceptions.ExceptionUtil;
|
|
|
+import cn.hutool.core.collection.CollUtil;
|
|
|
import cn.hutool.core.map.MapUtil;
|
|
|
import cn.hutool.core.util.ObjUtil;
|
|
|
import cn.hutool.core.util.StrUtil;
|
|
@@ -14,8 +14,7 @@ import cn.iocoder.yudao.framework.common.pojo.PageParam;
|
|
|
import cn.iocoder.yudao.framework.common.pojo.PageResult;
|
|
|
import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
|
|
|
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageDrawReqVO;
|
|
|
-import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageMidjourneyImagineReqVO;
|
|
|
-import cn.iocoder.yudao.module.ai.controller.admin.image.vo.MidjourneyNotifyReqVO;
|
|
|
+import cn.iocoder.yudao.module.ai.controller.admin.image.vo.midjourney.AiImageMidjourneyImagineReqVO;
|
|
|
import cn.iocoder.yudao.module.ai.dal.dataobject.image.AiImageDO;
|
|
|
import cn.iocoder.yudao.module.ai.dal.mysql.image.AiImageMapper;
|
|
|
import cn.iocoder.yudao.module.ai.enums.image.AiImageStatusEnum;
|
|
@@ -29,15 +28,17 @@ import org.springframework.ai.image.ImagePrompt;
|
|
|
import org.springframework.ai.image.ImageResponse;
|
|
|
import org.springframework.ai.openai.OpenAiImageOptions;
|
|
|
import org.springframework.ai.stabilityai.api.StabilityAiImageOptions;
|
|
|
-import org.springframework.beans.factory.annotation.Autowired;
|
|
|
import org.springframework.beans.factory.annotation.Value;
|
|
|
import org.springframework.scheduling.annotation.Async;
|
|
|
import org.springframework.stereotype.Service;
|
|
|
import org.springframework.transaction.annotation.Transactional;
|
|
|
|
|
|
import java.util.List;
|
|
|
+import java.util.Map;
|
|
|
|
|
|
import static cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionUtil.exception;
|
|
|
+import static cn.iocoder.yudao.framework.common.util.collection.CollectionUtils.convertMap;
|
|
|
+import static cn.iocoder.yudao.framework.common.util.collection.CollectionUtils.convertSet;
|
|
|
import static cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants.*;
|
|
|
|
|
|
/**
|
|
@@ -51,12 +52,16 @@ public class AiImageServiceImpl implements AiImageService {
|
|
|
|
|
|
@Resource
|
|
|
private AiImageMapper imageMapper;
|
|
|
+
|
|
|
@Resource
|
|
|
private FileApi fileApi;
|
|
|
+
|
|
|
@Resource
|
|
|
private AiApiKeyService apiKeyService;
|
|
|
- @Autowired(required = false)
|
|
|
+
|
|
|
+ @Resource
|
|
|
private MidjourneyApi midjourneyApi;
|
|
|
+
|
|
|
@Value("${ai.midjourney-proxy.notifyUrl:http://127.0.0.1:48080/admin-api/ai/image/midjourney-notify}")
|
|
|
private String midjourneyNotifyUrl;
|
|
|
|
|
@@ -74,7 +79,7 @@ public class AiImageServiceImpl implements AiImageService {
|
|
|
public Long drawImage(Long userId, AiImageDrawReqVO drawReqVO) {
|
|
|
// 1. 保存数据库
|
|
|
AiImageDO image = BeanUtils.toBean(drawReqVO, AiImageDO.class).setUserId(userId).setPublicStatus(false)
|
|
|
- .setWidth(drawReqVO.getWidth()).setHeight(drawReqVO.getHeight()).setStatus(AiImageStatusEnum.IN_PROGRESS.getStatus());
|
|
|
+ .setStatus(AiImageStatusEnum.IN_PROGRESS.getStatus());
|
|
|
imageMapper.insert(image);
|
|
|
// 2. 异步绘制,后续前端通过返回的 id 进行轮询结果
|
|
|
getSelf().executeDrawImage(image, drawReqVO);
|
|
@@ -122,101 +127,121 @@ public class AiImageServiceImpl implements AiImageService {
|
|
|
}
|
|
|
|
|
|
@Override
|
|
|
- @Transactional(rollbackFor = Exception.class)
|
|
|
- public Long midjourneyImagine(Long userId, AiImageMidjourneyImagineReqVO req) {
|
|
|
- // 1、构建 AiImageDO 并 保存
|
|
|
- AiImageDO image = new AiImageDO()
|
|
|
- .setUserId(userId)
|
|
|
- .setPrompt(req.getPrompt())
|
|
|
- .setPlatform(AiPlatformEnum.MIDJOURNEY.getPlatform())
|
|
|
- .setModel(req.getModel())
|
|
|
- .setWidth(req.getWidth())
|
|
|
- .setHeight(req.getHeight())
|
|
|
- .setStatus(AiImageStatusEnum.IN_PROGRESS.getStatus());
|
|
|
- imageMapper.insert(image);
|
|
|
+ public void deleteImageMy(Long id, Long userId) {
|
|
|
+ // 1. 校验是否存在
|
|
|
+ AiImageDO image = validateImageExists(id);
|
|
|
+ if (ObjUtil.notEqual(image.getUserId(), userId)) {
|
|
|
+ throw exception(AI_IMAGE_NOT_EXISTS);
|
|
|
+ }
|
|
|
+ // 2. 删除记录
|
|
|
+ imageMapper.deleteById(id);
|
|
|
+ }
|
|
|
|
|
|
- // 3、调用 MidjourneyProxy 提交任务
|
|
|
+ private AiImageDO validateImageExists(Long id) {
|
|
|
+ AiImageDO image = imageMapper.selectById(id);
|
|
|
+ if (image == null) {
|
|
|
+ throw exception(AI_IMAGE_NOT_EXISTS);
|
|
|
+ }
|
|
|
+ return image;
|
|
|
+ }
|
|
|
|
|
|
- // 3.1、设置 midjourney 扩展参数
|
|
|
- MidjourneyApi.ImagineRequest imagineRequest = new MidjourneyApi.ImagineRequest(null, midjourneyNotifyUrl, req.getPrompt(),
|
|
|
- buildParams(req.getWidth(), req.getHeight(), req.getVersion(),
|
|
|
- MidjourneyApi.ModelEnum.valueOfModel(req.getModel())));
|
|
|
- // 3.2、提交绘画请求
|
|
|
- // TODO @fan:5 这里,失败的情况,到底抛出异常,还是 RespVO,可以参考 OpenAI 的 API 封装
|
|
|
- MidjourneyApi.SubmitResponse submitResponse = midjourneyApi.imagine(imagineRequest);
|
|
|
+ // ================ midjourney 专属 ================
|
|
|
|
|
|
- // 4、保存任务 id (状态码: 1(提交成功), 21(已存在), 22(排队中), other(错误))
|
|
|
- if (!MidjourneyApi.SubmitCodeEnum.SUCCESS_CODES.contains(submitResponse.code())) {
|
|
|
- if (submitResponse.description().contains("quota_not_enough")) {
|
|
|
- throw exception(AI_IMAGE_SYSTEM_ACCOUNT_INSUFFICIENT_BALANCE, submitResponse.description());
|
|
|
- }
|
|
|
- throw exception(AI_IMAGE_MIDJOURNEY_SUBMIT_FAIL, submitResponse.description());
|
|
|
+ @Override
|
|
|
+ @Transactional(rollbackFor = Exception.class)
|
|
|
+ public Long midjourneyImagine(Long userId, AiImageMidjourneyImagineReqVO reqVO) {
|
|
|
+ // 1. 保存数据库
|
|
|
+ AiImageDO image = BeanUtils.toBean(reqVO, AiImageDO.class).setUserId(userId).setPublicStatus(false)
|
|
|
+ .setStatus(AiImageStatusEnum.IN_PROGRESS.getStatus())
|
|
|
+ .setPlatform(AiPlatformEnum.MIDJOURNEY.getPlatform());
|
|
|
+ imageMapper.insert(image);
|
|
|
+
|
|
|
+ // 2. 调用 Midjourney Proxy 提交任务
|
|
|
+ MidjourneyApi.ImagineRequest imagineRequest = new MidjourneyApi.ImagineRequest(
|
|
|
+ null, midjourneyNotifyUrl, reqVO.getPrompt(),
|
|
|
+ MidjourneyApi.ImagineRequest.buildState(reqVO.getWidth(), reqVO.getHeight(), reqVO.getVersion(), reqVO.getModel()));
|
|
|
+ MidjourneyApi.SubmitResponse imagineResponse = midjourneyApi.imagine(imagineRequest);
|
|
|
+
|
|
|
+ // 3. 情况一【失败】:抛出业务异常
|
|
|
+ if (!MidjourneyApi.SubmitCodeEnum.SUCCESS_CODES.contains(imagineResponse.code())) {
|
|
|
+ String description = imagineResponse.description().contains("quota_not_enough") ?
|
|
|
+ "账户余额不足" : imagineResponse.description();
|
|
|
+ throw exception(AI_IMAGE_MIDJOURNEY_SUBMIT_FAIL, description);
|
|
|
}
|
|
|
- // 4.1、更新 taskId 和参数
|
|
|
+
|
|
|
+ // 4. 情况二【成功】:更新 taskId 和参数
|
|
|
imageMapper.updateById(new AiImageDO()
|
|
|
.setId(image.getId())
|
|
|
- .setTaskId(submitResponse.result())
|
|
|
- .setOptions(BeanUtil.beanToMap(req))
|
|
|
+ .setTaskId(imagineResponse.result())
|
|
|
+ .setOptions(BeanUtil.beanToMap(reqVO))
|
|
|
);
|
|
|
return image.getId();
|
|
|
}
|
|
|
|
|
|
-
|
|
|
@Override
|
|
|
- public void deleteImageMy(Long id, Long userId) {
|
|
|
- // 1. 校验是否存在
|
|
|
- AiImageDO image = validateImageExists(id);
|
|
|
- if (ObjUtil.notEqual(image.getUserId(), userId)) {
|
|
|
- throw exception(AI_IMAGE_NOT_EXISTS);
|
|
|
+ public Integer midjourneySync() {
|
|
|
+ // 1.1 获取 Midjourney 平台,状态在 “进行中” 的 image
|
|
|
+ List<AiImageDO> imageList = imageMapper.selectListByStatusAndPlatform(
|
|
|
+ AiImageStatusEnum.IN_PROGRESS.getStatus(), AiPlatformEnum.MIDJOURNEY.getPlatform());
|
|
|
+ if (CollUtil.isEmpty(imageList)) {
|
|
|
+ return 0;
|
|
|
}
|
|
|
- // 2. 删除记录
|
|
|
- imageMapper.deleteById(id);
|
|
|
+ // 1.2 调用 Midjourney Proxy 获取任务进展
|
|
|
+ List<MidjourneyApi.Notify> taskList = midjourneyApi.getTaskList(convertSet(imageList, AiImageDO::getTaskId));
|
|
|
+ Map<String, MidjourneyApi.Notify> taskMap = convertMap(taskList, MidjourneyApi.Notify::id);
|
|
|
+
|
|
|
+ // 2. 逐个处理,更新进展
|
|
|
+ int count = 0;
|
|
|
+ for (AiImageDO image : imageList) {
|
|
|
+ MidjourneyApi.Notify notify = taskMap.get(image.getTaskId());
|
|
|
+ if (notify == null) {
|
|
|
+ log.error("[midjourneySync][image({}) 查询不到进展]", image);
|
|
|
+ continue;
|
|
|
+ }
|
|
|
+ count++;
|
|
|
+ updateMidjourneyStatus(image, notify);
|
|
|
+ }
|
|
|
+ return count;
|
|
|
}
|
|
|
|
|
|
@Override
|
|
|
- public void midjourneyNotify(MidjourneyNotifyReqVO notifyReqVO) {
|
|
|
- // 1、根据 job id 查询关联的 image
|
|
|
- AiImageDO image = imageMapper.selectByJobId(notifyReqVO.getId());
|
|
|
+ public void midjourneyNotify(MidjourneyApi.Notify notify) {
|
|
|
+ // 1. 校验 image 存在
|
|
|
+ AiImageDO image = imageMapper.selectByTaskId(notify.id());
|
|
|
if (image == null) {
|
|
|
- log.warn("midjourneyNotify 回调的 jobId 不存在! jobId: {}", notifyReqVO.getId());
|
|
|
+ log.warn("[midjourneyNotify][回调任务({}) 不存在]", notify.id());
|
|
|
+ return;
|
|
|
}
|
|
|
- // 2、转换状态
|
|
|
- AiImageDO updateImage = buildUpdateImage(image.getId(), notifyReqVO);
|
|
|
- // 3、更新 image 状态
|
|
|
- imageMapper.updateById(updateImage);
|
|
|
+ // 2. 更新状态
|
|
|
+ updateMidjourneyStatus(image, notify);
|
|
|
}
|
|
|
|
|
|
- public AiImageDO buildUpdateImage(Long imageId, MidjourneyNotifyReqVO notifyReqVO) {
|
|
|
- // 1、转换状态
|
|
|
- String imageStatus = null;
|
|
|
- if (StrUtil.isNotBlank(notifyReqVO.getStatus())) {
|
|
|
- MidjourneyApi.TaskStatusEnum taskStatusEnum = MidjourneyApi.TaskStatusEnum.valueOf(notifyReqVO.getStatus());
|
|
|
+ private void updateMidjourneyStatus(AiImageDO image, MidjourneyApi.Notify notify) {
|
|
|
+ // 1. 转换状态
|
|
|
+ Integer status = null;
|
|
|
+ if (StrUtil.isNotBlank(notify.status())) {
|
|
|
+ MidjourneyApi.TaskStatusEnum taskStatusEnum = MidjourneyApi.TaskStatusEnum.valueOf(notify.status());
|
|
|
if (MidjourneyApi.TaskStatusEnum.SUCCESS == taskStatusEnum) {
|
|
|
- imageStatus = AiImageStatusEnum.SUCCESS.getStatus();
|
|
|
+ status = AiImageStatusEnum.SUCCESS.getStatus();
|
|
|
} else if (MidjourneyApi.TaskStatusEnum.FAILURE == taskStatusEnum) {
|
|
|
- imageStatus = AiImageStatusEnum.FAIL.getStatus();
|
|
|
+ status = AiImageStatusEnum.FAIL.getStatus();
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- // 2、上传图片
|
|
|
- String filePath = null;
|
|
|
- if (!StrUtil.isBlank(notifyReqVO.getImageUrl())) {
|
|
|
+ // 2. 上传图片
|
|
|
+ String picUrl = null;
|
|
|
+ if (StrUtil.isNotBlank(notify.imageUrl())) {
|
|
|
try {
|
|
|
- filePath = fileApi.createFile(HttpUtil.downloadBytes(notifyReqVO.getImageUrl()));
|
|
|
+ picUrl = fileApi.createFile(HttpUtil.downloadBytes(notify.imageUrl()));
|
|
|
} catch (Exception e) {
|
|
|
- log.warn("midjourneyNotify 图片上传失败! {} 异常:{}", notifyReqVO.getImageUrl(), ExceptionUtil.getMessage(e));
|
|
|
+ picUrl = notify.imageUrl();
|
|
|
+ log.warn("[updateMidjourneyStatus][图片({}) 地址({}) 上传失败]", image.getId(), notify.imageUrl(), e);
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- // 3、更新 image 状态
|
|
|
- return new AiImageDO()
|
|
|
- .setId(imageId)
|
|
|
- .setStatus(imageStatus)
|
|
|
- .setPicUrl(filePath)
|
|
|
- .setProgress(notifyReqVO.getProgress())
|
|
|
- .setResponse(notifyReqVO)
|
|
|
- .setButtons(notifyReqVO.getButtons())
|
|
|
- .setErrorMessage(notifyReqVO.getFailReason());
|
|
|
+ // 3. 更新 image 状态
|
|
|
+ imageMapper.updateById(new AiImageDO().setId(image.getId()).setStatus(status)
|
|
|
+ .setPicUrl(picUrl).setButtons(notify.buttons()).setErrorMessage(notify.failReason()));
|
|
|
}
|
|
|
|
|
|
@Override
|
|
@@ -236,7 +261,6 @@ public class AiImageServiceImpl implements AiImageService {
|
|
|
|
|
|
// 5、新增 image 记录(根据 image 新增一个)
|
|
|
AiImageDO newImage = new AiImageDO();
|
|
|
- newImage.setId(null);
|
|
|
newImage.setUserId(image.getUserId());
|
|
|
newImage.setPrompt(image.getPrompt());
|
|
|
|
|
@@ -248,20 +272,15 @@ public class AiImageServiceImpl implements AiImageService {
|
|
|
newImage.setStatus(AiImageStatusEnum.IN_PROGRESS.getStatus());
|
|
|
newImage.setPublicStatus(image.getPublicStatus());
|
|
|
|
|
|
- newImage.setPicUrl(null);
|
|
|
- newImage.setProgress(null);
|
|
|
- newImage.setButtons(null);
|
|
|
newImage.setOptions(image.getOptions());
|
|
|
- newImage.setResponse(image.getResponse());
|
|
|
newImage.setTaskId(submitResponse.result());
|
|
|
- newImage.setErrorMessage(null);
|
|
|
imageMapper.insert(newImage);
|
|
|
}
|
|
|
|
|
|
- private static void validateCustomId(String customId, List<MidjourneyNotifyReqVO.Button> buttons) {
|
|
|
+ private static void validateCustomId(String customId, List<MidjourneyApi.Button> buttons) {
|
|
|
boolean isTrue = false;
|
|
|
- for (MidjourneyNotifyReqVO.Button button : buttons) {
|
|
|
- if (button.getCustomId().equals(customId)) {
|
|
|
+ for (MidjourneyApi.Button button : buttons) {
|
|
|
+ if (button.customId().equals(customId)) {
|
|
|
isTrue = true;
|
|
|
break;
|
|
|
}
|
|
@@ -271,14 +290,6 @@ public class AiImageServiceImpl implements AiImageService {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- private AiImageDO validateImageExists(Long id) {
|
|
|
- AiImageDO image = imageMapper.selectById(id);
|
|
|
- if (image == null) {
|
|
|
- throw exception(AI_IMAGE_NOT_EXISTS);
|
|
|
- }
|
|
|
- return image;
|
|
|
- }
|
|
|
-
|
|
|
/**
|
|
|
* 获得自身的代理对象,解决 AOP 生效问题
|
|
|
*
|
|
@@ -288,28 +299,4 @@ public class AiImageServiceImpl implements AiImageService {
|
|
|
return SpringUtil.getBean(getClass());
|
|
|
}
|
|
|
|
|
|
- // TODO @fan:这个是不是应该放在 MJ API 的封装里面搞哈?
|
|
|
-
|
|
|
- /**
|
|
|
- * 构建 Midjourney 自定义参数
|
|
|
- *
|
|
|
- * @param width
|
|
|
- * @param height
|
|
|
- * @param version
|
|
|
- * @param model
|
|
|
- * @return
|
|
|
- */
|
|
|
- private String buildParams(Integer width, Integer height, String version, MidjourneyApi.ModelEnum model) {
|
|
|
- StringBuilder params = new StringBuilder();
|
|
|
- // --ar 来设置尺寸
|
|
|
- params.append(String.format(" --ar %s:%s ", width, height));
|
|
|
- // --niji 模型
|
|
|
- if (MidjourneyApi.ModelEnum.NIJI == model) {
|
|
|
- params.append(String.format(" --niji %s ", version));
|
|
|
- } else {
|
|
|
- // --v 版本
|
|
|
- params.append(String.format(" --v %s ", version));
|
|
|
- }
|
|
|
- return params.toString();
|
|
|
- }
|
|
|
}
|