feat(LocalRunnerClient): support registering MCP clients in CommonMcp

This commit is contained in:
2026-01-05 23:06:17 +08:00
parent ec30ac1922
commit 44ab6cfac8

View File

@@ -1,6 +1,7 @@
package work.slhaf.partner.core.action.runner;
import cn.hutool.core.io.FileUtil;
import cn.hutool.core.io.IORuntimeException;
import cn.hutool.json.JSONUtil;
import com.alibaba.fastjson2.JSONArray;
import com.alibaba.fastjson2.JSONObject;
@@ -20,9 +21,9 @@ import io.modelcontextprotocol.spec.McpSchema;
import javassist.NotFoundException;
import lombok.Data;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;
import org.jetbrains.annotations.UnknownNullability;
import reactor.core.publisher.Mono;
import reactor.core.scheduler.Schedulers;
import work.slhaf.partner.common.mcp.InProcessMcpTransport;
@@ -37,6 +38,7 @@ import java.io.File;
import java.io.IOException;
import java.io.InputStreamReader;
import java.net.URI;
import java.net.http.HttpRequest;
import java.nio.charset.StandardCharsets;
import java.nio.file.*;
import java.time.Duration;
@@ -251,18 +253,6 @@ public class LocalRunnerClient extends RunnerClient {
return sysDependencies;
}
/**
* 该部分主要发生在扫描到新的MCP Server描述文件时出现的注册逻辑
*
* @param id MCP Client 的 id
* @param mcpClientTransportParams MCP Server 的参数
*/
private void registerMcpClient(String id, McpClientTransportParams mcpClientTransportParams) {
McpClientTransport clientTransport = createTransport(mcpClientTransportParams);
int timeout = mcpClientTransportParams.timeout;
registerMcpClient(id, clientTransport, timeout);
}
private void registerMcpClient(String id, McpClientTransport clientTransport, int timeout) {
McpSyncClient client = McpClient.sync(clientTransport)
.requestTimeout(Duration.ofSeconds(timeout))
@@ -292,27 +282,6 @@ public class LocalRunnerClient extends RunnerClient {
return info;
}
private McpClientTransport createTransport(McpClientTransportParams mcpClientTransportParams) {
return switch (mcpClientTransportParams) {
case McpClientTransportParams.Stdio params -> {
ServerParameters serverParameters = ServerParameters.builder(params.command)
.env(params.env)
.args(params.args)
.build();
yield new StdioClientTransport(serverParameters, McpJsonMapper.getDefault());
}
case McpClientTransportParams.Http params -> {
McpSyncHttpClientRequestCustomizer customizer = (builder, method, endpoint, body, context) -> {
params.headers.forEach(builder::setHeader);
};
yield HttpClientSseClientTransport.builder(params.baseUri)
.httpRequestCustomizer(customizer)
.sseEndpoint(params.endpoint)
.build();
}
};
}
private void setupShutdownHook() {
Runtime.getRuntime().addShutdownHook(new Thread(() -> {
dynamicActionMcpServer.close();
@@ -490,6 +459,18 @@ public class LocalRunnerClient extends RunnerClient {
protected abstract @NotNull LocalWatchServiceBuild.EventHandler buildOverflow();
protected File[] loadFiles() {
val root = ctx.root;
if (!Files.isDirectory(root)) {
throw new ActionInitFailedException("未找到目录: " + root);
}
val files = root.toFile().listFiles();
if (files == null) {
throw new ActionInitFailedException("目录无法正常读取: " + root);
}
return files;
}
@SuppressWarnings("LoggingSimilarMessage")
private static final class Dynamic extends LocalWatchEventProcessor {
@@ -502,14 +483,7 @@ public class LocalRunnerClient extends RunnerClient {
@SuppressWarnings("BooleanMethodIsAlwaysInverted")
private boolean normalPath(Path path) {
File file = path.toFile();
if (file.isFile()) {
return false;
}
File[] files = file.listFiles();
if (files == null) {
return false;
}
val files = loadFiles();
if (files.length < 2) {
return false;
}
@@ -879,14 +853,7 @@ public class LocalRunnerClient extends RunnerClient {
return () -> {
// DescMcp 的加载逻辑只负责读取已有的 *.desc.json 并注册为 resources
// 正常来讲 root 直接对应 MCP_DESC_PATH先检查 root 是否为目录,否则拒绝启动
Path root = ctx.root;
if (!Files.isDirectory(root)) {
throw new ActionInitFailedException("未找到目录: " + root);
}
File[] files = root.toFile().listFiles();
if (files == null) {
throw new ActionInitFailedException("目录无法正常读取: " + root);
}
val files = loadFiles();
for (File file : files) {
addResource(file);
}
@@ -1005,9 +972,130 @@ public class LocalRunnerClient extends RunnerClient {
this.mcpClients = mcpClients;
}
/**
* 该部分主要发生在扫描到新的MCP Server描述文件时出现的注册逻辑
*
* @param id MCP Client 的 id
* @param mcpClientTransportParams MCP Server 的参数
*/
private void registerMcpClient(String id, McpClientTransportParams mcpClientTransportParams) {
val clientTransport = createTransport(mcpClientTransportParams);
val timeout = mcpClientTransportParams.timeout;
val client = McpClient.sync(clientTransport)
.requestTimeout(Duration.ofSeconds(timeout))
.clientInfo(new McpSchema.Implementation(id, "PARTNER"))
.build();
mcpClients.put(id, client);
}
private McpClientTransport createTransport(McpClientTransportParams mcpClientTransportParams) {
return switch (mcpClientTransportParams) {
case McpClientTransportParams.Stdio params -> {
val serverParameters = ServerParameters.builder(params.command).env(params.env).args(params.args).build();
yield new StdioClientTransport(serverParameters, McpJsonMapper.getDefault());
}
case McpClientTransportParams.Http params -> {
val customizer = new McpSyncHttpClientRequestCustomizer() {
@Override
public void customize(HttpRequest.Builder builder, String method, URI endpoint, String body, McpTransportContext context) {
params.headers.forEach(builder::setHeader);
}
};
yield HttpClientSseClientTransport.builder(params.baseUri).httpRequestCustomizer(customizer).sseEndpoint(params.endpoint).build();
}
};
}
@Override
@NotNull
protected LocalWatchServiceBuild.InitLoader buildLoad() {
return () -> {
// For CommonMcp, we need to list all files in MCP_SERVER_PATH,
// and search for files with extend name .json,
// and then reading them as JSONObject to get McpClientTransportParams.
val files = loadFiles();
for (File file : files) {
if (file.isFile() && file.getName().endsWith(".json")) {
registerMcpClients(file);
}
}
};
}
private void registerMcpClients(File file) {
val json = readJson(file);
if (json == null) {
return;
}
for (String id : json.keySet()) {
val mcp = readMcp(json, id);
if (mcp == null) {
continue;
}
val params = readParams(mcp);
if (params == null) {
continue;
}
registerMcpClient(id, params);
}
}
private cn.hutool.json.JSONObject readJson(File file) {
try {
return JSONUtil.readJSONObject(file, StandardCharsets.UTF_8);
} catch (IORuntimeException ignored) {
return null;
}
}
private cn.hutool.json.JSONObject readMcp(cn.hutool.json.JSONObject json, String id) {
try {
return json.getJSONObject(id);
} catch (Exception ignored) {
return null;
}
}
@SuppressWarnings("unchecked")
private McpClientTransportParams readParams(cn.hutool.json.JSONObject mcp) {
val stdioKeys = Set.of("command", "args", "env");
val httpKeys = Set.of("uri", "endpoint", "headers");
val httpKey = Set.of("url");
val keys = mcp.keySet();
val timeout = mcp.getInt("timeout", 10);
if (keys.equals(stdioKeys)) {
val command = mcp.getStr("command");
val env = mcp.getBean("env", Map.class);
val args = mcp.getBeanList("args", String.class);
if (command == null || env == null || args == null) {
return null;
}
return new McpClientTransportParams.Stdio(timeout, command, env, args);
}
if (keys.equals(httpKeys)) {
val uri = mcp.getStr("uri");
val endpoint = mcp.getStr("endpoint");
val headers = mcp.getBean("headers", Map.class);
if (uri == null || endpoint == null || headers == null) {
return null;
}
return new McpClientTransportParams.Http(timeout, uri, endpoint, headers);
}
if (keys.equals(httpKey)) {
val url = mcp.getStr("url");
if (url == null) {
return null;
}
return new McpClientTransportParams.Http(timeout, url, "", Map.of());
}
return null;
}
@@ -1105,8 +1193,7 @@ public class LocalRunnerClient extends RunnerClient {
.start();
Thread stdoutThread = new Thread(() -> {
try (BufferedReader r = new BufferedReader(
new InputStreamReader(process.getInputStream()))) {
try (BufferedReader r = new BufferedReader(new InputStreamReader(process.getInputStream()))) {
String line;
while ((line = r.readLine()) != null) {
output.add(line);
@@ -1116,8 +1203,7 @@ public class LocalRunnerClient extends RunnerClient {
});
Thread stderrThread = new Thread(() -> {
try (BufferedReader r = new BufferedReader(
new InputStreamReader(process.getErrorStream()))) {
try (BufferedReader r = new BufferedReader(new InputStreamReader(process.getErrorStream()))) {
String line;
while ((line = r.readLine()) != null) {
error.add(line);
@@ -1135,8 +1221,7 @@ public class LocalRunnerClient extends RunnerClient {
result.setOk(exitCode == 0);
result.setResultList(output.isEmpty() ? error : output);
result.setTotal(String.join("\n",
output.isEmpty() ? error : output));
result.setTotal(String.join("\n", output.isEmpty() ? error : output));
} catch (Exception e) {
result.setOk(false);