refactor(gateway): manage gateway implementions via related registry and interface

This commit is contained in:
2026-04-08 22:24:07 +08:00
parent 0528890d60
commit 427d224f65
7 changed files with 293 additions and 58 deletions

View File

@@ -9,6 +9,8 @@ import org.java_websocket.handshake.ClientHandshake;
import org.java_websocket.server.WebSocketServer;
import org.jetbrains.annotations.NotNull;
import work.slhaf.partner.framework.agent.interaction.AgentGateway;
import work.slhaf.partner.framework.agent.interaction.AgentGatewayRegistration;
import work.slhaf.partner.framework.agent.interaction.AgentRuntime;
import work.slhaf.partner.framework.agent.interaction.data.InputData;
import work.slhaf.partner.framework.agent.interaction.data.InteractionEvent;
import work.slhaf.partner.runtime.interaction.data.context.PartnerRunningFlowContext;
@@ -17,6 +19,7 @@ import java.net.InetSocketAddress;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.atomic.AtomicBoolean;
@Slf4j
public class WebSocketGateway extends WebSocketServer implements AgentGateway<InputData, PartnerRunningFlowContext> {
@@ -26,6 +29,7 @@ public class WebSocketGateway extends WebSocketServer implements AgentGateway<In
@ToString.Exclude
private final ConcurrentHashMap<String, WebSocket> userSessions = new ConcurrentHashMap<>();
private final ExecutorService executor;
private final AtomicBoolean launched = new AtomicBoolean(false);
// 记录最后一次收到Pong的时间
private final ConcurrentHashMap<WebSocket, Long> lastPongTimes = new ConcurrentHashMap<>();
@@ -38,10 +42,13 @@ public class WebSocketGateway extends WebSocketServer implements AgentGateway<In
}
public void launch() {
if (!launched.compareAndSet(false, true)) {
return;
}
this.start();
setShutDownHook();
startHeartbeatThread();
register();
AgentRuntime.INSTANCE.registerResponseChannel(getChannelName(), this);
}
@Override
@@ -154,9 +161,35 @@ public class WebSocketGateway extends WebSocketServer implements AgentGateway<In
log.info("WebSocketServer 已启动...");
}
@Override
public AgentGatewayRegistration registration() {
return WebSocketGatewayRegistration.INSTANCE;
}
@Override
@NotNull
public String getChannelName() {
return "websocket_channel";
}
@Override
public void close() {
executor.shutdownNow();
lastPongTimes.clear();
userSessions.clear();
try {
for (WebSocket webSocket : getConnections()) {
if (webSocket != null && webSocket.isOpen()) {
webSocket.close(1001, "Server shutting down");
}
}
if (launched.get()) {
super.stop(1000);
}
} catch (Exception e) {
log.warn("关闭 WebSocketGateway 失败", e);
} finally {
launched.set(false);
}
}
}

View File

@@ -0,0 +1,25 @@
package work.slhaf.partner.runtime.interaction
import work.slhaf.partner.framework.agent.interaction.AgentGateway
import work.slhaf.partner.framework.agent.interaction.AgentGatewayRegistration
object WebSocketGatewayRegistration : AgentGatewayRegistration {
override val channelName: String = "websocket_channel"
override fun create(params: Map<String, String>): AgentGateway<*, *> {
val port = params["port"]?.toIntOrNull() ?: 29600
val heartbeatInterval = params["heartbeat_interval"]?.toLongOrNull() ?: 10_000L
require(port > 0) { "port must be greater than 0" }
require(heartbeatInterval > 0) { "heartbeat_interval must be greater than 0" }
return WebSocketGateway(port, heartbeatInterval)
}
override fun shutdown(instance: AgentGateway<*, *>) {
if (instance is WebSocketGateway) {
instance.close()
} else {
super.shutdown(instance)
}
}
}

View File

@@ -1,52 +0,0 @@
package work.slhaf.partner.runtime.interaction;
import com.alibaba.fastjson2.JSONObject;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;
import work.slhaf.partner.framework.agent.config.Config;
import work.slhaf.partner.framework.agent.config.ConfigDoc;
import work.slhaf.partner.framework.agent.config.ConfigRegistration;
import work.slhaf.partner.framework.agent.config.Configurable;
import java.nio.file.Path;
import java.util.Map;
public class WebSocketGatewayRegistry implements Configurable {
// TODO 在 Agent 入口处,针对这类内容提供统一注册
@Override
public @NotNull Map<Path, ConfigRegistration<? extends Config>> declare() {
return Map.of(Path.of("gateway", "websocket.json"), new WebSocketRegistration());
}
static class WebSocketRegistration implements ConfigRegistration<WebSocketConfig> {
@Override
@NotNull
public Class<WebSocketConfig> type() {
return WebSocketConfig.class;
}
@Override
public void init(@NotNull WebSocketConfig config, JSONObject json) {
new WebSocketGateway(config.port, config.heartbeatInterval);
}
@Nullable
@Override
public WebSocketConfig defaultConfig() {
return new WebSocketConfig(29600, 10_000);
}
}
static class WebSocketConfig extends Config {
@ConfigDoc(description = "WebSocket 监听端口")
final int port;
@ConfigDoc(description = "WebSocket 心跳间隔", unit = "ms", constraint = "> 0", example = "10000")
final int heartbeatInterval;
WebSocketConfig(int port, int heartbeatInterval) {
this.port = port;
this.heartbeatInterval = heartbeatInterval;
}
}
}