mirror of
https://github.com/slhaf/Partner.git
synced 2026-05-12 08:43:02 +08:00
推进 ActionExtractor 语义缓存机制: 移除了 VectorUtil,实现了 ollama、onnx runtime 两种向量客户端,通过 Agent 启动类暴露的后置启动任务加载并进行测试。
This commit is contained in:
@@ -2,6 +2,7 @@ package work.slhaf.partner;
|
|||||||
|
|
||||||
import work.slhaf.partner.api.agent.Agent;
|
import work.slhaf.partner.api.agent.Agent;
|
||||||
import work.slhaf.partner.common.config.PartnerAgentConfigManager;
|
import work.slhaf.partner.common.config.PartnerAgentConfigManager;
|
||||||
|
import work.slhaf.partner.common.vector.VectorClient;
|
||||||
import work.slhaf.partner.runtime.exception.PartnerExceptionCallback;
|
import work.slhaf.partner.runtime.exception.PartnerExceptionCallback;
|
||||||
import work.slhaf.partner.runtime.interaction.WebSocketGateway;
|
import work.slhaf.partner.runtime.interaction.WebSocketGateway;
|
||||||
|
|
||||||
@@ -11,6 +12,7 @@ public class Main {
|
|||||||
.setAgentConfigManager(PartnerAgentConfigManager.class)
|
.setAgentConfigManager(PartnerAgentConfigManager.class)
|
||||||
.setGateway(WebSocketGateway.class)
|
.setGateway(WebSocketGateway.class)
|
||||||
.setAgentExceptionCallback(PartnerExceptionCallback.class)
|
.setAgentExceptionCallback(PartnerExceptionCallback.class)
|
||||||
|
.addAfterLaunchRunners(() -> VectorClient.load())
|
||||||
.launch();
|
.launch();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -10,8 +10,11 @@ public class Config {
|
|||||||
|
|
||||||
@Data
|
@Data
|
||||||
public static class VectorConfig {
|
public static class VectorConfig {
|
||||||
|
private int type;
|
||||||
private String ollamaEmbeddingUrl;
|
private String ollamaEmbeddingUrl;
|
||||||
private String ollamaEmbeddingModel;
|
private String ollamaEmbeddingModel;
|
||||||
|
private String tokenizerPath;
|
||||||
|
private String embeddingModelPath;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
|
|||||||
@@ -1,59 +0,0 @@
|
|||||||
package work.slhaf.partner.common.util;
|
|
||||||
|
|
||||||
import cn.hutool.http.HttpRequest;
|
|
||||||
import cn.hutool.http.HttpResponse;
|
|
||||||
import com.alibaba.fastjson2.JSONObject;
|
|
||||||
import lombok.Data;
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
|
||||||
import org.nd4j.linalg.ops.transforms.Transforms;
|
|
||||||
import work.slhaf.partner.api.agent.runtime.config.AgentConfigManager;
|
|
||||||
import work.slhaf.partner.common.config.PartnerAgentConfigManager;
|
|
||||||
|
|
||||||
import java.util.Map;
|
|
||||||
|
|
||||||
@Slf4j
|
|
||||||
public class VectorUtil {
|
|
||||||
|
|
||||||
private static final String OLLAMA_EMBEDDING_URL = ((PartnerAgentConfigManager) AgentConfigManager.INSTANCE).getConfig().getVectorConfig().getOllamaEmbeddingUrl();
|
|
||||||
private static final String OLLAMA_EMBEDDING_MODEL = ((PartnerAgentConfigManager) AgentConfigManager.INSTANCE).getConfig().getVectorConfig().getOllamaEmbeddingModel();
|
|
||||||
|
|
||||||
private VectorUtil() {
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* 如果计算失败将返回null
|
|
||||||
*
|
|
||||||
* @param input 需要计算向量的字符串
|
|
||||||
* @return 向量计算结果
|
|
||||||
*/
|
|
||||||
public static float[] compute(String input) {
|
|
||||||
Map<String, String> param = Map.of("model", OLLAMA_EMBEDDING_MODEL, "input", input);
|
|
||||||
HttpRequest request = HttpRequest.get(OLLAMA_EMBEDDING_URL).body(JSONObject.toJSONString(param));
|
|
||||||
try (HttpResponse response = request.execute()) {
|
|
||||||
if (!response.isOk()) return null;
|
|
||||||
String resStr = response.body();
|
|
||||||
EmbeddingModelResponse embeddingResponse = JSONObject.parseObject(resStr, EmbeddingModelResponse.class);
|
|
||||||
return embeddingResponse.getEmbeddings()[0];
|
|
||||||
} catch (Exception e) {
|
|
||||||
log.error("嵌入模型执行出错", e);
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
public static double compare(float[] v1, float[] v2) {
|
|
||||||
try (INDArray a1 = Nd4j.create(v1); INDArray a2 = Nd4j.create(v2)) {
|
|
||||||
return Transforms.cosineSim(a1, a2);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
@Data
|
|
||||||
private static class EmbeddingModelResponse {
|
|
||||||
private String model;
|
|
||||||
private float[][] embeddings;
|
|
||||||
private long total_duration;
|
|
||||||
private long load_duration;
|
|
||||||
private int prompt_eval_count;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -0,0 +1,50 @@
|
|||||||
|
package work.slhaf.partner.common.vector;
|
||||||
|
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
import com.alibaba.fastjson2.JSONObject;
|
||||||
|
|
||||||
|
import cn.hutool.http.HttpRequest;
|
||||||
|
import cn.hutool.http.HttpResponse;
|
||||||
|
import lombok.Data;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import work.slhaf.partner.common.vector.exception.VectorClientExecuteException;
|
||||||
|
import work.slhaf.partner.common.vector.exception.VectorClientLoadFailedException;
|
||||||
|
|
||||||
|
@Slf4j
|
||||||
|
public class OllamaVectorClient extends VectorClient {
|
||||||
|
|
||||||
|
private String ollamaEmbeddingUrl;
|
||||||
|
private String ollamaEmbeddingModel;
|
||||||
|
|
||||||
|
protected OllamaVectorClient(String url, String model) {
|
||||||
|
this.ollamaEmbeddingUrl = url;
|
||||||
|
this.ollamaEmbeddingModel = model;
|
||||||
|
|
||||||
|
compute("test");
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected float[] doCompute(String input) {
|
||||||
|
Map<String, String> param = Map.of("model", ollamaEmbeddingModel, "input", input);
|
||||||
|
HttpRequest request = HttpRequest.get(ollamaEmbeddingUrl).body(JSONObject.toJSONString(param));
|
||||||
|
try (HttpResponse response = request.execute()) {
|
||||||
|
if (!response.isOk())
|
||||||
|
throw new VectorClientExecuteException("嵌入模型执行出错");
|
||||||
|
String resStr = response.body();
|
||||||
|
EmbeddingModelResponse embeddingResponse = JSONObject.parseObject(resStr, EmbeddingModelResponse.class);
|
||||||
|
return embeddingResponse.getEmbeddings()[0];
|
||||||
|
} catch (Exception e) {
|
||||||
|
throw new VectorClientExecuteException("嵌入模型执行出错", e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Data
|
||||||
|
private static class EmbeddingModelResponse {
|
||||||
|
private String model;
|
||||||
|
private float[][] embeddings;
|
||||||
|
private long total_duration;
|
||||||
|
private long load_duration;
|
||||||
|
private int prompt_eval_count;
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,83 @@
|
|||||||
|
package work.slhaf.partner.common.vector;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.nio.file.Path;
|
||||||
|
import java.util.HashMap;
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
import ai.djl.huggingface.tokenizers.Encoding;
|
||||||
|
import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer;
|
||||||
|
import ai.onnxruntime.OnnxTensor;
|
||||||
|
import ai.onnxruntime.OrtEnvironment;
|
||||||
|
import ai.onnxruntime.OrtException;
|
||||||
|
import ai.onnxruntime.OrtSession;
|
||||||
|
import work.slhaf.partner.common.vector.exception.VectorClientExecuteException;
|
||||||
|
import work.slhaf.partner.common.vector.exception.VectorClientLoadFailedException;
|
||||||
|
|
||||||
|
public class OnnxVectorClient extends VectorClient {
|
||||||
|
|
||||||
|
private String tokenizerPath;
|
||||||
|
private String modelPath;
|
||||||
|
|
||||||
|
private HuggingFaceTokenizer tokenizer;
|
||||||
|
private OrtSession session;
|
||||||
|
private OrtEnvironment env;
|
||||||
|
|
||||||
|
protected OnnxVectorClient(String tokenizer, String model) {
|
||||||
|
this.tokenizerPath = tokenizer;
|
||||||
|
this.modelPath = model;
|
||||||
|
|
||||||
|
loadTokenizer();
|
||||||
|
loadModel();
|
||||||
|
compute("test");
|
||||||
|
}
|
||||||
|
|
||||||
|
private void loadModel() {
|
||||||
|
try {
|
||||||
|
env = OrtEnvironment.getEnvironment();
|
||||||
|
OrtSession.SessionOptions ops = new OrtSession.SessionOptions();
|
||||||
|
session = env.createSession(modelPath, ops);
|
||||||
|
} catch (Exception e) {
|
||||||
|
throw new VectorClientLoadFailedException("加载ONNX模型失败", e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private void loadTokenizer() {
|
||||||
|
try {
|
||||||
|
tokenizer = HuggingFaceTokenizer.newInstance(Path.of(tokenizerPath));
|
||||||
|
} catch (Exception e) {
|
||||||
|
throw new VectorClientLoadFailedException("加载Tokenizer失败", e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected float[] doCompute(String input) {
|
||||||
|
try {
|
||||||
|
Encoding encode = tokenizer.encode(input);
|
||||||
|
long[] ids = encode.getIds();
|
||||||
|
long[] attentionMask = encode.getAttentionMask();
|
||||||
|
|
||||||
|
long[][] inputIdsBatch = { ids };
|
||||||
|
long[][] attentionMaskBatch = { attentionMask };
|
||||||
|
long[][] tokenTypeIdsBatch = { new long[ids.length] }; // 初始化全 0
|
||||||
|
for (int i = 0; i < ids.length; i++)
|
||||||
|
tokenTypeIdsBatch[0][i] = 0;
|
||||||
|
|
||||||
|
OnnxTensor inputTensor = OnnxTensor.createTensor(env, inputIdsBatch);
|
||||||
|
OnnxTensor maskTensor = OnnxTensor.createTensor(env, attentionMaskBatch);
|
||||||
|
OnnxTensor tokenTypeTensor = OnnxTensor.createTensor(env, tokenTypeIdsBatch);
|
||||||
|
|
||||||
|
Map<String, OnnxTensor> inputs = new HashMap<>();
|
||||||
|
inputs.put("input_ids", inputTensor);
|
||||||
|
inputs.put("attention_mask", maskTensor);
|
||||||
|
inputs.put("token_type_ids", tokenTypeTensor);
|
||||||
|
|
||||||
|
OrtSession.Result result = session.run(inputs);
|
||||||
|
OnnxTensor embeddingTensor = (OnnxTensor) result.get(0);
|
||||||
|
return embeddingTensor.getFloatBuffer().array();
|
||||||
|
} catch (Exception e) {
|
||||||
|
throw new VectorClientExecuteException("嵌入模型执行出错", e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
@@ -0,0 +1,68 @@
|
|||||||
|
package work.slhaf.partner.common.vector;
|
||||||
|
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
import org.nd4j.linalg.ops.transforms.Transforms;
|
||||||
|
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import work.slhaf.partner.api.agent.runtime.config.AgentConfigManager;
|
||||||
|
import work.slhaf.partner.common.config.Config.VectorConfig;
|
||||||
|
import work.slhaf.partner.common.exception.ServiceLoadFailedException;
|
||||||
|
import work.slhaf.partner.common.vector.exception.VectorClientExecuteException;
|
||||||
|
import work.slhaf.partner.common.vector.exception.VectorClientLoadFailedException;
|
||||||
|
import work.slhaf.partner.common.config.PartnerAgentConfigManager;
|
||||||
|
|
||||||
|
@Slf4j
|
||||||
|
public abstract class VectorClient {
|
||||||
|
|
||||||
|
public static boolean status;
|
||||||
|
public static VectorClient INSTANCE;
|
||||||
|
|
||||||
|
public static void load() {
|
||||||
|
PartnerAgentConfigManager configManager = (PartnerAgentConfigManager) AgentConfigManager.INSTANCE;
|
||||||
|
VectorConfig vectorConfig = configManager.getConfig().getVectorConfig();
|
||||||
|
int type = vectorConfig.getType();
|
||||||
|
try {
|
||||||
|
switch (type) {
|
||||||
|
case 0:
|
||||||
|
status = false;
|
||||||
|
break;
|
||||||
|
case 1:
|
||||||
|
status = true;
|
||||||
|
INSTANCE = new OllamaVectorClient(vectorConfig.getOllamaEmbeddingUrl(),
|
||||||
|
vectorConfig.getOllamaEmbeddingModel());
|
||||||
|
break;
|
||||||
|
case 2:
|
||||||
|
status = true;
|
||||||
|
INSTANCE = new OnnxVectorClient(vectorConfig.getTokenizerPath(),
|
||||||
|
vectorConfig.getEmbeddingModelPath());
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
throw new ServiceLoadFailedException(
|
||||||
|
"加载向量客户端失败! type: 0 -> 不启用语义缓存; type: 1 -> ollama; type: 2 -> ONNX RUNTIME");
|
||||||
|
}
|
||||||
|
log.info("向量客户端加载完毕");
|
||||||
|
} catch (VectorClientLoadFailedException | VectorClientExecuteException exception) {
|
||||||
|
status = false;
|
||||||
|
log.error("向量客户端加载失败", exception);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public float[] compute(String input) {
|
||||||
|
if (!status) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
return doCompute(input);
|
||||||
|
}
|
||||||
|
|
||||||
|
protected abstract float[] doCompute(String input);
|
||||||
|
|
||||||
|
public double compare(float[] v1, float[] v2) {
|
||||||
|
if (!status) {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
try (INDArray a1 = Nd4j.create(v1); INDArray a2 = Nd4j.create(v2)) {
|
||||||
|
return Transforms.cosineSim(a1, a2);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,15 @@
|
|||||||
|
package work.slhaf.partner.common.vector.exception;
|
||||||
|
|
||||||
|
import work.slhaf.partner.api.agent.runtime.exception.AgentRuntimeException;
|
||||||
|
|
||||||
|
public class VectorClientExecuteException extends AgentRuntimeException {
|
||||||
|
|
||||||
|
public VectorClientExecuteException(String message) {
|
||||||
|
super(message);
|
||||||
|
}
|
||||||
|
|
||||||
|
public VectorClientExecuteException(String message, Throwable cause) {
|
||||||
|
super(message, cause);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
@@ -0,0 +1,15 @@
|
|||||||
|
package work.slhaf.partner.common.vector.exception;
|
||||||
|
|
||||||
|
import work.slhaf.partner.api.agent.runtime.exception.AgentRuntimeException;
|
||||||
|
|
||||||
|
public class VectorClientLoadFailedException extends AgentRuntimeException {
|
||||||
|
|
||||||
|
public VectorClientLoadFailedException(String message) {
|
||||||
|
super(message);
|
||||||
|
}
|
||||||
|
|
||||||
|
public VectorClientLoadFailedException(String message, Throwable cause) {
|
||||||
|
super(message, cause);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user