From e78048f66d2716af5e3eb5af7c2f3ebb1e9dd30b Mon Sep 17 00:00:00 2001 From: slhafzjw Date: Thu, 16 Oct 2025 10:14:39 +0800 Subject: [PATCH] =?UTF-8?q?=E6=8E=A8=E8=BF=9B=20ActionExtractor:=20?= =?UTF-8?q?=E6=96=B0=E5=A2=9E=E8=AF=AD=E4=B9=89=E5=90=91=E9=87=8F=E8=AE=A1?= =?UTF-8?q?=E7=AE=97=E5=B7=A5=E5=85=B7;=E5=BC=80=E5=A7=8B=E6=8E=A8?= =?UTF-8?q?=E8=BF=9B=E8=AF=AD=E4=B9=89=E7=BC=93=E5=AD=98=E7=9B=B8=E5=85=B3?= =?UTF-8?q?;=E8=B0=83=E6=95=B4=E9=85=8D=E7=BD=AE=E7=B1=BB=E6=A0=BC?= =?UTF-8?q?=E5=BC=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../slhaf/partner/common/config/Config.java | 14 ++++- .../config/PartnerAgentConfigManager.java | 5 +- .../slhaf/partner/common/util/VectorUtil.java | 59 +++++++++++++++++++ .../slhaf/partner/core/action/ActionCore.java | 10 +++- .../core/action/entity/ActionCacheData.java | 15 +++-- .../runtime/interaction/WebSocketGateway.java | 2 +- 6 files changed, 92 insertions(+), 13 deletions(-) create mode 100644 Partner-Main/src/main/java/work/slhaf/partner/common/util/VectorUtil.java diff --git a/Partner-Main/src/main/java/work/slhaf/partner/common/config/Config.java b/Partner-Main/src/main/java/work/slhaf/partner/common/config/Config.java index 0f37197d..8b33aa5d 100644 --- a/Partner-Main/src/main/java/work/slhaf/partner/common/config/Config.java +++ b/Partner-Main/src/main/java/work/slhaf/partner/common/config/Config.java @@ -4,6 +4,18 @@ import lombok.Data; @Data public class Config { - private int port; private String agentId; + private WebSocketConfig webSocketConfig; + private VectorConfig vectorConfig; + + @Data + public static class VectorConfig { + private String ollamaEmbeddingUrl; + private String ollamaEmbeddingModel; + } + + @Data + public static class WebSocketConfig { + private int port; + } } diff --git a/Partner-Main/src/main/java/work/slhaf/partner/common/config/PartnerAgentConfigManager.java b/Partner-Main/src/main/java/work/slhaf/partner/common/config/PartnerAgentConfigManager.java index 86f84375..78ba1369 100644 --- a/Partner-Main/src/main/java/work/slhaf/partner/common/config/PartnerAgentConfigManager.java +++ b/Partner-Main/src/main/java/work/slhaf/partner/common/config/PartnerAgentConfigManager.java @@ -33,8 +33,9 @@ public final class PartnerAgentConfigManager extends FileAgentConfigManager { if (config == null || config.getAgentId() == null) { throw new ConfigLoadFailedException("Partner Config Load Failed: " + COMMON_CONFIG_FILE); } - if (config.getPort() <= 0 || config.getPort() > 65535) { - throw new ConfigLoadFailedException("Invalid Websocket port: " + config.getPort()); + int port = config.getWebSocketConfig().getPort(); + if (port <= 0 || port > 65535) { + throw new ConfigLoadFailedException("Invalid Websocket port: " + port); } } } diff --git a/Partner-Main/src/main/java/work/slhaf/partner/common/util/VectorUtil.java b/Partner-Main/src/main/java/work/slhaf/partner/common/util/VectorUtil.java new file mode 100644 index 00000000..81e94ea8 --- /dev/null +++ b/Partner-Main/src/main/java/work/slhaf/partner/common/util/VectorUtil.java @@ -0,0 +1,59 @@ +package work.slhaf.partner.common.util; + +import cn.hutool.http.HttpRequest; +import cn.hutool.http.HttpResponse; +import com.alibaba.fastjson2.JSONObject; +import lombok.Data; +import lombok.extern.slf4j.Slf4j; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.Nd4j; +import org.nd4j.linalg.ops.transforms.Transforms; +import work.slhaf.partner.api.agent.runtime.config.AgentConfigManager; +import work.slhaf.partner.common.config.PartnerAgentConfigManager; + +import java.util.Map; + +@Slf4j +public class VectorUtil { + + private static final String OLLAMA_EMBEDDING_URL = ((PartnerAgentConfigManager) AgentConfigManager.INSTANCE).getConfig().getVectorConfig().getOllamaEmbeddingUrl(); + private static final String OLLAMA_EMBEDDING_MODEL = ((PartnerAgentConfigManager) AgentConfigManager.INSTANCE).getConfig().getVectorConfig().getOllamaEmbeddingModel(); + + private VectorUtil() { + } + + /** + * 如果计算失败将返回null + * + * @param input 需要计算向量的字符串 + * @return 向量计算结果 + */ + public static float[] compute(String input) { + Map param = Map.of("model", OLLAMA_EMBEDDING_MODEL, "input", input); + HttpRequest request = HttpRequest.get(OLLAMA_EMBEDDING_URL).body(JSONObject.toJSONString(param)); + try (HttpResponse response = request.execute()) { + if (!response.isOk()) return null; + String resStr = response.body(); + EmbeddingModelResponse embeddingResponse = JSONObject.parseObject(resStr, EmbeddingModelResponse.class); + return embeddingResponse.getEmbeddings()[0]; + } catch (Exception e) { + log.error("嵌入模型执行出错", e); + return null; + } + } + + public static double compare(float[] v1, float[] v2) { + try (INDArray a1 = Nd4j.create(v1); INDArray a2 = Nd4j.create(v2)) { + return Transforms.cosineSim(a1, a2); + } + } + + @Data + private static class EmbeddingModelResponse { + private String model; + private float[][] embeddings; + private long total_duration; + private long load_duration; + private int prompt_eval_count; + } +} diff --git a/Partner-Main/src/main/java/work/slhaf/partner/core/action/ActionCore.java b/Partner-Main/src/main/java/work/slhaf/partner/core/action/ActionCore.java index 9952bff5..ee69d8a1 100644 --- a/Partner-Main/src/main/java/work/slhaf/partner/core/action/ActionCore.java +++ b/Partner-Main/src/main/java/work/slhaf/partner/core/action/ActionCore.java @@ -4,6 +4,7 @@ import lombok.Getter; import lombok.Setter; import work.slhaf.partner.api.agent.factory.capability.annotation.Capability; import work.slhaf.partner.api.agent.factory.capability.annotation.CapabilityMethod; +import work.slhaf.partner.common.util.VectorUtil; import work.slhaf.partner.core.PartnerCore; import work.slhaf.partner.core.action.entity.ActionCacheData; import work.slhaf.partner.core.action.entity.MetaActionInfo; @@ -82,8 +83,15 @@ public class ActionCore extends PartnerCore { @CapabilityMethod public List computeActionCache(String input){ //计算本次输入的向量 - + float[] vector = VectorUtil.compute(input); + if (vector == null) return null; //与现有缓存比对,如果存在,则使缓存计数+1 + actionCache.stream() + .filter(ActionCacheData::isActivated) + .forEach(data -> { + double compared = VectorUtil.compare(vector, data.getInputVector()); + }); + return null; } diff --git a/Partner-Main/src/main/java/work/slhaf/partner/core/action/entity/ActionCacheData.java b/Partner-Main/src/main/java/work/slhaf/partner/core/action/entity/ActionCacheData.java index 2c9b4ed1..abc0a24a 100644 --- a/Partner-Main/src/main/java/work/slhaf/partner/core/action/entity/ActionCacheData.java +++ b/Partner-Main/src/main/java/work/slhaf/partner/core/action/entity/ActionCacheData.java @@ -1,18 +1,17 @@ package work.slhaf.partner.core.action.entity; +import lombok.Data; + import java.util.ArrayList; import java.util.List; -import org.nd4j.linalg.api.ndarray.INDArray; - -import lombok.Data; - @Data public class ActionCacheData { - private INDArray inputArray; - private INDArray tendencyArray; + private float[] inputVector; + private float[] tendencyVector; private String tendency; - private int count; - private List activateInputs = new ArrayList<>(); + private int inputMatchCount; private boolean activated; + private List validSamples = new ArrayList<>(); + private double threshold; } diff --git a/Partner-Main/src/main/java/work/slhaf/partner/runtime/interaction/WebSocketGateway.java b/Partner-Main/src/main/java/work/slhaf/partner/runtime/interaction/WebSocketGateway.java index b4b14a77..3bc7adcb 100644 --- a/Partner-Main/src/main/java/work/slhaf/partner/runtime/interaction/WebSocketGateway.java +++ b/Partner-Main/src/main/java/work/slhaf/partner/runtime/interaction/WebSocketGateway.java @@ -33,7 +33,7 @@ public class WebSocketGateway extends WebSocketServer implements AgentGateway lastPongTimes = new ConcurrentHashMap<>(); public WebSocketGateway() { - this(((PartnerAgentConfigManager) AgentConfigManager.INSTANCE).getConfig().getPort()); + this(((PartnerAgentConfigManager) AgentConfigManager.INSTANCE).getConfig().getWebSocketConfig().getPort()); } private WebSocketGateway(int port) {