refactor(modules): refactor modules base class into AbstractAgentModule and remove unused @slf4j annotation

This commit is contained in:
2026-02-20 17:17:49 +08:00
parent 38c618a222
commit c47d2b2285
29 changed files with 90 additions and 537 deletions

View File

@@ -1,10 +1,9 @@
package work.slhaf.partner.module.common.module;
import work.slhaf.partner.api.agent.factory.module.abstracts.AbstractAgentRunningModule;
import work.slhaf.partner.api.agent.factory.module.abstracts.AbstractAgentModule;
import work.slhaf.partner.runtime.interaction.data.context.PartnerRunningFlowContext;
public abstract class PostRunningAbstractAgentModuleAbstract extends AbstractAgentRunningModule<PartnerRunningFlowContext> {
public abstract class PostRunningAbstractAgentModuleAbstract extends AbstractAgentModule.Running<PartnerRunningFlowContext> {
@Override
public final void execute(PartnerRunningFlowContext context) {
boolean trigger = context.getModuleContext().getExtraContext().getBoolean("post_process_trigger");
@@ -13,8 +12,6 @@ public abstract class PostRunningAbstractAgentModuleAbstract extends AbstractAge
}
doExecute(context);
}
public abstract void doExecute(PartnerRunningFlowContext context);
protected abstract boolean relyOnMessage();
}

View File

@@ -1,17 +1,14 @@
package work.slhaf.partner.module.common.module;
import lombok.extern.slf4j.Slf4j;
import work.slhaf.partner.api.agent.factory.module.abstracts.AbstractAgentRunningModule;
import work.slhaf.partner.api.agent.factory.module.abstracts.AbstractAgentModule;
import work.slhaf.partner.module.common.entity.AppendPromptData;
import work.slhaf.partner.runtime.interaction.data.context.PartnerRunningFlowContext;
import java.util.Map;
/**
* 前置模块抽象类
*/
@Slf4j
public abstract class PreRunningAbstractAgentModuleAbstract extends AbstractAgentRunningModule<PartnerRunningFlowContext> {
public abstract class PreRunningAbstractAgentModuleAbstract extends AbstractAgentModule.Running<PartnerRunningFlowContext> {
private synchronized void setAppendedPrompt(PartnerRunningFlowContext context) {
AppendPromptData data = new AppendPromptData();
data.setModuleName(moduleName());
@@ -19,25 +16,19 @@ public abstract class PreRunningAbstractAgentModuleAbstract extends AbstractAgen
data.setAppendedPrompt(map);
context.setAppendedPrompt(data);
}
private synchronized void setActiveModule(PartnerRunningFlowContext context) {
context.getCoreContext().addActiveModule(moduleName());
}
protected abstract Map<String, String> getPromptDataMap(PartnerRunningFlowContext context);
/**
* 用于在CoreModule接收到的模块Prompt中标识模块名称
*/
protected abstract String moduleName();
@Override
public final void execute(PartnerRunningFlowContext context) {
doExecute(context); // 子类实现差异化逻辑
setAppendedPrompt(context); // 通用逻辑
setActiveModule(context); // 通用逻辑
}
protected abstract void doExecute(PartnerRunningFlowContext context);
}

View File

@@ -2,7 +2,6 @@ package work.slhaf.partner.module.modules.action.dispatcher;
import lombok.val;
import work.slhaf.partner.api.agent.factory.capability.annotation.InjectCapability;
import work.slhaf.partner.api.agent.factory.module.annotation.AgentRunningModule;
import work.slhaf.partner.api.agent.factory.module.annotation.Init;
import work.slhaf.partner.api.agent.factory.module.annotation.InjectModule;
import work.slhaf.partner.core.action.ActionCapability;
@@ -19,25 +18,18 @@ import work.slhaf.partner.runtime.interaction.data.context.PartnerRunningFlowCon
import java.util.HashSet;
import java.util.Set;
import java.util.concurrent.ExecutorService;
@AgentRunningModule(name = "action_dispatcher", order = 7)
public class ActionDispatcher extends PostRunningAbstractAgentModuleAbstract {
@InjectCapability
private ActionCapability actionCapability;
@InjectModule
private ActionExecutor actionExecutor;
@InjectModule
private ActionScheduler actionScheduler;
private ExecutorService executor;
@Init
public void init() {
executor = actionCapability.getExecutor(ActionCore.ExecutorType.VIRTUAL);
}
@Override
public void doExecute(PartnerRunningFlowContext context) {
// 只需要处理prepared action因为pending action在用户确认后也将变为prepared action
@@ -61,10 +53,13 @@ public class ActionDispatcher extends PostRunningAbstractAgentModuleAbstract {
actionScheduler.execute(scheduledActions);
});
}
@Override
protected boolean relyOnMessage() {
return false;
}
@Override
public int order() {
return 7;
}
}

View File

@@ -1,52 +1,40 @@
package work.slhaf.partner.module.modules.action.dispatcher.executor;
import com.alibaba.fastjson2.JSONObject;
import lombok.val;
import work.slhaf.partner.api.agent.factory.module.abstracts.AbstractAgentSubModule;
import work.slhaf.partner.api.agent.factory.module.abstracts.AbstractAgentModule;
import work.slhaf.partner.api.agent.factory.module.abstracts.ActivateModel;
import work.slhaf.partner.api.agent.factory.module.annotation.AgentSubModule;
import work.slhaf.partner.module.modules.action.dispatcher.executor.entity.CorrectorInput;
import work.slhaf.partner.module.modules.action.dispatcher.executor.entity.CorrectorResult;
/**
* 负责在单组行动执行后,根据行动意图与结果检查后续行动是否符合目的,必要时直接调整行动链,或发起自对话请求进行干预
*/
@AgentSubModule
public class ActionCorrector extends AbstractAgentSubModule<CorrectorInput, CorrectorResult> implements ActivateModel {
public class ActionCorrector extends AbstractAgentModule.Sub<CorrectorInput, CorrectorResult> implements ActivateModel {
@Override
public CorrectorResult execute(CorrectorInput input) {
val prompt = buildPrompt(input);
val chatResponse = singleChat(prompt);
return JSONObject.parseObject(chatResponse.getMessage(), CorrectorResult.class);
}
private String buildPrompt(CorrectorInput input) {
val prompt = new JSONObject();
prompt.put("[行动来源]", input.getSource());
prompt.put("[行动倾向]", input.getTendency());
prompt.put("[行动描述]", input.getDescription());
prompt.put("[行动原因]", input.getReason());
val messages = prompt.putArray("[近期对话]");
messages.addAll(input.getRecentMessages());
val memory = prompt.putArray("[已激活记忆]");
memory.addAll(input.getActivatedSlices());
val history = prompt.putArray("[已执行情况]");
history.addAll(input.getHistory());
return prompt.toJSONString();
}
@Override
public String modelKey() {
return "action_corrector";
}
@Override
public boolean withBasicPrompt() {
return false;
}
}

View File

@@ -1,10 +1,8 @@
package work.slhaf.partner.module.modules.action.dispatcher.executor;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import work.slhaf.partner.api.agent.factory.capability.annotation.InjectCapability;
import work.slhaf.partner.api.agent.factory.module.abstracts.AbstractAgentSubModule;
import work.slhaf.partner.api.agent.factory.module.annotation.AgentSubModule;
import work.slhaf.partner.api.agent.factory.module.abstracts.AbstractAgentModule;
import work.slhaf.partner.api.agent.factory.module.annotation.Init;
import work.slhaf.partner.api.agent.factory.module.annotation.InjectModule;
import work.slhaf.partner.core.action.ActionCapability;
@@ -26,17 +24,13 @@ import java.util.concurrent.Phaser;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
@Slf4j
@AgentSubModule
public class ActionExecutor extends AbstractAgentSubModule<ActionExecutorInput, Void> {
public class ActionExecutor extends AbstractAgentModule.Sub<ActionExecutorInput, Void> {
@InjectCapability
private ActionCapability actionCapability;
@InjectCapability
private MemoryCapability memoryCapability;
@InjectCapability
private CognationCapability cognationCapability;
@InjectModule
private ParamsExtractor paramsExtractor;
@InjectModule
@@ -45,20 +39,16 @@ public class ActionExecutor extends AbstractAgentSubModule<ActionExecutorInput,
private ActionCorrector actionCorrector;
@InjectModule
private ActionScheduler actionScheduler;
private ExecutorService virtualExecutor;
private ExecutorService platformExecutor;
private RunnerClient runnerClient;
private final AssemblyHelper assemblyHelper = new AssemblyHelper();
@Init
public void init() {
virtualExecutor = actionCapability.getExecutor(ActionCore.ExecutorType.VIRTUAL);
platformExecutor = actionCapability.getExecutor(ActionCore.ExecutorType.PLATFORM);
runnerClient = actionCapability.runnerClient();
}
/**
* 执行行动
*
@@ -85,62 +75,51 @@ public class ActionExecutor extends AbstractAgentSubModule<ActionExecutorInput,
val phaser = new Phaser();
val phaserRecord = actionCapability.putPhaserRecord(phaser, executableAction);
executableAction.setStatus(Status.EXECUTING);
// 开始执行
val stageCursor = new Object() {
int stageCount;
boolean executingStageUpdated;
boolean stageCountUpdated;
void init() {
stageCount = 0;
executingStageUpdated = false;
stageCountUpdated = false;
update();
}
void requestAdvance() {
if (!stageCountUpdated) {
stageCount++;
stageCountUpdated = true;
}
if (stageCount < actionChain.size() && !executingStageUpdated) {
update();
executingStageUpdated = true;
}
}
boolean next() {
executingStageUpdated = false;
stageCountUpdated = false;
return stageCount < actionChain.size();
}
void update() {
val orderList = new ArrayList<>(actionChain.keySet());
orderList.sort(Integer::compareTo);
executableAction.setExecutingStage(orderList.get(stageCount));
}
};
stageCursor.init();
do {
val metaActions = actionChain.get(executableAction.getExecutingStage());
val listeningRecord = executeAndListening(metaActions, phaserRecord, source);
phaser.awaitAdvance(listeningRecord.phase());
// synchronized 同步防止 accepting 循环间、phase guard 判定后发生 stage 推进
// 导致新行动的 phaser 投放阶段错乱无法阻塞的场景
// 该 synchronized 将阶段推进与 accepting 监听 loop 捆绑为互斥的原子事件,避免了细粒度的 phaser 阶段竞态问题
synchronized (listeningRecord.accepting()) {
listeningRecord.accepting().set(false);
// 立即尝试推进,本次推进中,如果前方仍有未执行 stage将执行一次阶段推进
stageCursor.requestAdvance();
}
try {
// 针对行动链进行修正,修正需要传入执行历史、行动目标等内容
// 如果后续运行 corrector 触发频率较高,可考虑增加重试机制
@@ -149,12 +128,10 @@ public class ActionExecutor extends AbstractAgentSubModule<ActionExecutorInput,
actionCapability.handleInterventions(correctorResult.getMetaInterventionList(), executableAction);
} catch (Exception ignored) {
}
// 第二次尝试进行阶段推进,本次负责补充上一次在不存在 stage时但 corrector 执行期间发生了 actionChain 的插入事件
// 如果第一次已经推进完毕,本次将会跳过
stageCursor.requestAdvance();
} while (stageCursor.next());
// 结束
actionCapability.removePhaserRecord(phaser);
if (executableAction.getStatus() != Status.FAILED) {
@@ -165,20 +142,15 @@ public class ActionExecutor extends AbstractAgentSubModule<ActionExecutorInput,
} else {
executableAction.setStatus(Status.SUCCESS);
}
// TODO 执行过后需要回写至任务上下文recentCompletedTask同时触发自对话信号进行确认并记录以及是否通知用户触发与否需要机制进行匹配在模块链路可增加 interaction gate 门控,判断此次对话作用于谁、由谁发出、何种性质、是否需要回应等)
}
});
}
return null;
}
private MetaActionsListeningRecord executeAndListening(List<MetaAction> metaActions, PhaserRecord phaserRecord, String source) {
AtomicBoolean accepting = new AtomicBoolean(true);
AtomicInteger cursor = new AtomicInteger();
CountDownLatch latch = new CountDownLatch(1);
val phaser = phaserRecord.phaser();
val phase = phaser.register();
@@ -187,27 +159,22 @@ public class ActionExecutor extends AbstractAgentSubModule<ActionExecutorInput,
while (accepting.get()) {
synchronized (accepting) {
MetaAction next = null;
synchronized (metaActions) {
if (cursor.get() < metaActions.size()) {
next = metaActions.get(cursor.getAndIncrement());
}
}
if (next == null) {
Thread.onSpinWait();
continue;
}
if (phaser.getPhase() != phase) {
metaActions.remove(next);
log.warn("行动阶段已推进,丢弃该行动: {}", next);
continue;
}
ExecutorService executor = next.getIo() ? virtualExecutor : platformExecutor;
executor.execute(buildMataActionTask(next, phaserRecord, source));
if (first) {
phaser.arriveAndDeregister();
latch.countDown();
@@ -223,7 +190,6 @@ public class ActionExecutor extends AbstractAgentSubModule<ActionExecutorInput,
}
return new MetaActionsListeningRecord(accepting, phase);
}
private Runnable buildMataActionTask(MetaAction metaAction, PhaserRecord phaserRecord, String source) {
val phaser = phaserRecord.phaser();
phaser.register();
@@ -238,7 +204,6 @@ public class ActionExecutor extends AbstractAgentSubModule<ActionExecutorInput,
val additionalContext = actionData.getAdditionalContext().get(executingStage);
val extractorInput = assemblyHelper.buildExtractorInput(metaAction, source, historyActionResults, additionalContext);
val extractorResult = paramsExtractor.execute(extractorInput);
if (extractorResult.isOk()) {
metaAction.getParams().putAll(extractorResult.getParams());
runnerClient.submit(metaAction);
@@ -275,16 +240,12 @@ public class ActionExecutor extends AbstractAgentSubModule<ActionExecutorInput,
}
};
}
private record MetaActionsListeningRecord(AtomicBoolean accepting, int phase) {
}
@SuppressWarnings("InnerClassMayBeStatic")
private class AssemblyHelper {
private AssemblyHelper() {
}
private RepairerInput buildRepairerInput(List<HistoryAction> historyActionsResults, MetaAction action, String userId) {
RepairerInput input = new RepairerInput();
MetaActionInfo metaActionInfo = actionCapability.loadMetaActionInfo(action.getKey());
@@ -295,7 +256,6 @@ public class ActionExecutor extends AbstractAgentSubModule<ActionExecutorInput,
input.setUserId(userId);
return input;
}
private ExtractorInput buildExtractorInput(MetaAction action, String source, List<HistoryAction> historyActionResults,
List<String> additionalContext) {
ExtractorInput input = new ExtractorInput();
@@ -306,7 +266,6 @@ public class ActionExecutor extends AbstractAgentSubModule<ActionExecutorInput,
input.setAdditionalContext(additionalContext);
return input;
}
private CorrectorInput buildCorrectorInput(ExecutableAction executableAction, String source) {
return CorrectorInput.builder()
.tendency(executableAction.getTendency())
@@ -320,5 +279,4 @@ public class ActionExecutor extends AbstractAgentSubModule<ActionExecutorInput,
.build();
}
}
}

View File

@@ -4,11 +4,9 @@ import com.alibaba.fastjson2.JSONArray;
import com.alibaba.fastjson2.JSONObject;
import com.alibaba.fastjson2.TypeReference;
import lombok.Data;
import lombok.extern.slf4j.Slf4j;
import work.slhaf.partner.api.agent.factory.capability.annotation.InjectCapability;
import work.slhaf.partner.api.agent.factory.module.abstracts.AbstractAgentSubModule;
import work.slhaf.partner.api.agent.factory.module.abstracts.AbstractAgentModule;
import work.slhaf.partner.api.agent.factory.module.abstracts.ActivateModel;
import work.slhaf.partner.api.agent.factory.module.annotation.AgentSubModule;
import work.slhaf.partner.api.agent.factory.module.annotation.Init;
import work.slhaf.partner.api.agent.factory.module.annotation.InjectModule;
import work.slhaf.partner.api.chat.pojo.ChatResponse;
@@ -28,7 +26,6 @@ import java.util.List;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.atomic.AtomicInteger;
/**
* 负责识别行动链的修复
* <ol>
@@ -43,26 +40,19 @@ import java.util.concurrent.atomic.AtomicInteger;
* </li>
* </ol>
*/
@Slf4j
@AgentSubModule
public class ActionRepairer extends AbstractAgentSubModule<RepairerInput, RepairerResult> implements ActivateModel {
public class ActionRepairer extends AbstractAgentModule.Sub<RepairerInput, RepairerResult> implements ActivateModel {
@InjectCapability
private ActionCapability actionCapability;
@InjectCapability
private CognationCapability cognationCapability;
@InjectModule
private DynamicActionGenerator dynamicActionGenerator;
private final AssemblyHelper assemblyHelper = new AssemblyHelper();
private RunnerClient runnerClient;
@Init
void init() {
runnerClient = actionCapability.runnerClient();
}
@Override
public RepairerResult execute(RepairerInput data) {
RepairerResult result;
@@ -92,7 +82,6 @@ public class ActionRepairer extends AbstractAgentSubModule<RepairerInput, Repair
}
return result;
}
/**
* 负责根据输入内容进行行动单元的参数信息修复
*
@@ -107,7 +96,6 @@ public class ActionRepairer extends AbstractAgentSubModule<RepairerInput, Repair
result.setStatus(RepairerStatus.FAILED);
return result;
}
runnerClient.submit(tempAction);
// 根据 tempAction 的执行状态设置修复结果
Result actionResult = tempAction.getResult();
@@ -115,12 +103,10 @@ public class ActionRepairer extends AbstractAgentSubModule<RepairerInput, Repair
result.setStatus(RepairerStatus.FAILED);
return result;
}
result.setStatus(RepairerStatus.OK);
result.getFixedData().add(actionResult.getData());
return result;
}
/**
* 负责根据输入内容进行行动单元的参数信息修复
*
@@ -161,50 +147,41 @@ public class ActionRepairer extends AbstractAgentSubModule<RepairerInput, Repair
}
return result;
}
private RepairerResult handleUserInteraction(String acquireContent) {
RepairerResult result = new RepairerResult();
result.setStatus(RepairerStatus.ACQUIRE);
// 发送自对话请求
return result;
}
@Override
public String modelKey() {
return "action_repairer";
}
@Override
public boolean withBasicPrompt() {
return false;
}
@SuppressWarnings("InnerClassMayBeStatic")
@Data
private class RepairerData {
private RepairerType repairerType;
private String data;
}
private enum RepairerType {
ACTION_GENERATION,
ACTION_INVOCATION,
USER_INTERACTION
}
@SuppressWarnings("InnerClassMayBeStatic")
private class AssemblyHelper {
private AssemblyHelper() {
}
private String buildPrompt(RepairerInput data, String specialInstruction) {
JSONObject prompt = new JSONObject();
JSONObject actionData = prompt.putObject("[本次行动信息]");
actionData.put("[行动描述]", data.getActionDescription());
JSONObject actionParamsData = actionData.putObject("[行动参数说明]");
actionParamsData.putAll(data.getParams());
JSONArray historyData = prompt.putArray("[历史行动执行结果]");
data.getHistoryActionResults().forEach(historyAction -> {
JSONObject historyItem = new JSONObject();
@@ -213,16 +190,12 @@ public class ActionRepairer extends AbstractAgentSubModule<RepairerInput, Repair
historyItem.put("[行动结果]", historyAction.result());
historyData.add(historyItem);
});
JSONArray messageData = prompt.putArray("[最近消息列表]");
messageData.addAll(data.getRecentMessages());
if (specialInstruction != null) {
prompt.put("[特殊指令]", specialInstruction);
}
return prompt.toString();
}
}
}

View File

@@ -1,11 +1,9 @@
package work.slhaf.partner.module.modules.action.dispatcher.executor;
import com.alibaba.fastjson2.JSONObject;
import lombok.val;
import work.slhaf.partner.api.agent.factory.capability.annotation.InjectCapability;
import work.slhaf.partner.api.agent.factory.module.abstracts.AbstractAgentSubModule;
import work.slhaf.partner.api.agent.factory.module.abstracts.AbstractAgentModule;
import work.slhaf.partner.api.agent.factory.module.abstracts.ActivateModel;
import work.slhaf.partner.api.agent.factory.module.annotation.AgentSubModule;
import work.slhaf.partner.api.agent.factory.module.annotation.Init;
import work.slhaf.partner.api.chat.pojo.ChatResponse;
import work.slhaf.partner.common.util.ExtractUtil;
@@ -15,24 +13,18 @@ import work.slhaf.partner.core.action.entity.MetaAction;
import work.slhaf.partner.core.action.runner.RunnerClient;
import work.slhaf.partner.module.modules.action.dispatcher.executor.entity.GeneratorInput;
import work.slhaf.partner.module.modules.action.dispatcher.executor.entity.GeneratorResult;
/**
* 负责依据输入内容生成可执行的动态行动单元,并选择是否持久化至 SandboxRunner 容器内
*/
@AgentSubModule
public class DynamicActionGenerator extends AbstractAgentSubModule<GeneratorInput, GeneratorResult>
public class DynamicActionGenerator extends AbstractAgentModule.Sub<GeneratorInput, GeneratorResult>
implements ActivateModel {
@InjectCapability
private ActionCapability actionCapability;
private RunnerClient runnerClient;
@Init
void init() {
runnerClient = actionCapability.runnerClient();
}
@Override
public GeneratorResult execute(GeneratorInput input) {
GeneratorResult result = new GeneratorResult();
@@ -40,12 +32,10 @@ public class DynamicActionGenerator extends AbstractAgentSubModule<GeneratorInpu
// 由于 SCRIPT 类型程序都是在 SandboxRunner 内部的磁盘上加载然后执行的,
// 所以此处的输入内容也只需要指定输入参数、临时key、是否持久化即可路径将按照指定规则统一构建不可交给LLM生成
String prompt = buildPrompt(input);
// 响应结果需要包含几个特殊数据: 依赖项、代码内容、是否序列化、响应数据释义
ChatResponse response = this.singleChat(prompt);
GeneratedData generatorData = JSONObject
.parseObject(ExtractUtil.extractJson(response.getMessage()), GeneratedData.class);
val location = runnerClient.buildTmpPath(input.getActionName(), generatorData.getCodeType());
MetaAction tempAction = new MetaAction(
input.getActionName(),
@@ -66,11 +56,9 @@ public class DynamicActionGenerator extends AbstractAgentSubModule<GeneratorInpu
}
return result;
}
private void waitingSerialize() {
throw new UnsupportedOperationException("Unimplemented method 'waitingSerialize'");
}
private String buildPrompt(GeneratorInput data) {
JSONObject prompt = new JSONObject();
prompt.put("[行动描述]", data.getDescription());
@@ -78,12 +66,10 @@ public class DynamicActionGenerator extends AbstractAgentSubModule<GeneratorInpu
prompt.putObject("[行动参数描述]").putAll(data.getParamsDescription());
return prompt.toString();
}
@Override
public String modelKey() {
return "dynamic_generator";
}
@Override
public boolean withBasicPrompt() {
return false;

View File

@@ -2,10 +2,8 @@ package work.slhaf.partner.module.modules.action.dispatcher.executor;
import com.alibaba.fastjson2.JSONArray;
import com.alibaba.fastjson2.JSONObject;
import lombok.extern.slf4j.Slf4j;
import work.slhaf.partner.api.agent.factory.module.abstracts.AbstractAgentSubModule;
import work.slhaf.partner.api.agent.factory.module.abstracts.AbstractAgentModule;
import work.slhaf.partner.api.agent.factory.module.abstracts.ActivateModel;
import work.slhaf.partner.api.agent.factory.module.annotation.AgentSubModule;
import work.slhaf.partner.api.chat.pojo.ChatResponse;
import work.slhaf.partner.core.action.entity.MetaActionInfo;
import work.slhaf.partner.module.modules.action.dispatcher.executor.entity.ExtractorInput;
@@ -14,14 +12,10 @@ import work.slhaf.partner.module.modules.action.dispatcher.executor.entity.Histo
import java.util.HashMap;
import java.util.List;
/**
* 负责依据输入内容进行行动单元的参数信息提取
*/
@Slf4j
@AgentSubModule
public class ParamsExtractor extends AbstractAgentSubModule<ExtractorInput, ExtractorResult> implements ActivateModel {
public class ParamsExtractor extends AbstractAgentModule.Sub<ExtractorInput, ExtractorResult> implements ActivateModel {
@Override
public ExtractorResult execute(ExtractorInput input) {
String prompt = buildPrompt(input);
@@ -37,15 +31,12 @@ public class ParamsExtractor extends AbstractAgentSubModule<ExtractorInput, Extr
}
return result;
}
private String buildPrompt(ExtractorInput input) {
JSONObject prompt = new JSONObject();
JSONObject actionData = prompt.putObject("[本次行动信息]");
MetaActionInfo actionInfo = input.getMetaActionInfo();
actionData.put("[行动描述]", actionInfo.getDescription());
actionData.put("[行动参数说明]", actionInfo.getParams());
JSONArray historyData = prompt.putArray("[历史行动执行结果]");
List<HistoryAction> historyActions = input.getHistoryActionResults();
for (HistoryAction historyAction : historyActions) {
@@ -55,21 +46,16 @@ public class ParamsExtractor extends AbstractAgentSubModule<ExtractorInput, Extr
historyItem.put("[行动结果]", historyAction.result());
historyData.add(historyItem);
}
JSONArray messageData = prompt.putArray("[最近消息列表]");
messageData.addAll(input.getRecentMessages());
return prompt.toString();
}
@Override
public String modelKey() {
return "params_extractor";
}
@Override
public boolean withBasicPrompt() {
return false;
}
}

View File

@@ -12,8 +12,7 @@ import kotlinx.coroutines.sync.Mutex
import kotlinx.coroutines.sync.withLock
import org.slf4j.LoggerFactory
import work.slhaf.partner.api.agent.factory.capability.annotation.InjectCapability
import work.slhaf.partner.api.agent.factory.module.abstracts.AbstractAgentSubModule
import work.slhaf.partner.api.agent.factory.module.annotation.AgentSubModule
import work.slhaf.partner.api.agent.factory.module.abstracts.AbstractAgentModule
import work.slhaf.partner.api.agent.factory.module.annotation.Init
import work.slhaf.partner.api.agent.factory.module.annotation.InjectModule
import work.slhaf.partner.core.action.ActionCapability
@@ -30,17 +29,13 @@ import java.time.temporal.ChronoUnit
import java.util.stream.Collectors
import kotlin.jvm.optionals.getOrNull
@AgentSubModule
class ActionScheduler : AbstractAgentSubModule<Set<Schedulable>, Void>() {
class ActionScheduler : AbstractAgentModule.Sub<Set<Schedulable>, Void?>() {
@InjectCapability
private lateinit var actionCapability: ActionCapability
@InjectModule
private lateinit var actionExecutor: ActionExecutor
private lateinit var timeWheel: TimeWheel
private val schedulerScope =
CoroutineScope(Dispatchers.Default + SupervisorJob() + CoroutineName("ActionScheduler"))
@@ -58,7 +53,6 @@ class ActionScheduler : AbstractAgentSubModule<Set<Schedulable>, Void>() {
.map { it as SchedulableExecutableAction }
.collect(Collectors.toSet())
}
val onTrigger: (Set<Schedulable>) -> Unit = { schedulableSet ->
val executableActions = mutableSetOf<SchedulableExecutableAction>()
val stateActions = mutableSetOf<StateAction>()
@@ -72,12 +66,9 @@ class ActionScheduler : AbstractAgentSubModule<Set<Schedulable>, Void>() {
actionCapability.getExecutor(ActionCore.ExecutorType.VIRTUAL)
.execute { stateActions.forEach { it.trigger.onTrigger() } }
}
timeWheel = TimeWheel(listScheduledActions, onTrigger)
}
loadScheduledActions()
setupShutdownHook()
}
@@ -88,10 +79,9 @@ class ActionScheduler : AbstractAgentSubModule<Set<Schedulable>, Void>() {
})
}
override fun execute(schedulableSet: Set<Schedulable>?): Void? {
override fun execute(input: Set<Schedulable>): Void? {
schedulerScope.launch {
schedulableSet?.run {
for (schedulableData in schedulableSet) {
for (schedulableData in input) {
log.debug("New data to schedule: {}", schedulableData)
timeWheel.schedule(schedulableData)
if (schedulableData is SchedulableExecutableAction) {
@@ -99,7 +89,6 @@ class ActionScheduler : AbstractAgentSubModule<Set<Schedulable>, Void>() {
}
}
}
}
return null
}
@@ -107,16 +96,13 @@ class ActionScheduler : AbstractAgentSubModule<Set<Schedulable>, Void>() {
val listSource: () -> Set<Schedulable>,
val onTrigger: (toTrigger: Set<Schedulable>) -> Unit
) : Closeable {
private val schedulableGroupByHour = Array<MutableSet<Schedulable>>(24) { mutableSetOf() }
private val wheel = Array<MutableSet<Schedulable>>(60 * 60) { mutableSetOf() }
private var recordHour: Int = -1
private var recordDay: Int = -1
private val state = MutableStateFlow(WheelState.SLEEPING)
private val wheelActionsLock = Mutex()
private val timeWheelScope = CoroutineScope(SupervisorJob() + Dispatchers.Default + CoroutineName("TimeWheel"))
private val cronDefinition: CronDefinition = CronDefinitionBuilder.instanceDefinitionFor(CronType.QUARTZ)
private val cronParser: CronParser = CronParser(cronDefinition)
@@ -136,23 +122,19 @@ class ActionScheduler : AbstractAgentSubModule<Set<Schedulable>, Void>() {
return@checkThenExecute
}
log.debug("Action next execution time: {}", parseToZonedDateTime)
val hour = parseToZonedDateTime.hour
schedulableGroupByHour[hour].add(schedulableData)
log.debug("Action scheduled at {}", hour)
if (it.hour == hour) {
val wheelOffset = parseToZonedDateTime.minute * 60 + parseToZonedDateTime.second
wheel[wheelOffset].add(schedulableData)
state.value = WheelState.ACTIVE
log.debug("Action scheduled at wheel offset {}", wheelOffset)
}
}
}
private fun wheel() {
data class WheelStepResult(
val toTrigger: Set<Schedulable>?,
val shouldBreak: Boolean
@@ -178,26 +160,20 @@ class ActionScheduler : AbstractAgentSubModule<Set<Schedulable>, Void>() {
suspend fun CoroutineScope.wheel(launchingTime: ZonedDateTime, primaryTickAdvanceTime: Long) {
val launchingHour = launchingTime.hour
var tick = launchingTime.minute * 60 + launchingTime.second
// 让节拍器从“启动时刻的下一秒”开始(避免立即 step=0
var nextTickNanos = primaryTickAdvanceTime + 1_000_000_000L
while (isActive) {
// 1) 计算落后多少秒:至少 1正常推进也可能 >1追赶
val now0 = System.nanoTime()
val lagNanos = now0 - nextTickNanos
val step = if (lagNanos < 0) 1 else (lagNanos / 1_000_000_000L).toInt() + 1
val previousTick = tick
tick = (tick + step).coerceAtMost(wheel.lastIndex)
// 2) 推进节拍器:按“理论秒”前进 step 次
nextTickNanos += step.toLong() * 1_000_000_000L
val stepResult = run {
var shouldBreak = false
var toTrigger: Set<Schedulable>? = null
checkThenExecute(false) {
if (it.hour != launchingHour) {
shouldBreak = true
@@ -210,30 +186,23 @@ class ActionScheduler : AbstractAgentSubModule<Set<Schedulable>, Void>() {
)
return@checkThenExecute
}
toTrigger = collectToTrigger(tick, previousTick, launchingHour)
if (tick >= wheel.lastIndex || schedulableGroupByHour[launchingHour].isEmpty()) {
state.value = WheelState.SLEEPING
shouldBreak = true
}
}
WheelStepResult(toTrigger, shouldBreak)
}
stepResult.toTrigger?.let { trigger ->
timeWheelScope.launch {
onTrigger(trigger)
}
}
if (stepResult.shouldBreak) {
log.debug("Wheel stopped at tick {}", tick)
break
}
// 3) 精确睡到下一次理论 tick用最新 nanoTime
val now1 = System.nanoTime()
val sleepNanos = nextTickNanos - now1
@@ -255,7 +224,6 @@ class ActionScheduler : AbstractAgentSubModule<Set<Schedulable>, Void>() {
}
log.debug("Waiting ended at {}", ZonedDateTime.now())
}
timeWheelScope.launch {
while (isActive) {
// 判断是否该步入下一小时
@@ -270,14 +238,12 @@ class ActionScheduler : AbstractAgentSubModule<Set<Schedulable>, Void>() {
// 而启动时无触发保障,此时一并初始化 tick 推进时间,足以应对 check 与 wheel 间的这段时间间隔
primaryTickAdvanceTime = System.nanoTime()
}
// 如果该时无任务则等待,插入事件可提前唤醒
if (shouldWait!!) {
// 计算距离下一小时的时间,等待
currentTime?.let { wait(it) }
continue
}
// 唤醒进行时间轮循环
wheel(currentTime!!, primaryTickAdvanceTime!!)
}
@@ -303,11 +269,9 @@ class ActionScheduler : AbstractAgentSubModule<Set<Schedulable>, Void>() {
logFailedStatus(schedulableData)
continue
}
load(nextExecutingTime, schedulableData)
}
}
repair()
runLoading()
}
@@ -319,13 +283,11 @@ class ActionScheduler : AbstractAgentSubModule<Set<Schedulable>, Void>() {
wheel[secondsTime].add(schedulableData)
log.debug("Action loaded to hour: {}", schedulableData)
}
val repair: () -> Unit = {
for (set in wheel) {
set.clear()
}
}
loadActions(schedulableGroupByHour[currentTime.hour], currentTime, load, repair)
}
@@ -335,13 +297,11 @@ class ActionScheduler : AbstractAgentSubModule<Set<Schedulable>, Void>() {
schedulableGroupByHour[latestExecutingTime.hour].add(schedulableData)
log.debug("Action loaded to day: {}", schedulableData)
}
val repair: () -> Unit = {
for (set in schedulableGroupByHour) {
set.clear()
}
}
loadActions(listSource(), currentTime, load, repair)
}
@@ -360,7 +320,6 @@ class ActionScheduler : AbstractAgentSubModule<Set<Schedulable>, Void>() {
}
val now = ZonedDateTime.now()
if (finallyToExecute) {
refreshIfNeeded(now)
then(now)
@@ -398,9 +357,7 @@ class ActionScheduler : AbstractAgentSubModule<Set<Schedulable>, Void>() {
else
executionTime
}
}
}
private fun logFailedStatus(scheduleData: Schedulable) {

View File

@@ -5,7 +5,6 @@ import com.alibaba.fastjson2.JSONObject;
import lombok.val;
import work.slhaf.partner.api.agent.factory.capability.annotation.InjectCapability;
import work.slhaf.partner.api.agent.factory.module.abstracts.ActivateModel;
import work.slhaf.partner.api.agent.factory.module.annotation.AgentRunningModule;
import work.slhaf.partner.api.agent.factory.module.annotation.InjectModule;
import work.slhaf.partner.core.action.ActionCapability;
import work.slhaf.partner.core.action.ActionCore;
@@ -30,40 +29,32 @@ import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Stream;
/**
* 负责识别潜在的行动干预信息,作用于正在进行或已存在的行动池中内容
*/
@AgentRunningModule(name = "action_identifier", order = 2)
public class ActionInterventor extends PreRunningAbstractAgentModuleAbstract implements ActivateModel {
@InjectModule
private InterventionRecognizer interventionRecognizer;
@InjectModule
private InterventionEvaluator interventionEvaluator;
@InjectCapability
private ActionCapability actionCapability;
@InjectCapability
private CognationCapability cognationCapability;
@InjectCapability
private MemoryCapability memoryCapability;
private final AssemblyHelper assemblyHelper = new AssemblyHelper();
private final PromptHelper promptHelper = new PromptHelper();
/**
* 键: 本次调用uuid
* 值本次调用对应的prompt
*/
private final Map<String, Map<String, String>> interventionPrompt = new HashMap<>();
@Override
protected void doExecute(PartnerRunningFlowContext context) {
// 综合当前正在进行的行动链信息、用户交互历史、激活的记忆切片,尝试识别出是否存在行动干预意图
// 首先通过recognizer进行快速意图识别识别成功则步入评估阶段评估成功则直接作用于目标行动链
// 进行快速意图识别时必须结合近期对话与进行中行动链情况
// 干预意图识别
String uuid = context.getUuid();
String userId = context.getUserId();
@@ -73,19 +64,16 @@ public class ActionInterventor extends PreRunningAbstractAgentModuleAbstract imp
promptHelper.setupNoInterventionPrompt(uuid);
return;
}
// 干预意图评估
EvaluatorResult evaluatorResult = interventionEvaluator
.execute(assemblyHelper.buildEvaluatorInput(recognizerResult, userId));
List<EvaluatedInterventionData> executingDataList = evaluatorResult.getExecutingDataList();
List<EvaluatedInterventionData> preparedDataList = evaluatorResult.getPreparedDataList();
// 意图评估结果处理
if (evaluatorResult.isOk()) {
// 对存在异常ActionKey的评估结果列表进行过滤
invalidActionKeysFilter(executingDataList);
invalidActionKeysFilter(preparedDataList);
// 同步写入prompt异步处理干预行为异步在处理流程中体现
promptHelper.setupInterventionPrompt(uuid, executingDataList, preparedDataList);
handleInterventions(executingDataList, recognizerResult.getExecutingInterventions());
@@ -93,9 +81,7 @@ public class ActionInterventor extends PreRunningAbstractAgentModuleAbstract imp
} else {
promptHelper.setupInterventionIgnoredPrompt(uuid, executingDataList, preparedDataList);
}
}
private void handleInterventions(List<EvaluatedInterventionData> interventionDataList, Map<String, ExecutableAction> interventionDataMap) {
val executor = actionCapability.getExecutor(ActionCore.ExecutorType.PLATFORM);
executor.execute(() -> {
@@ -106,11 +92,8 @@ public class ActionInterventor extends PreRunningAbstractAgentModuleAbstract imp
}
});
}
private void invalidActionKeysFilter(List<EvaluatedInterventionData> interventions) {
List<EvaluatedInterventionData> toRemove = new ArrayList<>();
for (EvaluatedInterventionData intervention : interventions) {
List<MetaIntervention> interventionData = intervention.getMetaInterventionList();
List<String> actions = new ArrayList<>();
@@ -121,16 +104,13 @@ public class ActionInterventor extends PreRunningAbstractAgentModuleAbstract imp
if (!actionCapability.checkExists(actions.toArray(String[]::new))) {
toRemove.add(intervention);
}
// 针对 REBUILD 类型进行特殊校验, REBUILD 类型必须满足所有 MetaIntervention 的类型均为 REBUILD
if (!checkRebuildType(interventionData)) {
toRemove.add(intervention);
}
}
interventions.removeAll(toRemove);
}
private boolean checkRebuildType(List<MetaIntervention> interventionData) {
boolean hasRebuild = false;
for (MetaIntervention meta : interventionData) {
@@ -141,34 +121,27 @@ public class ActionInterventor extends PreRunningAbstractAgentModuleAbstract imp
return false;
}
}
return true;
}
@Override
public String modelKey() {
return "action_identifier";
}
@Override
public boolean withBasicPrompt() {
return false;
}
@Override
protected Map<String, String> getPromptDataMap(PartnerRunningFlowContext context) {
return interventionPrompt.remove(context.getUuid());
}
@Override
protected String moduleName() {
return "[行动干预识别模块]";
}
private final class AssemblyHelper {
private AssemblyHelper() {
}
private RecognizerInput buildRecognizerInput(String userId, String input) {
RecognizerInput recognizerInput = new RecognizerInput();
recognizerInput.setInput(input);
@@ -179,7 +152,6 @@ public class ActionInterventor extends PreRunningAbstractAgentModuleAbstract imp
recognizerInput.setPreparedActions(actionCapability.listActions(ExecutableAction.Status.PREPARE, userId).stream().toList());
return recognizerInput;
}
private EvaluatorInput buildEvaluatorInput(RecognizerResult recognizerResult, String userId) {
EvaluatorInput input = new EvaluatorInput();
input.setExecutingInterventions(recognizerResult.getExecutingInterventions());
@@ -189,22 +161,17 @@ public class ActionInterventor extends PreRunningAbstractAgentModuleAbstract imp
return input;
}
}
private final class PromptHelper {
private PromptHelper() {
}
private void setupInterventionIgnoredPrompt(String uuid, List<EvaluatedInterventionData> executingDataList, List<EvaluatedInterventionData> preparedDataList) {
List<EvaluatedInterventionData> total = Stream.concat(executingDataList.stream(), preparedDataList.stream()).toList();
JSONArray reasons = new JSONArray();
for (EvaluatedInterventionData data : total) {
JSONObject reason = reasons.addObject();
reason.put("[干预倾向]", data.getTendency());
reason.put("[未采用原因]", data.getDescription());
}
synchronized (interventionPrompt) {
interventionPrompt.put(uuid, Map.of(
"[识别状态] <是否识别到干预已存在行动的意图>", "识别到,但都未采用",
@@ -212,12 +179,10 @@ public class ActionInterventor extends PreRunningAbstractAgentModuleAbstract imp
"[干预行动] <将对已存在行动做出的行为>", "无行为"));
}
}
private void setupInterventionPrompt(String uuid, List<EvaluatedInterventionData> executingDataList,
List<EvaluatedInterventionData> preparedDataList) {
JSONArray contents = new JSONArray();
List<EvaluatedInterventionData> temp = Stream.concat(executingDataList.stream(), preparedDataList.stream()).toList();
for (EvaluatedInterventionData data : temp) {
if (!data.isOk()) {
continue;
@@ -226,7 +191,6 @@ public class ActionInterventor extends PreRunningAbstractAgentModuleAbstract imp
JSONObject newElement = contents.addObject();
newElement.put("[干预倾向]", tendency);
JSONArray changes = newElement.putArray("[行动链变动情况]");
for (MetaIntervention intervention : data.getMetaInterventionList()) {
JSONObject change = changes.addObject();
change.put("[干预类型]", intervention.getType());
@@ -234,18 +198,21 @@ public class ActionInterventor extends PreRunningAbstractAgentModuleAbstract imp
change.putArray("[干预内容]").addAll(intervention.getActions());
}
}
synchronized (interventionPrompt) {
interventionPrompt.put(uuid, Map.of(
"[识别状态] <是否识别到干预已存在行动的意图>", "识别到,将采用",
"[干预内容] <将对已存在行动做出的行为>", contents.toString()));
}
}
private void setupNoInterventionPrompt(String uuid) {
interventionPrompt.put(uuid, Map.of(
"[识别状态] <是否识别到干预已存在行动的意图>", "未识别到干预意图",
"[干预行动] <将对已存在行动做出的行为>", "无行动"));
}
}
@Override
public int order() {
return 2;
}
}

View File

@@ -1,11 +1,9 @@
package work.slhaf.partner.module.modules.action.interventor.evaluator;
import com.alibaba.fastjson2.JSONObject;
import lombok.extern.slf4j.Slf4j;
import work.slhaf.partner.api.agent.factory.capability.annotation.InjectCapability;
import work.slhaf.partner.api.agent.factory.module.abstracts.AbstractAgentSubModule;
import work.slhaf.partner.api.agent.factory.module.abstracts.AbstractAgentModule;
import work.slhaf.partner.api.agent.factory.module.abstracts.ActivateModel;
import work.slhaf.partner.api.agent.factory.module.annotation.AgentSubModule;
import work.slhaf.partner.api.chat.pojo.ChatResponse;
import work.slhaf.partner.api.chat.pojo.Message;
import work.slhaf.partner.core.action.ActionCapability;
@@ -21,14 +19,10 @@ import java.util.Map;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
@Slf4j
@AgentSubModule
public class InterventionEvaluator extends AbstractAgentSubModule<EvaluatorInput, EvaluatorResult>
public class InterventionEvaluator extends AbstractAgentModule.Sub<EvaluatorInput, EvaluatorResult>
implements ActivateModel {
@InjectCapability
private ActionCapability actionCapability;
/**
* 基于干预意图、记忆切片、交互上下文、已有行动程序综合评估,尝试评估并选取出合适的行动程序,交付给 ActionInterventor
*/
@@ -39,30 +33,24 @@ public class InterventionEvaluator extends AbstractAgentSubModule<EvaluatorInput
Map<String, ExecutableAction> executingInterventions = input.getExecutingInterventions();
Map<String, ExecutableAction> preparedInterventions = input.getPreparedInterventions();
CountDownLatch latch = new CountDownLatch(executingInterventions.size() + preparedInterventions.size());
// 创建结果容器
EvaluatorResult result = new EvaluatorResult();
List<EvaluatedInterventionData> executingDataList = result.getExecutingDataList();
List<EvaluatedInterventionData> preparedDataList = result.getPreparedDataList();
// 并发评估
evaluateIntervention(executingDataList, executingInterventions, input, executor, latch);
evaluateIntervention(preparedDataList, preparedInterventions, input, executor, latch);
try {
latch.await();
} catch (InterruptedException e) {
log.warn("CountDownLatch阻塞已中断");
}
return result;
}
private void evaluateIntervention(List<EvaluatedInterventionData> evaluatedDataList, Map<String, ExecutableAction> interventionMap, EvaluatorInput input, ExecutorService executor, CountDownLatch latch) {
interventionMap.forEach((tendency, actionData) -> executor.execute(() -> {
try {
String prompt = buildPrompt(input.getRecentMessages(), input.getActivatedSlices(), actionData, tendency);
ChatResponse response = this.singleChat(prompt);
EvaluatedInterventionData evaluatedData = JSONObject.parseObject(response.getMessage(),
EvaluatedInterventionData.class);
@@ -76,7 +64,6 @@ public class InterventionEvaluator extends AbstractAgentSubModule<EvaluatorInput
}
}));
}
private String buildPrompt(List<Message> recentMessages, List<EvaluatedSlice> activatedSlices,
ExecutableAction executableAction, String tendency) {
JSONObject json = new JSONObject();
@@ -86,12 +73,10 @@ public class InterventionEvaluator extends AbstractAgentSubModule<EvaluatorInput
json.put("将干预的行动", JSONObject.toJSONString(executableAction));
return json.toJSONString();
}
@Override
public String modelKey() {
return "intervention_evaluator";
}
@Override
public boolean withBasicPrompt() {
return false;

View File

@@ -1,11 +1,9 @@
package work.slhaf.partner.module.modules.action.interventor.recognizer;
import com.alibaba.fastjson2.JSONObject;
import lombok.extern.slf4j.Slf4j;
import work.slhaf.partner.api.agent.factory.capability.annotation.InjectCapability;
import work.slhaf.partner.api.agent.factory.module.abstracts.AbstractAgentSubModule;
import work.slhaf.partner.api.agent.factory.module.abstracts.AbstractAgentModule;
import work.slhaf.partner.api.agent.factory.module.abstracts.ActivateModel;
import work.slhaf.partner.api.agent.factory.module.annotation.AgentSubModule;
import work.slhaf.partner.api.chat.pojo.ChatResponse;
import work.slhaf.partner.core.action.ActionCapability;
import work.slhaf.partner.core.action.ActionCore;
@@ -19,13 +17,9 @@ import java.util.Map;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
@Slf4j
@AgentSubModule
public class InterventionRecognizer extends AbstractAgentSubModule<RecognizerInput, RecognizerResult> implements ActivateModel {
public class InterventionRecognizer extends AbstractAgentModule.Sub<RecognizerInput, RecognizerResult> implements ActivateModel {
@InjectCapability
private ActionCapability actionCapability;
@Override
public RecognizerResult execute(RecognizerInput input) {
// 获取必须数据
@@ -33,16 +27,13 @@ public class InterventionRecognizer extends AbstractAgentSubModule<RecognizerInp
List<ExecutableAction> executingActions = input.getExecutingActions();
List<ExecutableAction> preparedActions = input.getPreparedActions();
CountDownLatch countDownLatch = new CountDownLatch(executingActions.size() + preparedActions.size());
// 创建结果容器
RecognizerResult recognizerResult = new RecognizerResult();
Map<String, ExecutableAction> executingInterventions = recognizerResult.getExecutingInterventions();
Map<String, ExecutableAction> preparedInterventions = recognizerResult.getPreparedInterventions();
// 执行识别操作
recognizeIntervention(executingInterventions, executingActions, executor, input, countDownLatch);
recognizeIntervention(preparedInterventions, preparedActions, executor, input, countDownLatch);
try {
countDownLatch.await();
} catch (InterruptedException e) {
@@ -50,7 +41,6 @@ public class InterventionRecognizer extends AbstractAgentSubModule<RecognizerInp
}
return recognizerResult;
}
private void recognizeIntervention(Map<String, ExecutableAction> interventionsMap, List<ExecutableAction> actions, ExecutorService executor, RecognizerInput input, CountDownLatch latch) {
for (ExecutableAction data : actions) {
executor.execute(() -> {
@@ -71,30 +61,24 @@ public class InterventionRecognizer extends AbstractAgentSubModule<RecognizerInp
});
}
}
private String buildPrompt(ExecutableAction executableAction, RecognizerInput input) {
JSONObject json = new JSONObject();
JSONObject actionInfo = json.putObject("行动信息");
actionInfo.put("行动倾向", executableAction.getTendency());
actionInfo.put("行动原因", executableAction.getReason());
actionInfo.put("行动描述", executableAction.getDescription());
actionInfo.put("行动状态", executableAction.getStatus());
actionInfo.put("行动来源", executableAction.getSource());
JSONObject interactionInfo = json.putObject("交互信息");
interactionInfo.put("用户输入", input.getInput());
interactionInfo.put("当前对话", input.getRecentMessages());
interactionInfo.put("近期对话", input.getUserDialogMapStr());
return json.toString();
}
@Override
public String modelKey() {
return "intervention_recognizer";
}
@Override
public boolean withBasicPrompt() {
return false;

View File

@@ -1,10 +1,8 @@
package work.slhaf.partner.module.modules.action.planner;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import org.jetbrains.annotations.NotNull;
import work.slhaf.partner.api.agent.factory.capability.annotation.InjectCapability;
import work.slhaf.partner.api.agent.factory.module.annotation.AgentRunningModule;
import work.slhaf.partner.api.agent.factory.module.annotation.Init;
import work.slhaf.partner.api.agent.factory.module.annotation.InjectModule;
import work.slhaf.partner.api.chat.pojo.Message;
@@ -33,14 +31,10 @@ import java.util.*;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.atomic.AtomicBoolean;
/**
* 负责针对本次输入生成基础的行动计划,在主模型传达意愿后,执行行动或者放入计划池
*/
@Slf4j
@AgentRunningModule(name = "action_planner", order = 2)
public class ActionPlanner extends PreRunningAbstractAgentModuleAbstract {
@InjectCapability
private CognationCapability cognationCapability;
@InjectCapability
@@ -49,23 +43,18 @@ public class ActionPlanner extends PreRunningAbstractAgentModuleAbstract {
private PerceiveCapability perceiveCapability;
@InjectCapability
private MemoryCapability memoryCapability;
@InjectModule
private ActionEvaluator actionEvaluator;
@InjectModule
private ActionExtractor actionExtractor;
@InjectModule
private ActionConfirmer actionConfirmer;
private ExecutorService executor;
private final ActionAssemblyHelper assemblyHelper = new ActionAssemblyHelper();
@Init
public void init() {
executor = actionCapability.getExecutor(ActionCore.ExecutorType.VIRTUAL);
}
@Override
protected void doExecute(PartnerRunningFlowContext context) {
try {
@@ -77,7 +66,6 @@ public class ActionPlanner extends PreRunningAbstractAgentModuleAbstract {
log.error("执行异常", e);
}
}
/**
* 新的提取与评估任务
*
@@ -98,7 +86,6 @@ public class ActionPlanner extends PreRunningAbstractAgentModuleAbstract {
return null;
});
}
private void updateTendencyCache(List<EvaluatorResult> evaluatorResults, String input,
ExtractorResult extractorResult) {
if (!VectorClient.status) {
@@ -119,9 +106,7 @@ public class ActionPlanner extends PreRunningAbstractAgentModuleAbstract {
data.setInput(input);
actionCapability.updateTendencyCache(data);
});
}
/**
* 待确认行动的判断任务
*
@@ -136,7 +121,6 @@ public class ActionPlanner extends PreRunningAbstractAgentModuleAbstract {
return null;
});
}
private void setupConfirmedActionInfo(PartnerRunningFlowContext context, ConfirmerResult result) {
// TODO 需考虑未确认任务的失效或者拒绝时机在action core中实现
List<String> uuids = result.getUuids();
@@ -150,7 +134,6 @@ public class ActionPlanner extends PreRunningAbstractAgentModuleAbstract {
}
}
}
private void putActionData(List<EvaluatorResult> evaluatorResults, PartnerRunningFlowContext context) {
for (EvaluatorResult evaluatorResult : evaluatorResults) {
ExecutableAction executableAction = assemblyHelper.buildActionData(evaluatorResult, context.getUserId());
@@ -161,7 +144,6 @@ public class ActionPlanner extends PreRunningAbstractAgentModuleAbstract {
}
}
}
@Override
protected Map<String, String> getPromptDataMap(PartnerRunningFlowContext context) {
HashMap<String, String> map = new HashMap<>();
@@ -170,7 +152,6 @@ public class ActionPlanner extends PreRunningAbstractAgentModuleAbstract {
setupPreparedActions(map, userId);
return map;
}
private void setupPendingActions(HashMap<String, String> map, String userId) {
List<ExecutableAction> executableActionData = actionCapability.listPendingAction(userId);
if (executableActionData == null || executableActionData.isEmpty()) {
@@ -181,7 +162,6 @@ public class ActionPlanner extends PreRunningAbstractAgentModuleAbstract {
map.put("[待确认行动 " + (i + 1) + " ] <等待用户确认的行动信息>", generateActionStr(executableActionData.get(i)));
}
}
private void setupPreparedActions(HashMap<String, String> map, String userId) {
val preparedActions = actionCapability.listActions(ExecutableAction.Status.PREPARE, userId).stream().toList();
if (preparedActions.isEmpty()) {
@@ -192,22 +172,18 @@ public class ActionPlanner extends PreRunningAbstractAgentModuleAbstract {
map.put("[预备行动 " + (i + 1) + " ] <预备执行或放入计划池的行动信息>", generateActionStr(preparedActions.get(i)));
}
}
private String generateActionStr(ExecutableAction executableAction) {
return "<行动倾向>" + " : " + executableAction.getTendency() +
"<行动原因>" + " : " + executableAction.getReason() +
"<工具描述>" + " : " + executableAction.getDescription();
}
@Override
protected String moduleName() {
return "[行动模块]";
}
private final class ActionAssemblyHelper {
private ActionAssemblyHelper() {
}
private ExtractorInput buildExtractorInput(PartnerRunningFlowContext context) {
ExtractorInput input = new ExtractorInput();
input.setInput(context.getInput());
@@ -221,7 +197,6 @@ public class ActionPlanner extends PreRunningAbstractAgentModuleAbstract {
input.setRecentMessages(recentMessages);
return input;
}
private EvaluatorInput buildEvaluatorInput(ExtractorResult extractorResult, String userId) {
EvaluatorInput input = new EvaluatorInput();
input.setTendencies(extractorResult.getTendencies());
@@ -230,7 +205,6 @@ public class ActionPlanner extends PreRunningAbstractAgentModuleAbstract {
input.setActivatedSlices(memoryCapability.getActivatedSlices(userId));
return input;
}
private ExecutableAction buildActionData(EvaluatorResult evaluatorResult, String userId) {
Map<Integer, List<MetaAction>> actionChain = getActionChain(evaluatorResult);
return switch (evaluatorResult.getType()) {
@@ -252,7 +226,6 @@ public class ActionPlanner extends PreRunningAbstractAgentModuleAbstract {
);
};
}
private @NotNull Map<Integer, List<MetaAction>> getActionChain(EvaluatorResult evaluatorResult) {
Map<Integer, List<MetaAction>> actionChain = new HashMap<>();
Map<Integer, List<String>> primaryActionChain = evaluatorResult.getPrimaryActionChain();
@@ -265,7 +238,6 @@ public class ActionPlanner extends PreRunningAbstractAgentModuleAbstract {
});
return actionChain;
}
private void fixDependencies(Map<Integer, List<String>> primaryActionChain) {
// 先将 primaryActionChain 的节点序号修正为从1开始依次增大
fixOrder(primaryActionChain);
@@ -291,7 +263,6 @@ public class ActionPlanner extends PreRunningAbstractAgentModuleAbstract {
if (checkDependenciesExist(lastOrder, preActions, primaryActionChain)) {
continue;
}
// 如果存在前置依赖,则将其放置在当前order之前的位置,
// 放置位置优先选择已存在的上一节点,如果不存在(行动链的头节点时)则需要向行动链新增order
// 不需要检查行动链的当前节点的已存在 Action 是否为新 Action 的依赖项,因为这些 Action 实际来自 LLM
@@ -309,7 +280,6 @@ public class ActionPlanner extends PreRunningAbstractAgentModuleAbstract {
fixedOrders.addAll(tempOrders);
} while (fixed.getAndSet(false));
}
private void fixOrder(Map<Integer, List<String>> primaryActionChain) {
Map<Integer, List<String>> tempChain = new HashMap<>(primaryActionChain);
primaryActionChain.clear();
@@ -318,7 +288,6 @@ public class ActionPlanner extends PreRunningAbstractAgentModuleAbstract {
primaryActionChain.put(i, tempChain.get(i));
}
}
private boolean checkDependenciesExist(int lastOrder, List<String> preActions,
Map<Integer, List<String>> primaryActionChain) {
if (!primaryActionChain.containsKey(lastOrder)) {
@@ -328,7 +297,6 @@ public class ActionPlanner extends PreRunningAbstractAgentModuleAbstract {
//noinspection SlowListContainsAll
return existActions.containsAll(preActions);
}
private ConfirmerInput buildConfirmerInput(PartnerRunningFlowContext context) {
ConfirmerInput confirmerInput = new ConfirmerInput();
confirmerInput.setInput(context.getInput());
@@ -337,4 +305,9 @@ public class ActionPlanner extends PreRunningAbstractAgentModuleAbstract {
return confirmerInput;
}
}
@Override
public int order() {
return 2;
}
}

View File

@@ -2,11 +2,9 @@ package work.slhaf.partner.module.modules.action.planner.confirmer;
import com.alibaba.fastjson2.JSONArray;
import com.alibaba.fastjson2.JSONObject;
import lombok.extern.slf4j.Slf4j;
import work.slhaf.partner.api.agent.factory.capability.annotation.InjectCapability;
import work.slhaf.partner.api.agent.factory.module.abstracts.AbstractAgentSubModule;
import work.slhaf.partner.api.agent.factory.module.abstracts.AbstractAgentModule;
import work.slhaf.partner.api.agent.factory.module.abstracts.ActivateModel;
import work.slhaf.partner.api.agent.factory.module.annotation.AgentSubModule;
import work.slhaf.partner.api.chat.pojo.ChatResponse;
import work.slhaf.partner.api.chat.pojo.Message;
import work.slhaf.partner.core.action.ActionCapability;
@@ -21,22 +19,16 @@ import java.util.concurrent.ExecutorService;
import static work.slhaf.partner.common.util.ExtractUtil.extractJson;
@Slf4j
@AgentSubModule
public class ActionConfirmer extends AbstractAgentSubModule<ConfirmerInput, ConfirmerResult> implements ActivateModel {
public class ActionConfirmer extends AbstractAgentModule.Sub<ConfirmerInput, ConfirmerResult> implements ActivateModel {
@InjectCapability
private ActionCapability actionCapability;
@Override
public ConfirmerResult execute(ConfirmerInput data) {
List<ExecutableAction> executableActionList = data.getExecutableActionData();
ExecutorService executor = actionCapability.getExecutor(ActionCore.ExecutorType.VIRTUAL);
CountDownLatch latch = new CountDownLatch(executableActionList.size());
ConfirmerResult result = new ConfirmerResult();
List<String> uuids = result.getUuids();
for (ExecutableAction executableAction : executableActionList) {
executor.execute(() -> {
try {
@@ -61,28 +53,22 @@ public class ActionConfirmer extends AbstractAgentSubModule<ConfirmerInput, Conf
}
return result;
}
private String buildPrompt(ExecutableAction data, String input, List<Message> recentMessages) {
JSONObject prompt = new JSONObject();
prompt.put("[用户输入]", input);
JSONObject actionData = prompt.putObject("[行动数据]");
actionData.put("[行动倾向]", data.getTendency());
actionData.put("[行动原因]", data.getReason());
actionData.put("[行动来源]", data.getSource());
actionData.put("[行动描述]", data.getDescription());
JSONArray messageData = prompt.putArray("[近期对话]");
messageData.addAll(recentMessages);
return prompt.toString();
}
@Override
public String modelKey() {
return "action-confirmer";
}
@Override
public boolean withBasicPrompt() {
return false;

View File

@@ -4,9 +4,8 @@ import cn.hutool.core.bean.BeanUtil;
import com.alibaba.fastjson2.JSONArray;
import com.alibaba.fastjson2.JSONObject;
import work.slhaf.partner.api.agent.factory.capability.annotation.InjectCapability;
import work.slhaf.partner.api.agent.factory.module.abstracts.AbstractAgentSubModule;
import work.slhaf.partner.api.agent.factory.module.abstracts.AbstractAgentModule;
import work.slhaf.partner.api.agent.factory.module.abstracts.ActivateModel;
import work.slhaf.partner.api.agent.factory.module.annotation.AgentSubModule;
import work.slhaf.partner.api.agent.factory.module.annotation.Init;
import work.slhaf.partner.api.chat.pojo.ChatResponse;
import work.slhaf.partner.common.thread.InteractionThreadPoolExecutor;
@@ -22,19 +21,14 @@ import java.util.List;
import java.util.Map;
import java.util.concurrent.Callable;
@AgentSubModule
public class ActionEvaluator extends AbstractAgentSubModule<EvaluatorInput, List<EvaluatorResult>> implements ActivateModel {
public class ActionEvaluator extends AbstractAgentModule.Sub<EvaluatorInput, List<EvaluatorResult>> implements ActivateModel {
@InjectCapability
private ActionCapability actionCapability;
private InteractionThreadPoolExecutor executor;
@Init
public void init() {
executor = InteractionThreadPoolExecutor.getInstance();
}
/**
* 对输入的行为倾向进行评估,并根据评估结果,对缓存做出调整
*
@@ -47,7 +41,6 @@ public class ActionEvaluator extends AbstractAgentSubModule<EvaluatorInput, List
List<Callable<EvaluatorResult>> tasks = getTasks(batchInputs);
return executor.invokeAllAndReturn(tasks);
}
private List<Callable<EvaluatorResult>> getTasks(List<EvaluatorBatchInput> batchInputs) {
List<Callable<EvaluatorResult>> list = new ArrayList<>();
for (EvaluatorBatchInput batchInput : batchInputs) {
@@ -60,7 +53,6 @@ public class ActionEvaluator extends AbstractAgentSubModule<EvaluatorInput, List
}
return list;
}
private List<EvaluatorBatchInput> buildEvaluatorBatchInput(EvaluatorInput data) {
List<EvaluatorBatchInput> list = new ArrayList<>();
for (String tendency : data.getTendencies()) {
@@ -74,30 +66,23 @@ public class ActionEvaluator extends AbstractAgentSubModule<EvaluatorInput, List
}
return list;
}
private String buildPrompt(EvaluatorBatchInput batchInput) {
JSONObject prompt = new JSONObject();
prompt.put("[行动倾向]", batchInput.getTendency());
JSONArray memoryData = prompt.putArray("[相关记忆切片]");
for (EvaluatedSlice evaluatedSlice : batchInput.getActivatedSlices()) {
JSONObject memory = memoryData.addObject();
memory.put("[日期]", evaluatedSlice.getDate());
memory.put("[摘要]", evaluatedSlice.getSummary());
}
JSONObject availableActionData = prompt.putObject("[可用行动单元]");
availableActionData.putAll(batchInput.getAvailableActions());
return prompt.toString();
}
@Override
public String modelKey() {
return "action_evaluator";
}
@Override
public boolean withBasicPrompt() {
return true;

View File

@@ -1,11 +1,9 @@
package work.slhaf.partner.module.modules.action.planner.extractor;
import com.alibaba.fastjson2.JSONObject;
import lombok.extern.slf4j.Slf4j;
import work.slhaf.partner.api.agent.factory.capability.annotation.InjectCapability;
import work.slhaf.partner.api.agent.factory.module.abstracts.AbstractAgentSubModule;
import work.slhaf.partner.api.agent.factory.module.abstracts.AbstractAgentModule;
import work.slhaf.partner.api.agent.factory.module.abstracts.ActivateModel;
import work.slhaf.partner.api.agent.factory.module.annotation.AgentSubModule;
import work.slhaf.partner.api.chat.pojo.ChatResponse;
import work.slhaf.partner.core.action.ActionCapability;
import work.slhaf.partner.module.modules.action.planner.extractor.entity.ExtractorInput;
@@ -13,13 +11,9 @@ import work.slhaf.partner.module.modules.action.planner.extractor.entity.Extract
import java.util.List;
@Slf4j
@AgentSubModule
public class ActionExtractor extends AbstractAgentSubModule<ExtractorInput, ExtractorResult> implements ActivateModel {
public class ActionExtractor extends AbstractAgentModule.Sub<ExtractorInput, ExtractorResult> implements ActivateModel {
@InjectCapability
private ActionCapability actionCapability;
@Override
public ExtractorResult execute(ExtractorInput data) {
ExtractorResult result = new ExtractorResult();
@@ -28,7 +22,6 @@ public class ActionExtractor extends AbstractAgentSubModule<ExtractorInput, Extr
result.setTendencies(tendencyCache);
return result;
}
for (int i = 0; i < 3; i++) {
try {
ChatResponse response = this.singleChat(JSONObject.toJSONString(data));
@@ -37,15 +30,12 @@ public class ActionExtractor extends AbstractAgentSubModule<ExtractorInput, Extr
log.error("[ActionExtractor] 提取信息出错", e);
}
}
return new ExtractorResult();
}
@Override
public String modelKey() {
return "action_extractor";
}
@Override
public boolean withBasicPrompt() {
return false;

View File

@@ -3,9 +3,7 @@ package work.slhaf.partner.module.modules.memory.selector;
import com.alibaba.fastjson2.JSONObject;
import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.extern.slf4j.Slf4j;
import work.slhaf.partner.api.agent.factory.capability.annotation.InjectCapability;
import work.slhaf.partner.api.agent.factory.module.annotation.AgentRunningModule;
import work.slhaf.partner.api.agent.factory.module.annotation.InjectModule;
import work.slhaf.partner.core.cognation.CognationCapability;
import work.slhaf.partner.core.memory.MemoryCapability;
@@ -24,23 +22,17 @@ import work.slhaf.partner.runtime.interaction.data.context.PartnerRunningFlowCon
import java.time.LocalDate;
import java.util.*;
@EqualsAndHashCode(callSuper = true)
@Data
@Slf4j
@AgentRunningModule(name = "memory_selector", order = 2)
public class MemorySelector extends PreRunningAbstractAgentModuleAbstract {
@InjectCapability
private MemoryCapability memoryCapability;
@InjectCapability
private CognationCapability cognationCapability;
@InjectModule
private SliceSelectEvaluator sliceSelectEvaluator;
@InjectModule
private MemorySelectExtractor memorySelectExtractor;
@Override
public void doExecute(PartnerRunningFlowContext runningFlowContext) {
String userId = runningFlowContext.getUserId();
@@ -53,7 +45,6 @@ public class MemorySelector extends PreRunningAbstractAgentModuleAbstract {
}
setModuleContextRecall(runningFlowContext);
}
private List<EvaluatedSlice> selectAndEvaluateMemory(PartnerRunningFlowContext runningFlowContext, ExtractorResult extractorResult) {
log.debug("[MemorySelector] 触发记忆回溯...");
//查找切片
@@ -71,7 +62,6 @@ public class MemorySelector extends PreRunningAbstractAgentModuleAbstract {
log.debug("[MemorySelector] 切片评估结果: {}", JSONObject.toJSONString(memorySlices));
return memorySlices;
}
private void setModuleContextRecall(PartnerRunningFlowContext runningFlowContext) {
String userId = runningFlowContext.getUserId();
boolean recall = memoryCapability.hasActivatedSlices(userId);
@@ -80,8 +70,6 @@ public class MemorySelector extends PreRunningAbstractAgentModuleAbstract {
runningFlowContext.getModuleContext().getExtraContext().put("recall_count", memoryCapability.getActivatedSlicesSize(userId));
}
}
private void setMemoryResultList(List<MemoryResult> memoryResultList, List<ExtractorMatchData> matches, String userId) {
for (ExtractorMatchData match : matches) {
try {
@@ -101,7 +89,6 @@ public class MemorySelector extends PreRunningAbstractAgentModuleAbstract {
}
//清理切片记录
memoryCapability.cleanSelectedSliceFilter();
//根据userInfo过滤是否为私人记忆
for (MemoryResult memoryResult : memoryResultList) {
//过滤终点记忆
@@ -110,25 +97,21 @@ public class MemorySelector extends PreRunningAbstractAgentModuleAbstract {
memoryResult.getRelatedMemorySliceResult().removeIf(m -> removeOrNot(m, userId));
}
}
private void removeDuplicateSlice(MemoryResult memoryResult) {
Collection<String> values = memoryCapability.getDialogMap().values();
memoryResult.getRelatedMemorySliceResult().removeIf(m -> values.contains(m.getSummary()));
memoryResult.getMemorySliceResult().removeIf(m -> values.contains(m.getMemorySlice().getSummary()));
}
private boolean removeOrNot(MemorySlice memorySlice, String userId) {
if (memorySlice.isPrivate()) {
return memorySlice.getStartUserId().equals(userId);
}
return false;
}
@Override
public String moduleName() {
return "[记忆模块]";
}
@Override
protected Map<String, String> getPromptDataMap(PartnerRunningFlowContext context) {
HashMap<String, String> map = new HashMap<>();
@@ -137,12 +120,10 @@ public class MemorySelector extends PreRunningAbstractAgentModuleAbstract {
if (!dialogMapStr.isEmpty()) {
map.put("[记忆缓存] <你最近两日和所有聊天者的对话记忆印象>", dialogMapStr);
}
String userDialogMapStr = memoryCapability.getUserDialogMapStr(userId);
if (userDialogMapStr != null && !userDialogMapStr.isEmpty() && !cognationCapability.isSingleUser()) {
map.put("[用户记忆缓存] <与最新一条消息的发送者的近两天对话记忆印象, 可能与[记忆缓存]稍有重复>", userDialogMapStr);
}
String sliceStr = memoryCapability.getActivatedSlicesStr(userId);
if (sliceStr != null && !sliceStr.isEmpty()) {
map.put("[记忆切片] <你与最新一条消息的发送者的相关回忆, 不会与[记忆缓存]重复, 如果有重复你也可以指出来>", sliceStr);
@@ -150,4 +131,8 @@ public class MemorySelector extends PreRunningAbstractAgentModuleAbstract {
return map;
}
@Override
public int order() {
return 2;
}
}

View File

@@ -5,10 +5,8 @@ import cn.hutool.json.JSONUtil;
import com.alibaba.fastjson2.JSONObject;
import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.extern.slf4j.Slf4j;
import work.slhaf.partner.api.agent.factory.module.abstracts.AbstractAgentSubModule;
import work.slhaf.partner.api.agent.factory.module.abstracts.AbstractAgentModule;
import work.slhaf.partner.api.agent.factory.module.abstracts.ActivateModel;
import work.slhaf.partner.api.agent.factory.module.annotation.AgentSubModule;
import work.slhaf.partner.api.agent.factory.module.annotation.Init;
import work.slhaf.partner.common.thread.InteractionThreadPoolExecutor;
import work.slhaf.partner.core.memory.pojo.EvaluatedSlice;
@@ -27,20 +25,14 @@ import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import static work.slhaf.partner.common.util.ExtractUtil.extractJson;
@EqualsAndHashCode(callSuper = true)
@Data
@Slf4j
@AgentSubModule
public class SliceSelectEvaluator extends AbstractAgentSubModule<EvaluatorInput, List<EvaluatedSlice>> implements ActivateModel {
public class SliceSelectEvaluator extends AbstractAgentModule.Sub<EvaluatorInput, List<EvaluatedSlice>> implements ActivateModel {
private InteractionThreadPoolExecutor executor;
@Init
public void init() {
executor = InteractionThreadPoolExecutor.getInstance();
}
@Override
public List<EvaluatedSlice> execute(EvaluatorInput evaluatorInput) {
log.debug("[SliceSelectEvaluator] 切片评估模块开始...");
@@ -83,16 +75,13 @@ public class SliceSelectEvaluator extends AbstractAgentSubModule<EvaluatorInput,
return null;
});
}
executor.invokeAll(tasks, 30, TimeUnit.SECONDS);
log.debug("[SliceSelectEvaluator] 评估模块结束, 输出队列: {}", queue);
List<EvaluatedSlice> temp = queue.stream().toList();
return new ArrayList<>(temp);
}
private void setSliceSummaryList(MemoryResult memoryResult, List<SliceSummary> sliceSummaryList, Map<Long, SliceSummary> map) {
for (MemorySliceResult memorySliceResult : memoryResult.getMemorySliceResult()) {
SliceSummary sliceSummary = new SliceSummary();
sliceSummary.setId(memorySliceResult.getMemorySlice().getTimestamp());
StringBuilder stringBuilder = new StringBuilder();
@@ -109,29 +98,22 @@ public class SliceSelectEvaluator extends AbstractAgentSubModule<EvaluatorInput,
sliceSummary.setSummary(stringBuilder.toString());
Long timestamp = memorySliceResult.getMemorySlice().getTimestamp();
sliceSummary.setDate(DateUtil.date(timestamp).toLocalDateTime().toLocalDate());
sliceSummaryList.add(sliceSummary);
map.put(timestamp, sliceSummary);
}
for (MemorySlice memorySlice : memoryResult.getRelatedMemorySliceResult()) {
SliceSummary sliceSummary = new SliceSummary();
sliceSummary.setId(memorySlice.getTimestamp());
sliceSummary.setSummary(memorySlice.getSummary());
sliceSummaryList.add(sliceSummary);
map.put(memorySlice.getTimestamp(), sliceSummary);
}
}
public String modelKey() {
return "slice_evaluator";
}
@Override
public boolean withBasicPrompt() {
return false;
}
}

View File

@@ -4,11 +4,9 @@ import cn.hutool.json.JSONUtil;
import com.alibaba.fastjson2.JSONObject;
import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.extern.slf4j.Slf4j;
import work.slhaf.partner.api.agent.factory.capability.annotation.InjectCapability;
import work.slhaf.partner.api.agent.factory.module.abstracts.AbstractAgentSubModule;
import work.slhaf.partner.api.agent.factory.module.abstracts.AbstractAgentModule;
import work.slhaf.partner.api.agent.factory.module.abstracts.ActivateModel;
import work.slhaf.partner.api.agent.factory.module.annotation.AgentSubModule;
import work.slhaf.partner.api.chat.pojo.Message;
import work.slhaf.partner.api.chat.pojo.MetaMessage;
import work.slhaf.partner.core.cognation.CognationCapability;
@@ -24,20 +22,14 @@ import java.util.List;
import static work.slhaf.partner.common.util.ExtractUtil.extractJson;
import static work.slhaf.partner.common.util.ExtractUtil.fixTopicPath;
@EqualsAndHashCode(callSuper = true)
@Data
@Slf4j
@AgentSubModule
public class MemorySelectExtractor extends AbstractAgentSubModule<PartnerRunningFlowContext, ExtractorResult>
public class MemorySelectExtractor extends AbstractAgentModule.Sub<PartnerRunningFlowContext, ExtractorResult>
implements ActivateModel {
@InjectCapability
private MemoryCapability memoryCapability;
@InjectCapability
private CognationCapability cognationCapability;
@Override
public ExtractorResult execute(PartnerRunningFlowContext context) {
log.debug("[MemorySelectExtractor] 主题提取模块开始...");
@@ -52,7 +44,6 @@ public class MemorySelectExtractor extends AbstractAgentSubModule<PartnerRunning
chatMessages.add(metaMessage.getAssistantMessage());
}
}
ExtractorResult extractorResult;
try {
List<EvaluatedSlice> activatedMemorySlices = memoryCapability.getActivatedSlices(context.getUserId());
@@ -75,7 +66,6 @@ public class MemorySelectExtractor extends AbstractAgentSubModule<PartnerRunning
}
return fix(extractorResult);
}
private ExtractorResult fix(ExtractorResult extractorResult) {
extractorResult.getMatches().forEach(m -> {
if (m.getType().equals(ExtractorMatchData.Constant.DATE)) {
@@ -89,15 +79,12 @@ public class MemorySelectExtractor extends AbstractAgentSubModule<PartnerRunning
extractorResult.getMatches().removeIf(m -> m.getText().split("->")[0].isEmpty());
return extractorResult;
}
@Override
public String modelKey() {
return "topic_extractor";
}
@Override
public boolean withBasicPrompt() {
return false;
}
}

View File

@@ -3,9 +3,7 @@ package work.slhaf.partner.module.modules.memory.updater;
import com.alibaba.fastjson2.JSONObject;
import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.extern.slf4j.Slf4j;
import work.slhaf.partner.api.agent.factory.capability.annotation.InjectCapability;
import work.slhaf.partner.api.agent.factory.module.annotation.AgentRunningModule;
import work.slhaf.partner.api.agent.factory.module.annotation.Init;
import work.slhaf.partner.api.agent.factory.module.annotation.InjectModule;
import work.slhaf.partner.api.chat.constant.ChatConstant;
@@ -29,11 +27,8 @@ import java.util.concurrent.Callable;
import java.util.concurrent.atomic.AtomicInteger;
import static work.slhaf.partner.common.util.ExtractUtil.extractUserId;
@EqualsAndHashCode(callSuper = true)
@Data
@Slf4j
@AgentRunningModule(name = "memory_updater", order = 7)
public class MemoryUpdater extends PostRunningAbstractAgentModuleAbstract {
private static final long SCHEDULED_UPDATE_INTERVAL = 10 * 1000;
@@ -52,19 +47,16 @@ public class MemoryUpdater extends PostRunningAbstractAgentModuleAbstract {
private SingleSummarizer singleSummarizer;
@InjectModule
private TotalSummarizer totalSummarizer;
private InteractionThreadPoolExecutor executor;
/**
* 用于临时存储完整对话记录在MemoryManager的分离后
*/
private List<Message> tempMessage;
@Init
public void init() {
executor = InteractionThreadPoolExecutor.getInstance();
setScheduledUpdater();
}
private void setScheduledUpdater() {
executor.execute(() -> {
log.info("[MemoryUpdater] 记忆自动更新线程启动");
@@ -88,7 +80,6 @@ public class MemoryUpdater extends PostRunningAbstractAgentModuleAbstract {
log.info("[MemoryUpdater] 记忆自动更新线程结束");
});
}
@Override
public void doExecute(PartnerRunningFlowContext context) {
if (context.isFinished()) {
@@ -118,12 +109,10 @@ public class MemoryUpdater extends PostRunningAbstractAgentModuleAbstract {
}
});
}
@Override
protected boolean relyOnMessage() {
return true;
}
private void updateMemory() {
log.debug("[MemoryUpdater] 记忆更新流程开始...");
tempMessage = new ArrayList<>(cognationCapability.getChatMessages());
@@ -135,7 +124,6 @@ public class MemoryUpdater extends PostRunningAbstractAgentModuleAbstract {
cognationCapability.resetLastUpdatedTime();
log.debug("[MemoryUpdater] 记忆更新流程结束...");
}
private void updateMultiChatSlices(HashMap<String, String> singleMemorySummary) {
// 此时chatMessages中不再包含单聊记录直接执行摘要以及切片插入
// 对剩下的多人聊天记录进行进行摘要
@@ -161,21 +149,17 @@ public class MemoryUpdater extends PostRunningAbstractAgentModuleAbstract {
// 设置involvedUserId
setInvolvedUserId(userId, memorySlice, chatMessages);
memoryCapability.insertSlice(memorySlice, summarizeResult.getTopicPath());
memoryCapability.updateDialogMap(LocalDateTime.now(), summarizeResult.getSummary());
} else {
log.debug("[MemoryUpdater] 不存在多人聊天记录, 将以单聊总结为对话缓存的主要输入: {}", singleMemorySummary);
memoryCapability.updateDialogMap(LocalDateTime.now(), totalSummarizer.execute(singleMemorySummary));
}
log.debug("[MemoryUpdater] 对话缓存更新完毕");
log.debug("[MemoryUpdater] 多人聊天记忆更新流程结束...");
return null;
};
executor.invokeAll(List.of(task));
}
private void cleanMessage(List<Message> chatMessages) {
// 清理时间标识
for (Message message : chatMessages) {
@@ -186,7 +170,6 @@ public class MemoryUpdater extends PostRunningAbstractAgentModuleAbstract {
message.setContent(message.getContent().replace("\r\n**" + time, ""));
}
}
private void clearChatMessages() {
// 不全部清空,保留一部分输入防止上下文割裂
cognationCapability.getMessageLock().lock();
@@ -196,7 +179,6 @@ public class MemoryUpdater extends PostRunningAbstractAgentModuleAbstract {
cognationCapability.getChatMessages().addAll(0, temp);
cognationCapability.getMessageLock().unlock();
}
private void setInvolvedUserId(String startUserId, MemorySlice memorySlice, List<Message> chatMessages) {
for (Message chatMessage : chatMessages) {
if (chatMessage.getRole().equals(ChatConstant.Character.ASSISTANT)) {
@@ -214,7 +196,6 @@ public class MemoryUpdater extends PostRunningAbstractAgentModuleAbstract {
memorySlice.getInvolvedUserIds().add(userId);
}
}
private void updateSingleChatSlices(HashMap<String, String> singleMemorySummary) {
log.debug("[MemoryUpdater] 单聊记忆更新流程开始...");
// 更新单聊记忆同时从chatMessages中去掉单聊记忆
@@ -247,23 +228,19 @@ public class MemoryUpdater extends PostRunningAbstractAgentModuleAbstract {
}
return null;
});
}
executor.invokeAll(tasks);
log.debug("[MemoryUpdater] 单聊记忆更新结束...");
}
private SummarizeResult summarize(SummarizeInput summarizeInput) {
singleSummarizer.execute(summarizeInput.getChatMessages());
return multiSummarizer.execute(summarizeInput);
}
private MemorySlice getMemorySlice(String userId, SummarizeResult summarizeResult, List<Message> chatMessages) {
MemorySlice memorySlice = new MemorySlice();
// 设置 memoryId,timestamp
memorySlice.setMemoryId(cognationCapability.getCurrentMemoryId());
memorySlice.setTimestamp(System.currentTimeMillis());
// 补充信息
memorySlice.setPrivate(summarizeResult.isPrivate());
memorySlice.setSummary(summarizeResult.getSummary());
@@ -277,4 +254,9 @@ public class MemoryUpdater extends PostRunningAbstractAgentModuleAbstract {
memorySlice.setRelatedTopics(relatedTopicPathList);
return memorySlice;
}
@Override
public int order() {
return 7;
}
}

View File

@@ -4,10 +4,8 @@ import cn.hutool.json.JSONUtil;
import com.alibaba.fastjson2.JSONObject;
import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.extern.slf4j.Slf4j;
import work.slhaf.partner.api.agent.factory.module.abstracts.AbstractAgentSubModule;
import work.slhaf.partner.api.agent.factory.module.abstracts.AbstractAgentModule;
import work.slhaf.partner.api.agent.factory.module.abstracts.ActivateModel;
import work.slhaf.partner.api.agent.factory.module.annotation.AgentSubModule;
import work.slhaf.partner.api.agent.factory.module.annotation.Init;
import work.slhaf.partner.api.chat.pojo.ChatResponse;
import work.slhaf.partner.module.modules.memory.updater.summarizer.entity.SummarizeInput;
@@ -18,18 +16,13 @@ import java.util.List;
import static work.slhaf.partner.common.util.ExtractUtil.extractJson;
import static work.slhaf.partner.common.util.ExtractUtil.fixTopicPath;
@EqualsAndHashCode(callSuper = true)
@Data
@Slf4j
@AgentSubModule
public class MultiSummarizer extends AbstractAgentSubModule<SummarizeInput, SummarizeResult> implements ActivateModel {
public class MultiSummarizer extends AbstractAgentModule.Sub<SummarizeInput, SummarizeResult> implements ActivateModel {
@Init
public void init() {
updateChatClientSettings();
}
@Override
public SummarizeResult execute(SummarizeInput input) {
log.debug("[MemorySummarizer] 整体摘要开始...");
@@ -38,12 +31,10 @@ public class MultiSummarizer extends AbstractAgentSubModule<SummarizeInput, Summ
SummarizeResult result = JSONObject.parseObject(extractJson(response.getMessage()), SummarizeResult.class);
return fix(result);
}
private SummarizeResult fix(SummarizeResult result) {
if (result == null || result.getTopicPath() == null || result.getTopicPath().isEmpty()) {
return result;
}
String topicPath = fixTopicPath(result.getTopicPath());
List<String> relatedTopicPath = new ArrayList<>();
for (String s : result.getRelatedTopicPath()) {
@@ -53,15 +44,12 @@ public class MultiSummarizer extends AbstractAgentSubModule<SummarizeInput, Summ
result.setRelatedTopicPath(relatedTopicPath);
return result;
}
@Override
public String modelKey() {
return "multi_summarizer";
}
@Override
public boolean withBasicPrompt() {
return true;
}
}

View File

@@ -3,10 +3,8 @@ package work.slhaf.partner.module.modules.memory.updater.summarizer;
import com.alibaba.fastjson2.JSONObject;
import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.extern.slf4j.Slf4j;
import work.slhaf.partner.api.agent.factory.module.abstracts.AbstractAgentSubModule;
import work.slhaf.partner.api.agent.factory.module.abstracts.AbstractAgentModule;
import work.slhaf.partner.api.agent.factory.module.abstracts.ActivateModel;
import work.slhaf.partner.api.agent.factory.module.annotation.AgentSubModule;
import work.slhaf.partner.api.agent.factory.module.annotation.Init;
import work.slhaf.partner.api.chat.constant.ChatConstant;
import work.slhaf.partner.api.chat.pojo.ChatResponse;
@@ -18,20 +16,14 @@ import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
@EqualsAndHashCode(callSuper = true)
@Slf4j
@Data
@AgentSubModule
public class SingleSummarizer extends AbstractAgentSubModule<List<Message>, Void> implements ActivateModel {
public class SingleSummarizer extends AbstractAgentModule.Sub<List<Message>, Void> implements ActivateModel {
private InteractionThreadPoolExecutor executor;
@Init
public void init() {
this.executor = InteractionThreadPoolExecutor.getInstance();
}
@Override
public Void execute(List<Message> chatMessages) {
log.debug("[MemorySummarizer] 长文本摘要开始...");
@@ -55,7 +47,6 @@ public class SingleSummarizer extends AbstractAgentSubModule<List<Message>, Void
log.debug("[MemorySummarizer] 长文本摘要结束");
return null;
}
private String singleExecute(String primaryContent) {
try {
ChatResponse response = this.singleChat(primaryContent);
@@ -65,15 +56,12 @@ public class SingleSummarizer extends AbstractAgentSubModule<List<Message>, Void
return primaryContent;
}
}
@Override
public String modelKey() {
return "single_summarizer";
}
@Override
public boolean withBasicPrompt() {
return false;
}
}

View File

@@ -4,41 +4,31 @@ import cn.hutool.json.JSONUtil;
import com.alibaba.fastjson2.JSONObject;
import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.extern.slf4j.Slf4j;
import work.slhaf.partner.api.agent.factory.module.abstracts.AbstractAgentSubModule;
import work.slhaf.partner.api.agent.factory.module.abstracts.AbstractAgentModule;
import work.slhaf.partner.api.agent.factory.module.abstracts.ActivateModel;
import work.slhaf.partner.api.agent.factory.module.annotation.AgentSubModule;
import work.slhaf.partner.api.agent.factory.module.annotation.Init;
import work.slhaf.partner.api.chat.pojo.ChatResponse;
import java.util.HashMap;
import static work.slhaf.partner.common.util.ExtractUtil.extractJson;
@EqualsAndHashCode(callSuper = true)
@Data
@Slf4j
@AgentSubModule
public class TotalSummarizer extends AbstractAgentSubModule<HashMap<String, String>, String> implements ActivateModel {
public class TotalSummarizer extends AbstractAgentModule.Sub<HashMap<String, String>, String> implements ActivateModel {
@Init
public void init() {
updateChatClientSettings();
}
public String execute(HashMap<String, String> singleMemorySummary){
ChatResponse response = this.singleChat(JSONUtil.toJsonPrettyStr(singleMemorySummary));
return JSONObject.parseObject(extractJson(response.getMessage())).getString("content");
}
@Override
public String modelKey() {
return "total_summarizer";
}
@Override
public boolean withBasicPrompt() {
return true;
}
}

View File

@@ -1,9 +1,7 @@
package work.slhaf.partner.module.modules.perceive.selector;
import lombok.Setter;
import lombok.extern.slf4j.Slf4j;
import work.slhaf.partner.api.agent.factory.capability.annotation.InjectCapability;
import work.slhaf.partner.api.agent.factory.module.annotation.AgentRunningModule;
import work.slhaf.partner.core.perceive.PerceiveCapability;
import work.slhaf.partner.core.perceive.pojo.User;
import work.slhaf.partner.module.common.module.PreRunningAbstractAgentModuleAbstract;
@@ -11,19 +9,13 @@ import work.slhaf.partner.runtime.interaction.data.context.PartnerRunningFlowCon
import java.util.HashMap;
import java.util.Map;
@Slf4j
@Setter
@AgentRunningModule(name = "perceive_selector", order = 2)
public class PerceiveSelector extends PreRunningAbstractAgentModuleAbstract {
@InjectCapability
private PerceiveCapability perceiveCapability;
@Override
public void doExecute(PartnerRunningFlowContext context) {
}
@Override
protected Map<String, String> getPromptDataMap(PartnerRunningFlowContext context) {
HashMap<String, String> map = new HashMap<>();
@@ -34,9 +26,13 @@ public class PerceiveSelector extends PreRunningAbstractAgentModuleAbstract {
map.put("[静态记忆] <你关于最新聊天用户的静态记忆>", user.getStaticMemory().toString());
return map;
}
@Override
public String moduleName() {
return "[感知模块]";
}
@Override
public int order() {
return 2;
}
}

View File

@@ -2,9 +2,7 @@ package work.slhaf.partner.module.modules.perceive.updater;
import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.extern.slf4j.Slf4j;
import work.slhaf.partner.api.agent.factory.capability.annotation.InjectCapability;
import work.slhaf.partner.api.agent.factory.module.annotation.AgentRunningModule;
import work.slhaf.partner.api.agent.factory.module.annotation.Init;
import work.slhaf.partner.api.agent.factory.module.annotation.InjectModule;
import work.slhaf.partner.common.thread.InteractionThreadPoolExecutor;
@@ -22,34 +20,25 @@ import java.util.HashMap;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.locks.ReentrantLock;
/**
* 感知更新,异步
*/
@EqualsAndHashCode(callSuper = true)
@Slf4j
@Data
@AgentRunningModule(name = "perceive_updater", order = 7)
public class PerceiveUpdater extends PostRunningAbstractAgentModuleAbstract {
@InjectCapability
private PerceiveCapability perceiveCapability;
@InjectCapability
private CognationCapability cognationCapability;
@InjectModule
private RelationExtractor relationExtractor;
@InjectModule
private StaticMemoryExtractor staticMemoryExtractor;
private InteractionThreadPoolExecutor executor;
@Init
public void init() {
this.executor = InteractionThreadPoolExecutor.getInstance();
}
@Override
public void doExecute(PartnerRunningFlowContext context) {
executor.execute(() -> {
@@ -69,12 +58,10 @@ public class PerceiveUpdater extends PostRunningAbstractAgentModuleAbstract {
perceiveCapability.updateUser(user);
});
}
@Override
protected boolean relyOnMessage() {
return true;
}
private void runRelationExtractorAction(PartnerRunningFlowContext context, ReentrantLock userLock, User user) {
RelationExtractResult relationExtractResult = relationExtractor.execute(context);
userLock.lock();
@@ -84,7 +71,6 @@ public class PerceiveUpdater extends PostRunningAbstractAgentModuleAbstract {
user.updateRelationChange(relationExtractResult.getRelationChangeHistory());
userLock.unlock();
}
private void runStaticExtractorAction(PartnerRunningFlowContext context, ReentrantLock userLock, User user) {
HashMap<String, String> newStaticMemory = staticMemoryExtractor.execute(context);
userLock.lock();
@@ -92,4 +78,8 @@ public class PerceiveUpdater extends PostRunningAbstractAgentModuleAbstract {
userLock.unlock();
}
@Override
public int order() {
return 7;
}
}

View File

@@ -4,9 +4,8 @@ import com.alibaba.fastjson2.JSONObject;
import lombok.Data;
import lombok.EqualsAndHashCode;
import work.slhaf.partner.api.agent.factory.capability.annotation.InjectCapability;
import work.slhaf.partner.api.agent.factory.module.abstracts.AbstractAgentSubModule;
import work.slhaf.partner.api.agent.factory.module.abstracts.AbstractAgentModule;
import work.slhaf.partner.api.agent.factory.module.abstracts.ActivateModel;
import work.slhaf.partner.api.agent.factory.module.annotation.AgentSubModule;
import work.slhaf.partner.api.chat.pojo.ChatResponse;
import work.slhaf.partner.api.chat.pojo.Message;
import work.slhaf.partner.core.cognation.CognationCapability;
@@ -19,20 +18,14 @@ import work.slhaf.partner.runtime.interaction.data.context.PartnerRunningFlowCon
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
@EqualsAndHashCode(callSuper = true)
@Data
@AgentSubModule
public class RelationExtractor extends AbstractAgentSubModule<PartnerRunningFlowContext, RelationExtractResult> implements ActivateModel {
public class RelationExtractor extends AbstractAgentModule.Sub<PartnerRunningFlowContext, RelationExtractResult> implements ActivateModel {
@InjectCapability
private CognationCapability cognationCapability;
@InjectCapability
private PerceiveCapability perceiveCapability;
private List<Message> tempMessages;
@Override
public RelationExtractResult execute(PartnerRunningFlowContext context){
tempMessages = new ArrayList<>(cognationCapability.getChatMessages());
@@ -43,8 +36,6 @@ public class RelationExtractor extends AbstractAgentSubModule<PartnerRunningFlow
perceiveCapability.updateUser(user);
return relationExtractResult;
}
private User getTempUser(PartnerRunningFlowContext context, RelationExtractResult relationExtractResult) {
User user = new User();
user.setUuid(context.getUserId());
@@ -53,12 +44,10 @@ public class RelationExtractor extends AbstractAgentSubModule<PartnerRunningFlow
user.setAttitude(relationExtractResult.getAttitude());
return user;
}
private RelationExtractResult getRelationResult(RelationExtractInput input) {
ChatResponse response = singleChat(JSONObject.toJSONString(input));
return JSONObject.parseObject(response.getMessage(), RelationExtractResult.class);
}
private RelationExtractInput getRelationInput(String userId) {
HashMap<String,String> map = new HashMap<>();
User user = perceiveCapability.getUser(userId);
@@ -72,15 +61,12 @@ public class RelationExtractor extends AbstractAgentSubModule<PartnerRunningFlow
input.setChatMessages(tempMessages);
return input;
}
@Override
public String modelKey() {
return "relation_extractor";
}
@Override
public boolean withBasicPrompt() {
return true;
}
}

View File

@@ -5,9 +5,8 @@ import com.alibaba.fastjson2.JSONObject;
import lombok.Data;
import lombok.EqualsAndHashCode;
import work.slhaf.partner.api.agent.factory.capability.annotation.InjectCapability;
import work.slhaf.partner.api.agent.factory.module.abstracts.AbstractAgentSubModule;
import work.slhaf.partner.api.agent.factory.module.abstracts.AbstractAgentModule;
import work.slhaf.partner.api.agent.factory.module.abstracts.ActivateModel;
import work.slhaf.partner.api.agent.factory.module.annotation.AgentSubModule;
import work.slhaf.partner.api.chat.pojo.ChatResponse;
import work.slhaf.partner.core.cognation.CognationCapability;
import work.slhaf.partner.core.perceive.PerceiveCapability;
@@ -15,17 +14,13 @@ import work.slhaf.partner.module.modules.perceive.updater.static_extractor.entit
import work.slhaf.partner.runtime.interaction.data.context.PartnerRunningFlowContext;
import java.util.HashMap;
@EqualsAndHashCode(callSuper = true)
@Data
@AgentSubModule
public class StaticMemoryExtractor extends AbstractAgentSubModule<PartnerRunningFlowContext, HashMap<String, String>> implements ActivateModel {
public class StaticMemoryExtractor extends AbstractAgentModule.Sub<PartnerRunningFlowContext, HashMap<String, String>> implements ActivateModel {
@InjectCapability
private CognationCapability cognationCapability;
@InjectCapability
private PerceiveCapability perceiveCapability;
@Override
public HashMap<String, String> execute(PartnerRunningFlowContext context) {
StaticMemoryExtractInput input = StaticMemoryExtractInput.builder()
@@ -33,22 +28,18 @@ public class StaticMemoryExtractor extends AbstractAgentSubModule<PartnerRunning
.messages(cognationCapability.getChatMessages())
.existedStaticMap(perceiveCapability.getUser(context.getUserId()).getStaticMemory())
.build();
ChatResponse response = singleChat(JSONUtil.toJsonPrettyStr(input));
JSONObject jsonObject = JSONObject.parseObject(response.getMessage());
HashMap<String, String> result = new HashMap<>();
jsonObject.forEach((k, v) -> result.put(k, (String) v));
return result;
}
@Override
public String modelKey() {
return "static_extractor";
}
@Override
public boolean withBasicPrompt() {
return true;
}
}

View File

@@ -1,29 +1,25 @@
package work.slhaf.partner.module.modules.process;
import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.extern.slf4j.Slf4j;
import work.slhaf.partner.api.agent.factory.capability.annotation.InjectCapability;
import work.slhaf.partner.api.agent.factory.module.abstracts.AbstractAgentRunningModule;
import work.slhaf.partner.api.agent.factory.module.annotation.AgentRunningModule;
import work.slhaf.partner.api.agent.factory.module.abstracts.AbstractAgentModule;
import work.slhaf.partner.core.cognation.CognationCapability;
import work.slhaf.partner.runtime.interaction.data.context.PartnerRunningFlowContext;
@EqualsAndHashCode(callSuper = true)
@Slf4j
@Data
@AgentRunningModule(name = "postprocess_executor", order = 6)
public class PostprocessExecutor extends AbstractAgentRunningModule<PartnerRunningFlowContext> {
public class PostprocessExecutor extends AbstractAgentModule.Running<PartnerRunningFlowContext> {
private static final int POST_PROCESS_TRIGGER_ROLL_LIMIT = 36;
@InjectCapability
private CognationCapability cognationCapability;
@Override
public void execute(PartnerRunningFlowContext context) {
boolean trigger = cognationCapability.getChatMessages().size() >= POST_PROCESS_TRIGGER_ROLL_LIMIT;
context.getModuleContext().getExtraContext().put("post_process_trigger", trigger);
log.debug("[PostprocessExecutor] 是否执行后处理: {}", trigger);
}
@Override
public int order() {
return 6;
}
}

View File

@@ -2,9 +2,7 @@ package work.slhaf.partner.module.modules.process;
import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.extern.slf4j.Slf4j;
import work.slhaf.partner.api.agent.factory.capability.annotation.InjectCapability;
import work.slhaf.partner.api.agent.factory.module.annotation.AgentRunningModule;
import work.slhaf.partner.core.cognation.CognationCapability;
import work.slhaf.partner.core.perceive.PerceiveCapability;
import work.slhaf.partner.core.perceive.pojo.User;
@@ -16,31 +14,24 @@ import java.time.LocalDateTime;
import java.time.format.DateTimeFormatter;
import java.util.HashMap;
import java.util.Map;
@EqualsAndHashCode(callSuper = true)
@Data
@Slf4j
@AgentRunningModule(name = "preprocess_executor", order = 1)
public class PreprocessExecutor extends PreRunningAbstractAgentModuleAbstract {
@InjectCapability
private CognationCapability cognationCapability;
@InjectCapability
private PerceiveCapability perceiveCapability;
@Override
public void doExecute(PartnerRunningFlowContext context) {
checkAndSetMemoryId();
getInteractionContext(context);
}
private void checkAndSetMemoryId() {
String currentMemoryId = cognationCapability.getCurrentMemoryId();
if (currentMemoryId == null || cognationCapability.getChatMessages().isEmpty()) {
cognationCapability.refreshMemoryId();
}
}
private void getInteractionContext(PartnerRunningFlowContext context) {
log.debug("[PreprocessExecutor] 预处理原始输入: {}", context);
User user = perceiveCapability.getUser(context.getUserInfo(), context.getPlatform());
@@ -49,15 +40,12 @@ public class PreprocessExecutor extends PreRunningAbstractAgentModuleAbstract {
}
String userId = user.getUuid();
context.setUserId(userId);
String userStr = "[" + context.getUserNickname() + "(" + userId + ")]";
String input = userStr + " " + context.getInput();
context.setInput(input);
setCoreContext(context);
log.debug("[PreprocessExecutor] 预处理结果: {}", context);
}
@Override
protected Map<String, String> getPromptDataMap(PartnerRunningFlowContext context) {
HashMap<String, String> map = new HashMap<>();
@@ -69,12 +57,10 @@ public class PreprocessExecutor extends PreRunningAbstractAgentModuleAbstract {
map.put("其他", "历史对话中将在用户消息的最后一行标注时间");
return map;
}
@Override
protected String moduleName() {
return "[基础模块]";
}
private void setCoreContext(PartnerRunningFlowContext context) {
CoreContext coreContext = context.getCoreContext();
coreContext.setText(context.getInput());
@@ -82,4 +68,9 @@ public class PreprocessExecutor extends PreRunningAbstractAgentModuleAbstract {
coreContext.setUserNick(context.getUserNickname());
coreContext.setUserId(context.getUserId());
}
@Override
public int order() {
return 1;
}
}