refactor(vector): refactor VectorClient configuration loading method

This commit is contained in:
2026-04-06 15:52:30 +08:00
parent 332792daa2
commit f79a0521b2
5 changed files with 106 additions and 37 deletions

View File

@@ -2,7 +2,6 @@ package work.slhaf.partner;
import work.slhaf.partner.api.agent.Agent; import work.slhaf.partner.api.agent.Agent;
import work.slhaf.partner.common.config.PartnerAgentConfigLoader; import work.slhaf.partner.common.config.PartnerAgentConfigLoader;
import work.slhaf.partner.common.vector.VectorClient;
import work.slhaf.partner.runtime.exception.PartnerExceptionCallback; import work.slhaf.partner.runtime.exception.PartnerExceptionCallback;
import work.slhaf.partner.runtime.interaction.WebSocketGateway; import work.slhaf.partner.runtime.interaction.WebSocketGateway;
@@ -12,7 +11,6 @@ public class Main {
.setAgentConfigManager(PartnerAgentConfigLoader.class) .setAgentConfigManager(PartnerAgentConfigLoader.class)
.setGateway(WebSocketGateway.class) .setGateway(WebSocketGateway.class)
.setAgentExceptionCallback(PartnerExceptionCallback.class) .setAgentExceptionCallback(PartnerExceptionCallback.class)
.addAfterLaunchRunners(VectorClient::load)
.launch(); .launch();
} }
} }

View File

@@ -28,6 +28,7 @@ public class OnnxVectorClient extends VectorClient {
loadTokenizer(); loadTokenizer();
loadModel(); loadModel();
compute("test"); compute("test");
} }

View File

@@ -4,47 +4,22 @@ import lombok.extern.slf4j.Slf4j;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.ops.transforms.Transforms; import org.nd4j.linalg.ops.transforms.Transforms;
import work.slhaf.partner.api.agent.runtime.config.AgentConfigLoader;
import work.slhaf.partner.common.config.Config.VectorConfig;
import work.slhaf.partner.common.config.PartnerAgentConfigLoader;
import work.slhaf.partner.common.exception.ServiceLoadFailedException;
import work.slhaf.partner.common.vector.exception.VectorClientExecuteException;
import work.slhaf.partner.common.vector.exception.VectorClientLoadFailedException;
@Slf4j @Slf4j
public abstract class VectorClient { public abstract class VectorClient {
public static boolean status; public static boolean status = false;
public static VectorClient INSTANCE; public static VectorClient INSTANCE;
public static void load() { public static void startClient(VectorConfig config) {
PartnerAgentConfigLoader configManager = (PartnerAgentConfigLoader) AgentConfigLoader.INSTANCE; if (config instanceof VectorConfig.Ollama ollama) {
VectorConfig vectorConfig = configManager.getConfig().getVectorConfig(); INSTANCE = new OllamaVectorClient(ollama.ollamaEmbeddingUrl, ollama.ollamaEmbeddingModel);
int type = vectorConfig.getType(); } else if (config instanceof VectorConfig.Onnx onnx) {
try { INSTANCE = new OnnxVectorClient(onnx.tokenizerPath, onnx.embeddingModelPath);
switch (type) { } else {
case 0: return;
status = false;
break;
case 1:
status = true;
INSTANCE = new OllamaVectorClient(vectorConfig.getOllamaEmbeddingUrl(),
vectorConfig.getOllamaEmbeddingModel());
break;
case 2:
status = true;
INSTANCE = new OnnxVectorClient(vectorConfig.getTokenizerPath(),
vectorConfig.getEmbeddingModelPath());
break;
default:
throw new ServiceLoadFailedException(
"加载向量客户端失败! type: 0 -> 不启用语义缓存; type: 1 -> ollama; type: 2 -> ONNX RUNTIME");
}
log.info("向量客户端加载完毕");
} catch (VectorClientLoadFailedException | VectorClientExecuteException exception) {
status = false;
log.error("向量客户端加载失败", exception);
} }
status = true;
} }
public float[] compute(String input) { public float[] compute(String input) {
@@ -74,7 +49,7 @@ public abstract class VectorClient {
// 2⃣ 根据相似度决定更新比例 α(差异越大,新输入影响越强) // 2⃣ 根据相似度决定更新比例 α(差异越大,新输入影响越强)
double alpha = (1.0 - similarity) * 0.5; double alpha = (1.0 - similarity) * 0.5;
alpha = Math.max(0.05, Math.min(alpha, 0.5)); alpha = Math.clamp(alpha, 0.05, 0.5);
// 3⃣ 按比例混合旧向量与新向量 // 3⃣ 按比例混合旧向量与新向量
INDArray updated = primary.mul(1 - alpha).add(latest.mul(alpha)); INDArray updated = primary.mul(1 - alpha).add(latest.mul(alpha));

View File

@@ -0,0 +1,51 @@
package work.slhaf.partner.common.vector;
import com.alibaba.fastjson2.JSONObject;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;
import work.slhaf.partner.api.agent.runtime.config.Config;
import work.slhaf.partner.api.agent.runtime.config.ConfigRegistration;
import work.slhaf.partner.api.agent.runtime.config.Configurable;
import java.nio.file.Path;
import java.util.Map;
public class VectorClientRegistry implements Configurable, ConfigRegistration<VectorConfig> {
@Override
public void init(@NotNull VectorConfig config, @Nullable JSONObject json) {
if (!config.enabled) {
return;
}
if (config.type == null) {
return;
}
if (json == null) {
return;
}
config = switch (config.type) {
case ONNX -> json.toJavaObject(VectorConfig.Onnx.class);
case OLLAMA -> json.toJavaObject(VectorConfig.Ollama.class);
};
VectorClient.startClient(config);
}
@Override
@NotNull
public Class<VectorConfig> type() {
return VectorConfig.class;
}
@Nullable
@Override
public VectorConfig defaultConfig() {
return new VectorConfig(false, null);
}
@Override
public @NotNull Map<Path, ConfigRegistration<? extends Config>> declare() {
return Map.of(Path.of("vector", "config.json"), this);
}
}

View File

@@ -0,0 +1,44 @@
package work.slhaf.partner.common.vector;
import work.slhaf.partner.api.agent.runtime.config.Config;
public sealed class VectorConfig extends Config permits VectorConfig.Ollama, VectorConfig.Onnx {
final boolean enabled;
final Type type;
public VectorConfig(boolean enabled, Type type) {
this.enabled = enabled;
this.type = type;
}
public enum Type {
ONNX,
OLLAMA
}
static final class Onnx extends VectorConfig {
final String tokenizerPath;
final String embeddingModelPath;
public Onnx(boolean enabled, Type type, String tokenizerPath, String embeddingModelPath) {
super(enabled, type);
this.tokenizerPath = tokenizerPath;
this.embeddingModelPath = embeddingModelPath;
}
}
static final class Ollama extends VectorConfig {
final String ollamaEmbeddingUrl;
final String ollamaEmbeddingModel;
public Ollama(boolean enabled, Type type, String ollamaEmbeddingUrl, String ollamaEmbeddingModel) {
super(enabled, type);
this.ollamaEmbeddingUrl = ollamaEmbeddingUrl;
this.ollamaEmbeddingModel = ollamaEmbeddingModel;
}
}
}