diff --git a/Partner-Main/src/main/java/work/slhaf/partner/common/config/Config.java b/Partner-Main/src/main/java/work/slhaf/partner/common/config/Config.java index 0f37197d..8b33aa5d 100644 --- a/Partner-Main/src/main/java/work/slhaf/partner/common/config/Config.java +++ b/Partner-Main/src/main/java/work/slhaf/partner/common/config/Config.java @@ -4,6 +4,18 @@ import lombok.Data; @Data public class Config { - private int port; private String agentId; + private WebSocketConfig webSocketConfig; + private VectorConfig vectorConfig; + + @Data + public static class VectorConfig { + private String ollamaEmbeddingUrl; + private String ollamaEmbeddingModel; + } + + @Data + public static class WebSocketConfig { + private int port; + } } diff --git a/Partner-Main/src/main/java/work/slhaf/partner/common/config/PartnerAgentConfigManager.java b/Partner-Main/src/main/java/work/slhaf/partner/common/config/PartnerAgentConfigManager.java index 86f84375..78ba1369 100644 --- a/Partner-Main/src/main/java/work/slhaf/partner/common/config/PartnerAgentConfigManager.java +++ b/Partner-Main/src/main/java/work/slhaf/partner/common/config/PartnerAgentConfigManager.java @@ -33,8 +33,9 @@ public final class PartnerAgentConfigManager extends FileAgentConfigManager { if (config == null || config.getAgentId() == null) { throw new ConfigLoadFailedException("Partner Config Load Failed: " + COMMON_CONFIG_FILE); } - if (config.getPort() <= 0 || config.getPort() > 65535) { - throw new ConfigLoadFailedException("Invalid Websocket port: " + config.getPort()); + int port = config.getWebSocketConfig().getPort(); + if (port <= 0 || port > 65535) { + throw new ConfigLoadFailedException("Invalid Websocket port: " + port); } } } diff --git a/Partner-Main/src/main/java/work/slhaf/partner/common/util/VectorUtil.java b/Partner-Main/src/main/java/work/slhaf/partner/common/util/VectorUtil.java new file mode 100644 index 00000000..81e94ea8 --- /dev/null +++ b/Partner-Main/src/main/java/work/slhaf/partner/common/util/VectorUtil.java @@ -0,0 +1,59 @@ +package work.slhaf.partner.common.util; + +import cn.hutool.http.HttpRequest; +import cn.hutool.http.HttpResponse; +import com.alibaba.fastjson2.JSONObject; +import lombok.Data; +import lombok.extern.slf4j.Slf4j; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.Nd4j; +import org.nd4j.linalg.ops.transforms.Transforms; +import work.slhaf.partner.api.agent.runtime.config.AgentConfigManager; +import work.slhaf.partner.common.config.PartnerAgentConfigManager; + +import java.util.Map; + +@Slf4j +public class VectorUtil { + + private static final String OLLAMA_EMBEDDING_URL = ((PartnerAgentConfigManager) AgentConfigManager.INSTANCE).getConfig().getVectorConfig().getOllamaEmbeddingUrl(); + private static final String OLLAMA_EMBEDDING_MODEL = ((PartnerAgentConfigManager) AgentConfigManager.INSTANCE).getConfig().getVectorConfig().getOllamaEmbeddingModel(); + + private VectorUtil() { + } + + /** + * 如果计算失败将返回null + * + * @param input 需要计算向量的字符串 + * @return 向量计算结果 + */ + public static float[] compute(String input) { + Map param = Map.of("model", OLLAMA_EMBEDDING_MODEL, "input", input); + HttpRequest request = HttpRequest.get(OLLAMA_EMBEDDING_URL).body(JSONObject.toJSONString(param)); + try (HttpResponse response = request.execute()) { + if (!response.isOk()) return null; + String resStr = response.body(); + EmbeddingModelResponse embeddingResponse = JSONObject.parseObject(resStr, EmbeddingModelResponse.class); + return embeddingResponse.getEmbeddings()[0]; + } catch (Exception e) { + log.error("嵌入模型执行出错", e); + return null; + } + } + + public static double compare(float[] v1, float[] v2) { + try (INDArray a1 = Nd4j.create(v1); INDArray a2 = Nd4j.create(v2)) { + return Transforms.cosineSim(a1, a2); + } + } + + @Data + private static class EmbeddingModelResponse { + private String model; + private float[][] embeddings; + private long total_duration; + private long load_duration; + private int prompt_eval_count; + } +} diff --git a/Partner-Main/src/main/java/work/slhaf/partner/core/action/ActionCore.java b/Partner-Main/src/main/java/work/slhaf/partner/core/action/ActionCore.java index 9952bff5..ee69d8a1 100644 --- a/Partner-Main/src/main/java/work/slhaf/partner/core/action/ActionCore.java +++ b/Partner-Main/src/main/java/work/slhaf/partner/core/action/ActionCore.java @@ -4,6 +4,7 @@ import lombok.Getter; import lombok.Setter; import work.slhaf.partner.api.agent.factory.capability.annotation.Capability; import work.slhaf.partner.api.agent.factory.capability.annotation.CapabilityMethod; +import work.slhaf.partner.common.util.VectorUtil; import work.slhaf.partner.core.PartnerCore; import work.slhaf.partner.core.action.entity.ActionCacheData; import work.slhaf.partner.core.action.entity.MetaActionInfo; @@ -82,8 +83,15 @@ public class ActionCore extends PartnerCore { @CapabilityMethod public List computeActionCache(String input){ //计算本次输入的向量 - + float[] vector = VectorUtil.compute(input); + if (vector == null) return null; //与现有缓存比对,如果存在,则使缓存计数+1 + actionCache.stream() + .filter(ActionCacheData::isActivated) + .forEach(data -> { + double compared = VectorUtil.compare(vector, data.getInputVector()); + }); + return null; } diff --git a/Partner-Main/src/main/java/work/slhaf/partner/core/action/entity/ActionCacheData.java b/Partner-Main/src/main/java/work/slhaf/partner/core/action/entity/ActionCacheData.java index 2c9b4ed1..abc0a24a 100644 --- a/Partner-Main/src/main/java/work/slhaf/partner/core/action/entity/ActionCacheData.java +++ b/Partner-Main/src/main/java/work/slhaf/partner/core/action/entity/ActionCacheData.java @@ -1,18 +1,17 @@ package work.slhaf.partner.core.action.entity; +import lombok.Data; + import java.util.ArrayList; import java.util.List; -import org.nd4j.linalg.api.ndarray.INDArray; - -import lombok.Data; - @Data public class ActionCacheData { - private INDArray inputArray; - private INDArray tendencyArray; + private float[] inputVector; + private float[] tendencyVector; private String tendency; - private int count; - private List activateInputs = new ArrayList<>(); + private int inputMatchCount; private boolean activated; + private List validSamples = new ArrayList<>(); + private double threshold; } diff --git a/Partner-Main/src/main/java/work/slhaf/partner/runtime/interaction/WebSocketGateway.java b/Partner-Main/src/main/java/work/slhaf/partner/runtime/interaction/WebSocketGateway.java index b4b14a77..3bc7adcb 100644 --- a/Partner-Main/src/main/java/work/slhaf/partner/runtime/interaction/WebSocketGateway.java +++ b/Partner-Main/src/main/java/work/slhaf/partner/runtime/interaction/WebSocketGateway.java @@ -33,7 +33,7 @@ public class WebSocketGateway extends WebSocketServer implements AgentGateway lastPongTimes = new ConcurrentHashMap<>(); public WebSocketGateway() { - this(((PartnerAgentConfigManager) AgentConfigManager.INSTANCE).getConfig().getPort()); + this(((PartnerAgentConfigManager) AgentConfigManager.INSTANCE).getConfig().getWebSocketConfig().getPort()); } private WebSocketGateway(int port) {