refactor(gateway): manage gateway implementions via related registry and interface

This commit is contained in:
2026-04-08 22:24:07 +08:00
parent 0528890d60
commit 427d224f65
7 changed files with 293 additions and 58 deletions

View File

@@ -9,6 +9,8 @@ import org.java_websocket.handshake.ClientHandshake;
import org.java_websocket.server.WebSocketServer;
import org.jetbrains.annotations.NotNull;
import work.slhaf.partner.framework.agent.interaction.AgentGateway;
import work.slhaf.partner.framework.agent.interaction.AgentGatewayRegistration;
import work.slhaf.partner.framework.agent.interaction.AgentRuntime;
import work.slhaf.partner.framework.agent.interaction.data.InputData;
import work.slhaf.partner.framework.agent.interaction.data.InteractionEvent;
import work.slhaf.partner.runtime.interaction.data.context.PartnerRunningFlowContext;
@@ -17,6 +19,7 @@ import java.net.InetSocketAddress;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.atomic.AtomicBoolean;
@Slf4j
public class WebSocketGateway extends WebSocketServer implements AgentGateway<InputData, PartnerRunningFlowContext> {
@@ -26,6 +29,7 @@ public class WebSocketGateway extends WebSocketServer implements AgentGateway<In
@ToString.Exclude
private final ConcurrentHashMap<String, WebSocket> userSessions = new ConcurrentHashMap<>();
private final ExecutorService executor;
private final AtomicBoolean launched = new AtomicBoolean(false);
// 记录最后一次收到Pong的时间
private final ConcurrentHashMap<WebSocket, Long> lastPongTimes = new ConcurrentHashMap<>();
@@ -38,10 +42,13 @@ public class WebSocketGateway extends WebSocketServer implements AgentGateway<In
}
public void launch() {
if (!launched.compareAndSet(false, true)) {
return;
}
this.start();
setShutDownHook();
startHeartbeatThread();
register();
AgentRuntime.INSTANCE.registerResponseChannel(getChannelName(), this);
}
@Override
@@ -154,9 +161,35 @@ public class WebSocketGateway extends WebSocketServer implements AgentGateway<In
log.info("WebSocketServer 已启动...");
}
@Override
public AgentGatewayRegistration registration() {
return WebSocketGatewayRegistration.INSTANCE;
}
@Override
@NotNull
public String getChannelName() {
return "websocket_channel";
}
@Override
public void close() {
executor.shutdownNow();
lastPongTimes.clear();
userSessions.clear();
try {
for (WebSocket webSocket : getConnections()) {
if (webSocket != null && webSocket.isOpen()) {
webSocket.close(1001, "Server shutting down");
}
}
if (launched.get()) {
super.stop(1000);
}
} catch (Exception e) {
log.warn("关闭 WebSocketGateway 失败", e);
} finally {
launched.set(false);
}
}
}

View File

@@ -0,0 +1,25 @@
package work.slhaf.partner.runtime.interaction
import work.slhaf.partner.framework.agent.interaction.AgentGateway
import work.slhaf.partner.framework.agent.interaction.AgentGatewayRegistration
object WebSocketGatewayRegistration : AgentGatewayRegistration {
override val channelName: String = "websocket_channel"
override fun create(params: Map<String, String>): AgentGateway<*, *> {
val port = params["port"]?.toIntOrNull() ?: 29600
val heartbeatInterval = params["heartbeat_interval"]?.toLongOrNull() ?: 10_000L
require(port > 0) { "port must be greater than 0" }
require(heartbeatInterval > 0) { "heartbeat_interval must be greater than 0" }
return WebSocketGateway(port, heartbeatInterval)
}
override fun shutdown(instance: AgentGateway<*, *>) {
if (instance is WebSocketGateway) {
instance.close()
} else {
super.shutdown(instance)
}
}
}

View File

@@ -1,52 +0,0 @@
package work.slhaf.partner.runtime.interaction;
import com.alibaba.fastjson2.JSONObject;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;
import work.slhaf.partner.framework.agent.config.Config;
import work.slhaf.partner.framework.agent.config.ConfigDoc;
import work.slhaf.partner.framework.agent.config.ConfigRegistration;
import work.slhaf.partner.framework.agent.config.Configurable;
import java.nio.file.Path;
import java.util.Map;
public class WebSocketGatewayRegistry implements Configurable {
// TODO 在 Agent 入口处,针对这类内容提供统一注册
@Override
public @NotNull Map<Path, ConfigRegistration<? extends Config>> declare() {
return Map.of(Path.of("gateway", "websocket.json"), new WebSocketRegistration());
}
static class WebSocketRegistration implements ConfigRegistration<WebSocketConfig> {
@Override
@NotNull
public Class<WebSocketConfig> type() {
return WebSocketConfig.class;
}
@Override
public void init(@NotNull WebSocketConfig config, JSONObject json) {
new WebSocketGateway(config.port, config.heartbeatInterval);
}
@Nullable
@Override
public WebSocketConfig defaultConfig() {
return new WebSocketConfig(29600, 10_000);
}
}
static class WebSocketConfig extends Config {
@ConfigDoc(description = "WebSocket 监听端口")
final int port;
@ConfigDoc(description = "WebSocket 心跳间隔", unit = "ms", constraint = "> 0", example = "10000")
final int heartbeatInterval;
WebSocketConfig(int port, int heartbeatInterval) {
this.port = port;
this.heartbeatInterval = heartbeatInterval;
}
}
}

View File

@@ -3,10 +3,17 @@ package work.slhaf.partner.framework.agent.interaction;
import work.slhaf.partner.framework.agent.interaction.data.InputData;
import work.slhaf.partner.framework.agent.interaction.flow.RunningFlowContext;
public interface AgentGateway<I extends InputData, C extends RunningFlowContext> extends ResponseChannel {
public interface AgentGateway<I extends InputData, C extends RunningFlowContext> extends ResponseChannel, AutoCloseable {
void launch();
AgentGatewayRegistration registration();
@Override
default void register() {
registration().register();
}
default void receive(I inputData) {
C parsedContext = parseRunningFlowContext(inputData);
AgentRuntime.INSTANCE.submit(parsedContext);

View File

@@ -0,0 +1,22 @@
package work.slhaf.partner.framework.agent.interaction
interface AgentGatewayRegistration {
val channelName: String
fun create(params: Map<String, String>): AgentGateway<*, *>
fun supportsHotReloadReuse(oldParams: Map<String, String>, newParams: Map<String, String>): Boolean {
return oldParams == newParams
}
fun shutdown(instance: AgentGateway<*, *>) {
if (instance is AutoCloseable) {
instance.close()
}
}
fun register() {
AgentGatewayRegistry.register(this)
}
}

View File

@@ -0,0 +1,187 @@
package work.slhaf.partner.framework.agent.interaction
import com.alibaba.fastjson2.JSONObject
import com.alibaba.fastjson2.annotation.JSONField
import org.slf4j.LoggerFactory
import work.slhaf.partner.framework.agent.config.Config
import work.slhaf.partner.framework.agent.config.ConfigDoc
import work.slhaf.partner.framework.agent.config.ConfigRegistration
import work.slhaf.partner.framework.agent.config.Configurable
import java.nio.file.Path
import java.util.concurrent.locks.ReentrantLock
import kotlin.concurrent.withLock
object AgentGatewayRegistry : Configurable, ConfigRegistration<AgentGatewayRegistryConfig> {
private val log = LoggerFactory.getLogger(AgentGatewayRegistry::class.java)
private val registryLock = ReentrantLock()
private val registrations = linkedMapOf<String, AgentGatewayRegistration>()
private val runningChannels = linkedMapOf<String, RunningGateway>()
init {
register()
}
override fun declare(): Map<Path, ConfigRegistration<out Config>> {
return mapOf(Path.of("gateway", "gateway.json") to this)
}
override fun type(): Class<AgentGatewayRegistryConfig> = AgentGatewayRegistryConfig::class.java
fun register(registration: AgentGatewayRegistration) = registryLock.withLock {
val previous = registrations.putIfAbsent(registration.channelName, registration)
check(previous == null || previous === registration) {
"AgentGateway channel already registered: ${registration.channelName}"
}
}
fun resolve(channelName: String): AgentGateway<*, *>? = registryLock.withLock {
runningChannels[channelName]?.gateway
}
internal fun snapshotRunningChannels(): Map<String, AgentGateway<*, *>> = registryLock.withLock {
runningChannels.mapValues { it.value.gateway }
}
override fun init(config: AgentGatewayRegistryConfig, json: JSONObject?) = registryLock.withLock {
applyConfig(config)
}
override fun onReload(config: AgentGatewayRegistryConfig, json: JSONObject?) = registryLock.withLock {
val runtimeSnapshot = LinkedHashMap(runningChannels)
val defaultSnapshot = AgentRuntime.defaultResponseChannel()
try {
applyConfig(config)
} catch (e: Exception) {
log.error("Error while reloading gateway config", e)
restoreSnapshot(runtimeSnapshot, defaultSnapshot)
}
}
override fun defaultConfig(): AgentGatewayRegistryConfig? = null
private fun applyConfig(config: AgentGatewayRegistryConfig) {
validateConfig(config)
reconcileChannels(config.channels)
AgentRuntime.setDefaultResponseChannel(config.defaultChannel)
}
private fun validateConfig(config: AgentGatewayRegistryConfig) {
require(config.defaultChannel.isNotBlank()) { "default_channel must not be blank" }
require(config.channels.isNotEmpty()) { "channels must not be empty" }
val channelNames = mutableSetOf<String>()
config.channels.forEach { channel ->
require(channel.channelName.isNotBlank()) { "channel_name must not be blank" }
require(channelNames.add(channel.channelName)) { "Duplicated channel_name: ${channel.channelName}" }
require(registrations.containsKey(channel.channelName)) {
"AgentGateway channel is not registered: ${channel.channelName}"
}
}
require(channelNames.contains(config.defaultChannel)) {
"default_channel must exist in channels: ${config.defaultChannel}"
}
}
private fun reconcileChannels(configuredChannels: List<AgentGatewayChannelConfig>) {
val expectedNames = configuredChannels.map { it.channelName }.toSet()
val removedNames = runningChannels.keys.filterNot(expectedNames::contains)
removedNames.forEach(this::stopChannel)
configuredChannels.forEach { channelConfig ->
val registration = registrations[channelConfig.channelName]
?: error("AgentGateway channel is not registered: ${channelConfig.channelName}")
val existing = runningChannels[channelConfig.channelName]
if (existing != null && existing.registration === registration &&
registration.supportsHotReloadReuse(existing.params, channelConfig.params)
) {
return@forEach
}
if (existing != null) {
stopChannel(channelConfig.channelName)
}
startChannel(registration, channelConfig)
}
}
private fun startChannel(
registration: AgentGatewayRegistration,
channelConfig: AgentGatewayChannelConfig
) {
val gateway = registration.create(channelConfig.params)
try {
gateway.launch()
AgentRuntime.registerResponseChannel(channelConfig.channelName, gateway)
runningChannels[channelConfig.channelName] = RunningGateway(
registration = registration,
params = LinkedHashMap(channelConfig.params),
gateway = gateway
)
} catch (e: Exception) {
runCatching { registration.shutdown(gateway) }
.onFailure { shutdownError ->
log.warn(
"Failed to shutdown gateway after launch failure: {}",
channelConfig.channelName,
shutdownError
)
}
throw e
}
}
private fun stopChannel(channelName: String) {
val running = runningChannels.remove(channelName) ?: return
runCatching { running.registration.shutdown(running.gateway) }
.onFailure { e -> log.warn("Failed to shutdown gateway: {}", channelName, e) }
AgentRuntime.unregisterResponseChannel(channelName)
}
private fun restoreSnapshot(
runtimeSnapshot: Map<String, RunningGateway>,
defaultSnapshot: String
) {
val currentChannels = runningChannels.keys.toList()
currentChannels.forEach(this::stopChannel)
runtimeSnapshot.forEach { (channelName, running) ->
AgentRuntime.registerResponseChannel(channelName, running.gateway)
runningChannels[channelName] = running
}
AgentRuntime.setDefaultResponseChannel(defaultSnapshot)
}
private data class RunningGateway(
val registration: AgentGatewayRegistration,
val params: Map<String, String>,
val gateway: AgentGateway<*, *>
)
}
data class AgentGatewayRegistryConfig(
@field:JSONField(name = "default_channel")
@field:ConfigDoc(description = "默认响应通道", example = "websocket_channel")
val defaultChannel: String,
@field:ConfigDoc(
description = "要启用的通道列表",
example = """[
{
"channel_name": "websocket_channel",
"params": {
"port": "29600",
"heartbeat_interval": "10000"
}
}
]"""
)
val channels: List<AgentGatewayChannelConfig>
) : Config()
data class AgentGatewayChannelConfig(
@field:JSONField(name = "channel_name")
@field:ConfigDoc(description = "通道名称,同时对应已注册的 gateway 名称", example = "websocket_channel")
val channelName: String,
@field:ConfigDoc(description = "通道参数", example = """{ "key1": "value1", "key2": "value2" }""")
val params: Map<String, String> = emptyMap()
)

View File

@@ -17,8 +17,8 @@ object AgentRuntime {
LogChannel.channelName to LogChannel
)
// TODO 暂时取 log_channel 为默认回复通道,若为空则只打印信息。后续将配合配置中心替换通过配置文件进行指定
private val defaultChannel: String = "log_channel"
@Volatile
private var defaultChannel: String = LogChannel.channelName
@Volatile
private var runningModules: Map<Int, List<AbstractAgentModule.Running<RunningFlowContext>>> = emptyMap()
@@ -35,11 +35,24 @@ object AgentRuntime {
responseChannels[channelName] = responseChannel
}
fun unregisterResponseChannel(channelName: String) {
if (channelName == LogChannel.channelName) {
return
}
responseChannels.remove(channelName)
}
fun setDefaultResponseChannel(channelName: String) {
defaultChannel = channelName
}
fun defaultResponseChannel(): String = defaultChannel
@JvmOverloads
fun response(event: InteractionEvent, channelName: String = defaultChannel) {
val channel = responseChannels[channelName]
if (channel == null) {
responseChannels[defaultChannel]!!.response(event)
responseChannels[defaultChannel]?.response(event) ?: LogChannel.response(event)
} else {
channel.response(event)
}
@@ -87,4 +100,4 @@ object AgentRuntime {
}
}
}
}