mirror of
https://github.com/slhaf/Partner.git
synced 2026-05-12 16:53:04 +08:00
推进 ActionExtractor: 新增语义向量计算工具;开始推进语义缓存相关;调整配置类格式
This commit is contained in:
@@ -4,6 +4,18 @@ import lombok.Data;
|
||||
|
||||
@Data
|
||||
public class Config {
|
||||
private int port;
|
||||
private String agentId;
|
||||
private WebSocketConfig webSocketConfig;
|
||||
private VectorConfig vectorConfig;
|
||||
|
||||
@Data
|
||||
public static class VectorConfig {
|
||||
private String ollamaEmbeddingUrl;
|
||||
private String ollamaEmbeddingModel;
|
||||
}
|
||||
|
||||
@Data
|
||||
public static class WebSocketConfig {
|
||||
private int port;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -33,8 +33,9 @@ public final class PartnerAgentConfigManager extends FileAgentConfigManager {
|
||||
if (config == null || config.getAgentId() == null) {
|
||||
throw new ConfigLoadFailedException("Partner Config Load Failed: " + COMMON_CONFIG_FILE);
|
||||
}
|
||||
if (config.getPort() <= 0 || config.getPort() > 65535) {
|
||||
throw new ConfigLoadFailedException("Invalid Websocket port: " + config.getPort());
|
||||
int port = config.getWebSocketConfig().getPort();
|
||||
if (port <= 0 || port > 65535) {
|
||||
throw new ConfigLoadFailedException("Invalid Websocket port: " + port);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,59 @@
|
||||
package work.slhaf.partner.common.util;
|
||||
|
||||
import cn.hutool.http.HttpRequest;
|
||||
import cn.hutool.http.HttpResponse;
|
||||
import com.alibaba.fastjson2.JSONObject;
|
||||
import lombok.Data;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.linalg.ops.transforms.Transforms;
|
||||
import work.slhaf.partner.api.agent.runtime.config.AgentConfigManager;
|
||||
import work.slhaf.partner.common.config.PartnerAgentConfigManager;
|
||||
|
||||
import java.util.Map;
|
||||
|
||||
@Slf4j
|
||||
public class VectorUtil {
|
||||
|
||||
private static final String OLLAMA_EMBEDDING_URL = ((PartnerAgentConfigManager) AgentConfigManager.INSTANCE).getConfig().getVectorConfig().getOllamaEmbeddingUrl();
|
||||
private static final String OLLAMA_EMBEDDING_MODEL = ((PartnerAgentConfigManager) AgentConfigManager.INSTANCE).getConfig().getVectorConfig().getOllamaEmbeddingModel();
|
||||
|
||||
private VectorUtil() {
|
||||
}
|
||||
|
||||
/**
|
||||
* 如果计算失败将返回null
|
||||
*
|
||||
* @param input 需要计算向量的字符串
|
||||
* @return 向量计算结果
|
||||
*/
|
||||
public static float[] compute(String input) {
|
||||
Map<String, String> param = Map.of("model", OLLAMA_EMBEDDING_MODEL, "input", input);
|
||||
HttpRequest request = HttpRequest.get(OLLAMA_EMBEDDING_URL).body(JSONObject.toJSONString(param));
|
||||
try (HttpResponse response = request.execute()) {
|
||||
if (!response.isOk()) return null;
|
||||
String resStr = response.body();
|
||||
EmbeddingModelResponse embeddingResponse = JSONObject.parseObject(resStr, EmbeddingModelResponse.class);
|
||||
return embeddingResponse.getEmbeddings()[0];
|
||||
} catch (Exception e) {
|
||||
log.error("嵌入模型执行出错", e);
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
public static double compare(float[] v1, float[] v2) {
|
||||
try (INDArray a1 = Nd4j.create(v1); INDArray a2 = Nd4j.create(v2)) {
|
||||
return Transforms.cosineSim(a1, a2);
|
||||
}
|
||||
}
|
||||
|
||||
@Data
|
||||
private static class EmbeddingModelResponse {
|
||||
private String model;
|
||||
private float[][] embeddings;
|
||||
private long total_duration;
|
||||
private long load_duration;
|
||||
private int prompt_eval_count;
|
||||
}
|
||||
}
|
||||
@@ -4,6 +4,7 @@ import lombok.Getter;
|
||||
import lombok.Setter;
|
||||
import work.slhaf.partner.api.agent.factory.capability.annotation.Capability;
|
||||
import work.slhaf.partner.api.agent.factory.capability.annotation.CapabilityMethod;
|
||||
import work.slhaf.partner.common.util.VectorUtil;
|
||||
import work.slhaf.partner.core.PartnerCore;
|
||||
import work.slhaf.partner.core.action.entity.ActionCacheData;
|
||||
import work.slhaf.partner.core.action.entity.MetaActionInfo;
|
||||
@@ -82,8 +83,15 @@ public class ActionCore extends PartnerCore<ActionCore> {
|
||||
@CapabilityMethod
|
||||
public List<String> computeActionCache(String input){
|
||||
//计算本次输入的向量
|
||||
|
||||
float[] vector = VectorUtil.compute(input);
|
||||
if (vector == null) return null;
|
||||
//与现有缓存比对,如果存在,则使缓存计数+1
|
||||
actionCache.stream()
|
||||
.filter(ActionCacheData::isActivated)
|
||||
.forEach(data -> {
|
||||
double compared = VectorUtil.compare(vector, data.getInputVector());
|
||||
});
|
||||
|
||||
return null;
|
||||
}
|
||||
|
||||
|
||||
@@ -1,18 +1,17 @@
|
||||
package work.slhaf.partner.core.action.entity;
|
||||
|
||||
import lombok.Data;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
|
||||
import lombok.Data;
|
||||
|
||||
@Data
|
||||
public class ActionCacheData {
|
||||
private INDArray inputArray;
|
||||
private INDArray tendencyArray;
|
||||
private float[] inputVector;
|
||||
private float[] tendencyVector;
|
||||
private String tendency;
|
||||
private int count;
|
||||
private List<String> activateInputs = new ArrayList<>();
|
||||
private int inputMatchCount;
|
||||
private boolean activated;
|
||||
private List<float[]> validSamples = new ArrayList<>();
|
||||
private double threshold;
|
||||
}
|
||||
|
||||
@@ -33,7 +33,7 @@ public class WebSocketGateway extends WebSocketServer implements AgentGateway<Pa
|
||||
private final ConcurrentHashMap<WebSocket, Long> lastPongTimes = new ConcurrentHashMap<>();
|
||||
|
||||
public WebSocketGateway() {
|
||||
this(((PartnerAgentConfigManager) AgentConfigManager.INSTANCE).getConfig().getPort());
|
||||
this(((PartnerAgentConfigManager) AgentConfigManager.INSTANCE).getConfig().getWebSocketConfig().getPort());
|
||||
}
|
||||
|
||||
private WebSocketGateway(int port) {
|
||||
|
||||
Reference in New Issue
Block a user