mirror of
https://github.com/slhaf/Partner.git
synced 2026-05-12 08:43:02 +08:00
refactor(gateway): manage gateway implementions via related registry and interface
This commit is contained in:
@@ -9,6 +9,8 @@ import org.java_websocket.handshake.ClientHandshake;
|
||||
import org.java_websocket.server.WebSocketServer;
|
||||
import org.jetbrains.annotations.NotNull;
|
||||
import work.slhaf.partner.framework.agent.interaction.AgentGateway;
|
||||
import work.slhaf.partner.framework.agent.interaction.AgentGatewayRegistration;
|
||||
import work.slhaf.partner.framework.agent.interaction.AgentRuntime;
|
||||
import work.slhaf.partner.framework.agent.interaction.data.InputData;
|
||||
import work.slhaf.partner.framework.agent.interaction.data.InteractionEvent;
|
||||
import work.slhaf.partner.runtime.interaction.data.context.PartnerRunningFlowContext;
|
||||
@@ -17,6 +19,7 @@ import java.net.InetSocketAddress;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
import java.util.concurrent.ExecutorService;
|
||||
import java.util.concurrent.Executors;
|
||||
import java.util.concurrent.atomic.AtomicBoolean;
|
||||
|
||||
@Slf4j
|
||||
public class WebSocketGateway extends WebSocketServer implements AgentGateway<InputData, PartnerRunningFlowContext> {
|
||||
@@ -26,6 +29,7 @@ public class WebSocketGateway extends WebSocketServer implements AgentGateway<In
|
||||
@ToString.Exclude
|
||||
private final ConcurrentHashMap<String, WebSocket> userSessions = new ConcurrentHashMap<>();
|
||||
private final ExecutorService executor;
|
||||
private final AtomicBoolean launched = new AtomicBoolean(false);
|
||||
|
||||
// 记录最后一次收到Pong的时间
|
||||
private final ConcurrentHashMap<WebSocket, Long> lastPongTimes = new ConcurrentHashMap<>();
|
||||
@@ -38,10 +42,13 @@ public class WebSocketGateway extends WebSocketServer implements AgentGateway<In
|
||||
}
|
||||
|
||||
public void launch() {
|
||||
if (!launched.compareAndSet(false, true)) {
|
||||
return;
|
||||
}
|
||||
this.start();
|
||||
setShutDownHook();
|
||||
startHeartbeatThread();
|
||||
register();
|
||||
AgentRuntime.INSTANCE.registerResponseChannel(getChannelName(), this);
|
||||
}
|
||||
|
||||
@Override
|
||||
@@ -154,9 +161,35 @@ public class WebSocketGateway extends WebSocketServer implements AgentGateway<In
|
||||
log.info("WebSocketServer 已启动...");
|
||||
}
|
||||
|
||||
@Override
|
||||
public AgentGatewayRegistration registration() {
|
||||
return WebSocketGatewayRegistration.INSTANCE;
|
||||
}
|
||||
|
||||
@Override
|
||||
@NotNull
|
||||
public String getChannelName() {
|
||||
return "websocket_channel";
|
||||
}
|
||||
|
||||
@Override
|
||||
public void close() {
|
||||
executor.shutdownNow();
|
||||
lastPongTimes.clear();
|
||||
userSessions.clear();
|
||||
try {
|
||||
for (WebSocket webSocket : getConnections()) {
|
||||
if (webSocket != null && webSocket.isOpen()) {
|
||||
webSocket.close(1001, "Server shutting down");
|
||||
}
|
||||
}
|
||||
if (launched.get()) {
|
||||
super.stop(1000);
|
||||
}
|
||||
} catch (Exception e) {
|
||||
log.warn("关闭 WebSocketGateway 失败", e);
|
||||
} finally {
|
||||
launched.set(false);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,25 @@
|
||||
package work.slhaf.partner.runtime.interaction
|
||||
|
||||
import work.slhaf.partner.framework.agent.interaction.AgentGateway
|
||||
import work.slhaf.partner.framework.agent.interaction.AgentGatewayRegistration
|
||||
|
||||
object WebSocketGatewayRegistration : AgentGatewayRegistration {
|
||||
|
||||
override val channelName: String = "websocket_channel"
|
||||
|
||||
override fun create(params: Map<String, String>): AgentGateway<*, *> {
|
||||
val port = params["port"]?.toIntOrNull() ?: 29600
|
||||
val heartbeatInterval = params["heartbeat_interval"]?.toLongOrNull() ?: 10_000L
|
||||
require(port > 0) { "port must be greater than 0" }
|
||||
require(heartbeatInterval > 0) { "heartbeat_interval must be greater than 0" }
|
||||
return WebSocketGateway(port, heartbeatInterval)
|
||||
}
|
||||
|
||||
override fun shutdown(instance: AgentGateway<*, *>) {
|
||||
if (instance is WebSocketGateway) {
|
||||
instance.close()
|
||||
} else {
|
||||
super.shutdown(instance)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,52 +0,0 @@
|
||||
package work.slhaf.partner.runtime.interaction;
|
||||
|
||||
import com.alibaba.fastjson2.JSONObject;
|
||||
import org.jetbrains.annotations.NotNull;
|
||||
import org.jetbrains.annotations.Nullable;
|
||||
import work.slhaf.partner.framework.agent.config.Config;
|
||||
import work.slhaf.partner.framework.agent.config.ConfigDoc;
|
||||
import work.slhaf.partner.framework.agent.config.ConfigRegistration;
|
||||
import work.slhaf.partner.framework.agent.config.Configurable;
|
||||
|
||||
import java.nio.file.Path;
|
||||
import java.util.Map;
|
||||
|
||||
public class WebSocketGatewayRegistry implements Configurable {
|
||||
// TODO 在 Agent 入口处,针对这类内容提供统一注册
|
||||
@Override
|
||||
public @NotNull Map<Path, ConfigRegistration<? extends Config>> declare() {
|
||||
return Map.of(Path.of("gateway", "websocket.json"), new WebSocketRegistration());
|
||||
}
|
||||
|
||||
static class WebSocketRegistration implements ConfigRegistration<WebSocketConfig> {
|
||||
|
||||
@Override
|
||||
@NotNull
|
||||
public Class<WebSocketConfig> type() {
|
||||
return WebSocketConfig.class;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void init(@NotNull WebSocketConfig config, JSONObject json) {
|
||||
new WebSocketGateway(config.port, config.heartbeatInterval);
|
||||
}
|
||||
|
||||
@Nullable
|
||||
@Override
|
||||
public WebSocketConfig defaultConfig() {
|
||||
return new WebSocketConfig(29600, 10_000);
|
||||
}
|
||||
}
|
||||
|
||||
static class WebSocketConfig extends Config {
|
||||
@ConfigDoc(description = "WebSocket 监听端口")
|
||||
final int port;
|
||||
@ConfigDoc(description = "WebSocket 心跳间隔", unit = "ms", constraint = "> 0", example = "10000")
|
||||
final int heartbeatInterval;
|
||||
|
||||
WebSocketConfig(int port, int heartbeatInterval) {
|
||||
this.port = port;
|
||||
this.heartbeatInterval = heartbeatInterval;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -3,10 +3,17 @@ package work.slhaf.partner.framework.agent.interaction;
|
||||
import work.slhaf.partner.framework.agent.interaction.data.InputData;
|
||||
import work.slhaf.partner.framework.agent.interaction.flow.RunningFlowContext;
|
||||
|
||||
public interface AgentGateway<I extends InputData, C extends RunningFlowContext> extends ResponseChannel {
|
||||
public interface AgentGateway<I extends InputData, C extends RunningFlowContext> extends ResponseChannel, AutoCloseable {
|
||||
|
||||
void launch();
|
||||
|
||||
AgentGatewayRegistration registration();
|
||||
|
||||
@Override
|
||||
default void register() {
|
||||
registration().register();
|
||||
}
|
||||
|
||||
default void receive(I inputData) {
|
||||
C parsedContext = parseRunningFlowContext(inputData);
|
||||
AgentRuntime.INSTANCE.submit(parsedContext);
|
||||
|
||||
@@ -0,0 +1,22 @@
|
||||
package work.slhaf.partner.framework.agent.interaction
|
||||
|
||||
interface AgentGatewayRegistration {
|
||||
|
||||
val channelName: String
|
||||
|
||||
fun create(params: Map<String, String>): AgentGateway<*, *>
|
||||
|
||||
fun supportsHotReloadReuse(oldParams: Map<String, String>, newParams: Map<String, String>): Boolean {
|
||||
return oldParams == newParams
|
||||
}
|
||||
|
||||
fun shutdown(instance: AgentGateway<*, *>) {
|
||||
if (instance is AutoCloseable) {
|
||||
instance.close()
|
||||
}
|
||||
}
|
||||
|
||||
fun register() {
|
||||
AgentGatewayRegistry.register(this)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,187 @@
|
||||
package work.slhaf.partner.framework.agent.interaction
|
||||
|
||||
import com.alibaba.fastjson2.JSONObject
|
||||
import com.alibaba.fastjson2.annotation.JSONField
|
||||
import org.slf4j.LoggerFactory
|
||||
import work.slhaf.partner.framework.agent.config.Config
|
||||
import work.slhaf.partner.framework.agent.config.ConfigDoc
|
||||
import work.slhaf.partner.framework.agent.config.ConfigRegistration
|
||||
import work.slhaf.partner.framework.agent.config.Configurable
|
||||
import java.nio.file.Path
|
||||
import java.util.concurrent.locks.ReentrantLock
|
||||
import kotlin.concurrent.withLock
|
||||
|
||||
object AgentGatewayRegistry : Configurable, ConfigRegistration<AgentGatewayRegistryConfig> {
|
||||
|
||||
private val log = LoggerFactory.getLogger(AgentGatewayRegistry::class.java)
|
||||
private val registryLock = ReentrantLock()
|
||||
private val registrations = linkedMapOf<String, AgentGatewayRegistration>()
|
||||
private val runningChannels = linkedMapOf<String, RunningGateway>()
|
||||
|
||||
init {
|
||||
register()
|
||||
}
|
||||
|
||||
override fun declare(): Map<Path, ConfigRegistration<out Config>> {
|
||||
return mapOf(Path.of("gateway", "gateway.json") to this)
|
||||
}
|
||||
|
||||
override fun type(): Class<AgentGatewayRegistryConfig> = AgentGatewayRegistryConfig::class.java
|
||||
|
||||
fun register(registration: AgentGatewayRegistration) = registryLock.withLock {
|
||||
val previous = registrations.putIfAbsent(registration.channelName, registration)
|
||||
check(previous == null || previous === registration) {
|
||||
"AgentGateway channel already registered: ${registration.channelName}"
|
||||
}
|
||||
}
|
||||
|
||||
fun resolve(channelName: String): AgentGateway<*, *>? = registryLock.withLock {
|
||||
runningChannels[channelName]?.gateway
|
||||
}
|
||||
|
||||
internal fun snapshotRunningChannels(): Map<String, AgentGateway<*, *>> = registryLock.withLock {
|
||||
runningChannels.mapValues { it.value.gateway }
|
||||
}
|
||||
|
||||
override fun init(config: AgentGatewayRegistryConfig, json: JSONObject?) = registryLock.withLock {
|
||||
applyConfig(config)
|
||||
}
|
||||
|
||||
override fun onReload(config: AgentGatewayRegistryConfig, json: JSONObject?) = registryLock.withLock {
|
||||
val runtimeSnapshot = LinkedHashMap(runningChannels)
|
||||
val defaultSnapshot = AgentRuntime.defaultResponseChannel()
|
||||
try {
|
||||
applyConfig(config)
|
||||
} catch (e: Exception) {
|
||||
log.error("Error while reloading gateway config", e)
|
||||
restoreSnapshot(runtimeSnapshot, defaultSnapshot)
|
||||
}
|
||||
}
|
||||
|
||||
override fun defaultConfig(): AgentGatewayRegistryConfig? = null
|
||||
|
||||
private fun applyConfig(config: AgentGatewayRegistryConfig) {
|
||||
validateConfig(config)
|
||||
reconcileChannels(config.channels)
|
||||
AgentRuntime.setDefaultResponseChannel(config.defaultChannel)
|
||||
}
|
||||
|
||||
private fun validateConfig(config: AgentGatewayRegistryConfig) {
|
||||
require(config.defaultChannel.isNotBlank()) { "default_channel must not be blank" }
|
||||
require(config.channels.isNotEmpty()) { "channels must not be empty" }
|
||||
|
||||
val channelNames = mutableSetOf<String>()
|
||||
config.channels.forEach { channel ->
|
||||
require(channel.channelName.isNotBlank()) { "channel_name must not be blank" }
|
||||
require(channelNames.add(channel.channelName)) { "Duplicated channel_name: ${channel.channelName}" }
|
||||
require(registrations.containsKey(channel.channelName)) {
|
||||
"AgentGateway channel is not registered: ${channel.channelName}"
|
||||
}
|
||||
}
|
||||
|
||||
require(channelNames.contains(config.defaultChannel)) {
|
||||
"default_channel must exist in channels: ${config.defaultChannel}"
|
||||
}
|
||||
}
|
||||
|
||||
private fun reconcileChannels(configuredChannels: List<AgentGatewayChannelConfig>) {
|
||||
val expectedNames = configuredChannels.map { it.channelName }.toSet()
|
||||
val removedNames = runningChannels.keys.filterNot(expectedNames::contains)
|
||||
removedNames.forEach(this::stopChannel)
|
||||
|
||||
configuredChannels.forEach { channelConfig ->
|
||||
val registration = registrations[channelConfig.channelName]
|
||||
?: error("AgentGateway channel is not registered: ${channelConfig.channelName}")
|
||||
val existing = runningChannels[channelConfig.channelName]
|
||||
if (existing != null && existing.registration === registration &&
|
||||
registration.supportsHotReloadReuse(existing.params, channelConfig.params)
|
||||
) {
|
||||
return@forEach
|
||||
}
|
||||
if (existing != null) {
|
||||
stopChannel(channelConfig.channelName)
|
||||
}
|
||||
startChannel(registration, channelConfig)
|
||||
}
|
||||
}
|
||||
|
||||
private fun startChannel(
|
||||
registration: AgentGatewayRegistration,
|
||||
channelConfig: AgentGatewayChannelConfig
|
||||
) {
|
||||
val gateway = registration.create(channelConfig.params)
|
||||
try {
|
||||
gateway.launch()
|
||||
AgentRuntime.registerResponseChannel(channelConfig.channelName, gateway)
|
||||
runningChannels[channelConfig.channelName] = RunningGateway(
|
||||
registration = registration,
|
||||
params = LinkedHashMap(channelConfig.params),
|
||||
gateway = gateway
|
||||
)
|
||||
} catch (e: Exception) {
|
||||
runCatching { registration.shutdown(gateway) }
|
||||
.onFailure { shutdownError ->
|
||||
log.warn(
|
||||
"Failed to shutdown gateway after launch failure: {}",
|
||||
channelConfig.channelName,
|
||||
shutdownError
|
||||
)
|
||||
}
|
||||
throw e
|
||||
}
|
||||
}
|
||||
|
||||
private fun stopChannel(channelName: String) {
|
||||
val running = runningChannels.remove(channelName) ?: return
|
||||
runCatching { running.registration.shutdown(running.gateway) }
|
||||
.onFailure { e -> log.warn("Failed to shutdown gateway: {}", channelName, e) }
|
||||
AgentRuntime.unregisterResponseChannel(channelName)
|
||||
}
|
||||
|
||||
private fun restoreSnapshot(
|
||||
runtimeSnapshot: Map<String, RunningGateway>,
|
||||
defaultSnapshot: String
|
||||
) {
|
||||
val currentChannels = runningChannels.keys.toList()
|
||||
currentChannels.forEach(this::stopChannel)
|
||||
|
||||
runtimeSnapshot.forEach { (channelName, running) ->
|
||||
AgentRuntime.registerResponseChannel(channelName, running.gateway)
|
||||
runningChannels[channelName] = running
|
||||
}
|
||||
AgentRuntime.setDefaultResponseChannel(defaultSnapshot)
|
||||
}
|
||||
|
||||
private data class RunningGateway(
|
||||
val registration: AgentGatewayRegistration,
|
||||
val params: Map<String, String>,
|
||||
val gateway: AgentGateway<*, *>
|
||||
)
|
||||
}
|
||||
|
||||
data class AgentGatewayRegistryConfig(
|
||||
@field:JSONField(name = "default_channel")
|
||||
@field:ConfigDoc(description = "默认响应通道", example = "websocket_channel")
|
||||
val defaultChannel: String,
|
||||
@field:ConfigDoc(
|
||||
description = "要启用的通道列表",
|
||||
example = """[
|
||||
{
|
||||
"channel_name": "websocket_channel",
|
||||
"params": {
|
||||
"port": "29600",
|
||||
"heartbeat_interval": "10000"
|
||||
}
|
||||
}
|
||||
]"""
|
||||
)
|
||||
val channels: List<AgentGatewayChannelConfig>
|
||||
) : Config()
|
||||
|
||||
data class AgentGatewayChannelConfig(
|
||||
@field:JSONField(name = "channel_name")
|
||||
@field:ConfigDoc(description = "通道名称,同时对应已注册的 gateway 名称", example = "websocket_channel")
|
||||
val channelName: String,
|
||||
@field:ConfigDoc(description = "通道参数", example = """{ "key1": "value1", "key2": "value2" }""")
|
||||
val params: Map<String, String> = emptyMap()
|
||||
)
|
||||
@@ -17,8 +17,8 @@ object AgentRuntime {
|
||||
LogChannel.channelName to LogChannel
|
||||
)
|
||||
|
||||
// TODO 暂时取 log_channel 为默认回复通道,若为空则只打印信息。后续将配合配置中心替换通过配置文件进行指定
|
||||
private val defaultChannel: String = "log_channel"
|
||||
@Volatile
|
||||
private var defaultChannel: String = LogChannel.channelName
|
||||
|
||||
@Volatile
|
||||
private var runningModules: Map<Int, List<AbstractAgentModule.Running<RunningFlowContext>>> = emptyMap()
|
||||
@@ -35,11 +35,24 @@ object AgentRuntime {
|
||||
responseChannels[channelName] = responseChannel
|
||||
}
|
||||
|
||||
fun unregisterResponseChannel(channelName: String) {
|
||||
if (channelName == LogChannel.channelName) {
|
||||
return
|
||||
}
|
||||
responseChannels.remove(channelName)
|
||||
}
|
||||
|
||||
fun setDefaultResponseChannel(channelName: String) {
|
||||
defaultChannel = channelName
|
||||
}
|
||||
|
||||
fun defaultResponseChannel(): String = defaultChannel
|
||||
|
||||
@JvmOverloads
|
||||
fun response(event: InteractionEvent, channelName: String = defaultChannel) {
|
||||
val channel = responseChannels[channelName]
|
||||
if (channel == null) {
|
||||
responseChannels[defaultChannel]!!.response(event)
|
||||
responseChannels[defaultChannel]?.response(event) ?: LogChannel.response(event)
|
||||
} else {
|
||||
channel.response(event)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user