Action 模块语义缓存机制实现完毕,支持三种情况的语义缓存相关行为: 命中缓存且评估通过、命中缓存但评估未通过、未命中缓存但评估通过。将在评估过后步入主模块之前,进行异步更新操作(借助@AfterExecute注解,通过虚拟线程进入异步流程,在真正调用处使用平台线程加速计算)

This commit is contained in:
2025-10-19 22:05:27 +08:00
parent aee6d879e9
commit 5864760f35
9 changed files with 317 additions and 42 deletions

View File

@@ -64,4 +64,25 @@ public abstract class VectorClient {
return Transforms.cosineSim(a1, a2); return Transforms.cosineSim(a1, a2);
} }
} }
public float[] weightedAverage(float[] newVector, float[] primaryVector) {
try (INDArray primary = Nd4j.create(primaryVector);
INDArray latest = Nd4j.create(newVector)) {
// 1⃣ 计算余弦相似度
double similarity = Transforms.cosineSim(primary, latest);
// 2⃣ 根据相似度决定更新比例 α(差异越大,新输入影响越强)
double alpha = (1.0 - similarity) * 0.5;
alpha = Math.max(0.05, Math.min(alpha, 0.5));
// 3⃣ 按比例混合旧向量与新向量
INDArray updated = primary.mul(1 - alpha).add(latest.mul(alpha));
// 4⃣ 归一化结果(保持方向空间一致)
updated = updated.div(updated.norm2Number());
return updated.toFloatVector();
}
}
} }

View File

@@ -22,5 +22,5 @@ public interface ActionCapability {
List<String> selectTendencyCache(String input); List<String> selectTendencyCache(String input);
void updateTendencyCache(List<CacheAdjustData> list); void updateTendencyCache(CacheAdjustData data);
} }

View File

@@ -1,21 +1,28 @@
package work.slhaf.partner.core.action; package work.slhaf.partner.core.action;
import lombok.extern.slf4j.Slf4j;
import work.slhaf.partner.api.agent.factory.capability.annotation.Capability; import work.slhaf.partner.api.agent.factory.capability.annotation.Capability;
import work.slhaf.partner.api.agent.factory.capability.annotation.CapabilityMethod; import work.slhaf.partner.api.agent.factory.capability.annotation.CapabilityMethod;
import work.slhaf.partner.common.vector.VectorClient; import work.slhaf.partner.common.vector.VectorClient;
import work.slhaf.partner.core.PartnerCore; import work.slhaf.partner.core.PartnerCore;
import work.slhaf.partner.core.action.entity.ActionCacheData; import work.slhaf.partner.core.action.entity.ActionCacheData;
import work.slhaf.partner.core.action.entity.CacheAdjustData; import work.slhaf.partner.core.action.entity.CacheAdjustData;
import work.slhaf.partner.core.action.entity.CacheAdjustMetaData;
import work.slhaf.partner.core.action.entity.MetaActionInfo; import work.slhaf.partner.core.action.entity.MetaActionInfo;
import java.io.IOException; import java.io.IOException;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.HashMap; import java.util.HashMap;
import java.util.List; import java.util.List;
import java.util.concurrent.Executor;
import java.util.concurrent.Executors;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReentrantLock;
import java.util.stream.Collectors; import java.util.stream.Collectors;
@SuppressWarnings("FieldMayBeFinal") @SuppressWarnings("FieldMayBeFinal")
@Capability(value = "action") @Capability(value = "action")
@Slf4j
public class ActionCore extends PartnerCore<ActionCore> { public class ActionCore extends PartnerCore<ActionCore> {
/** /**
@@ -33,6 +40,10 @@ public class ActionCore extends PartnerCore<ActionCore> {
*/ */
private List<ActionCacheData> actionCache = new ArrayList<>(); private List<ActionCacheData> actionCache = new ArrayList<>();
private Lock cacheLock = new ReentrantLock();
private Executor executor = Executors.newFixedThreadPool(Runtime.getRuntime().availableProcessors());
public ActionCore() throws IOException, ClassNotFoundException { public ActionCore() throws IOException, ClassNotFoundException {
} }
@@ -105,43 +116,116 @@ public class ActionCore extends PartnerCore<ActionCore> {
} }
@CapabilityMethod @CapabilityMethod
public void updateTendencyCache(List<CacheAdjustData> list) { public void updateTendencyCache(CacheAdjustData data) {
List<CacheAdjustData> matchAndPassed = new ArrayList<>();
List<CacheAdjustData> matchNotPassed = new ArrayList<>();
List<CacheAdjustData> notMatchPassed = new ArrayList<>();
for (CacheAdjustData data : list) {
if (data.isHit() && data.isPassed()) {
matchAndPassed.add(data);
} else if (data.isHit()) {
matchNotPassed.add(data);
} else if (!data.isPassed()) {
notMatchPassed.add(data);
}
}
VectorClient vectorClient = VectorClient.INSTANCE; VectorClient vectorClient = VectorClient.INSTANCE;
adjustMatchAndPassed(matchAndPassed, vectorClient); List<CacheAdjustMetaData> list = data.getMetaDataList();
adjustMatchNotPassed(matchNotPassed, vectorClient); String input = data.getInput();
adjustNotMatchPassed(notMatchPassed, vectorClient); float[] inputVector = vectorClient.compute(input);
List<CacheAdjustMetaData> matchAndPassed = new ArrayList<>();
List<CacheAdjustMetaData> matchNotPassed = new ArrayList<>();
List<CacheAdjustMetaData> notMatchPassed = new ArrayList<>();
for (CacheAdjustMetaData metaData : list) {
if (metaData.isHit() && metaData.isPassed()) {
matchAndPassed.add(metaData);
} else if (metaData.isHit()) {
matchNotPassed.add(metaData);
} else if (!metaData.isPassed()) {
notMatchPassed.add(metaData);
}
}
executor.execute(() -> adjustMatchAndPassed(matchAndPassed, inputVector, input, vectorClient));
executor.execute(() -> adjustMatchNotPassed(matchNotPassed, vectorClient));
executor.execute(() -> adjustNotMatchPassed(notMatchPassed, inputVector, input, vectorClient));
} }
/** /**
* 命中缓存且评估通过时,根据输入内容的语义向量与现有的输入语义向量进行带权移动平均,以相似度为权重 * 命中缓存且评估通过时
* *
* @param matchAndPassed 该类型的带调整缓存信息列表 * @param matchAndPassed 该类型的带调整缓存信息列表
* @param inputVector 本次输入内容的语义向量
* @param vectorClient 向量客户端 * @param vectorClient 向量客户端
*/ */
private void adjustMatchAndPassed(List<CacheAdjustData> matchAndPassed, VectorClient vectorClient) { private void adjustMatchAndPassed(List<CacheAdjustMetaData> matchAndPassed, float[] inputVector, String input, VectorClient vectorClient) {
matchAndPassed.forEach(adjustData -> {
//获取原始缓存条目
String tendency = adjustData.getTendency();
ActionCacheData primaryCacheData = selectCacheData(tendency);
if (primaryCacheData == null) {
return;
}
primaryCacheData.updateAfterMatchAndPassed(inputVector, vectorClient, input);
});
} }
private void adjustMatchNotPassed(List<CacheAdjustData> matchNotPassed, VectorClient vectorClient) { /**
* 针对命中缓存、但评估未通过的条目与输入进行处理
*
* @param matchNotPassed 该类型的带调整缓存信息列表
* @param vectorClient 向量客户端
*/
private void adjustMatchNotPassed(List<CacheAdjustMetaData> matchNotPassed, VectorClient vectorClient) {
List<ActionCacheData> toRemove = new ArrayList<>();
matchNotPassed.forEach(adjustData -> {
//获取原始缓存条目
String tendency = adjustData.getTendency();
ActionCacheData primaryCacheData = selectCacheData(tendency);
if (primaryCacheData == null) {
return;
}
boolean remove = primaryCacheData.updateAfterMatchNotPassed(vectorClient);
if (remove) {
toRemove.add(primaryCacheData);
} }
private void adjustNotMatchPassed(List<CacheAdjustData> notMatchPassed, VectorClient vectorClient) { });
cacheLock.lock();
actionCache.removeAll(toRemove);
cacheLock.unlock();
}
/**
* 针对未命中但评估通过的缓存做出调整:
* <ol>
* <h3>如果存在缓存条目</h3>
* <li>
* 若已生效,但此时未匹配到则说明尚未生效或者阈值、向量{@link ActionCacheData#getInputVector()}存在问题,调低阈值,同时带权移动平均
* </li>
* <li>
* 若未生效,则只增加计数并带权移动平均
* </li>
* </ol>
* 如果不存在缓存条目,则新增并填充字段
*
* @param notMatchPassed 该类型的带调整缓存信息列表
* @param inputVector 本次输入内容的语义向量
* @param input 本次输入内容
* @param vectorClient 向量客户端
*/
private void adjustNotMatchPassed(List<CacheAdjustMetaData> notMatchPassed, float[] inputVector, String input, VectorClient vectorClient) {
notMatchPassed.forEach(adjustData -> {
//获取原始缓存条目
String tendency = adjustData.getTendency();
ActionCacheData primaryCacheData = selectCacheData(tendency);
float[] tendencyVector = vectorClient.compute(tendency);
if (primaryCacheData == null) {
actionCache.add(new ActionCacheData(tendency, tendencyVector, inputVector, input));
return;
}
primaryCacheData.updateAfterNotMatchPassed(input, inputVector, tendencyVector, vectorClient);
});
}
private ActionCacheData selectCacheData(String tendency) {
for (ActionCacheData actionCacheData : actionCache) {
if (actionCacheData.getTendency().equals(tendency)) {
return actionCacheData;
}
}
log.warn("[{}] 未找到行为倾向[{}]对应的缓存条目,可能是代码逻辑存在错误", getCoreKey(), tendency);
return null;
} }
@Override @Override

View File

@@ -1,25 +1,181 @@
package work.slhaf.partner.core.action.entity; package work.slhaf.partner.core.action.entity;
import lombok.Data; import lombok.Data;
import work.slhaf.partner.common.vector.VectorClient;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.List; import java.util.List;
@Data @Data
public class ActionCacheData { public class ActionCacheData {
private boolean activated; private boolean activated = false;
private int inputMatchCount; private int inputMatchCount = 1;
private float[] inputVector; private float[] inputVector;
private float[] tendencyVector; private float[] tendencyVector;
private String tendency; private String tendency;
private double threshold; private double threshold = 0.75;
private List<String> validSamples = new ArrayList<>(); private List<String> validSamples = new ArrayList<>();
private int failedCount; private int failedCount = 0;
private Type type; private Type type = Type.PRIMARY;
enum Type { public ActionCacheData(String tendency, float[] tendencyVector, float[] inputVector, String input) {
this.tendency = tendency;
this.inputVector = inputVector;
this.tendencyVector = tendencyVector;
this.validSamples.add(input);
}
/**
* 命中缓存且评估通过时,根据输入内容的语义向量与现有的输入语义向量进行带权移动平均,以相似度为权重,同时降低失败计数,为零时置为上一级缓存类型{@link ActionCacheData.Type}
*
* @param inputVector 本次输入内容对应的语义向量
* @param vectorClient 向量客户端
* @param input 本次输入内容
*/
public synchronized void updateAfterMatchAndPassed(float[] inputVector, VectorClient vectorClient, String input) {
updateInputVector(inputVector, vectorClient);
addValidSample(input);
reduceFailedCount();
updateType();
addInputMatchCount();
}
private void updateType() {
if (this.failedCount == 0) {
this.type = switch (type) {
case PRIMARY, REBUILD_V1 -> ActionCacheData.Type.PRIMARY;
case REBUILD_V2 -> ActionCacheData.Type.REBUILD_V1;
case REBUILD_V3 -> ActionCacheData.Type.REBUILD_V2;
};
}
}
private void reduceFailedCount() {
this.failedCount = Math.max(this.failedCount - 1, 0);
}
private void addValidSample(String input) {
if (this.validSamples.size() == 12) {
this.validSamples.removeFirst();
}
this.validSamples.add(input);
}
private void updateInputVector(float[] inputVector, VectorClient vectorClient) {
this.inputVector = vectorClient.weightedAverage(inputVector, this.inputVector);
}
/**
* 针对命中缓存、但评估未通过的条目与输入进行处理: 增加失败计数(必要时重建并更新类型等级)、调高阈值(0.02),由于缓存匹配但评估未通过,所以不进行带权移动平均
*
* @param vectorClient 向量客户端
* @return 是否需要删除(已在REBUILD_V3状态且达到最大误判次数的)
*/
public synchronized boolean updateAfterMatchNotPassed(VectorClient vectorClient) {
adjustThreshold();
addFailedCount();
if (this.failedCount < 3) {
return false;
}
if (this.type == Type.REBUILD_V3) {
return true;
}
rebuildAndSwitchType(vectorClient);
return false;
}
private void rebuildAndSwitchType(VectorClient vectorClient) {
this.type = switch (this.type) {
case PRIMARY -> {
//样本顺序反转后,以全部样本重建
this.validSamples = this.validSamples.reversed();
rebuildWithSamples(vectorClient);
yield Type.REBUILD_V1;
}
case REBUILD_V1 -> {
//截取后一半样本,反转后以此重建
List<String> temp = this.validSamples.subList(this.validSamples.size() / 2, this.validSamples.size());
this.validSamples = temp.reversed();
rebuildWithSamples(vectorClient);
yield Type.REBUILD_V2;
}
case REBUILD_V2 -> {
//截取后四分之一样本,反转后以此重建
List<String> temp = this.validSamples.subList(this.validSamples.size() / 4, this.validSamples.size());
this.validSamples = temp.reversed();
rebuildWithSamples(vectorClient);
yield Type.REBUILD_V3;
}
case REBUILD_V3 -> null;
};
//阈值减0.05,防止重建后一直升高
this.threshold = Math.max(this.threshold - 0.05, 0.75);
this.failedCount = 0;
}
private void rebuildWithSamples(VectorClient vectorClient) {
for (int i = 0; i < this.validSamples.size(); i++) {
String sample = this.validSamples.get(i);
if (i == 0) {
this.inputVector = vectorClient.compute(sample);
} else {
float[] newSampleVector = vectorClient.compute(sample);
this.inputVector = vectorClient.weightedAverage(this.inputVector, newSampleVector);
}
}
}
private void addFailedCount() {
this.failedCount = Math.min(this.failedCount + 1, 3);
}
private void adjustThreshold() {
double newThreshold = this.threshold + 0.03;
this.threshold = Math.min(newThreshold, 0.95);
}
/**
* 针对未命中但评估通过的已存在缓存做出调整:
* <ol>
* <li>
* 若已生效,但此时未匹配到则说明阈值或者向量{@link ActionCacheData#getInputVector()}存在问题,调低阈值,同时带权移动平均
* </li>
* <li>
* 若未生效,则只增加计数并带权移动平均
* </li>
* </ol>
*
* @param input 本次输入内容
* @param inputVector 本次输入内容对应的语义向量
* @param tendencyVector 本次倾向对应的语义向量
* @param vectorClient 向量客户端
*/
public synchronized void updateAfterNotMatchPassed(String input, float[] inputVector, float[] tendencyVector, VectorClient vectorClient) {
if (this.activated) {
reduceThreshold();
this.inputVector = vectorClient.weightedAverage(inputVector, this.inputVector);
} else {
addValidSample(input);
this.tendencyVector = vectorClient.weightedAverage(tendencyVector, this.tendencyVector);
addInputMatchCount();
}
}
private void reduceThreshold() {
double newThreshold = this.threshold - 0.02;
this.threshold = Math.max(newThreshold, 0.75);
}
private void addInputMatchCount() {
this.inputMatchCount += 1;
if (inputMatchCount >= 6) {
this.activated = true;
}
}
public enum Type {
PRIMARY, REBUILD_V1, REBUILD_V2, REBUILD_V3 PRIMARY, REBUILD_V1, REBUILD_V2, REBUILD_V3
} }
} }

View File

@@ -2,10 +2,10 @@ package work.slhaf.partner.core.action.entity;
import lombok.Data; import lombok.Data;
import java.util.List;
@Data @Data
public class CacheAdjustData { public class CacheAdjustData {
private String input; private String input;
private String tendency; private List<CacheAdjustMetaData> metaDataList;
private boolean passed;
private boolean hit;
} }

View File

@@ -0,0 +1,10 @@
package work.slhaf.partner.core.action.entity;
import lombok.Data;
@Data
public class CacheAdjustMetaData {
private String tendency;
private boolean passed;
private boolean hit;
}

View File

@@ -1,11 +1,13 @@
package work.slhaf.partner.module.modules.action.planner; package work.slhaf.partner.module.modules.action.planner;
import work.slhaf.partner.api.agent.factory.capability.annotation.InjectCapability; import work.slhaf.partner.api.agent.factory.capability.annotation.InjectCapability;
import work.slhaf.partner.api.agent.factory.module.annotation.AfterExecute;
import work.slhaf.partner.api.agent.factory.module.annotation.AgentModule; import work.slhaf.partner.api.agent.factory.module.annotation.AgentModule;
import work.slhaf.partner.api.agent.factory.module.annotation.Init; import work.slhaf.partner.api.agent.factory.module.annotation.Init;
import work.slhaf.partner.api.agent.factory.module.annotation.InjectModule; import work.slhaf.partner.api.agent.factory.module.annotation.InjectModule;
import work.slhaf.partner.api.chat.pojo.Message; import work.slhaf.partner.api.chat.pojo.Message;
import work.slhaf.partner.common.thread.InteractionThreadPoolExecutor; import work.slhaf.partner.common.thread.InteractionThreadPoolExecutor;
import work.slhaf.partner.common.vector.VectorClient;
import work.slhaf.partner.core.action.ActionCapability; import work.slhaf.partner.core.action.ActionCapability;
import work.slhaf.partner.core.action.entity.*; import work.slhaf.partner.core.action.entity.*;
import work.slhaf.partner.core.cognation.CognationCapability; import work.slhaf.partner.core.cognation.CognationCapability;
@@ -83,26 +85,30 @@ public class ActionPlanner extends PreRunningModule {
} }
EvaluatorInput evaluatorInput = assemblyHelper.buildEvaluatorInput(extractorResult, context.getUserId()); EvaluatorInput evaluatorInput = assemblyHelper.buildEvaluatorInput(extractorResult, context.getUserId());
List<EvaluatorResult> evaluatorResults = actionEvaluator.execute(evaluatorInput); //并发操作均为访问 List<EvaluatorResult> evaluatorResults = actionEvaluator.execute(evaluatorInput); //并发操作均为访问
if (extractorResult.isCacheEnabled())
updateTendencyCache(evaluatorResults, context.getInput(), extractorResult);
setupActionInfo(evaluatorResults, context); setupActionInfo(evaluatorResults, context);
return null; return null;
}); });
} }
@AfterExecute
private void updateTendencyCache(List<EvaluatorResult> evaluatorResults, String input, ExtractorResult extractorResult) { private void updateTendencyCache(List<EvaluatorResult> evaluatorResults, String input, ExtractorResult extractorResult) {
if (!VectorClient.status) {
return;
}
executor.execute(() -> { executor.execute(() -> {
List<CacheAdjustData> list = new ArrayList<>(); CacheAdjustData data = new CacheAdjustData();
List<CacheAdjustMetaData> list = new ArrayList<>();
List<String> hitTendencies = extractorResult.getTendencies(); List<String> hitTendencies = extractorResult.getTendencies();
for (EvaluatorResult result : evaluatorResults) { for (EvaluatorResult result : evaluatorResults) {
CacheAdjustData data = new CacheAdjustData(); CacheAdjustMetaData metaData = new CacheAdjustMetaData();
data.setTendency(result.getTendency()); metaData.setTendency(result.getTendency());
data.setInput(input); metaData.setPassed(result.isOk());
data.setPassed(result.isOk()); metaData.setHit(hitTendencies.contains(result.getTendency()));
data.setHit(hitTendencies.contains(result.getTendency())); list.add(metaData);
list.add(data);
} }
actionCapability.updateTendencyCache(list); data.setMetaDataList(list);
data.setInput(input);
actionCapability.updateTendencyCache(data);
}); });
} }

View File

@@ -24,7 +24,6 @@ public class ActionExtractor extends AgentRunningSubModule<ExtractorInput, Extra
public ExtractorResult execute(ExtractorInput data) { public ExtractorResult execute(ExtractorInput data) {
ExtractorResult result = new ExtractorResult(); ExtractorResult result = new ExtractorResult();
List<String> tendencyCache = actionCapability.selectTendencyCache(data.getInput()); List<String> tendencyCache = actionCapability.selectTendencyCache(data.getInput());
result.setCacheEnabled(tendencyCache != null);
if (tendencyCache != null && !tendencyCache.isEmpty()) { if (tendencyCache != null && !tendencyCache.isEmpty()) {
result.setTendencies(tendencyCache); result.setTendencies(tendencyCache);
return result; return result;

View File

@@ -7,6 +7,5 @@ import java.util.List;
@Data @Data
public class ExtractorResult { public class ExtractorResult {
private boolean cacheEnabled;
private List<String> tendencies = new ArrayList<>(); private List<String> tendencies = new ArrayList<>();
} }