diff --git a/Partner-Core/src/main/java/work/slhaf/partner/runtime/interaction/WebSocketGateway.java b/Partner-Core/src/main/java/work/slhaf/partner/runtime/interaction/WebSocketGateway.java index 8c8cfa68..6637ce21 100644 --- a/Partner-Core/src/main/java/work/slhaf/partner/runtime/interaction/WebSocketGateway.java +++ b/Partner-Core/src/main/java/work/slhaf/partner/runtime/interaction/WebSocketGateway.java @@ -8,11 +8,9 @@ import org.java_websocket.framing.Framedata; import org.java_websocket.handshake.ClientHandshake; import org.java_websocket.server.WebSocketServer; import org.jetbrains.annotations.NotNull; -import work.slhaf.partner.api.agent.runtime.config.AgentConfigLoader; import work.slhaf.partner.api.agent.runtime.interaction.AgentGateway; import work.slhaf.partner.api.agent.runtime.interaction.data.InputData; import work.slhaf.partner.api.agent.runtime.interaction.data.InteractionEvent; -import work.slhaf.partner.common.config.PartnerAgentConfigLoader; import work.slhaf.partner.runtime.interaction.data.context.PartnerRunningFlowContext; import java.net.InetSocketAddress; @@ -23,7 +21,7 @@ import java.util.concurrent.Executors; @Slf4j public class WebSocketGateway extends WebSocketServer implements AgentGateway { - private static final long HEARTBEAT_INTERVAL = 10_000; + private final long heartbeatInterval; @ToString.Exclude private final ConcurrentHashMap userSessions = new ConcurrentHashMap<>(); @@ -32,12 +30,9 @@ public class WebSocketGateway extends WebSocketServer implements AgentGateway lastPongTimes = new ConcurrentHashMap<>(); - public WebSocketGateway() { - this(((PartnerAgentConfigLoader) AgentConfigLoader.INSTANCE).getConfig().getWebSocketConfig().getPort()); - } - - private WebSocketGateway(int port) { + public WebSocketGateway(int port, long heartbeatInterval) { super(new InetSocketAddress(port)); + this.heartbeatInterval = heartbeatInterval; this.setReuseAddr(true); this.executor = Executors.newSingleThreadExecutor(); } @@ -71,7 +66,7 @@ public class WebSocketGateway extends WebSocketServer implements AgentGateway { while (!Thread.interrupted()) { try { - Thread.sleep(HEARTBEAT_INTERVAL); + Thread.sleep(heartbeatInterval); checkConnections(); } catch (InterruptedException e) { Thread.currentThread().interrupt(); @@ -90,7 +85,7 @@ public class WebSocketGateway extends WebSocketServer implements AgentGateway HEARTBEAT_INTERVAL * 2) { + if (lastPong != null && now - lastPong > heartbeatInterval * 2) { log.warn("Connection {} timed out, closing...", conn.getRemoteSocketAddress()); conn.close(1001, "No Pong response"); } diff --git a/Partner-Core/src/main/java/work/slhaf/partner/runtime/interaction/WebSocketGatewayRegistry.java b/Partner-Core/src/main/java/work/slhaf/partner/runtime/interaction/WebSocketGatewayRegistry.java new file mode 100644 index 00000000..897ee27f --- /dev/null +++ b/Partner-Core/src/main/java/work/slhaf/partner/runtime/interaction/WebSocketGatewayRegistry.java @@ -0,0 +1,48 @@ +package work.slhaf.partner.runtime.interaction; + +import org.jetbrains.annotations.NotNull; +import org.jetbrains.annotations.Nullable; +import work.slhaf.partner.api.agent.runtime.config.Config; +import work.slhaf.partner.api.agent.runtime.config.ConfigRegistration; +import work.slhaf.partner.api.agent.runtime.config.Configurable; + +import java.nio.file.Path; +import java.util.Map; + +public class WebSocketGatewayRegistry implements Configurable { + // TODO 在 Agent 入口处,针对这类内容提供统一注册 + @Override + public @NotNull Map> declare() { + return Map.of(Path.of("gateway", "websocket.json"), new WebSocketRegistration()); + } + + static class WebSocketRegistration implements ConfigRegistration { + + @Override + @NotNull + public Class type() { + return WebSocketConfig.class; + } + + @Override + public void init(@NotNull WebSocketConfig config) { + new WebSocketGateway(config.port, config.heartbeatInterval); + } + + @Nullable + @Override + public WebSocketConfig defaultConfig() { + return new WebSocketConfig(29600, 10_000); + } + } + + static class WebSocketConfig extends Config { + final int port; + final int heartbeatInterval; + + WebSocketConfig(int port, int heartbeatInterval) { + this.port = port; + this.heartbeatInterval = heartbeatInterval; + } + } +}