|
@@ -15,6 +15,7 @@ import cn.iocoder.yudao.module.ai.client.MidjourneyProxyClient;
|
|
|
import cn.iocoder.yudao.module.ai.client.enums.MidjourneyModelEnum;
|
|
|
import cn.iocoder.yudao.module.ai.client.enums.MidjourneySubmitCodeEnum;
|
|
|
import cn.iocoder.yudao.module.ai.client.enums.MidjourneyTaskStatusEnum;
|
|
|
+import cn.iocoder.yudao.module.ai.client.vo.MidjourneyActionReqVO;
|
|
|
import cn.iocoder.yudao.module.ai.client.vo.MidjourneyImagineReqVO;
|
|
|
import cn.iocoder.yudao.module.ai.client.vo.MidjourneyNotifyReqVO;
|
|
|
import cn.iocoder.yudao.module.ai.client.vo.MidjourneySubmitRespVO;
|
|
@@ -39,9 +40,10 @@ import org.springframework.scheduling.annotation.Async;
|
|
|
import org.springframework.stereotype.Service;
|
|
|
import org.springframework.transaction.annotation.Transactional;
|
|
|
|
|
|
+import java.util.List;
|
|
|
+
|
|
|
import static cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionUtil.exception;
|
|
|
-import static cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants.AI_IMAGE_MIDJOURNEY_SUBMIT_FAIL;
|
|
|
-import static cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants.AI_IMAGE_NOT_EXISTS;
|
|
|
+import static cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants.*;
|
|
|
|
|
|
/**
|
|
|
* AI 绘画 Service 实现类
|
|
@@ -136,30 +138,21 @@ public class AiImageServiceImpl implements AiImageService {
|
|
|
aiImageDO.setUserId(loginUserId);
|
|
|
aiImageDO.setPrompt(req.getPrompt());
|
|
|
aiImageDO.setPlatform(AiPlatformEnum.MIDJOURNEY.getPlatform());
|
|
|
- // todo @范 平台需要转换(mj 模型一般分版本)
|
|
|
aiImageDO.setModel(null);
|
|
|
aiImageDO.setWidth(null);
|
|
|
aiImageDO.setHeight(null);
|
|
|
aiImageDO.setStatus(AiImageStatusEnum.IN_PROGRESS.getStatus());
|
|
|
-
|
|
|
// 2、保存 image
|
|
|
imageMapper.insert(aiImageDO);
|
|
|
-
|
|
|
// 3、调用 MidjourneyProxy 提交任务
|
|
|
MidjourneyImagineReqVO imagineReqVO = BeanUtils.toBean(req, MidjourneyImagineReqVO.class);
|
|
|
imagineReqVO.setNotifyHook(midjourneyNotifyUrl);
|
|
|
- // 设置 midjourney 扩展参数
|
|
|
- // --ar 来设置尺寸
|
|
|
- String midjourneySizeParam = String.format(" --ar %s:%s ", req.getWidth(), req.getHeight());
|
|
|
- // --v 版本
|
|
|
- String midjourneyVersionParam = String.format(" --v %s ", req.getVersion());
|
|
|
- // --niji 模型
|
|
|
- MidjourneyModelEnum midjourneyModelEnum = MidjourneyModelEnum.valueOfModel(req.getModel());
|
|
|
- String midjourneyNijiParam = MidjourneyModelEnum.NIJI == midjourneyModelEnum ? " --niji " : "";
|
|
|
- // 设置参数
|
|
|
- imagineReqVO.setState(midjourneySizeParam.concat(midjourneyVersionParam).concat(midjourneyNijiParam));
|
|
|
+ // 4、设置 midjourney 扩展参数
|
|
|
+ imagineReqVO.setState(buildParams(req.getWidth(),
|
|
|
+ req.getHeight(), req.getVersion(), MidjourneyModelEnum.valueOfModel(req.getModel())));
|
|
|
+ // 5、提交绘画请求
|
|
|
MidjourneySubmitRespVO submitRespVO = midjourneyProxyClient.imagine(imagineReqVO);
|
|
|
- // 4、保存任务 id (状态码: 1(提交成功), 21(已存在), 22(排队中), other(错误))
|
|
|
+ // 6、保存任务 id (状态码: 1(提交成功), 21(已存在), 22(排队中), other(错误))
|
|
|
if (!MidjourneySubmitCodeEnum.SUCCESS_CODES.contains(submitRespVO.getCode())) {
|
|
|
throw exception(AI_IMAGE_MIDJOURNEY_SUBMIT_FAIL, submitRespVO.getDescription());
|
|
|
}
|
|
@@ -170,6 +163,8 @@ public class AiImageServiceImpl implements AiImageService {
|
|
|
return aiImageDO.getId();
|
|
|
}
|
|
|
|
|
|
+
|
|
|
+
|
|
|
@Override
|
|
|
public void deleteImageMy(Long id, Long userId) {
|
|
|
// 1. 校验是否存在
|
|
@@ -182,7 +177,7 @@ public class AiImageServiceImpl implements AiImageService {
|
|
|
}
|
|
|
|
|
|
@Override
|
|
|
- public Boolean midjourneyNotify(Long loginUserId, MidjourneyNotifyReqVO notifyReqVO) {
|
|
|
+ public Boolean midjourneyNotify(MidjourneyNotifyReqVO notifyReqVO) {
|
|
|
// 1、根据 job id 查询关联的 image
|
|
|
AiImageDO image = imageMapper.selectByJobId(notifyReqVO.getId());
|
|
|
if (image == null) {
|
|
@@ -220,6 +215,34 @@ public class AiImageServiceImpl implements AiImageService {
|
|
|
return true;
|
|
|
}
|
|
|
|
|
|
+ @Override
|
|
|
+ @Transactional(rollbackFor = Exception.class)
|
|
|
+ public Boolean midjourneyAction(Long loginUserId, Long imageId, String customId) {
|
|
|
+ // 1、检查 image
|
|
|
+ AiImageDO aiImageDO = validateImageExists(imageId);
|
|
|
+ // 2、检查 customId
|
|
|
+ if (!validateCustomId(customId, aiImageDO.getButtons())) {
|
|
|
+ throw exception(AI_IMAGE_CUSTOM_ID_NOT_EXISTS);
|
|
|
+ }
|
|
|
+ // 3、调用 midjourney proxy
|
|
|
+ midjourneyProxyClient.action(
|
|
|
+ new MidjourneyActionReqVO()
|
|
|
+ .setCustomId(customId)
|
|
|
+ .setTaskId(aiImageDO.getJobId())
|
|
|
+ .setNotifyHook(midjourneyNotifyUrl)
|
|
|
+ );
|
|
|
+ return Boolean.TRUE;
|
|
|
+ }
|
|
|
+
|
|
|
+ private static boolean validateCustomId(String customId, List<MidjourneyNotifyReqVO.Button> buttons) {
|
|
|
+ for (MidjourneyNotifyReqVO.Button button : buttons) {
|
|
|
+ if (button.getCustomId().equals(customId)) {
|
|
|
+ return true;
|
|
|
+ }
|
|
|
+ }
|
|
|
+ return false;
|
|
|
+ }
|
|
|
+
|
|
|
private AiImageDO validateImageExists(Long id) {
|
|
|
AiImageDO image = imageMapper.selectById(id);
|
|
|
if (image == null) {
|
|
@@ -237,4 +260,25 @@ public class AiImageServiceImpl implements AiImageService {
|
|
|
return SpringUtil.getBean(getClass());
|
|
|
}
|
|
|
|
|
|
+ /**
|
|
|
+ * 构建 Midjourney 自定义参数
|
|
|
+ *
|
|
|
+ * @param width
|
|
|
+ * @param height
|
|
|
+ * @param version
|
|
|
+ * @param model
|
|
|
+ * @return
|
|
|
+ */
|
|
|
+ private String buildParams(String width, String height, String version, MidjourneyModelEnum model) {
|
|
|
+ StringBuilder params = new StringBuilder();
|
|
|
+ // --ar 来设置尺寸
|
|
|
+ params.append(String.format(" --ar %s:%s ", width, height));
|
|
|
+ // --v 版本
|
|
|
+ params.append(String.format(" --v %s ", version));
|
|
|
+ // --niji 模型
|
|
|
+ if (MidjourneyModelEnum.NIJI == model) {
|
|
|
+ params.append(" --niji ");
|
|
|
+ }
|
|
|
+ return params.toString();
|
|
|
+ }
|
|
|
}
|