Explorar el Código

增加mj 图片生成消息转换

cherishsince hace 1 año
padre
commit
f2b9c14819

+ 124 - 0
yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/midjourney/MjMessage.java

@@ -0,0 +1,124 @@
+package cn.iocoder.yudao.framework.ai.midjourney;
+
+import lombok.Data;
+import lombok.experimental.Accessors;
+
+import java.util.List;
+
+@Data
+@Accessors(chain = true)
+public class MjMessage {
+
+	/**
+	 * id是一个重要的字段,在同时生成多个的时候,可以区分生成信息
+	 */
+	private String id;
+	/**
+	 * 现在已知:
+	 * 0:我们发送的消息,和指令
+	 * 20: mj生成图片发送过程中
+	 */
+	private Integer type;
+	/**
+	 * content
+	 */
+	private Content content;
+	/**
+	 * 图片生成完成才有
+	 */
+	private List<ComponentType> components;
+	/**
+	 * 生成过程中如果有,预展示图片,这里会有
+	 */
+	private List<Attachment> attachments;
+	/**
+	 * 原始数据(discard 返回的原始数据)
+	 */
+	private String rawData;
+	/**
+	 * 生成状态(用于区分生成状态)
+	 * 1、等待
+	 * 2、进行中
+	 * 3、完成
+	 * {@link cn.iocoder.yudao.framework.ai.midjourney.constants.MjGennerateStatusEnum}
+	 */
+	private String generateStatus;
+
+	@Data
+	@Accessors(chain = true)
+	public static class ComponentType {
+
+		private int type;
+
+		private List<Component> components;
+	}
+
+	@Data
+	@Accessors(chain = true)
+	public static class Component {
+		/**
+		 * 自定义ID,用于唯一标识特定交互动作及其上下文信息。
+		 */
+		private String customId;
+
+		/**
+		 * 样式编号,用于确定按钮的样式外观。
+		 * 在某些应用中,例如Discord,2可能表示一种特定的颜色或形状的按钮。
+		 */
+		private int style;
+
+		/**
+		 * 按钮的标签文本,用户可见的内容。
+		 */
+		private String label;
+
+		/**
+		 * 组件类型,此处为2可能表示这是一种特定类型的交互组件,
+		 * 如在Discord API中,类型2对应的是一个可点击的按钮组件。
+		 */
+		private int type;
+
+	}
+
+	@Data
+	@Accessors(chain = true)
+	public static class Attachment {
+		// 文件名
+		private String filename;
+
+		// 附件大小(字节)
+		private int size;
+
+		// 内容类型(例如:image/webp)
+		private String contentType;
+
+		// 图像宽度(像素)
+		private int width;
+
+		// 占位符版本号
+		private int placeholderVersion;
+
+		// 代理URL,用于访问附件资源
+		private String proxyUrl;
+
+		// 占位符标识符
+		private String placeholder;
+
+		// 附件ID
+		private String id;
+
+		// 直接访问附件资源的URL
+		private String url;
+
+		// 图像高度(像素)
+		private int height;
+	}
+
+	@Data
+	@Accessors(chain = true)
+	public static class Content {
+		private String prompt;
+		private String progress;
+		private String status;
+	}
+}

+ 29 - 0
yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/midjourney/constants/MjGennerateStatusEnum.java

@@ -0,0 +1,29 @@
+package cn.iocoder.yudao.framework.ai.midjourney.constants;
+
+import lombok.Getter;
+
+/**
+ * mj 生成状态
+ *
+ * author: fansili
+ * time: 2024/4/6 21:07
+ */
+@Getter
+public enum MjGennerateStatusEnum {
+
+
+    WAITING("waiting", "等待..."),
+    IN_PROGRESS("in_progress", "进行中"),
+    COMPLETED("completed", "完成"),
+
+    ;
+
+    MjGennerateStatusEnum(String value, String message) {
+        this.value = value;
+        this.message = message;
+    }
+
+    private String value;
+
+    private String message;
+}

+ 69 - 30
yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/midjourney/webSocket/listener/MjMessageListener.java

@@ -1,44 +1,83 @@
 package cn.iocoder.yudao.framework.ai.midjourney.webSocket.listener;
 
 
+import cn.hutool.core.collection.CollUtil;
 import cn.hutool.core.text.CharSequenceUtil;
+import cn.hutool.core.util.StrUtil;
+import cn.hutool.json.JSONUtil;
 import cn.iocoder.yudao.framework.ai.midjourney.MidjourneyConfig;
+import cn.iocoder.yudao.framework.ai.midjourney.MjMessage;
 import cn.iocoder.yudao.framework.ai.midjourney.constants.MjConstants;
+import cn.iocoder.yudao.framework.ai.midjourney.constants.MjGennerateStatusEnum;
 import cn.iocoder.yudao.framework.ai.midjourney.constants.MjMessageTypeEnum;
+import cn.iocoder.yudao.framework.ai.midjourney.util.MjUtil;
+import com.alibaba.fastjson.JSON;
+import com.google.common.collect.Lists;
 import lombok.extern.slf4j.Slf4j;
 import net.dv8tion.jda.api.utils.data.DataObject;
 
+import java.io.UnsupportedEncodingException;
+import java.util.Arrays;
+import java.util.List;
+
 @Slf4j
 public class MjMessageListener {
 
-	private MidjourneyConfig midjourneyConfig;
-
-	public MjMessageListener(MidjourneyConfig midjourneyConfig) {
-		this.midjourneyConfig = midjourneyConfig;
-	}
-
-	public void onMessage(DataObject raw) {
-		MjMessageTypeEnum messageType = MjMessageTypeEnum.of(raw.getString("t"));
-		if (messageType == null || MjMessageTypeEnum.DELETE == messageType) {
-			return;
-		}
-		DataObject data = raw.getObject("d");
-		if (ignoreAndLogMessage(data, messageType)) {
-			return;
-		}
-		System.err.println(data);
-//		if (data.getBoolean(Constants.MJ_MESSAGE_HANDLED, false)) {
-//			return;
-//		}
-	}
-
-	private boolean ignoreAndLogMessage(DataObject data, MjMessageTypeEnum messageType) {
-		String channelId = data.getString(MjConstants.CHANNEL_ID);
-		if (!CharSequenceUtil.equals(channelId, midjourneyConfig.getChannelId())) {
-			return true;
-		}
-		String authorName = data.optObject("author").map(a -> a.getString("username")).orElse("System");
-		log.debug("{} - {} - {}: {}", midjourneyConfig.getChannelId(), messageType.name(), authorName, data.opt("content").orElse(""));
-		return false;
-	}
+    private MidjourneyConfig midjourneyConfig;
+
+    public MjMessageListener(MidjourneyConfig midjourneyConfig) {
+        this.midjourneyConfig = midjourneyConfig;
+    }
+
+    public void onMessage(DataObject raw) {
+        MjMessageTypeEnum messageType = MjMessageTypeEnum.of(raw.getString("t"));
+        if (messageType == null || MjMessageTypeEnum.DELETE == messageType) {
+            return;
+        }
+        DataObject data = raw.getObject("d");
+        if (ignoreAndLogMessage(data, messageType)) {
+            return;
+        }
+
+        MjMessage mjMessage = new MjMessage();
+        mjMessage.setId(data.getString("id"));
+        mjMessage.setType(data.getInt("type"));
+        mjMessage.setRawData(StrUtil.str(raw.toJson(), "UTF-8"));
+		mjMessage.setContent(MjUtil.parseContent(data.getString("content")));
+
+        if (!data.getArray("components").isEmpty()) {
+            String componentsJson = StrUtil.str(data.getArray("components").toJson(), "UTF-8");
+            List<MjMessage.ComponentType> components = JSON.parseArray(componentsJson, MjMessage.ComponentType.class);
+            mjMessage.setComponents(components);
+        }
+        if (!data.getArray("attachments").isEmpty()) {
+            String attachmentsJson = StrUtil.str(data.getArray("attachments").toJson(), "UTF-8");
+            List<MjMessage.Attachment> attachments = JSON.parseArray(attachmentsJson, MjMessage.Attachment.class);
+            mjMessage.setAttachments(attachments);
+        }
+
+        // 转换状态
+        convertGenerateStatus(mjMessage);
+        System.err.println(JSONUtil.toJsonPrettyStr(mjMessage));
+    }
+
+    private void convertGenerateStatus(MjMessage mjMessage) {
+        if (mjMessage.getType() == 20 && mjMessage.getContent().getStatus().contains("Waiting")) {
+            mjMessage.setGenerateStatus(MjGennerateStatusEnum.WAITING.getValue());
+        } else if (mjMessage.getType() == 20 && !StrUtil.isBlank(mjMessage.getContent().getProgress())) {
+            mjMessage.setGenerateStatus(MjGennerateStatusEnum.IN_PROGRESS.getValue());
+        } else if (mjMessage.getType() == 0 && !CollUtil.isEmpty(mjMessage.getComponents())) {
+            mjMessage.setGenerateStatus(MjGennerateStatusEnum.COMPLETED.getValue());
+        }
+    }
+
+    private boolean ignoreAndLogMessage(DataObject data, MjMessageTypeEnum messageType) {
+        String channelId = data.getString(MjConstants.CHANNEL_ID);
+        if (!CharSequenceUtil.equals(channelId, midjourneyConfig.getChannelId())) {
+            return true;
+        }
+        String authorName = data.optObject("author").map(a -> a.getString("username")).orElse("System");
+        log.debug("{} - {} - {}: {}", midjourneyConfig.getChannelId(), messageType.name(), authorName, data.opt("content").orElse(""));
+        return false;
+    }
 }