Ver Fonte

【增加】图片绘画增加 绘画失败状态,并保存绘画错误信息,同时可以将错误信息返回给前端

cherishsince há 1 ano atrás
pai
commit
99d477ac0a

+ 35 - 0
yudao-module-ai/yudao-module-ai-api/src/main/java/cn/iocoder/yudao/module/ai/enums/AiChatDrawingStatusEnum.java

@@ -0,0 +1,35 @@
+package cn.iocoder.yudao.module.ai.enums;
+
+import lombok.AllArgsConstructor;
+import lombok.Getter;
+
+/**
+ * ai绘画状态
+ *
+ * @author fansili
+ * @time 2024/4/28 17:05
+ * @since 1.0
+ */
+@AllArgsConstructor
+@Getter
+public enum AiChatDrawingStatusEnum {
+
+    COMPLETE("complete", "完成"),
+    FAIL("fail", "失败"),
+
+    ;
+
+    private String status;
+
+    private String name;
+
+
+    public static AiChatDrawingStatusEnum valueOfStatus(String status) {
+        for (AiChatDrawingStatusEnum itemEnum : AiChatDrawingStatusEnum.values()) {
+            if (itemEnum.getStatus().equals(status)) {
+                return itemEnum;
+            }
+        }
+        throw new IllegalArgumentException("Invalid MessageType value: " + status);
+    }
+}

+ 3 - 0
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/AiImageDO.java

@@ -41,5 +41,8 @@ public class AiImageDO extends BaseDO {
     @Schema(description = "绘画图片地址")
     private String drawingImageUrl;
 
+    @Schema(description = "绘画错误信息")
+    private String drawingError;
+
 }
 

+ 31 - 9
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/impl/AiImageServiceImpl.java

@@ -1,15 +1,17 @@
 package cn.iocoder.yudao.module.ai.service.impl;
 
-import cn.iocoder.yudao.framework.ai.image.Image;
+import cn.iocoder.yudao.framework.ai.exception.AiException;
 import cn.iocoder.yudao.framework.ai.image.ImageGeneration;
 import cn.iocoder.yudao.framework.ai.image.ImagePrompt;
 import cn.iocoder.yudao.framework.ai.image.ImageResponse;
 import cn.iocoder.yudao.framework.ai.imageopenai.OpenAiImageClient;
 import cn.iocoder.yudao.framework.ai.imageopenai.OpenAiImageModelEnum;
 import cn.iocoder.yudao.framework.ai.imageopenai.OpenAiImageOptions;
+import cn.iocoder.yudao.framework.ai.imageopenai.OpenAiImageStyleEnum;
 import cn.iocoder.yudao.framework.security.core.util.SecurityFrameworkUtils;
 import cn.iocoder.yudao.module.ai.controller.Utf8SseEmitter;
 import cn.iocoder.yudao.module.ai.dal.dataobject.AiImageDO;
+import cn.iocoder.yudao.module.ai.enums.AiChatDrawingStatusEnum;
 import cn.iocoder.yudao.module.ai.mapper.AiImageMapper;
 import cn.iocoder.yudao.module.ai.service.AiImageService;
 import cn.iocoder.yudao.module.ai.vo.AiImageDallDrawingReq;
@@ -39,25 +41,43 @@ public class AiImageServiceImpl implements AiImageService {
     public void dallDrawing(AiImageDallDrawingReq req, Utf8SseEmitter sseEmitter) {
         // 获取 model
         OpenAiImageModelEnum openAiImageModelEnum = OpenAiImageModelEnum.valueOfModel(req.getModal());
+        OpenAiImageStyleEnum openAiImageStyleEnum = OpenAiImageStyleEnum.valueOfStyle(req.getStyle());
         //
         OpenAiImageOptions openAiImageOptions = new OpenAiImageOptions();
         openAiImageOptions.setModel(openAiImageModelEnum);
-        ImageResponse imageResponse = openAiImageClient.call(new ImagePrompt(req.getPrompt(), openAiImageOptions));
-        // 发送
-        ImageGeneration imageGeneration = imageResponse.getResult();
+        openAiImageOptions.setStyle(openAiImageStyleEnum);
+        openAiImageOptions.setSize(req.getSize());
+        ImageResponse imageResponse;
         try {
-            sseEmitter.send(imageGeneration, MediaType.APPLICATION_JSON);
+            imageResponse = openAiImageClient.call(new ImagePrompt(req.getPrompt(), openAiImageOptions));
+            // 发送
+            ImageGeneration imageGeneration = imageResponse.getResult();
+            // 发送信息
+            sendSseEmitter(sseEmitter, imageGeneration);
+            // 保存数据库
+            doSave(req, imageGeneration.getOutput().getUrl(), AiChatDrawingStatusEnum.COMPLETE, null);
+        } catch (AiException aiException) {
+            // 保存数据库
+            doSave(req, null, AiChatDrawingStatusEnum.FAIL, aiException.getMessage());
+            // 发送错误信息
+            sendSseEmitter(sseEmitter, aiException.getMessage());
+        }
+    }
+
+    private static void sendSseEmitter(Utf8SseEmitter sseEmitter, Object object) {
+        try {
+            sseEmitter.send(object, MediaType.APPLICATION_JSON);
         } catch (IOException e) {
             throw new RuntimeException(e);
         } finally {
             // 发送 complete
             sseEmitter.complete();
         }
-        Long loginUserId = SecurityFrameworkUtils.getLoginUserId();
-        //
+    }
+
+    private void doSave(AiImageDallDrawingReq req, String imageUrl, AiChatDrawingStatusEnum drawingStatusEnum, String drawingError) {
         // 保存数据库
-        Image output = imageGeneration.getOutput();
-        String imageUrl = output.getUrl();
+        Long loginUserId = SecurityFrameworkUtils.getLoginUserId();
         AiImageDO aiImageDO = new AiImageDO();
         aiImageDO.setId(null);
         aiImageDO.setPrompt(req.getPrompt());
@@ -65,6 +85,8 @@ public class AiImageServiceImpl implements AiImageService {
         aiImageDO.setModal(req.getModal());
         aiImageDO.setUserId(loginUserId);
         aiImageDO.setDrawingImageUrl(imageUrl);
+        aiImageDO.setDrawingStatus(drawingStatusEnum.getStatus());
+        aiImageDO.setDrawingError(drawingError);
         aiImageMapper.insert(aiImageDO);
     }
 }

+ 4 - 0
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/vo/AiImageDallDrawingReq.java

@@ -24,6 +24,10 @@ public class AiImageDallDrawingReq {
     @NotNull(message = "模型不能为空")
     private String modal;
 
+    @Schema(description = "图像生成的风格。可为vivid(生动)或natural(自然)")
+    @NotNull(message = "图像生成的风格,不能为空!")
+    private String style;
+
     @Schema(description = "生成图像的尺寸大小。对于dall-e-2模型,尺寸可为256x256, 512x512, 或 1024x1024。对于dall-e-3模型,尺寸可为1024x1024, 1792x1024, 或 1024x1792。")
     @NotNull(message = "size不能为空!")
     private String size;