mirror of
https://github.com/slhaf/Partner.git
synced 2026-05-12 08:43:02 +08:00
Action 模块语义缓存机制实现完毕,支持三种情况的语义缓存相关行为: 命中缓存且评估通过、命中缓存但评估未通过、未命中缓存但评估通过。将在评估过后步入主模块之前,进行异步更新操作(借助@AfterExecute注解,通过虚拟线程进入异步流程,在真正调用处使用平台线程加速计算)
This commit is contained in:
@@ -64,4 +64,25 @@ public abstract class VectorClient {
|
||||
return Transforms.cosineSim(a1, a2);
|
||||
}
|
||||
}
|
||||
|
||||
public float[] weightedAverage(float[] newVector, float[] primaryVector) {
|
||||
try (INDArray primary = Nd4j.create(primaryVector);
|
||||
INDArray latest = Nd4j.create(newVector)) {
|
||||
|
||||
// 1️⃣ 计算余弦相似度
|
||||
double similarity = Transforms.cosineSim(primary, latest);
|
||||
|
||||
// 2️⃣ 根据相似度决定更新比例 α(差异越大,新输入影响越强)
|
||||
double alpha = (1.0 - similarity) * 0.5;
|
||||
alpha = Math.max(0.05, Math.min(alpha, 0.5));
|
||||
|
||||
// 3️⃣ 按比例混合旧向量与新向量
|
||||
INDArray updated = primary.mul(1 - alpha).add(latest.mul(alpha));
|
||||
|
||||
// 4️⃣ 归一化结果(保持方向空间一致)
|
||||
updated = updated.div(updated.norm2Number());
|
||||
|
||||
return updated.toFloatVector();
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -22,5 +22,5 @@ public interface ActionCapability {
|
||||
|
||||
List<String> selectTendencyCache(String input);
|
||||
|
||||
void updateTendencyCache(List<CacheAdjustData> list);
|
||||
void updateTendencyCache(CacheAdjustData data);
|
||||
}
|
||||
|
||||
@@ -1,21 +1,28 @@
|
||||
package work.slhaf.partner.core.action;
|
||||
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import work.slhaf.partner.api.agent.factory.capability.annotation.Capability;
|
||||
import work.slhaf.partner.api.agent.factory.capability.annotation.CapabilityMethod;
|
||||
import work.slhaf.partner.common.vector.VectorClient;
|
||||
import work.slhaf.partner.core.PartnerCore;
|
||||
import work.slhaf.partner.core.action.entity.ActionCacheData;
|
||||
import work.slhaf.partner.core.action.entity.CacheAdjustData;
|
||||
import work.slhaf.partner.core.action.entity.CacheAdjustMetaData;
|
||||
import work.slhaf.partner.core.action.entity.MetaActionInfo;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.concurrent.Executor;
|
||||
import java.util.concurrent.Executors;
|
||||
import java.util.concurrent.locks.Lock;
|
||||
import java.util.concurrent.locks.ReentrantLock;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
@SuppressWarnings("FieldMayBeFinal")
|
||||
@Capability(value = "action")
|
||||
@Slf4j
|
||||
public class ActionCore extends PartnerCore<ActionCore> {
|
||||
|
||||
/**
|
||||
@@ -33,6 +40,10 @@ public class ActionCore extends PartnerCore<ActionCore> {
|
||||
*/
|
||||
private List<ActionCacheData> actionCache = new ArrayList<>();
|
||||
|
||||
private Lock cacheLock = new ReentrantLock();
|
||||
|
||||
private Executor executor = Executors.newFixedThreadPool(Runtime.getRuntime().availableProcessors());
|
||||
|
||||
public ActionCore() throws IOException, ClassNotFoundException {
|
||||
}
|
||||
|
||||
@@ -105,43 +116,116 @@ public class ActionCore extends PartnerCore<ActionCore> {
|
||||
}
|
||||
|
||||
@CapabilityMethod
|
||||
public void updateTendencyCache(List<CacheAdjustData> list) {
|
||||
List<CacheAdjustData> matchAndPassed = new ArrayList<>();
|
||||
List<CacheAdjustData> matchNotPassed = new ArrayList<>();
|
||||
List<CacheAdjustData> notMatchPassed = new ArrayList<>();
|
||||
|
||||
for (CacheAdjustData data : list) {
|
||||
if (data.isHit() && data.isPassed()) {
|
||||
matchAndPassed.add(data);
|
||||
} else if (data.isHit()) {
|
||||
matchNotPassed.add(data);
|
||||
} else if (!data.isPassed()) {
|
||||
notMatchPassed.add(data);
|
||||
}
|
||||
}
|
||||
|
||||
public void updateTendencyCache(CacheAdjustData data) {
|
||||
VectorClient vectorClient = VectorClient.INSTANCE;
|
||||
adjustMatchAndPassed(matchAndPassed, vectorClient);
|
||||
adjustMatchNotPassed(matchNotPassed, vectorClient);
|
||||
adjustNotMatchPassed(notMatchPassed, vectorClient);
|
||||
List<CacheAdjustMetaData> list = data.getMetaDataList();
|
||||
String input = data.getInput();
|
||||
float[] inputVector = vectorClient.compute(input);
|
||||
|
||||
List<CacheAdjustMetaData> matchAndPassed = new ArrayList<>();
|
||||
List<CacheAdjustMetaData> matchNotPassed = new ArrayList<>();
|
||||
List<CacheAdjustMetaData> notMatchPassed = new ArrayList<>();
|
||||
|
||||
for (CacheAdjustMetaData metaData : list) {
|
||||
if (metaData.isHit() && metaData.isPassed()) {
|
||||
matchAndPassed.add(metaData);
|
||||
} else if (metaData.isHit()) {
|
||||
matchNotPassed.add(metaData);
|
||||
} else if (!metaData.isPassed()) {
|
||||
notMatchPassed.add(metaData);
|
||||
}
|
||||
}
|
||||
|
||||
executor.execute(() -> adjustMatchAndPassed(matchAndPassed, inputVector, input, vectorClient));
|
||||
executor.execute(() -> adjustMatchNotPassed(matchNotPassed, vectorClient));
|
||||
executor.execute(() -> adjustNotMatchPassed(notMatchPassed, inputVector, input, vectorClient));
|
||||
}
|
||||
|
||||
/**
|
||||
* 命中缓存且评估通过时,根据输入内容的语义向量与现有的输入语义向量进行带权移动平均,以相似度为权重
|
||||
* 命中缓存且评估通过时
|
||||
*
|
||||
* @param matchAndPassed 该类型的带调整缓存信息列表
|
||||
* @param inputVector 本次输入内容的语义向量
|
||||
* @param vectorClient 向量客户端
|
||||
*/
|
||||
private void adjustMatchAndPassed(List<CacheAdjustData> matchAndPassed, VectorClient vectorClient) {
|
||||
|
||||
private void adjustMatchAndPassed(List<CacheAdjustMetaData> matchAndPassed, float[] inputVector, String input, VectorClient vectorClient) {
|
||||
matchAndPassed.forEach(adjustData -> {
|
||||
//获取原始缓存条目
|
||||
String tendency = adjustData.getTendency();
|
||||
ActionCacheData primaryCacheData = selectCacheData(tendency);
|
||||
if (primaryCacheData == null) {
|
||||
return;
|
||||
}
|
||||
primaryCacheData.updateAfterMatchAndPassed(inputVector, vectorClient, input);
|
||||
});
|
||||
}
|
||||
|
||||
private void adjustMatchNotPassed(List<CacheAdjustData> matchNotPassed, VectorClient vectorClient) {
|
||||
|
||||
/**
|
||||
* 针对命中缓存、但评估未通过的条目与输入进行处理
|
||||
*
|
||||
* @param matchNotPassed 该类型的带调整缓存信息列表
|
||||
* @param vectorClient 向量客户端
|
||||
*/
|
||||
private void adjustMatchNotPassed(List<CacheAdjustMetaData> matchNotPassed, VectorClient vectorClient) {
|
||||
List<ActionCacheData> toRemove = new ArrayList<>();
|
||||
matchNotPassed.forEach(adjustData -> {
|
||||
//获取原始缓存条目
|
||||
String tendency = adjustData.getTendency();
|
||||
ActionCacheData primaryCacheData = selectCacheData(tendency);
|
||||
if (primaryCacheData == null) {
|
||||
return;
|
||||
}
|
||||
boolean remove = primaryCacheData.updateAfterMatchNotPassed(vectorClient);
|
||||
if (remove) {
|
||||
toRemove.add(primaryCacheData);
|
||||
}
|
||||
|
||||
private void adjustNotMatchPassed(List<CacheAdjustData> notMatchPassed, VectorClient vectorClient) {
|
||||
});
|
||||
cacheLock.lock();
|
||||
actionCache.removeAll(toRemove);
|
||||
cacheLock.unlock();
|
||||
}
|
||||
|
||||
/**
|
||||
* 针对未命中但评估通过的缓存做出调整:
|
||||
* <ol>
|
||||
* <h3>如果存在缓存条目</h3>
|
||||
* <li>
|
||||
* 若已生效,但此时未匹配到则说明尚未生效或者阈值、向量{@link ActionCacheData#getInputVector()}存在问题,调低阈值,同时带权移动平均
|
||||
* </li>
|
||||
* <li>
|
||||
* 若未生效,则只增加计数并带权移动平均
|
||||
* </li>
|
||||
* </ol>
|
||||
* 如果不存在缓存条目,则新增并填充字段
|
||||
*
|
||||
* @param notMatchPassed 该类型的带调整缓存信息列表
|
||||
* @param inputVector 本次输入内容的语义向量
|
||||
* @param input 本次输入内容
|
||||
* @param vectorClient 向量客户端
|
||||
*/
|
||||
private void adjustNotMatchPassed(List<CacheAdjustMetaData> notMatchPassed, float[] inputVector, String input, VectorClient vectorClient) {
|
||||
notMatchPassed.forEach(adjustData -> {
|
||||
//获取原始缓存条目
|
||||
String tendency = adjustData.getTendency();
|
||||
ActionCacheData primaryCacheData = selectCacheData(tendency);
|
||||
float[] tendencyVector = vectorClient.compute(tendency);
|
||||
if (primaryCacheData == null) {
|
||||
actionCache.add(new ActionCacheData(tendency, tendencyVector, inputVector, input));
|
||||
return;
|
||||
}
|
||||
primaryCacheData.updateAfterNotMatchPassed(input, inputVector, tendencyVector, vectorClient);
|
||||
});
|
||||
}
|
||||
|
||||
private ActionCacheData selectCacheData(String tendency) {
|
||||
for (ActionCacheData actionCacheData : actionCache) {
|
||||
if (actionCacheData.getTendency().equals(tendency)) {
|
||||
return actionCacheData;
|
||||
}
|
||||
}
|
||||
log.warn("[{}] 未找到行为倾向[{}]对应的缓存条目,可能是代码逻辑存在错误", getCoreKey(), tendency);
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
|
||||
@@ -1,25 +1,181 @@
|
||||
package work.slhaf.partner.core.action.entity;
|
||||
|
||||
import lombok.Data;
|
||||
import work.slhaf.partner.common.vector.VectorClient;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
@Data
|
||||
public class ActionCacheData {
|
||||
private boolean activated;
|
||||
private int inputMatchCount;
|
||||
private boolean activated = false;
|
||||
private int inputMatchCount = 1;
|
||||
|
||||
private float[] inputVector;
|
||||
private float[] tendencyVector;
|
||||
private String tendency;
|
||||
private double threshold;
|
||||
private double threshold = 0.75;
|
||||
|
||||
private List<String> validSamples = new ArrayList<>();
|
||||
private int failedCount;
|
||||
private Type type;
|
||||
private int failedCount = 0;
|
||||
private Type type = Type.PRIMARY;
|
||||
|
||||
enum Type {
|
||||
public ActionCacheData(String tendency, float[] tendencyVector, float[] inputVector, String input) {
|
||||
this.tendency = tendency;
|
||||
this.inputVector = inputVector;
|
||||
this.tendencyVector = tendencyVector;
|
||||
this.validSamples.add(input);
|
||||
}
|
||||
|
||||
/**
|
||||
* 命中缓存且评估通过时,根据输入内容的语义向量与现有的输入语义向量进行带权移动平均,以相似度为权重,同时降低失败计数,为零时置为上一级缓存类型{@link ActionCacheData.Type}
|
||||
*
|
||||
* @param inputVector 本次输入内容对应的语义向量
|
||||
* @param vectorClient 向量客户端
|
||||
* @param input 本次输入内容
|
||||
*/
|
||||
public synchronized void updateAfterMatchAndPassed(float[] inputVector, VectorClient vectorClient, String input) {
|
||||
updateInputVector(inputVector, vectorClient);
|
||||
addValidSample(input);
|
||||
reduceFailedCount();
|
||||
updateType();
|
||||
addInputMatchCount();
|
||||
}
|
||||
|
||||
private void updateType() {
|
||||
if (this.failedCount == 0) {
|
||||
this.type = switch (type) {
|
||||
case PRIMARY, REBUILD_V1 -> ActionCacheData.Type.PRIMARY;
|
||||
case REBUILD_V2 -> ActionCacheData.Type.REBUILD_V1;
|
||||
case REBUILD_V3 -> ActionCacheData.Type.REBUILD_V2;
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
private void reduceFailedCount() {
|
||||
this.failedCount = Math.max(this.failedCount - 1, 0);
|
||||
}
|
||||
|
||||
private void addValidSample(String input) {
|
||||
if (this.validSamples.size() == 12) {
|
||||
this.validSamples.removeFirst();
|
||||
}
|
||||
this.validSamples.add(input);
|
||||
}
|
||||
|
||||
private void updateInputVector(float[] inputVector, VectorClient vectorClient) {
|
||||
this.inputVector = vectorClient.weightedAverage(inputVector, this.inputVector);
|
||||
}
|
||||
|
||||
/**
|
||||
* 针对命中缓存、但评估未通过的条目与输入进行处理: 增加失败计数(必要时重建并更新类型等级)、调高阈值(0.02),由于缓存匹配但评估未通过,所以不进行带权移动平均
|
||||
*
|
||||
* @param vectorClient 向量客户端
|
||||
* @return 是否需要删除(已在REBUILD_V3状态且达到最大误判次数的)
|
||||
*/
|
||||
public synchronized boolean updateAfterMatchNotPassed(VectorClient vectorClient) {
|
||||
adjustThreshold();
|
||||
addFailedCount();
|
||||
if (this.failedCount < 3) {
|
||||
return false;
|
||||
}
|
||||
if (this.type == Type.REBUILD_V3) {
|
||||
return true;
|
||||
}
|
||||
rebuildAndSwitchType(vectorClient);
|
||||
return false;
|
||||
}
|
||||
|
||||
private void rebuildAndSwitchType(VectorClient vectorClient) {
|
||||
this.type = switch (this.type) {
|
||||
case PRIMARY -> {
|
||||
//样本顺序反转后,以全部样本重建
|
||||
this.validSamples = this.validSamples.reversed();
|
||||
rebuildWithSamples(vectorClient);
|
||||
yield Type.REBUILD_V1;
|
||||
}
|
||||
case REBUILD_V1 -> {
|
||||
//截取后一半样本,反转后以此重建
|
||||
List<String> temp = this.validSamples.subList(this.validSamples.size() / 2, this.validSamples.size());
|
||||
this.validSamples = temp.reversed();
|
||||
rebuildWithSamples(vectorClient);
|
||||
yield Type.REBUILD_V2;
|
||||
}
|
||||
case REBUILD_V2 -> {
|
||||
//截取后四分之一样本,反转后以此重建
|
||||
List<String> temp = this.validSamples.subList(this.validSamples.size() / 4, this.validSamples.size());
|
||||
this.validSamples = temp.reversed();
|
||||
rebuildWithSamples(vectorClient);
|
||||
yield Type.REBUILD_V3;
|
||||
}
|
||||
case REBUILD_V3 -> null;
|
||||
};
|
||||
//阈值减0.05,防止重建后一直升高
|
||||
this.threshold = Math.max(this.threshold - 0.05, 0.75);
|
||||
this.failedCount = 0;
|
||||
}
|
||||
|
||||
private void rebuildWithSamples(VectorClient vectorClient) {
|
||||
for (int i = 0; i < this.validSamples.size(); i++) {
|
||||
String sample = this.validSamples.get(i);
|
||||
if (i == 0) {
|
||||
this.inputVector = vectorClient.compute(sample);
|
||||
} else {
|
||||
float[] newSampleVector = vectorClient.compute(sample);
|
||||
this.inputVector = vectorClient.weightedAverage(this.inputVector, newSampleVector);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private void addFailedCount() {
|
||||
this.failedCount = Math.min(this.failedCount + 1, 3);
|
||||
}
|
||||
|
||||
private void adjustThreshold() {
|
||||
double newThreshold = this.threshold + 0.03;
|
||||
this.threshold = Math.min(newThreshold, 0.95);
|
||||
}
|
||||
|
||||
/**
|
||||
* 针对未命中但评估通过的已存在缓存做出调整:
|
||||
* <ol>
|
||||
* <li>
|
||||
* 若已生效,但此时未匹配到则说明阈值或者向量{@link ActionCacheData#getInputVector()}存在问题,调低阈值,同时带权移动平均
|
||||
* </li>
|
||||
* <li>
|
||||
* 若未生效,则只增加计数并带权移动平均
|
||||
* </li>
|
||||
* </ol>
|
||||
*
|
||||
* @param input 本次输入内容
|
||||
* @param inputVector 本次输入内容对应的语义向量
|
||||
* @param tendencyVector 本次倾向对应的语义向量
|
||||
* @param vectorClient 向量客户端
|
||||
*/
|
||||
public synchronized void updateAfterNotMatchPassed(String input, float[] inputVector, float[] tendencyVector, VectorClient vectorClient) {
|
||||
if (this.activated) {
|
||||
reduceThreshold();
|
||||
this.inputVector = vectorClient.weightedAverage(inputVector, this.inputVector);
|
||||
} else {
|
||||
addValidSample(input);
|
||||
this.tendencyVector = vectorClient.weightedAverage(tendencyVector, this.tendencyVector);
|
||||
addInputMatchCount();
|
||||
}
|
||||
}
|
||||
|
||||
private void reduceThreshold() {
|
||||
double newThreshold = this.threshold - 0.02;
|
||||
this.threshold = Math.max(newThreshold, 0.75);
|
||||
}
|
||||
|
||||
private void addInputMatchCount() {
|
||||
this.inputMatchCount += 1;
|
||||
if (inputMatchCount >= 6) {
|
||||
this.activated = true;
|
||||
}
|
||||
}
|
||||
|
||||
public enum Type {
|
||||
PRIMARY, REBUILD_V1, REBUILD_V2, REBUILD_V3
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,10 +2,10 @@ package work.slhaf.partner.core.action.entity;
|
||||
|
||||
import lombok.Data;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
@Data
|
||||
public class CacheAdjustData {
|
||||
private String input;
|
||||
private String tendency;
|
||||
private boolean passed;
|
||||
private boolean hit;
|
||||
private List<CacheAdjustMetaData> metaDataList;
|
||||
}
|
||||
|
||||
@@ -0,0 +1,10 @@
|
||||
package work.slhaf.partner.core.action.entity;
|
||||
|
||||
import lombok.Data;
|
||||
|
||||
@Data
|
||||
public class CacheAdjustMetaData {
|
||||
private String tendency;
|
||||
private boolean passed;
|
||||
private boolean hit;
|
||||
}
|
||||
@@ -1,11 +1,13 @@
|
||||
package work.slhaf.partner.module.modules.action.planner;
|
||||
|
||||
import work.slhaf.partner.api.agent.factory.capability.annotation.InjectCapability;
|
||||
import work.slhaf.partner.api.agent.factory.module.annotation.AfterExecute;
|
||||
import work.slhaf.partner.api.agent.factory.module.annotation.AgentModule;
|
||||
import work.slhaf.partner.api.agent.factory.module.annotation.Init;
|
||||
import work.slhaf.partner.api.agent.factory.module.annotation.InjectModule;
|
||||
import work.slhaf.partner.api.chat.pojo.Message;
|
||||
import work.slhaf.partner.common.thread.InteractionThreadPoolExecutor;
|
||||
import work.slhaf.partner.common.vector.VectorClient;
|
||||
import work.slhaf.partner.core.action.ActionCapability;
|
||||
import work.slhaf.partner.core.action.entity.*;
|
||||
import work.slhaf.partner.core.cognation.CognationCapability;
|
||||
@@ -83,26 +85,30 @@ public class ActionPlanner extends PreRunningModule {
|
||||
}
|
||||
EvaluatorInput evaluatorInput = assemblyHelper.buildEvaluatorInput(extractorResult, context.getUserId());
|
||||
List<EvaluatorResult> evaluatorResults = actionEvaluator.execute(evaluatorInput); //并发操作均为访问
|
||||
if (extractorResult.isCacheEnabled())
|
||||
updateTendencyCache(evaluatorResults, context.getInput(), extractorResult);
|
||||
setupActionInfo(evaluatorResults, context);
|
||||
return null;
|
||||
});
|
||||
}
|
||||
|
||||
@AfterExecute
|
||||
private void updateTendencyCache(List<EvaluatorResult> evaluatorResults, String input, ExtractorResult extractorResult) {
|
||||
if (!VectorClient.status) {
|
||||
return;
|
||||
}
|
||||
executor.execute(() -> {
|
||||
List<CacheAdjustData> list = new ArrayList<>();
|
||||
CacheAdjustData data = new CacheAdjustData();
|
||||
List<CacheAdjustMetaData> list = new ArrayList<>();
|
||||
List<String> hitTendencies = extractorResult.getTendencies();
|
||||
for (EvaluatorResult result : evaluatorResults) {
|
||||
CacheAdjustData data = new CacheAdjustData();
|
||||
data.setTendency(result.getTendency());
|
||||
data.setInput(input);
|
||||
data.setPassed(result.isOk());
|
||||
data.setHit(hitTendencies.contains(result.getTendency()));
|
||||
list.add(data);
|
||||
CacheAdjustMetaData metaData = new CacheAdjustMetaData();
|
||||
metaData.setTendency(result.getTendency());
|
||||
metaData.setPassed(result.isOk());
|
||||
metaData.setHit(hitTendencies.contains(result.getTendency()));
|
||||
list.add(metaData);
|
||||
}
|
||||
actionCapability.updateTendencyCache(list);
|
||||
data.setMetaDataList(list);
|
||||
data.setInput(input);
|
||||
actionCapability.updateTendencyCache(data);
|
||||
});
|
||||
|
||||
}
|
||||
|
||||
@@ -24,7 +24,6 @@ public class ActionExtractor extends AgentRunningSubModule<ExtractorInput, Extra
|
||||
public ExtractorResult execute(ExtractorInput data) {
|
||||
ExtractorResult result = new ExtractorResult();
|
||||
List<String> tendencyCache = actionCapability.selectTendencyCache(data.getInput());
|
||||
result.setCacheEnabled(tendencyCache != null);
|
||||
if (tendencyCache != null && !tendencyCache.isEmpty()) {
|
||||
result.setTendencies(tendencyCache);
|
||||
return result;
|
||||
|
||||
@@ -7,6 +7,5 @@ import java.util.List;
|
||||
|
||||
@Data
|
||||
public class ExtractorResult {
|
||||
private boolean cacheEnabled;
|
||||
private List<String> tendencies = new ArrayList<>();
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user