From 427d224f6504e13bc9e66f8d185feb8549383703 Mon Sep 17 00:00:00 2001 From: slhafzjw Date: Wed, 8 Apr 2026 22:24:07 +0800 Subject: [PATCH] refactor(gateway): manage gateway implementions via related registry and interface --- .../runtime/interaction/WebSocketGateway.java | 35 +++- .../WebSocketGatewayRegistration.kt | 25 +++ .../interaction/WebSocketGatewayRegistry.java | 52 ----- .../agent/interaction/AgentGateway.java | 9 +- .../interaction/AgentGatewayRegistration.kt | 22 +++ .../agent/interaction/AgentGatewayRegistry.kt | 187 ++++++++++++++++++ .../agent/interaction/AgentRuntime.kt | 21 +- 7 files changed, 293 insertions(+), 58 deletions(-) create mode 100644 Partner-Core/src/main/java/work/slhaf/partner/runtime/interaction/WebSocketGatewayRegistration.kt delete mode 100644 Partner-Core/src/main/java/work/slhaf/partner/runtime/interaction/WebSocketGatewayRegistry.java create mode 100644 Partner-Framework/src/main/java/work/slhaf/partner/framework/agent/interaction/AgentGatewayRegistration.kt create mode 100644 Partner-Framework/src/main/java/work/slhaf/partner/framework/agent/interaction/AgentGatewayRegistry.kt diff --git a/Partner-Core/src/main/java/work/slhaf/partner/runtime/interaction/WebSocketGateway.java b/Partner-Core/src/main/java/work/slhaf/partner/runtime/interaction/WebSocketGateway.java index 9babbdc7..c7e68ad9 100644 --- a/Partner-Core/src/main/java/work/slhaf/partner/runtime/interaction/WebSocketGateway.java +++ b/Partner-Core/src/main/java/work/slhaf/partner/runtime/interaction/WebSocketGateway.java @@ -9,6 +9,8 @@ import org.java_websocket.handshake.ClientHandshake; import org.java_websocket.server.WebSocketServer; import org.jetbrains.annotations.NotNull; import work.slhaf.partner.framework.agent.interaction.AgentGateway; +import work.slhaf.partner.framework.agent.interaction.AgentGatewayRegistration; +import work.slhaf.partner.framework.agent.interaction.AgentRuntime; import work.slhaf.partner.framework.agent.interaction.data.InputData; import work.slhaf.partner.framework.agent.interaction.data.InteractionEvent; import work.slhaf.partner.runtime.interaction.data.context.PartnerRunningFlowContext; @@ -17,6 +19,7 @@ import java.net.InetSocketAddress; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; +import java.util.concurrent.atomic.AtomicBoolean; @Slf4j public class WebSocketGateway extends WebSocketServer implements AgentGateway { @@ -26,6 +29,7 @@ public class WebSocketGateway extends WebSocketServer implements AgentGateway userSessions = new ConcurrentHashMap<>(); private final ExecutorService executor; + private final AtomicBoolean launched = new AtomicBoolean(false); // 记录最后一次收到Pong的时间 private final ConcurrentHashMap lastPongTimes = new ConcurrentHashMap<>(); @@ -38,10 +42,13 @@ public class WebSocketGateway extends WebSocketServer implements AgentGateway): AgentGateway<*, *> { + val port = params["port"]?.toIntOrNull() ?: 29600 + val heartbeatInterval = params["heartbeat_interval"]?.toLongOrNull() ?: 10_000L + require(port > 0) { "port must be greater than 0" } + require(heartbeatInterval > 0) { "heartbeat_interval must be greater than 0" } + return WebSocketGateway(port, heartbeatInterval) + } + + override fun shutdown(instance: AgentGateway<*, *>) { + if (instance is WebSocketGateway) { + instance.close() + } else { + super.shutdown(instance) + } + } +} diff --git a/Partner-Core/src/main/java/work/slhaf/partner/runtime/interaction/WebSocketGatewayRegistry.java b/Partner-Core/src/main/java/work/slhaf/partner/runtime/interaction/WebSocketGatewayRegistry.java deleted file mode 100644 index f65899a9..00000000 --- a/Partner-Core/src/main/java/work/slhaf/partner/runtime/interaction/WebSocketGatewayRegistry.java +++ /dev/null @@ -1,52 +0,0 @@ -package work.slhaf.partner.runtime.interaction; - -import com.alibaba.fastjson2.JSONObject; -import org.jetbrains.annotations.NotNull; -import org.jetbrains.annotations.Nullable; -import work.slhaf.partner.framework.agent.config.Config; -import work.slhaf.partner.framework.agent.config.ConfigDoc; -import work.slhaf.partner.framework.agent.config.ConfigRegistration; -import work.slhaf.partner.framework.agent.config.Configurable; - -import java.nio.file.Path; -import java.util.Map; - -public class WebSocketGatewayRegistry implements Configurable { - // TODO 在 Agent 入口处,针对这类内容提供统一注册 - @Override - public @NotNull Map> declare() { - return Map.of(Path.of("gateway", "websocket.json"), new WebSocketRegistration()); - } - - static class WebSocketRegistration implements ConfigRegistration { - - @Override - @NotNull - public Class type() { - return WebSocketConfig.class; - } - - @Override - public void init(@NotNull WebSocketConfig config, JSONObject json) { - new WebSocketGateway(config.port, config.heartbeatInterval); - } - - @Nullable - @Override - public WebSocketConfig defaultConfig() { - return new WebSocketConfig(29600, 10_000); - } - } - - static class WebSocketConfig extends Config { - @ConfigDoc(description = "WebSocket 监听端口") - final int port; - @ConfigDoc(description = "WebSocket 心跳间隔", unit = "ms", constraint = "> 0", example = "10000") - final int heartbeatInterval; - - WebSocketConfig(int port, int heartbeatInterval) { - this.port = port; - this.heartbeatInterval = heartbeatInterval; - } - } -} diff --git a/Partner-Framework/src/main/java/work/slhaf/partner/framework/agent/interaction/AgentGateway.java b/Partner-Framework/src/main/java/work/slhaf/partner/framework/agent/interaction/AgentGateway.java index 56888660..8a8c7e50 100644 --- a/Partner-Framework/src/main/java/work/slhaf/partner/framework/agent/interaction/AgentGateway.java +++ b/Partner-Framework/src/main/java/work/slhaf/partner/framework/agent/interaction/AgentGateway.java @@ -3,10 +3,17 @@ package work.slhaf.partner.framework.agent.interaction; import work.slhaf.partner.framework.agent.interaction.data.InputData; import work.slhaf.partner.framework.agent.interaction.flow.RunningFlowContext; -public interface AgentGateway extends ResponseChannel { +public interface AgentGateway extends ResponseChannel, AutoCloseable { void launch(); + AgentGatewayRegistration registration(); + + @Override + default void register() { + registration().register(); + } + default void receive(I inputData) { C parsedContext = parseRunningFlowContext(inputData); AgentRuntime.INSTANCE.submit(parsedContext); diff --git a/Partner-Framework/src/main/java/work/slhaf/partner/framework/agent/interaction/AgentGatewayRegistration.kt b/Partner-Framework/src/main/java/work/slhaf/partner/framework/agent/interaction/AgentGatewayRegistration.kt new file mode 100644 index 00000000..2b22c265 --- /dev/null +++ b/Partner-Framework/src/main/java/work/slhaf/partner/framework/agent/interaction/AgentGatewayRegistration.kt @@ -0,0 +1,22 @@ +package work.slhaf.partner.framework.agent.interaction + +interface AgentGatewayRegistration { + + val channelName: String + + fun create(params: Map): AgentGateway<*, *> + + fun supportsHotReloadReuse(oldParams: Map, newParams: Map): Boolean { + return oldParams == newParams + } + + fun shutdown(instance: AgentGateway<*, *>) { + if (instance is AutoCloseable) { + instance.close() + } + } + + fun register() { + AgentGatewayRegistry.register(this) + } +} diff --git a/Partner-Framework/src/main/java/work/slhaf/partner/framework/agent/interaction/AgentGatewayRegistry.kt b/Partner-Framework/src/main/java/work/slhaf/partner/framework/agent/interaction/AgentGatewayRegistry.kt new file mode 100644 index 00000000..4b81be26 --- /dev/null +++ b/Partner-Framework/src/main/java/work/slhaf/partner/framework/agent/interaction/AgentGatewayRegistry.kt @@ -0,0 +1,187 @@ +package work.slhaf.partner.framework.agent.interaction + +import com.alibaba.fastjson2.JSONObject +import com.alibaba.fastjson2.annotation.JSONField +import org.slf4j.LoggerFactory +import work.slhaf.partner.framework.agent.config.Config +import work.slhaf.partner.framework.agent.config.ConfigDoc +import work.slhaf.partner.framework.agent.config.ConfigRegistration +import work.slhaf.partner.framework.agent.config.Configurable +import java.nio.file.Path +import java.util.concurrent.locks.ReentrantLock +import kotlin.concurrent.withLock + +object AgentGatewayRegistry : Configurable, ConfigRegistration { + + private val log = LoggerFactory.getLogger(AgentGatewayRegistry::class.java) + private val registryLock = ReentrantLock() + private val registrations = linkedMapOf() + private val runningChannels = linkedMapOf() + + init { + register() + } + + override fun declare(): Map> { + return mapOf(Path.of("gateway", "gateway.json") to this) + } + + override fun type(): Class = AgentGatewayRegistryConfig::class.java + + fun register(registration: AgentGatewayRegistration) = registryLock.withLock { + val previous = registrations.putIfAbsent(registration.channelName, registration) + check(previous == null || previous === registration) { + "AgentGateway channel already registered: ${registration.channelName}" + } + } + + fun resolve(channelName: String): AgentGateway<*, *>? = registryLock.withLock { + runningChannels[channelName]?.gateway + } + + internal fun snapshotRunningChannels(): Map> = registryLock.withLock { + runningChannels.mapValues { it.value.gateway } + } + + override fun init(config: AgentGatewayRegistryConfig, json: JSONObject?) = registryLock.withLock { + applyConfig(config) + } + + override fun onReload(config: AgentGatewayRegistryConfig, json: JSONObject?) = registryLock.withLock { + val runtimeSnapshot = LinkedHashMap(runningChannels) + val defaultSnapshot = AgentRuntime.defaultResponseChannel() + try { + applyConfig(config) + } catch (e: Exception) { + log.error("Error while reloading gateway config", e) + restoreSnapshot(runtimeSnapshot, defaultSnapshot) + } + } + + override fun defaultConfig(): AgentGatewayRegistryConfig? = null + + private fun applyConfig(config: AgentGatewayRegistryConfig) { + validateConfig(config) + reconcileChannels(config.channels) + AgentRuntime.setDefaultResponseChannel(config.defaultChannel) + } + + private fun validateConfig(config: AgentGatewayRegistryConfig) { + require(config.defaultChannel.isNotBlank()) { "default_channel must not be blank" } + require(config.channels.isNotEmpty()) { "channels must not be empty" } + + val channelNames = mutableSetOf() + config.channels.forEach { channel -> + require(channel.channelName.isNotBlank()) { "channel_name must not be blank" } + require(channelNames.add(channel.channelName)) { "Duplicated channel_name: ${channel.channelName}" } + require(registrations.containsKey(channel.channelName)) { + "AgentGateway channel is not registered: ${channel.channelName}" + } + } + + require(channelNames.contains(config.defaultChannel)) { + "default_channel must exist in channels: ${config.defaultChannel}" + } + } + + private fun reconcileChannels(configuredChannels: List) { + val expectedNames = configuredChannels.map { it.channelName }.toSet() + val removedNames = runningChannels.keys.filterNot(expectedNames::contains) + removedNames.forEach(this::stopChannel) + + configuredChannels.forEach { channelConfig -> + val registration = registrations[channelConfig.channelName] + ?: error("AgentGateway channel is not registered: ${channelConfig.channelName}") + val existing = runningChannels[channelConfig.channelName] + if (existing != null && existing.registration === registration && + registration.supportsHotReloadReuse(existing.params, channelConfig.params) + ) { + return@forEach + } + if (existing != null) { + stopChannel(channelConfig.channelName) + } + startChannel(registration, channelConfig) + } + } + + private fun startChannel( + registration: AgentGatewayRegistration, + channelConfig: AgentGatewayChannelConfig + ) { + val gateway = registration.create(channelConfig.params) + try { + gateway.launch() + AgentRuntime.registerResponseChannel(channelConfig.channelName, gateway) + runningChannels[channelConfig.channelName] = RunningGateway( + registration = registration, + params = LinkedHashMap(channelConfig.params), + gateway = gateway + ) + } catch (e: Exception) { + runCatching { registration.shutdown(gateway) } + .onFailure { shutdownError -> + log.warn( + "Failed to shutdown gateway after launch failure: {}", + channelConfig.channelName, + shutdownError + ) + } + throw e + } + } + + private fun stopChannel(channelName: String) { + val running = runningChannels.remove(channelName) ?: return + runCatching { running.registration.shutdown(running.gateway) } + .onFailure { e -> log.warn("Failed to shutdown gateway: {}", channelName, e) } + AgentRuntime.unregisterResponseChannel(channelName) + } + + private fun restoreSnapshot( + runtimeSnapshot: Map, + defaultSnapshot: String + ) { + val currentChannels = runningChannels.keys.toList() + currentChannels.forEach(this::stopChannel) + + runtimeSnapshot.forEach { (channelName, running) -> + AgentRuntime.registerResponseChannel(channelName, running.gateway) + runningChannels[channelName] = running + } + AgentRuntime.setDefaultResponseChannel(defaultSnapshot) + } + + private data class RunningGateway( + val registration: AgentGatewayRegistration, + val params: Map, + val gateway: AgentGateway<*, *> + ) +} + +data class AgentGatewayRegistryConfig( + @field:JSONField(name = "default_channel") + @field:ConfigDoc(description = "默认响应通道", example = "websocket_channel") + val defaultChannel: String, + @field:ConfigDoc( + description = "要启用的通道列表", + example = """[ + { + "channel_name": "websocket_channel", + "params": { + "port": "29600", + "heartbeat_interval": "10000" + } + } + ]""" + ) + val channels: List +) : Config() + +data class AgentGatewayChannelConfig( + @field:JSONField(name = "channel_name") + @field:ConfigDoc(description = "通道名称,同时对应已注册的 gateway 名称", example = "websocket_channel") + val channelName: String, + @field:ConfigDoc(description = "通道参数", example = """{ "key1": "value1", "key2": "value2" }""") + val params: Map = emptyMap() +) diff --git a/Partner-Framework/src/main/java/work/slhaf/partner/framework/agent/interaction/AgentRuntime.kt b/Partner-Framework/src/main/java/work/slhaf/partner/framework/agent/interaction/AgentRuntime.kt index 24a39c75..b8a218bc 100644 --- a/Partner-Framework/src/main/java/work/slhaf/partner/framework/agent/interaction/AgentRuntime.kt +++ b/Partner-Framework/src/main/java/work/slhaf/partner/framework/agent/interaction/AgentRuntime.kt @@ -17,8 +17,8 @@ object AgentRuntime { LogChannel.channelName to LogChannel ) - // TODO 暂时取 log_channel 为默认回复通道,若为空则只打印信息。后续将配合配置中心替换通过配置文件进行指定 - private val defaultChannel: String = "log_channel" + @Volatile + private var defaultChannel: String = LogChannel.channelName @Volatile private var runningModules: Map>> = emptyMap() @@ -35,11 +35,24 @@ object AgentRuntime { responseChannels[channelName] = responseChannel } + fun unregisterResponseChannel(channelName: String) { + if (channelName == LogChannel.channelName) { + return + } + responseChannels.remove(channelName) + } + + fun setDefaultResponseChannel(channelName: String) { + defaultChannel = channelName + } + + fun defaultResponseChannel(): String = defaultChannel + @JvmOverloads fun response(event: InteractionEvent, channelName: String = defaultChannel) { val channel = responseChannels[channelName] if (channel == null) { - responseChannels[defaultChannel]!!.response(event) + responseChannels[defaultChannel]?.response(event) ?: LogChannel.response(event) } else { channel.response(event) } @@ -87,4 +100,4 @@ object AgentRuntime { } } -} \ No newline at end of file +}