From d1ea8dde790d7e62a09399c1f35e8da6dd370ad6 Mon Sep 17 00:00:00 2001 From: slhafzjw Date: Fri, 17 Oct 2025 11:20:11 +0800 Subject: [PATCH] =?UTF-8?q?=E6=8E=A8=E8=BF=9B=20ActionExtractor=20?= =?UTF-8?q?=E8=AF=AD=E4=B9=89=E7=BC=93=E5=AD=98=E6=9C=BA=E5=88=B6:=20?= =?UTF-8?q?=E7=A7=BB=E9=99=A4=E4=BA=86=20VectorUtil=EF=BC=8C=E5=AE=9E?= =?UTF-8?q?=E7=8E=B0=E4=BA=86=20ollama=E3=80=81onnx=20runtime=20=E4=B8=A4?= =?UTF-8?q?=E7=A7=8D=E5=90=91=E9=87=8F=E5=AE=A2=E6=88=B7=E7=AB=AF=EF=BC=8C?= =?UTF-8?q?=E9=80=9A=E8=BF=87=20Agent=20=E5=90=AF=E5=8A=A8=E7=B1=BB?= =?UTF-8?q?=E6=9A=B4=E9=9C=B2=E7=9A=84=E5=90=8E=E7=BD=AE=E5=90=AF=E5=8A=A8?= =?UTF-8?q?=E4=BB=BB=E5=8A=A1=E5=8A=A0=E8=BD=BD=E5=B9=B6=E8=BF=9B=E8=A1=8C?= =?UTF-8?q?=E6=B5=8B=E8=AF=95=E3=80=82?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../main/java/work/slhaf/partner/Main.java | 2 + .../slhaf/partner/common/config/Config.java | 3 + .../slhaf/partner/common/util/VectorUtil.java | 59 ------------- .../common/vector/OllamaVectorClient.java | 50 +++++++++++ .../common/vector/OnnxVectorClient.java | 83 +++++++++++++++++++ .../partner/common/vector/VectorClient.java | 68 +++++++++++++++ .../VectorClientExecuteException.java | 15 ++++ .../VectorClientLoadFailedException.java | 15 ++++ 8 files changed, 236 insertions(+), 59 deletions(-) delete mode 100644 Partner-Main/src/main/java/work/slhaf/partner/common/util/VectorUtil.java create mode 100644 Partner-Main/src/main/java/work/slhaf/partner/common/vector/OllamaVectorClient.java create mode 100644 Partner-Main/src/main/java/work/slhaf/partner/common/vector/OnnxVectorClient.java create mode 100644 Partner-Main/src/main/java/work/slhaf/partner/common/vector/VectorClient.java create mode 100644 Partner-Main/src/main/java/work/slhaf/partner/common/vector/exception/VectorClientExecuteException.java create mode 100644 Partner-Main/src/main/java/work/slhaf/partner/common/vector/exception/VectorClientLoadFailedException.java diff --git a/Partner-Main/src/main/java/work/slhaf/partner/Main.java b/Partner-Main/src/main/java/work/slhaf/partner/Main.java index 14dc1c1d..5d5e27a0 100644 --- a/Partner-Main/src/main/java/work/slhaf/partner/Main.java +++ b/Partner-Main/src/main/java/work/slhaf/partner/Main.java @@ -2,6 +2,7 @@ package work.slhaf.partner; import work.slhaf.partner.api.agent.Agent; import work.slhaf.partner.common.config.PartnerAgentConfigManager; +import work.slhaf.partner.common.vector.VectorClient; import work.slhaf.partner.runtime.exception.PartnerExceptionCallback; import work.slhaf.partner.runtime.interaction.WebSocketGateway; @@ -11,6 +12,7 @@ public class Main { .setAgentConfigManager(PartnerAgentConfigManager.class) .setGateway(WebSocketGateway.class) .setAgentExceptionCallback(PartnerExceptionCallback.class) + .addAfterLaunchRunners(() -> VectorClient.load()) .launch(); } } \ No newline at end of file diff --git a/Partner-Main/src/main/java/work/slhaf/partner/common/config/Config.java b/Partner-Main/src/main/java/work/slhaf/partner/common/config/Config.java index 8b33aa5d..273f83e8 100644 --- a/Partner-Main/src/main/java/work/slhaf/partner/common/config/Config.java +++ b/Partner-Main/src/main/java/work/slhaf/partner/common/config/Config.java @@ -10,8 +10,11 @@ public class Config { @Data public static class VectorConfig { + private int type; private String ollamaEmbeddingUrl; private String ollamaEmbeddingModel; + private String tokenizerPath; + private String embeddingModelPath; } @Data diff --git a/Partner-Main/src/main/java/work/slhaf/partner/common/util/VectorUtil.java b/Partner-Main/src/main/java/work/slhaf/partner/common/util/VectorUtil.java deleted file mode 100644 index 81e94ea8..00000000 --- a/Partner-Main/src/main/java/work/slhaf/partner/common/util/VectorUtil.java +++ /dev/null @@ -1,59 +0,0 @@ -package work.slhaf.partner.common.util; - -import cn.hutool.http.HttpRequest; -import cn.hutool.http.HttpResponse; -import com.alibaba.fastjson2.JSONObject; -import lombok.Data; -import lombok.extern.slf4j.Slf4j; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.factory.Nd4j; -import org.nd4j.linalg.ops.transforms.Transforms; -import work.slhaf.partner.api.agent.runtime.config.AgentConfigManager; -import work.slhaf.partner.common.config.PartnerAgentConfigManager; - -import java.util.Map; - -@Slf4j -public class VectorUtil { - - private static final String OLLAMA_EMBEDDING_URL = ((PartnerAgentConfigManager) AgentConfigManager.INSTANCE).getConfig().getVectorConfig().getOllamaEmbeddingUrl(); - private static final String OLLAMA_EMBEDDING_MODEL = ((PartnerAgentConfigManager) AgentConfigManager.INSTANCE).getConfig().getVectorConfig().getOllamaEmbeddingModel(); - - private VectorUtil() { - } - - /** - * 如果计算失败将返回null - * - * @param input 需要计算向量的字符串 - * @return 向量计算结果 - */ - public static float[] compute(String input) { - Map param = Map.of("model", OLLAMA_EMBEDDING_MODEL, "input", input); - HttpRequest request = HttpRequest.get(OLLAMA_EMBEDDING_URL).body(JSONObject.toJSONString(param)); - try (HttpResponse response = request.execute()) { - if (!response.isOk()) return null; - String resStr = response.body(); - EmbeddingModelResponse embeddingResponse = JSONObject.parseObject(resStr, EmbeddingModelResponse.class); - return embeddingResponse.getEmbeddings()[0]; - } catch (Exception e) { - log.error("嵌入模型执行出错", e); - return null; - } - } - - public static double compare(float[] v1, float[] v2) { - try (INDArray a1 = Nd4j.create(v1); INDArray a2 = Nd4j.create(v2)) { - return Transforms.cosineSim(a1, a2); - } - } - - @Data - private static class EmbeddingModelResponse { - private String model; - private float[][] embeddings; - private long total_duration; - private long load_duration; - private int prompt_eval_count; - } -} diff --git a/Partner-Main/src/main/java/work/slhaf/partner/common/vector/OllamaVectorClient.java b/Partner-Main/src/main/java/work/slhaf/partner/common/vector/OllamaVectorClient.java new file mode 100644 index 00000000..84c2cdfa --- /dev/null +++ b/Partner-Main/src/main/java/work/slhaf/partner/common/vector/OllamaVectorClient.java @@ -0,0 +1,50 @@ +package work.slhaf.partner.common.vector; + +import java.util.Map; + +import com.alibaba.fastjson2.JSONObject; + +import cn.hutool.http.HttpRequest; +import cn.hutool.http.HttpResponse; +import lombok.Data; +import lombok.extern.slf4j.Slf4j; +import work.slhaf.partner.common.vector.exception.VectorClientExecuteException; +import work.slhaf.partner.common.vector.exception.VectorClientLoadFailedException; + +@Slf4j +public class OllamaVectorClient extends VectorClient { + + private String ollamaEmbeddingUrl; + private String ollamaEmbeddingModel; + + protected OllamaVectorClient(String url, String model) { + this.ollamaEmbeddingUrl = url; + this.ollamaEmbeddingModel = model; + + compute("test"); + } + + @Override + protected float[] doCompute(String input) { + Map param = Map.of("model", ollamaEmbeddingModel, "input", input); + HttpRequest request = HttpRequest.get(ollamaEmbeddingUrl).body(JSONObject.toJSONString(param)); + try (HttpResponse response = request.execute()) { + if (!response.isOk()) + throw new VectorClientExecuteException("嵌入模型执行出错"); + String resStr = response.body(); + EmbeddingModelResponse embeddingResponse = JSONObject.parseObject(resStr, EmbeddingModelResponse.class); + return embeddingResponse.getEmbeddings()[0]; + } catch (Exception e) { + throw new VectorClientExecuteException("嵌入模型执行出错", e); + } + } + + @Data + private static class EmbeddingModelResponse { + private String model; + private float[][] embeddings; + private long total_duration; + private long load_duration; + private int prompt_eval_count; + } +} diff --git a/Partner-Main/src/main/java/work/slhaf/partner/common/vector/OnnxVectorClient.java b/Partner-Main/src/main/java/work/slhaf/partner/common/vector/OnnxVectorClient.java new file mode 100644 index 00000000..4a4c6581 --- /dev/null +++ b/Partner-Main/src/main/java/work/slhaf/partner/common/vector/OnnxVectorClient.java @@ -0,0 +1,83 @@ +package work.slhaf.partner.common.vector; + +import java.io.IOException; +import java.nio.file.Path; +import java.util.HashMap; +import java.util.Map; + +import ai.djl.huggingface.tokenizers.Encoding; +import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer; +import ai.onnxruntime.OnnxTensor; +import ai.onnxruntime.OrtEnvironment; +import ai.onnxruntime.OrtException; +import ai.onnxruntime.OrtSession; +import work.slhaf.partner.common.vector.exception.VectorClientExecuteException; +import work.slhaf.partner.common.vector.exception.VectorClientLoadFailedException; + +public class OnnxVectorClient extends VectorClient { + + private String tokenizerPath; + private String modelPath; + + private HuggingFaceTokenizer tokenizer; + private OrtSession session; + private OrtEnvironment env; + + protected OnnxVectorClient(String tokenizer, String model) { + this.tokenizerPath = tokenizer; + this.modelPath = model; + + loadTokenizer(); + loadModel(); + compute("test"); + } + + private void loadModel() { + try { + env = OrtEnvironment.getEnvironment(); + OrtSession.SessionOptions ops = new OrtSession.SessionOptions(); + session = env.createSession(modelPath, ops); + } catch (Exception e) { + throw new VectorClientLoadFailedException("加载ONNX模型失败", e); + } + } + + private void loadTokenizer() { + try { + tokenizer = HuggingFaceTokenizer.newInstance(Path.of(tokenizerPath)); + } catch (Exception e) { + throw new VectorClientLoadFailedException("加载Tokenizer失败", e); + } + } + + @Override + protected float[] doCompute(String input) { + try { + Encoding encode = tokenizer.encode(input); + long[] ids = encode.getIds(); + long[] attentionMask = encode.getAttentionMask(); + + long[][] inputIdsBatch = { ids }; + long[][] attentionMaskBatch = { attentionMask }; + long[][] tokenTypeIdsBatch = { new long[ids.length] }; // 初始化全 0 + for (int i = 0; i < ids.length; i++) + tokenTypeIdsBatch[0][i] = 0; + + OnnxTensor inputTensor = OnnxTensor.createTensor(env, inputIdsBatch); + OnnxTensor maskTensor = OnnxTensor.createTensor(env, attentionMaskBatch); + OnnxTensor tokenTypeTensor = OnnxTensor.createTensor(env, tokenTypeIdsBatch); + + Map inputs = new HashMap<>(); + inputs.put("input_ids", inputTensor); + inputs.put("attention_mask", maskTensor); + inputs.put("token_type_ids", tokenTypeTensor); + + OrtSession.Result result = session.run(inputs); + OnnxTensor embeddingTensor = (OnnxTensor) result.get(0); + return embeddingTensor.getFloatBuffer().array(); + } catch (Exception e) { + throw new VectorClientExecuteException("嵌入模型执行出错", e); + } + } + +} diff --git a/Partner-Main/src/main/java/work/slhaf/partner/common/vector/VectorClient.java b/Partner-Main/src/main/java/work/slhaf/partner/common/vector/VectorClient.java new file mode 100644 index 00000000..76b3ebfd --- /dev/null +++ b/Partner-Main/src/main/java/work/slhaf/partner/common/vector/VectorClient.java @@ -0,0 +1,68 @@ +package work.slhaf.partner.common.vector; + +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.Nd4j; +import org.nd4j.linalg.ops.transforms.Transforms; + +import lombok.extern.slf4j.Slf4j; +import work.slhaf.partner.api.agent.runtime.config.AgentConfigManager; +import work.slhaf.partner.common.config.Config.VectorConfig; +import work.slhaf.partner.common.exception.ServiceLoadFailedException; +import work.slhaf.partner.common.vector.exception.VectorClientExecuteException; +import work.slhaf.partner.common.vector.exception.VectorClientLoadFailedException; +import work.slhaf.partner.common.config.PartnerAgentConfigManager; + +@Slf4j +public abstract class VectorClient { + + public static boolean status; + public static VectorClient INSTANCE; + + public static void load() { + PartnerAgentConfigManager configManager = (PartnerAgentConfigManager) AgentConfigManager.INSTANCE; + VectorConfig vectorConfig = configManager.getConfig().getVectorConfig(); + int type = vectorConfig.getType(); + try { + switch (type) { + case 0: + status = false; + break; + case 1: + status = true; + INSTANCE = new OllamaVectorClient(vectorConfig.getOllamaEmbeddingUrl(), + vectorConfig.getOllamaEmbeddingModel()); + break; + case 2: + status = true; + INSTANCE = new OnnxVectorClient(vectorConfig.getTokenizerPath(), + vectorConfig.getEmbeddingModelPath()); + break; + default: + throw new ServiceLoadFailedException( + "加载向量客户端失败! type: 0 -> 不启用语义缓存; type: 1 -> ollama; type: 2 -> ONNX RUNTIME"); + } + log.info("向量客户端加载完毕"); + } catch (VectorClientLoadFailedException | VectorClientExecuteException exception) { + status = false; + log.error("向量客户端加载失败", exception); + } + } + + public float[] compute(String input) { + if (!status) { + return null; + } + return doCompute(input); + } + + protected abstract float[] doCompute(String input); + + public double compare(float[] v1, float[] v2) { + if (!status) { + return 0; + } + try (INDArray a1 = Nd4j.create(v1); INDArray a2 = Nd4j.create(v2)) { + return Transforms.cosineSim(a1, a2); + } + } +} \ No newline at end of file diff --git a/Partner-Main/src/main/java/work/slhaf/partner/common/vector/exception/VectorClientExecuteException.java b/Partner-Main/src/main/java/work/slhaf/partner/common/vector/exception/VectorClientExecuteException.java new file mode 100644 index 00000000..9f7d7fb0 --- /dev/null +++ b/Partner-Main/src/main/java/work/slhaf/partner/common/vector/exception/VectorClientExecuteException.java @@ -0,0 +1,15 @@ +package work.slhaf.partner.common.vector.exception; + +import work.slhaf.partner.api.agent.runtime.exception.AgentRuntimeException; + +public class VectorClientExecuteException extends AgentRuntimeException { + + public VectorClientExecuteException(String message) { + super(message); + } + + public VectorClientExecuteException(String message, Throwable cause) { + super(message, cause); + } + +} diff --git a/Partner-Main/src/main/java/work/slhaf/partner/common/vector/exception/VectorClientLoadFailedException.java b/Partner-Main/src/main/java/work/slhaf/partner/common/vector/exception/VectorClientLoadFailedException.java new file mode 100644 index 00000000..2c7afc41 --- /dev/null +++ b/Partner-Main/src/main/java/work/slhaf/partner/common/vector/exception/VectorClientLoadFailedException.java @@ -0,0 +1,15 @@ +package work.slhaf.partner.common.vector.exception; + +import work.slhaf.partner.api.agent.runtime.exception.AgentRuntimeException; + +public class VectorClientLoadFailedException extends AgentRuntimeException { + + public VectorClientLoadFailedException(String message) { + super(message); + } + + public VectorClientLoadFailedException(String message, Throwable cause) { + super(message, cause); + } + +}