refactor(action): remove problematic action tendency cache

This commit is contained in:
2026-04-05 22:32:52 +08:00
parent 50db3fa7b2
commit 3b236286b9
5 changed files with 6 additions and 377 deletions

View File

@@ -6,7 +6,6 @@ import work.slhaf.partner.api.agent.factory.capability.annotation.Capability;
import work.slhaf.partner.core.action.entity.ExecutableAction; import work.slhaf.partner.core.action.entity.ExecutableAction;
import work.slhaf.partner.core.action.entity.MetaAction; import work.slhaf.partner.core.action.entity.MetaAction;
import work.slhaf.partner.core.action.entity.MetaActionInfo; import work.slhaf.partner.core.action.entity.MetaActionInfo;
import work.slhaf.partner.core.action.entity.cache.CacheAdjustData;
import work.slhaf.partner.core.action.entity.intervention.MetaIntervention; import work.slhaf.partner.core.action.entity.intervention.MetaIntervention;
import work.slhaf.partner.core.action.runner.RunnerClient; import work.slhaf.partner.core.action.runner.RunnerClient;
@@ -22,10 +21,6 @@ public interface ActionCapability {
Set<ExecutableAction> listActions(@Nullable ExecutableAction.Status status, @Nullable String source); Set<ExecutableAction> listActions(@Nullable ExecutableAction.Status status, @Nullable String source);
List<String> selectTendencyCache(String input);
void updateTendencyCache(CacheAdjustData data);
ExecutorService getExecutor(ActionCore.ExecutorType type); ExecutorService getExecutor(ActionCore.ExecutorType type);
MetaAction loadMetaAction(@NonNull String actionKey); MetaAction loadMetaAction(@NonNull String actionKey);

View File

@@ -6,14 +6,10 @@ import lombok.val;
import org.jetbrains.annotations.Nullable; import org.jetbrains.annotations.Nullable;
import work.slhaf.partner.api.agent.factory.capability.annotation.CapabilityCore; import work.slhaf.partner.api.agent.factory.capability.annotation.CapabilityCore;
import work.slhaf.partner.api.agent.factory.capability.annotation.CapabilityMethod; import work.slhaf.partner.api.agent.factory.capability.annotation.CapabilityMethod;
import work.slhaf.partner.common.vector.VectorClient;
import work.slhaf.partner.core.PartnerCore; import work.slhaf.partner.core.PartnerCore;
import work.slhaf.partner.core.action.entity.ExecutableAction; import work.slhaf.partner.core.action.entity.ExecutableAction;
import work.slhaf.partner.core.action.entity.MetaAction; import work.slhaf.partner.core.action.entity.MetaAction;
import work.slhaf.partner.core.action.entity.MetaActionInfo; import work.slhaf.partner.core.action.entity.MetaActionInfo;
import work.slhaf.partner.core.action.entity.cache.ActionCacheData;
import work.slhaf.partner.core.action.entity.cache.CacheAdjustData;
import work.slhaf.partner.core.action.entity.cache.CacheAdjustMetaData;
import work.slhaf.partner.core.action.entity.intervention.InterventionType; import work.slhaf.partner.core.action.entity.intervention.InterventionType;
import work.slhaf.partner.core.action.entity.intervention.MetaIntervention; import work.slhaf.partner.core.action.entity.intervention.MetaIntervention;
import work.slhaf.partner.core.action.exception.MetaActionNotFoundException; import work.slhaf.partner.core.action.exception.MetaActionNotFoundException;
@@ -26,8 +22,6 @@ import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.CopyOnWriteArraySet; import java.util.concurrent.CopyOnWriteArraySet;
import java.util.concurrent.ExecutorService; import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors; import java.util.concurrent.Executors;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReentrantLock;
import java.util.stream.Collectors; import java.util.stream.Collectors;
@SuppressWarnings("FieldMayBeFinal") @SuppressWarnings("FieldMayBeFinal")
@@ -37,11 +31,10 @@ public class ActionCore extends PartnerCore<ActionCore> {
public static final String BUILTIN_LOCATION = "builtin"; public static final String BUILTIN_LOCATION = "builtin";
public static final String ORIGIN_LOCATION = "origin"; public static final String ORIGIN_LOCATION = "origin";
private final Lock cacheLock = new ReentrantLock();
// 由于当前的执行器逻辑实现,平台线程池大小不得小于 2这里规定为最小为 4 // 由于当前的执行器逻辑实现,平台线程池大小不得小于 2这里规定为最小为 4
private final ExecutorService platformExecutor = Executors private final ExecutorService platformExecutor = Executors.newFixedThreadPool(Math.max(Runtime.getRuntime().availableProcessors(), 4));
.newFixedThreadPool(Math.max(Runtime.getRuntime().availableProcessors(), 4));
private final ExecutorService virtualExecutor = Executors.newVirtualThreadPerTaskExecutor(); private final ExecutorService virtualExecutor = Executors.newVirtualThreadPerTaskExecutor();
/** /**
* 已存在的行动程序,键格式为‘<MCP-ServerName>::<Tool-Name>’,值为 MCP Server 通过 Resources 相关渠道传递的行动程序元信息 * 已存在的行动程序,键格式为‘<MCP-ServerName>::<Tool-Name>’,值为 MCP Server 通过 Resources 相关渠道传递的行动程序元信息
*/ */
@@ -50,10 +43,7 @@ public class ActionCore extends PartnerCore<ActionCore> {
* 持久行动池 * 持久行动池
*/ */
private CopyOnWriteArraySet<ExecutableAction> actionPool = new CopyOnWriteArraySet<>(); private CopyOnWriteArraySet<ExecutableAction> actionPool = new CopyOnWriteArraySet<>();
/**
* 语义缓存与行为倾向映射
*/
private List<ActionCacheData> actionCache = new ArrayList<>();
private RunnerClient runnerClient; private RunnerClient runnerClient;
public ActionCore() throws IOException, ClassNotFoundException { public ActionCore() throws IOException, ClassNotFoundException {
@@ -85,59 +75,6 @@ public class ActionCore extends PartnerCore<ActionCore> {
.collect(Collectors.toSet()); .collect(Collectors.toSet());
} }
/**
* 计算输入内容的语义向量,根据与{@link ActionCacheData#getInputVector()}的相似度挑取缓存,后续将根据评估结果来更新计数
*
* @param input 本次输入内容
* @return 命中的行为倾向集合
*/
@CapabilityMethod
public List<String> selectTendencyCache(String input) {
if (!VectorClient.status) {
return null;
}
VectorClient vectorClient = VectorClient.INSTANCE;
// 计算本次输入的向量
float[] vector = vectorClient.compute(input);
if (vector == null)
return null;
// 与现有缓存比对,将匹配到的收集并返回
return actionCache.parallelStream()
.filter(ActionCacheData::isActivated)
.filter(data -> {
double compared = vectorClient.compare(vector, data.getInputVector());
return compared > data.getThreshold();
})
.map(ActionCacheData::getTendency)
.collect(Collectors.toList());
}
@CapabilityMethod
public void updateTendencyCache(CacheAdjustData data) {
VectorClient vectorClient = VectorClient.INSTANCE;
List<CacheAdjustMetaData> list = data.getMetaDataList();
String input = data.getInput();
float[] inputVector = vectorClient.compute(input);
List<CacheAdjustMetaData> matchAndPassed = new ArrayList<>();
List<CacheAdjustMetaData> matchNotPassed = new ArrayList<>();
List<CacheAdjustMetaData> notMatchPassed = new ArrayList<>();
for (CacheAdjustMetaData metaData : list) {
if (metaData.isHit() && metaData.isPassed()) {
matchAndPassed.add(metaData);
} else if (metaData.isHit()) {
matchNotPassed.add(metaData);
} else if (!metaData.isPassed()) {
notMatchPassed.add(metaData);
}
}
platformExecutor.execute(() -> adjustMatchAndPassed(matchAndPassed, inputVector, input, vectorClient));
platformExecutor.execute(() -> adjustMatchNotPassed(matchNotPassed, vectorClient));
platformExecutor.execute(() -> adjustNotMatchPassed(notMatchPassed, inputVector, input, vectorClient));
}
@CapabilityMethod @CapabilityMethod
public ExecutorService getExecutor(ExecutorType type) { public ExecutorService getExecutor(ExecutorType type) {
return switch (type) { return switch (type) {
@@ -291,95 +228,6 @@ public class ActionCore extends PartnerCore<ActionCore> {
executableAction.getHistory().clear(); executableAction.getHistory().clear();
} }
/**
* 命中缓存且评估通过时
*
* @param matchAndPassed 该类型的带调整缓存信息列表
* @param inputVector 本次输入内容的语义向量
* @param vectorClient 向量客户端
*/
private void adjustMatchAndPassed(List<CacheAdjustMetaData> matchAndPassed, float[] inputVector, String input,
VectorClient vectorClient) {
matchAndPassed.forEach(adjustData -> {
// 获取原始缓存条目
String tendency = adjustData.getTendency();
ActionCacheData primaryCacheData = selectCacheData(tendency);
if (primaryCacheData == null) {
return;
}
primaryCacheData.updateAfterMatchAndPassed(inputVector, vectorClient, input);
});
}
/**
* 针对命中缓存、但评估未通过的条目与输入进行处理
*
* @param matchNotPassed 该类型的带调整缓存信息列表
* @param vectorClient 向量客户端
*/
private void adjustMatchNotPassed(List<CacheAdjustMetaData> matchNotPassed, VectorClient vectorClient) {
List<ActionCacheData> toRemove = new ArrayList<>();
matchNotPassed.forEach(adjustData -> {
// 获取原始缓存条目
String tendency = adjustData.getTendency();
ActionCacheData primaryCacheData = selectCacheData(tendency);
if (primaryCacheData == null) {
return;
}
boolean remove = primaryCacheData.updateAfterMatchNotPassed(vectorClient);
if (remove) {
toRemove.add(primaryCacheData);
}
});
cacheLock.lock();
actionCache.removeAll(toRemove);
cacheLock.unlock();
}
/**
* 针对未命中但评估通过的缓存做出调整:
* <ol>
* <h3>如果存在缓存条目</h3>
* <li>
* 若已生效,但此时未匹配到则说明尚未生效或者阈值、向量{@link ActionCacheData#getInputVector()}存在问题,调低阈值,同时带权移动平均
* </li>
* <li>
* 若未生效,则只增加计数并带权移动平均
* </li>
* </ol>
* 如果不存在缓存条目,则新增并填充字段
*
* @param notMatchPassed 该类型的带调整缓存信息列表
* @param inputVector 本次输入内容的语义向量
* @param input 本次输入内容
* @param vectorClient 向量客户端
*/
private void adjustNotMatchPassed(List<CacheAdjustMetaData> notMatchPassed, float[] inputVector, String input,
VectorClient vectorClient) {
notMatchPassed.forEach(adjustData -> {
// 获取原始缓存条目
String tendency = adjustData.getTendency();
ActionCacheData primaryCacheData = selectCacheData(tendency);
float[] tendencyVector = vectorClient.compute(tendency);
if (primaryCacheData == null) {
actionCache.add(new ActionCacheData(tendency, tendencyVector, inputVector, input));
return;
}
primaryCacheData.updateAfterNotMatchPassed(input, inputVector, tendencyVector, vectorClient);
});
}
private ActionCacheData selectCacheData(String tendency) {
for (ActionCacheData actionCacheData : actionCache) {
if (actionCacheData.getTendency().equals(tendency)) {
return actionCacheData;
}
}
log.warn("[{}] 未找到行为倾向[{}]对应的缓存条目,可能是代码逻辑存在错误", getCoreKey(), tendency);
return null;
}
@Override @Override
protected String getCoreKey() { protected String getCoreKey() {
return "action-core"; return "action-core";

View File

@@ -1,181 +0,0 @@
package work.slhaf.partner.core.action.entity.cache;
import lombok.Data;
import work.slhaf.partner.common.vector.VectorClient;
import java.util.ArrayList;
import java.util.List;
@Data
public class ActionCacheData {
private boolean activated = false;
private int inputMatchCount = 1;
private float[] inputVector;
private float[] tendencyVector;
private String tendency;
private double threshold = 0.75;
private List<String> validSamples = new ArrayList<>();
private int failedCount = 0;
private Type type = Type.PRIMARY;
public ActionCacheData(String tendency, float[] tendencyVector, float[] inputVector, String input) {
this.tendency = tendency;
this.inputVector = inputVector;
this.tendencyVector = tendencyVector;
this.validSamples.add(input);
}
/**
* 命中缓存且评估通过时,根据输入内容的语义向量与现有的输入语义向量进行带权移动平均,以相似度为权重,同时降低失败计数,为零时置为上一级缓存类型{@link ActionCacheData.Type}
*
* @param inputVector 本次输入内容对应的语义向量
* @param vectorClient 向量客户端
* @param input 本次输入内容
*/
public synchronized void updateAfterMatchAndPassed(float[] inputVector, VectorClient vectorClient, String input) {
updateInputVector(inputVector, vectorClient);
addValidSample(input);
reduceFailedCount();
updateType();
addInputMatchCount();
}
private void updateType() {
if (this.failedCount == 0) {
this.type = switch (type) {
case PRIMARY, REBUILD_V1 -> ActionCacheData.Type.PRIMARY;
case REBUILD_V2 -> ActionCacheData.Type.REBUILD_V1;
case REBUILD_V3 -> ActionCacheData.Type.REBUILD_V2;
};
}
}
private void reduceFailedCount() {
this.failedCount = Math.max(this.failedCount - 1, 0);
}
private void addValidSample(String input) {
if (this.validSamples.size() == 12) {
this.validSamples.removeFirst();
}
this.validSamples.add(input);
}
private void updateInputVector(float[] inputVector, VectorClient vectorClient) {
this.inputVector = vectorClient.weightedAverage(inputVector, this.inputVector);
}
/**
* 针对命中缓存、但评估未通过的条目与输入进行处理: 增加失败计数(必要时重建并更新类型等级)、调高阈值(0.02),由于缓存匹配但评估未通过,所以不进行带权移动平均
*
* @param vectorClient 向量客户端
* @return 是否需要删除(已在REBUILD_V3状态且达到最大误判次数的)
*/
public synchronized boolean updateAfterMatchNotPassed(VectorClient vectorClient) {
adjustThreshold();
addFailedCount();
if (this.failedCount < 3) {
return false;
}
if (this.type == Type.REBUILD_V3) {
return true;
}
rebuildAndSwitchType(vectorClient);
return false;
}
private void rebuildAndSwitchType(VectorClient vectorClient) {
this.type = switch (this.type) {
case PRIMARY -> {
//样本顺序反转后,以全部样本重建
this.validSamples = this.validSamples.reversed();
rebuildWithSamples(vectorClient);
yield Type.REBUILD_V1;
}
case REBUILD_V1 -> {
//截取后一半样本,反转后以此重建
List<String> temp = this.validSamples.subList(this.validSamples.size() / 2, this.validSamples.size());
this.validSamples = temp.reversed();
rebuildWithSamples(vectorClient);
yield Type.REBUILD_V2;
}
case REBUILD_V2 -> {
//截取后四分之一样本,反转后以此重建
List<String> temp = this.validSamples.subList(this.validSamples.size() / 4, this.validSamples.size());
this.validSamples = temp.reversed();
rebuildWithSamples(vectorClient);
yield Type.REBUILD_V3;
}
case REBUILD_V3 -> null;
};
//阈值减0.05,防止重建后一直升高
this.threshold = Math.max(this.threshold - 0.05, 0.75);
this.failedCount = 0;
}
private void rebuildWithSamples(VectorClient vectorClient) {
for (int i = 0; i < this.validSamples.size(); i++) {
String sample = this.validSamples.get(i);
if (i == 0) {
this.inputVector = vectorClient.compute(sample);
} else {
float[] newSampleVector = vectorClient.compute(sample);
this.inputVector = vectorClient.weightedAverage(this.inputVector, newSampleVector);
}
}
}
private void addFailedCount() {
this.failedCount = Math.min(this.failedCount + 1, 3);
}
private void adjustThreshold() {
double newThreshold = this.threshold + 0.03;
this.threshold = Math.min(newThreshold, 0.95);
}
/**
* 针对未命中但评估通过的已存在缓存做出调整:
* <ol>
* <li>
* 若已生效,但此时未匹配到则说明阈值或者向量{@link ActionCacheData#getInputVector()}存在问题,调低阈值,同时带权移动平均
* </li>
* <li>
* 若未生效,则只增加计数并带权移动平均
* </li>
* </ol>
*
* @param input 本次输入内容
* @param inputVector 本次输入内容对应的语义向量
* @param tendencyVector 本次倾向对应的语义向量
* @param vectorClient 向量客户端
*/
public synchronized void updateAfterNotMatchPassed(String input, float[] inputVector, float[] tendencyVector, VectorClient vectorClient) {
if (this.activated) {
reduceThreshold();
this.inputVector = vectorClient.weightedAverage(inputVector, this.inputVector);
} else {
addValidSample(input);
this.tendencyVector = vectorClient.weightedAverage(tendencyVector, this.tendencyVector);
addInputMatchCount();
}
}
private void reduceThreshold() {
double newThreshold = this.threshold - 0.02;
this.threshold = Math.max(newThreshold, 0.75);
}
private void addInputMatchCount() {
this.inputMatchCount += 1;
if (inputMatchCount >= 6) {
this.activated = true;
}
}
public enum Type {
PRIMARY, REBUILD_V1, REBUILD_V2, REBUILD_V3
}
}

View File

@@ -8,12 +8,9 @@ import work.slhaf.partner.api.agent.factory.capability.annotation.InjectCapabili
import work.slhaf.partner.api.agent.factory.component.abstracts.AbstractAgentModule; import work.slhaf.partner.api.agent.factory.component.abstracts.AbstractAgentModule;
import work.slhaf.partner.api.agent.factory.component.annotation.Init; import work.slhaf.partner.api.agent.factory.component.annotation.Init;
import work.slhaf.partner.api.agent.factory.component.annotation.InjectModule; import work.slhaf.partner.api.agent.factory.component.annotation.InjectModule;
import work.slhaf.partner.common.vector.VectorClient;
import work.slhaf.partner.core.action.ActionCapability; import work.slhaf.partner.core.action.ActionCapability;
import work.slhaf.partner.core.action.ActionCore; import work.slhaf.partner.core.action.ActionCore;
import work.slhaf.partner.core.action.entity.*; import work.slhaf.partner.core.action.entity.*;
import work.slhaf.partner.core.action.entity.cache.CacheAdjustData;
import work.slhaf.partner.core.action.entity.cache.CacheAdjustMetaData;
import work.slhaf.partner.core.cognition.BlockContent; import work.slhaf.partner.core.cognition.BlockContent;
import work.slhaf.partner.core.cognition.CognitionCapability; import work.slhaf.partner.core.cognition.CognitionCapability;
import work.slhaf.partner.core.cognition.CommunicationBlockContent; import work.slhaf.partner.core.cognition.CommunicationBlockContent;
@@ -135,33 +132,11 @@ public class ActionPlanner extends AbstractAgentModule.Running<PartnerRunningFlo
EvaluatorInput evaluatorInput = assemblyHelper.buildEvaluatorInput(extractorResult); EvaluatorInput evaluatorInput = assemblyHelper.buildEvaluatorInput(extractorResult);
List<EvaluatorResult> evaluatorResults = actionEvaluator.execute(evaluatorInput); // 并发操作均为访问 List<EvaluatorResult> evaluatorResults = actionEvaluator.execute(evaluatorInput); // 并发操作均为访问
handleEvaluatorResults(evaluatorResults, source, input); handleEvaluatorResults(evaluatorResults, source, input);
updateTendencyCache(evaluatorResults, input, extractorResult);
cognitionCapability.contextWorkspace().expire(TENDENCIES_EVALUATING_BLOCK_NAME, getModuleName()); cognitionCapability.contextWorkspace().expire(TENDENCIES_EVALUATING_BLOCK_NAME, getModuleName());
}); });
} }
private void updateTendencyCache(List<EvaluatorResult> evaluatorResults, String input, ExtractorResult extractorResult) {
if (!VectorClient.status) {
return;
}
executor.execute(() -> {
CacheAdjustData data = new CacheAdjustData();
List<CacheAdjustMetaData> list = new ArrayList<>();
List<String> hitTendencies = extractorResult.getTendencies();
for (EvaluatorResult result : evaluatorResults) {
CacheAdjustMetaData metaData = new CacheAdjustMetaData();
metaData.setTendency(result.getTendency());
metaData.setPassed(result.isOk());
metaData.setHit(hitTendencies.contains(result.getTendency()));
list.add(metaData);
}
data.setMetaDataList(list);
data.setInput(input);
actionCapability.updateTendencyCache(data);
});
}
private void handleEvaluatorResults(List<EvaluatorResult> evaluatorResults, String source, String input) { private void handleEvaluatorResults(List<EvaluatorResult> evaluatorResults, String source, String input) {
List<ExecutableAction> passedActions = new ArrayList<>(); List<ExecutableAction> passedActions = new ArrayList<>();
int approvedExecutableCount = 0; int approvedExecutableCount = 0;

View File

@@ -5,7 +5,6 @@ import work.slhaf.partner.api.agent.factory.capability.annotation.InjectCapabili
import work.slhaf.partner.api.agent.factory.component.abstracts.AbstractAgentModule; import work.slhaf.partner.api.agent.factory.component.abstracts.AbstractAgentModule;
import work.slhaf.partner.api.agent.model.ActivateModel; import work.slhaf.partner.api.agent.model.ActivateModel;
import work.slhaf.partner.api.agent.model.pojo.Message; import work.slhaf.partner.api.agent.model.pojo.Message;
import work.slhaf.partner.core.action.ActionCapability;
import work.slhaf.partner.core.cognition.CognitionCapability; import work.slhaf.partner.core.cognition.CognitionCapability;
import work.slhaf.partner.core.cognition.ContextBlock; import work.slhaf.partner.core.cognition.ContextBlock;
import work.slhaf.partner.module.action.planner.extractor.entity.ExtractorResult; import work.slhaf.partner.module.action.planner.extractor.entity.ExtractorResult;
@@ -13,25 +12,18 @@ import work.slhaf.partner.module.action.planner.extractor.entity.ExtractorResult
import java.util.List; import java.util.List;
public class ActionExtractor extends AbstractAgentModule.Sub<String, ExtractorResult> implements ActivateModel { public class ActionExtractor extends AbstractAgentModule.Sub<String, ExtractorResult> implements ActivateModel {
@InjectCapability
private ActionCapability actionCapability;
@InjectCapability @InjectCapability
private CognitionCapability cognitionCapability; private CognitionCapability cognitionCapability;
@Override @Override
public ExtractorResult execute(String input) { public ExtractorResult execute(String input) {
List<String> tendencyCache = actionCapability.selectTendencyCache(input);
if (tendencyCache != null && !tendencyCache.isEmpty()) {
ExtractorResult result = new ExtractorResult();
result.setTendencies(tendencyCache);
return result;
}
for (int i = 0; i < 3; i++) { for (int i = 0; i < 3; i++) {
try { try {
List<Message> messages = List.of( List<Message> messages = List.of(
cognitionCapability.contextWorkspace().resolve(List.of( cognitionCapability.contextWorkspace().resolve(List.of(
ContextBlock.VisibleDomain.ACTION, ContextBlock.VisibleDomain.COGNITION,
ContextBlock.VisibleDomain.COGNITION ContextBlock.VisibleDomain.ACTION
)).encodeToMessage(), )).encodeToMessage(),
new Message(Message.Character.USER, input) new Message(Message.Character.USER, input)
); );