cherishsince пре 1 година
родитељ
комит
44e44dc4bb

+ 0 - 28
yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/midjourney/constants/MjNotifyCode.java

@@ -8,35 +8,7 @@ public final class MjNotifyCode {
 	 * 成功.
 	 */
 	public static final int SUCCESS = 1;
-	/**
-	 * 数据未找到.
-	 */
-	public static final int NOT_FOUND = 3;
-	/**
-	 * 校验错误.
-	 */
-	public static final int VALIDATION_ERROR = 4;
-	/**
-	 * 系统异常.
-	 */
-	public static final int FAILURE = 9;
 
-	/**
-	 * 已存在.
-	 */
-	public static final int EXISTED = 21;
-	/**
-	 * 排队中.
-	 */
-	public static final int IN_QUEUE = 22;
-	/**
-	 * 队列已满.
-	 */
-	public static final int QUEUE_REJECTED = 23;
-	/**
-	 * prompt包含敏感词.
-	 */
-	public static final int BANNED_PROMPT = 24;
 
 
 }

+ 76 - 23
yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/midjourney/webSocket/MjWebSocketStarter.java

@@ -7,9 +7,11 @@ import cn.iocoder.yudao.framework.ai.midjourney.MidjourneyConfig;
 import cn.iocoder.yudao.framework.ai.midjourney.constants.MjNotifyCode;
 import cn.iocoder.yudao.framework.ai.midjourney.webSocket.handler.MjWebSocketHandler;
 import cn.iocoder.yudao.framework.ai.midjourney.webSocket.listener.MjMessageListener;
+import lombok.Getter;
 import lombok.extern.slf4j.Slf4j;
 import org.apache.tomcat.websocket.Constants;
 import org.jetbrains.annotations.NotNull;
+import org.springframework.util.concurrent.ListenableFuture;
 import org.springframework.util.concurrent.ListenableFutureCallback;
 import org.springframework.web.socket.CloseStatus;
 import org.springframework.web.socket.WebSocketHttpHeaders;
@@ -22,19 +24,43 @@ import java.util.concurrent.TimeoutException;
 
 @Slf4j
 public class MjWebSocketStarter implements WebSocketStarter {
+	/**
+	 * 链接重试次数
+	 */
 	private static final int CONNECT_RETRY_LIMIT = 5;
-
+	/**
+	 * mj 配置文件
+	 */
 	private final MidjourneyConfig midjourneyConfig;
+	/**
+	 * mj 监听(所有message 都会 callback到这里)
+	 */
 	private final MjMessageListener userMessageListener;
+	/**
+	 * wss 服务器
+	 */
 	private final String wssServer;
+	/**
+	 *
+	 */
 	private final String resumeWss;
-
+	/**
+	 *
+	 */
+	private ResumeData resumeData = null;
+	/**
+	 * 是否运行成功
+	 */
 	private boolean running = false;
-
+	/**
+	 * 链接成功的 session
+	 */
 	private WebSocketSession webSocketSession = null;
-	private ResumeData resumeData = null;
 
-	public MjWebSocketStarter(String wssServer, String resumeWss, MidjourneyConfig midjourneyConfig, MjMessageListener userMessageListener) {
+	public MjWebSocketStarter(String wssServer,
+							  String resumeWss,
+							  MidjourneyConfig midjourneyConfig,
+							  MjMessageListener userMessageListener) {
 		this.wssServer = wssServer;
 		this.resumeWss = resumeWss;
 		this.midjourneyConfig = midjourneyConfig;
@@ -42,11 +68,12 @@ public class MjWebSocketStarter implements WebSocketStarter {
 	}
 
 	@Override
-	public void start() throws Exception {
+	public void start() {
 		start(false);
 	}
 
 	private void start(boolean reconnect) {
+		// 设置header
 		WebSocketHttpHeaders headers = new WebSocketHttpHeaders();
 		headers.add("Accept-Encoding", "gzip, deflate, br");
 		headers.add("Accept-Language", "zh-CN,zh;q=0.9");
@@ -54,19 +81,26 @@ public class MjWebSocketStarter implements WebSocketStarter {
 		headers.add("Pragma", "no-cache");
 		headers.add("Sec-Websocket-Extensions", "permessage-deflate; client_max_window_bits");
 		headers.add("User-Agent", this.midjourneyConfig.getUserAage());
-		var handler = new MjWebSocketHandler(this.midjourneyConfig, this.userMessageListener, this::onSocketSuccess, this::onSocketFailure);
+		// 创建 mjHeader
+		MjWebSocketHandler mjWebSocketHandler = new MjWebSocketHandler(
+				this.midjourneyConfig, this.userMessageListener, this::onSocketSuccess, this::onSocketFailure);
+		//
 		String gatewayUrl;
 		if (reconnect) {
-			gatewayUrl = getGatewayServer(this.resumeData.resumeGatewayUrl()) + "/?encoding=json&v=9&compress=zlib-stream";
-			handler.setSessionId(this.resumeData.sessionId());
-			handler.setSequence(this.resumeData.sequence());
-			handler.setResumeGatewayUrl(this.resumeData.resumeGatewayUrl());
+			gatewayUrl = getGatewayServer(this.resumeData.getResumeGatewayUrl()) + "/?encoding=json&v=9&compress=zlib-stream";
+			mjWebSocketHandler.setSessionId(this.resumeData.getSessionId());
+			mjWebSocketHandler.setSequence(this.resumeData.getSequence());
+			mjWebSocketHandler.setResumeGatewayUrl(this.resumeData.getResumeGatewayUrl());
 		} else {
 			gatewayUrl = getGatewayServer(null) + "/?encoding=json&v=9&compress=zlib-stream";
 		}
-		var webSocketClient = new StandardWebSocketClient();
+		// 创建 StandardWebSocketClient
+		StandardWebSocketClient webSocketClient = new StandardWebSocketClient();
+		// 设置 io timeout 时间
 		webSocketClient.getUserProperties().put(Constants.IO_TIMEOUT_MS_PROPERTY, "10000");
-		var socketSessionFuture = webSocketClient.doHandshake(handler, headers, URI.create(gatewayUrl));
+		//
+		ListenableFuture<WebSocketSession> socketSessionFuture = webSocketClient.doHandshake(mjWebSocketHandler, headers, URI.create(gatewayUrl));
+		// 添加 callback 进行回调
 		socketSessionFuture.addCallback(new ListenableFutureCallback<>() {
 			@Override
 			public void onFailure(@NotNull Throwable e) {
@@ -87,14 +121,18 @@ public class MjWebSocketStarter implements WebSocketStarter {
 	}
 
 	private void onSocketFailure(int code, String reason) {
+		// 1001异常可以忽略
 		if (code == 1001) {
 			return;
 		}
+		// 关闭 socket
 		closeSocketSessionWhenIsOpen();
+		// 没有运行通知
 		if (!this.running) {
 			notifyWssLock(code, reason);
 			return;
 		}
+		// 已经运行先设置为false,发起
 		this.running = false;
 		if (code >= 4000) {
 			log.warn("[wss-{}] Can't reconnect! Account disabled. Closed by {}({}).", this.midjourneyConfig.getChannelId(), code, reason);
@@ -107,36 +145,34 @@ public class MjWebSocketStarter implements WebSocketStarter {
 		}
 	}
 
+	/**
+	 * 重连
+	 */
 	private void tryReconnect() {
 		try {
 			tryStart(true);
 		} catch (Exception e) {
-			if (e instanceof TimeoutException) {
-				closeSocketSessionWhenIsOpen();
-			}
-			log.warn("[wss-{}] Reconnect fail: {}, Try new connection...", this.midjourneyConfig.getChannelId(), e.getMessage());
+            log.warn("[wss-{}] Reconnect fail: {}, Try new connection...", this.midjourneyConfig.getChannelId(), e.getMessage());
 			ThreadUtil.sleep(1000);
 			tryNewConnect();
 		}
 	}
 
 	private void tryNewConnect() {
+		// 链接重试次数5
 		for (int i = 1; i <= CONNECT_RETRY_LIMIT; i++) {
 			try {
 				tryStart(false);
 				return;
 			} catch (Exception e) {
-				if (e instanceof TimeoutException) {
-					closeSocketSessionWhenIsOpen();
-				}
-				log.warn("[wss-{}] New connect fail ({}): {}", this.midjourneyConfig.getChannelId(), i, e.getMessage());
+                log.warn("[wss-{}] New connect fail ({}): {}", this.midjourneyConfig.getChannelId(), i, e.getMessage());
 				ThreadUtil.sleep(5000);
 			}
 		}
 		log.error("[wss-{}] Account disabled", this.midjourneyConfig.getChannelId());
 	}
 
-	public void tryStart(boolean reconnect) throws Exception {
+	public void tryStart(boolean reconnect) {
 		start(reconnect);
 	}
 
@@ -144,6 +180,9 @@ public class MjWebSocketStarter implements WebSocketStarter {
 		System.err.println("notifyWssLock: " + code + " - " + reason);
 	}
 
+	/**
+	 * 关闭 socket session
+	 */
 	private void closeSocketSessionWhenIsOpen() {
 		try {
 			if (this.webSocketSession != null && this.webSocketSession.isOpen()) {
@@ -161,6 +200,20 @@ public class MjWebSocketStarter implements WebSocketStarter {
 		return this.wssServer;
 	}
 
-	public record ResumeData(String sessionId, Object sequence, String resumeGatewayUrl) {
+	@Getter
+	public static class ResumeData {
+
+		public ResumeData(String sessionId, Object sequence, String resumeGatewayUrl) {
+			this.sessionId = sessionId;
+			this.sequence = sequence;
+			this.resumeGatewayUrl = resumeGatewayUrl;
+		}
+
+		/**
+		 * socket session
+		 */
+		private final String sessionId;
+		private final Object sequence;
+		private final String resumeGatewayUrl;
 	}
 }

+ 37 - 2
yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/midjourney/webSocket/handler/MjWebSocketHandler.java

@@ -30,16 +30,41 @@ import java.util.concurrent.TimeUnit;
 
 @Slf4j
 public class MjWebSocketHandler implements WebSocketHandler {
+	/**
+	 * close 错误码:重连
+	 */
 	public static final int CLOSE_CODE_RECONNECT = 2001;
+	/**
+	 * close 错误码:无效、作废
+	 */
 	public static final int CLOSE_CODE_INVALIDATE = 1009;
+	/**
+	 * close 错误码:异常
+	 */
 	public static final int CLOSE_CODE_EXCEPTION = 1011;
-
+	/**
+	 * mj配置文件
+	 */
 	private final MidjourneyConfig midjourneyConfig;
+	/**
+	 * mj 消息监听
+	 */
 	private final MjMessageListener userMessageListener;
+	/**
+	 * 成功回调
+	 */
 	private final SuccessCallback successCallback;
+	/**
+	 * 失败回调
+	 */
 	private final FailureCallback failureCallback;
-
+	/**
+	 * 心跳执行器
+	 */
 	private final ScheduledExecutorService heartExecutor;
+	/**
+	 * auth数据
+	 */
 	private final DataObject authData;
 
 	@Setter
@@ -55,6 +80,9 @@ public class MjWebSocketHandler implements WebSocketHandler {
 	private Future<?> heartbeatInterval;
 	private Future<?> heartbeatTimeout;
 
+	/**
+	 * 处理 message 消息的 Decompressor
+	 */
 	private final Decompressor decompressor = new ZlibDecompressor(2048);
 
 	public MjWebSocketHandler(MidjourneyConfig account,
@@ -77,11 +105,13 @@ public class MjWebSocketHandler implements WebSocketHandler {
 	@Override
 	public void handleTransportError(@NotNull WebSocketSession session, @NotNull Throwable e) throws Exception {
 		log.error("[wss-{}] Transport error", this.midjourneyConfig.getChannelId(), e);
+		// 通知链接异常
 		onFailure(CLOSE_CODE_EXCEPTION, "transport error");
 	}
 
 	@Override
 	public void afterConnectionClosed(@NotNull WebSocketSession session, @NotNull CloseStatus closeStatus) throws Exception {
+		// 链接关闭
 		onFailure(closeStatus.getCode(), closeStatus.getReason());
 	}
 
@@ -92,13 +122,18 @@ public class MjWebSocketHandler implements WebSocketHandler {
 
 	@Override
 	public void handleMessage(@NotNull WebSocketSession session, WebSocketMessage<?> message) throws Exception {
+		// 获取 message 消息
 		ByteBuffer buffer = (ByteBuffer) message.getPayload();
+		// 解析 message
 		byte[] decompressed = decompressor.decompress(buffer.array());
 		if (decompressed == null) {
 			return;
 		}
+		// 转换 json
 		String json = new String(decompressed, StandardCharsets.UTF_8);
+		// 转换 jda 自带的 dataObject(和json object 差不多)
 		DataObject data = DataObject.fromJson(json);
+		// 获取消息类型
 		int opCode = data.getInt("op");
 		switch (opCode) {
 			case WebSocketCode.HEARTBEAT -> handleHeartbeat(session);