|
@@ -7,9 +7,11 @@ import cn.iocoder.yudao.framework.ai.midjourney.MidjourneyConfig;
|
|
|
import cn.iocoder.yudao.framework.ai.midjourney.constants.MjNotifyCode;
|
|
|
import cn.iocoder.yudao.framework.ai.midjourney.webSocket.handler.MjWebSocketHandler;
|
|
|
import cn.iocoder.yudao.framework.ai.midjourney.webSocket.listener.MjMessageListener;
|
|
|
+import lombok.Getter;
|
|
|
import lombok.extern.slf4j.Slf4j;
|
|
|
import org.apache.tomcat.websocket.Constants;
|
|
|
import org.jetbrains.annotations.NotNull;
|
|
|
+import org.springframework.util.concurrent.ListenableFuture;
|
|
|
import org.springframework.util.concurrent.ListenableFutureCallback;
|
|
|
import org.springframework.web.socket.CloseStatus;
|
|
|
import org.springframework.web.socket.WebSocketHttpHeaders;
|
|
@@ -22,19 +24,43 @@ import java.util.concurrent.TimeoutException;
|
|
|
|
|
|
@Slf4j
|
|
|
public class MjWebSocketStarter implements WebSocketStarter {
|
|
|
+ /**
|
|
|
+ * 链接重试次数
|
|
|
+ */
|
|
|
private static final int CONNECT_RETRY_LIMIT = 5;
|
|
|
-
|
|
|
+ /**
|
|
|
+ * mj 配置文件
|
|
|
+ */
|
|
|
private final MidjourneyConfig midjourneyConfig;
|
|
|
+ /**
|
|
|
+ * mj 监听(所有message 都会 callback到这里)
|
|
|
+ */
|
|
|
private final MjMessageListener userMessageListener;
|
|
|
+ /**
|
|
|
+ * wss 服务器
|
|
|
+ */
|
|
|
private final String wssServer;
|
|
|
+ /**
|
|
|
+ *
|
|
|
+ */
|
|
|
private final String resumeWss;
|
|
|
-
|
|
|
+ /**
|
|
|
+ *
|
|
|
+ */
|
|
|
+ private ResumeData resumeData = null;
|
|
|
+ /**
|
|
|
+ * 是否运行成功
|
|
|
+ */
|
|
|
private boolean running = false;
|
|
|
-
|
|
|
+ /**
|
|
|
+ * 链接成功的 session
|
|
|
+ */
|
|
|
private WebSocketSession webSocketSession = null;
|
|
|
- private ResumeData resumeData = null;
|
|
|
|
|
|
- public MjWebSocketStarter(String wssServer, String resumeWss, MidjourneyConfig midjourneyConfig, MjMessageListener userMessageListener) {
|
|
|
+ public MjWebSocketStarter(String wssServer,
|
|
|
+ String resumeWss,
|
|
|
+ MidjourneyConfig midjourneyConfig,
|
|
|
+ MjMessageListener userMessageListener) {
|
|
|
this.wssServer = wssServer;
|
|
|
this.resumeWss = resumeWss;
|
|
|
this.midjourneyConfig = midjourneyConfig;
|
|
@@ -42,11 +68,12 @@ public class MjWebSocketStarter implements WebSocketStarter {
|
|
|
}
|
|
|
|
|
|
@Override
|
|
|
- public void start() throws Exception {
|
|
|
+ public void start() {
|
|
|
start(false);
|
|
|
}
|
|
|
|
|
|
private void start(boolean reconnect) {
|
|
|
+ // 设置header
|
|
|
WebSocketHttpHeaders headers = new WebSocketHttpHeaders();
|
|
|
headers.add("Accept-Encoding", "gzip, deflate, br");
|
|
|
headers.add("Accept-Language", "zh-CN,zh;q=0.9");
|
|
@@ -54,19 +81,26 @@ public class MjWebSocketStarter implements WebSocketStarter {
|
|
|
headers.add("Pragma", "no-cache");
|
|
|
headers.add("Sec-Websocket-Extensions", "permessage-deflate; client_max_window_bits");
|
|
|
headers.add("User-Agent", this.midjourneyConfig.getUserAage());
|
|
|
- var handler = new MjWebSocketHandler(this.midjourneyConfig, this.userMessageListener, this::onSocketSuccess, this::onSocketFailure);
|
|
|
+ // 创建 mjHeader
|
|
|
+ MjWebSocketHandler mjWebSocketHandler = new MjWebSocketHandler(
|
|
|
+ this.midjourneyConfig, this.userMessageListener, this::onSocketSuccess, this::onSocketFailure);
|
|
|
+ //
|
|
|
String gatewayUrl;
|
|
|
if (reconnect) {
|
|
|
- gatewayUrl = getGatewayServer(this.resumeData.resumeGatewayUrl()) + "/?encoding=json&v=9&compress=zlib-stream";
|
|
|
- handler.setSessionId(this.resumeData.sessionId());
|
|
|
- handler.setSequence(this.resumeData.sequence());
|
|
|
- handler.setResumeGatewayUrl(this.resumeData.resumeGatewayUrl());
|
|
|
+ gatewayUrl = getGatewayServer(this.resumeData.getResumeGatewayUrl()) + "/?encoding=json&v=9&compress=zlib-stream";
|
|
|
+ mjWebSocketHandler.setSessionId(this.resumeData.getSessionId());
|
|
|
+ mjWebSocketHandler.setSequence(this.resumeData.getSequence());
|
|
|
+ mjWebSocketHandler.setResumeGatewayUrl(this.resumeData.getResumeGatewayUrl());
|
|
|
} else {
|
|
|
gatewayUrl = getGatewayServer(null) + "/?encoding=json&v=9&compress=zlib-stream";
|
|
|
}
|
|
|
- var webSocketClient = new StandardWebSocketClient();
|
|
|
+ // 创建 StandardWebSocketClient
|
|
|
+ StandardWebSocketClient webSocketClient = new StandardWebSocketClient();
|
|
|
+ // 设置 io timeout 时间
|
|
|
webSocketClient.getUserProperties().put(Constants.IO_TIMEOUT_MS_PROPERTY, "10000");
|
|
|
- var socketSessionFuture = webSocketClient.doHandshake(handler, headers, URI.create(gatewayUrl));
|
|
|
+ //
|
|
|
+ ListenableFuture<WebSocketSession> socketSessionFuture = webSocketClient.doHandshake(mjWebSocketHandler, headers, URI.create(gatewayUrl));
|
|
|
+ // 添加 callback 进行回调
|
|
|
socketSessionFuture.addCallback(new ListenableFutureCallback<>() {
|
|
|
@Override
|
|
|
public void onFailure(@NotNull Throwable e) {
|
|
@@ -87,14 +121,18 @@ public class MjWebSocketStarter implements WebSocketStarter {
|
|
|
}
|
|
|
|
|
|
private void onSocketFailure(int code, String reason) {
|
|
|
+ // 1001异常可以忽略
|
|
|
if (code == 1001) {
|
|
|
return;
|
|
|
}
|
|
|
+ // 关闭 socket
|
|
|
closeSocketSessionWhenIsOpen();
|
|
|
+ // 没有运行通知
|
|
|
if (!this.running) {
|
|
|
notifyWssLock(code, reason);
|
|
|
return;
|
|
|
}
|
|
|
+ // 已经运行先设置为false,发起
|
|
|
this.running = false;
|
|
|
if (code >= 4000) {
|
|
|
log.warn("[wss-{}] Can't reconnect! Account disabled. Closed by {}({}).", this.midjourneyConfig.getChannelId(), code, reason);
|
|
@@ -107,36 +145,34 @@ public class MjWebSocketStarter implements WebSocketStarter {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+ /**
|
|
|
+ * 重连
|
|
|
+ */
|
|
|
private void tryReconnect() {
|
|
|
try {
|
|
|
tryStart(true);
|
|
|
} catch (Exception e) {
|
|
|
- if (e instanceof TimeoutException) {
|
|
|
- closeSocketSessionWhenIsOpen();
|
|
|
- }
|
|
|
- log.warn("[wss-{}] Reconnect fail: {}, Try new connection...", this.midjourneyConfig.getChannelId(), e.getMessage());
|
|
|
+ log.warn("[wss-{}] Reconnect fail: {}, Try new connection...", this.midjourneyConfig.getChannelId(), e.getMessage());
|
|
|
ThreadUtil.sleep(1000);
|
|
|
tryNewConnect();
|
|
|
}
|
|
|
}
|
|
|
|
|
|
private void tryNewConnect() {
|
|
|
+ // 链接重试次数5
|
|
|
for (int i = 1; i <= CONNECT_RETRY_LIMIT; i++) {
|
|
|
try {
|
|
|
tryStart(false);
|
|
|
return;
|
|
|
} catch (Exception e) {
|
|
|
- if (e instanceof TimeoutException) {
|
|
|
- closeSocketSessionWhenIsOpen();
|
|
|
- }
|
|
|
- log.warn("[wss-{}] New connect fail ({}): {}", this.midjourneyConfig.getChannelId(), i, e.getMessage());
|
|
|
+ log.warn("[wss-{}] New connect fail ({}): {}", this.midjourneyConfig.getChannelId(), i, e.getMessage());
|
|
|
ThreadUtil.sleep(5000);
|
|
|
}
|
|
|
}
|
|
|
log.error("[wss-{}] Account disabled", this.midjourneyConfig.getChannelId());
|
|
|
}
|
|
|
|
|
|
- public void tryStart(boolean reconnect) throws Exception {
|
|
|
+ public void tryStart(boolean reconnect) {
|
|
|
start(reconnect);
|
|
|
}
|
|
|
|
|
@@ -144,6 +180,9 @@ public class MjWebSocketStarter implements WebSocketStarter {
|
|
|
System.err.println("notifyWssLock: " + code + " - " + reason);
|
|
|
}
|
|
|
|
|
|
+ /**
|
|
|
+ * 关闭 socket session
|
|
|
+ */
|
|
|
private void closeSocketSessionWhenIsOpen() {
|
|
|
try {
|
|
|
if (this.webSocketSession != null && this.webSocketSession.isOpen()) {
|
|
@@ -161,6 +200,20 @@ public class MjWebSocketStarter implements WebSocketStarter {
|
|
|
return this.wssServer;
|
|
|
}
|
|
|
|
|
|
- public record ResumeData(String sessionId, Object sequence, String resumeGatewayUrl) {
|
|
|
+ @Getter
|
|
|
+ public static class ResumeData {
|
|
|
+
|
|
|
+ public ResumeData(String sessionId, Object sequence, String resumeGatewayUrl) {
|
|
|
+ this.sessionId = sessionId;
|
|
|
+ this.sequence = sequence;
|
|
|
+ this.resumeGatewayUrl = resumeGatewayUrl;
|
|
|
+ }
|
|
|
+
|
|
|
+ /**
|
|
|
+ * socket session
|
|
|
+ */
|
|
|
+ private final String sessionId;
|
|
|
+ private final Object sequence;
|
|
|
+ private final String resumeGatewayUrl;
|
|
|
}
|
|
|
}
|