mirror of
https://github.com/slhaf/Partner.git
synced 2026-05-12 16:53:04 +08:00
Action 模块语义缓存机制实现完毕,支持三种情况的语义缓存相关行为: 命中缓存且评估通过、命中缓存但评估未通过、未命中缓存但评估通过。将在评估过后步入主模块之前,进行异步更新操作(借助@AfterExecute注解,通过虚拟线程进入异步流程,在真正调用处使用平台线程加速计算)
This commit is contained in:
@@ -64,4 +64,25 @@ public abstract class VectorClient {
|
|||||||
return Transforms.cosineSim(a1, a2);
|
return Transforms.cosineSim(a1, a2);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public float[] weightedAverage(float[] newVector, float[] primaryVector) {
|
||||||
|
try (INDArray primary = Nd4j.create(primaryVector);
|
||||||
|
INDArray latest = Nd4j.create(newVector)) {
|
||||||
|
|
||||||
|
// 1️⃣ 计算余弦相似度
|
||||||
|
double similarity = Transforms.cosineSim(primary, latest);
|
||||||
|
|
||||||
|
// 2️⃣ 根据相似度决定更新比例 α(差异越大,新输入影响越强)
|
||||||
|
double alpha = (1.0 - similarity) * 0.5;
|
||||||
|
alpha = Math.max(0.05, Math.min(alpha, 0.5));
|
||||||
|
|
||||||
|
// 3️⃣ 按比例混合旧向量与新向量
|
||||||
|
INDArray updated = primary.mul(1 - alpha).add(latest.mul(alpha));
|
||||||
|
|
||||||
|
// 4️⃣ 归一化结果(保持方向空间一致)
|
||||||
|
updated = updated.div(updated.norm2Number());
|
||||||
|
|
||||||
|
return updated.toFloatVector();
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
@@ -22,5 +22,5 @@ public interface ActionCapability {
|
|||||||
|
|
||||||
List<String> selectTendencyCache(String input);
|
List<String> selectTendencyCache(String input);
|
||||||
|
|
||||||
void updateTendencyCache(List<CacheAdjustData> list);
|
void updateTendencyCache(CacheAdjustData data);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,21 +1,28 @@
|
|||||||
package work.slhaf.partner.core.action;
|
package work.slhaf.partner.core.action;
|
||||||
|
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import work.slhaf.partner.api.agent.factory.capability.annotation.Capability;
|
import work.slhaf.partner.api.agent.factory.capability.annotation.Capability;
|
||||||
import work.slhaf.partner.api.agent.factory.capability.annotation.CapabilityMethod;
|
import work.slhaf.partner.api.agent.factory.capability.annotation.CapabilityMethod;
|
||||||
import work.slhaf.partner.common.vector.VectorClient;
|
import work.slhaf.partner.common.vector.VectorClient;
|
||||||
import work.slhaf.partner.core.PartnerCore;
|
import work.slhaf.partner.core.PartnerCore;
|
||||||
import work.slhaf.partner.core.action.entity.ActionCacheData;
|
import work.slhaf.partner.core.action.entity.ActionCacheData;
|
||||||
import work.slhaf.partner.core.action.entity.CacheAdjustData;
|
import work.slhaf.partner.core.action.entity.CacheAdjustData;
|
||||||
|
import work.slhaf.partner.core.action.entity.CacheAdjustMetaData;
|
||||||
import work.slhaf.partner.core.action.entity.MetaActionInfo;
|
import work.slhaf.partner.core.action.entity.MetaActionInfo;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.HashMap;
|
import java.util.HashMap;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
import java.util.concurrent.Executor;
|
||||||
|
import java.util.concurrent.Executors;
|
||||||
|
import java.util.concurrent.locks.Lock;
|
||||||
|
import java.util.concurrent.locks.ReentrantLock;
|
||||||
import java.util.stream.Collectors;
|
import java.util.stream.Collectors;
|
||||||
|
|
||||||
@SuppressWarnings("FieldMayBeFinal")
|
@SuppressWarnings("FieldMayBeFinal")
|
||||||
@Capability(value = "action")
|
@Capability(value = "action")
|
||||||
|
@Slf4j
|
||||||
public class ActionCore extends PartnerCore<ActionCore> {
|
public class ActionCore extends PartnerCore<ActionCore> {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@@ -33,6 +40,10 @@ public class ActionCore extends PartnerCore<ActionCore> {
|
|||||||
*/
|
*/
|
||||||
private List<ActionCacheData> actionCache = new ArrayList<>();
|
private List<ActionCacheData> actionCache = new ArrayList<>();
|
||||||
|
|
||||||
|
private Lock cacheLock = new ReentrantLock();
|
||||||
|
|
||||||
|
private Executor executor = Executors.newFixedThreadPool(Runtime.getRuntime().availableProcessors());
|
||||||
|
|
||||||
public ActionCore() throws IOException, ClassNotFoundException {
|
public ActionCore() throws IOException, ClassNotFoundException {
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -105,43 +116,116 @@ public class ActionCore extends PartnerCore<ActionCore> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
@CapabilityMethod
|
@CapabilityMethod
|
||||||
public void updateTendencyCache(List<CacheAdjustData> list) {
|
public void updateTendencyCache(CacheAdjustData data) {
|
||||||
List<CacheAdjustData> matchAndPassed = new ArrayList<>();
|
|
||||||
List<CacheAdjustData> matchNotPassed = new ArrayList<>();
|
|
||||||
List<CacheAdjustData> notMatchPassed = new ArrayList<>();
|
|
||||||
|
|
||||||
for (CacheAdjustData data : list) {
|
|
||||||
if (data.isHit() && data.isPassed()) {
|
|
||||||
matchAndPassed.add(data);
|
|
||||||
} else if (data.isHit()) {
|
|
||||||
matchNotPassed.add(data);
|
|
||||||
} else if (!data.isPassed()) {
|
|
||||||
notMatchPassed.add(data);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
VectorClient vectorClient = VectorClient.INSTANCE;
|
VectorClient vectorClient = VectorClient.INSTANCE;
|
||||||
adjustMatchAndPassed(matchAndPassed, vectorClient);
|
List<CacheAdjustMetaData> list = data.getMetaDataList();
|
||||||
adjustMatchNotPassed(matchNotPassed, vectorClient);
|
String input = data.getInput();
|
||||||
adjustNotMatchPassed(notMatchPassed, vectorClient);
|
float[] inputVector = vectorClient.compute(input);
|
||||||
|
|
||||||
|
List<CacheAdjustMetaData> matchAndPassed = new ArrayList<>();
|
||||||
|
List<CacheAdjustMetaData> matchNotPassed = new ArrayList<>();
|
||||||
|
List<CacheAdjustMetaData> notMatchPassed = new ArrayList<>();
|
||||||
|
|
||||||
|
for (CacheAdjustMetaData metaData : list) {
|
||||||
|
if (metaData.isHit() && metaData.isPassed()) {
|
||||||
|
matchAndPassed.add(metaData);
|
||||||
|
} else if (metaData.isHit()) {
|
||||||
|
matchNotPassed.add(metaData);
|
||||||
|
} else if (!metaData.isPassed()) {
|
||||||
|
notMatchPassed.add(metaData);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
executor.execute(() -> adjustMatchAndPassed(matchAndPassed, inputVector, input, vectorClient));
|
||||||
|
executor.execute(() -> adjustMatchNotPassed(matchNotPassed, vectorClient));
|
||||||
|
executor.execute(() -> adjustNotMatchPassed(notMatchPassed, inputVector, input, vectorClient));
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 命中缓存且评估通过时,根据输入内容的语义向量与现有的输入语义向量进行带权移动平均,以相似度为权重
|
* 命中缓存且评估通过时
|
||||||
*
|
*
|
||||||
* @param matchAndPassed 该类型的带调整缓存信息列表
|
* @param matchAndPassed 该类型的带调整缓存信息列表
|
||||||
|
* @param inputVector 本次输入内容的语义向量
|
||||||
* @param vectorClient 向量客户端
|
* @param vectorClient 向量客户端
|
||||||
*/
|
*/
|
||||||
private void adjustMatchAndPassed(List<CacheAdjustData> matchAndPassed, VectorClient vectorClient) {
|
private void adjustMatchAndPassed(List<CacheAdjustMetaData> matchAndPassed, float[] inputVector, String input, VectorClient vectorClient) {
|
||||||
|
matchAndPassed.forEach(adjustData -> {
|
||||||
|
//获取原始缓存条目
|
||||||
|
String tendency = adjustData.getTendency();
|
||||||
|
ActionCacheData primaryCacheData = selectCacheData(tendency);
|
||||||
|
if (primaryCacheData == null) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
primaryCacheData.updateAfterMatchAndPassed(inputVector, vectorClient, input);
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
private void adjustMatchNotPassed(List<CacheAdjustData> matchNotPassed, VectorClient vectorClient) {
|
/**
|
||||||
|
* 针对命中缓存、但评估未通过的条目与输入进行处理
|
||||||
|
*
|
||||||
|
* @param matchNotPassed 该类型的带调整缓存信息列表
|
||||||
|
* @param vectorClient 向量客户端
|
||||||
|
*/
|
||||||
|
private void adjustMatchNotPassed(List<CacheAdjustMetaData> matchNotPassed, VectorClient vectorClient) {
|
||||||
|
List<ActionCacheData> toRemove = new ArrayList<>();
|
||||||
|
matchNotPassed.forEach(adjustData -> {
|
||||||
|
//获取原始缓存条目
|
||||||
|
String tendency = adjustData.getTendency();
|
||||||
|
ActionCacheData primaryCacheData = selectCacheData(tendency);
|
||||||
|
if (primaryCacheData == null) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
boolean remove = primaryCacheData.updateAfterMatchNotPassed(vectorClient);
|
||||||
|
if (remove) {
|
||||||
|
toRemove.add(primaryCacheData);
|
||||||
}
|
}
|
||||||
|
|
||||||
private void adjustNotMatchPassed(List<CacheAdjustData> notMatchPassed, VectorClient vectorClient) {
|
});
|
||||||
|
cacheLock.lock();
|
||||||
|
actionCache.removeAll(toRemove);
|
||||||
|
cacheLock.unlock();
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 针对未命中但评估通过的缓存做出调整:
|
||||||
|
* <ol>
|
||||||
|
* <h3>如果存在缓存条目</h3>
|
||||||
|
* <li>
|
||||||
|
* 若已生效,但此时未匹配到则说明尚未生效或者阈值、向量{@link ActionCacheData#getInputVector()}存在问题,调低阈值,同时带权移动平均
|
||||||
|
* </li>
|
||||||
|
* <li>
|
||||||
|
* 若未生效,则只增加计数并带权移动平均
|
||||||
|
* </li>
|
||||||
|
* </ol>
|
||||||
|
* 如果不存在缓存条目,则新增并填充字段
|
||||||
|
*
|
||||||
|
* @param notMatchPassed 该类型的带调整缓存信息列表
|
||||||
|
* @param inputVector 本次输入内容的语义向量
|
||||||
|
* @param input 本次输入内容
|
||||||
|
* @param vectorClient 向量客户端
|
||||||
|
*/
|
||||||
|
private void adjustNotMatchPassed(List<CacheAdjustMetaData> notMatchPassed, float[] inputVector, String input, VectorClient vectorClient) {
|
||||||
|
notMatchPassed.forEach(adjustData -> {
|
||||||
|
//获取原始缓存条目
|
||||||
|
String tendency = adjustData.getTendency();
|
||||||
|
ActionCacheData primaryCacheData = selectCacheData(tendency);
|
||||||
|
float[] tendencyVector = vectorClient.compute(tendency);
|
||||||
|
if (primaryCacheData == null) {
|
||||||
|
actionCache.add(new ActionCacheData(tendency, tendencyVector, inputVector, input));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
primaryCacheData.updateAfterNotMatchPassed(input, inputVector, tendencyVector, vectorClient);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
private ActionCacheData selectCacheData(String tendency) {
|
||||||
|
for (ActionCacheData actionCacheData : actionCache) {
|
||||||
|
if (actionCacheData.getTendency().equals(tendency)) {
|
||||||
|
return actionCacheData;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
log.warn("[{}] 未找到行为倾向[{}]对应的缓存条目,可能是代码逻辑存在错误", getCoreKey(), tendency);
|
||||||
|
return null;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|||||||
@@ -1,25 +1,181 @@
|
|||||||
package work.slhaf.partner.core.action.entity;
|
package work.slhaf.partner.core.action.entity;
|
||||||
|
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
|
import work.slhaf.partner.common.vector.VectorClient;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
public class ActionCacheData {
|
public class ActionCacheData {
|
||||||
private boolean activated;
|
private boolean activated = false;
|
||||||
private int inputMatchCount;
|
private int inputMatchCount = 1;
|
||||||
|
|
||||||
private float[] inputVector;
|
private float[] inputVector;
|
||||||
private float[] tendencyVector;
|
private float[] tendencyVector;
|
||||||
private String tendency;
|
private String tendency;
|
||||||
private double threshold;
|
private double threshold = 0.75;
|
||||||
|
|
||||||
private List<String> validSamples = new ArrayList<>();
|
private List<String> validSamples = new ArrayList<>();
|
||||||
private int failedCount;
|
private int failedCount = 0;
|
||||||
private Type type;
|
private Type type = Type.PRIMARY;
|
||||||
|
|
||||||
enum Type {
|
public ActionCacheData(String tendency, float[] tendencyVector, float[] inputVector, String input) {
|
||||||
|
this.tendency = tendency;
|
||||||
|
this.inputVector = inputVector;
|
||||||
|
this.tendencyVector = tendencyVector;
|
||||||
|
this.validSamples.add(input);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 命中缓存且评估通过时,根据输入内容的语义向量与现有的输入语义向量进行带权移动平均,以相似度为权重,同时降低失败计数,为零时置为上一级缓存类型{@link ActionCacheData.Type}
|
||||||
|
*
|
||||||
|
* @param inputVector 本次输入内容对应的语义向量
|
||||||
|
* @param vectorClient 向量客户端
|
||||||
|
* @param input 本次输入内容
|
||||||
|
*/
|
||||||
|
public synchronized void updateAfterMatchAndPassed(float[] inputVector, VectorClient vectorClient, String input) {
|
||||||
|
updateInputVector(inputVector, vectorClient);
|
||||||
|
addValidSample(input);
|
||||||
|
reduceFailedCount();
|
||||||
|
updateType();
|
||||||
|
addInputMatchCount();
|
||||||
|
}
|
||||||
|
|
||||||
|
private void updateType() {
|
||||||
|
if (this.failedCount == 0) {
|
||||||
|
this.type = switch (type) {
|
||||||
|
case PRIMARY, REBUILD_V1 -> ActionCacheData.Type.PRIMARY;
|
||||||
|
case REBUILD_V2 -> ActionCacheData.Type.REBUILD_V1;
|
||||||
|
case REBUILD_V3 -> ActionCacheData.Type.REBUILD_V2;
|
||||||
|
};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private void reduceFailedCount() {
|
||||||
|
this.failedCount = Math.max(this.failedCount - 1, 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
private void addValidSample(String input) {
|
||||||
|
if (this.validSamples.size() == 12) {
|
||||||
|
this.validSamples.removeFirst();
|
||||||
|
}
|
||||||
|
this.validSamples.add(input);
|
||||||
|
}
|
||||||
|
|
||||||
|
private void updateInputVector(float[] inputVector, VectorClient vectorClient) {
|
||||||
|
this.inputVector = vectorClient.weightedAverage(inputVector, this.inputVector);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 针对命中缓存、但评估未通过的条目与输入进行处理: 增加失败计数(必要时重建并更新类型等级)、调高阈值(0.02),由于缓存匹配但评估未通过,所以不进行带权移动平均
|
||||||
|
*
|
||||||
|
* @param vectorClient 向量客户端
|
||||||
|
* @return 是否需要删除(已在REBUILD_V3状态且达到最大误判次数的)
|
||||||
|
*/
|
||||||
|
public synchronized boolean updateAfterMatchNotPassed(VectorClient vectorClient) {
|
||||||
|
adjustThreshold();
|
||||||
|
addFailedCount();
|
||||||
|
if (this.failedCount < 3) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (this.type == Type.REBUILD_V3) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
rebuildAndSwitchType(vectorClient);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
private void rebuildAndSwitchType(VectorClient vectorClient) {
|
||||||
|
this.type = switch (this.type) {
|
||||||
|
case PRIMARY -> {
|
||||||
|
//样本顺序反转后,以全部样本重建
|
||||||
|
this.validSamples = this.validSamples.reversed();
|
||||||
|
rebuildWithSamples(vectorClient);
|
||||||
|
yield Type.REBUILD_V1;
|
||||||
|
}
|
||||||
|
case REBUILD_V1 -> {
|
||||||
|
//截取后一半样本,反转后以此重建
|
||||||
|
List<String> temp = this.validSamples.subList(this.validSamples.size() / 2, this.validSamples.size());
|
||||||
|
this.validSamples = temp.reversed();
|
||||||
|
rebuildWithSamples(vectorClient);
|
||||||
|
yield Type.REBUILD_V2;
|
||||||
|
}
|
||||||
|
case REBUILD_V2 -> {
|
||||||
|
//截取后四分之一样本,反转后以此重建
|
||||||
|
List<String> temp = this.validSamples.subList(this.validSamples.size() / 4, this.validSamples.size());
|
||||||
|
this.validSamples = temp.reversed();
|
||||||
|
rebuildWithSamples(vectorClient);
|
||||||
|
yield Type.REBUILD_V3;
|
||||||
|
}
|
||||||
|
case REBUILD_V3 -> null;
|
||||||
|
};
|
||||||
|
//阈值减0.05,防止重建后一直升高
|
||||||
|
this.threshold = Math.max(this.threshold - 0.05, 0.75);
|
||||||
|
this.failedCount = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
private void rebuildWithSamples(VectorClient vectorClient) {
|
||||||
|
for (int i = 0; i < this.validSamples.size(); i++) {
|
||||||
|
String sample = this.validSamples.get(i);
|
||||||
|
if (i == 0) {
|
||||||
|
this.inputVector = vectorClient.compute(sample);
|
||||||
|
} else {
|
||||||
|
float[] newSampleVector = vectorClient.compute(sample);
|
||||||
|
this.inputVector = vectorClient.weightedAverage(this.inputVector, newSampleVector);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private void addFailedCount() {
|
||||||
|
this.failedCount = Math.min(this.failedCount + 1, 3);
|
||||||
|
}
|
||||||
|
|
||||||
|
private void adjustThreshold() {
|
||||||
|
double newThreshold = this.threshold + 0.03;
|
||||||
|
this.threshold = Math.min(newThreshold, 0.95);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 针对未命中但评估通过的已存在缓存做出调整:
|
||||||
|
* <ol>
|
||||||
|
* <li>
|
||||||
|
* 若已生效,但此时未匹配到则说明阈值或者向量{@link ActionCacheData#getInputVector()}存在问题,调低阈值,同时带权移动平均
|
||||||
|
* </li>
|
||||||
|
* <li>
|
||||||
|
* 若未生效,则只增加计数并带权移动平均
|
||||||
|
* </li>
|
||||||
|
* </ol>
|
||||||
|
*
|
||||||
|
* @param input 本次输入内容
|
||||||
|
* @param inputVector 本次输入内容对应的语义向量
|
||||||
|
* @param tendencyVector 本次倾向对应的语义向量
|
||||||
|
* @param vectorClient 向量客户端
|
||||||
|
*/
|
||||||
|
public synchronized void updateAfterNotMatchPassed(String input, float[] inputVector, float[] tendencyVector, VectorClient vectorClient) {
|
||||||
|
if (this.activated) {
|
||||||
|
reduceThreshold();
|
||||||
|
this.inputVector = vectorClient.weightedAverage(inputVector, this.inputVector);
|
||||||
|
} else {
|
||||||
|
addValidSample(input);
|
||||||
|
this.tendencyVector = vectorClient.weightedAverage(tendencyVector, this.tendencyVector);
|
||||||
|
addInputMatchCount();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private void reduceThreshold() {
|
||||||
|
double newThreshold = this.threshold - 0.02;
|
||||||
|
this.threshold = Math.max(newThreshold, 0.75);
|
||||||
|
}
|
||||||
|
|
||||||
|
private void addInputMatchCount() {
|
||||||
|
this.inputMatchCount += 1;
|
||||||
|
if (inputMatchCount >= 6) {
|
||||||
|
this.activated = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public enum Type {
|
||||||
PRIMARY, REBUILD_V1, REBUILD_V2, REBUILD_V3
|
PRIMARY, REBUILD_V1, REBUILD_V2, REBUILD_V3
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,10 +2,10 @@ package work.slhaf.partner.core.action.entity;
|
|||||||
|
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
public class CacheAdjustData {
|
public class CacheAdjustData {
|
||||||
private String input;
|
private String input;
|
||||||
private String tendency;
|
private List<CacheAdjustMetaData> metaDataList;
|
||||||
private boolean passed;
|
|
||||||
private boolean hit;
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,10 @@
|
|||||||
|
package work.slhaf.partner.core.action.entity;
|
||||||
|
|
||||||
|
import lombok.Data;
|
||||||
|
|
||||||
|
@Data
|
||||||
|
public class CacheAdjustMetaData {
|
||||||
|
private String tendency;
|
||||||
|
private boolean passed;
|
||||||
|
private boolean hit;
|
||||||
|
}
|
||||||
@@ -1,11 +1,13 @@
|
|||||||
package work.slhaf.partner.module.modules.action.planner;
|
package work.slhaf.partner.module.modules.action.planner;
|
||||||
|
|
||||||
import work.slhaf.partner.api.agent.factory.capability.annotation.InjectCapability;
|
import work.slhaf.partner.api.agent.factory.capability.annotation.InjectCapability;
|
||||||
|
import work.slhaf.partner.api.agent.factory.module.annotation.AfterExecute;
|
||||||
import work.slhaf.partner.api.agent.factory.module.annotation.AgentModule;
|
import work.slhaf.partner.api.agent.factory.module.annotation.AgentModule;
|
||||||
import work.slhaf.partner.api.agent.factory.module.annotation.Init;
|
import work.slhaf.partner.api.agent.factory.module.annotation.Init;
|
||||||
import work.slhaf.partner.api.agent.factory.module.annotation.InjectModule;
|
import work.slhaf.partner.api.agent.factory.module.annotation.InjectModule;
|
||||||
import work.slhaf.partner.api.chat.pojo.Message;
|
import work.slhaf.partner.api.chat.pojo.Message;
|
||||||
import work.slhaf.partner.common.thread.InteractionThreadPoolExecutor;
|
import work.slhaf.partner.common.thread.InteractionThreadPoolExecutor;
|
||||||
|
import work.slhaf.partner.common.vector.VectorClient;
|
||||||
import work.slhaf.partner.core.action.ActionCapability;
|
import work.slhaf.partner.core.action.ActionCapability;
|
||||||
import work.slhaf.partner.core.action.entity.*;
|
import work.slhaf.partner.core.action.entity.*;
|
||||||
import work.slhaf.partner.core.cognation.CognationCapability;
|
import work.slhaf.partner.core.cognation.CognationCapability;
|
||||||
@@ -83,26 +85,30 @@ public class ActionPlanner extends PreRunningModule {
|
|||||||
}
|
}
|
||||||
EvaluatorInput evaluatorInput = assemblyHelper.buildEvaluatorInput(extractorResult, context.getUserId());
|
EvaluatorInput evaluatorInput = assemblyHelper.buildEvaluatorInput(extractorResult, context.getUserId());
|
||||||
List<EvaluatorResult> evaluatorResults = actionEvaluator.execute(evaluatorInput); //并发操作均为访问
|
List<EvaluatorResult> evaluatorResults = actionEvaluator.execute(evaluatorInput); //并发操作均为访问
|
||||||
if (extractorResult.isCacheEnabled())
|
|
||||||
updateTendencyCache(evaluatorResults, context.getInput(), extractorResult);
|
|
||||||
setupActionInfo(evaluatorResults, context);
|
setupActionInfo(evaluatorResults, context);
|
||||||
return null;
|
return null;
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@AfterExecute
|
||||||
private void updateTendencyCache(List<EvaluatorResult> evaluatorResults, String input, ExtractorResult extractorResult) {
|
private void updateTendencyCache(List<EvaluatorResult> evaluatorResults, String input, ExtractorResult extractorResult) {
|
||||||
|
if (!VectorClient.status) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
executor.execute(() -> {
|
executor.execute(() -> {
|
||||||
List<CacheAdjustData> list = new ArrayList<>();
|
CacheAdjustData data = new CacheAdjustData();
|
||||||
|
List<CacheAdjustMetaData> list = new ArrayList<>();
|
||||||
List<String> hitTendencies = extractorResult.getTendencies();
|
List<String> hitTendencies = extractorResult.getTendencies();
|
||||||
for (EvaluatorResult result : evaluatorResults) {
|
for (EvaluatorResult result : evaluatorResults) {
|
||||||
CacheAdjustData data = new CacheAdjustData();
|
CacheAdjustMetaData metaData = new CacheAdjustMetaData();
|
||||||
data.setTendency(result.getTendency());
|
metaData.setTendency(result.getTendency());
|
||||||
data.setInput(input);
|
metaData.setPassed(result.isOk());
|
||||||
data.setPassed(result.isOk());
|
metaData.setHit(hitTendencies.contains(result.getTendency()));
|
||||||
data.setHit(hitTendencies.contains(result.getTendency()));
|
list.add(metaData);
|
||||||
list.add(data);
|
|
||||||
}
|
}
|
||||||
actionCapability.updateTendencyCache(list);
|
data.setMetaDataList(list);
|
||||||
|
data.setInput(input);
|
||||||
|
actionCapability.updateTendencyCache(data);
|
||||||
});
|
});
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -24,7 +24,6 @@ public class ActionExtractor extends AgentRunningSubModule<ExtractorInput, Extra
|
|||||||
public ExtractorResult execute(ExtractorInput data) {
|
public ExtractorResult execute(ExtractorInput data) {
|
||||||
ExtractorResult result = new ExtractorResult();
|
ExtractorResult result = new ExtractorResult();
|
||||||
List<String> tendencyCache = actionCapability.selectTendencyCache(data.getInput());
|
List<String> tendencyCache = actionCapability.selectTendencyCache(data.getInput());
|
||||||
result.setCacheEnabled(tendencyCache != null);
|
|
||||||
if (tendencyCache != null && !tendencyCache.isEmpty()) {
|
if (tendencyCache != null && !tendencyCache.isEmpty()) {
|
||||||
result.setTendencies(tendencyCache);
|
result.setTendencies(tendencyCache);
|
||||||
return result;
|
return result;
|
||||||
|
|||||||
@@ -7,6 +7,5 @@ import java.util.List;
|
|||||||
|
|
||||||
@Data
|
@Data
|
||||||
public class ExtractorResult {
|
public class ExtractorResult {
|
||||||
private boolean cacheEnabled;
|
|
||||||
private List<String> tendencies = new ArrayList<>();
|
private List<String> tendencies = new ArrayList<>();
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user