推进 ActionExtractor: 新增语义向量计算工具;开始推进语义缓存相关;调整配置类格式

This commit is contained in:
2025-10-16 10:14:39 +08:00
parent 2f09c0cd71
commit e78048f66d
6 changed files with 92 additions and 13 deletions

View File

@@ -4,6 +4,18 @@ import lombok.Data;
@Data @Data
public class Config { public class Config {
private int port;
private String agentId; private String agentId;
private WebSocketConfig webSocketConfig;
private VectorConfig vectorConfig;
@Data
public static class VectorConfig {
private String ollamaEmbeddingUrl;
private String ollamaEmbeddingModel;
}
@Data
public static class WebSocketConfig {
private int port;
}
} }

View File

@@ -33,8 +33,9 @@ public final class PartnerAgentConfigManager extends FileAgentConfigManager {
if (config == null || config.getAgentId() == null) { if (config == null || config.getAgentId() == null) {
throw new ConfigLoadFailedException("Partner Config Load Failed: " + COMMON_CONFIG_FILE); throw new ConfigLoadFailedException("Partner Config Load Failed: " + COMMON_CONFIG_FILE);
} }
if (config.getPort() <= 0 || config.getPort() > 65535) { int port = config.getWebSocketConfig().getPort();
throw new ConfigLoadFailedException("Invalid Websocket port: " + config.getPort()); if (port <= 0 || port > 65535) {
throw new ConfigLoadFailedException("Invalid Websocket port: " + port);
} }
} }
} }

View File

@@ -0,0 +1,59 @@
package work.slhaf.partner.common.util;
import cn.hutool.http.HttpRequest;
import cn.hutool.http.HttpResponse;
import com.alibaba.fastjson2.JSONObject;
import lombok.Data;
import lombok.extern.slf4j.Slf4j;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.ops.transforms.Transforms;
import work.slhaf.partner.api.agent.runtime.config.AgentConfigManager;
import work.slhaf.partner.common.config.PartnerAgentConfigManager;
import java.util.Map;
@Slf4j
public class VectorUtil {
private static final String OLLAMA_EMBEDDING_URL = ((PartnerAgentConfigManager) AgentConfigManager.INSTANCE).getConfig().getVectorConfig().getOllamaEmbeddingUrl();
private static final String OLLAMA_EMBEDDING_MODEL = ((PartnerAgentConfigManager) AgentConfigManager.INSTANCE).getConfig().getVectorConfig().getOllamaEmbeddingModel();
private VectorUtil() {
}
/**
* 如果计算失败将返回null
*
* @param input 需要计算向量的字符串
* @return 向量计算结果
*/
public static float[] compute(String input) {
Map<String, String> param = Map.of("model", OLLAMA_EMBEDDING_MODEL, "input", input);
HttpRequest request = HttpRequest.get(OLLAMA_EMBEDDING_URL).body(JSONObject.toJSONString(param));
try (HttpResponse response = request.execute()) {
if (!response.isOk()) return null;
String resStr = response.body();
EmbeddingModelResponse embeddingResponse = JSONObject.parseObject(resStr, EmbeddingModelResponse.class);
return embeddingResponse.getEmbeddings()[0];
} catch (Exception e) {
log.error("嵌入模型执行出错", e);
return null;
}
}
public static double compare(float[] v1, float[] v2) {
try (INDArray a1 = Nd4j.create(v1); INDArray a2 = Nd4j.create(v2)) {
return Transforms.cosineSim(a1, a2);
}
}
@Data
private static class EmbeddingModelResponse {
private String model;
private float[][] embeddings;
private long total_duration;
private long load_duration;
private int prompt_eval_count;
}
}

View File

@@ -4,6 +4,7 @@ import lombok.Getter;
import lombok.Setter; import lombok.Setter;
import work.slhaf.partner.api.agent.factory.capability.annotation.Capability; import work.slhaf.partner.api.agent.factory.capability.annotation.Capability;
import work.slhaf.partner.api.agent.factory.capability.annotation.CapabilityMethod; import work.slhaf.partner.api.agent.factory.capability.annotation.CapabilityMethod;
import work.slhaf.partner.common.util.VectorUtil;
import work.slhaf.partner.core.PartnerCore; import work.slhaf.partner.core.PartnerCore;
import work.slhaf.partner.core.action.entity.ActionCacheData; import work.slhaf.partner.core.action.entity.ActionCacheData;
import work.slhaf.partner.core.action.entity.MetaActionInfo; import work.slhaf.partner.core.action.entity.MetaActionInfo;
@@ -82,8 +83,15 @@ public class ActionCore extends PartnerCore<ActionCore> {
@CapabilityMethod @CapabilityMethod
public List<String> computeActionCache(String input){ public List<String> computeActionCache(String input){
//计算本次输入的向量 //计算本次输入的向量
float[] vector = VectorUtil.compute(input);
if (vector == null) return null;
//与现有缓存比对,如果存在,则使缓存计数+1 //与现有缓存比对,如果存在,则使缓存计数+1
actionCache.stream()
.filter(ActionCacheData::isActivated)
.forEach(data -> {
double compared = VectorUtil.compare(vector, data.getInputVector());
});
return null; return null;
} }

View File

@@ -1,18 +1,17 @@
package work.slhaf.partner.core.action.entity; package work.slhaf.partner.core.action.entity;
import lombok.Data;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.List; import java.util.List;
import org.nd4j.linalg.api.ndarray.INDArray;
import lombok.Data;
@Data @Data
public class ActionCacheData { public class ActionCacheData {
private INDArray inputArray; private float[] inputVector;
private INDArray tendencyArray; private float[] tendencyVector;
private String tendency; private String tendency;
private int count; private int inputMatchCount;
private List<String> activateInputs = new ArrayList<>();
private boolean activated; private boolean activated;
private List<float[]> validSamples = new ArrayList<>();
private double threshold;
} }

View File

@@ -33,7 +33,7 @@ public class WebSocketGateway extends WebSocketServer implements AgentGateway<Pa
private final ConcurrentHashMap<WebSocket, Long> lastPongTimes = new ConcurrentHashMap<>(); private final ConcurrentHashMap<WebSocket, Long> lastPongTimes = new ConcurrentHashMap<>();
public WebSocketGateway() { public WebSocketGateway() {
this(((PartnerAgentConfigManager) AgentConfigManager.INSTANCE).getConfig().getPort()); this(((PartnerAgentConfigManager) AgentConfigManager.INSTANCE).getConfig().getWebSocketConfig().getPort());
} }
private WebSocketGateway(int port) { private WebSocketGateway(int port) {