diff --git a/Partner-Api/src/main/java/work/slhaf/partner/api/agent/runtime/interaction/flow/AgentRunningFlow.java b/Partner-Api/src/main/java/work/slhaf/partner/api/agent/runtime/interaction/flow/AgentRunningFlow.java index 531cdc88..4ed2a300 100644 --- a/Partner-Api/src/main/java/work/slhaf/partner/api/agent/runtime/interaction/flow/AgentRunningFlow.java +++ b/Partner-Api/src/main/java/work/slhaf/partner/api/agent/runtime/interaction/flow/AgentRunningFlow.java @@ -25,11 +25,7 @@ public class AgentRunningFlow { List moduleList = entry.getValue(); for (MetaModule module : moduleList) { Future future = executor.submit(() -> { - try { - module.getInstance().execute(interactionContext); - } catch (Exception e) { - throw new AgentRuntimeException("模块执行出错: " + module.getName(), e); - } + module.getInstance().execute(interactionContext); }); futures.add(future); } diff --git a/Partner-Api/src/main/java/work/slhaf/partner/api/agent/runtime/interaction/flow/abstracts/AgentRunningModule.java b/Partner-Api/src/main/java/work/slhaf/partner/api/agent/runtime/interaction/flow/abstracts/AgentRunningModule.java index 9d43095c..4dbe0295 100644 --- a/Partner-Api/src/main/java/work/slhaf/partner/api/agent/runtime/interaction/flow/abstracts/AgentRunningModule.java +++ b/Partner-Api/src/main/java/work/slhaf/partner/api/agent/runtime/interaction/flow/abstracts/AgentRunningModule.java @@ -7,14 +7,12 @@ import work.slhaf.partner.api.agent.factory.module.annotation.BeforeExecute; import work.slhaf.partner.api.agent.factory.module.annotation.CoreModule; import work.slhaf.partner.api.agent.runtime.interaction.flow.entity.RunningFlowContext; -import java.io.IOException; - /** * 流程执行模块基类 */ @Slf4j public abstract class AgentRunningModule extends Module { - public abstract void execute(C context) throws IOException, ClassNotFoundException; + public abstract void execute(C context); @BeforeExecute private void beforeLog() { diff --git a/Partner-Main/src/main/java/work/slhaf/partner/common/thread/InteractionThreadPoolExecutor.java b/Partner-Main/src/main/java/work/slhaf/partner/common/thread/InteractionThreadPoolExecutor.java index d74d4171..23ea8e97 100644 --- a/Partner-Main/src/main/java/work/slhaf/partner/common/thread/InteractionThreadPoolExecutor.java +++ b/Partner-Main/src/main/java/work/slhaf/partner/common/thread/InteractionThreadPoolExecutor.java @@ -1,12 +1,10 @@ package work.slhaf.partner.common.thread; +import java.util.ArrayList; import java.util.List; -import java.util.concurrent.Callable; -import java.util.concurrent.ExecutorService; -import java.util.concurrent.Executors; -import java.util.concurrent.TimeUnit; +import java.util.concurrent.*; -public class InteractionThreadPoolExecutor { +public class InteractionThreadPoolExecutor { private static InteractionThreadPoolExecutor interactionThreadPoolExecutor; @@ -27,12 +25,32 @@ public class InteractionThreadPoolExecutor { throw new RuntimeException(e); } } - + public void invokeAll(List> tasks) { try { - executorService.invokeAll(tasks); + List> futures = executorService.invokeAll(tasks); + for (Future future : futures) { + future.get(); + } } catch (InterruptedException e) { throw new RuntimeException(e); + } catch (ExecutionException e) { + throw new RuntimeException(e.getCause()); + } + } + + public List invokeAllAndReturn(List> tasks) { + try { + List> futures = executorService.invokeAll(tasks); + List results = new ArrayList<>(); + for (Future future : futures) { + results.add(future.get()); + } + return results; + } catch (InterruptedException e) { + throw new RuntimeException(e); + } catch (ExecutionException e) { + throw new RuntimeException(e.getCause()); } } diff --git a/Partner-Main/src/main/java/work/slhaf/partner/common/vector/OllamaVectorClient.java b/Partner-Main/src/main/java/work/slhaf/partner/common/vector/OllamaVectorClient.java index 84c2cdfa..786eb88b 100644 --- a/Partner-Main/src/main/java/work/slhaf/partner/common/vector/OllamaVectorClient.java +++ b/Partner-Main/src/main/java/work/slhaf/partner/common/vector/OllamaVectorClient.java @@ -1,15 +1,13 @@ package work.slhaf.partner.common.vector; -import java.util.Map; - -import com.alibaba.fastjson2.JSONObject; - import cn.hutool.http.HttpRequest; import cn.hutool.http.HttpResponse; +import com.alibaba.fastjson2.JSONObject; import lombok.Data; import lombok.extern.slf4j.Slf4j; import work.slhaf.partner.common.vector.exception.VectorClientExecuteException; -import work.slhaf.partner.common.vector.exception.VectorClientLoadFailedException; + +import java.util.Map; @Slf4j public class OllamaVectorClient extends VectorClient { diff --git a/Partner-Main/src/main/java/work/slhaf/partner/common/vector/OnnxVectorClient.java b/Partner-Main/src/main/java/work/slhaf/partner/common/vector/OnnxVectorClient.java index 4a4c6581..12d94905 100644 --- a/Partner-Main/src/main/java/work/slhaf/partner/common/vector/OnnxVectorClient.java +++ b/Partner-Main/src/main/java/work/slhaf/partner/common/vector/OnnxVectorClient.java @@ -1,19 +1,18 @@ package work.slhaf.partner.common.vector; -import java.io.IOException; -import java.nio.file.Path; -import java.util.HashMap; -import java.util.Map; - import ai.djl.huggingface.tokenizers.Encoding; import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer; import ai.onnxruntime.OnnxTensor; import ai.onnxruntime.OrtEnvironment; -import ai.onnxruntime.OrtException; import ai.onnxruntime.OrtSession; import work.slhaf.partner.common.vector.exception.VectorClientExecuteException; import work.slhaf.partner.common.vector.exception.VectorClientLoadFailedException; +import java.nio.file.Path; +import java.util.HashMap; +import java.util.Map; + +@SuppressWarnings("FieldMayBeFinal") public class OnnxVectorClient extends VectorClient { private String tokenizerPath; @@ -57,9 +56,9 @@ public class OnnxVectorClient extends VectorClient { long[] ids = encode.getIds(); long[] attentionMask = encode.getAttentionMask(); - long[][] inputIdsBatch = { ids }; - long[][] attentionMaskBatch = { attentionMask }; - long[][] tokenTypeIdsBatch = { new long[ids.length] }; // 初始化全 0 + long[][] inputIdsBatch = {ids}; + long[][] attentionMaskBatch = {attentionMask}; + long[][] tokenTypeIdsBatch = {new long[ids.length]}; // 初始化全 0 for (int i = 0; i < ids.length; i++) tokenTypeIdsBatch[0][i] = 0; diff --git a/Partner-Main/src/main/java/work/slhaf/partner/common/vector/VectorClient.java b/Partner-Main/src/main/java/work/slhaf/partner/common/vector/VectorClient.java index 76b3ebfd..26229d4b 100644 --- a/Partner-Main/src/main/java/work/slhaf/partner/common/vector/VectorClient.java +++ b/Partner-Main/src/main/java/work/slhaf/partner/common/vector/VectorClient.java @@ -1,16 +1,15 @@ package work.slhaf.partner.common.vector; +import lombok.extern.slf4j.Slf4j; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.ops.transforms.Transforms; - -import lombok.extern.slf4j.Slf4j; import work.slhaf.partner.api.agent.runtime.config.AgentConfigManager; import work.slhaf.partner.common.config.Config.VectorConfig; +import work.slhaf.partner.common.config.PartnerAgentConfigManager; import work.slhaf.partner.common.exception.ServiceLoadFailedException; import work.slhaf.partner.common.vector.exception.VectorClientExecuteException; import work.slhaf.partner.common.vector.exception.VectorClientLoadFailedException; -import work.slhaf.partner.common.config.PartnerAgentConfigManager; @Slf4j public abstract class VectorClient { diff --git a/Partner-Main/src/main/java/work/slhaf/partner/core/action/ActionCapability.java b/Partner-Main/src/main/java/work/slhaf/partner/core/action/ActionCapability.java index 25fa174a..c75bd8b7 100644 --- a/Partner-Main/src/main/java/work/slhaf/partner/core/action/ActionCapability.java +++ b/Partner-Main/src/main/java/work/slhaf/partner/core/action/ActionCapability.java @@ -1,6 +1,7 @@ package work.slhaf.partner.core.action; import work.slhaf.partner.api.agent.factory.capability.annotation.Capability; +import work.slhaf.partner.core.action.entity.CacheAdjustData; import work.slhaf.partner.core.action.entity.MetaActionInfo; import java.util.List; @@ -19,5 +20,7 @@ public interface ActionCapability { void putPendingActions(String userId, MetaActionInfo metaActionInfo); - List computeActionCache(String input); + List selectTendencyCache(String input); + + void updateTendencyCache(List list); } diff --git a/Partner-Main/src/main/java/work/slhaf/partner/core/action/ActionCore.java b/Partner-Main/src/main/java/work/slhaf/partner/core/action/ActionCore.java index ee69d8a1..e412d02d 100644 --- a/Partner-Main/src/main/java/work/slhaf/partner/core/action/ActionCore.java +++ b/Partner-Main/src/main/java/work/slhaf/partner/core/action/ActionCore.java @@ -1,21 +1,20 @@ package work.slhaf.partner.core.action; -import lombok.Getter; -import lombok.Setter; import work.slhaf.partner.api.agent.factory.capability.annotation.Capability; import work.slhaf.partner.api.agent.factory.capability.annotation.CapabilityMethod; -import work.slhaf.partner.common.util.VectorUtil; +import work.slhaf.partner.common.vector.VectorClient; import work.slhaf.partner.core.PartnerCore; import work.slhaf.partner.core.action.entity.ActionCacheData; +import work.slhaf.partner.core.action.entity.CacheAdjustData; import work.slhaf.partner.core.action.entity.MetaActionInfo; import java.io.IOException; import java.util.ArrayList; import java.util.HashMap; import java.util.List; +import java.util.stream.Collectors; -@Setter -@Getter +@SuppressWarnings("FieldMayBeFinal") @Capability(value = "action") public class ActionCore extends PartnerCore { @@ -34,7 +33,6 @@ public class ActionCore extends PartnerCore { */ private List actionCache = new ArrayList<>(); - //TODO 添加语义缓存,可借由简单向量模型,设想以向量结果为键、行动倾向为值 public ActionCore() throws IOException, ClassNotFoundException { } @@ -80,19 +78,70 @@ public class ActionCore extends PartnerCore { return pendingActions.get(userId); } + /** + * 计算输入内容的语义向量,根据与{@link ActionCacheData#getInputVector()}的相似度挑取缓存,后续将根据评估结果来更新计数 + * + * @param input 本次输入内容 + * @return 命中的行为倾向集合 + */ @CapabilityMethod - public List computeActionCache(String input){ + public List selectTendencyCache(String input) { + if (!VectorClient.status) { + return null; + } + VectorClient vectorClient = VectorClient.INSTANCE; //计算本次输入的向量 - float[] vector = VectorUtil.compute(input); + float[] vector = vectorClient.compute(input); if (vector == null) return null; - //与现有缓存比对,如果存在,则使缓存计数+1 - actionCache.stream() + //与现有缓存比对,将匹配到的收集并返回 + return actionCache.parallelStream() .filter(ActionCacheData::isActivated) - .forEach(data -> { - double compared = VectorUtil.compare(vector, data.getInputVector()); - }); + .filter(data -> { + double compared = vectorClient.compare(vector, data.getInputVector()); + return compared > data.getThreshold(); + }) + .map(ActionCacheData::getTendency) + .collect(Collectors.toList()); + } + + @CapabilityMethod + public void updateTendencyCache(List list) { + List matchAndPassed = new ArrayList<>(); + List matchNotPassed = new ArrayList<>(); + List notMatchPassed = new ArrayList<>(); + + for (CacheAdjustData data : list) { + if (data.isHit() && data.isPassed()) { + matchAndPassed.add(data); + } else if (data.isHit()) { + matchNotPassed.add(data); + } else if (!data.isPassed()) { + notMatchPassed.add(data); + } + } + + VectorClient vectorClient = VectorClient.INSTANCE; + adjustMatchAndPassed(matchAndPassed, vectorClient); + adjustMatchNotPassed(matchNotPassed, vectorClient); + adjustNotMatchPassed(notMatchPassed, vectorClient); + } + + /** + * 命中缓存且评估通过时,根据输入内容的语义向量与现有的输入语义向量进行带权移动平均,以相似度为权重 + * + * @param matchAndPassed 该类型的带调整缓存信息列表 + * @param vectorClient 向量客户端 + */ + private void adjustMatchAndPassed(List matchAndPassed, VectorClient vectorClient) { + + } + + private void adjustMatchNotPassed(List matchNotPassed, VectorClient vectorClient) { + + } + + private void adjustNotMatchPassed(List notMatchPassed, VectorClient vectorClient) { - return null; } @Override diff --git a/Partner-Main/src/main/java/work/slhaf/partner/core/action/entity/ActionCacheData.java b/Partner-Main/src/main/java/work/slhaf/partner/core/action/entity/ActionCacheData.java index abc0a24a..ce859d6b 100644 --- a/Partner-Main/src/main/java/work/slhaf/partner/core/action/entity/ActionCacheData.java +++ b/Partner-Main/src/main/java/work/slhaf/partner/core/action/entity/ActionCacheData.java @@ -7,11 +7,19 @@ import java.util.List; @Data public class ActionCacheData { + private boolean activated; + private int inputMatchCount; + private float[] inputVector; private float[] tendencyVector; private String tendency; - private int inputMatchCount; - private boolean activated; - private List validSamples = new ArrayList<>(); private double threshold; + + private List validSamples = new ArrayList<>(); + private int failedCount; + private Type type; + + enum Type { + PRIMARY, REBUILD_V1, REBUILD_V2, REBUILD_V3 + } } diff --git a/Partner-Main/src/main/java/work/slhaf/partner/core/action/entity/CacheAdjustData.java b/Partner-Main/src/main/java/work/slhaf/partner/core/action/entity/CacheAdjustData.java new file mode 100644 index 00000000..24a06ba3 --- /dev/null +++ b/Partner-Main/src/main/java/work/slhaf/partner/core/action/entity/CacheAdjustData.java @@ -0,0 +1,11 @@ +package work.slhaf.partner.core.action.entity; + +import lombok.Data; + +@Data +public class CacheAdjustData { + private String input; + private String tendency; + private boolean passed; + private boolean hit; +} diff --git a/Partner-Main/src/main/java/work/slhaf/partner/core/memory/MemoryCapability.java b/Partner-Main/src/main/java/work/slhaf/partner/core/memory/MemoryCapability.java index 0b62a9e8..c918fdb6 100644 --- a/Partner-Main/src/main/java/work/slhaf/partner/core/memory/MemoryCapability.java +++ b/Partner-Main/src/main/java/work/slhaf/partner/core/memory/MemoryCapability.java @@ -5,7 +5,6 @@ import work.slhaf.partner.core.memory.pojo.EvaluatedSlice; import work.slhaf.partner.core.memory.pojo.MemoryResult; import work.slhaf.partner.core.memory.pojo.MemorySlice; -import java.io.IOException; import java.time.LocalDate; import java.time.LocalDateTime; import java.util.HashMap; @@ -45,7 +44,7 @@ public interface MemoryCapability { MemoryResult selectMemory(String topicPathStr); - MemoryResult selectMemory(LocalDate date) throws IOException, ClassNotFoundException; + MemoryResult selectMemory(LocalDate date); void insertSlice(MemorySlice memorySlice, String topicPath); diff --git a/Partner-Main/src/main/java/work/slhaf/partner/module/common/module/PostRunningModule.java b/Partner-Main/src/main/java/work/slhaf/partner/module/common/module/PostRunningModule.java index 9bfc1277..f3c5f01f 100644 --- a/Partner-Main/src/main/java/work/slhaf/partner/module/common/module/PostRunningModule.java +++ b/Partner-Main/src/main/java/work/slhaf/partner/module/common/module/PostRunningModule.java @@ -3,12 +3,10 @@ package work.slhaf.partner.module.common.module; import work.slhaf.partner.api.agent.runtime.interaction.flow.abstracts.AgentRunningModule; import work.slhaf.partner.runtime.interaction.data.context.PartnerRunningFlowContext; -import java.io.IOException; - public abstract class PostRunningModule extends AgentRunningModule { @Override - public final void execute(PartnerRunningFlowContext context) throws IOException, ClassNotFoundException { + public final void execute(PartnerRunningFlowContext context) { boolean trigger = context.getModuleContext().getExtraContext().getBoolean("post_process_trigger"); if (!trigger) { return; diff --git a/Partner-Main/src/main/java/work/slhaf/partner/module/common/module/PreRunningModule.java b/Partner-Main/src/main/java/work/slhaf/partner/module/common/module/PreRunningModule.java index f72e4629..49f7097c 100644 --- a/Partner-Main/src/main/java/work/slhaf/partner/module/common/module/PreRunningModule.java +++ b/Partner-Main/src/main/java/work/slhaf/partner/module/common/module/PreRunningModule.java @@ -5,7 +5,6 @@ import work.slhaf.partner.api.agent.runtime.interaction.flow.abstracts.AgentRunn import work.slhaf.partner.module.common.entity.AppendPromptData; import work.slhaf.partner.runtime.interaction.data.context.PartnerRunningFlowContext; -import java.io.IOException; import java.util.HashMap; /** @@ -33,13 +32,13 @@ public abstract class PreRunningModule extends AgentRunningModule evaluatorResults = actionEvaluator.execute(evaluatorInput); - setupPreparedActionInfo(evaluatorResults, context); + List evaluatorResults = actionEvaluator.execute(evaluatorInput); //并发操作均为访问 + if (extractorResult.isCacheEnabled()) + updateTendencyCache(evaluatorResults, context.getInput(), extractorResult); + setupActionInfo(evaluatorResults, context); return null; }); } + private void updateTendencyCache(List evaluatorResults, String input, ExtractorResult extractorResult) { + executor.execute(() -> { + List list = new ArrayList<>(); + List hitTendencies = extractorResult.getTendencies(); + for (EvaluatorResult result : evaluatorResults) { + CacheAdjustData data = new CacheAdjustData(); + data.setTendency(result.getTendency()); + data.setInput(input); + data.setPassed(result.isOk()); + data.setHit(hitTendencies.contains(result.getTendency())); + list.add(data); + } + actionCapability.updateTendencyCache(list); + }); + + } + /** * 待确认行动的判断任务 * @@ -119,13 +138,12 @@ public class ActionPlanner extends PreRunningModule { } - private void setupPreparedActionInfo(List evaluatorResults, PartnerRunningFlowContext context) { + private void setupActionInfo(List evaluatorResults, PartnerRunningFlowContext context) { for (EvaluatorResult evaluatorResult : evaluatorResults) { + MetaActionInfo metaActionInfo = assemblyHelper.buildMetaActionInfo(evaluatorResult); if (evaluatorResult.isNeedConfirm()) { - MetaActionInfo metaActionInfo = assemblyHelper.buildMetaActionInfo(evaluatorResult); actionCapability.putPendingActions(context.getUserId(), metaActionInfo); } else { - MetaActionInfo metaActionInfo = assemblyHelper.buildMetaActionInfo(evaluatorResult); actionCapability.putPreparedAction(context.getUuid(), metaActionInfo); } } @@ -192,7 +210,7 @@ public class ActionPlanner extends PreRunningModule { return input; } - public EvaluatorInput buildEvaluatorInput(ExtractorResult extractorResult, String userId) { + private EvaluatorInput buildEvaluatorInput(ExtractorResult extractorResult, String userId) { EvaluatorInput input = new EvaluatorInput(); input.setTendencies(extractorResult.getTendencies()); input.setUser(perceiveCapability.getUser(userId)); @@ -201,7 +219,7 @@ public class ActionPlanner extends PreRunningModule { return input; } - public MetaActionInfo buildMetaActionInfo(EvaluatorResult evaluatorResult) { + private MetaActionInfo buildMetaActionInfo(EvaluatorResult evaluatorResult) { return switch (evaluatorResult.getType()) { case PLANNING -> { ScheduledActionInfo actionInfo = new ScheduledActionInfo(); @@ -221,7 +239,7 @@ public class ActionPlanner extends PreRunningModule { }; } - public ConfirmerInput buildConfirmerInput(PartnerRunningFlowContext context) { + private ConfirmerInput buildConfirmerInput(PartnerRunningFlowContext context) { ConfirmerInput confirmerInput = new ConfirmerInput(); confirmerInput.setInput(context.getInput()); List pendingActions = actionCapability.listPendingAction(context.getUserId()); diff --git a/Partner-Main/src/main/java/work/slhaf/partner/module/modules/action/planner/evaluator/ActionEvaluator.java b/Partner-Main/src/main/java/work/slhaf/partner/module/modules/action/planner/evaluator/ActionEvaluator.java index b655fb26..5fd56b9b 100644 --- a/Partner-Main/src/main/java/work/slhaf/partner/module/modules/action/planner/evaluator/ActionEvaluator.java +++ b/Partner-Main/src/main/java/work/slhaf/partner/module/modules/action/planner/evaluator/ActionEvaluator.java @@ -1,20 +1,67 @@ package work.slhaf.partner.module.modules.action.planner.evaluator; +import cn.hutool.core.bean.BeanUtil; +import com.alibaba.fastjson2.JSONObject; import work.slhaf.partner.api.agent.factory.module.annotation.AgentSubModule; +import work.slhaf.partner.api.agent.factory.module.annotation.Init; import work.slhaf.partner.api.agent.runtime.interaction.flow.abstracts.ActivateModel; import work.slhaf.partner.api.agent.runtime.interaction.flow.abstracts.AgentRunningSubModule; +import work.slhaf.partner.api.chat.pojo.ChatResponse; +import work.slhaf.partner.common.thread.InteractionThreadPoolExecutor; +import work.slhaf.partner.module.modules.action.planner.evaluator.entity.EvaluatorBatchInput; import work.slhaf.partner.module.modules.action.planner.evaluator.entity.EvaluatorInput; import work.slhaf.partner.module.modules.action.planner.evaluator.entity.EvaluatorResult; +import java.util.ArrayList; import java.util.List; +import java.util.concurrent.Callable; @AgentSubModule public class ActionEvaluator extends AgentRunningSubModule> implements ActivateModel { + private InteractionThreadPoolExecutor executor; + + @Init + public void init() { + executor = InteractionThreadPoolExecutor.getInstance(); + } + + /** + * 对输入的行为倾向进行评估,并根据评估结果,对缓存做出调整 + * + * @param data 评估输入内容,包含提取/命中缓存的行动倾向、近几条聊天记录,正在生效的记忆切片内容 + * @return 评估结果集合,包含 + */ @Override public List execute(EvaluatorInput data) { + List batchInputs = buildEvaluatorBatchInput(data); + List> tasks = getTasks(batchInputs); + return executor.invokeAllAndReturn(tasks); + } - return null; + + private List> getTasks(List batchInputs) { + List> list = new ArrayList<>(); + for (EvaluatorBatchInput batchInput : batchInputs) { + list.add(() -> { + ChatResponse response = this.singleChat(JSONObject.toJSONString(batchInput)); + EvaluatorResult evaluatorResult = JSONObject.parseObject(response.getMessage(), EvaluatorResult.class); + evaluatorResult.setTendency(batchInput.getTendency()); + return evaluatorResult; + }); + } + return list; + } + + private List buildEvaluatorBatchInput(EvaluatorInput data) { + List list = new ArrayList<>(); + for (String tendency : data.getTendencies()) { + EvaluatorBatchInput temp = new EvaluatorBatchInput(); + BeanUtil.copyProperties(data, temp); + temp.setTendency(tendency); + list.add(temp); + } + return list; } @Override diff --git a/Partner-Main/src/main/java/work/slhaf/partner/module/modules/action/planner/evaluator/entity/EvaluatorBatchInput.java b/Partner-Main/src/main/java/work/slhaf/partner/module/modules/action/planner/evaluator/entity/EvaluatorBatchInput.java new file mode 100644 index 00000000..33114b9a --- /dev/null +++ b/Partner-Main/src/main/java/work/slhaf/partner/module/modules/action/planner/evaluator/entity/EvaluatorBatchInput.java @@ -0,0 +1,14 @@ +package work.slhaf.partner.module.modules.action.planner.evaluator.entity; + +import lombok.Data; +import work.slhaf.partner.api.chat.pojo.Message; +import work.slhaf.partner.core.memory.pojo.EvaluatedSlice; + +import java.util.List; + +@Data +public class EvaluatorBatchInput { + private List recentMessages; + private List activatedSlices; + private String tendency; +} diff --git a/Partner-Main/src/main/java/work/slhaf/partner/module/modules/action/planner/evaluator/entity/EvaluatorResult.java b/Partner-Main/src/main/java/work/slhaf/partner/module/modules/action/planner/evaluator/entity/EvaluatorResult.java index e1d302df..797edc1b 100644 --- a/Partner-Main/src/main/java/work/slhaf/partner/module/modules/action/planner/evaluator/entity/EvaluatorResult.java +++ b/Partner-Main/src/main/java/work/slhaf/partner/module/modules/action/planner/evaluator/entity/EvaluatorResult.java @@ -11,4 +11,5 @@ public class EvaluatorResult { private ActionType type; private String scheduleContent; private ActionData actionData; + private String tendency; } diff --git a/Partner-Main/src/main/java/work/slhaf/partner/module/modules/action/planner/extractor/ActionExtractor.java b/Partner-Main/src/main/java/work/slhaf/partner/module/modules/action/planner/extractor/ActionExtractor.java index 6618ad6f..a9f19f43 100644 --- a/Partner-Main/src/main/java/work/slhaf/partner/module/modules/action/planner/extractor/ActionExtractor.java +++ b/Partner-Main/src/main/java/work/slhaf/partner/module/modules/action/planner/extractor/ActionExtractor.java @@ -1,21 +1,18 @@ package work.slhaf.partner.module.modules.action.planner.extractor; -import java.util.List; - import com.alibaba.fastjson2.JSONObject; - import lombok.extern.slf4j.Slf4j; import work.slhaf.partner.api.agent.factory.capability.annotation.InjectCapability; import work.slhaf.partner.api.agent.factory.module.annotation.AgentSubModule; import work.slhaf.partner.api.agent.runtime.interaction.flow.abstracts.ActivateModel; import work.slhaf.partner.api.agent.runtime.interaction.flow.abstracts.AgentRunningSubModule; -import work.slhaf.partner.api.chat.constant.ChatConstant; import work.slhaf.partner.api.chat.pojo.ChatResponse; -import work.slhaf.partner.api.chat.pojo.Message; import work.slhaf.partner.core.action.ActionCapability; import work.slhaf.partner.module.modules.action.planner.extractor.entity.ExtractorInput; import work.slhaf.partner.module.modules.action.planner.extractor.entity.ExtractorResult; +import java.util.List; + @Slf4j @AgentSubModule public class ActionExtractor extends AgentRunningSubModule implements ActivateModel { @@ -25,17 +22,17 @@ public class ActionExtractor extends AgentRunningSubModule tendencyCache = actionCapability.computeActionCache(data.getInput()); - if ( tendencyCache == null || !tendencyCache.isEmpty()) { - ExtractorResult result = new ExtractorResult(); + ExtractorResult result = new ExtractorResult(); + List tendencyCache = actionCapability.selectTendencyCache(data.getInput()); + result.setCacheEnabled(tendencyCache != null); + if (tendencyCache != null && !tendencyCache.isEmpty()) { + result.setTendencies(tendencyCache); return result; } for (int i = 0; i < 3; i++) { try { - this.chatMessages().add(new Message(ChatConstant.Character.USER, JSONObject.toJSONString(data))); - ChatResponse response = this.chat(); + ChatResponse response = this.singleChat(JSONObject.toJSONString(data)); return JSONObject.parseObject(response.getMessage(), ExtractorResult.class); } catch (Exception e) { log.error("[ActionExtractor] 提取信息出错", e); diff --git a/Partner-Main/src/main/java/work/slhaf/partner/module/modules/action/planner/extractor/entity/ExtractorResult.java b/Partner-Main/src/main/java/work/slhaf/partner/module/modules/action/planner/extractor/entity/ExtractorResult.java index fd13d86c..5d8791d5 100644 --- a/Partner-Main/src/main/java/work/slhaf/partner/module/modules/action/planner/extractor/entity/ExtractorResult.java +++ b/Partner-Main/src/main/java/work/slhaf/partner/module/modules/action/planner/extractor/entity/ExtractorResult.java @@ -7,5 +7,6 @@ import java.util.List; @Data public class ExtractorResult { + private boolean cacheEnabled; private List tendencies = new ArrayList<>(); } diff --git a/Partner-Main/src/main/java/work/slhaf/partner/module/modules/memory/selector/MemorySelector.java b/Partner-Main/src/main/java/work/slhaf/partner/module/modules/memory/selector/MemorySelector.java index 4dc3fd9c..d35bec48 100644 --- a/Partner-Main/src/main/java/work/slhaf/partner/module/modules/memory/selector/MemorySelector.java +++ b/Partner-Main/src/main/java/work/slhaf/partner/module/modules/memory/selector/MemorySelector.java @@ -22,7 +22,6 @@ import work.slhaf.partner.module.modules.memory.selector.extractor.entity.Extrac import work.slhaf.partner.module.modules.memory.selector.extractor.entity.ExtractorResult; import work.slhaf.partner.runtime.interaction.data.context.PartnerRunningFlowContext; -import java.io.IOException; import java.time.LocalDate; import java.util.ArrayList; import java.util.Collection; @@ -32,7 +31,7 @@ import java.util.List; @EqualsAndHashCode(callSuper = true) @Data @Slf4j -@AgentModule(name="memory_selector",order=2) +@AgentModule(name = "memory_selector", order = 2) public class MemorySelector extends PreRunningModule { @InjectCapability @@ -46,7 +45,7 @@ public class MemorySelector extends PreRunningModule { private MemorySelectExtractor memorySelectExtractor; @Override - public void doExecute(PartnerRunningFlowContext runningFlowContext) throws IOException, ClassNotFoundException { + public void doExecute(PartnerRunningFlowContext runningFlowContext) { String userId = runningFlowContext.getUserId(); //获取主题路径 ExtractorResult extractorResult = memorySelectExtractor.execute(runningFlowContext); @@ -58,7 +57,7 @@ public class MemorySelector extends PreRunningModule { setModuleContextRecall(runningFlowContext); } - private List selectAndEvaluateMemory(PartnerRunningFlowContext runningFlowContext, ExtractorResult extractorResult) throws IOException, ClassNotFoundException { + private List selectAndEvaluateMemory(PartnerRunningFlowContext runningFlowContext, ExtractorResult extractorResult) { log.debug("[MemorySelector] 触发记忆回溯..."); //查找切片 String userId = runningFlowContext.getUserId(); @@ -86,7 +85,7 @@ public class MemorySelector extends PreRunningModule { } - private void setMemoryResultList(List memoryResultList, List matches, String userId) throws IOException, ClassNotFoundException { + private void setMemoryResultList(List memoryResultList, List matches, String userId) { for (ExtractorMatchData match : matches) { try { MemoryResult memoryResult = switch (match.getType()) { diff --git a/Partner-Main/src/main/java/work/slhaf/partner/module/modules/process/PostprocessExecutor.java b/Partner-Main/src/main/java/work/slhaf/partner/module/modules/process/PostprocessExecutor.java index 8148abd3..15d6a827 100644 --- a/Partner-Main/src/main/java/work/slhaf/partner/module/modules/process/PostprocessExecutor.java +++ b/Partner-Main/src/main/java/work/slhaf/partner/module/modules/process/PostprocessExecutor.java @@ -9,8 +9,6 @@ import work.slhaf.partner.api.agent.runtime.interaction.flow.abstracts.AgentRunn import work.slhaf.partner.core.cognation.CognationCapability; import work.slhaf.partner.runtime.interaction.data.context.PartnerRunningFlowContext; -import java.io.IOException; - @EqualsAndHashCode(callSuper = true) @Slf4j @Data @@ -23,7 +21,7 @@ public class PostprocessExecutor extends AgentRunningModule= POST_PROCESS_TRIGGER_ROLL_LIMIT; context.getModuleContext().getExtraContext().put("post_process_trigger", trigger); log.debug("[PostprocessExecutor] 是否执行后处理: {}", trigger);