mirror of
https://github.com/slhaf/Partner.git
synced 2026-05-12 16:53:04 +08:00
refactor(vector): refactor VectorClient configuration loading method
This commit is contained in:
@@ -2,7 +2,6 @@ package work.slhaf.partner;
|
|||||||
|
|
||||||
import work.slhaf.partner.api.agent.Agent;
|
import work.slhaf.partner.api.agent.Agent;
|
||||||
import work.slhaf.partner.common.config.PartnerAgentConfigLoader;
|
import work.slhaf.partner.common.config.PartnerAgentConfigLoader;
|
||||||
import work.slhaf.partner.common.vector.VectorClient;
|
|
||||||
import work.slhaf.partner.runtime.exception.PartnerExceptionCallback;
|
import work.slhaf.partner.runtime.exception.PartnerExceptionCallback;
|
||||||
import work.slhaf.partner.runtime.interaction.WebSocketGateway;
|
import work.slhaf.partner.runtime.interaction.WebSocketGateway;
|
||||||
|
|
||||||
@@ -12,7 +11,6 @@ public class Main {
|
|||||||
.setAgentConfigManager(PartnerAgentConfigLoader.class)
|
.setAgentConfigManager(PartnerAgentConfigLoader.class)
|
||||||
.setGateway(WebSocketGateway.class)
|
.setGateway(WebSocketGateway.class)
|
||||||
.setAgentExceptionCallback(PartnerExceptionCallback.class)
|
.setAgentExceptionCallback(PartnerExceptionCallback.class)
|
||||||
.addAfterLaunchRunners(VectorClient::load)
|
|
||||||
.launch();
|
.launch();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -28,6 +28,7 @@ public class OnnxVectorClient extends VectorClient {
|
|||||||
|
|
||||||
loadTokenizer();
|
loadTokenizer();
|
||||||
loadModel();
|
loadModel();
|
||||||
|
|
||||||
compute("test");
|
compute("test");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -4,47 +4,22 @@ import lombok.extern.slf4j.Slf4j;
|
|||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.nd4j.linalg.ops.transforms.Transforms;
|
import org.nd4j.linalg.ops.transforms.Transforms;
|
||||||
import work.slhaf.partner.api.agent.runtime.config.AgentConfigLoader;
|
|
||||||
import work.slhaf.partner.common.config.Config.VectorConfig;
|
|
||||||
import work.slhaf.partner.common.config.PartnerAgentConfigLoader;
|
|
||||||
import work.slhaf.partner.common.exception.ServiceLoadFailedException;
|
|
||||||
import work.slhaf.partner.common.vector.exception.VectorClientExecuteException;
|
|
||||||
import work.slhaf.partner.common.vector.exception.VectorClientLoadFailedException;
|
|
||||||
|
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public abstract class VectorClient {
|
public abstract class VectorClient {
|
||||||
|
|
||||||
public static boolean status;
|
public static boolean status = false;
|
||||||
public static VectorClient INSTANCE;
|
public static VectorClient INSTANCE;
|
||||||
|
|
||||||
public static void load() {
|
public static void startClient(VectorConfig config) {
|
||||||
PartnerAgentConfigLoader configManager = (PartnerAgentConfigLoader) AgentConfigLoader.INSTANCE;
|
if (config instanceof VectorConfig.Ollama ollama) {
|
||||||
VectorConfig vectorConfig = configManager.getConfig().getVectorConfig();
|
INSTANCE = new OllamaVectorClient(ollama.ollamaEmbeddingUrl, ollama.ollamaEmbeddingModel);
|
||||||
int type = vectorConfig.getType();
|
} else if (config instanceof VectorConfig.Onnx onnx) {
|
||||||
try {
|
INSTANCE = new OnnxVectorClient(onnx.tokenizerPath, onnx.embeddingModelPath);
|
||||||
switch (type) {
|
} else {
|
||||||
case 0:
|
return;
|
||||||
status = false;
|
|
||||||
break;
|
|
||||||
case 1:
|
|
||||||
status = true;
|
|
||||||
INSTANCE = new OllamaVectorClient(vectorConfig.getOllamaEmbeddingUrl(),
|
|
||||||
vectorConfig.getOllamaEmbeddingModel());
|
|
||||||
break;
|
|
||||||
case 2:
|
|
||||||
status = true;
|
|
||||||
INSTANCE = new OnnxVectorClient(vectorConfig.getTokenizerPath(),
|
|
||||||
vectorConfig.getEmbeddingModelPath());
|
|
||||||
break;
|
|
||||||
default:
|
|
||||||
throw new ServiceLoadFailedException(
|
|
||||||
"加载向量客户端失败! type: 0 -> 不启用语义缓存; type: 1 -> ollama; type: 2 -> ONNX RUNTIME");
|
|
||||||
}
|
|
||||||
log.info("向量客户端加载完毕");
|
|
||||||
} catch (VectorClientLoadFailedException | VectorClientExecuteException exception) {
|
|
||||||
status = false;
|
|
||||||
log.error("向量客户端加载失败", exception);
|
|
||||||
}
|
}
|
||||||
|
status = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
public float[] compute(String input) {
|
public float[] compute(String input) {
|
||||||
@@ -74,7 +49,7 @@ public abstract class VectorClient {
|
|||||||
|
|
||||||
// 2️⃣ 根据相似度决定更新比例 α(差异越大,新输入影响越强)
|
// 2️⃣ 根据相似度决定更新比例 α(差异越大,新输入影响越强)
|
||||||
double alpha = (1.0 - similarity) * 0.5;
|
double alpha = (1.0 - similarity) * 0.5;
|
||||||
alpha = Math.max(0.05, Math.min(alpha, 0.5));
|
alpha = Math.clamp(alpha, 0.05, 0.5);
|
||||||
|
|
||||||
// 3️⃣ 按比例混合旧向量与新向量
|
// 3️⃣ 按比例混合旧向量与新向量
|
||||||
INDArray updated = primary.mul(1 - alpha).add(latest.mul(alpha));
|
INDArray updated = primary.mul(1 - alpha).add(latest.mul(alpha));
|
||||||
|
|||||||
@@ -0,0 +1,51 @@
|
|||||||
|
package work.slhaf.partner.common.vector;
|
||||||
|
|
||||||
|
import com.alibaba.fastjson2.JSONObject;
|
||||||
|
import org.jetbrains.annotations.NotNull;
|
||||||
|
import org.jetbrains.annotations.Nullable;
|
||||||
|
import work.slhaf.partner.api.agent.runtime.config.Config;
|
||||||
|
import work.slhaf.partner.api.agent.runtime.config.ConfigRegistration;
|
||||||
|
import work.slhaf.partner.api.agent.runtime.config.Configurable;
|
||||||
|
|
||||||
|
import java.nio.file.Path;
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
public class VectorClientRegistry implements Configurable, ConfigRegistration<VectorConfig> {
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void init(@NotNull VectorConfig config, @Nullable JSONObject json) {
|
||||||
|
if (!config.enabled) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (config.type == null) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (json == null) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
config = switch (config.type) {
|
||||||
|
case ONNX -> json.toJavaObject(VectorConfig.Onnx.class);
|
||||||
|
case OLLAMA -> json.toJavaObject(VectorConfig.Ollama.class);
|
||||||
|
};
|
||||||
|
VectorClient.startClient(config);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
@NotNull
|
||||||
|
public Class<VectorConfig> type() {
|
||||||
|
return VectorConfig.class;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Nullable
|
||||||
|
@Override
|
||||||
|
public VectorConfig defaultConfig() {
|
||||||
|
return new VectorConfig(false, null);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public @NotNull Map<Path, ConfigRegistration<? extends Config>> declare() {
|
||||||
|
return Map.of(Path.of("vector", "config.json"), this);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
@@ -0,0 +1,44 @@
|
|||||||
|
package work.slhaf.partner.common.vector;
|
||||||
|
|
||||||
|
import work.slhaf.partner.api.agent.runtime.config.Config;
|
||||||
|
|
||||||
|
public sealed class VectorConfig extends Config permits VectorConfig.Ollama, VectorConfig.Onnx {
|
||||||
|
final boolean enabled;
|
||||||
|
final Type type;
|
||||||
|
|
||||||
|
public VectorConfig(boolean enabled, Type type) {
|
||||||
|
this.enabled = enabled;
|
||||||
|
this.type = type;
|
||||||
|
}
|
||||||
|
|
||||||
|
public enum Type {
|
||||||
|
ONNX,
|
||||||
|
OLLAMA
|
||||||
|
}
|
||||||
|
|
||||||
|
static final class Onnx extends VectorConfig {
|
||||||
|
|
||||||
|
final String tokenizerPath;
|
||||||
|
final String embeddingModelPath;
|
||||||
|
|
||||||
|
public Onnx(boolean enabled, Type type, String tokenizerPath, String embeddingModelPath) {
|
||||||
|
super(enabled, type);
|
||||||
|
this.tokenizerPath = tokenizerPath;
|
||||||
|
this.embeddingModelPath = embeddingModelPath;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static final class Ollama extends VectorConfig {
|
||||||
|
|
||||||
|
final String ollamaEmbeddingUrl;
|
||||||
|
final String ollamaEmbeddingModel;
|
||||||
|
|
||||||
|
public Ollama(boolean enabled, Type type, String ollamaEmbeddingUrl, String ollamaEmbeddingModel) {
|
||||||
|
super(enabled, type);
|
||||||
|
this.ollamaEmbeddingUrl = ollamaEmbeddingUrl;
|
||||||
|
this.ollamaEmbeddingModel = ollamaEmbeddingModel;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
Reference in New Issue
Block a user