feat(RunnerClient): support MCP type-based dynamic client/server registration

This allows implementations of RunnerClient to dynamically register different types of MCP service, and also provides a shutdown hook to close client/server properly.
This commit is contained in:
2025-12-18 22:25:32 +08:00
parent cb28a5b068
commit e851e33b2e
2 changed files with 203 additions and 0 deletions

View File

@@ -42,6 +42,8 @@ public abstract class RunnerClient {
protected final Map<String, MetaActionInfo> existedMetaActions; protected final Map<String, MetaActionInfo> existedMetaActions;
protected final ExecutorService executor; protected final ExecutorService executor;
protected final Map<String, McpSyncClient> mcpClients = new HashMap<>();
protected final Map<String, McpStatelessAsyncServer> localMcpServers = new HashMap<>();
/** /**
* ActionCore 将注入虚拟线程池 * ActionCore 将注入虚拟线程池
@@ -49,6 +51,18 @@ public abstract class RunnerClient {
public RunnerClient(Map<String, MetaActionInfo> existedMetaActions, ExecutorService executor) { public RunnerClient(Map<String, MetaActionInfo> existedMetaActions, ExecutorService executor) {
this.existedMetaActions = existedMetaActions; this.existedMetaActions = existedMetaActions;
this.executor = executor; this.executor = executor;
setupShutdownHook();
}
protected void setupShutdownHook() {
this.mcpClients.forEach((id, client) -> {
client.close();
log.info("[{}] MCP-Client 已关闭", id);
});
this.localMcpServers.forEach((id, server) -> {
server.close();
log.info("[{}] MCP-Server 已关闭", id);
});
} }
/** /**
@@ -65,6 +79,77 @@ public abstract class RunnerClient {
result.setStatus(response.isOk() ? ResultStatus.SUCCESS : ResultStatus.FAILED); result.setStatus(response.isOk() ? ResultStatus.SUCCESS : ResultStatus.FAILED);
} }
protected void registerMcpClient(String id, McpServerParams mcpServerParams) {
McpClientTransport clientTransport = createTransport(mcpServerParams);
McpSyncClient client = McpClient.sync(clientTransport)
.requestTimeout(Duration.ofSeconds(mcpServerParams.timeout))
.clientInfo(new McpSchema.Implementation(id, "PARTNER"))
// 行动程序(现 MCP Tool)的描述文本将直接由resources返回
// 原因: ToolChange 发送的内容侧重调用,缺少可承担描述文本的字段
// ResourcesChange 事件传递的 Resource 可以由 Client 读取内容
// 预计在 Server 侧,收到客户端发送的新的行动程序信息,该信息由客户端处补充后,将其放置在指定位置
// 并写入描述文件、发起 ResourcesChange 事件
.resourcesChangeConsumer(resources -> updateExistedMetaActions(id, resources))
.build();
mcpClients.put(id, client);
}
private void updateExistedMetaActions(String id, List<McpSchema.Resource> resources) {
synchronized (existedMetaActions) {
McpSyncClient client = mcpClients.get(id);
for (McpSchema.Resource resource : resources) {
McpSchema.ReadResourceResult resourceResult = client.readResource(resource);
for (McpSchema.ResourceContents resourceContent : resourceResult.contents()) {
// 忽略非文本类型,行动描述信息只会以文本形式存在
if (resourceContent instanceof McpSchema.TextResourceContents content) {
MetaActionInfo metaActionInfo = JSONObject.parseObject(content.text(), MetaActionInfo.class);
existedMetaActions.put(id + "::" + metaActionInfo.getKey(), metaActionInfo);
}
}
}
}
}
private McpClientTransport createTransport(McpServerParams mcpServerParams) {
return switch (mcpServerParams) {
case InProcessMcpServerParams params -> {
InProcessMcpTransport.Pair pair = InProcessMcpTransport.pair();
createInProcessMcpServer(params.id, pair.serverSide);
yield pair.clientSide;
}
case StdioMcpServerParams params -> {
ServerParameters serverParameters = ServerParameters.builder(params.command)
.env(params.env)
.args(params.args)
.build();
yield new StdioClientTransport(serverParameters, McpJsonMapper.getDefault());
}
case HttpMcpServerParams params -> {
McpSyncHttpClientRequestCustomizer customizer = (builder, method, endpoint, body, context) -> {
params.headers.forEach(builder::setHeader);
};
yield HttpClientSseClientTransport.builder(params.baseUri)
.httpRequestCustomizer(customizer)
.sseEndpoint(params.endpoint)
.build();
}
};
}
private void createInProcessMcpServer(String id, InProcessMcpTransport serverSide) {
McpSchema.ServerCapabilities serverCapabilities = McpSchema.ServerCapabilities.builder()
.tools(true)
.resources(true, true)
.build();
McpStatelessAsyncServer server = McpServer.async(serverSide)
.capabilities(serverCapabilities)
.serverInfo(id, "PARTNER")
.build();
localMcpServers.put(id, server);
}
protected abstract RunnerResponse doRun(MetaAction metaAction); protected abstract RunnerResponse doRun(MetaAction metaAction);
public abstract Path buildTmpPath(MetaAction tempAction, String codeType); public abstract Path buildTmpPath(MetaAction tempAction, String codeType);
@@ -84,6 +169,58 @@ public abstract class RunnerClient {
private String data; private String data;
} }
protected sealed abstract static class McpServerParams permits HttpMcpServerParams, InProcessMcpServerParams, StdioMcpServerParams {
private final int timeout;
private McpServerParams(int timeout) {
this.timeout = timeout;
}
}
protected final static class HttpMcpServerParams extends McpServerParams {
private final String baseUri;
private final String endpoint;
private final Map<String, String> headers;
protected HttpMcpServerParams(int timeout, String baseUri, String endpoint, Map<String, String> header) {
super(timeout);
this.baseUri = baseUri;
this.endpoint = endpoint;
this.headers = header;
}
}
protected final static class StdioMcpServerParams extends McpServerParams {
private final String command;
private final Map<String, String> env;
private final List<String> args;
protected StdioMcpServerParams(int timeout, String command, Map<String, String> env, List<String> args) {
super(timeout);
this.command = command;
this.env = env;
this.args = args;
}
}
protected final static class InProcessMcpServerParams extends McpServerParams {
private final String id;
protected InProcessMcpServerParams(int timeout, String id) {
super(timeout);
this.id = id;
}
}
protected enum McpServerType {
HTTP,
STDIO,
/**
* 对应 Partner 内部的 Server 创建方式
*/
SELF
}
public static final class InProcessMcpTransport implements McpClientTransport, McpStatelessServerTransport { public static final class InProcessMcpTransport implements McpClientTransport, McpStatelessServerTransport {
// 每个 transport 只处理一条 inbound 流 // 每个 transport 只处理一条 inbound 流

View File

@@ -1,5 +1,6 @@
package work.slhaf.partner.core.action.runner; package work.slhaf.partner.core.action.runner;
import com.alibaba.fastjson2.JSONObject;
import io.modelcontextprotocol.client.McpClient; import io.modelcontextprotocol.client.McpClient;
import io.modelcontextprotocol.client.McpSyncClient; import io.modelcontextprotocol.client.McpSyncClient;
import io.modelcontextprotocol.server.McpServer; import io.modelcontextprotocol.server.McpServer;
@@ -7,11 +8,45 @@ import io.modelcontextprotocol.server.McpStatelessServerFeatures;
import io.modelcontextprotocol.server.McpStatelessSyncServer; import io.modelcontextprotocol.server.McpStatelessSyncServer;
import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import work.slhaf.partner.core.action.entity.McpData;
import work.slhaf.partner.core.action.entity.MetaAction;
import work.slhaf.partner.core.action.entity.MetaActionInfo;
import java.io.IOException;
import java.nio.file.Path;
import java.util.List; import java.util.List;
import java.util.Map;
import java.util.concurrent.Executors;
public class RunnerClientTest { public class RunnerClientTest {
@Test
void httpMcpClientTest() {
TestRunnerClient testClient = new TestRunnerClient();
RunnerClient.HttpMcpServerParams params = new RunnerClient.HttpMcpServerParams(20, "https://dashscope.aliyuncs.com", "/api/v1/mcps/WebSearch/sse", Map.of("Authorization", "Bearer sk-xxx"));
testClient.registerMcpClient("test", params);
McpSyncClient client = testClient.mcpClients.values().stream().toList().getFirst();
List<McpSchema.Tool> tools = client.listTools().tools();
System.out.println(tools);
McpSchema.CallToolResult query = client.callTool(McpSchema.CallToolRequest.builder().name(tools.getFirst().name()).arguments(Map.of("query", "123")).build());
for (McpSchema.Content content : query.content()) {
System.out.println("\r\n---\r\n");
System.out.println(content);
}
}
@Test
void stdioMcpClientTest() {
TestRunnerClient testClient = new TestRunnerClient();
RunnerClient.StdioMcpServerParams params = new RunnerClient.StdioMcpServerParams(20, "uvx", Map.of("http_proxy", "http://127.0.0.1:7897", "https_proxy", "http://127.0.0.1:7897"), List.of("mcp-server-fetch"));
testClient.registerMcpClient("test", params);
McpSyncClient client = testClient.mcpClients.values().stream().toList().getFirst();
List<McpSchema.Tool> tools = client.listTools().tools();
System.out.println(tools);
McpSchema.CallToolResult query = client.callTool(McpSchema.CallToolRequest.builder().name(tools.getFirst().name()).arguments(Map.of("url", "https://gitea.slhaf.work")).build());
System.out.println(query.toString());
}
@Test @Test
void inProcessMcpTransportTest() { void inProcessMcpTransportTest() {
RunnerClient.InProcessMcpTransport.Pair pair = RunnerClient.InProcessMcpTransport.pair(); RunnerClient.InProcessMcpTransport.Pair pair = RunnerClient.InProcessMcpTransport.pair();
@@ -39,4 +74,35 @@ public class RunnerClientTest {
server.close(); server.close();
} }
private static class TestRunnerClient extends RunnerClient {
public TestRunnerClient() {
super(Map.of(), Executors.newVirtualThreadPerTaskExecutor());
}
@Override
protected RunnerResponse doRun(MetaAction metaAction) {
return null;
}
@Override
public Path buildTmpPath(MetaAction tempAction, String codeType) {
return null;
}
@Override
public void tmpSerialize(MetaAction tempAction, String code, String codeType) throws IOException {
}
@Override
public void persistSerialize(MetaActionInfo metaActionInfo, McpData mcpData) {
}
@Override
public JSONObject listSysDependencies() {
return null;
}
}
} }