mirror of
https://github.com/slhaf/Partner.git
synced 2026-05-12 16:53:04 +08:00
推进 ActionExtractor: 新增语义向量计算工具;开始推进语义缓存相关;调整配置类格式
This commit is contained in:
@@ -4,6 +4,18 @@ import lombok.Data;
|
|||||||
|
|
||||||
@Data
|
@Data
|
||||||
public class Config {
|
public class Config {
|
||||||
private int port;
|
|
||||||
private String agentId;
|
private String agentId;
|
||||||
|
private WebSocketConfig webSocketConfig;
|
||||||
|
private VectorConfig vectorConfig;
|
||||||
|
|
||||||
|
@Data
|
||||||
|
public static class VectorConfig {
|
||||||
|
private String ollamaEmbeddingUrl;
|
||||||
|
private String ollamaEmbeddingModel;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Data
|
||||||
|
public static class WebSocketConfig {
|
||||||
|
private int port;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -33,8 +33,9 @@ public final class PartnerAgentConfigManager extends FileAgentConfigManager {
|
|||||||
if (config == null || config.getAgentId() == null) {
|
if (config == null || config.getAgentId() == null) {
|
||||||
throw new ConfigLoadFailedException("Partner Config Load Failed: " + COMMON_CONFIG_FILE);
|
throw new ConfigLoadFailedException("Partner Config Load Failed: " + COMMON_CONFIG_FILE);
|
||||||
}
|
}
|
||||||
if (config.getPort() <= 0 || config.getPort() > 65535) {
|
int port = config.getWebSocketConfig().getPort();
|
||||||
throw new ConfigLoadFailedException("Invalid Websocket port: " + config.getPort());
|
if (port <= 0 || port > 65535) {
|
||||||
|
throw new ConfigLoadFailedException("Invalid Websocket port: " + port);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,59 @@
|
|||||||
|
package work.slhaf.partner.common.util;
|
||||||
|
|
||||||
|
import cn.hutool.http.HttpRequest;
|
||||||
|
import cn.hutool.http.HttpResponse;
|
||||||
|
import com.alibaba.fastjson2.JSONObject;
|
||||||
|
import lombok.Data;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
import org.nd4j.linalg.ops.transforms.Transforms;
|
||||||
|
import work.slhaf.partner.api.agent.runtime.config.AgentConfigManager;
|
||||||
|
import work.slhaf.partner.common.config.PartnerAgentConfigManager;
|
||||||
|
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
@Slf4j
|
||||||
|
public class VectorUtil {
|
||||||
|
|
||||||
|
private static final String OLLAMA_EMBEDDING_URL = ((PartnerAgentConfigManager) AgentConfigManager.INSTANCE).getConfig().getVectorConfig().getOllamaEmbeddingUrl();
|
||||||
|
private static final String OLLAMA_EMBEDDING_MODEL = ((PartnerAgentConfigManager) AgentConfigManager.INSTANCE).getConfig().getVectorConfig().getOllamaEmbeddingModel();
|
||||||
|
|
||||||
|
private VectorUtil() {
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 如果计算失败将返回null
|
||||||
|
*
|
||||||
|
* @param input 需要计算向量的字符串
|
||||||
|
* @return 向量计算结果
|
||||||
|
*/
|
||||||
|
public static float[] compute(String input) {
|
||||||
|
Map<String, String> param = Map.of("model", OLLAMA_EMBEDDING_MODEL, "input", input);
|
||||||
|
HttpRequest request = HttpRequest.get(OLLAMA_EMBEDDING_URL).body(JSONObject.toJSONString(param));
|
||||||
|
try (HttpResponse response = request.execute()) {
|
||||||
|
if (!response.isOk()) return null;
|
||||||
|
String resStr = response.body();
|
||||||
|
EmbeddingModelResponse embeddingResponse = JSONObject.parseObject(resStr, EmbeddingModelResponse.class);
|
||||||
|
return embeddingResponse.getEmbeddings()[0];
|
||||||
|
} catch (Exception e) {
|
||||||
|
log.error("嵌入模型执行出错", e);
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public static double compare(float[] v1, float[] v2) {
|
||||||
|
try (INDArray a1 = Nd4j.create(v1); INDArray a2 = Nd4j.create(v2)) {
|
||||||
|
return Transforms.cosineSim(a1, a2);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Data
|
||||||
|
private static class EmbeddingModelResponse {
|
||||||
|
private String model;
|
||||||
|
private float[][] embeddings;
|
||||||
|
private long total_duration;
|
||||||
|
private long load_duration;
|
||||||
|
private int prompt_eval_count;
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -4,6 +4,7 @@ import lombok.Getter;
|
|||||||
import lombok.Setter;
|
import lombok.Setter;
|
||||||
import work.slhaf.partner.api.agent.factory.capability.annotation.Capability;
|
import work.slhaf.partner.api.agent.factory.capability.annotation.Capability;
|
||||||
import work.slhaf.partner.api.agent.factory.capability.annotation.CapabilityMethod;
|
import work.slhaf.partner.api.agent.factory.capability.annotation.CapabilityMethod;
|
||||||
|
import work.slhaf.partner.common.util.VectorUtil;
|
||||||
import work.slhaf.partner.core.PartnerCore;
|
import work.slhaf.partner.core.PartnerCore;
|
||||||
import work.slhaf.partner.core.action.entity.ActionCacheData;
|
import work.slhaf.partner.core.action.entity.ActionCacheData;
|
||||||
import work.slhaf.partner.core.action.entity.MetaActionInfo;
|
import work.slhaf.partner.core.action.entity.MetaActionInfo;
|
||||||
@@ -82,8 +83,15 @@ public class ActionCore extends PartnerCore<ActionCore> {
|
|||||||
@CapabilityMethod
|
@CapabilityMethod
|
||||||
public List<String> computeActionCache(String input){
|
public List<String> computeActionCache(String input){
|
||||||
//计算本次输入的向量
|
//计算本次输入的向量
|
||||||
|
float[] vector = VectorUtil.compute(input);
|
||||||
|
if (vector == null) return null;
|
||||||
//与现有缓存比对,如果存在,则使缓存计数+1
|
//与现有缓存比对,如果存在,则使缓存计数+1
|
||||||
|
actionCache.stream()
|
||||||
|
.filter(ActionCacheData::isActivated)
|
||||||
|
.forEach(data -> {
|
||||||
|
double compared = VectorUtil.compare(vector, data.getInputVector());
|
||||||
|
});
|
||||||
|
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,18 +1,17 @@
|
|||||||
package work.slhaf.partner.core.action.entity;
|
package work.slhaf.partner.core.action.entity;
|
||||||
|
|
||||||
|
import lombok.Data;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
|
||||||
|
|
||||||
import lombok.Data;
|
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
public class ActionCacheData {
|
public class ActionCacheData {
|
||||||
private INDArray inputArray;
|
private float[] inputVector;
|
||||||
private INDArray tendencyArray;
|
private float[] tendencyVector;
|
||||||
private String tendency;
|
private String tendency;
|
||||||
private int count;
|
private int inputMatchCount;
|
||||||
private List<String> activateInputs = new ArrayList<>();
|
|
||||||
private boolean activated;
|
private boolean activated;
|
||||||
|
private List<float[]> validSamples = new ArrayList<>();
|
||||||
|
private double threshold;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -33,7 +33,7 @@ public class WebSocketGateway extends WebSocketServer implements AgentGateway<Pa
|
|||||||
private final ConcurrentHashMap<WebSocket, Long> lastPongTimes = new ConcurrentHashMap<>();
|
private final ConcurrentHashMap<WebSocket, Long> lastPongTimes = new ConcurrentHashMap<>();
|
||||||
|
|
||||||
public WebSocketGateway() {
|
public WebSocketGateway() {
|
||||||
this(((PartnerAgentConfigManager) AgentConfigManager.INSTANCE).getConfig().getPort());
|
this(((PartnerAgentConfigManager) AgentConfigManager.INSTANCE).getConfig().getWebSocketConfig().getPort());
|
||||||
}
|
}
|
||||||
|
|
||||||
private WebSocketGateway(int port) {
|
private WebSocketGateway(int port) {
|
||||||
|
|||||||
Reference in New Issue
Block a user