mirror of
https://github.com/slhaf/Partner.git
synced 2026-05-12 16:53:04 +08:00
推进 ActionExtractor 语义缓存机制: 移除了 VectorUtil,实现了 ollama、onnx runtime 两种向量客户端,通过 Agent 启动类暴露的后置启动任务加载并进行测试。
This commit is contained in:
@@ -2,6 +2,7 @@ package work.slhaf.partner;
|
||||
|
||||
import work.slhaf.partner.api.agent.Agent;
|
||||
import work.slhaf.partner.common.config.PartnerAgentConfigManager;
|
||||
import work.slhaf.partner.common.vector.VectorClient;
|
||||
import work.slhaf.partner.runtime.exception.PartnerExceptionCallback;
|
||||
import work.slhaf.partner.runtime.interaction.WebSocketGateway;
|
||||
|
||||
@@ -11,6 +12,7 @@ public class Main {
|
||||
.setAgentConfigManager(PartnerAgentConfigManager.class)
|
||||
.setGateway(WebSocketGateway.class)
|
||||
.setAgentExceptionCallback(PartnerExceptionCallback.class)
|
||||
.addAfterLaunchRunners(() -> VectorClient.load())
|
||||
.launch();
|
||||
}
|
||||
}
|
||||
@@ -10,8 +10,11 @@ public class Config {
|
||||
|
||||
@Data
|
||||
public static class VectorConfig {
|
||||
private int type;
|
||||
private String ollamaEmbeddingUrl;
|
||||
private String ollamaEmbeddingModel;
|
||||
private String tokenizerPath;
|
||||
private String embeddingModelPath;
|
||||
}
|
||||
|
||||
@Data
|
||||
|
||||
@@ -1,59 +0,0 @@
|
||||
package work.slhaf.partner.common.util;
|
||||
|
||||
import cn.hutool.http.HttpRequest;
|
||||
import cn.hutool.http.HttpResponse;
|
||||
import com.alibaba.fastjson2.JSONObject;
|
||||
import lombok.Data;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.linalg.ops.transforms.Transforms;
|
||||
import work.slhaf.partner.api.agent.runtime.config.AgentConfigManager;
|
||||
import work.slhaf.partner.common.config.PartnerAgentConfigManager;
|
||||
|
||||
import java.util.Map;
|
||||
|
||||
@Slf4j
|
||||
public class VectorUtil {
|
||||
|
||||
private static final String OLLAMA_EMBEDDING_URL = ((PartnerAgentConfigManager) AgentConfigManager.INSTANCE).getConfig().getVectorConfig().getOllamaEmbeddingUrl();
|
||||
private static final String OLLAMA_EMBEDDING_MODEL = ((PartnerAgentConfigManager) AgentConfigManager.INSTANCE).getConfig().getVectorConfig().getOllamaEmbeddingModel();
|
||||
|
||||
private VectorUtil() {
|
||||
}
|
||||
|
||||
/**
|
||||
* 如果计算失败将返回null
|
||||
*
|
||||
* @param input 需要计算向量的字符串
|
||||
* @return 向量计算结果
|
||||
*/
|
||||
public static float[] compute(String input) {
|
||||
Map<String, String> param = Map.of("model", OLLAMA_EMBEDDING_MODEL, "input", input);
|
||||
HttpRequest request = HttpRequest.get(OLLAMA_EMBEDDING_URL).body(JSONObject.toJSONString(param));
|
||||
try (HttpResponse response = request.execute()) {
|
||||
if (!response.isOk()) return null;
|
||||
String resStr = response.body();
|
||||
EmbeddingModelResponse embeddingResponse = JSONObject.parseObject(resStr, EmbeddingModelResponse.class);
|
||||
return embeddingResponse.getEmbeddings()[0];
|
||||
} catch (Exception e) {
|
||||
log.error("嵌入模型执行出错", e);
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
public static double compare(float[] v1, float[] v2) {
|
||||
try (INDArray a1 = Nd4j.create(v1); INDArray a2 = Nd4j.create(v2)) {
|
||||
return Transforms.cosineSim(a1, a2);
|
||||
}
|
||||
}
|
||||
|
||||
@Data
|
||||
private static class EmbeddingModelResponse {
|
||||
private String model;
|
||||
private float[][] embeddings;
|
||||
private long total_duration;
|
||||
private long load_duration;
|
||||
private int prompt_eval_count;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,50 @@
|
||||
package work.slhaf.partner.common.vector;
|
||||
|
||||
import java.util.Map;
|
||||
|
||||
import com.alibaba.fastjson2.JSONObject;
|
||||
|
||||
import cn.hutool.http.HttpRequest;
|
||||
import cn.hutool.http.HttpResponse;
|
||||
import lombok.Data;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import work.slhaf.partner.common.vector.exception.VectorClientExecuteException;
|
||||
import work.slhaf.partner.common.vector.exception.VectorClientLoadFailedException;
|
||||
|
||||
@Slf4j
|
||||
public class OllamaVectorClient extends VectorClient {
|
||||
|
||||
private String ollamaEmbeddingUrl;
|
||||
private String ollamaEmbeddingModel;
|
||||
|
||||
protected OllamaVectorClient(String url, String model) {
|
||||
this.ollamaEmbeddingUrl = url;
|
||||
this.ollamaEmbeddingModel = model;
|
||||
|
||||
compute("test");
|
||||
}
|
||||
|
||||
@Override
|
||||
protected float[] doCompute(String input) {
|
||||
Map<String, String> param = Map.of("model", ollamaEmbeddingModel, "input", input);
|
||||
HttpRequest request = HttpRequest.get(ollamaEmbeddingUrl).body(JSONObject.toJSONString(param));
|
||||
try (HttpResponse response = request.execute()) {
|
||||
if (!response.isOk())
|
||||
throw new VectorClientExecuteException("嵌入模型执行出错");
|
||||
String resStr = response.body();
|
||||
EmbeddingModelResponse embeddingResponse = JSONObject.parseObject(resStr, EmbeddingModelResponse.class);
|
||||
return embeddingResponse.getEmbeddings()[0];
|
||||
} catch (Exception e) {
|
||||
throw new VectorClientExecuteException("嵌入模型执行出错", e);
|
||||
}
|
||||
}
|
||||
|
||||
@Data
|
||||
private static class EmbeddingModelResponse {
|
||||
private String model;
|
||||
private float[][] embeddings;
|
||||
private long total_duration;
|
||||
private long load_duration;
|
||||
private int prompt_eval_count;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,83 @@
|
||||
package work.slhaf.partner.common.vector;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.nio.file.Path;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
|
||||
import ai.djl.huggingface.tokenizers.Encoding;
|
||||
import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer;
|
||||
import ai.onnxruntime.OnnxTensor;
|
||||
import ai.onnxruntime.OrtEnvironment;
|
||||
import ai.onnxruntime.OrtException;
|
||||
import ai.onnxruntime.OrtSession;
|
||||
import work.slhaf.partner.common.vector.exception.VectorClientExecuteException;
|
||||
import work.slhaf.partner.common.vector.exception.VectorClientLoadFailedException;
|
||||
|
||||
public class OnnxVectorClient extends VectorClient {
|
||||
|
||||
private String tokenizerPath;
|
||||
private String modelPath;
|
||||
|
||||
private HuggingFaceTokenizer tokenizer;
|
||||
private OrtSession session;
|
||||
private OrtEnvironment env;
|
||||
|
||||
protected OnnxVectorClient(String tokenizer, String model) {
|
||||
this.tokenizerPath = tokenizer;
|
||||
this.modelPath = model;
|
||||
|
||||
loadTokenizer();
|
||||
loadModel();
|
||||
compute("test");
|
||||
}
|
||||
|
||||
private void loadModel() {
|
||||
try {
|
||||
env = OrtEnvironment.getEnvironment();
|
||||
OrtSession.SessionOptions ops = new OrtSession.SessionOptions();
|
||||
session = env.createSession(modelPath, ops);
|
||||
} catch (Exception e) {
|
||||
throw new VectorClientLoadFailedException("加载ONNX模型失败", e);
|
||||
}
|
||||
}
|
||||
|
||||
private void loadTokenizer() {
|
||||
try {
|
||||
tokenizer = HuggingFaceTokenizer.newInstance(Path.of(tokenizerPath));
|
||||
} catch (Exception e) {
|
||||
throw new VectorClientLoadFailedException("加载Tokenizer失败", e);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
protected float[] doCompute(String input) {
|
||||
try {
|
||||
Encoding encode = tokenizer.encode(input);
|
||||
long[] ids = encode.getIds();
|
||||
long[] attentionMask = encode.getAttentionMask();
|
||||
|
||||
long[][] inputIdsBatch = { ids };
|
||||
long[][] attentionMaskBatch = { attentionMask };
|
||||
long[][] tokenTypeIdsBatch = { new long[ids.length] }; // 初始化全 0
|
||||
for (int i = 0; i < ids.length; i++)
|
||||
tokenTypeIdsBatch[0][i] = 0;
|
||||
|
||||
OnnxTensor inputTensor = OnnxTensor.createTensor(env, inputIdsBatch);
|
||||
OnnxTensor maskTensor = OnnxTensor.createTensor(env, attentionMaskBatch);
|
||||
OnnxTensor tokenTypeTensor = OnnxTensor.createTensor(env, tokenTypeIdsBatch);
|
||||
|
||||
Map<String, OnnxTensor> inputs = new HashMap<>();
|
||||
inputs.put("input_ids", inputTensor);
|
||||
inputs.put("attention_mask", maskTensor);
|
||||
inputs.put("token_type_ids", tokenTypeTensor);
|
||||
|
||||
OrtSession.Result result = session.run(inputs);
|
||||
OnnxTensor embeddingTensor = (OnnxTensor) result.get(0);
|
||||
return embeddingTensor.getFloatBuffer().array();
|
||||
} catch (Exception e) {
|
||||
throw new VectorClientExecuteException("嵌入模型执行出错", e);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,68 @@
|
||||
package work.slhaf.partner.common.vector;
|
||||
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.linalg.ops.transforms.Transforms;
|
||||
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import work.slhaf.partner.api.agent.runtime.config.AgentConfigManager;
|
||||
import work.slhaf.partner.common.config.Config.VectorConfig;
|
||||
import work.slhaf.partner.common.exception.ServiceLoadFailedException;
|
||||
import work.slhaf.partner.common.vector.exception.VectorClientExecuteException;
|
||||
import work.slhaf.partner.common.vector.exception.VectorClientLoadFailedException;
|
||||
import work.slhaf.partner.common.config.PartnerAgentConfigManager;
|
||||
|
||||
@Slf4j
|
||||
public abstract class VectorClient {
|
||||
|
||||
public static boolean status;
|
||||
public static VectorClient INSTANCE;
|
||||
|
||||
public static void load() {
|
||||
PartnerAgentConfigManager configManager = (PartnerAgentConfigManager) AgentConfigManager.INSTANCE;
|
||||
VectorConfig vectorConfig = configManager.getConfig().getVectorConfig();
|
||||
int type = vectorConfig.getType();
|
||||
try {
|
||||
switch (type) {
|
||||
case 0:
|
||||
status = false;
|
||||
break;
|
||||
case 1:
|
||||
status = true;
|
||||
INSTANCE = new OllamaVectorClient(vectorConfig.getOllamaEmbeddingUrl(),
|
||||
vectorConfig.getOllamaEmbeddingModel());
|
||||
break;
|
||||
case 2:
|
||||
status = true;
|
||||
INSTANCE = new OnnxVectorClient(vectorConfig.getTokenizerPath(),
|
||||
vectorConfig.getEmbeddingModelPath());
|
||||
break;
|
||||
default:
|
||||
throw new ServiceLoadFailedException(
|
||||
"加载向量客户端失败! type: 0 -> 不启用语义缓存; type: 1 -> ollama; type: 2 -> ONNX RUNTIME");
|
||||
}
|
||||
log.info("向量客户端加载完毕");
|
||||
} catch (VectorClientLoadFailedException | VectorClientExecuteException exception) {
|
||||
status = false;
|
||||
log.error("向量客户端加载失败", exception);
|
||||
}
|
||||
}
|
||||
|
||||
public float[] compute(String input) {
|
||||
if (!status) {
|
||||
return null;
|
||||
}
|
||||
return doCompute(input);
|
||||
}
|
||||
|
||||
protected abstract float[] doCompute(String input);
|
||||
|
||||
public double compare(float[] v1, float[] v2) {
|
||||
if (!status) {
|
||||
return 0;
|
||||
}
|
||||
try (INDArray a1 = Nd4j.create(v1); INDArray a2 = Nd4j.create(v2)) {
|
||||
return Transforms.cosineSim(a1, a2);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,15 @@
|
||||
package work.slhaf.partner.common.vector.exception;
|
||||
|
||||
import work.slhaf.partner.api.agent.runtime.exception.AgentRuntimeException;
|
||||
|
||||
public class VectorClientExecuteException extends AgentRuntimeException {
|
||||
|
||||
public VectorClientExecuteException(String message) {
|
||||
super(message);
|
||||
}
|
||||
|
||||
public VectorClientExecuteException(String message, Throwable cause) {
|
||||
super(message, cause);
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,15 @@
|
||||
package work.slhaf.partner.common.vector.exception;
|
||||
|
||||
import work.slhaf.partner.api.agent.runtime.exception.AgentRuntimeException;
|
||||
|
||||
public class VectorClientLoadFailedException extends AgentRuntimeException {
|
||||
|
||||
public VectorClientLoadFailedException(String message) {
|
||||
super(message);
|
||||
}
|
||||
|
||||
public VectorClientLoadFailedException(String message, Throwable cause) {
|
||||
super(message, cause);
|
||||
}
|
||||
|
||||
}
|
||||
Reference in New Issue
Block a user