mirror of
https://github.com/slhaf/Partner.git
synced 2026-05-12 08:43:02 +08:00
推进 Action 模块语义缓存机制
- 完善缓存命中部分; - 调整 ActionExtractor 以适配缓存逻辑 - 缓存更新大致框架待填充具体更新逻辑;
This commit is contained in:
@@ -25,11 +25,7 @@ public class AgentRunningFlow<C extends RunningFlowContext> {
|
|||||||
List<MetaModule> moduleList = entry.getValue();
|
List<MetaModule> moduleList = entry.getValue();
|
||||||
for (MetaModule module : moduleList) {
|
for (MetaModule module : moduleList) {
|
||||||
Future<?> future = executor.submit(() -> {
|
Future<?> future = executor.submit(() -> {
|
||||||
try {
|
|
||||||
module.getInstance().execute(interactionContext);
|
module.getInstance().execute(interactionContext);
|
||||||
} catch (Exception e) {
|
|
||||||
throw new AgentRuntimeException("模块执行出错: " + module.getName(), e);
|
|
||||||
}
|
|
||||||
});
|
});
|
||||||
futures.add(future);
|
futures.add(future);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -7,14 +7,12 @@ import work.slhaf.partner.api.agent.factory.module.annotation.BeforeExecute;
|
|||||||
import work.slhaf.partner.api.agent.factory.module.annotation.CoreModule;
|
import work.slhaf.partner.api.agent.factory.module.annotation.CoreModule;
|
||||||
import work.slhaf.partner.api.agent.runtime.interaction.flow.entity.RunningFlowContext;
|
import work.slhaf.partner.api.agent.runtime.interaction.flow.entity.RunningFlowContext;
|
||||||
|
|
||||||
import java.io.IOException;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 流程执行模块基类
|
* 流程执行模块基类
|
||||||
*/
|
*/
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public abstract class AgentRunningModule<C extends RunningFlowContext> extends Module {
|
public abstract class AgentRunningModule<C extends RunningFlowContext> extends Module {
|
||||||
public abstract void execute(C context) throws IOException, ClassNotFoundException;
|
public abstract void execute(C context);
|
||||||
|
|
||||||
@BeforeExecute
|
@BeforeExecute
|
||||||
private void beforeLog() {
|
private void beforeLog() {
|
||||||
|
|||||||
@@ -1,10 +1,8 @@
|
|||||||
package work.slhaf.partner.common.thread;
|
package work.slhaf.partner.common.thread;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.concurrent.Callable;
|
import java.util.concurrent.*;
|
||||||
import java.util.concurrent.ExecutorService;
|
|
||||||
import java.util.concurrent.Executors;
|
|
||||||
import java.util.concurrent.TimeUnit;
|
|
||||||
|
|
||||||
public class InteractionThreadPoolExecutor {
|
public class InteractionThreadPoolExecutor {
|
||||||
|
|
||||||
@@ -30,9 +28,29 @@ public class InteractionThreadPoolExecutor {
|
|||||||
|
|
||||||
public <T> void invokeAll(List<Callable<T>> tasks) {
|
public <T> void invokeAll(List<Callable<T>> tasks) {
|
||||||
try {
|
try {
|
||||||
executorService.invokeAll(tasks);
|
List<Future<T>> futures = executorService.invokeAll(tasks);
|
||||||
|
for (Future<T> future : futures) {
|
||||||
|
future.get();
|
||||||
|
}
|
||||||
} catch (InterruptedException e) {
|
} catch (InterruptedException e) {
|
||||||
throw new RuntimeException(e);
|
throw new RuntimeException(e);
|
||||||
|
} catch (ExecutionException e) {
|
||||||
|
throw new RuntimeException(e.getCause());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public <T> List<T> invokeAllAndReturn(List<Callable<T>> tasks) {
|
||||||
|
try {
|
||||||
|
List<Future<T>> futures = executorService.invokeAll(tasks);
|
||||||
|
List<T> results = new ArrayList<>();
|
||||||
|
for (Future<T> future : futures) {
|
||||||
|
results.add(future.get());
|
||||||
|
}
|
||||||
|
return results;
|
||||||
|
} catch (InterruptedException e) {
|
||||||
|
throw new RuntimeException(e);
|
||||||
|
} catch (ExecutionException e) {
|
||||||
|
throw new RuntimeException(e.getCause());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,15 +1,13 @@
|
|||||||
package work.slhaf.partner.common.vector;
|
package work.slhaf.partner.common.vector;
|
||||||
|
|
||||||
import java.util.Map;
|
|
||||||
|
|
||||||
import com.alibaba.fastjson2.JSONObject;
|
|
||||||
|
|
||||||
import cn.hutool.http.HttpRequest;
|
import cn.hutool.http.HttpRequest;
|
||||||
import cn.hutool.http.HttpResponse;
|
import cn.hutool.http.HttpResponse;
|
||||||
|
import com.alibaba.fastjson2.JSONObject;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import work.slhaf.partner.common.vector.exception.VectorClientExecuteException;
|
import work.slhaf.partner.common.vector.exception.VectorClientExecuteException;
|
||||||
import work.slhaf.partner.common.vector.exception.VectorClientLoadFailedException;
|
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class OllamaVectorClient extends VectorClient {
|
public class OllamaVectorClient extends VectorClient {
|
||||||
|
|||||||
@@ -1,19 +1,18 @@
|
|||||||
package work.slhaf.partner.common.vector;
|
package work.slhaf.partner.common.vector;
|
||||||
|
|
||||||
import java.io.IOException;
|
|
||||||
import java.nio.file.Path;
|
|
||||||
import java.util.HashMap;
|
|
||||||
import java.util.Map;
|
|
||||||
|
|
||||||
import ai.djl.huggingface.tokenizers.Encoding;
|
import ai.djl.huggingface.tokenizers.Encoding;
|
||||||
import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer;
|
import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer;
|
||||||
import ai.onnxruntime.OnnxTensor;
|
import ai.onnxruntime.OnnxTensor;
|
||||||
import ai.onnxruntime.OrtEnvironment;
|
import ai.onnxruntime.OrtEnvironment;
|
||||||
import ai.onnxruntime.OrtException;
|
|
||||||
import ai.onnxruntime.OrtSession;
|
import ai.onnxruntime.OrtSession;
|
||||||
import work.slhaf.partner.common.vector.exception.VectorClientExecuteException;
|
import work.slhaf.partner.common.vector.exception.VectorClientExecuteException;
|
||||||
import work.slhaf.partner.common.vector.exception.VectorClientLoadFailedException;
|
import work.slhaf.partner.common.vector.exception.VectorClientLoadFailedException;
|
||||||
|
|
||||||
|
import java.nio.file.Path;
|
||||||
|
import java.util.HashMap;
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
@SuppressWarnings("FieldMayBeFinal")
|
||||||
public class OnnxVectorClient extends VectorClient {
|
public class OnnxVectorClient extends VectorClient {
|
||||||
|
|
||||||
private String tokenizerPath;
|
private String tokenizerPath;
|
||||||
|
|||||||
@@ -1,16 +1,15 @@
|
|||||||
package work.slhaf.partner.common.vector;
|
package work.slhaf.partner.common.vector;
|
||||||
|
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.nd4j.linalg.ops.transforms.Transforms;
|
import org.nd4j.linalg.ops.transforms.Transforms;
|
||||||
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
|
||||||
import work.slhaf.partner.api.agent.runtime.config.AgentConfigManager;
|
import work.slhaf.partner.api.agent.runtime.config.AgentConfigManager;
|
||||||
import work.slhaf.partner.common.config.Config.VectorConfig;
|
import work.slhaf.partner.common.config.Config.VectorConfig;
|
||||||
|
import work.slhaf.partner.common.config.PartnerAgentConfigManager;
|
||||||
import work.slhaf.partner.common.exception.ServiceLoadFailedException;
|
import work.slhaf.partner.common.exception.ServiceLoadFailedException;
|
||||||
import work.slhaf.partner.common.vector.exception.VectorClientExecuteException;
|
import work.slhaf.partner.common.vector.exception.VectorClientExecuteException;
|
||||||
import work.slhaf.partner.common.vector.exception.VectorClientLoadFailedException;
|
import work.slhaf.partner.common.vector.exception.VectorClientLoadFailedException;
|
||||||
import work.slhaf.partner.common.config.PartnerAgentConfigManager;
|
|
||||||
|
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public abstract class VectorClient {
|
public abstract class VectorClient {
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package work.slhaf.partner.core.action;
|
package work.slhaf.partner.core.action;
|
||||||
|
|
||||||
import work.slhaf.partner.api.agent.factory.capability.annotation.Capability;
|
import work.slhaf.partner.api.agent.factory.capability.annotation.Capability;
|
||||||
|
import work.slhaf.partner.core.action.entity.CacheAdjustData;
|
||||||
import work.slhaf.partner.core.action.entity.MetaActionInfo;
|
import work.slhaf.partner.core.action.entity.MetaActionInfo;
|
||||||
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
@@ -19,5 +20,7 @@ public interface ActionCapability {
|
|||||||
|
|
||||||
void putPendingActions(String userId, MetaActionInfo metaActionInfo);
|
void putPendingActions(String userId, MetaActionInfo metaActionInfo);
|
||||||
|
|
||||||
List<String> computeActionCache(String input);
|
List<String> selectTendencyCache(String input);
|
||||||
|
|
||||||
|
void updateTendencyCache(List<CacheAdjustData> list);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,21 +1,20 @@
|
|||||||
package work.slhaf.partner.core.action;
|
package work.slhaf.partner.core.action;
|
||||||
|
|
||||||
import lombok.Getter;
|
|
||||||
import lombok.Setter;
|
|
||||||
import work.slhaf.partner.api.agent.factory.capability.annotation.Capability;
|
import work.slhaf.partner.api.agent.factory.capability.annotation.Capability;
|
||||||
import work.slhaf.partner.api.agent.factory.capability.annotation.CapabilityMethod;
|
import work.slhaf.partner.api.agent.factory.capability.annotation.CapabilityMethod;
|
||||||
import work.slhaf.partner.common.util.VectorUtil;
|
import work.slhaf.partner.common.vector.VectorClient;
|
||||||
import work.slhaf.partner.core.PartnerCore;
|
import work.slhaf.partner.core.PartnerCore;
|
||||||
import work.slhaf.partner.core.action.entity.ActionCacheData;
|
import work.slhaf.partner.core.action.entity.ActionCacheData;
|
||||||
|
import work.slhaf.partner.core.action.entity.CacheAdjustData;
|
||||||
import work.slhaf.partner.core.action.entity.MetaActionInfo;
|
import work.slhaf.partner.core.action.entity.MetaActionInfo;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.HashMap;
|
import java.util.HashMap;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
import java.util.stream.Collectors;
|
||||||
|
|
||||||
@Setter
|
@SuppressWarnings("FieldMayBeFinal")
|
||||||
@Getter
|
|
||||||
@Capability(value = "action")
|
@Capability(value = "action")
|
||||||
public class ActionCore extends PartnerCore<ActionCore> {
|
public class ActionCore extends PartnerCore<ActionCore> {
|
||||||
|
|
||||||
@@ -34,7 +33,6 @@ public class ActionCore extends PartnerCore<ActionCore> {
|
|||||||
*/
|
*/
|
||||||
private List<ActionCacheData> actionCache = new ArrayList<>();
|
private List<ActionCacheData> actionCache = new ArrayList<>();
|
||||||
|
|
||||||
//TODO 添加语义缓存,可借由简单向量模型,设想以向量结果为键、行动倾向为值
|
|
||||||
public ActionCore() throws IOException, ClassNotFoundException {
|
public ActionCore() throws IOException, ClassNotFoundException {
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -80,20 +78,71 @@ public class ActionCore extends PartnerCore<ActionCore> {
|
|||||||
return pendingActions.get(userId);
|
return pendingActions.get(userId);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 计算输入内容的语义向量,根据与{@link ActionCacheData#getInputVector()}的相似度挑取缓存,后续将根据评估结果来更新计数
|
||||||
|
*
|
||||||
|
* @param input 本次输入内容
|
||||||
|
* @return 命中的行为倾向集合
|
||||||
|
*/
|
||||||
@CapabilityMethod
|
@CapabilityMethod
|
||||||
public List<String> computeActionCache(String input){
|
public List<String> selectTendencyCache(String input) {
|
||||||
//计算本次输入的向量
|
if (!VectorClient.status) {
|
||||||
float[] vector = VectorUtil.compute(input);
|
|
||||||
if (vector == null) return null;
|
|
||||||
//与现有缓存比对,如果存在,则使缓存计数+1
|
|
||||||
actionCache.stream()
|
|
||||||
.filter(ActionCacheData::isActivated)
|
|
||||||
.forEach(data -> {
|
|
||||||
double compared = VectorUtil.compare(vector, data.getInputVector());
|
|
||||||
});
|
|
||||||
|
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
|
VectorClient vectorClient = VectorClient.INSTANCE;
|
||||||
|
//计算本次输入的向量
|
||||||
|
float[] vector = vectorClient.compute(input);
|
||||||
|
if (vector == null) return null;
|
||||||
|
//与现有缓存比对,将匹配到的收集并返回
|
||||||
|
return actionCache.parallelStream()
|
||||||
|
.filter(ActionCacheData::isActivated)
|
||||||
|
.filter(data -> {
|
||||||
|
double compared = vectorClient.compare(vector, data.getInputVector());
|
||||||
|
return compared > data.getThreshold();
|
||||||
|
})
|
||||||
|
.map(ActionCacheData::getTendency)
|
||||||
|
.collect(Collectors.toList());
|
||||||
|
}
|
||||||
|
|
||||||
|
@CapabilityMethod
|
||||||
|
public void updateTendencyCache(List<CacheAdjustData> list) {
|
||||||
|
List<CacheAdjustData> matchAndPassed = new ArrayList<>();
|
||||||
|
List<CacheAdjustData> matchNotPassed = new ArrayList<>();
|
||||||
|
List<CacheAdjustData> notMatchPassed = new ArrayList<>();
|
||||||
|
|
||||||
|
for (CacheAdjustData data : list) {
|
||||||
|
if (data.isHit() && data.isPassed()) {
|
||||||
|
matchAndPassed.add(data);
|
||||||
|
} else if (data.isHit()) {
|
||||||
|
matchNotPassed.add(data);
|
||||||
|
} else if (!data.isPassed()) {
|
||||||
|
notMatchPassed.add(data);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
VectorClient vectorClient = VectorClient.INSTANCE;
|
||||||
|
adjustMatchAndPassed(matchAndPassed, vectorClient);
|
||||||
|
adjustMatchNotPassed(matchNotPassed, vectorClient);
|
||||||
|
adjustNotMatchPassed(notMatchPassed, vectorClient);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 命中缓存且评估通过时,根据输入内容的语义向量与现有的输入语义向量进行带权移动平均,以相似度为权重
|
||||||
|
*
|
||||||
|
* @param matchAndPassed 该类型的带调整缓存信息列表
|
||||||
|
* @param vectorClient 向量客户端
|
||||||
|
*/
|
||||||
|
private void adjustMatchAndPassed(List<CacheAdjustData> matchAndPassed, VectorClient vectorClient) {
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
private void adjustMatchNotPassed(List<CacheAdjustData> matchNotPassed, VectorClient vectorClient) {
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
private void adjustNotMatchPassed(List<CacheAdjustData> notMatchPassed, VectorClient vectorClient) {
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
protected String getCoreKey() {
|
protected String getCoreKey() {
|
||||||
|
|||||||
@@ -7,11 +7,19 @@ import java.util.List;
|
|||||||
|
|
||||||
@Data
|
@Data
|
||||||
public class ActionCacheData {
|
public class ActionCacheData {
|
||||||
|
private boolean activated;
|
||||||
|
private int inputMatchCount;
|
||||||
|
|
||||||
private float[] inputVector;
|
private float[] inputVector;
|
||||||
private float[] tendencyVector;
|
private float[] tendencyVector;
|
||||||
private String tendency;
|
private String tendency;
|
||||||
private int inputMatchCount;
|
|
||||||
private boolean activated;
|
|
||||||
private List<float[]> validSamples = new ArrayList<>();
|
|
||||||
private double threshold;
|
private double threshold;
|
||||||
|
|
||||||
|
private List<String> validSamples = new ArrayList<>();
|
||||||
|
private int failedCount;
|
||||||
|
private Type type;
|
||||||
|
|
||||||
|
enum Type {
|
||||||
|
PRIMARY, REBUILD_V1, REBUILD_V2, REBUILD_V3
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,11 @@
|
|||||||
|
package work.slhaf.partner.core.action.entity;
|
||||||
|
|
||||||
|
import lombok.Data;
|
||||||
|
|
||||||
|
@Data
|
||||||
|
public class CacheAdjustData {
|
||||||
|
private String input;
|
||||||
|
private String tendency;
|
||||||
|
private boolean passed;
|
||||||
|
private boolean hit;
|
||||||
|
}
|
||||||
@@ -5,7 +5,6 @@ import work.slhaf.partner.core.memory.pojo.EvaluatedSlice;
|
|||||||
import work.slhaf.partner.core.memory.pojo.MemoryResult;
|
import work.slhaf.partner.core.memory.pojo.MemoryResult;
|
||||||
import work.slhaf.partner.core.memory.pojo.MemorySlice;
|
import work.slhaf.partner.core.memory.pojo.MemorySlice;
|
||||||
|
|
||||||
import java.io.IOException;
|
|
||||||
import java.time.LocalDate;
|
import java.time.LocalDate;
|
||||||
import java.time.LocalDateTime;
|
import java.time.LocalDateTime;
|
||||||
import java.util.HashMap;
|
import java.util.HashMap;
|
||||||
@@ -45,7 +44,7 @@ public interface MemoryCapability {
|
|||||||
|
|
||||||
MemoryResult selectMemory(String topicPathStr);
|
MemoryResult selectMemory(String topicPathStr);
|
||||||
|
|
||||||
MemoryResult selectMemory(LocalDate date) throws IOException, ClassNotFoundException;
|
MemoryResult selectMemory(LocalDate date);
|
||||||
|
|
||||||
void insertSlice(MemorySlice memorySlice, String topicPath);
|
void insertSlice(MemorySlice memorySlice, String topicPath);
|
||||||
|
|
||||||
|
|||||||
@@ -3,12 +3,10 @@ package work.slhaf.partner.module.common.module;
|
|||||||
import work.slhaf.partner.api.agent.runtime.interaction.flow.abstracts.AgentRunningModule;
|
import work.slhaf.partner.api.agent.runtime.interaction.flow.abstracts.AgentRunningModule;
|
||||||
import work.slhaf.partner.runtime.interaction.data.context.PartnerRunningFlowContext;
|
import work.slhaf.partner.runtime.interaction.data.context.PartnerRunningFlowContext;
|
||||||
|
|
||||||
import java.io.IOException;
|
|
||||||
|
|
||||||
public abstract class PostRunningModule extends AgentRunningModule<PartnerRunningFlowContext> {
|
public abstract class PostRunningModule extends AgentRunningModule<PartnerRunningFlowContext> {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public final void execute(PartnerRunningFlowContext context) throws IOException, ClassNotFoundException {
|
public final void execute(PartnerRunningFlowContext context) {
|
||||||
boolean trigger = context.getModuleContext().getExtraContext().getBoolean("post_process_trigger");
|
boolean trigger = context.getModuleContext().getExtraContext().getBoolean("post_process_trigger");
|
||||||
if (!trigger) {
|
if (!trigger) {
|
||||||
return;
|
return;
|
||||||
|
|||||||
@@ -5,7 +5,6 @@ import work.slhaf.partner.api.agent.runtime.interaction.flow.abstracts.AgentRunn
|
|||||||
import work.slhaf.partner.module.common.entity.AppendPromptData;
|
import work.slhaf.partner.module.common.entity.AppendPromptData;
|
||||||
import work.slhaf.partner.runtime.interaction.data.context.PartnerRunningFlowContext;
|
import work.slhaf.partner.runtime.interaction.data.context.PartnerRunningFlowContext;
|
||||||
|
|
||||||
import java.io.IOException;
|
|
||||||
import java.util.HashMap;
|
import java.util.HashMap;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@@ -33,13 +32,13 @@ public abstract class PreRunningModule extends AgentRunningModule<PartnerRunning
|
|||||||
protected abstract String moduleName();
|
protected abstract String moduleName();
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public final void execute(PartnerRunningFlowContext context) throws IOException, ClassNotFoundException {
|
public final void execute(PartnerRunningFlowContext context) {
|
||||||
doExecute(context); // 子类实现差异化逻辑
|
doExecute(context); // 子类实现差异化逻辑
|
||||||
setAppendedPrompt(context); // 通用逻辑
|
setAppendedPrompt(context); // 通用逻辑
|
||||||
setActiveModule(context); // 通用逻辑
|
setActiveModule(context); // 通用逻辑
|
||||||
}
|
}
|
||||||
|
|
||||||
protected abstract void doExecute(PartnerRunningFlowContext context) throws IOException, ClassNotFoundException;
|
protected abstract void doExecute(PartnerRunningFlowContext context);
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -82,12 +82,31 @@ public class ActionPlanner extends PreRunningModule {
|
|||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
EvaluatorInput evaluatorInput = assemblyHelper.buildEvaluatorInput(extractorResult, context.getUserId());
|
EvaluatorInput evaluatorInput = assemblyHelper.buildEvaluatorInput(extractorResult, context.getUserId());
|
||||||
List<EvaluatorResult> evaluatorResults = actionEvaluator.execute(evaluatorInput);
|
List<EvaluatorResult> evaluatorResults = actionEvaluator.execute(evaluatorInput); //并发操作均为访问
|
||||||
setupPreparedActionInfo(evaluatorResults, context);
|
if (extractorResult.isCacheEnabled())
|
||||||
|
updateTendencyCache(evaluatorResults, context.getInput(), extractorResult);
|
||||||
|
setupActionInfo(evaluatorResults, context);
|
||||||
return null;
|
return null;
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private void updateTendencyCache(List<EvaluatorResult> evaluatorResults, String input, ExtractorResult extractorResult) {
|
||||||
|
executor.execute(() -> {
|
||||||
|
List<CacheAdjustData> list = new ArrayList<>();
|
||||||
|
List<String> hitTendencies = extractorResult.getTendencies();
|
||||||
|
for (EvaluatorResult result : evaluatorResults) {
|
||||||
|
CacheAdjustData data = new CacheAdjustData();
|
||||||
|
data.setTendency(result.getTendency());
|
||||||
|
data.setInput(input);
|
||||||
|
data.setPassed(result.isOk());
|
||||||
|
data.setHit(hitTendencies.contains(result.getTendency()));
|
||||||
|
list.add(data);
|
||||||
|
}
|
||||||
|
actionCapability.updateTendencyCache(list);
|
||||||
|
});
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 待确认行动的判断任务
|
* 待确认行动的判断任务
|
||||||
*
|
*
|
||||||
@@ -119,13 +138,12 @@ public class ActionPlanner extends PreRunningModule {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
private void setupPreparedActionInfo(List<EvaluatorResult> evaluatorResults, PartnerRunningFlowContext context) {
|
private void setupActionInfo(List<EvaluatorResult> evaluatorResults, PartnerRunningFlowContext context) {
|
||||||
for (EvaluatorResult evaluatorResult : evaluatorResults) {
|
for (EvaluatorResult evaluatorResult : evaluatorResults) {
|
||||||
if (evaluatorResult.isNeedConfirm()) {
|
|
||||||
MetaActionInfo metaActionInfo = assemblyHelper.buildMetaActionInfo(evaluatorResult);
|
MetaActionInfo metaActionInfo = assemblyHelper.buildMetaActionInfo(evaluatorResult);
|
||||||
|
if (evaluatorResult.isNeedConfirm()) {
|
||||||
actionCapability.putPendingActions(context.getUserId(), metaActionInfo);
|
actionCapability.putPendingActions(context.getUserId(), metaActionInfo);
|
||||||
} else {
|
} else {
|
||||||
MetaActionInfo metaActionInfo = assemblyHelper.buildMetaActionInfo(evaluatorResult);
|
|
||||||
actionCapability.putPreparedAction(context.getUuid(), metaActionInfo);
|
actionCapability.putPreparedAction(context.getUuid(), metaActionInfo);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -192,7 +210,7 @@ public class ActionPlanner extends PreRunningModule {
|
|||||||
return input;
|
return input;
|
||||||
}
|
}
|
||||||
|
|
||||||
public EvaluatorInput buildEvaluatorInput(ExtractorResult extractorResult, String userId) {
|
private EvaluatorInput buildEvaluatorInput(ExtractorResult extractorResult, String userId) {
|
||||||
EvaluatorInput input = new EvaluatorInput();
|
EvaluatorInput input = new EvaluatorInput();
|
||||||
input.setTendencies(extractorResult.getTendencies());
|
input.setTendencies(extractorResult.getTendencies());
|
||||||
input.setUser(perceiveCapability.getUser(userId));
|
input.setUser(perceiveCapability.getUser(userId));
|
||||||
@@ -201,7 +219,7 @@ public class ActionPlanner extends PreRunningModule {
|
|||||||
return input;
|
return input;
|
||||||
}
|
}
|
||||||
|
|
||||||
public MetaActionInfo buildMetaActionInfo(EvaluatorResult evaluatorResult) {
|
private MetaActionInfo buildMetaActionInfo(EvaluatorResult evaluatorResult) {
|
||||||
return switch (evaluatorResult.getType()) {
|
return switch (evaluatorResult.getType()) {
|
||||||
case PLANNING -> {
|
case PLANNING -> {
|
||||||
ScheduledActionInfo actionInfo = new ScheduledActionInfo();
|
ScheduledActionInfo actionInfo = new ScheduledActionInfo();
|
||||||
@@ -221,7 +239,7 @@ public class ActionPlanner extends PreRunningModule {
|
|||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
public ConfirmerInput buildConfirmerInput(PartnerRunningFlowContext context) {
|
private ConfirmerInput buildConfirmerInput(PartnerRunningFlowContext context) {
|
||||||
ConfirmerInput confirmerInput = new ConfirmerInput();
|
ConfirmerInput confirmerInput = new ConfirmerInput();
|
||||||
confirmerInput.setInput(context.getInput());
|
confirmerInput.setInput(context.getInput());
|
||||||
List<MetaActionInfo> pendingActions = actionCapability.listPendingAction(context.getUserId());
|
List<MetaActionInfo> pendingActions = actionCapability.listPendingAction(context.getUserId());
|
||||||
|
|||||||
@@ -1,20 +1,67 @@
|
|||||||
package work.slhaf.partner.module.modules.action.planner.evaluator;
|
package work.slhaf.partner.module.modules.action.planner.evaluator;
|
||||||
|
|
||||||
|
import cn.hutool.core.bean.BeanUtil;
|
||||||
|
import com.alibaba.fastjson2.JSONObject;
|
||||||
import work.slhaf.partner.api.agent.factory.module.annotation.AgentSubModule;
|
import work.slhaf.partner.api.agent.factory.module.annotation.AgentSubModule;
|
||||||
|
import work.slhaf.partner.api.agent.factory.module.annotation.Init;
|
||||||
import work.slhaf.partner.api.agent.runtime.interaction.flow.abstracts.ActivateModel;
|
import work.slhaf.partner.api.agent.runtime.interaction.flow.abstracts.ActivateModel;
|
||||||
import work.slhaf.partner.api.agent.runtime.interaction.flow.abstracts.AgentRunningSubModule;
|
import work.slhaf.partner.api.agent.runtime.interaction.flow.abstracts.AgentRunningSubModule;
|
||||||
|
import work.slhaf.partner.api.chat.pojo.ChatResponse;
|
||||||
|
import work.slhaf.partner.common.thread.InteractionThreadPoolExecutor;
|
||||||
|
import work.slhaf.partner.module.modules.action.planner.evaluator.entity.EvaluatorBatchInput;
|
||||||
import work.slhaf.partner.module.modules.action.planner.evaluator.entity.EvaluatorInput;
|
import work.slhaf.partner.module.modules.action.planner.evaluator.entity.EvaluatorInput;
|
||||||
import work.slhaf.partner.module.modules.action.planner.evaluator.entity.EvaluatorResult;
|
import work.slhaf.partner.module.modules.action.planner.evaluator.entity.EvaluatorResult;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
import java.util.concurrent.Callable;
|
||||||
|
|
||||||
@AgentSubModule
|
@AgentSubModule
|
||||||
public class ActionEvaluator extends AgentRunningSubModule<EvaluatorInput, List<EvaluatorResult>> implements ActivateModel {
|
public class ActionEvaluator extends AgentRunningSubModule<EvaluatorInput, List<EvaluatorResult>> implements ActivateModel {
|
||||||
|
|
||||||
|
private InteractionThreadPoolExecutor executor;
|
||||||
|
|
||||||
|
@Init
|
||||||
|
public void init() {
|
||||||
|
executor = InteractionThreadPoolExecutor.getInstance();
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 对输入的行为倾向进行评估,并根据评估结果,对缓存做出调整
|
||||||
|
*
|
||||||
|
* @param data 评估输入内容,包含提取/命中缓存的行动倾向、近几条聊天记录,正在生效的记忆切片内容
|
||||||
|
* @return 评估结果集合,包含
|
||||||
|
*/
|
||||||
@Override
|
@Override
|
||||||
public List<EvaluatorResult> execute(EvaluatorInput data) {
|
public List<EvaluatorResult> execute(EvaluatorInput data) {
|
||||||
|
List<EvaluatorBatchInput> batchInputs = buildEvaluatorBatchInput(data);
|
||||||
|
List<Callable<EvaluatorResult>> tasks = getTasks(batchInputs);
|
||||||
|
return executor.invokeAllAndReturn(tasks);
|
||||||
|
}
|
||||||
|
|
||||||
return null;
|
|
||||||
|
private List<Callable<EvaluatorResult>> getTasks(List<EvaluatorBatchInput> batchInputs) {
|
||||||
|
List<Callable<EvaluatorResult>> list = new ArrayList<>();
|
||||||
|
for (EvaluatorBatchInput batchInput : batchInputs) {
|
||||||
|
list.add(() -> {
|
||||||
|
ChatResponse response = this.singleChat(JSONObject.toJSONString(batchInput));
|
||||||
|
EvaluatorResult evaluatorResult = JSONObject.parseObject(response.getMessage(), EvaluatorResult.class);
|
||||||
|
evaluatorResult.setTendency(batchInput.getTendency());
|
||||||
|
return evaluatorResult;
|
||||||
|
});
|
||||||
|
}
|
||||||
|
return list;
|
||||||
|
}
|
||||||
|
|
||||||
|
private List<EvaluatorBatchInput> buildEvaluatorBatchInput(EvaluatorInput data) {
|
||||||
|
List<EvaluatorBatchInput> list = new ArrayList<>();
|
||||||
|
for (String tendency : data.getTendencies()) {
|
||||||
|
EvaluatorBatchInput temp = new EvaluatorBatchInput();
|
||||||
|
BeanUtil.copyProperties(data, temp);
|
||||||
|
temp.setTendency(tendency);
|
||||||
|
list.add(temp);
|
||||||
|
}
|
||||||
|
return list;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|||||||
@@ -0,0 +1,14 @@
|
|||||||
|
package work.slhaf.partner.module.modules.action.planner.evaluator.entity;
|
||||||
|
|
||||||
|
import lombok.Data;
|
||||||
|
import work.slhaf.partner.api.chat.pojo.Message;
|
||||||
|
import work.slhaf.partner.core.memory.pojo.EvaluatedSlice;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
@Data
|
||||||
|
public class EvaluatorBatchInput {
|
||||||
|
private List<Message> recentMessages;
|
||||||
|
private List<EvaluatedSlice> activatedSlices;
|
||||||
|
private String tendency;
|
||||||
|
}
|
||||||
@@ -11,4 +11,5 @@ public class EvaluatorResult {
|
|||||||
private ActionType type;
|
private ActionType type;
|
||||||
private String scheduleContent;
|
private String scheduleContent;
|
||||||
private ActionData actionData;
|
private ActionData actionData;
|
||||||
|
private String tendency;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,21 +1,18 @@
|
|||||||
package work.slhaf.partner.module.modules.action.planner.extractor;
|
package work.slhaf.partner.module.modules.action.planner.extractor;
|
||||||
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
import com.alibaba.fastjson2.JSONObject;
|
import com.alibaba.fastjson2.JSONObject;
|
||||||
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import work.slhaf.partner.api.agent.factory.capability.annotation.InjectCapability;
|
import work.slhaf.partner.api.agent.factory.capability.annotation.InjectCapability;
|
||||||
import work.slhaf.partner.api.agent.factory.module.annotation.AgentSubModule;
|
import work.slhaf.partner.api.agent.factory.module.annotation.AgentSubModule;
|
||||||
import work.slhaf.partner.api.agent.runtime.interaction.flow.abstracts.ActivateModel;
|
import work.slhaf.partner.api.agent.runtime.interaction.flow.abstracts.ActivateModel;
|
||||||
import work.slhaf.partner.api.agent.runtime.interaction.flow.abstracts.AgentRunningSubModule;
|
import work.slhaf.partner.api.agent.runtime.interaction.flow.abstracts.AgentRunningSubModule;
|
||||||
import work.slhaf.partner.api.chat.constant.ChatConstant;
|
|
||||||
import work.slhaf.partner.api.chat.pojo.ChatResponse;
|
import work.slhaf.partner.api.chat.pojo.ChatResponse;
|
||||||
import work.slhaf.partner.api.chat.pojo.Message;
|
|
||||||
import work.slhaf.partner.core.action.ActionCapability;
|
import work.slhaf.partner.core.action.ActionCapability;
|
||||||
import work.slhaf.partner.module.modules.action.planner.extractor.entity.ExtractorInput;
|
import work.slhaf.partner.module.modules.action.planner.extractor.entity.ExtractorInput;
|
||||||
import work.slhaf.partner.module.modules.action.planner.extractor.entity.ExtractorResult;
|
import work.slhaf.partner.module.modules.action.planner.extractor.entity.ExtractorResult;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
@Slf4j
|
@Slf4j
|
||||||
@AgentSubModule
|
@AgentSubModule
|
||||||
public class ActionExtractor extends AgentRunningSubModule<ExtractorInput, ExtractorResult> implements ActivateModel {
|
public class ActionExtractor extends AgentRunningSubModule<ExtractorInput, ExtractorResult> implements ActivateModel {
|
||||||
@@ -25,17 +22,17 @@ public class ActionExtractor extends AgentRunningSubModule<ExtractorInput, Extra
|
|||||||
|
|
||||||
@Override
|
@Override
|
||||||
public ExtractorResult execute(ExtractorInput data) {
|
public ExtractorResult execute(ExtractorInput data) {
|
||||||
// TODO 添加语义缓存判断
|
|
||||||
List<String> tendencyCache = actionCapability.computeActionCache(data.getInput());
|
|
||||||
if ( tendencyCache == null || !tendencyCache.isEmpty()) {
|
|
||||||
ExtractorResult result = new ExtractorResult();
|
ExtractorResult result = new ExtractorResult();
|
||||||
|
List<String> tendencyCache = actionCapability.selectTendencyCache(data.getInput());
|
||||||
|
result.setCacheEnabled(tendencyCache != null);
|
||||||
|
if (tendencyCache != null && !tendencyCache.isEmpty()) {
|
||||||
|
result.setTendencies(tendencyCache);
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
for (int i = 0; i < 3; i++) {
|
for (int i = 0; i < 3; i++) {
|
||||||
try {
|
try {
|
||||||
this.chatMessages().add(new Message(ChatConstant.Character.USER, JSONObject.toJSONString(data)));
|
ChatResponse response = this.singleChat(JSONObject.toJSONString(data));
|
||||||
ChatResponse response = this.chat();
|
|
||||||
return JSONObject.parseObject(response.getMessage(), ExtractorResult.class);
|
return JSONObject.parseObject(response.getMessage(), ExtractorResult.class);
|
||||||
} catch (Exception e) {
|
} catch (Exception e) {
|
||||||
log.error("[ActionExtractor] 提取信息出错", e);
|
log.error("[ActionExtractor] 提取信息出错", e);
|
||||||
|
|||||||
@@ -7,5 +7,6 @@ import java.util.List;
|
|||||||
|
|
||||||
@Data
|
@Data
|
||||||
public class ExtractorResult {
|
public class ExtractorResult {
|
||||||
|
private boolean cacheEnabled;
|
||||||
private List<String> tendencies = new ArrayList<>();
|
private List<String> tendencies = new ArrayList<>();
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -22,7 +22,6 @@ import work.slhaf.partner.module.modules.memory.selector.extractor.entity.Extrac
|
|||||||
import work.slhaf.partner.module.modules.memory.selector.extractor.entity.ExtractorResult;
|
import work.slhaf.partner.module.modules.memory.selector.extractor.entity.ExtractorResult;
|
||||||
import work.slhaf.partner.runtime.interaction.data.context.PartnerRunningFlowContext;
|
import work.slhaf.partner.runtime.interaction.data.context.PartnerRunningFlowContext;
|
||||||
|
|
||||||
import java.io.IOException;
|
|
||||||
import java.time.LocalDate;
|
import java.time.LocalDate;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.Collection;
|
import java.util.Collection;
|
||||||
@@ -46,7 +45,7 @@ public class MemorySelector extends PreRunningModule {
|
|||||||
private MemorySelectExtractor memorySelectExtractor;
|
private MemorySelectExtractor memorySelectExtractor;
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void doExecute(PartnerRunningFlowContext runningFlowContext) throws IOException, ClassNotFoundException {
|
public void doExecute(PartnerRunningFlowContext runningFlowContext) {
|
||||||
String userId = runningFlowContext.getUserId();
|
String userId = runningFlowContext.getUserId();
|
||||||
//获取主题路径
|
//获取主题路径
|
||||||
ExtractorResult extractorResult = memorySelectExtractor.execute(runningFlowContext);
|
ExtractorResult extractorResult = memorySelectExtractor.execute(runningFlowContext);
|
||||||
@@ -58,7 +57,7 @@ public class MemorySelector extends PreRunningModule {
|
|||||||
setModuleContextRecall(runningFlowContext);
|
setModuleContextRecall(runningFlowContext);
|
||||||
}
|
}
|
||||||
|
|
||||||
private List<EvaluatedSlice> selectAndEvaluateMemory(PartnerRunningFlowContext runningFlowContext, ExtractorResult extractorResult) throws IOException, ClassNotFoundException {
|
private List<EvaluatedSlice> selectAndEvaluateMemory(PartnerRunningFlowContext runningFlowContext, ExtractorResult extractorResult) {
|
||||||
log.debug("[MemorySelector] 触发记忆回溯...");
|
log.debug("[MemorySelector] 触发记忆回溯...");
|
||||||
//查找切片
|
//查找切片
|
||||||
String userId = runningFlowContext.getUserId();
|
String userId = runningFlowContext.getUserId();
|
||||||
@@ -86,7 +85,7 @@ public class MemorySelector extends PreRunningModule {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
private void setMemoryResultList(List<MemoryResult> memoryResultList, List<ExtractorMatchData> matches, String userId) throws IOException, ClassNotFoundException {
|
private void setMemoryResultList(List<MemoryResult> memoryResultList, List<ExtractorMatchData> matches, String userId) {
|
||||||
for (ExtractorMatchData match : matches) {
|
for (ExtractorMatchData match : matches) {
|
||||||
try {
|
try {
|
||||||
MemoryResult memoryResult = switch (match.getType()) {
|
MemoryResult memoryResult = switch (match.getType()) {
|
||||||
|
|||||||
@@ -9,8 +9,6 @@ import work.slhaf.partner.api.agent.runtime.interaction.flow.abstracts.AgentRunn
|
|||||||
import work.slhaf.partner.core.cognation.CognationCapability;
|
import work.slhaf.partner.core.cognation.CognationCapability;
|
||||||
import work.slhaf.partner.runtime.interaction.data.context.PartnerRunningFlowContext;
|
import work.slhaf.partner.runtime.interaction.data.context.PartnerRunningFlowContext;
|
||||||
|
|
||||||
import java.io.IOException;
|
|
||||||
|
|
||||||
@EqualsAndHashCode(callSuper = true)
|
@EqualsAndHashCode(callSuper = true)
|
||||||
@Slf4j
|
@Slf4j
|
||||||
@Data
|
@Data
|
||||||
@@ -23,7 +21,7 @@ public class PostprocessExecutor extends AgentRunningModule<PartnerRunningFlowCo
|
|||||||
private CognationCapability cognationCapability;
|
private CognationCapability cognationCapability;
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void execute(PartnerRunningFlowContext context) throws IOException, ClassNotFoundException {
|
public void execute(PartnerRunningFlowContext context) {
|
||||||
boolean trigger = cognationCapability.getChatMessages().size() >= POST_PROCESS_TRIGGER_ROLL_LIMIT;
|
boolean trigger = cognationCapability.getChatMessages().size() >= POST_PROCESS_TRIGGER_ROLL_LIMIT;
|
||||||
context.getModuleContext().getExtraContext().put("post_process_trigger", trigger);
|
context.getModuleContext().getExtraContext().put("post_process_trigger", trigger);
|
||||||
log.debug("[PostprocessExecutor] 是否执行后处理: {}", trigger);
|
log.debug("[PostprocessExecutor] 是否执行后处理: {}", trigger);
|
||||||
|
|||||||
Reference in New Issue
Block a user