From 44ab6cfac84ddac90e94a2b12a1fad9ea5dd3f54 Mon Sep 17 00:00:00 2001 From: slhafzjw Date: Mon, 5 Jan 2026 23:06:17 +0800 Subject: [PATCH] feat(LocalRunnerClient): support registering MCP clients in CommonMcp --- .../core/action/runner/LocalRunnerClient.java | 197 +++++++++++++----- 1 file changed, 141 insertions(+), 56 deletions(-) diff --git a/Partner-Main/src/main/java/work/slhaf/partner/core/action/runner/LocalRunnerClient.java b/Partner-Main/src/main/java/work/slhaf/partner/core/action/runner/LocalRunnerClient.java index 8c12f53b..c4c331ad 100644 --- a/Partner-Main/src/main/java/work/slhaf/partner/core/action/runner/LocalRunnerClient.java +++ b/Partner-Main/src/main/java/work/slhaf/partner/core/action/runner/LocalRunnerClient.java @@ -1,6 +1,7 @@ package work.slhaf.partner.core.action.runner; import cn.hutool.core.io.FileUtil; +import cn.hutool.core.io.IORuntimeException; import cn.hutool.json.JSONUtil; import com.alibaba.fastjson2.JSONArray; import com.alibaba.fastjson2.JSONObject; @@ -20,9 +21,9 @@ import io.modelcontextprotocol.spec.McpSchema; import javassist.NotFoundException; import lombok.Data; import lombok.extern.slf4j.Slf4j; +import lombok.val; import org.jetbrains.annotations.NotNull; import org.jetbrains.annotations.Nullable; -import org.jetbrains.annotations.UnknownNullability; import reactor.core.publisher.Mono; import reactor.core.scheduler.Schedulers; import work.slhaf.partner.common.mcp.InProcessMcpTransport; @@ -37,6 +38,7 @@ import java.io.File; import java.io.IOException; import java.io.InputStreamReader; import java.net.URI; +import java.net.http.HttpRequest; import java.nio.charset.StandardCharsets; import java.nio.file.*; import java.time.Duration; @@ -251,18 +253,6 @@ public class LocalRunnerClient extends RunnerClient { return sysDependencies; } - /** - * 该部分主要发生在扫描到新的MCP Server描述文件时出现的注册逻辑 - * - * @param id MCP Client 的 id - * @param mcpClientTransportParams MCP Server 的参数 - */ - private void registerMcpClient(String id, McpClientTransportParams mcpClientTransportParams) { - McpClientTransport clientTransport = createTransport(mcpClientTransportParams); - int timeout = mcpClientTransportParams.timeout; - registerMcpClient(id, clientTransport, timeout); - } - private void registerMcpClient(String id, McpClientTransport clientTransport, int timeout) { McpSyncClient client = McpClient.sync(clientTransport) .requestTimeout(Duration.ofSeconds(timeout)) @@ -292,27 +282,6 @@ public class LocalRunnerClient extends RunnerClient { return info; } - private McpClientTransport createTransport(McpClientTransportParams mcpClientTransportParams) { - return switch (mcpClientTransportParams) { - case McpClientTransportParams.Stdio params -> { - ServerParameters serverParameters = ServerParameters.builder(params.command) - .env(params.env) - .args(params.args) - .build(); - yield new StdioClientTransport(serverParameters, McpJsonMapper.getDefault()); - } - case McpClientTransportParams.Http params -> { - McpSyncHttpClientRequestCustomizer customizer = (builder, method, endpoint, body, context) -> { - params.headers.forEach(builder::setHeader); - }; - yield HttpClientSseClientTransport.builder(params.baseUri) - .httpRequestCustomizer(customizer) - .sseEndpoint(params.endpoint) - .build(); - } - }; - } - private void setupShutdownHook() { Runtime.getRuntime().addShutdownHook(new Thread(() -> { dynamicActionMcpServer.close(); @@ -490,6 +459,18 @@ public class LocalRunnerClient extends RunnerClient { protected abstract @NotNull LocalWatchServiceBuild.EventHandler buildOverflow(); + protected File[] loadFiles() { + val root = ctx.root; + if (!Files.isDirectory(root)) { + throw new ActionInitFailedException("未找到目录: " + root); + } + val files = root.toFile().listFiles(); + if (files == null) { + throw new ActionInitFailedException("目录无法正常读取: " + root); + } + return files; + } + @SuppressWarnings("LoggingSimilarMessage") private static final class Dynamic extends LocalWatchEventProcessor { @@ -502,14 +483,7 @@ public class LocalRunnerClient extends RunnerClient { @SuppressWarnings("BooleanMethodIsAlwaysInverted") private boolean normalPath(Path path) { - File file = path.toFile(); - if (file.isFile()) { - return false; - } - File[] files = file.listFiles(); - if (files == null) { - return false; - } + val files = loadFiles(); if (files.length < 2) { return false; } @@ -879,14 +853,7 @@ public class LocalRunnerClient extends RunnerClient { return () -> { // DescMcp 的加载逻辑只负责读取已有的 *.desc.json 并注册为 resources // 正常来讲 root 直接对应 MCP_DESC_PATH,先检查 root 是否为目录,否则拒绝启动 - Path root = ctx.root; - if (!Files.isDirectory(root)) { - throw new ActionInitFailedException("未找到目录: " + root); - } - File[] files = root.toFile().listFiles(); - if (files == null) { - throw new ActionInitFailedException("目录无法正常读取: " + root); - } + val files = loadFiles(); for (File file : files) { addResource(file); } @@ -1005,9 +972,130 @@ public class LocalRunnerClient extends RunnerClient { this.mcpClients = mcpClients; } + /** + * 该部分主要发生在扫描到新的MCP Server描述文件时出现的注册逻辑 + * + * @param id MCP Client 的 id + * @param mcpClientTransportParams MCP Server 的参数 + */ + private void registerMcpClient(String id, McpClientTransportParams mcpClientTransportParams) { + val clientTransport = createTransport(mcpClientTransportParams); + val timeout = mcpClientTransportParams.timeout; + val client = McpClient.sync(clientTransport) + .requestTimeout(Duration.ofSeconds(timeout)) + .clientInfo(new McpSchema.Implementation(id, "PARTNER")) + .build(); + mcpClients.put(id, client); + } + + private McpClientTransport createTransport(McpClientTransportParams mcpClientTransportParams) { + return switch (mcpClientTransportParams) { + case McpClientTransportParams.Stdio params -> { + val serverParameters = ServerParameters.builder(params.command).env(params.env).args(params.args).build(); + yield new StdioClientTransport(serverParameters, McpJsonMapper.getDefault()); + } + case McpClientTransportParams.Http params -> { + val customizer = new McpSyncHttpClientRequestCustomizer() { + @Override + public void customize(HttpRequest.Builder builder, String method, URI endpoint, String body, McpTransportContext context) { + params.headers.forEach(builder::setHeader); + } + }; + yield HttpClientSseClientTransport.builder(params.baseUri).httpRequestCustomizer(customizer).sseEndpoint(params.endpoint).build(); + } + }; + } + @Override @NotNull protected LocalWatchServiceBuild.InitLoader buildLoad() { + return () -> { + // For CommonMcp, we need to list all files in MCP_SERVER_PATH, + // and search for files with extend name .json, + // and then reading them as JSONObject to get McpClientTransportParams. + val files = loadFiles(); + + for (File file : files) { + if (file.isFile() && file.getName().endsWith(".json")) { + registerMcpClients(file); + } + } + }; + } + + private void registerMcpClients(File file) { + val json = readJson(file); + if (json == null) { + return; + } + + for (String id : json.keySet()) { + val mcp = readMcp(json, id); + if (mcp == null) { + continue; + } + + val params = readParams(mcp); + if (params == null) { + continue; + } + + registerMcpClient(id, params); + } + } + + private cn.hutool.json.JSONObject readJson(File file) { + try { + return JSONUtil.readJSONObject(file, StandardCharsets.UTF_8); + } catch (IORuntimeException ignored) { + return null; + } + } + + private cn.hutool.json.JSONObject readMcp(cn.hutool.json.JSONObject json, String id) { + try { + return json.getJSONObject(id); + } catch (Exception ignored) { + return null; + } + } + + @SuppressWarnings("unchecked") + private McpClientTransportParams readParams(cn.hutool.json.JSONObject mcp) { + val stdioKeys = Set.of("command", "args", "env"); + val httpKeys = Set.of("uri", "endpoint", "headers"); + val httpKey = Set.of("url"); + val keys = mcp.keySet(); + val timeout = mcp.getInt("timeout", 10); + + if (keys.equals(stdioKeys)) { + val command = mcp.getStr("command"); + val env = mcp.getBean("env", Map.class); + val args = mcp.getBeanList("args", String.class); + if (command == null || env == null || args == null) { + return null; + } + return new McpClientTransportParams.Stdio(timeout, command, env, args); + } + + if (keys.equals(httpKeys)) { + val uri = mcp.getStr("uri"); + val endpoint = mcp.getStr("endpoint"); + val headers = mcp.getBean("headers", Map.class); + if (uri == null || endpoint == null || headers == null) { + return null; + } + return new McpClientTransportParams.Http(timeout, uri, endpoint, headers); + } + + if (keys.equals(httpKey)) { + val url = mcp.getStr("url"); + if (url == null) { + return null; + } + return new McpClientTransportParams.Http(timeout, url, "", Map.of()); + } + return null; } @@ -1105,8 +1193,7 @@ public class LocalRunnerClient extends RunnerClient { .start(); Thread stdoutThread = new Thread(() -> { - try (BufferedReader r = new BufferedReader( - new InputStreamReader(process.getInputStream()))) { + try (BufferedReader r = new BufferedReader(new InputStreamReader(process.getInputStream()))) { String line; while ((line = r.readLine()) != null) { output.add(line); @@ -1116,8 +1203,7 @@ public class LocalRunnerClient extends RunnerClient { }); Thread stderrThread = new Thread(() -> { - try (BufferedReader r = new BufferedReader( - new InputStreamReader(process.getErrorStream()))) { + try (BufferedReader r = new BufferedReader(new InputStreamReader(process.getErrorStream()))) { String line; while ((line = r.readLine()) != null) { error.add(line); @@ -1135,8 +1221,7 @@ public class LocalRunnerClient extends RunnerClient { result.setOk(exitCode == 0); result.setResultList(output.isEmpty() ? error : output); - result.setTotal(String.join("\n", - output.isEmpty() ? error : output)); + result.setTotal(String.join("\n", output.isEmpty() ? error : output)); } catch (Exception e) { result.setOk(false);