推进 ActionExtractor 语义缓存机制: 移除了 VectorUtil,实现了 ollama、onnx runtime 两种向量客户端,通过 Agent 启动类暴露的后置启动任务加载并进行测试。

This commit is contained in:
2025-10-17 11:20:11 +08:00
parent 7094a8a68b
commit d1ea8dde79
8 changed files with 236 additions and 59 deletions

View File

@@ -2,6 +2,7 @@ package work.slhaf.partner;
import work.slhaf.partner.api.agent.Agent; import work.slhaf.partner.api.agent.Agent;
import work.slhaf.partner.common.config.PartnerAgentConfigManager; import work.slhaf.partner.common.config.PartnerAgentConfigManager;
import work.slhaf.partner.common.vector.VectorClient;
import work.slhaf.partner.runtime.exception.PartnerExceptionCallback; import work.slhaf.partner.runtime.exception.PartnerExceptionCallback;
import work.slhaf.partner.runtime.interaction.WebSocketGateway; import work.slhaf.partner.runtime.interaction.WebSocketGateway;
@@ -11,6 +12,7 @@ public class Main {
.setAgentConfigManager(PartnerAgentConfigManager.class) .setAgentConfigManager(PartnerAgentConfigManager.class)
.setGateway(WebSocketGateway.class) .setGateway(WebSocketGateway.class)
.setAgentExceptionCallback(PartnerExceptionCallback.class) .setAgentExceptionCallback(PartnerExceptionCallback.class)
.addAfterLaunchRunners(() -> VectorClient.load())
.launch(); .launch();
} }
} }

View File

@@ -10,8 +10,11 @@ public class Config {
@Data @Data
public static class VectorConfig { public static class VectorConfig {
private int type;
private String ollamaEmbeddingUrl; private String ollamaEmbeddingUrl;
private String ollamaEmbeddingModel; private String ollamaEmbeddingModel;
private String tokenizerPath;
private String embeddingModelPath;
} }
@Data @Data

View File

@@ -1,59 +0,0 @@
package work.slhaf.partner.common.util;
import cn.hutool.http.HttpRequest;
import cn.hutool.http.HttpResponse;
import com.alibaba.fastjson2.JSONObject;
import lombok.Data;
import lombok.extern.slf4j.Slf4j;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.ops.transforms.Transforms;
import work.slhaf.partner.api.agent.runtime.config.AgentConfigManager;
import work.slhaf.partner.common.config.PartnerAgentConfigManager;
import java.util.Map;
@Slf4j
public class VectorUtil {
private static final String OLLAMA_EMBEDDING_URL = ((PartnerAgentConfigManager) AgentConfigManager.INSTANCE).getConfig().getVectorConfig().getOllamaEmbeddingUrl();
private static final String OLLAMA_EMBEDDING_MODEL = ((PartnerAgentConfigManager) AgentConfigManager.INSTANCE).getConfig().getVectorConfig().getOllamaEmbeddingModel();
private VectorUtil() {
}
/**
* 如果计算失败将返回null
*
* @param input 需要计算向量的字符串
* @return 向量计算结果
*/
public static float[] compute(String input) {
Map<String, String> param = Map.of("model", OLLAMA_EMBEDDING_MODEL, "input", input);
HttpRequest request = HttpRequest.get(OLLAMA_EMBEDDING_URL).body(JSONObject.toJSONString(param));
try (HttpResponse response = request.execute()) {
if (!response.isOk()) return null;
String resStr = response.body();
EmbeddingModelResponse embeddingResponse = JSONObject.parseObject(resStr, EmbeddingModelResponse.class);
return embeddingResponse.getEmbeddings()[0];
} catch (Exception e) {
log.error("嵌入模型执行出错", e);
return null;
}
}
public static double compare(float[] v1, float[] v2) {
try (INDArray a1 = Nd4j.create(v1); INDArray a2 = Nd4j.create(v2)) {
return Transforms.cosineSim(a1, a2);
}
}
@Data
private static class EmbeddingModelResponse {
private String model;
private float[][] embeddings;
private long total_duration;
private long load_duration;
private int prompt_eval_count;
}
}

View File

@@ -0,0 +1,50 @@
package work.slhaf.partner.common.vector;
import java.util.Map;
import com.alibaba.fastjson2.JSONObject;
import cn.hutool.http.HttpRequest;
import cn.hutool.http.HttpResponse;
import lombok.Data;
import lombok.extern.slf4j.Slf4j;
import work.slhaf.partner.common.vector.exception.VectorClientExecuteException;
import work.slhaf.partner.common.vector.exception.VectorClientLoadFailedException;
@Slf4j
public class OllamaVectorClient extends VectorClient {
private String ollamaEmbeddingUrl;
private String ollamaEmbeddingModel;
protected OllamaVectorClient(String url, String model) {
this.ollamaEmbeddingUrl = url;
this.ollamaEmbeddingModel = model;
compute("test");
}
@Override
protected float[] doCompute(String input) {
Map<String, String> param = Map.of("model", ollamaEmbeddingModel, "input", input);
HttpRequest request = HttpRequest.get(ollamaEmbeddingUrl).body(JSONObject.toJSONString(param));
try (HttpResponse response = request.execute()) {
if (!response.isOk())
throw new VectorClientExecuteException("嵌入模型执行出错");
String resStr = response.body();
EmbeddingModelResponse embeddingResponse = JSONObject.parseObject(resStr, EmbeddingModelResponse.class);
return embeddingResponse.getEmbeddings()[0];
} catch (Exception e) {
throw new VectorClientExecuteException("嵌入模型执行出错", e);
}
}
@Data
private static class EmbeddingModelResponse {
private String model;
private float[][] embeddings;
private long total_duration;
private long load_duration;
private int prompt_eval_count;
}
}

View File

@@ -0,0 +1,83 @@
package work.slhaf.partner.common.vector;
import java.io.IOException;
import java.nio.file.Path;
import java.util.HashMap;
import java.util.Map;
import ai.djl.huggingface.tokenizers.Encoding;
import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer;
import ai.onnxruntime.OnnxTensor;
import ai.onnxruntime.OrtEnvironment;
import ai.onnxruntime.OrtException;
import ai.onnxruntime.OrtSession;
import work.slhaf.partner.common.vector.exception.VectorClientExecuteException;
import work.slhaf.partner.common.vector.exception.VectorClientLoadFailedException;
public class OnnxVectorClient extends VectorClient {
private String tokenizerPath;
private String modelPath;
private HuggingFaceTokenizer tokenizer;
private OrtSession session;
private OrtEnvironment env;
protected OnnxVectorClient(String tokenizer, String model) {
this.tokenizerPath = tokenizer;
this.modelPath = model;
loadTokenizer();
loadModel();
compute("test");
}
private void loadModel() {
try {
env = OrtEnvironment.getEnvironment();
OrtSession.SessionOptions ops = new OrtSession.SessionOptions();
session = env.createSession(modelPath, ops);
} catch (Exception e) {
throw new VectorClientLoadFailedException("加载ONNX模型失败", e);
}
}
private void loadTokenizer() {
try {
tokenizer = HuggingFaceTokenizer.newInstance(Path.of(tokenizerPath));
} catch (Exception e) {
throw new VectorClientLoadFailedException("加载Tokenizer失败", e);
}
}
@Override
protected float[] doCompute(String input) {
try {
Encoding encode = tokenizer.encode(input);
long[] ids = encode.getIds();
long[] attentionMask = encode.getAttentionMask();
long[][] inputIdsBatch = { ids };
long[][] attentionMaskBatch = { attentionMask };
long[][] tokenTypeIdsBatch = { new long[ids.length] }; // 初始化全 0
for (int i = 0; i < ids.length; i++)
tokenTypeIdsBatch[0][i] = 0;
OnnxTensor inputTensor = OnnxTensor.createTensor(env, inputIdsBatch);
OnnxTensor maskTensor = OnnxTensor.createTensor(env, attentionMaskBatch);
OnnxTensor tokenTypeTensor = OnnxTensor.createTensor(env, tokenTypeIdsBatch);
Map<String, OnnxTensor> inputs = new HashMap<>();
inputs.put("input_ids", inputTensor);
inputs.put("attention_mask", maskTensor);
inputs.put("token_type_ids", tokenTypeTensor);
OrtSession.Result result = session.run(inputs);
OnnxTensor embeddingTensor = (OnnxTensor) result.get(0);
return embeddingTensor.getFloatBuffer().array();
} catch (Exception e) {
throw new VectorClientExecuteException("嵌入模型执行出错", e);
}
}
}

View File

@@ -0,0 +1,68 @@
package work.slhaf.partner.common.vector;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.ops.transforms.Transforms;
import lombok.extern.slf4j.Slf4j;
import work.slhaf.partner.api.agent.runtime.config.AgentConfigManager;
import work.slhaf.partner.common.config.Config.VectorConfig;
import work.slhaf.partner.common.exception.ServiceLoadFailedException;
import work.slhaf.partner.common.vector.exception.VectorClientExecuteException;
import work.slhaf.partner.common.vector.exception.VectorClientLoadFailedException;
import work.slhaf.partner.common.config.PartnerAgentConfigManager;
@Slf4j
public abstract class VectorClient {
public static boolean status;
public static VectorClient INSTANCE;
public static void load() {
PartnerAgentConfigManager configManager = (PartnerAgentConfigManager) AgentConfigManager.INSTANCE;
VectorConfig vectorConfig = configManager.getConfig().getVectorConfig();
int type = vectorConfig.getType();
try {
switch (type) {
case 0:
status = false;
break;
case 1:
status = true;
INSTANCE = new OllamaVectorClient(vectorConfig.getOllamaEmbeddingUrl(),
vectorConfig.getOllamaEmbeddingModel());
break;
case 2:
status = true;
INSTANCE = new OnnxVectorClient(vectorConfig.getTokenizerPath(),
vectorConfig.getEmbeddingModelPath());
break;
default:
throw new ServiceLoadFailedException(
"加载向量客户端失败! type: 0 -> 不启用语义缓存; type: 1 -> ollama; type: 2 -> ONNX RUNTIME");
}
log.info("向量客户端加载完毕");
} catch (VectorClientLoadFailedException | VectorClientExecuteException exception) {
status = false;
log.error("向量客户端加载失败", exception);
}
}
public float[] compute(String input) {
if (!status) {
return null;
}
return doCompute(input);
}
protected abstract float[] doCompute(String input);
public double compare(float[] v1, float[] v2) {
if (!status) {
return 0;
}
try (INDArray a1 = Nd4j.create(v1); INDArray a2 = Nd4j.create(v2)) {
return Transforms.cosineSim(a1, a2);
}
}
}

View File

@@ -0,0 +1,15 @@
package work.slhaf.partner.common.vector.exception;
import work.slhaf.partner.api.agent.runtime.exception.AgentRuntimeException;
public class VectorClientExecuteException extends AgentRuntimeException {
public VectorClientExecuteException(String message) {
super(message);
}
public VectorClientExecuteException(String message, Throwable cause) {
super(message, cause);
}
}

View File

@@ -0,0 +1,15 @@
package work.slhaf.partner.common.vector.exception;
import work.slhaf.partner.api.agent.runtime.exception.AgentRuntimeException;
public class VectorClientLoadFailedException extends AgentRuntimeException {
public VectorClientLoadFailedException(String message) {
super(message);
}
public VectorClientLoadFailedException(String message, Throwable cause) {
super(message, cause);
}
}