feat(LocalRunnerClient): support registering MCP clients in CommonMcp

This commit is contained in:
2026-01-05 23:06:17 +08:00
parent ec30ac1922
commit 44ab6cfac8

View File

@@ -1,6 +1,7 @@
package work.slhaf.partner.core.action.runner; package work.slhaf.partner.core.action.runner;
import cn.hutool.core.io.FileUtil; import cn.hutool.core.io.FileUtil;
import cn.hutool.core.io.IORuntimeException;
import cn.hutool.json.JSONUtil; import cn.hutool.json.JSONUtil;
import com.alibaba.fastjson2.JSONArray; import com.alibaba.fastjson2.JSONArray;
import com.alibaba.fastjson2.JSONObject; import com.alibaba.fastjson2.JSONObject;
@@ -20,9 +21,9 @@ import io.modelcontextprotocol.spec.McpSchema;
import javassist.NotFoundException; import javassist.NotFoundException;
import lombok.Data; import lombok.Data;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import lombok.val;
import org.jetbrains.annotations.NotNull; import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable; import org.jetbrains.annotations.Nullable;
import org.jetbrains.annotations.UnknownNullability;
import reactor.core.publisher.Mono; import reactor.core.publisher.Mono;
import reactor.core.scheduler.Schedulers; import reactor.core.scheduler.Schedulers;
import work.slhaf.partner.common.mcp.InProcessMcpTransport; import work.slhaf.partner.common.mcp.InProcessMcpTransport;
@@ -37,6 +38,7 @@ import java.io.File;
import java.io.IOException; import java.io.IOException;
import java.io.InputStreamReader; import java.io.InputStreamReader;
import java.net.URI; import java.net.URI;
import java.net.http.HttpRequest;
import java.nio.charset.StandardCharsets; import java.nio.charset.StandardCharsets;
import java.nio.file.*; import java.nio.file.*;
import java.time.Duration; import java.time.Duration;
@@ -251,18 +253,6 @@ public class LocalRunnerClient extends RunnerClient {
return sysDependencies; return sysDependencies;
} }
/**
* 该部分主要发生在扫描到新的MCP Server描述文件时出现的注册逻辑
*
* @param id MCP Client 的 id
* @param mcpClientTransportParams MCP Server 的参数
*/
private void registerMcpClient(String id, McpClientTransportParams mcpClientTransportParams) {
McpClientTransport clientTransport = createTransport(mcpClientTransportParams);
int timeout = mcpClientTransportParams.timeout;
registerMcpClient(id, clientTransport, timeout);
}
private void registerMcpClient(String id, McpClientTransport clientTransport, int timeout) { private void registerMcpClient(String id, McpClientTransport clientTransport, int timeout) {
McpSyncClient client = McpClient.sync(clientTransport) McpSyncClient client = McpClient.sync(clientTransport)
.requestTimeout(Duration.ofSeconds(timeout)) .requestTimeout(Duration.ofSeconds(timeout))
@@ -292,27 +282,6 @@ public class LocalRunnerClient extends RunnerClient {
return info; return info;
} }
private McpClientTransport createTransport(McpClientTransportParams mcpClientTransportParams) {
return switch (mcpClientTransportParams) {
case McpClientTransportParams.Stdio params -> {
ServerParameters serverParameters = ServerParameters.builder(params.command)
.env(params.env)
.args(params.args)
.build();
yield new StdioClientTransport(serverParameters, McpJsonMapper.getDefault());
}
case McpClientTransportParams.Http params -> {
McpSyncHttpClientRequestCustomizer customizer = (builder, method, endpoint, body, context) -> {
params.headers.forEach(builder::setHeader);
};
yield HttpClientSseClientTransport.builder(params.baseUri)
.httpRequestCustomizer(customizer)
.sseEndpoint(params.endpoint)
.build();
}
};
}
private void setupShutdownHook() { private void setupShutdownHook() {
Runtime.getRuntime().addShutdownHook(new Thread(() -> { Runtime.getRuntime().addShutdownHook(new Thread(() -> {
dynamicActionMcpServer.close(); dynamicActionMcpServer.close();
@@ -490,6 +459,18 @@ public class LocalRunnerClient extends RunnerClient {
protected abstract @NotNull LocalWatchServiceBuild.EventHandler buildOverflow(); protected abstract @NotNull LocalWatchServiceBuild.EventHandler buildOverflow();
protected File[] loadFiles() {
val root = ctx.root;
if (!Files.isDirectory(root)) {
throw new ActionInitFailedException("未找到目录: " + root);
}
val files = root.toFile().listFiles();
if (files == null) {
throw new ActionInitFailedException("目录无法正常读取: " + root);
}
return files;
}
@SuppressWarnings("LoggingSimilarMessage") @SuppressWarnings("LoggingSimilarMessage")
private static final class Dynamic extends LocalWatchEventProcessor { private static final class Dynamic extends LocalWatchEventProcessor {
@@ -502,14 +483,7 @@ public class LocalRunnerClient extends RunnerClient {
@SuppressWarnings("BooleanMethodIsAlwaysInverted") @SuppressWarnings("BooleanMethodIsAlwaysInverted")
private boolean normalPath(Path path) { private boolean normalPath(Path path) {
File file = path.toFile(); val files = loadFiles();
if (file.isFile()) {
return false;
}
File[] files = file.listFiles();
if (files == null) {
return false;
}
if (files.length < 2) { if (files.length < 2) {
return false; return false;
} }
@@ -879,14 +853,7 @@ public class LocalRunnerClient extends RunnerClient {
return () -> { return () -> {
// DescMcp 的加载逻辑只负责读取已有的 *.desc.json 并注册为 resources // DescMcp 的加载逻辑只负责读取已有的 *.desc.json 并注册为 resources
// 正常来讲 root 直接对应 MCP_DESC_PATH先检查 root 是否为目录,否则拒绝启动 // 正常来讲 root 直接对应 MCP_DESC_PATH先检查 root 是否为目录,否则拒绝启动
Path root = ctx.root; val files = loadFiles();
if (!Files.isDirectory(root)) {
throw new ActionInitFailedException("未找到目录: " + root);
}
File[] files = root.toFile().listFiles();
if (files == null) {
throw new ActionInitFailedException("目录无法正常读取: " + root);
}
for (File file : files) { for (File file : files) {
addResource(file); addResource(file);
} }
@@ -1005,9 +972,130 @@ public class LocalRunnerClient extends RunnerClient {
this.mcpClients = mcpClients; this.mcpClients = mcpClients;
} }
/**
* 该部分主要发生在扫描到新的MCP Server描述文件时出现的注册逻辑
*
* @param id MCP Client 的 id
* @param mcpClientTransportParams MCP Server 的参数
*/
private void registerMcpClient(String id, McpClientTransportParams mcpClientTransportParams) {
val clientTransport = createTransport(mcpClientTransportParams);
val timeout = mcpClientTransportParams.timeout;
val client = McpClient.sync(clientTransport)
.requestTimeout(Duration.ofSeconds(timeout))
.clientInfo(new McpSchema.Implementation(id, "PARTNER"))
.build();
mcpClients.put(id, client);
}
private McpClientTransport createTransport(McpClientTransportParams mcpClientTransportParams) {
return switch (mcpClientTransportParams) {
case McpClientTransportParams.Stdio params -> {
val serverParameters = ServerParameters.builder(params.command).env(params.env).args(params.args).build();
yield new StdioClientTransport(serverParameters, McpJsonMapper.getDefault());
}
case McpClientTransportParams.Http params -> {
val customizer = new McpSyncHttpClientRequestCustomizer() {
@Override
public void customize(HttpRequest.Builder builder, String method, URI endpoint, String body, McpTransportContext context) {
params.headers.forEach(builder::setHeader);
}
};
yield HttpClientSseClientTransport.builder(params.baseUri).httpRequestCustomizer(customizer).sseEndpoint(params.endpoint).build();
}
};
}
@Override @Override
@NotNull @NotNull
protected LocalWatchServiceBuild.InitLoader buildLoad() { protected LocalWatchServiceBuild.InitLoader buildLoad() {
return () -> {
// For CommonMcp, we need to list all files in MCP_SERVER_PATH,
// and search for files with extend name .json,
// and then reading them as JSONObject to get McpClientTransportParams.
val files = loadFiles();
for (File file : files) {
if (file.isFile() && file.getName().endsWith(".json")) {
registerMcpClients(file);
}
}
};
}
private void registerMcpClients(File file) {
val json = readJson(file);
if (json == null) {
return;
}
for (String id : json.keySet()) {
val mcp = readMcp(json, id);
if (mcp == null) {
continue;
}
val params = readParams(mcp);
if (params == null) {
continue;
}
registerMcpClient(id, params);
}
}
private cn.hutool.json.JSONObject readJson(File file) {
try {
return JSONUtil.readJSONObject(file, StandardCharsets.UTF_8);
} catch (IORuntimeException ignored) {
return null;
}
}
private cn.hutool.json.JSONObject readMcp(cn.hutool.json.JSONObject json, String id) {
try {
return json.getJSONObject(id);
} catch (Exception ignored) {
return null;
}
}
@SuppressWarnings("unchecked")
private McpClientTransportParams readParams(cn.hutool.json.JSONObject mcp) {
val stdioKeys = Set.of("command", "args", "env");
val httpKeys = Set.of("uri", "endpoint", "headers");
val httpKey = Set.of("url");
val keys = mcp.keySet();
val timeout = mcp.getInt("timeout", 10);
if (keys.equals(stdioKeys)) {
val command = mcp.getStr("command");
val env = mcp.getBean("env", Map.class);
val args = mcp.getBeanList("args", String.class);
if (command == null || env == null || args == null) {
return null;
}
return new McpClientTransportParams.Stdio(timeout, command, env, args);
}
if (keys.equals(httpKeys)) {
val uri = mcp.getStr("uri");
val endpoint = mcp.getStr("endpoint");
val headers = mcp.getBean("headers", Map.class);
if (uri == null || endpoint == null || headers == null) {
return null;
}
return new McpClientTransportParams.Http(timeout, uri, endpoint, headers);
}
if (keys.equals(httpKey)) {
val url = mcp.getStr("url");
if (url == null) {
return null;
}
return new McpClientTransportParams.Http(timeout, url, "", Map.of());
}
return null; return null;
} }
@@ -1105,8 +1193,7 @@ public class LocalRunnerClient extends RunnerClient {
.start(); .start();
Thread stdoutThread = new Thread(() -> { Thread stdoutThread = new Thread(() -> {
try (BufferedReader r = new BufferedReader( try (BufferedReader r = new BufferedReader(new InputStreamReader(process.getInputStream()))) {
new InputStreamReader(process.getInputStream()))) {
String line; String line;
while ((line = r.readLine()) != null) { while ((line = r.readLine()) != null) {
output.add(line); output.add(line);
@@ -1116,8 +1203,7 @@ public class LocalRunnerClient extends RunnerClient {
}); });
Thread stderrThread = new Thread(() -> { Thread stderrThread = new Thread(() -> {
try (BufferedReader r = new BufferedReader( try (BufferedReader r = new BufferedReader(new InputStreamReader(process.getErrorStream()))) {
new InputStreamReader(process.getErrorStream()))) {
String line; String line;
while ((line = r.readLine()) != null) { while ((line = r.readLine()) != null) {
error.add(line); error.add(line);
@@ -1135,8 +1221,7 @@ public class LocalRunnerClient extends RunnerClient {
result.setOk(exitCode == 0); result.setOk(exitCode == 0);
result.setResultList(output.isEmpty() ? error : output); result.setResultList(output.isEmpty() ? error : output);
result.setTotal(String.join("\n", result.setTotal(String.join("\n", output.isEmpty() ? error : output));
output.isEmpty() ? error : output));
} catch (Exception e) { } catch (Exception e) {
result.setOk(false); result.setOk(false);