refactor(modules): refactor modules base class into AbstractAgentModule and remove unused @slf4j annotation

This commit is contained in:
2026-02-20 17:17:49 +08:00
parent 38c618a222
commit c47d2b2285
29 changed files with 90 additions and 537 deletions

View File

@@ -1,10 +1,9 @@
package work.slhaf.partner.module.common.module; package work.slhaf.partner.module.common.module;
import work.slhaf.partner.api.agent.factory.module.abstracts.AbstractAgentRunningModule; import work.slhaf.partner.api.agent.factory.module.abstracts.AbstractAgentModule;
import work.slhaf.partner.runtime.interaction.data.context.PartnerRunningFlowContext; import work.slhaf.partner.runtime.interaction.data.context.PartnerRunningFlowContext;
public abstract class PostRunningAbstractAgentModuleAbstract extends AbstractAgentRunningModule<PartnerRunningFlowContext> { public abstract class PostRunningAbstractAgentModuleAbstract extends AbstractAgentModule.Running<PartnerRunningFlowContext> {
@Override @Override
public final void execute(PartnerRunningFlowContext context) { public final void execute(PartnerRunningFlowContext context) {
boolean trigger = context.getModuleContext().getExtraContext().getBoolean("post_process_trigger"); boolean trigger = context.getModuleContext().getExtraContext().getBoolean("post_process_trigger");
@@ -13,8 +12,6 @@ public abstract class PostRunningAbstractAgentModuleAbstract extends AbstractAge
} }
doExecute(context); doExecute(context);
} }
public abstract void doExecute(PartnerRunningFlowContext context); public abstract void doExecute(PartnerRunningFlowContext context);
protected abstract boolean relyOnMessage(); protected abstract boolean relyOnMessage();
} }

View File

@@ -1,17 +1,14 @@
package work.slhaf.partner.module.common.module; package work.slhaf.partner.module.common.module;
import lombok.extern.slf4j.Slf4j; import work.slhaf.partner.api.agent.factory.module.abstracts.AbstractAgentModule;
import work.slhaf.partner.api.agent.factory.module.abstracts.AbstractAgentRunningModule;
import work.slhaf.partner.module.common.entity.AppendPromptData; import work.slhaf.partner.module.common.entity.AppendPromptData;
import work.slhaf.partner.runtime.interaction.data.context.PartnerRunningFlowContext; import work.slhaf.partner.runtime.interaction.data.context.PartnerRunningFlowContext;
import java.util.Map; import java.util.Map;
/** /**
* 前置模块抽象类 * 前置模块抽象类
*/ */
@Slf4j public abstract class PreRunningAbstractAgentModuleAbstract extends AbstractAgentModule.Running<PartnerRunningFlowContext> {
public abstract class PreRunningAbstractAgentModuleAbstract extends AbstractAgentRunningModule<PartnerRunningFlowContext> {
private synchronized void setAppendedPrompt(PartnerRunningFlowContext context) { private synchronized void setAppendedPrompt(PartnerRunningFlowContext context) {
AppendPromptData data = new AppendPromptData(); AppendPromptData data = new AppendPromptData();
data.setModuleName(moduleName()); data.setModuleName(moduleName());
@@ -19,25 +16,19 @@ public abstract class PreRunningAbstractAgentModuleAbstract extends AbstractAgen
data.setAppendedPrompt(map); data.setAppendedPrompt(map);
context.setAppendedPrompt(data); context.setAppendedPrompt(data);
} }
private synchronized void setActiveModule(PartnerRunningFlowContext context) { private synchronized void setActiveModule(PartnerRunningFlowContext context) {
context.getCoreContext().addActiveModule(moduleName()); context.getCoreContext().addActiveModule(moduleName());
} }
protected abstract Map<String, String> getPromptDataMap(PartnerRunningFlowContext context); protected abstract Map<String, String> getPromptDataMap(PartnerRunningFlowContext context);
/** /**
* 用于在CoreModule接收到的模块Prompt中标识模块名称 * 用于在CoreModule接收到的模块Prompt中标识模块名称
*/ */
protected abstract String moduleName(); protected abstract String moduleName();
@Override @Override
public final void execute(PartnerRunningFlowContext context) { public final void execute(PartnerRunningFlowContext context) {
doExecute(context); // 子类实现差异化逻辑 doExecute(context); // 子类实现差异化逻辑
setAppendedPrompt(context); // 通用逻辑 setAppendedPrompt(context); // 通用逻辑
setActiveModule(context); // 通用逻辑 setActiveModule(context); // 通用逻辑
} }
protected abstract void doExecute(PartnerRunningFlowContext context); protected abstract void doExecute(PartnerRunningFlowContext context);
} }

View File

@@ -2,7 +2,6 @@ package work.slhaf.partner.module.modules.action.dispatcher;
import lombok.val; import lombok.val;
import work.slhaf.partner.api.agent.factory.capability.annotation.InjectCapability; import work.slhaf.partner.api.agent.factory.capability.annotation.InjectCapability;
import work.slhaf.partner.api.agent.factory.module.annotation.AgentRunningModule;
import work.slhaf.partner.api.agent.factory.module.annotation.Init; import work.slhaf.partner.api.agent.factory.module.annotation.Init;
import work.slhaf.partner.api.agent.factory.module.annotation.InjectModule; import work.slhaf.partner.api.agent.factory.module.annotation.InjectModule;
import work.slhaf.partner.core.action.ActionCapability; import work.slhaf.partner.core.action.ActionCapability;
@@ -19,25 +18,18 @@ import work.slhaf.partner.runtime.interaction.data.context.PartnerRunningFlowCon
import java.util.HashSet; import java.util.HashSet;
import java.util.Set; import java.util.Set;
import java.util.concurrent.ExecutorService; import java.util.concurrent.ExecutorService;
@AgentRunningModule(name = "action_dispatcher", order = 7)
public class ActionDispatcher extends PostRunningAbstractAgentModuleAbstract { public class ActionDispatcher extends PostRunningAbstractAgentModuleAbstract {
@InjectCapability @InjectCapability
private ActionCapability actionCapability; private ActionCapability actionCapability;
@InjectModule @InjectModule
private ActionExecutor actionExecutor; private ActionExecutor actionExecutor;
@InjectModule @InjectModule
private ActionScheduler actionScheduler; private ActionScheduler actionScheduler;
private ExecutorService executor; private ExecutorService executor;
@Init @Init
public void init() { public void init() {
executor = actionCapability.getExecutor(ActionCore.ExecutorType.VIRTUAL); executor = actionCapability.getExecutor(ActionCore.ExecutorType.VIRTUAL);
} }
@Override @Override
public void doExecute(PartnerRunningFlowContext context) { public void doExecute(PartnerRunningFlowContext context) {
// 只需要处理prepared action因为pending action在用户确认后也将变为prepared action // 只需要处理prepared action因为pending action在用户确认后也将变为prepared action
@@ -61,10 +53,13 @@ public class ActionDispatcher extends PostRunningAbstractAgentModuleAbstract {
actionScheduler.execute(scheduledActions); actionScheduler.execute(scheduledActions);
}); });
} }
@Override @Override
protected boolean relyOnMessage() { protected boolean relyOnMessage() {
return false; return false;
} }
@Override
public int order() {
return 7;
}
} }

View File

@@ -1,52 +1,40 @@
package work.slhaf.partner.module.modules.action.dispatcher.executor; package work.slhaf.partner.module.modules.action.dispatcher.executor;
import com.alibaba.fastjson2.JSONObject; import com.alibaba.fastjson2.JSONObject;
import lombok.val; import lombok.val;
import work.slhaf.partner.api.agent.factory.module.abstracts.AbstractAgentSubModule; import work.slhaf.partner.api.agent.factory.module.abstracts.AbstractAgentModule;
import work.slhaf.partner.api.agent.factory.module.abstracts.ActivateModel; import work.slhaf.partner.api.agent.factory.module.abstracts.ActivateModel;
import work.slhaf.partner.api.agent.factory.module.annotation.AgentSubModule;
import work.slhaf.partner.module.modules.action.dispatcher.executor.entity.CorrectorInput; import work.slhaf.partner.module.modules.action.dispatcher.executor.entity.CorrectorInput;
import work.slhaf.partner.module.modules.action.dispatcher.executor.entity.CorrectorResult; import work.slhaf.partner.module.modules.action.dispatcher.executor.entity.CorrectorResult;
/** /**
* 负责在单组行动执行后,根据行动意图与结果检查后续行动是否符合目的,必要时直接调整行动链,或发起自对话请求进行干预 * 负责在单组行动执行后,根据行动意图与结果检查后续行动是否符合目的,必要时直接调整行动链,或发起自对话请求进行干预
*/ */
@AgentSubModule public class ActionCorrector extends AbstractAgentModule.Sub<CorrectorInput, CorrectorResult> implements ActivateModel {
public class ActionCorrector extends AbstractAgentSubModule<CorrectorInput, CorrectorResult> implements ActivateModel {
@Override @Override
public CorrectorResult execute(CorrectorInput input) { public CorrectorResult execute(CorrectorInput input) {
val prompt = buildPrompt(input); val prompt = buildPrompt(input);
val chatResponse = singleChat(prompt); val chatResponse = singleChat(prompt);
return JSONObject.parseObject(chatResponse.getMessage(), CorrectorResult.class); return JSONObject.parseObject(chatResponse.getMessage(), CorrectorResult.class);
} }
private String buildPrompt(CorrectorInput input) { private String buildPrompt(CorrectorInput input) {
val prompt = new JSONObject(); val prompt = new JSONObject();
prompt.put("[行动来源]", input.getSource()); prompt.put("[行动来源]", input.getSource());
prompt.put("[行动倾向]", input.getTendency()); prompt.put("[行动倾向]", input.getTendency());
prompt.put("[行动描述]", input.getDescription()); prompt.put("[行动描述]", input.getDescription());
prompt.put("[行动原因]", input.getReason()); prompt.put("[行动原因]", input.getReason());
val messages = prompt.putArray("[近期对话]"); val messages = prompt.putArray("[近期对话]");
messages.addAll(input.getRecentMessages()); messages.addAll(input.getRecentMessages());
val memory = prompt.putArray("[已激活记忆]"); val memory = prompt.putArray("[已激活记忆]");
memory.addAll(input.getActivatedSlices()); memory.addAll(input.getActivatedSlices());
val history = prompt.putArray("[已执行情况]"); val history = prompt.putArray("[已执行情况]");
history.addAll(input.getHistory()); history.addAll(input.getHistory());
return prompt.toJSONString(); return prompt.toJSONString();
} }
@Override @Override
public String modelKey() { public String modelKey() {
return "action_corrector"; return "action_corrector";
} }
@Override @Override
public boolean withBasicPrompt() { public boolean withBasicPrompt() {
return false; return false;
} }
} }

View File

@@ -1,10 +1,8 @@
package work.slhaf.partner.module.modules.action.dispatcher.executor; package work.slhaf.partner.module.modules.action.dispatcher.executor;
import lombok.extern.slf4j.Slf4j;
import lombok.val; import lombok.val;
import work.slhaf.partner.api.agent.factory.capability.annotation.InjectCapability; import work.slhaf.partner.api.agent.factory.capability.annotation.InjectCapability;
import work.slhaf.partner.api.agent.factory.module.abstracts.AbstractAgentSubModule; import work.slhaf.partner.api.agent.factory.module.abstracts.AbstractAgentModule;
import work.slhaf.partner.api.agent.factory.module.annotation.AgentSubModule;
import work.slhaf.partner.api.agent.factory.module.annotation.Init; import work.slhaf.partner.api.agent.factory.module.annotation.Init;
import work.slhaf.partner.api.agent.factory.module.annotation.InjectModule; import work.slhaf.partner.api.agent.factory.module.annotation.InjectModule;
import work.slhaf.partner.core.action.ActionCapability; import work.slhaf.partner.core.action.ActionCapability;
@@ -26,17 +24,13 @@ import java.util.concurrent.Phaser;
import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicInteger;
@Slf4j public class ActionExecutor extends AbstractAgentModule.Sub<ActionExecutorInput, Void> {
@AgentSubModule
public class ActionExecutor extends AbstractAgentSubModule<ActionExecutorInput, Void> {
@InjectCapability @InjectCapability
private ActionCapability actionCapability; private ActionCapability actionCapability;
@InjectCapability @InjectCapability
private MemoryCapability memoryCapability; private MemoryCapability memoryCapability;
@InjectCapability @InjectCapability
private CognationCapability cognationCapability; private CognationCapability cognationCapability;
@InjectModule @InjectModule
private ParamsExtractor paramsExtractor; private ParamsExtractor paramsExtractor;
@InjectModule @InjectModule
@@ -45,20 +39,16 @@ public class ActionExecutor extends AbstractAgentSubModule<ActionExecutorInput,
private ActionCorrector actionCorrector; private ActionCorrector actionCorrector;
@InjectModule @InjectModule
private ActionScheduler actionScheduler; private ActionScheduler actionScheduler;
private ExecutorService virtualExecutor; private ExecutorService virtualExecutor;
private ExecutorService platformExecutor; private ExecutorService platformExecutor;
private RunnerClient runnerClient; private RunnerClient runnerClient;
private final AssemblyHelper assemblyHelper = new AssemblyHelper(); private final AssemblyHelper assemblyHelper = new AssemblyHelper();
@Init @Init
public void init() { public void init() {
virtualExecutor = actionCapability.getExecutor(ActionCore.ExecutorType.VIRTUAL); virtualExecutor = actionCapability.getExecutor(ActionCore.ExecutorType.VIRTUAL);
platformExecutor = actionCapability.getExecutor(ActionCore.ExecutorType.PLATFORM); platformExecutor = actionCapability.getExecutor(ActionCore.ExecutorType.PLATFORM);
runnerClient = actionCapability.runnerClient(); runnerClient = actionCapability.runnerClient();
} }
/** /**
* 执行行动 * 执行行动
* *
@@ -85,62 +75,51 @@ public class ActionExecutor extends AbstractAgentSubModule<ActionExecutorInput,
val phaser = new Phaser(); val phaser = new Phaser();
val phaserRecord = actionCapability.putPhaserRecord(phaser, executableAction); val phaserRecord = actionCapability.putPhaserRecord(phaser, executableAction);
executableAction.setStatus(Status.EXECUTING); executableAction.setStatus(Status.EXECUTING);
// 开始执行 // 开始执行
val stageCursor = new Object() { val stageCursor = new Object() {
int stageCount; int stageCount;
boolean executingStageUpdated; boolean executingStageUpdated;
boolean stageCountUpdated; boolean stageCountUpdated;
void init() { void init() {
stageCount = 0; stageCount = 0;
executingStageUpdated = false; executingStageUpdated = false;
stageCountUpdated = false; stageCountUpdated = false;
update(); update();
} }
void requestAdvance() { void requestAdvance() {
if (!stageCountUpdated) { if (!stageCountUpdated) {
stageCount++; stageCount++;
stageCountUpdated = true; stageCountUpdated = true;
} }
if (stageCount < actionChain.size() && !executingStageUpdated) { if (stageCount < actionChain.size() && !executingStageUpdated) {
update(); update();
executingStageUpdated = true; executingStageUpdated = true;
} }
} }
boolean next() { boolean next() {
executingStageUpdated = false; executingStageUpdated = false;
stageCountUpdated = false; stageCountUpdated = false;
return stageCount < actionChain.size(); return stageCount < actionChain.size();
} }
void update() { void update() {
val orderList = new ArrayList<>(actionChain.keySet()); val orderList = new ArrayList<>(actionChain.keySet());
orderList.sort(Integer::compareTo); orderList.sort(Integer::compareTo);
executableAction.setExecutingStage(orderList.get(stageCount)); executableAction.setExecutingStage(orderList.get(stageCount));
} }
}; };
stageCursor.init(); stageCursor.init();
do { do {
val metaActions = actionChain.get(executableAction.getExecutingStage()); val metaActions = actionChain.get(executableAction.getExecutingStage());
val listeningRecord = executeAndListening(metaActions, phaserRecord, source); val listeningRecord = executeAndListening(metaActions, phaserRecord, source);
phaser.awaitAdvance(listeningRecord.phase()); phaser.awaitAdvance(listeningRecord.phase());
// synchronized 同步防止 accepting 循环间、phase guard 判定后发生 stage 推进 // synchronized 同步防止 accepting 循环间、phase guard 判定后发生 stage 推进
// 导致新行动的 phaser 投放阶段错乱无法阻塞的场景 // 导致新行动的 phaser 投放阶段错乱无法阻塞的场景
// 该 synchronized 将阶段推进与 accepting 监听 loop 捆绑为互斥的原子事件,避免了细粒度的 phaser 阶段竞态问题 // 该 synchronized 将阶段推进与 accepting 监听 loop 捆绑为互斥的原子事件,避免了细粒度的 phaser 阶段竞态问题
synchronized (listeningRecord.accepting()) { synchronized (listeningRecord.accepting()) {
listeningRecord.accepting().set(false); listeningRecord.accepting().set(false);
// 立即尝试推进,本次推进中,如果前方仍有未执行 stage将执行一次阶段推进 // 立即尝试推进,本次推进中,如果前方仍有未执行 stage将执行一次阶段推进
stageCursor.requestAdvance(); stageCursor.requestAdvance();
} }
try { try {
// 针对行动链进行修正,修正需要传入执行历史、行动目标等内容 // 针对行动链进行修正,修正需要传入执行历史、行动目标等内容
// 如果后续运行 corrector 触发频率较高,可考虑增加重试机制 // 如果后续运行 corrector 触发频率较高,可考虑增加重试机制
@@ -149,12 +128,10 @@ public class ActionExecutor extends AbstractAgentSubModule<ActionExecutorInput,
actionCapability.handleInterventions(correctorResult.getMetaInterventionList(), executableAction); actionCapability.handleInterventions(correctorResult.getMetaInterventionList(), executableAction);
} catch (Exception ignored) { } catch (Exception ignored) {
} }
// 第二次尝试进行阶段推进,本次负责补充上一次在不存在 stage时但 corrector 执行期间发生了 actionChain 的插入事件 // 第二次尝试进行阶段推进,本次负责补充上一次在不存在 stage时但 corrector 执行期间发生了 actionChain 的插入事件
// 如果第一次已经推进完毕,本次将会跳过 // 如果第一次已经推进完毕,本次将会跳过
stageCursor.requestAdvance(); stageCursor.requestAdvance();
} while (stageCursor.next()); } while (stageCursor.next());
// 结束 // 结束
actionCapability.removePhaserRecord(phaser); actionCapability.removePhaserRecord(phaser);
if (executableAction.getStatus() != Status.FAILED) { if (executableAction.getStatus() != Status.FAILED) {
@@ -165,20 +142,15 @@ public class ActionExecutor extends AbstractAgentSubModule<ActionExecutorInput,
} else { } else {
executableAction.setStatus(Status.SUCCESS); executableAction.setStatus(Status.SUCCESS);
} }
// TODO 执行过后需要回写至任务上下文recentCompletedTask同时触发自对话信号进行确认并记录以及是否通知用户触发与否需要机制进行匹配在模块链路可增加 interaction gate 门控,判断此次对话作用于谁、由谁发出、何种性质、是否需要回应等) // TODO 执行过后需要回写至任务上下文recentCompletedTask同时触发自对话信号进行确认并记录以及是否通知用户触发与否需要机制进行匹配在模块链路可增加 interaction gate 门控,判断此次对话作用于谁、由谁发出、何种性质、是否需要回应等)
} }
}); });
} }
return null; return null;
} }
private MetaActionsListeningRecord executeAndListening(List<MetaAction> metaActions, PhaserRecord phaserRecord, String source) { private MetaActionsListeningRecord executeAndListening(List<MetaAction> metaActions, PhaserRecord phaserRecord, String source) {
AtomicBoolean accepting = new AtomicBoolean(true); AtomicBoolean accepting = new AtomicBoolean(true);
AtomicInteger cursor = new AtomicInteger(); AtomicInteger cursor = new AtomicInteger();
CountDownLatch latch = new CountDownLatch(1); CountDownLatch latch = new CountDownLatch(1);
val phaser = phaserRecord.phaser(); val phaser = phaserRecord.phaser();
val phase = phaser.register(); val phase = phaser.register();
@@ -187,27 +159,22 @@ public class ActionExecutor extends AbstractAgentSubModule<ActionExecutorInput,
while (accepting.get()) { while (accepting.get()) {
synchronized (accepting) { synchronized (accepting) {
MetaAction next = null; MetaAction next = null;
synchronized (metaActions) { synchronized (metaActions) {
if (cursor.get() < metaActions.size()) { if (cursor.get() < metaActions.size()) {
next = metaActions.get(cursor.getAndIncrement()); next = metaActions.get(cursor.getAndIncrement());
} }
} }
if (next == null) { if (next == null) {
Thread.onSpinWait(); Thread.onSpinWait();
continue; continue;
} }
if (phaser.getPhase() != phase) { if (phaser.getPhase() != phase) {
metaActions.remove(next); metaActions.remove(next);
log.warn("行动阶段已推进,丢弃该行动: {}", next); log.warn("行动阶段已推进,丢弃该行动: {}", next);
continue; continue;
} }
ExecutorService executor = next.getIo() ? virtualExecutor : platformExecutor; ExecutorService executor = next.getIo() ? virtualExecutor : platformExecutor;
executor.execute(buildMataActionTask(next, phaserRecord, source)); executor.execute(buildMataActionTask(next, phaserRecord, source));
if (first) { if (first) {
phaser.arriveAndDeregister(); phaser.arriveAndDeregister();
latch.countDown(); latch.countDown();
@@ -223,7 +190,6 @@ public class ActionExecutor extends AbstractAgentSubModule<ActionExecutorInput,
} }
return new MetaActionsListeningRecord(accepting, phase); return new MetaActionsListeningRecord(accepting, phase);
} }
private Runnable buildMataActionTask(MetaAction metaAction, PhaserRecord phaserRecord, String source) { private Runnable buildMataActionTask(MetaAction metaAction, PhaserRecord phaserRecord, String source) {
val phaser = phaserRecord.phaser(); val phaser = phaserRecord.phaser();
phaser.register(); phaser.register();
@@ -238,7 +204,6 @@ public class ActionExecutor extends AbstractAgentSubModule<ActionExecutorInput,
val additionalContext = actionData.getAdditionalContext().get(executingStage); val additionalContext = actionData.getAdditionalContext().get(executingStage);
val extractorInput = assemblyHelper.buildExtractorInput(metaAction, source, historyActionResults, additionalContext); val extractorInput = assemblyHelper.buildExtractorInput(metaAction, source, historyActionResults, additionalContext);
val extractorResult = paramsExtractor.execute(extractorInput); val extractorResult = paramsExtractor.execute(extractorInput);
if (extractorResult.isOk()) { if (extractorResult.isOk()) {
metaAction.getParams().putAll(extractorResult.getParams()); metaAction.getParams().putAll(extractorResult.getParams());
runnerClient.submit(metaAction); runnerClient.submit(metaAction);
@@ -275,16 +240,12 @@ public class ActionExecutor extends AbstractAgentSubModule<ActionExecutorInput,
} }
}; };
} }
private record MetaActionsListeningRecord(AtomicBoolean accepting, int phase) { private record MetaActionsListeningRecord(AtomicBoolean accepting, int phase) {
} }
@SuppressWarnings("InnerClassMayBeStatic") @SuppressWarnings("InnerClassMayBeStatic")
private class AssemblyHelper { private class AssemblyHelper {
private AssemblyHelper() { private AssemblyHelper() {
} }
private RepairerInput buildRepairerInput(List<HistoryAction> historyActionsResults, MetaAction action, String userId) { private RepairerInput buildRepairerInput(List<HistoryAction> historyActionsResults, MetaAction action, String userId) {
RepairerInput input = new RepairerInput(); RepairerInput input = new RepairerInput();
MetaActionInfo metaActionInfo = actionCapability.loadMetaActionInfo(action.getKey()); MetaActionInfo metaActionInfo = actionCapability.loadMetaActionInfo(action.getKey());
@@ -295,7 +256,6 @@ public class ActionExecutor extends AbstractAgentSubModule<ActionExecutorInput,
input.setUserId(userId); input.setUserId(userId);
return input; return input;
} }
private ExtractorInput buildExtractorInput(MetaAction action, String source, List<HistoryAction> historyActionResults, private ExtractorInput buildExtractorInput(MetaAction action, String source, List<HistoryAction> historyActionResults,
List<String> additionalContext) { List<String> additionalContext) {
ExtractorInput input = new ExtractorInput(); ExtractorInput input = new ExtractorInput();
@@ -306,7 +266,6 @@ public class ActionExecutor extends AbstractAgentSubModule<ActionExecutorInput,
input.setAdditionalContext(additionalContext); input.setAdditionalContext(additionalContext);
return input; return input;
} }
private CorrectorInput buildCorrectorInput(ExecutableAction executableAction, String source) { private CorrectorInput buildCorrectorInput(ExecutableAction executableAction, String source) {
return CorrectorInput.builder() return CorrectorInput.builder()
.tendency(executableAction.getTendency()) .tendency(executableAction.getTendency())
@@ -320,5 +279,4 @@ public class ActionExecutor extends AbstractAgentSubModule<ActionExecutorInput,
.build(); .build();
} }
} }
} }

View File

@@ -4,11 +4,9 @@ import com.alibaba.fastjson2.JSONArray;
import com.alibaba.fastjson2.JSONObject; import com.alibaba.fastjson2.JSONObject;
import com.alibaba.fastjson2.TypeReference; import com.alibaba.fastjson2.TypeReference;
import lombok.Data; import lombok.Data;
import lombok.extern.slf4j.Slf4j;
import work.slhaf.partner.api.agent.factory.capability.annotation.InjectCapability; import work.slhaf.partner.api.agent.factory.capability.annotation.InjectCapability;
import work.slhaf.partner.api.agent.factory.module.abstracts.AbstractAgentSubModule; import work.slhaf.partner.api.agent.factory.module.abstracts.AbstractAgentModule;
import work.slhaf.partner.api.agent.factory.module.abstracts.ActivateModel; import work.slhaf.partner.api.agent.factory.module.abstracts.ActivateModel;
import work.slhaf.partner.api.agent.factory.module.annotation.AgentSubModule;
import work.slhaf.partner.api.agent.factory.module.annotation.Init; import work.slhaf.partner.api.agent.factory.module.annotation.Init;
import work.slhaf.partner.api.agent.factory.module.annotation.InjectModule; import work.slhaf.partner.api.agent.factory.module.annotation.InjectModule;
import work.slhaf.partner.api.chat.pojo.ChatResponse; import work.slhaf.partner.api.chat.pojo.ChatResponse;
@@ -28,7 +26,6 @@ import java.util.List;
import java.util.concurrent.CountDownLatch; import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService; import java.util.concurrent.ExecutorService;
import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicInteger;
/** /**
* 负责识别行动链的修复 * 负责识别行动链的修复
* <ol> * <ol>
@@ -43,26 +40,19 @@ import java.util.concurrent.atomic.AtomicInteger;
* </li> * </li>
* </ol> * </ol>
*/ */
@Slf4j public class ActionRepairer extends AbstractAgentModule.Sub<RepairerInput, RepairerResult> implements ActivateModel {
@AgentSubModule
public class ActionRepairer extends AbstractAgentSubModule<RepairerInput, RepairerResult> implements ActivateModel {
@InjectCapability @InjectCapability
private ActionCapability actionCapability; private ActionCapability actionCapability;
@InjectCapability @InjectCapability
private CognationCapability cognationCapability; private CognationCapability cognationCapability;
@InjectModule @InjectModule
private DynamicActionGenerator dynamicActionGenerator; private DynamicActionGenerator dynamicActionGenerator;
private final AssemblyHelper assemblyHelper = new AssemblyHelper(); private final AssemblyHelper assemblyHelper = new AssemblyHelper();
private RunnerClient runnerClient; private RunnerClient runnerClient;
@Init @Init
void init() { void init() {
runnerClient = actionCapability.runnerClient(); runnerClient = actionCapability.runnerClient();
} }
@Override @Override
public RepairerResult execute(RepairerInput data) { public RepairerResult execute(RepairerInput data) {
RepairerResult result; RepairerResult result;
@@ -92,7 +82,6 @@ public class ActionRepairer extends AbstractAgentSubModule<RepairerInput, Repair
} }
return result; return result;
} }
/** /**
* 负责根据输入内容进行行动单元的参数信息修复 * 负责根据输入内容进行行动单元的参数信息修复
* *
@@ -107,7 +96,6 @@ public class ActionRepairer extends AbstractAgentSubModule<RepairerInput, Repair
result.setStatus(RepairerStatus.FAILED); result.setStatus(RepairerStatus.FAILED);
return result; return result;
} }
runnerClient.submit(tempAction); runnerClient.submit(tempAction);
// 根据 tempAction 的执行状态设置修复结果 // 根据 tempAction 的执行状态设置修复结果
Result actionResult = tempAction.getResult(); Result actionResult = tempAction.getResult();
@@ -115,12 +103,10 @@ public class ActionRepairer extends AbstractAgentSubModule<RepairerInput, Repair
result.setStatus(RepairerStatus.FAILED); result.setStatus(RepairerStatus.FAILED);
return result; return result;
} }
result.setStatus(RepairerStatus.OK); result.setStatus(RepairerStatus.OK);
result.getFixedData().add(actionResult.getData()); result.getFixedData().add(actionResult.getData());
return result; return result;
} }
/** /**
* 负责根据输入内容进行行动单元的参数信息修复 * 负责根据输入内容进行行动单元的参数信息修复
* *
@@ -161,50 +147,41 @@ public class ActionRepairer extends AbstractAgentSubModule<RepairerInput, Repair
} }
return result; return result;
} }
private RepairerResult handleUserInteraction(String acquireContent) { private RepairerResult handleUserInteraction(String acquireContent) {
RepairerResult result = new RepairerResult(); RepairerResult result = new RepairerResult();
result.setStatus(RepairerStatus.ACQUIRE); result.setStatus(RepairerStatus.ACQUIRE);
// 发送自对话请求 // 发送自对话请求
return result; return result;
} }
@Override @Override
public String modelKey() { public String modelKey() {
return "action_repairer"; return "action_repairer";
} }
@Override @Override
public boolean withBasicPrompt() { public boolean withBasicPrompt() {
return false; return false;
} }
@SuppressWarnings("InnerClassMayBeStatic") @SuppressWarnings("InnerClassMayBeStatic")
@Data @Data
private class RepairerData { private class RepairerData {
private RepairerType repairerType; private RepairerType repairerType;
private String data; private String data;
} }
private enum RepairerType { private enum RepairerType {
ACTION_GENERATION, ACTION_GENERATION,
ACTION_INVOCATION, ACTION_INVOCATION,
USER_INTERACTION USER_INTERACTION
} }
@SuppressWarnings("InnerClassMayBeStatic") @SuppressWarnings("InnerClassMayBeStatic")
private class AssemblyHelper { private class AssemblyHelper {
private AssemblyHelper() { private AssemblyHelper() {
} }
private String buildPrompt(RepairerInput data, String specialInstruction) { private String buildPrompt(RepairerInput data, String specialInstruction) {
JSONObject prompt = new JSONObject(); JSONObject prompt = new JSONObject();
JSONObject actionData = prompt.putObject("[本次行动信息]"); JSONObject actionData = prompt.putObject("[本次行动信息]");
actionData.put("[行动描述]", data.getActionDescription()); actionData.put("[行动描述]", data.getActionDescription());
JSONObject actionParamsData = actionData.putObject("[行动参数说明]"); JSONObject actionParamsData = actionData.putObject("[行动参数说明]");
actionParamsData.putAll(data.getParams()); actionParamsData.putAll(data.getParams());
JSONArray historyData = prompt.putArray("[历史行动执行结果]"); JSONArray historyData = prompt.putArray("[历史行动执行结果]");
data.getHistoryActionResults().forEach(historyAction -> { data.getHistoryActionResults().forEach(historyAction -> {
JSONObject historyItem = new JSONObject(); JSONObject historyItem = new JSONObject();
@@ -213,16 +190,12 @@ public class ActionRepairer extends AbstractAgentSubModule<RepairerInput, Repair
historyItem.put("[行动结果]", historyAction.result()); historyItem.put("[行动结果]", historyAction.result());
historyData.add(historyItem); historyData.add(historyItem);
}); });
JSONArray messageData = prompt.putArray("[最近消息列表]"); JSONArray messageData = prompt.putArray("[最近消息列表]");
messageData.addAll(data.getRecentMessages()); messageData.addAll(data.getRecentMessages());
if (specialInstruction != null) { if (specialInstruction != null) {
prompt.put("[特殊指令]", specialInstruction); prompt.put("[特殊指令]", specialInstruction);
} }
return prompt.toString(); return prompt.toString();
} }
} }
} }

View File

@@ -1,11 +1,9 @@
package work.slhaf.partner.module.modules.action.dispatcher.executor; package work.slhaf.partner.module.modules.action.dispatcher.executor;
import com.alibaba.fastjson2.JSONObject; import com.alibaba.fastjson2.JSONObject;
import lombok.val; import lombok.val;
import work.slhaf.partner.api.agent.factory.capability.annotation.InjectCapability; import work.slhaf.partner.api.agent.factory.capability.annotation.InjectCapability;
import work.slhaf.partner.api.agent.factory.module.abstracts.AbstractAgentSubModule; import work.slhaf.partner.api.agent.factory.module.abstracts.AbstractAgentModule;
import work.slhaf.partner.api.agent.factory.module.abstracts.ActivateModel; import work.slhaf.partner.api.agent.factory.module.abstracts.ActivateModel;
import work.slhaf.partner.api.agent.factory.module.annotation.AgentSubModule;
import work.slhaf.partner.api.agent.factory.module.annotation.Init; import work.slhaf.partner.api.agent.factory.module.annotation.Init;
import work.slhaf.partner.api.chat.pojo.ChatResponse; import work.slhaf.partner.api.chat.pojo.ChatResponse;
import work.slhaf.partner.common.util.ExtractUtil; import work.slhaf.partner.common.util.ExtractUtil;
@@ -15,24 +13,18 @@ import work.slhaf.partner.core.action.entity.MetaAction;
import work.slhaf.partner.core.action.runner.RunnerClient; import work.slhaf.partner.core.action.runner.RunnerClient;
import work.slhaf.partner.module.modules.action.dispatcher.executor.entity.GeneratorInput; import work.slhaf.partner.module.modules.action.dispatcher.executor.entity.GeneratorInput;
import work.slhaf.partner.module.modules.action.dispatcher.executor.entity.GeneratorResult; import work.slhaf.partner.module.modules.action.dispatcher.executor.entity.GeneratorResult;
/** /**
* 负责依据输入内容生成可执行的动态行动单元,并选择是否持久化至 SandboxRunner 容器内 * 负责依据输入内容生成可执行的动态行动单元,并选择是否持久化至 SandboxRunner 容器内
*/ */
@AgentSubModule public class DynamicActionGenerator extends AbstractAgentModule.Sub<GeneratorInput, GeneratorResult>
public class DynamicActionGenerator extends AbstractAgentSubModule<GeneratorInput, GeneratorResult>
implements ActivateModel { implements ActivateModel {
@InjectCapability @InjectCapability
private ActionCapability actionCapability; private ActionCapability actionCapability;
private RunnerClient runnerClient; private RunnerClient runnerClient;
@Init @Init
void init() { void init() {
runnerClient = actionCapability.runnerClient(); runnerClient = actionCapability.runnerClient();
} }
@Override @Override
public GeneratorResult execute(GeneratorInput input) { public GeneratorResult execute(GeneratorInput input) {
GeneratorResult result = new GeneratorResult(); GeneratorResult result = new GeneratorResult();
@@ -40,12 +32,10 @@ public class DynamicActionGenerator extends AbstractAgentSubModule<GeneratorInpu
// 由于 SCRIPT 类型程序都是在 SandboxRunner 内部的磁盘上加载然后执行的, // 由于 SCRIPT 类型程序都是在 SandboxRunner 内部的磁盘上加载然后执行的,
// 所以此处的输入内容也只需要指定输入参数、临时key、是否持久化即可路径将按照指定规则统一构建不可交给LLM生成 // 所以此处的输入内容也只需要指定输入参数、临时key、是否持久化即可路径将按照指定规则统一构建不可交给LLM生成
String prompt = buildPrompt(input); String prompt = buildPrompt(input);
// 响应结果需要包含几个特殊数据: 依赖项、代码内容、是否序列化、响应数据释义 // 响应结果需要包含几个特殊数据: 依赖项、代码内容、是否序列化、响应数据释义
ChatResponse response = this.singleChat(prompt); ChatResponse response = this.singleChat(prompt);
GeneratedData generatorData = JSONObject GeneratedData generatorData = JSONObject
.parseObject(ExtractUtil.extractJson(response.getMessage()), GeneratedData.class); .parseObject(ExtractUtil.extractJson(response.getMessage()), GeneratedData.class);
val location = runnerClient.buildTmpPath(input.getActionName(), generatorData.getCodeType()); val location = runnerClient.buildTmpPath(input.getActionName(), generatorData.getCodeType());
MetaAction tempAction = new MetaAction( MetaAction tempAction = new MetaAction(
input.getActionName(), input.getActionName(),
@@ -66,11 +56,9 @@ public class DynamicActionGenerator extends AbstractAgentSubModule<GeneratorInpu
} }
return result; return result;
} }
private void waitingSerialize() { private void waitingSerialize() {
throw new UnsupportedOperationException("Unimplemented method 'waitingSerialize'"); throw new UnsupportedOperationException("Unimplemented method 'waitingSerialize'");
} }
private String buildPrompt(GeneratorInput data) { private String buildPrompt(GeneratorInput data) {
JSONObject prompt = new JSONObject(); JSONObject prompt = new JSONObject();
prompt.put("[行动描述]", data.getDescription()); prompt.put("[行动描述]", data.getDescription());
@@ -78,12 +66,10 @@ public class DynamicActionGenerator extends AbstractAgentSubModule<GeneratorInpu
prompt.putObject("[行动参数描述]").putAll(data.getParamsDescription()); prompt.putObject("[行动参数描述]").putAll(data.getParamsDescription());
return prompt.toString(); return prompt.toString();
} }
@Override @Override
public String modelKey() { public String modelKey() {
return "dynamic_generator"; return "dynamic_generator";
} }
@Override @Override
public boolean withBasicPrompt() { public boolean withBasicPrompt() {
return false; return false;

View File

@@ -2,10 +2,8 @@ package work.slhaf.partner.module.modules.action.dispatcher.executor;
import com.alibaba.fastjson2.JSONArray; import com.alibaba.fastjson2.JSONArray;
import com.alibaba.fastjson2.JSONObject; import com.alibaba.fastjson2.JSONObject;
import lombok.extern.slf4j.Slf4j; import work.slhaf.partner.api.agent.factory.module.abstracts.AbstractAgentModule;
import work.slhaf.partner.api.agent.factory.module.abstracts.AbstractAgentSubModule;
import work.slhaf.partner.api.agent.factory.module.abstracts.ActivateModel; import work.slhaf.partner.api.agent.factory.module.abstracts.ActivateModel;
import work.slhaf.partner.api.agent.factory.module.annotation.AgentSubModule;
import work.slhaf.partner.api.chat.pojo.ChatResponse; import work.slhaf.partner.api.chat.pojo.ChatResponse;
import work.slhaf.partner.core.action.entity.MetaActionInfo; import work.slhaf.partner.core.action.entity.MetaActionInfo;
import work.slhaf.partner.module.modules.action.dispatcher.executor.entity.ExtractorInput; import work.slhaf.partner.module.modules.action.dispatcher.executor.entity.ExtractorInput;
@@ -14,14 +12,10 @@ import work.slhaf.partner.module.modules.action.dispatcher.executor.entity.Histo
import java.util.HashMap; import java.util.HashMap;
import java.util.List; import java.util.List;
/** /**
* 负责依据输入内容进行行动单元的参数信息提取 * 负责依据输入内容进行行动单元的参数信息提取
*/ */
@Slf4j public class ParamsExtractor extends AbstractAgentModule.Sub<ExtractorInput, ExtractorResult> implements ActivateModel {
@AgentSubModule
public class ParamsExtractor extends AbstractAgentSubModule<ExtractorInput, ExtractorResult> implements ActivateModel {
@Override @Override
public ExtractorResult execute(ExtractorInput input) { public ExtractorResult execute(ExtractorInput input) {
String prompt = buildPrompt(input); String prompt = buildPrompt(input);
@@ -37,15 +31,12 @@ public class ParamsExtractor extends AbstractAgentSubModule<ExtractorInput, Extr
} }
return result; return result;
} }
private String buildPrompt(ExtractorInput input) { private String buildPrompt(ExtractorInput input) {
JSONObject prompt = new JSONObject(); JSONObject prompt = new JSONObject();
JSONObject actionData = prompt.putObject("[本次行动信息]"); JSONObject actionData = prompt.putObject("[本次行动信息]");
MetaActionInfo actionInfo = input.getMetaActionInfo(); MetaActionInfo actionInfo = input.getMetaActionInfo();
actionData.put("[行动描述]", actionInfo.getDescription()); actionData.put("[行动描述]", actionInfo.getDescription());
actionData.put("[行动参数说明]", actionInfo.getParams()); actionData.put("[行动参数说明]", actionInfo.getParams());
JSONArray historyData = prompt.putArray("[历史行动执行结果]"); JSONArray historyData = prompt.putArray("[历史行动执行结果]");
List<HistoryAction> historyActions = input.getHistoryActionResults(); List<HistoryAction> historyActions = input.getHistoryActionResults();
for (HistoryAction historyAction : historyActions) { for (HistoryAction historyAction : historyActions) {
@@ -55,21 +46,16 @@ public class ParamsExtractor extends AbstractAgentSubModule<ExtractorInput, Extr
historyItem.put("[行动结果]", historyAction.result()); historyItem.put("[行动结果]", historyAction.result());
historyData.add(historyItem); historyData.add(historyItem);
} }
JSONArray messageData = prompt.putArray("[最近消息列表]"); JSONArray messageData = prompt.putArray("[最近消息列表]");
messageData.addAll(input.getRecentMessages()); messageData.addAll(input.getRecentMessages());
return prompt.toString(); return prompt.toString();
} }
@Override @Override
public String modelKey() { public String modelKey() {
return "params_extractor"; return "params_extractor";
} }
@Override @Override
public boolean withBasicPrompt() { public boolean withBasicPrompt() {
return false; return false;
} }
} }

View File

@@ -12,8 +12,7 @@ import kotlinx.coroutines.sync.Mutex
import kotlinx.coroutines.sync.withLock import kotlinx.coroutines.sync.withLock
import org.slf4j.LoggerFactory import org.slf4j.LoggerFactory
import work.slhaf.partner.api.agent.factory.capability.annotation.InjectCapability import work.slhaf.partner.api.agent.factory.capability.annotation.InjectCapability
import work.slhaf.partner.api.agent.factory.module.abstracts.AbstractAgentSubModule import work.slhaf.partner.api.agent.factory.module.abstracts.AbstractAgentModule
import work.slhaf.partner.api.agent.factory.module.annotation.AgentSubModule
import work.slhaf.partner.api.agent.factory.module.annotation.Init import work.slhaf.partner.api.agent.factory.module.annotation.Init
import work.slhaf.partner.api.agent.factory.module.annotation.InjectModule import work.slhaf.partner.api.agent.factory.module.annotation.InjectModule
import work.slhaf.partner.core.action.ActionCapability import work.slhaf.partner.core.action.ActionCapability
@@ -30,17 +29,13 @@ import java.time.temporal.ChronoUnit
import java.util.stream.Collectors import java.util.stream.Collectors
import kotlin.jvm.optionals.getOrNull import kotlin.jvm.optionals.getOrNull
@AgentSubModule class ActionScheduler : AbstractAgentModule.Sub<Set<Schedulable>, Void?>() {
class ActionScheduler : AbstractAgentSubModule<Set<Schedulable>, Void>() {
@InjectCapability @InjectCapability
private lateinit var actionCapability: ActionCapability private lateinit var actionCapability: ActionCapability
@InjectModule @InjectModule
private lateinit var actionExecutor: ActionExecutor private lateinit var actionExecutor: ActionExecutor
private lateinit var timeWheel: TimeWheel private lateinit var timeWheel: TimeWheel
private val schedulerScope = private val schedulerScope =
CoroutineScope(Dispatchers.Default + SupervisorJob() + CoroutineName("ActionScheduler")) CoroutineScope(Dispatchers.Default + SupervisorJob() + CoroutineName("ActionScheduler"))
@@ -58,7 +53,6 @@ class ActionScheduler : AbstractAgentSubModule<Set<Schedulable>, Void>() {
.map { it as SchedulableExecutableAction } .map { it as SchedulableExecutableAction }
.collect(Collectors.toSet()) .collect(Collectors.toSet())
} }
val onTrigger: (Set<Schedulable>) -> Unit = { schedulableSet -> val onTrigger: (Set<Schedulable>) -> Unit = { schedulableSet ->
val executableActions = mutableSetOf<SchedulableExecutableAction>() val executableActions = mutableSetOf<SchedulableExecutableAction>()
val stateActions = mutableSetOf<StateAction>() val stateActions = mutableSetOf<StateAction>()
@@ -72,12 +66,9 @@ class ActionScheduler : AbstractAgentSubModule<Set<Schedulable>, Void>() {
actionCapability.getExecutor(ActionCore.ExecutorType.VIRTUAL) actionCapability.getExecutor(ActionCore.ExecutorType.VIRTUAL)
.execute { stateActions.forEach { it.trigger.onTrigger() } } .execute { stateActions.forEach { it.trigger.onTrigger() } }
} }
timeWheel = TimeWheel(listScheduledActions, onTrigger) timeWheel = TimeWheel(listScheduledActions, onTrigger)
} }
loadScheduledActions() loadScheduledActions()
setupShutdownHook() setupShutdownHook()
} }
@@ -88,10 +79,9 @@ class ActionScheduler : AbstractAgentSubModule<Set<Schedulable>, Void>() {
}) })
} }
override fun execute(schedulableSet: Set<Schedulable>?): Void? { override fun execute(input: Set<Schedulable>): Void? {
schedulerScope.launch { schedulerScope.launch {
schedulableSet?.run { for (schedulableData in input) {
for (schedulableData in schedulableSet) {
log.debug("New data to schedule: {}", schedulableData) log.debug("New data to schedule: {}", schedulableData)
timeWheel.schedule(schedulableData) timeWheel.schedule(schedulableData)
if (schedulableData is SchedulableExecutableAction) { if (schedulableData is SchedulableExecutableAction) {
@@ -99,7 +89,6 @@ class ActionScheduler : AbstractAgentSubModule<Set<Schedulable>, Void>() {
} }
} }
} }
}
return null return null
} }
@@ -107,16 +96,13 @@ class ActionScheduler : AbstractAgentSubModule<Set<Schedulable>, Void>() {
val listSource: () -> Set<Schedulable>, val listSource: () -> Set<Schedulable>,
val onTrigger: (toTrigger: Set<Schedulable>) -> Unit val onTrigger: (toTrigger: Set<Schedulable>) -> Unit
) : Closeable { ) : Closeable {
private val schedulableGroupByHour = Array<MutableSet<Schedulable>>(24) { mutableSetOf() } private val schedulableGroupByHour = Array<MutableSet<Schedulable>>(24) { mutableSetOf() }
private val wheel = Array<MutableSet<Schedulable>>(60 * 60) { mutableSetOf() } private val wheel = Array<MutableSet<Schedulable>>(60 * 60) { mutableSetOf() }
private var recordHour: Int = -1 private var recordHour: Int = -1
private var recordDay: Int = -1 private var recordDay: Int = -1
private val state = MutableStateFlow(WheelState.SLEEPING) private val state = MutableStateFlow(WheelState.SLEEPING)
private val wheelActionsLock = Mutex() private val wheelActionsLock = Mutex()
private val timeWheelScope = CoroutineScope(SupervisorJob() + Dispatchers.Default + CoroutineName("TimeWheel")) private val timeWheelScope = CoroutineScope(SupervisorJob() + Dispatchers.Default + CoroutineName("TimeWheel"))
private val cronDefinition: CronDefinition = CronDefinitionBuilder.instanceDefinitionFor(CronType.QUARTZ) private val cronDefinition: CronDefinition = CronDefinitionBuilder.instanceDefinitionFor(CronType.QUARTZ)
private val cronParser: CronParser = CronParser(cronDefinition) private val cronParser: CronParser = CronParser(cronDefinition)
@@ -136,23 +122,19 @@ class ActionScheduler : AbstractAgentSubModule<Set<Schedulable>, Void>() {
return@checkThenExecute return@checkThenExecute
} }
log.debug("Action next execution time: {}", parseToZonedDateTime) log.debug("Action next execution time: {}", parseToZonedDateTime)
val hour = parseToZonedDateTime.hour val hour = parseToZonedDateTime.hour
schedulableGroupByHour[hour].add(schedulableData) schedulableGroupByHour[hour].add(schedulableData)
log.debug("Action scheduled at {}", hour) log.debug("Action scheduled at {}", hour)
if (it.hour == hour) { if (it.hour == hour) {
val wheelOffset = parseToZonedDateTime.minute * 60 + parseToZonedDateTime.second val wheelOffset = parseToZonedDateTime.minute * 60 + parseToZonedDateTime.second
wheel[wheelOffset].add(schedulableData) wheel[wheelOffset].add(schedulableData)
state.value = WheelState.ACTIVE state.value = WheelState.ACTIVE
log.debug("Action scheduled at wheel offset {}", wheelOffset) log.debug("Action scheduled at wheel offset {}", wheelOffset)
} }
} }
} }
private fun wheel() { private fun wheel() {
data class WheelStepResult( data class WheelStepResult(
val toTrigger: Set<Schedulable>?, val toTrigger: Set<Schedulable>?,
val shouldBreak: Boolean val shouldBreak: Boolean
@@ -178,26 +160,20 @@ class ActionScheduler : AbstractAgentSubModule<Set<Schedulable>, Void>() {
suspend fun CoroutineScope.wheel(launchingTime: ZonedDateTime, primaryTickAdvanceTime: Long) { suspend fun CoroutineScope.wheel(launchingTime: ZonedDateTime, primaryTickAdvanceTime: Long) {
val launchingHour = launchingTime.hour val launchingHour = launchingTime.hour
var tick = launchingTime.minute * 60 + launchingTime.second var tick = launchingTime.minute * 60 + launchingTime.second
// 让节拍器从“启动时刻的下一秒”开始(避免立即 step=0 // 让节拍器从“启动时刻的下一秒”开始(避免立即 step=0
var nextTickNanos = primaryTickAdvanceTime + 1_000_000_000L var nextTickNanos = primaryTickAdvanceTime + 1_000_000_000L
while (isActive) { while (isActive) {
// 1) 计算落后多少秒:至少 1正常推进也可能 >1追赶 // 1) 计算落后多少秒:至少 1正常推进也可能 >1追赶
val now0 = System.nanoTime() val now0 = System.nanoTime()
val lagNanos = now0 - nextTickNanos val lagNanos = now0 - nextTickNanos
val step = if (lagNanos < 0) 1 else (lagNanos / 1_000_000_000L).toInt() + 1 val step = if (lagNanos < 0) 1 else (lagNanos / 1_000_000_000L).toInt() + 1
val previousTick = tick val previousTick = tick
tick = (tick + step).coerceAtMost(wheel.lastIndex) tick = (tick + step).coerceAtMost(wheel.lastIndex)
// 2) 推进节拍器:按“理论秒”前进 step 次 // 2) 推进节拍器:按“理论秒”前进 step 次
nextTickNanos += step.toLong() * 1_000_000_000L nextTickNanos += step.toLong() * 1_000_000_000L
val stepResult = run { val stepResult = run {
var shouldBreak = false var shouldBreak = false
var toTrigger: Set<Schedulable>? = null var toTrigger: Set<Schedulable>? = null
checkThenExecute(false) { checkThenExecute(false) {
if (it.hour != launchingHour) { if (it.hour != launchingHour) {
shouldBreak = true shouldBreak = true
@@ -210,30 +186,23 @@ class ActionScheduler : AbstractAgentSubModule<Set<Schedulable>, Void>() {
) )
return@checkThenExecute return@checkThenExecute
} }
toTrigger = collectToTrigger(tick, previousTick, launchingHour) toTrigger = collectToTrigger(tick, previousTick, launchingHour)
if (tick >= wheel.lastIndex || schedulableGroupByHour[launchingHour].isEmpty()) { if (tick >= wheel.lastIndex || schedulableGroupByHour[launchingHour].isEmpty()) {
state.value = WheelState.SLEEPING state.value = WheelState.SLEEPING
shouldBreak = true shouldBreak = true
} }
} }
WheelStepResult(toTrigger, shouldBreak) WheelStepResult(toTrigger, shouldBreak)
} }
stepResult.toTrigger?.let { trigger -> stepResult.toTrigger?.let { trigger ->
timeWheelScope.launch { timeWheelScope.launch {
onTrigger(trigger) onTrigger(trigger)
} }
} }
if (stepResult.shouldBreak) { if (stepResult.shouldBreak) {
log.debug("Wheel stopped at tick {}", tick) log.debug("Wheel stopped at tick {}", tick)
break break
} }
// 3) 精确睡到下一次理论 tick用最新 nanoTime // 3) 精确睡到下一次理论 tick用最新 nanoTime
val now1 = System.nanoTime() val now1 = System.nanoTime()
val sleepNanos = nextTickNanos - now1 val sleepNanos = nextTickNanos - now1
@@ -255,7 +224,6 @@ class ActionScheduler : AbstractAgentSubModule<Set<Schedulable>, Void>() {
} }
log.debug("Waiting ended at {}", ZonedDateTime.now()) log.debug("Waiting ended at {}", ZonedDateTime.now())
} }
timeWheelScope.launch { timeWheelScope.launch {
while (isActive) { while (isActive) {
// 判断是否该步入下一小时 // 判断是否该步入下一小时
@@ -270,14 +238,12 @@ class ActionScheduler : AbstractAgentSubModule<Set<Schedulable>, Void>() {
// 而启动时无触发保障,此时一并初始化 tick 推进时间,足以应对 check 与 wheel 间的这段时间间隔 // 而启动时无触发保障,此时一并初始化 tick 推进时间,足以应对 check 与 wheel 间的这段时间间隔
primaryTickAdvanceTime = System.nanoTime() primaryTickAdvanceTime = System.nanoTime()
} }
// 如果该时无任务则等待,插入事件可提前唤醒 // 如果该时无任务则等待,插入事件可提前唤醒
if (shouldWait!!) { if (shouldWait!!) {
// 计算距离下一小时的时间,等待 // 计算距离下一小时的时间,等待
currentTime?.let { wait(it) } currentTime?.let { wait(it) }
continue continue
} }
// 唤醒进行时间轮循环 // 唤醒进行时间轮循环
wheel(currentTime!!, primaryTickAdvanceTime!!) wheel(currentTime!!, primaryTickAdvanceTime!!)
} }
@@ -303,11 +269,9 @@ class ActionScheduler : AbstractAgentSubModule<Set<Schedulable>, Void>() {
logFailedStatus(schedulableData) logFailedStatus(schedulableData)
continue continue
} }
load(nextExecutingTime, schedulableData) load(nextExecutingTime, schedulableData)
} }
} }
repair() repair()
runLoading() runLoading()
} }
@@ -319,13 +283,11 @@ class ActionScheduler : AbstractAgentSubModule<Set<Schedulable>, Void>() {
wheel[secondsTime].add(schedulableData) wheel[secondsTime].add(schedulableData)
log.debug("Action loaded to hour: {}", schedulableData) log.debug("Action loaded to hour: {}", schedulableData)
} }
val repair: () -> Unit = { val repair: () -> Unit = {
for (set in wheel) { for (set in wheel) {
set.clear() set.clear()
} }
} }
loadActions(schedulableGroupByHour[currentTime.hour], currentTime, load, repair) loadActions(schedulableGroupByHour[currentTime.hour], currentTime, load, repair)
} }
@@ -335,13 +297,11 @@ class ActionScheduler : AbstractAgentSubModule<Set<Schedulable>, Void>() {
schedulableGroupByHour[latestExecutingTime.hour].add(schedulableData) schedulableGroupByHour[latestExecutingTime.hour].add(schedulableData)
log.debug("Action loaded to day: {}", schedulableData) log.debug("Action loaded to day: {}", schedulableData)
} }
val repair: () -> Unit = { val repair: () -> Unit = {
for (set in schedulableGroupByHour) { for (set in schedulableGroupByHour) {
set.clear() set.clear()
} }
} }
loadActions(listSource(), currentTime, load, repair) loadActions(listSource(), currentTime, load, repair)
} }
@@ -360,7 +320,6 @@ class ActionScheduler : AbstractAgentSubModule<Set<Schedulable>, Void>() {
} }
val now = ZonedDateTime.now() val now = ZonedDateTime.now()
if (finallyToExecute) { if (finallyToExecute) {
refreshIfNeeded(now) refreshIfNeeded(now)
then(now) then(now)
@@ -398,9 +357,7 @@ class ActionScheduler : AbstractAgentSubModule<Set<Schedulable>, Void>() {
else else
executionTime executionTime
} }
} }
} }
private fun logFailedStatus(scheduleData: Schedulable) { private fun logFailedStatus(scheduleData: Schedulable) {

View File

@@ -5,7 +5,6 @@ import com.alibaba.fastjson2.JSONObject;
import lombok.val; import lombok.val;
import work.slhaf.partner.api.agent.factory.capability.annotation.InjectCapability; import work.slhaf.partner.api.agent.factory.capability.annotation.InjectCapability;
import work.slhaf.partner.api.agent.factory.module.abstracts.ActivateModel; import work.slhaf.partner.api.agent.factory.module.abstracts.ActivateModel;
import work.slhaf.partner.api.agent.factory.module.annotation.AgentRunningModule;
import work.slhaf.partner.api.agent.factory.module.annotation.InjectModule; import work.slhaf.partner.api.agent.factory.module.annotation.InjectModule;
import work.slhaf.partner.core.action.ActionCapability; import work.slhaf.partner.core.action.ActionCapability;
import work.slhaf.partner.core.action.ActionCore; import work.slhaf.partner.core.action.ActionCore;
@@ -30,40 +29,32 @@ import java.util.HashMap;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.stream.Stream; import java.util.stream.Stream;
/** /**
* 负责识别潜在的行动干预信息,作用于正在进行或已存在的行动池中内容 * 负责识别潜在的行动干预信息,作用于正在进行或已存在的行动池中内容
*/ */
@AgentRunningModule(name = "action_identifier", order = 2)
public class ActionInterventor extends PreRunningAbstractAgentModuleAbstract implements ActivateModel { public class ActionInterventor extends PreRunningAbstractAgentModuleAbstract implements ActivateModel {
@InjectModule @InjectModule
private InterventionRecognizer interventionRecognizer; private InterventionRecognizer interventionRecognizer;
@InjectModule @InjectModule
private InterventionEvaluator interventionEvaluator; private InterventionEvaluator interventionEvaluator;
@InjectCapability @InjectCapability
private ActionCapability actionCapability; private ActionCapability actionCapability;
@InjectCapability @InjectCapability
private CognationCapability cognationCapability; private CognationCapability cognationCapability;
@InjectCapability @InjectCapability
private MemoryCapability memoryCapability; private MemoryCapability memoryCapability;
private final AssemblyHelper assemblyHelper = new AssemblyHelper(); private final AssemblyHelper assemblyHelper = new AssemblyHelper();
private final PromptHelper promptHelper = new PromptHelper(); private final PromptHelper promptHelper = new PromptHelper();
/** /**
* 键: 本次调用uuid * 键: 本次调用uuid
* 值本次调用对应的prompt * 值本次调用对应的prompt
*/ */
private final Map<String, Map<String, String>> interventionPrompt = new HashMap<>(); private final Map<String, Map<String, String>> interventionPrompt = new HashMap<>();
@Override @Override
protected void doExecute(PartnerRunningFlowContext context) { protected void doExecute(PartnerRunningFlowContext context) {
// 综合当前正在进行的行动链信息、用户交互历史、激活的记忆切片,尝试识别出是否存在行动干预意图 // 综合当前正在进行的行动链信息、用户交互历史、激活的记忆切片,尝试识别出是否存在行动干预意图
// 首先通过recognizer进行快速意图识别识别成功则步入评估阶段评估成功则直接作用于目标行动链 // 首先通过recognizer进行快速意图识别识别成功则步入评估阶段评估成功则直接作用于目标行动链
// 进行快速意图识别时必须结合近期对话与进行中行动链情况 // 进行快速意图识别时必须结合近期对话与进行中行动链情况
// 干预意图识别 // 干预意图识别
String uuid = context.getUuid(); String uuid = context.getUuid();
String userId = context.getUserId(); String userId = context.getUserId();
@@ -73,19 +64,16 @@ public class ActionInterventor extends PreRunningAbstractAgentModuleAbstract imp
promptHelper.setupNoInterventionPrompt(uuid); promptHelper.setupNoInterventionPrompt(uuid);
return; return;
} }
// 干预意图评估 // 干预意图评估
EvaluatorResult evaluatorResult = interventionEvaluator EvaluatorResult evaluatorResult = interventionEvaluator
.execute(assemblyHelper.buildEvaluatorInput(recognizerResult, userId)); .execute(assemblyHelper.buildEvaluatorInput(recognizerResult, userId));
List<EvaluatedInterventionData> executingDataList = evaluatorResult.getExecutingDataList(); List<EvaluatedInterventionData> executingDataList = evaluatorResult.getExecutingDataList();
List<EvaluatedInterventionData> preparedDataList = evaluatorResult.getPreparedDataList(); List<EvaluatedInterventionData> preparedDataList = evaluatorResult.getPreparedDataList();
// 意图评估结果处理 // 意图评估结果处理
if (evaluatorResult.isOk()) { if (evaluatorResult.isOk()) {
// 对存在异常ActionKey的评估结果列表进行过滤 // 对存在异常ActionKey的评估结果列表进行过滤
invalidActionKeysFilter(executingDataList); invalidActionKeysFilter(executingDataList);
invalidActionKeysFilter(preparedDataList); invalidActionKeysFilter(preparedDataList);
// 同步写入prompt异步处理干预行为异步在处理流程中体现 // 同步写入prompt异步处理干预行为异步在处理流程中体现
promptHelper.setupInterventionPrompt(uuid, executingDataList, preparedDataList); promptHelper.setupInterventionPrompt(uuid, executingDataList, preparedDataList);
handleInterventions(executingDataList, recognizerResult.getExecutingInterventions()); handleInterventions(executingDataList, recognizerResult.getExecutingInterventions());
@@ -93,9 +81,7 @@ public class ActionInterventor extends PreRunningAbstractAgentModuleAbstract imp
} else { } else {
promptHelper.setupInterventionIgnoredPrompt(uuid, executingDataList, preparedDataList); promptHelper.setupInterventionIgnoredPrompt(uuid, executingDataList, preparedDataList);
} }
} }
private void handleInterventions(List<EvaluatedInterventionData> interventionDataList, Map<String, ExecutableAction> interventionDataMap) { private void handleInterventions(List<EvaluatedInterventionData> interventionDataList, Map<String, ExecutableAction> interventionDataMap) {
val executor = actionCapability.getExecutor(ActionCore.ExecutorType.PLATFORM); val executor = actionCapability.getExecutor(ActionCore.ExecutorType.PLATFORM);
executor.execute(() -> { executor.execute(() -> {
@@ -106,11 +92,8 @@ public class ActionInterventor extends PreRunningAbstractAgentModuleAbstract imp
} }
}); });
} }
private void invalidActionKeysFilter(List<EvaluatedInterventionData> interventions) { private void invalidActionKeysFilter(List<EvaluatedInterventionData> interventions) {
List<EvaluatedInterventionData> toRemove = new ArrayList<>(); List<EvaluatedInterventionData> toRemove = new ArrayList<>();
for (EvaluatedInterventionData intervention : interventions) { for (EvaluatedInterventionData intervention : interventions) {
List<MetaIntervention> interventionData = intervention.getMetaInterventionList(); List<MetaIntervention> interventionData = intervention.getMetaInterventionList();
List<String> actions = new ArrayList<>(); List<String> actions = new ArrayList<>();
@@ -121,16 +104,13 @@ public class ActionInterventor extends PreRunningAbstractAgentModuleAbstract imp
if (!actionCapability.checkExists(actions.toArray(String[]::new))) { if (!actionCapability.checkExists(actions.toArray(String[]::new))) {
toRemove.add(intervention); toRemove.add(intervention);
} }
// 针对 REBUILD 类型进行特殊校验, REBUILD 类型必须满足所有 MetaIntervention 的类型均为 REBUILD // 针对 REBUILD 类型进行特殊校验, REBUILD 类型必须满足所有 MetaIntervention 的类型均为 REBUILD
if (!checkRebuildType(interventionData)) { if (!checkRebuildType(interventionData)) {
toRemove.add(intervention); toRemove.add(intervention);
} }
} }
interventions.removeAll(toRemove); interventions.removeAll(toRemove);
} }
private boolean checkRebuildType(List<MetaIntervention> interventionData) { private boolean checkRebuildType(List<MetaIntervention> interventionData) {
boolean hasRebuild = false; boolean hasRebuild = false;
for (MetaIntervention meta : interventionData) { for (MetaIntervention meta : interventionData) {
@@ -141,34 +121,27 @@ public class ActionInterventor extends PreRunningAbstractAgentModuleAbstract imp
return false; return false;
} }
} }
return true; return true;
} }
@Override @Override
public String modelKey() { public String modelKey() {
return "action_identifier"; return "action_identifier";
} }
@Override @Override
public boolean withBasicPrompt() { public boolean withBasicPrompt() {
return false; return false;
} }
@Override @Override
protected Map<String, String> getPromptDataMap(PartnerRunningFlowContext context) { protected Map<String, String> getPromptDataMap(PartnerRunningFlowContext context) {
return interventionPrompt.remove(context.getUuid()); return interventionPrompt.remove(context.getUuid());
} }
@Override @Override
protected String moduleName() { protected String moduleName() {
return "[行动干预识别模块]"; return "[行动干预识别模块]";
} }
private final class AssemblyHelper { private final class AssemblyHelper {
private AssemblyHelper() { private AssemblyHelper() {
} }
private RecognizerInput buildRecognizerInput(String userId, String input) { private RecognizerInput buildRecognizerInput(String userId, String input) {
RecognizerInput recognizerInput = new RecognizerInput(); RecognizerInput recognizerInput = new RecognizerInput();
recognizerInput.setInput(input); recognizerInput.setInput(input);
@@ -179,7 +152,6 @@ public class ActionInterventor extends PreRunningAbstractAgentModuleAbstract imp
recognizerInput.setPreparedActions(actionCapability.listActions(ExecutableAction.Status.PREPARE, userId).stream().toList()); recognizerInput.setPreparedActions(actionCapability.listActions(ExecutableAction.Status.PREPARE, userId).stream().toList());
return recognizerInput; return recognizerInput;
} }
private EvaluatorInput buildEvaluatorInput(RecognizerResult recognizerResult, String userId) { private EvaluatorInput buildEvaluatorInput(RecognizerResult recognizerResult, String userId) {
EvaluatorInput input = new EvaluatorInput(); EvaluatorInput input = new EvaluatorInput();
input.setExecutingInterventions(recognizerResult.getExecutingInterventions()); input.setExecutingInterventions(recognizerResult.getExecutingInterventions());
@@ -189,22 +161,17 @@ public class ActionInterventor extends PreRunningAbstractAgentModuleAbstract imp
return input; return input;
} }
} }
private final class PromptHelper { private final class PromptHelper {
private PromptHelper() { private PromptHelper() {
} }
private void setupInterventionIgnoredPrompt(String uuid, List<EvaluatedInterventionData> executingDataList, List<EvaluatedInterventionData> preparedDataList) { private void setupInterventionIgnoredPrompt(String uuid, List<EvaluatedInterventionData> executingDataList, List<EvaluatedInterventionData> preparedDataList) {
List<EvaluatedInterventionData> total = Stream.concat(executingDataList.stream(), preparedDataList.stream()).toList(); List<EvaluatedInterventionData> total = Stream.concat(executingDataList.stream(), preparedDataList.stream()).toList();
JSONArray reasons = new JSONArray(); JSONArray reasons = new JSONArray();
for (EvaluatedInterventionData data : total) { for (EvaluatedInterventionData data : total) {
JSONObject reason = reasons.addObject(); JSONObject reason = reasons.addObject();
reason.put("[干预倾向]", data.getTendency()); reason.put("[干预倾向]", data.getTendency());
reason.put("[未采用原因]", data.getDescription()); reason.put("[未采用原因]", data.getDescription());
} }
synchronized (interventionPrompt) { synchronized (interventionPrompt) {
interventionPrompt.put(uuid, Map.of( interventionPrompt.put(uuid, Map.of(
"[识别状态] <是否识别到干预已存在行动的意图>", "识别到,但都未采用", "[识别状态] <是否识别到干预已存在行动的意图>", "识别到,但都未采用",
@@ -212,12 +179,10 @@ public class ActionInterventor extends PreRunningAbstractAgentModuleAbstract imp
"[干预行动] <将对已存在行动做出的行为>", "无行为")); "[干预行动] <将对已存在行动做出的行为>", "无行为"));
} }
} }
private void setupInterventionPrompt(String uuid, List<EvaluatedInterventionData> executingDataList, private void setupInterventionPrompt(String uuid, List<EvaluatedInterventionData> executingDataList,
List<EvaluatedInterventionData> preparedDataList) { List<EvaluatedInterventionData> preparedDataList) {
JSONArray contents = new JSONArray(); JSONArray contents = new JSONArray();
List<EvaluatedInterventionData> temp = Stream.concat(executingDataList.stream(), preparedDataList.stream()).toList(); List<EvaluatedInterventionData> temp = Stream.concat(executingDataList.stream(), preparedDataList.stream()).toList();
for (EvaluatedInterventionData data : temp) { for (EvaluatedInterventionData data : temp) {
if (!data.isOk()) { if (!data.isOk()) {
continue; continue;
@@ -226,7 +191,6 @@ public class ActionInterventor extends PreRunningAbstractAgentModuleAbstract imp
JSONObject newElement = contents.addObject(); JSONObject newElement = contents.addObject();
newElement.put("[干预倾向]", tendency); newElement.put("[干预倾向]", tendency);
JSONArray changes = newElement.putArray("[行动链变动情况]"); JSONArray changes = newElement.putArray("[行动链变动情况]");
for (MetaIntervention intervention : data.getMetaInterventionList()) { for (MetaIntervention intervention : data.getMetaInterventionList()) {
JSONObject change = changes.addObject(); JSONObject change = changes.addObject();
change.put("[干预类型]", intervention.getType()); change.put("[干预类型]", intervention.getType());
@@ -234,18 +198,21 @@ public class ActionInterventor extends PreRunningAbstractAgentModuleAbstract imp
change.putArray("[干预内容]").addAll(intervention.getActions()); change.putArray("[干预内容]").addAll(intervention.getActions());
} }
} }
synchronized (interventionPrompt) { synchronized (interventionPrompt) {
interventionPrompt.put(uuid, Map.of( interventionPrompt.put(uuid, Map.of(
"[识别状态] <是否识别到干预已存在行动的意图>", "识别到,将采用", "[识别状态] <是否识别到干预已存在行动的意图>", "识别到,将采用",
"[干预内容] <将对已存在行动做出的行为>", contents.toString())); "[干预内容] <将对已存在行动做出的行为>", contents.toString()));
} }
} }
private void setupNoInterventionPrompt(String uuid) { private void setupNoInterventionPrompt(String uuid) {
interventionPrompt.put(uuid, Map.of( interventionPrompt.put(uuid, Map.of(
"[识别状态] <是否识别到干预已存在行动的意图>", "未识别到干预意图", "[识别状态] <是否识别到干预已存在行动的意图>", "未识别到干预意图",
"[干预行动] <将对已存在行动做出的行为>", "无行动")); "[干预行动] <将对已存在行动做出的行为>", "无行动"));
} }
} }
@Override
public int order() {
return 2;
}
} }

View File

@@ -1,11 +1,9 @@
package work.slhaf.partner.module.modules.action.interventor.evaluator; package work.slhaf.partner.module.modules.action.interventor.evaluator;
import com.alibaba.fastjson2.JSONObject; import com.alibaba.fastjson2.JSONObject;
import lombok.extern.slf4j.Slf4j;
import work.slhaf.partner.api.agent.factory.capability.annotation.InjectCapability; import work.slhaf.partner.api.agent.factory.capability.annotation.InjectCapability;
import work.slhaf.partner.api.agent.factory.module.abstracts.AbstractAgentSubModule; import work.slhaf.partner.api.agent.factory.module.abstracts.AbstractAgentModule;
import work.slhaf.partner.api.agent.factory.module.abstracts.ActivateModel; import work.slhaf.partner.api.agent.factory.module.abstracts.ActivateModel;
import work.slhaf.partner.api.agent.factory.module.annotation.AgentSubModule;
import work.slhaf.partner.api.chat.pojo.ChatResponse; import work.slhaf.partner.api.chat.pojo.ChatResponse;
import work.slhaf.partner.api.chat.pojo.Message; import work.slhaf.partner.api.chat.pojo.Message;
import work.slhaf.partner.core.action.ActionCapability; import work.slhaf.partner.core.action.ActionCapability;
@@ -21,14 +19,10 @@ import java.util.Map;
import java.util.concurrent.CountDownLatch; import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService; import java.util.concurrent.ExecutorService;
@Slf4j public class InterventionEvaluator extends AbstractAgentModule.Sub<EvaluatorInput, EvaluatorResult>
@AgentSubModule
public class InterventionEvaluator extends AbstractAgentSubModule<EvaluatorInput, EvaluatorResult>
implements ActivateModel { implements ActivateModel {
@InjectCapability @InjectCapability
private ActionCapability actionCapability; private ActionCapability actionCapability;
/** /**
* 基于干预意图、记忆切片、交互上下文、已有行动程序综合评估,尝试评估并选取出合适的行动程序,交付给 ActionInterventor * 基于干预意图、记忆切片、交互上下文、已有行动程序综合评估,尝试评估并选取出合适的行动程序,交付给 ActionInterventor
*/ */
@@ -39,30 +33,24 @@ public class InterventionEvaluator extends AbstractAgentSubModule<EvaluatorInput
Map<String, ExecutableAction> executingInterventions = input.getExecutingInterventions(); Map<String, ExecutableAction> executingInterventions = input.getExecutingInterventions();
Map<String, ExecutableAction> preparedInterventions = input.getPreparedInterventions(); Map<String, ExecutableAction> preparedInterventions = input.getPreparedInterventions();
CountDownLatch latch = new CountDownLatch(executingInterventions.size() + preparedInterventions.size()); CountDownLatch latch = new CountDownLatch(executingInterventions.size() + preparedInterventions.size());
// 创建结果容器 // 创建结果容器
EvaluatorResult result = new EvaluatorResult(); EvaluatorResult result = new EvaluatorResult();
List<EvaluatedInterventionData> executingDataList = result.getExecutingDataList(); List<EvaluatedInterventionData> executingDataList = result.getExecutingDataList();
List<EvaluatedInterventionData> preparedDataList = result.getPreparedDataList(); List<EvaluatedInterventionData> preparedDataList = result.getPreparedDataList();
// 并发评估 // 并发评估
evaluateIntervention(executingDataList, executingInterventions, input, executor, latch); evaluateIntervention(executingDataList, executingInterventions, input, executor, latch);
evaluateIntervention(preparedDataList, preparedInterventions, input, executor, latch); evaluateIntervention(preparedDataList, preparedInterventions, input, executor, latch);
try { try {
latch.await(); latch.await();
} catch (InterruptedException e) { } catch (InterruptedException e) {
log.warn("CountDownLatch阻塞已中断"); log.warn("CountDownLatch阻塞已中断");
} }
return result; return result;
} }
private void evaluateIntervention(List<EvaluatedInterventionData> evaluatedDataList, Map<String, ExecutableAction> interventionMap, EvaluatorInput input, ExecutorService executor, CountDownLatch latch) { private void evaluateIntervention(List<EvaluatedInterventionData> evaluatedDataList, Map<String, ExecutableAction> interventionMap, EvaluatorInput input, ExecutorService executor, CountDownLatch latch) {
interventionMap.forEach((tendency, actionData) -> executor.execute(() -> { interventionMap.forEach((tendency, actionData) -> executor.execute(() -> {
try { try {
String prompt = buildPrompt(input.getRecentMessages(), input.getActivatedSlices(), actionData, tendency); String prompt = buildPrompt(input.getRecentMessages(), input.getActivatedSlices(), actionData, tendency);
ChatResponse response = this.singleChat(prompt); ChatResponse response = this.singleChat(prompt);
EvaluatedInterventionData evaluatedData = JSONObject.parseObject(response.getMessage(), EvaluatedInterventionData evaluatedData = JSONObject.parseObject(response.getMessage(),
EvaluatedInterventionData.class); EvaluatedInterventionData.class);
@@ -76,7 +64,6 @@ public class InterventionEvaluator extends AbstractAgentSubModule<EvaluatorInput
} }
})); }));
} }
private String buildPrompt(List<Message> recentMessages, List<EvaluatedSlice> activatedSlices, private String buildPrompt(List<Message> recentMessages, List<EvaluatedSlice> activatedSlices,
ExecutableAction executableAction, String tendency) { ExecutableAction executableAction, String tendency) {
JSONObject json = new JSONObject(); JSONObject json = new JSONObject();
@@ -86,12 +73,10 @@ public class InterventionEvaluator extends AbstractAgentSubModule<EvaluatorInput
json.put("将干预的行动", JSONObject.toJSONString(executableAction)); json.put("将干预的行动", JSONObject.toJSONString(executableAction));
return json.toJSONString(); return json.toJSONString();
} }
@Override @Override
public String modelKey() { public String modelKey() {
return "intervention_evaluator"; return "intervention_evaluator";
} }
@Override @Override
public boolean withBasicPrompt() { public boolean withBasicPrompt() {
return false; return false;

View File

@@ -1,11 +1,9 @@
package work.slhaf.partner.module.modules.action.interventor.recognizer; package work.slhaf.partner.module.modules.action.interventor.recognizer;
import com.alibaba.fastjson2.JSONObject; import com.alibaba.fastjson2.JSONObject;
import lombok.extern.slf4j.Slf4j;
import work.slhaf.partner.api.agent.factory.capability.annotation.InjectCapability; import work.slhaf.partner.api.agent.factory.capability.annotation.InjectCapability;
import work.slhaf.partner.api.agent.factory.module.abstracts.AbstractAgentSubModule; import work.slhaf.partner.api.agent.factory.module.abstracts.AbstractAgentModule;
import work.slhaf.partner.api.agent.factory.module.abstracts.ActivateModel; import work.slhaf.partner.api.agent.factory.module.abstracts.ActivateModel;
import work.slhaf.partner.api.agent.factory.module.annotation.AgentSubModule;
import work.slhaf.partner.api.chat.pojo.ChatResponse; import work.slhaf.partner.api.chat.pojo.ChatResponse;
import work.slhaf.partner.core.action.ActionCapability; import work.slhaf.partner.core.action.ActionCapability;
import work.slhaf.partner.core.action.ActionCore; import work.slhaf.partner.core.action.ActionCore;
@@ -19,13 +17,9 @@ import java.util.Map;
import java.util.concurrent.CountDownLatch; import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService; import java.util.concurrent.ExecutorService;
@Slf4j public class InterventionRecognizer extends AbstractAgentModule.Sub<RecognizerInput, RecognizerResult> implements ActivateModel {
@AgentSubModule
public class InterventionRecognizer extends AbstractAgentSubModule<RecognizerInput, RecognizerResult> implements ActivateModel {
@InjectCapability @InjectCapability
private ActionCapability actionCapability; private ActionCapability actionCapability;
@Override @Override
public RecognizerResult execute(RecognizerInput input) { public RecognizerResult execute(RecognizerInput input) {
// 获取必须数据 // 获取必须数据
@@ -33,16 +27,13 @@ public class InterventionRecognizer extends AbstractAgentSubModule<RecognizerInp
List<ExecutableAction> executingActions = input.getExecutingActions(); List<ExecutableAction> executingActions = input.getExecutingActions();
List<ExecutableAction> preparedActions = input.getPreparedActions(); List<ExecutableAction> preparedActions = input.getPreparedActions();
CountDownLatch countDownLatch = new CountDownLatch(executingActions.size() + preparedActions.size()); CountDownLatch countDownLatch = new CountDownLatch(executingActions.size() + preparedActions.size());
// 创建结果容器 // 创建结果容器
RecognizerResult recognizerResult = new RecognizerResult(); RecognizerResult recognizerResult = new RecognizerResult();
Map<String, ExecutableAction> executingInterventions = recognizerResult.getExecutingInterventions(); Map<String, ExecutableAction> executingInterventions = recognizerResult.getExecutingInterventions();
Map<String, ExecutableAction> preparedInterventions = recognizerResult.getPreparedInterventions(); Map<String, ExecutableAction> preparedInterventions = recognizerResult.getPreparedInterventions();
// 执行识别操作 // 执行识别操作
recognizeIntervention(executingInterventions, executingActions, executor, input, countDownLatch); recognizeIntervention(executingInterventions, executingActions, executor, input, countDownLatch);
recognizeIntervention(preparedInterventions, preparedActions, executor, input, countDownLatch); recognizeIntervention(preparedInterventions, preparedActions, executor, input, countDownLatch);
try { try {
countDownLatch.await(); countDownLatch.await();
} catch (InterruptedException e) { } catch (InterruptedException e) {
@@ -50,7 +41,6 @@ public class InterventionRecognizer extends AbstractAgentSubModule<RecognizerInp
} }
return recognizerResult; return recognizerResult;
} }
private void recognizeIntervention(Map<String, ExecutableAction> interventionsMap, List<ExecutableAction> actions, ExecutorService executor, RecognizerInput input, CountDownLatch latch) { private void recognizeIntervention(Map<String, ExecutableAction> interventionsMap, List<ExecutableAction> actions, ExecutorService executor, RecognizerInput input, CountDownLatch latch) {
for (ExecutableAction data : actions) { for (ExecutableAction data : actions) {
executor.execute(() -> { executor.execute(() -> {
@@ -71,30 +61,24 @@ public class InterventionRecognizer extends AbstractAgentSubModule<RecognizerInp
}); });
} }
} }
private String buildPrompt(ExecutableAction executableAction, RecognizerInput input) { private String buildPrompt(ExecutableAction executableAction, RecognizerInput input) {
JSONObject json = new JSONObject(); JSONObject json = new JSONObject();
JSONObject actionInfo = json.putObject("行动信息"); JSONObject actionInfo = json.putObject("行动信息");
actionInfo.put("行动倾向", executableAction.getTendency()); actionInfo.put("行动倾向", executableAction.getTendency());
actionInfo.put("行动原因", executableAction.getReason()); actionInfo.put("行动原因", executableAction.getReason());
actionInfo.put("行动描述", executableAction.getDescription()); actionInfo.put("行动描述", executableAction.getDescription());
actionInfo.put("行动状态", executableAction.getStatus()); actionInfo.put("行动状态", executableAction.getStatus());
actionInfo.put("行动来源", executableAction.getSource()); actionInfo.put("行动来源", executableAction.getSource());
JSONObject interactionInfo = json.putObject("交互信息"); JSONObject interactionInfo = json.putObject("交互信息");
interactionInfo.put("用户输入", input.getInput()); interactionInfo.put("用户输入", input.getInput());
interactionInfo.put("当前对话", input.getRecentMessages()); interactionInfo.put("当前对话", input.getRecentMessages());
interactionInfo.put("近期对话", input.getUserDialogMapStr()); interactionInfo.put("近期对话", input.getUserDialogMapStr());
return json.toString(); return json.toString();
} }
@Override @Override
public String modelKey() { public String modelKey() {
return "intervention_recognizer"; return "intervention_recognizer";
} }
@Override @Override
public boolean withBasicPrompt() { public boolean withBasicPrompt() {
return false; return false;

View File

@@ -1,10 +1,8 @@
package work.slhaf.partner.module.modules.action.planner; package work.slhaf.partner.module.modules.action.planner;
import lombok.extern.slf4j.Slf4j;
import lombok.val; import lombok.val;
import org.jetbrains.annotations.NotNull; import org.jetbrains.annotations.NotNull;
import work.slhaf.partner.api.agent.factory.capability.annotation.InjectCapability; import work.slhaf.partner.api.agent.factory.capability.annotation.InjectCapability;
import work.slhaf.partner.api.agent.factory.module.annotation.AgentRunningModule;
import work.slhaf.partner.api.agent.factory.module.annotation.Init; import work.slhaf.partner.api.agent.factory.module.annotation.Init;
import work.slhaf.partner.api.agent.factory.module.annotation.InjectModule; import work.slhaf.partner.api.agent.factory.module.annotation.InjectModule;
import work.slhaf.partner.api.chat.pojo.Message; import work.slhaf.partner.api.chat.pojo.Message;
@@ -33,14 +31,10 @@ import java.util.*;
import java.util.concurrent.Callable; import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService; import java.util.concurrent.ExecutorService;
import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicBoolean;
/** /**
* 负责针对本次输入生成基础的行动计划,在主模型传达意愿后,执行行动或者放入计划池 * 负责针对本次输入生成基础的行动计划,在主模型传达意愿后,执行行动或者放入计划池
*/ */
@Slf4j
@AgentRunningModule(name = "action_planner", order = 2)
public class ActionPlanner extends PreRunningAbstractAgentModuleAbstract { public class ActionPlanner extends PreRunningAbstractAgentModuleAbstract {
@InjectCapability @InjectCapability
private CognationCapability cognationCapability; private CognationCapability cognationCapability;
@InjectCapability @InjectCapability
@@ -49,23 +43,18 @@ public class ActionPlanner extends PreRunningAbstractAgentModuleAbstract {
private PerceiveCapability perceiveCapability; private PerceiveCapability perceiveCapability;
@InjectCapability @InjectCapability
private MemoryCapability memoryCapability; private MemoryCapability memoryCapability;
@InjectModule @InjectModule
private ActionEvaluator actionEvaluator; private ActionEvaluator actionEvaluator;
@InjectModule @InjectModule
private ActionExtractor actionExtractor; private ActionExtractor actionExtractor;
@InjectModule @InjectModule
private ActionConfirmer actionConfirmer; private ActionConfirmer actionConfirmer;
private ExecutorService executor; private ExecutorService executor;
private final ActionAssemblyHelper assemblyHelper = new ActionAssemblyHelper(); private final ActionAssemblyHelper assemblyHelper = new ActionAssemblyHelper();
@Init @Init
public void init() { public void init() {
executor = actionCapability.getExecutor(ActionCore.ExecutorType.VIRTUAL); executor = actionCapability.getExecutor(ActionCore.ExecutorType.VIRTUAL);
} }
@Override @Override
protected void doExecute(PartnerRunningFlowContext context) { protected void doExecute(PartnerRunningFlowContext context) {
try { try {
@@ -77,7 +66,6 @@ public class ActionPlanner extends PreRunningAbstractAgentModuleAbstract {
log.error("执行异常", e); log.error("执行异常", e);
} }
} }
/** /**
* 新的提取与评估任务 * 新的提取与评估任务
* *
@@ -98,7 +86,6 @@ public class ActionPlanner extends PreRunningAbstractAgentModuleAbstract {
return null; return null;
}); });
} }
private void updateTendencyCache(List<EvaluatorResult> evaluatorResults, String input, private void updateTendencyCache(List<EvaluatorResult> evaluatorResults, String input,
ExtractorResult extractorResult) { ExtractorResult extractorResult) {
if (!VectorClient.status) { if (!VectorClient.status) {
@@ -119,9 +106,7 @@ public class ActionPlanner extends PreRunningAbstractAgentModuleAbstract {
data.setInput(input); data.setInput(input);
actionCapability.updateTendencyCache(data); actionCapability.updateTendencyCache(data);
}); });
} }
/** /**
* 待确认行动的判断任务 * 待确认行动的判断任务
* *
@@ -136,7 +121,6 @@ public class ActionPlanner extends PreRunningAbstractAgentModuleAbstract {
return null; return null;
}); });
} }
private void setupConfirmedActionInfo(PartnerRunningFlowContext context, ConfirmerResult result) { private void setupConfirmedActionInfo(PartnerRunningFlowContext context, ConfirmerResult result) {
// TODO 需考虑未确认任务的失效或者拒绝时机在action core中实现 // TODO 需考虑未确认任务的失效或者拒绝时机在action core中实现
List<String> uuids = result.getUuids(); List<String> uuids = result.getUuids();
@@ -150,7 +134,6 @@ public class ActionPlanner extends PreRunningAbstractAgentModuleAbstract {
} }
} }
} }
private void putActionData(List<EvaluatorResult> evaluatorResults, PartnerRunningFlowContext context) { private void putActionData(List<EvaluatorResult> evaluatorResults, PartnerRunningFlowContext context) {
for (EvaluatorResult evaluatorResult : evaluatorResults) { for (EvaluatorResult evaluatorResult : evaluatorResults) {
ExecutableAction executableAction = assemblyHelper.buildActionData(evaluatorResult, context.getUserId()); ExecutableAction executableAction = assemblyHelper.buildActionData(evaluatorResult, context.getUserId());
@@ -161,7 +144,6 @@ public class ActionPlanner extends PreRunningAbstractAgentModuleAbstract {
} }
} }
} }
@Override @Override
protected Map<String, String> getPromptDataMap(PartnerRunningFlowContext context) { protected Map<String, String> getPromptDataMap(PartnerRunningFlowContext context) {
HashMap<String, String> map = new HashMap<>(); HashMap<String, String> map = new HashMap<>();
@@ -170,7 +152,6 @@ public class ActionPlanner extends PreRunningAbstractAgentModuleAbstract {
setupPreparedActions(map, userId); setupPreparedActions(map, userId);
return map; return map;
} }
private void setupPendingActions(HashMap<String, String> map, String userId) { private void setupPendingActions(HashMap<String, String> map, String userId) {
List<ExecutableAction> executableActionData = actionCapability.listPendingAction(userId); List<ExecutableAction> executableActionData = actionCapability.listPendingAction(userId);
if (executableActionData == null || executableActionData.isEmpty()) { if (executableActionData == null || executableActionData.isEmpty()) {
@@ -181,7 +162,6 @@ public class ActionPlanner extends PreRunningAbstractAgentModuleAbstract {
map.put("[待确认行动 " + (i + 1) + " ] <等待用户确认的行动信息>", generateActionStr(executableActionData.get(i))); map.put("[待确认行动 " + (i + 1) + " ] <等待用户确认的行动信息>", generateActionStr(executableActionData.get(i)));
} }
} }
private void setupPreparedActions(HashMap<String, String> map, String userId) { private void setupPreparedActions(HashMap<String, String> map, String userId) {
val preparedActions = actionCapability.listActions(ExecutableAction.Status.PREPARE, userId).stream().toList(); val preparedActions = actionCapability.listActions(ExecutableAction.Status.PREPARE, userId).stream().toList();
if (preparedActions.isEmpty()) { if (preparedActions.isEmpty()) {
@@ -192,22 +172,18 @@ public class ActionPlanner extends PreRunningAbstractAgentModuleAbstract {
map.put("[预备行动 " + (i + 1) + " ] <预备执行或放入计划池的行动信息>", generateActionStr(preparedActions.get(i))); map.put("[预备行动 " + (i + 1) + " ] <预备执行或放入计划池的行动信息>", generateActionStr(preparedActions.get(i)));
} }
} }
private String generateActionStr(ExecutableAction executableAction) { private String generateActionStr(ExecutableAction executableAction) {
return "<行动倾向>" + " : " + executableAction.getTendency() + return "<行动倾向>" + " : " + executableAction.getTendency() +
"<行动原因>" + " : " + executableAction.getReason() + "<行动原因>" + " : " + executableAction.getReason() +
"<工具描述>" + " : " + executableAction.getDescription(); "<工具描述>" + " : " + executableAction.getDescription();
} }
@Override @Override
protected String moduleName() { protected String moduleName() {
return "[行动模块]"; return "[行动模块]";
} }
private final class ActionAssemblyHelper { private final class ActionAssemblyHelper {
private ActionAssemblyHelper() { private ActionAssemblyHelper() {
} }
private ExtractorInput buildExtractorInput(PartnerRunningFlowContext context) { private ExtractorInput buildExtractorInput(PartnerRunningFlowContext context) {
ExtractorInput input = new ExtractorInput(); ExtractorInput input = new ExtractorInput();
input.setInput(context.getInput()); input.setInput(context.getInput());
@@ -221,7 +197,6 @@ public class ActionPlanner extends PreRunningAbstractAgentModuleAbstract {
input.setRecentMessages(recentMessages); input.setRecentMessages(recentMessages);
return input; return input;
} }
private EvaluatorInput buildEvaluatorInput(ExtractorResult extractorResult, String userId) { private EvaluatorInput buildEvaluatorInput(ExtractorResult extractorResult, String userId) {
EvaluatorInput input = new EvaluatorInput(); EvaluatorInput input = new EvaluatorInput();
input.setTendencies(extractorResult.getTendencies()); input.setTendencies(extractorResult.getTendencies());
@@ -230,7 +205,6 @@ public class ActionPlanner extends PreRunningAbstractAgentModuleAbstract {
input.setActivatedSlices(memoryCapability.getActivatedSlices(userId)); input.setActivatedSlices(memoryCapability.getActivatedSlices(userId));
return input; return input;
} }
private ExecutableAction buildActionData(EvaluatorResult evaluatorResult, String userId) { private ExecutableAction buildActionData(EvaluatorResult evaluatorResult, String userId) {
Map<Integer, List<MetaAction>> actionChain = getActionChain(evaluatorResult); Map<Integer, List<MetaAction>> actionChain = getActionChain(evaluatorResult);
return switch (evaluatorResult.getType()) { return switch (evaluatorResult.getType()) {
@@ -252,7 +226,6 @@ public class ActionPlanner extends PreRunningAbstractAgentModuleAbstract {
); );
}; };
} }
private @NotNull Map<Integer, List<MetaAction>> getActionChain(EvaluatorResult evaluatorResult) { private @NotNull Map<Integer, List<MetaAction>> getActionChain(EvaluatorResult evaluatorResult) {
Map<Integer, List<MetaAction>> actionChain = new HashMap<>(); Map<Integer, List<MetaAction>> actionChain = new HashMap<>();
Map<Integer, List<String>> primaryActionChain = evaluatorResult.getPrimaryActionChain(); Map<Integer, List<String>> primaryActionChain = evaluatorResult.getPrimaryActionChain();
@@ -265,7 +238,6 @@ public class ActionPlanner extends PreRunningAbstractAgentModuleAbstract {
}); });
return actionChain; return actionChain;
} }
private void fixDependencies(Map<Integer, List<String>> primaryActionChain) { private void fixDependencies(Map<Integer, List<String>> primaryActionChain) {
// 先将 primaryActionChain 的节点序号修正为从1开始依次增大 // 先将 primaryActionChain 的节点序号修正为从1开始依次增大
fixOrder(primaryActionChain); fixOrder(primaryActionChain);
@@ -291,7 +263,6 @@ public class ActionPlanner extends PreRunningAbstractAgentModuleAbstract {
if (checkDependenciesExist(lastOrder, preActions, primaryActionChain)) { if (checkDependenciesExist(lastOrder, preActions, primaryActionChain)) {
continue; continue;
} }
// 如果存在前置依赖,则将其放置在当前order之前的位置, // 如果存在前置依赖,则将其放置在当前order之前的位置,
// 放置位置优先选择已存在的上一节点,如果不存在(行动链的头节点时)则需要向行动链新增order // 放置位置优先选择已存在的上一节点,如果不存在(行动链的头节点时)则需要向行动链新增order
// 不需要检查行动链的当前节点的已存在 Action 是否为新 Action 的依赖项,因为这些 Action 实际来自 LLM // 不需要检查行动链的当前节点的已存在 Action 是否为新 Action 的依赖项,因为这些 Action 实际来自 LLM
@@ -309,7 +280,6 @@ public class ActionPlanner extends PreRunningAbstractAgentModuleAbstract {
fixedOrders.addAll(tempOrders); fixedOrders.addAll(tempOrders);
} while (fixed.getAndSet(false)); } while (fixed.getAndSet(false));
} }
private void fixOrder(Map<Integer, List<String>> primaryActionChain) { private void fixOrder(Map<Integer, List<String>> primaryActionChain) {
Map<Integer, List<String>> tempChain = new HashMap<>(primaryActionChain); Map<Integer, List<String>> tempChain = new HashMap<>(primaryActionChain);
primaryActionChain.clear(); primaryActionChain.clear();
@@ -318,7 +288,6 @@ public class ActionPlanner extends PreRunningAbstractAgentModuleAbstract {
primaryActionChain.put(i, tempChain.get(i)); primaryActionChain.put(i, tempChain.get(i));
} }
} }
private boolean checkDependenciesExist(int lastOrder, List<String> preActions, private boolean checkDependenciesExist(int lastOrder, List<String> preActions,
Map<Integer, List<String>> primaryActionChain) { Map<Integer, List<String>> primaryActionChain) {
if (!primaryActionChain.containsKey(lastOrder)) { if (!primaryActionChain.containsKey(lastOrder)) {
@@ -328,7 +297,6 @@ public class ActionPlanner extends PreRunningAbstractAgentModuleAbstract {
//noinspection SlowListContainsAll //noinspection SlowListContainsAll
return existActions.containsAll(preActions); return existActions.containsAll(preActions);
} }
private ConfirmerInput buildConfirmerInput(PartnerRunningFlowContext context) { private ConfirmerInput buildConfirmerInput(PartnerRunningFlowContext context) {
ConfirmerInput confirmerInput = new ConfirmerInput(); ConfirmerInput confirmerInput = new ConfirmerInput();
confirmerInput.setInput(context.getInput()); confirmerInput.setInput(context.getInput());
@@ -337,4 +305,9 @@ public class ActionPlanner extends PreRunningAbstractAgentModuleAbstract {
return confirmerInput; return confirmerInput;
} }
} }
@Override
public int order() {
return 2;
}
} }

View File

@@ -2,11 +2,9 @@ package work.slhaf.partner.module.modules.action.planner.confirmer;
import com.alibaba.fastjson2.JSONArray; import com.alibaba.fastjson2.JSONArray;
import com.alibaba.fastjson2.JSONObject; import com.alibaba.fastjson2.JSONObject;
import lombok.extern.slf4j.Slf4j;
import work.slhaf.partner.api.agent.factory.capability.annotation.InjectCapability; import work.slhaf.partner.api.agent.factory.capability.annotation.InjectCapability;
import work.slhaf.partner.api.agent.factory.module.abstracts.AbstractAgentSubModule; import work.slhaf.partner.api.agent.factory.module.abstracts.AbstractAgentModule;
import work.slhaf.partner.api.agent.factory.module.abstracts.ActivateModel; import work.slhaf.partner.api.agent.factory.module.abstracts.ActivateModel;
import work.slhaf.partner.api.agent.factory.module.annotation.AgentSubModule;
import work.slhaf.partner.api.chat.pojo.ChatResponse; import work.slhaf.partner.api.chat.pojo.ChatResponse;
import work.slhaf.partner.api.chat.pojo.Message; import work.slhaf.partner.api.chat.pojo.Message;
import work.slhaf.partner.core.action.ActionCapability; import work.slhaf.partner.core.action.ActionCapability;
@@ -21,22 +19,16 @@ import java.util.concurrent.ExecutorService;
import static work.slhaf.partner.common.util.ExtractUtil.extractJson; import static work.slhaf.partner.common.util.ExtractUtil.extractJson;
@Slf4j public class ActionConfirmer extends AbstractAgentModule.Sub<ConfirmerInput, ConfirmerResult> implements ActivateModel {
@AgentSubModule
public class ActionConfirmer extends AbstractAgentSubModule<ConfirmerInput, ConfirmerResult> implements ActivateModel {
@InjectCapability @InjectCapability
private ActionCapability actionCapability; private ActionCapability actionCapability;
@Override @Override
public ConfirmerResult execute(ConfirmerInput data) { public ConfirmerResult execute(ConfirmerInput data) {
List<ExecutableAction> executableActionList = data.getExecutableActionData(); List<ExecutableAction> executableActionList = data.getExecutableActionData();
ExecutorService executor = actionCapability.getExecutor(ActionCore.ExecutorType.VIRTUAL); ExecutorService executor = actionCapability.getExecutor(ActionCore.ExecutorType.VIRTUAL);
CountDownLatch latch = new CountDownLatch(executableActionList.size()); CountDownLatch latch = new CountDownLatch(executableActionList.size());
ConfirmerResult result = new ConfirmerResult(); ConfirmerResult result = new ConfirmerResult();
List<String> uuids = result.getUuids(); List<String> uuids = result.getUuids();
for (ExecutableAction executableAction : executableActionList) { for (ExecutableAction executableAction : executableActionList) {
executor.execute(() -> { executor.execute(() -> {
try { try {
@@ -61,28 +53,22 @@ public class ActionConfirmer extends AbstractAgentSubModule<ConfirmerInput, Conf
} }
return result; return result;
} }
private String buildPrompt(ExecutableAction data, String input, List<Message> recentMessages) { private String buildPrompt(ExecutableAction data, String input, List<Message> recentMessages) {
JSONObject prompt = new JSONObject(); JSONObject prompt = new JSONObject();
prompt.put("[用户输入]", input); prompt.put("[用户输入]", input);
JSONObject actionData = prompt.putObject("[行动数据]"); JSONObject actionData = prompt.putObject("[行动数据]");
actionData.put("[行动倾向]", data.getTendency()); actionData.put("[行动倾向]", data.getTendency());
actionData.put("[行动原因]", data.getReason()); actionData.put("[行动原因]", data.getReason());
actionData.put("[行动来源]", data.getSource()); actionData.put("[行动来源]", data.getSource());
actionData.put("[行动描述]", data.getDescription()); actionData.put("[行动描述]", data.getDescription());
JSONArray messageData = prompt.putArray("[近期对话]"); JSONArray messageData = prompt.putArray("[近期对话]");
messageData.addAll(recentMessages); messageData.addAll(recentMessages);
return prompt.toString(); return prompt.toString();
} }
@Override @Override
public String modelKey() { public String modelKey() {
return "action-confirmer"; return "action-confirmer";
} }
@Override @Override
public boolean withBasicPrompt() { public boolean withBasicPrompt() {
return false; return false;

View File

@@ -4,9 +4,8 @@ import cn.hutool.core.bean.BeanUtil;
import com.alibaba.fastjson2.JSONArray; import com.alibaba.fastjson2.JSONArray;
import com.alibaba.fastjson2.JSONObject; import com.alibaba.fastjson2.JSONObject;
import work.slhaf.partner.api.agent.factory.capability.annotation.InjectCapability; import work.slhaf.partner.api.agent.factory.capability.annotation.InjectCapability;
import work.slhaf.partner.api.agent.factory.module.abstracts.AbstractAgentSubModule; import work.slhaf.partner.api.agent.factory.module.abstracts.AbstractAgentModule;
import work.slhaf.partner.api.agent.factory.module.abstracts.ActivateModel; import work.slhaf.partner.api.agent.factory.module.abstracts.ActivateModel;
import work.slhaf.partner.api.agent.factory.module.annotation.AgentSubModule;
import work.slhaf.partner.api.agent.factory.module.annotation.Init; import work.slhaf.partner.api.agent.factory.module.annotation.Init;
import work.slhaf.partner.api.chat.pojo.ChatResponse; import work.slhaf.partner.api.chat.pojo.ChatResponse;
import work.slhaf.partner.common.thread.InteractionThreadPoolExecutor; import work.slhaf.partner.common.thread.InteractionThreadPoolExecutor;
@@ -22,19 +21,14 @@ import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.concurrent.Callable; import java.util.concurrent.Callable;
@AgentSubModule public class ActionEvaluator extends AbstractAgentModule.Sub<EvaluatorInput, List<EvaluatorResult>> implements ActivateModel {
public class ActionEvaluator extends AbstractAgentSubModule<EvaluatorInput, List<EvaluatorResult>> implements ActivateModel {
@InjectCapability @InjectCapability
private ActionCapability actionCapability; private ActionCapability actionCapability;
private InteractionThreadPoolExecutor executor; private InteractionThreadPoolExecutor executor;
@Init @Init
public void init() { public void init() {
executor = InteractionThreadPoolExecutor.getInstance(); executor = InteractionThreadPoolExecutor.getInstance();
} }
/** /**
* 对输入的行为倾向进行评估,并根据评估结果,对缓存做出调整 * 对输入的行为倾向进行评估,并根据评估结果,对缓存做出调整
* *
@@ -47,7 +41,6 @@ public class ActionEvaluator extends AbstractAgentSubModule<EvaluatorInput, List
List<Callable<EvaluatorResult>> tasks = getTasks(batchInputs); List<Callable<EvaluatorResult>> tasks = getTasks(batchInputs);
return executor.invokeAllAndReturn(tasks); return executor.invokeAllAndReturn(tasks);
} }
private List<Callable<EvaluatorResult>> getTasks(List<EvaluatorBatchInput> batchInputs) { private List<Callable<EvaluatorResult>> getTasks(List<EvaluatorBatchInput> batchInputs) {
List<Callable<EvaluatorResult>> list = new ArrayList<>(); List<Callable<EvaluatorResult>> list = new ArrayList<>();
for (EvaluatorBatchInput batchInput : batchInputs) { for (EvaluatorBatchInput batchInput : batchInputs) {
@@ -60,7 +53,6 @@ public class ActionEvaluator extends AbstractAgentSubModule<EvaluatorInput, List
} }
return list; return list;
} }
private List<EvaluatorBatchInput> buildEvaluatorBatchInput(EvaluatorInput data) { private List<EvaluatorBatchInput> buildEvaluatorBatchInput(EvaluatorInput data) {
List<EvaluatorBatchInput> list = new ArrayList<>(); List<EvaluatorBatchInput> list = new ArrayList<>();
for (String tendency : data.getTendencies()) { for (String tendency : data.getTendencies()) {
@@ -74,30 +66,23 @@ public class ActionEvaluator extends AbstractAgentSubModule<EvaluatorInput, List
} }
return list; return list;
} }
private String buildPrompt(EvaluatorBatchInput batchInput) { private String buildPrompt(EvaluatorBatchInput batchInput) {
JSONObject prompt = new JSONObject(); JSONObject prompt = new JSONObject();
prompt.put("[行动倾向]", batchInput.getTendency()); prompt.put("[行动倾向]", batchInput.getTendency());
JSONArray memoryData = prompt.putArray("[相关记忆切片]"); JSONArray memoryData = prompt.putArray("[相关记忆切片]");
for (EvaluatedSlice evaluatedSlice : batchInput.getActivatedSlices()) { for (EvaluatedSlice evaluatedSlice : batchInput.getActivatedSlices()) {
JSONObject memory = memoryData.addObject(); JSONObject memory = memoryData.addObject();
memory.put("[日期]", evaluatedSlice.getDate()); memory.put("[日期]", evaluatedSlice.getDate());
memory.put("[摘要]", evaluatedSlice.getSummary()); memory.put("[摘要]", evaluatedSlice.getSummary());
} }
JSONObject availableActionData = prompt.putObject("[可用行动单元]"); JSONObject availableActionData = prompt.putObject("[可用行动单元]");
availableActionData.putAll(batchInput.getAvailableActions()); availableActionData.putAll(batchInput.getAvailableActions());
return prompt.toString(); return prompt.toString();
} }
@Override @Override
public String modelKey() { public String modelKey() {
return "action_evaluator"; return "action_evaluator";
} }
@Override @Override
public boolean withBasicPrompt() { public boolean withBasicPrompt() {
return true; return true;

View File

@@ -1,11 +1,9 @@
package work.slhaf.partner.module.modules.action.planner.extractor; package work.slhaf.partner.module.modules.action.planner.extractor;
import com.alibaba.fastjson2.JSONObject; import com.alibaba.fastjson2.JSONObject;
import lombok.extern.slf4j.Slf4j;
import work.slhaf.partner.api.agent.factory.capability.annotation.InjectCapability; import work.slhaf.partner.api.agent.factory.capability.annotation.InjectCapability;
import work.slhaf.partner.api.agent.factory.module.abstracts.AbstractAgentSubModule; import work.slhaf.partner.api.agent.factory.module.abstracts.AbstractAgentModule;
import work.slhaf.partner.api.agent.factory.module.abstracts.ActivateModel; import work.slhaf.partner.api.agent.factory.module.abstracts.ActivateModel;
import work.slhaf.partner.api.agent.factory.module.annotation.AgentSubModule;
import work.slhaf.partner.api.chat.pojo.ChatResponse; import work.slhaf.partner.api.chat.pojo.ChatResponse;
import work.slhaf.partner.core.action.ActionCapability; import work.slhaf.partner.core.action.ActionCapability;
import work.slhaf.partner.module.modules.action.planner.extractor.entity.ExtractorInput; import work.slhaf.partner.module.modules.action.planner.extractor.entity.ExtractorInput;
@@ -13,13 +11,9 @@ import work.slhaf.partner.module.modules.action.planner.extractor.entity.Extract
import java.util.List; import java.util.List;
@Slf4j public class ActionExtractor extends AbstractAgentModule.Sub<ExtractorInput, ExtractorResult> implements ActivateModel {
@AgentSubModule
public class ActionExtractor extends AbstractAgentSubModule<ExtractorInput, ExtractorResult> implements ActivateModel {
@InjectCapability @InjectCapability
private ActionCapability actionCapability; private ActionCapability actionCapability;
@Override @Override
public ExtractorResult execute(ExtractorInput data) { public ExtractorResult execute(ExtractorInput data) {
ExtractorResult result = new ExtractorResult(); ExtractorResult result = new ExtractorResult();
@@ -28,7 +22,6 @@ public class ActionExtractor extends AbstractAgentSubModule<ExtractorInput, Extr
result.setTendencies(tendencyCache); result.setTendencies(tendencyCache);
return result; return result;
} }
for (int i = 0; i < 3; i++) { for (int i = 0; i < 3; i++) {
try { try {
ChatResponse response = this.singleChat(JSONObject.toJSONString(data)); ChatResponse response = this.singleChat(JSONObject.toJSONString(data));
@@ -37,15 +30,12 @@ public class ActionExtractor extends AbstractAgentSubModule<ExtractorInput, Extr
log.error("[ActionExtractor] 提取信息出错", e); log.error("[ActionExtractor] 提取信息出错", e);
} }
} }
return new ExtractorResult(); return new ExtractorResult();
} }
@Override @Override
public String modelKey() { public String modelKey() {
return "action_extractor"; return "action_extractor";
} }
@Override @Override
public boolean withBasicPrompt() { public boolean withBasicPrompt() {
return false; return false;

View File

@@ -3,9 +3,7 @@ package work.slhaf.partner.module.modules.memory.selector;
import com.alibaba.fastjson2.JSONObject; import com.alibaba.fastjson2.JSONObject;
import lombok.Data; import lombok.Data;
import lombok.EqualsAndHashCode; import lombok.EqualsAndHashCode;
import lombok.extern.slf4j.Slf4j;
import work.slhaf.partner.api.agent.factory.capability.annotation.InjectCapability; import work.slhaf.partner.api.agent.factory.capability.annotation.InjectCapability;
import work.slhaf.partner.api.agent.factory.module.annotation.AgentRunningModule;
import work.slhaf.partner.api.agent.factory.module.annotation.InjectModule; import work.slhaf.partner.api.agent.factory.module.annotation.InjectModule;
import work.slhaf.partner.core.cognation.CognationCapability; import work.slhaf.partner.core.cognation.CognationCapability;
import work.slhaf.partner.core.memory.MemoryCapability; import work.slhaf.partner.core.memory.MemoryCapability;
@@ -24,23 +22,17 @@ import work.slhaf.partner.runtime.interaction.data.context.PartnerRunningFlowCon
import java.time.LocalDate; import java.time.LocalDate;
import java.util.*; import java.util.*;
@EqualsAndHashCode(callSuper = true) @EqualsAndHashCode(callSuper = true)
@Data @Data
@Slf4j
@AgentRunningModule(name = "memory_selector", order = 2)
public class MemorySelector extends PreRunningAbstractAgentModuleAbstract { public class MemorySelector extends PreRunningAbstractAgentModuleAbstract {
@InjectCapability @InjectCapability
private MemoryCapability memoryCapability; private MemoryCapability memoryCapability;
@InjectCapability @InjectCapability
private CognationCapability cognationCapability; private CognationCapability cognationCapability;
@InjectModule @InjectModule
private SliceSelectEvaluator sliceSelectEvaluator; private SliceSelectEvaluator sliceSelectEvaluator;
@InjectModule @InjectModule
private MemorySelectExtractor memorySelectExtractor; private MemorySelectExtractor memorySelectExtractor;
@Override @Override
public void doExecute(PartnerRunningFlowContext runningFlowContext) { public void doExecute(PartnerRunningFlowContext runningFlowContext) {
String userId = runningFlowContext.getUserId(); String userId = runningFlowContext.getUserId();
@@ -53,7 +45,6 @@ public class MemorySelector extends PreRunningAbstractAgentModuleAbstract {
} }
setModuleContextRecall(runningFlowContext); setModuleContextRecall(runningFlowContext);
} }
private List<EvaluatedSlice> selectAndEvaluateMemory(PartnerRunningFlowContext runningFlowContext, ExtractorResult extractorResult) { private List<EvaluatedSlice> selectAndEvaluateMemory(PartnerRunningFlowContext runningFlowContext, ExtractorResult extractorResult) {
log.debug("[MemorySelector] 触发记忆回溯..."); log.debug("[MemorySelector] 触发记忆回溯...");
//查找切片 //查找切片
@@ -71,7 +62,6 @@ public class MemorySelector extends PreRunningAbstractAgentModuleAbstract {
log.debug("[MemorySelector] 切片评估结果: {}", JSONObject.toJSONString(memorySlices)); log.debug("[MemorySelector] 切片评估结果: {}", JSONObject.toJSONString(memorySlices));
return memorySlices; return memorySlices;
} }
private void setModuleContextRecall(PartnerRunningFlowContext runningFlowContext) { private void setModuleContextRecall(PartnerRunningFlowContext runningFlowContext) {
String userId = runningFlowContext.getUserId(); String userId = runningFlowContext.getUserId();
boolean recall = memoryCapability.hasActivatedSlices(userId); boolean recall = memoryCapability.hasActivatedSlices(userId);
@@ -80,8 +70,6 @@ public class MemorySelector extends PreRunningAbstractAgentModuleAbstract {
runningFlowContext.getModuleContext().getExtraContext().put("recall_count", memoryCapability.getActivatedSlicesSize(userId)); runningFlowContext.getModuleContext().getExtraContext().put("recall_count", memoryCapability.getActivatedSlicesSize(userId));
} }
} }
private void setMemoryResultList(List<MemoryResult> memoryResultList, List<ExtractorMatchData> matches, String userId) { private void setMemoryResultList(List<MemoryResult> memoryResultList, List<ExtractorMatchData> matches, String userId) {
for (ExtractorMatchData match : matches) { for (ExtractorMatchData match : matches) {
try { try {
@@ -101,7 +89,6 @@ public class MemorySelector extends PreRunningAbstractAgentModuleAbstract {
} }
//清理切片记录 //清理切片记录
memoryCapability.cleanSelectedSliceFilter(); memoryCapability.cleanSelectedSliceFilter();
//根据userInfo过滤是否为私人记忆 //根据userInfo过滤是否为私人记忆
for (MemoryResult memoryResult : memoryResultList) { for (MemoryResult memoryResult : memoryResultList) {
//过滤终点记忆 //过滤终点记忆
@@ -110,25 +97,21 @@ public class MemorySelector extends PreRunningAbstractAgentModuleAbstract {
memoryResult.getRelatedMemorySliceResult().removeIf(m -> removeOrNot(m, userId)); memoryResult.getRelatedMemorySliceResult().removeIf(m -> removeOrNot(m, userId));
} }
} }
private void removeDuplicateSlice(MemoryResult memoryResult) { private void removeDuplicateSlice(MemoryResult memoryResult) {
Collection<String> values = memoryCapability.getDialogMap().values(); Collection<String> values = memoryCapability.getDialogMap().values();
memoryResult.getRelatedMemorySliceResult().removeIf(m -> values.contains(m.getSummary())); memoryResult.getRelatedMemorySliceResult().removeIf(m -> values.contains(m.getSummary()));
memoryResult.getMemorySliceResult().removeIf(m -> values.contains(m.getMemorySlice().getSummary())); memoryResult.getMemorySliceResult().removeIf(m -> values.contains(m.getMemorySlice().getSummary()));
} }
private boolean removeOrNot(MemorySlice memorySlice, String userId) { private boolean removeOrNot(MemorySlice memorySlice, String userId) {
if (memorySlice.isPrivate()) { if (memorySlice.isPrivate()) {
return memorySlice.getStartUserId().equals(userId); return memorySlice.getStartUserId().equals(userId);
} }
return false; return false;
} }
@Override @Override
public String moduleName() { public String moduleName() {
return "[记忆模块]"; return "[记忆模块]";
} }
@Override @Override
protected Map<String, String> getPromptDataMap(PartnerRunningFlowContext context) { protected Map<String, String> getPromptDataMap(PartnerRunningFlowContext context) {
HashMap<String, String> map = new HashMap<>(); HashMap<String, String> map = new HashMap<>();
@@ -137,12 +120,10 @@ public class MemorySelector extends PreRunningAbstractAgentModuleAbstract {
if (!dialogMapStr.isEmpty()) { if (!dialogMapStr.isEmpty()) {
map.put("[记忆缓存] <你最近两日和所有聊天者的对话记忆印象>", dialogMapStr); map.put("[记忆缓存] <你最近两日和所有聊天者的对话记忆印象>", dialogMapStr);
} }
String userDialogMapStr = memoryCapability.getUserDialogMapStr(userId); String userDialogMapStr = memoryCapability.getUserDialogMapStr(userId);
if (userDialogMapStr != null && !userDialogMapStr.isEmpty() && !cognationCapability.isSingleUser()) { if (userDialogMapStr != null && !userDialogMapStr.isEmpty() && !cognationCapability.isSingleUser()) {
map.put("[用户记忆缓存] <与最新一条消息的发送者的近两天对话记忆印象, 可能与[记忆缓存]稍有重复>", userDialogMapStr); map.put("[用户记忆缓存] <与最新一条消息的发送者的近两天对话记忆印象, 可能与[记忆缓存]稍有重复>", userDialogMapStr);
} }
String sliceStr = memoryCapability.getActivatedSlicesStr(userId); String sliceStr = memoryCapability.getActivatedSlicesStr(userId);
if (sliceStr != null && !sliceStr.isEmpty()) { if (sliceStr != null && !sliceStr.isEmpty()) {
map.put("[记忆切片] <你与最新一条消息的发送者的相关回忆, 不会与[记忆缓存]重复, 如果有重复你也可以指出来>", sliceStr); map.put("[记忆切片] <你与最新一条消息的发送者的相关回忆, 不会与[记忆缓存]重复, 如果有重复你也可以指出来>", sliceStr);
@@ -150,4 +131,8 @@ public class MemorySelector extends PreRunningAbstractAgentModuleAbstract {
return map; return map;
} }
@Override
public int order() {
return 2;
}
} }

View File

@@ -5,10 +5,8 @@ import cn.hutool.json.JSONUtil;
import com.alibaba.fastjson2.JSONObject; import com.alibaba.fastjson2.JSONObject;
import lombok.Data; import lombok.Data;
import lombok.EqualsAndHashCode; import lombok.EqualsAndHashCode;
import lombok.extern.slf4j.Slf4j; import work.slhaf.partner.api.agent.factory.module.abstracts.AbstractAgentModule;
import work.slhaf.partner.api.agent.factory.module.abstracts.AbstractAgentSubModule;
import work.slhaf.partner.api.agent.factory.module.abstracts.ActivateModel; import work.slhaf.partner.api.agent.factory.module.abstracts.ActivateModel;
import work.slhaf.partner.api.agent.factory.module.annotation.AgentSubModule;
import work.slhaf.partner.api.agent.factory.module.annotation.Init; import work.slhaf.partner.api.agent.factory.module.annotation.Init;
import work.slhaf.partner.common.thread.InteractionThreadPoolExecutor; import work.slhaf.partner.common.thread.InteractionThreadPoolExecutor;
import work.slhaf.partner.core.memory.pojo.EvaluatedSlice; import work.slhaf.partner.core.memory.pojo.EvaluatedSlice;
@@ -27,20 +25,14 @@ import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicInteger;
import static work.slhaf.partner.common.util.ExtractUtil.extractJson; import static work.slhaf.partner.common.util.ExtractUtil.extractJson;
@EqualsAndHashCode(callSuper = true) @EqualsAndHashCode(callSuper = true)
@Data @Data
@Slf4j public class SliceSelectEvaluator extends AbstractAgentModule.Sub<EvaluatorInput, List<EvaluatedSlice>> implements ActivateModel {
@AgentSubModule
public class SliceSelectEvaluator extends AbstractAgentSubModule<EvaluatorInput, List<EvaluatedSlice>> implements ActivateModel {
private InteractionThreadPoolExecutor executor; private InteractionThreadPoolExecutor executor;
@Init @Init
public void init() { public void init() {
executor = InteractionThreadPoolExecutor.getInstance(); executor = InteractionThreadPoolExecutor.getInstance();
} }
@Override @Override
public List<EvaluatedSlice> execute(EvaluatorInput evaluatorInput) { public List<EvaluatedSlice> execute(EvaluatorInput evaluatorInput) {
log.debug("[SliceSelectEvaluator] 切片评估模块开始..."); log.debug("[SliceSelectEvaluator] 切片评估模块开始...");
@@ -83,16 +75,13 @@ public class SliceSelectEvaluator extends AbstractAgentSubModule<EvaluatorInput,
return null; return null;
}); });
} }
executor.invokeAll(tasks, 30, TimeUnit.SECONDS); executor.invokeAll(tasks, 30, TimeUnit.SECONDS);
log.debug("[SliceSelectEvaluator] 评估模块结束, 输出队列: {}", queue); log.debug("[SliceSelectEvaluator] 评估模块结束, 输出队列: {}", queue);
List<EvaluatedSlice> temp = queue.stream().toList(); List<EvaluatedSlice> temp = queue.stream().toList();
return new ArrayList<>(temp); return new ArrayList<>(temp);
} }
private void setSliceSummaryList(MemoryResult memoryResult, List<SliceSummary> sliceSummaryList, Map<Long, SliceSummary> map) { private void setSliceSummaryList(MemoryResult memoryResult, List<SliceSummary> sliceSummaryList, Map<Long, SliceSummary> map) {
for (MemorySliceResult memorySliceResult : memoryResult.getMemorySliceResult()) { for (MemorySliceResult memorySliceResult : memoryResult.getMemorySliceResult()) {
SliceSummary sliceSummary = new SliceSummary(); SliceSummary sliceSummary = new SliceSummary();
sliceSummary.setId(memorySliceResult.getMemorySlice().getTimestamp()); sliceSummary.setId(memorySliceResult.getMemorySlice().getTimestamp());
StringBuilder stringBuilder = new StringBuilder(); StringBuilder stringBuilder = new StringBuilder();
@@ -109,29 +98,22 @@ public class SliceSelectEvaluator extends AbstractAgentSubModule<EvaluatorInput,
sliceSummary.setSummary(stringBuilder.toString()); sliceSummary.setSummary(stringBuilder.toString());
Long timestamp = memorySliceResult.getMemorySlice().getTimestamp(); Long timestamp = memorySliceResult.getMemorySlice().getTimestamp();
sliceSummary.setDate(DateUtil.date(timestamp).toLocalDateTime().toLocalDate()); sliceSummary.setDate(DateUtil.date(timestamp).toLocalDateTime().toLocalDate());
sliceSummaryList.add(sliceSummary); sliceSummaryList.add(sliceSummary);
map.put(timestamp, sliceSummary); map.put(timestamp, sliceSummary);
} }
for (MemorySlice memorySlice : memoryResult.getRelatedMemorySliceResult()) { for (MemorySlice memorySlice : memoryResult.getRelatedMemorySliceResult()) {
SliceSummary sliceSummary = new SliceSummary(); SliceSummary sliceSummary = new SliceSummary();
sliceSummary.setId(memorySlice.getTimestamp()); sliceSummary.setId(memorySlice.getTimestamp());
sliceSummary.setSummary(memorySlice.getSummary()); sliceSummary.setSummary(memorySlice.getSummary());
sliceSummaryList.add(sliceSummary); sliceSummaryList.add(sliceSummary);
map.put(memorySlice.getTimestamp(), sliceSummary); map.put(memorySlice.getTimestamp(), sliceSummary);
} }
} }
public String modelKey() { public String modelKey() {
return "slice_evaluator"; return "slice_evaluator";
} }
@Override @Override
public boolean withBasicPrompt() { public boolean withBasicPrompt() {
return false; return false;
} }
} }

View File

@@ -4,11 +4,9 @@ import cn.hutool.json.JSONUtil;
import com.alibaba.fastjson2.JSONObject; import com.alibaba.fastjson2.JSONObject;
import lombok.Data; import lombok.Data;
import lombok.EqualsAndHashCode; import lombok.EqualsAndHashCode;
import lombok.extern.slf4j.Slf4j;
import work.slhaf.partner.api.agent.factory.capability.annotation.InjectCapability; import work.slhaf.partner.api.agent.factory.capability.annotation.InjectCapability;
import work.slhaf.partner.api.agent.factory.module.abstracts.AbstractAgentSubModule; import work.slhaf.partner.api.agent.factory.module.abstracts.AbstractAgentModule;
import work.slhaf.partner.api.agent.factory.module.abstracts.ActivateModel; import work.slhaf.partner.api.agent.factory.module.abstracts.ActivateModel;
import work.slhaf.partner.api.agent.factory.module.annotation.AgentSubModule;
import work.slhaf.partner.api.chat.pojo.Message; import work.slhaf.partner.api.chat.pojo.Message;
import work.slhaf.partner.api.chat.pojo.MetaMessage; import work.slhaf.partner.api.chat.pojo.MetaMessage;
import work.slhaf.partner.core.cognation.CognationCapability; import work.slhaf.partner.core.cognation.CognationCapability;
@@ -24,20 +22,14 @@ import java.util.List;
import static work.slhaf.partner.common.util.ExtractUtil.extractJson; import static work.slhaf.partner.common.util.ExtractUtil.extractJson;
import static work.slhaf.partner.common.util.ExtractUtil.fixTopicPath; import static work.slhaf.partner.common.util.ExtractUtil.fixTopicPath;
@EqualsAndHashCode(callSuper = true) @EqualsAndHashCode(callSuper = true)
@Data @Data
@Slf4j public class MemorySelectExtractor extends AbstractAgentModule.Sub<PartnerRunningFlowContext, ExtractorResult>
@AgentSubModule
public class MemorySelectExtractor extends AbstractAgentSubModule<PartnerRunningFlowContext, ExtractorResult>
implements ActivateModel { implements ActivateModel {
@InjectCapability @InjectCapability
private MemoryCapability memoryCapability; private MemoryCapability memoryCapability;
@InjectCapability @InjectCapability
private CognationCapability cognationCapability; private CognationCapability cognationCapability;
@Override @Override
public ExtractorResult execute(PartnerRunningFlowContext context) { public ExtractorResult execute(PartnerRunningFlowContext context) {
log.debug("[MemorySelectExtractor] 主题提取模块开始..."); log.debug("[MemorySelectExtractor] 主题提取模块开始...");
@@ -52,7 +44,6 @@ public class MemorySelectExtractor extends AbstractAgentSubModule<PartnerRunning
chatMessages.add(metaMessage.getAssistantMessage()); chatMessages.add(metaMessage.getAssistantMessage());
} }
} }
ExtractorResult extractorResult; ExtractorResult extractorResult;
try { try {
List<EvaluatedSlice> activatedMemorySlices = memoryCapability.getActivatedSlices(context.getUserId()); List<EvaluatedSlice> activatedMemorySlices = memoryCapability.getActivatedSlices(context.getUserId());
@@ -75,7 +66,6 @@ public class MemorySelectExtractor extends AbstractAgentSubModule<PartnerRunning
} }
return fix(extractorResult); return fix(extractorResult);
} }
private ExtractorResult fix(ExtractorResult extractorResult) { private ExtractorResult fix(ExtractorResult extractorResult) {
extractorResult.getMatches().forEach(m -> { extractorResult.getMatches().forEach(m -> {
if (m.getType().equals(ExtractorMatchData.Constant.DATE)) { if (m.getType().equals(ExtractorMatchData.Constant.DATE)) {
@@ -89,15 +79,12 @@ public class MemorySelectExtractor extends AbstractAgentSubModule<PartnerRunning
extractorResult.getMatches().removeIf(m -> m.getText().split("->")[0].isEmpty()); extractorResult.getMatches().removeIf(m -> m.getText().split("->")[0].isEmpty());
return extractorResult; return extractorResult;
} }
@Override @Override
public String modelKey() { public String modelKey() {
return "topic_extractor"; return "topic_extractor";
} }
@Override @Override
public boolean withBasicPrompt() { public boolean withBasicPrompt() {
return false; return false;
} }
} }

View File

@@ -3,9 +3,7 @@ package work.slhaf.partner.module.modules.memory.updater;
import com.alibaba.fastjson2.JSONObject; import com.alibaba.fastjson2.JSONObject;
import lombok.Data; import lombok.Data;
import lombok.EqualsAndHashCode; import lombok.EqualsAndHashCode;
import lombok.extern.slf4j.Slf4j;
import work.slhaf.partner.api.agent.factory.capability.annotation.InjectCapability; import work.slhaf.partner.api.agent.factory.capability.annotation.InjectCapability;
import work.slhaf.partner.api.agent.factory.module.annotation.AgentRunningModule;
import work.slhaf.partner.api.agent.factory.module.annotation.Init; import work.slhaf.partner.api.agent.factory.module.annotation.Init;
import work.slhaf.partner.api.agent.factory.module.annotation.InjectModule; import work.slhaf.partner.api.agent.factory.module.annotation.InjectModule;
import work.slhaf.partner.api.chat.constant.ChatConstant; import work.slhaf.partner.api.chat.constant.ChatConstant;
@@ -29,11 +27,8 @@ import java.util.concurrent.Callable;
import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicInteger;
import static work.slhaf.partner.common.util.ExtractUtil.extractUserId; import static work.slhaf.partner.common.util.ExtractUtil.extractUserId;
@EqualsAndHashCode(callSuper = true) @EqualsAndHashCode(callSuper = true)
@Data @Data
@Slf4j
@AgentRunningModule(name = "memory_updater", order = 7)
public class MemoryUpdater extends PostRunningAbstractAgentModuleAbstract { public class MemoryUpdater extends PostRunningAbstractAgentModuleAbstract {
private static final long SCHEDULED_UPDATE_INTERVAL = 10 * 1000; private static final long SCHEDULED_UPDATE_INTERVAL = 10 * 1000;
@@ -52,19 +47,16 @@ public class MemoryUpdater extends PostRunningAbstractAgentModuleAbstract {
private SingleSummarizer singleSummarizer; private SingleSummarizer singleSummarizer;
@InjectModule @InjectModule
private TotalSummarizer totalSummarizer; private TotalSummarizer totalSummarizer;
private InteractionThreadPoolExecutor executor; private InteractionThreadPoolExecutor executor;
/** /**
* 用于临时存储完整对话记录在MemoryManager的分离后 * 用于临时存储完整对话记录在MemoryManager的分离后
*/ */
private List<Message> tempMessage; private List<Message> tempMessage;
@Init @Init
public void init() { public void init() {
executor = InteractionThreadPoolExecutor.getInstance(); executor = InteractionThreadPoolExecutor.getInstance();
setScheduledUpdater(); setScheduledUpdater();
} }
private void setScheduledUpdater() { private void setScheduledUpdater() {
executor.execute(() -> { executor.execute(() -> {
log.info("[MemoryUpdater] 记忆自动更新线程启动"); log.info("[MemoryUpdater] 记忆自动更新线程启动");
@@ -88,7 +80,6 @@ public class MemoryUpdater extends PostRunningAbstractAgentModuleAbstract {
log.info("[MemoryUpdater] 记忆自动更新线程结束"); log.info("[MemoryUpdater] 记忆自动更新线程结束");
}); });
} }
@Override @Override
public void doExecute(PartnerRunningFlowContext context) { public void doExecute(PartnerRunningFlowContext context) {
if (context.isFinished()) { if (context.isFinished()) {
@@ -118,12 +109,10 @@ public class MemoryUpdater extends PostRunningAbstractAgentModuleAbstract {
} }
}); });
} }
@Override @Override
protected boolean relyOnMessage() { protected boolean relyOnMessage() {
return true; return true;
} }
private void updateMemory() { private void updateMemory() {
log.debug("[MemoryUpdater] 记忆更新流程开始..."); log.debug("[MemoryUpdater] 记忆更新流程开始...");
tempMessage = new ArrayList<>(cognationCapability.getChatMessages()); tempMessage = new ArrayList<>(cognationCapability.getChatMessages());
@@ -135,7 +124,6 @@ public class MemoryUpdater extends PostRunningAbstractAgentModuleAbstract {
cognationCapability.resetLastUpdatedTime(); cognationCapability.resetLastUpdatedTime();
log.debug("[MemoryUpdater] 记忆更新流程结束..."); log.debug("[MemoryUpdater] 记忆更新流程结束...");
} }
private void updateMultiChatSlices(HashMap<String, String> singleMemorySummary) { private void updateMultiChatSlices(HashMap<String, String> singleMemorySummary) {
// 此时chatMessages中不再包含单聊记录直接执行摘要以及切片插入 // 此时chatMessages中不再包含单聊记录直接执行摘要以及切片插入
// 对剩下的多人聊天记录进行进行摘要 // 对剩下的多人聊天记录进行进行摘要
@@ -161,21 +149,17 @@ public class MemoryUpdater extends PostRunningAbstractAgentModuleAbstract {
// 设置involvedUserId // 设置involvedUserId
setInvolvedUserId(userId, memorySlice, chatMessages); setInvolvedUserId(userId, memorySlice, chatMessages);
memoryCapability.insertSlice(memorySlice, summarizeResult.getTopicPath()); memoryCapability.insertSlice(memorySlice, summarizeResult.getTopicPath());
memoryCapability.updateDialogMap(LocalDateTime.now(), summarizeResult.getSummary()); memoryCapability.updateDialogMap(LocalDateTime.now(), summarizeResult.getSummary());
} else { } else {
log.debug("[MemoryUpdater] 不存在多人聊天记录, 将以单聊总结为对话缓存的主要输入: {}", singleMemorySummary); log.debug("[MemoryUpdater] 不存在多人聊天记录, 将以单聊总结为对话缓存的主要输入: {}", singleMemorySummary);
memoryCapability.updateDialogMap(LocalDateTime.now(), totalSummarizer.execute(singleMemorySummary)); memoryCapability.updateDialogMap(LocalDateTime.now(), totalSummarizer.execute(singleMemorySummary));
} }
log.debug("[MemoryUpdater] 对话缓存更新完毕"); log.debug("[MemoryUpdater] 对话缓存更新完毕");
log.debug("[MemoryUpdater] 多人聊天记忆更新流程结束..."); log.debug("[MemoryUpdater] 多人聊天记忆更新流程结束...");
return null; return null;
}; };
executor.invokeAll(List.of(task)); executor.invokeAll(List.of(task));
} }
private void cleanMessage(List<Message> chatMessages) { private void cleanMessage(List<Message> chatMessages) {
// 清理时间标识 // 清理时间标识
for (Message message : chatMessages) { for (Message message : chatMessages) {
@@ -186,7 +170,6 @@ public class MemoryUpdater extends PostRunningAbstractAgentModuleAbstract {
message.setContent(message.getContent().replace("\r\n**" + time, "")); message.setContent(message.getContent().replace("\r\n**" + time, ""));
} }
} }
private void clearChatMessages() { private void clearChatMessages() {
// 不全部清空,保留一部分输入防止上下文割裂 // 不全部清空,保留一部分输入防止上下文割裂
cognationCapability.getMessageLock().lock(); cognationCapability.getMessageLock().lock();
@@ -196,7 +179,6 @@ public class MemoryUpdater extends PostRunningAbstractAgentModuleAbstract {
cognationCapability.getChatMessages().addAll(0, temp); cognationCapability.getChatMessages().addAll(0, temp);
cognationCapability.getMessageLock().unlock(); cognationCapability.getMessageLock().unlock();
} }
private void setInvolvedUserId(String startUserId, MemorySlice memorySlice, List<Message> chatMessages) { private void setInvolvedUserId(String startUserId, MemorySlice memorySlice, List<Message> chatMessages) {
for (Message chatMessage : chatMessages) { for (Message chatMessage : chatMessages) {
if (chatMessage.getRole().equals(ChatConstant.Character.ASSISTANT)) { if (chatMessage.getRole().equals(ChatConstant.Character.ASSISTANT)) {
@@ -214,7 +196,6 @@ public class MemoryUpdater extends PostRunningAbstractAgentModuleAbstract {
memorySlice.getInvolvedUserIds().add(userId); memorySlice.getInvolvedUserIds().add(userId);
} }
} }
private void updateSingleChatSlices(HashMap<String, String> singleMemorySummary) { private void updateSingleChatSlices(HashMap<String, String> singleMemorySummary) {
log.debug("[MemoryUpdater] 单聊记忆更新流程开始..."); log.debug("[MemoryUpdater] 单聊记忆更新流程开始...");
// 更新单聊记忆同时从chatMessages中去掉单聊记忆 // 更新单聊记忆同时从chatMessages中去掉单聊记忆
@@ -247,23 +228,19 @@ public class MemoryUpdater extends PostRunningAbstractAgentModuleAbstract {
} }
return null; return null;
}); });
} }
executor.invokeAll(tasks); executor.invokeAll(tasks);
log.debug("[MemoryUpdater] 单聊记忆更新结束..."); log.debug("[MemoryUpdater] 单聊记忆更新结束...");
} }
private SummarizeResult summarize(SummarizeInput summarizeInput) { private SummarizeResult summarize(SummarizeInput summarizeInput) {
singleSummarizer.execute(summarizeInput.getChatMessages()); singleSummarizer.execute(summarizeInput.getChatMessages());
return multiSummarizer.execute(summarizeInput); return multiSummarizer.execute(summarizeInput);
} }
private MemorySlice getMemorySlice(String userId, SummarizeResult summarizeResult, List<Message> chatMessages) { private MemorySlice getMemorySlice(String userId, SummarizeResult summarizeResult, List<Message> chatMessages) {
MemorySlice memorySlice = new MemorySlice(); MemorySlice memorySlice = new MemorySlice();
// 设置 memoryId,timestamp // 设置 memoryId,timestamp
memorySlice.setMemoryId(cognationCapability.getCurrentMemoryId()); memorySlice.setMemoryId(cognationCapability.getCurrentMemoryId());
memorySlice.setTimestamp(System.currentTimeMillis()); memorySlice.setTimestamp(System.currentTimeMillis());
// 补充信息 // 补充信息
memorySlice.setPrivate(summarizeResult.isPrivate()); memorySlice.setPrivate(summarizeResult.isPrivate());
memorySlice.setSummary(summarizeResult.getSummary()); memorySlice.setSummary(summarizeResult.getSummary());
@@ -277,4 +254,9 @@ public class MemoryUpdater extends PostRunningAbstractAgentModuleAbstract {
memorySlice.setRelatedTopics(relatedTopicPathList); memorySlice.setRelatedTopics(relatedTopicPathList);
return memorySlice; return memorySlice;
} }
@Override
public int order() {
return 7;
}
} }

View File

@@ -4,10 +4,8 @@ import cn.hutool.json.JSONUtil;
import com.alibaba.fastjson2.JSONObject; import com.alibaba.fastjson2.JSONObject;
import lombok.Data; import lombok.Data;
import lombok.EqualsAndHashCode; import lombok.EqualsAndHashCode;
import lombok.extern.slf4j.Slf4j; import work.slhaf.partner.api.agent.factory.module.abstracts.AbstractAgentModule;
import work.slhaf.partner.api.agent.factory.module.abstracts.AbstractAgentSubModule;
import work.slhaf.partner.api.agent.factory.module.abstracts.ActivateModel; import work.slhaf.partner.api.agent.factory.module.abstracts.ActivateModel;
import work.slhaf.partner.api.agent.factory.module.annotation.AgentSubModule;
import work.slhaf.partner.api.agent.factory.module.annotation.Init; import work.slhaf.partner.api.agent.factory.module.annotation.Init;
import work.slhaf.partner.api.chat.pojo.ChatResponse; import work.slhaf.partner.api.chat.pojo.ChatResponse;
import work.slhaf.partner.module.modules.memory.updater.summarizer.entity.SummarizeInput; import work.slhaf.partner.module.modules.memory.updater.summarizer.entity.SummarizeInput;
@@ -18,18 +16,13 @@ import java.util.List;
import static work.slhaf.partner.common.util.ExtractUtil.extractJson; import static work.slhaf.partner.common.util.ExtractUtil.extractJson;
import static work.slhaf.partner.common.util.ExtractUtil.fixTopicPath; import static work.slhaf.partner.common.util.ExtractUtil.fixTopicPath;
@EqualsAndHashCode(callSuper = true) @EqualsAndHashCode(callSuper = true)
@Data @Data
@Slf4j public class MultiSummarizer extends AbstractAgentModule.Sub<SummarizeInput, SummarizeResult> implements ActivateModel {
@AgentSubModule
public class MultiSummarizer extends AbstractAgentSubModule<SummarizeInput, SummarizeResult> implements ActivateModel {
@Init @Init
public void init() { public void init() {
updateChatClientSettings(); updateChatClientSettings();
} }
@Override @Override
public SummarizeResult execute(SummarizeInput input) { public SummarizeResult execute(SummarizeInput input) {
log.debug("[MemorySummarizer] 整体摘要开始..."); log.debug("[MemorySummarizer] 整体摘要开始...");
@@ -38,12 +31,10 @@ public class MultiSummarizer extends AbstractAgentSubModule<SummarizeInput, Summ
SummarizeResult result = JSONObject.parseObject(extractJson(response.getMessage()), SummarizeResult.class); SummarizeResult result = JSONObject.parseObject(extractJson(response.getMessage()), SummarizeResult.class);
return fix(result); return fix(result);
} }
private SummarizeResult fix(SummarizeResult result) { private SummarizeResult fix(SummarizeResult result) {
if (result == null || result.getTopicPath() == null || result.getTopicPath().isEmpty()) { if (result == null || result.getTopicPath() == null || result.getTopicPath().isEmpty()) {
return result; return result;
} }
String topicPath = fixTopicPath(result.getTopicPath()); String topicPath = fixTopicPath(result.getTopicPath());
List<String> relatedTopicPath = new ArrayList<>(); List<String> relatedTopicPath = new ArrayList<>();
for (String s : result.getRelatedTopicPath()) { for (String s : result.getRelatedTopicPath()) {
@@ -53,15 +44,12 @@ public class MultiSummarizer extends AbstractAgentSubModule<SummarizeInput, Summ
result.setRelatedTopicPath(relatedTopicPath); result.setRelatedTopicPath(relatedTopicPath);
return result; return result;
} }
@Override @Override
public String modelKey() { public String modelKey() {
return "multi_summarizer"; return "multi_summarizer";
} }
@Override @Override
public boolean withBasicPrompt() { public boolean withBasicPrompt() {
return true; return true;
} }
} }

View File

@@ -3,10 +3,8 @@ package work.slhaf.partner.module.modules.memory.updater.summarizer;
import com.alibaba.fastjson2.JSONObject; import com.alibaba.fastjson2.JSONObject;
import lombok.Data; import lombok.Data;
import lombok.EqualsAndHashCode; import lombok.EqualsAndHashCode;
import lombok.extern.slf4j.Slf4j; import work.slhaf.partner.api.agent.factory.module.abstracts.AbstractAgentModule;
import work.slhaf.partner.api.agent.factory.module.abstracts.AbstractAgentSubModule;
import work.slhaf.partner.api.agent.factory.module.abstracts.ActivateModel; import work.slhaf.partner.api.agent.factory.module.abstracts.ActivateModel;
import work.slhaf.partner.api.agent.factory.module.annotation.AgentSubModule;
import work.slhaf.partner.api.agent.factory.module.annotation.Init; import work.slhaf.partner.api.agent.factory.module.annotation.Init;
import work.slhaf.partner.api.chat.constant.ChatConstant; import work.slhaf.partner.api.chat.constant.ChatConstant;
import work.slhaf.partner.api.chat.pojo.ChatResponse; import work.slhaf.partner.api.chat.pojo.ChatResponse;
@@ -18,20 +16,14 @@ import java.util.List;
import java.util.concurrent.Callable; import java.util.concurrent.Callable;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicInteger;
@EqualsAndHashCode(callSuper = true) @EqualsAndHashCode(callSuper = true)
@Slf4j
@Data @Data
@AgentSubModule public class SingleSummarizer extends AbstractAgentModule.Sub<List<Message>, Void> implements ActivateModel {
public class SingleSummarizer extends AbstractAgentSubModule<List<Message>, Void> implements ActivateModel {
private InteractionThreadPoolExecutor executor; private InteractionThreadPoolExecutor executor;
@Init @Init
public void init() { public void init() {
this.executor = InteractionThreadPoolExecutor.getInstance(); this.executor = InteractionThreadPoolExecutor.getInstance();
} }
@Override @Override
public Void execute(List<Message> chatMessages) { public Void execute(List<Message> chatMessages) {
log.debug("[MemorySummarizer] 长文本摘要开始..."); log.debug("[MemorySummarizer] 长文本摘要开始...");
@@ -55,7 +47,6 @@ public class SingleSummarizer extends AbstractAgentSubModule<List<Message>, Void
log.debug("[MemorySummarizer] 长文本摘要结束"); log.debug("[MemorySummarizer] 长文本摘要结束");
return null; return null;
} }
private String singleExecute(String primaryContent) { private String singleExecute(String primaryContent) {
try { try {
ChatResponse response = this.singleChat(primaryContent); ChatResponse response = this.singleChat(primaryContent);
@@ -65,15 +56,12 @@ public class SingleSummarizer extends AbstractAgentSubModule<List<Message>, Void
return primaryContent; return primaryContent;
} }
} }
@Override @Override
public String modelKey() { public String modelKey() {
return "single_summarizer"; return "single_summarizer";
} }
@Override @Override
public boolean withBasicPrompt() { public boolean withBasicPrompt() {
return false; return false;
} }
} }

View File

@@ -4,41 +4,31 @@ import cn.hutool.json.JSONUtil;
import com.alibaba.fastjson2.JSONObject; import com.alibaba.fastjson2.JSONObject;
import lombok.Data; import lombok.Data;
import lombok.EqualsAndHashCode; import lombok.EqualsAndHashCode;
import lombok.extern.slf4j.Slf4j; import work.slhaf.partner.api.agent.factory.module.abstracts.AbstractAgentModule;
import work.slhaf.partner.api.agent.factory.module.abstracts.AbstractAgentSubModule;
import work.slhaf.partner.api.agent.factory.module.abstracts.ActivateModel; import work.slhaf.partner.api.agent.factory.module.abstracts.ActivateModel;
import work.slhaf.partner.api.agent.factory.module.annotation.AgentSubModule;
import work.slhaf.partner.api.agent.factory.module.annotation.Init; import work.slhaf.partner.api.agent.factory.module.annotation.Init;
import work.slhaf.partner.api.chat.pojo.ChatResponse; import work.slhaf.partner.api.chat.pojo.ChatResponse;
import java.util.HashMap; import java.util.HashMap;
import static work.slhaf.partner.common.util.ExtractUtil.extractJson; import static work.slhaf.partner.common.util.ExtractUtil.extractJson;
@EqualsAndHashCode(callSuper = true) @EqualsAndHashCode(callSuper = true)
@Data @Data
@Slf4j public class TotalSummarizer extends AbstractAgentModule.Sub<HashMap<String, String>, String> implements ActivateModel {
@AgentSubModule
public class TotalSummarizer extends AbstractAgentSubModule<HashMap<String, String>, String> implements ActivateModel {
@Init @Init
public void init() { public void init() {
updateChatClientSettings(); updateChatClientSettings();
} }
public String execute(HashMap<String, String> singleMemorySummary){ public String execute(HashMap<String, String> singleMemorySummary){
ChatResponse response = this.singleChat(JSONUtil.toJsonPrettyStr(singleMemorySummary)); ChatResponse response = this.singleChat(JSONUtil.toJsonPrettyStr(singleMemorySummary));
return JSONObject.parseObject(extractJson(response.getMessage())).getString("content"); return JSONObject.parseObject(extractJson(response.getMessage())).getString("content");
} }
@Override @Override
public String modelKey() { public String modelKey() {
return "total_summarizer"; return "total_summarizer";
} }
@Override @Override
public boolean withBasicPrompt() { public boolean withBasicPrompt() {
return true; return true;
} }
} }

View File

@@ -1,9 +1,7 @@
package work.slhaf.partner.module.modules.perceive.selector; package work.slhaf.partner.module.modules.perceive.selector;
import lombok.Setter; import lombok.Setter;
import lombok.extern.slf4j.Slf4j;
import work.slhaf.partner.api.agent.factory.capability.annotation.InjectCapability; import work.slhaf.partner.api.agent.factory.capability.annotation.InjectCapability;
import work.slhaf.partner.api.agent.factory.module.annotation.AgentRunningModule;
import work.slhaf.partner.core.perceive.PerceiveCapability; import work.slhaf.partner.core.perceive.PerceiveCapability;
import work.slhaf.partner.core.perceive.pojo.User; import work.slhaf.partner.core.perceive.pojo.User;
import work.slhaf.partner.module.common.module.PreRunningAbstractAgentModuleAbstract; import work.slhaf.partner.module.common.module.PreRunningAbstractAgentModuleAbstract;
@@ -11,19 +9,13 @@ import work.slhaf.partner.runtime.interaction.data.context.PartnerRunningFlowCon
import java.util.HashMap; import java.util.HashMap;
import java.util.Map; import java.util.Map;
@Slf4j
@Setter @Setter
@AgentRunningModule(name = "perceive_selector", order = 2)
public class PerceiveSelector extends PreRunningAbstractAgentModuleAbstract { public class PerceiveSelector extends PreRunningAbstractAgentModuleAbstract {
@InjectCapability @InjectCapability
private PerceiveCapability perceiveCapability; private PerceiveCapability perceiveCapability;
@Override @Override
public void doExecute(PartnerRunningFlowContext context) { public void doExecute(PartnerRunningFlowContext context) {
} }
@Override @Override
protected Map<String, String> getPromptDataMap(PartnerRunningFlowContext context) { protected Map<String, String> getPromptDataMap(PartnerRunningFlowContext context) {
HashMap<String, String> map = new HashMap<>(); HashMap<String, String> map = new HashMap<>();
@@ -34,9 +26,13 @@ public class PerceiveSelector extends PreRunningAbstractAgentModuleAbstract {
map.put("[静态记忆] <你关于最新聊天用户的静态记忆>", user.getStaticMemory().toString()); map.put("[静态记忆] <你关于最新聊天用户的静态记忆>", user.getStaticMemory().toString());
return map; return map;
} }
@Override @Override
public String moduleName() { public String moduleName() {
return "[感知模块]"; return "[感知模块]";
} }
@Override
public int order() {
return 2;
}
} }

View File

@@ -2,9 +2,7 @@ package work.slhaf.partner.module.modules.perceive.updater;
import lombok.Data; import lombok.Data;
import lombok.EqualsAndHashCode; import lombok.EqualsAndHashCode;
import lombok.extern.slf4j.Slf4j;
import work.slhaf.partner.api.agent.factory.capability.annotation.InjectCapability; import work.slhaf.partner.api.agent.factory.capability.annotation.InjectCapability;
import work.slhaf.partner.api.agent.factory.module.annotation.AgentRunningModule;
import work.slhaf.partner.api.agent.factory.module.annotation.Init; import work.slhaf.partner.api.agent.factory.module.annotation.Init;
import work.slhaf.partner.api.agent.factory.module.annotation.InjectModule; import work.slhaf.partner.api.agent.factory.module.annotation.InjectModule;
import work.slhaf.partner.common.thread.InteractionThreadPoolExecutor; import work.slhaf.partner.common.thread.InteractionThreadPoolExecutor;
@@ -22,34 +20,25 @@ import java.util.HashMap;
import java.util.List; import java.util.List;
import java.util.concurrent.Callable; import java.util.concurrent.Callable;
import java.util.concurrent.locks.ReentrantLock; import java.util.concurrent.locks.ReentrantLock;
/** /**
* 感知更新,异步 * 感知更新,异步
*/ */
@EqualsAndHashCode(callSuper = true) @EqualsAndHashCode(callSuper = true)
@Slf4j
@Data @Data
@AgentRunningModule(name = "perceive_updater", order = 7)
public class PerceiveUpdater extends PostRunningAbstractAgentModuleAbstract { public class PerceiveUpdater extends PostRunningAbstractAgentModuleAbstract {
@InjectCapability @InjectCapability
private PerceiveCapability perceiveCapability; private PerceiveCapability perceiveCapability;
@InjectCapability @InjectCapability
private CognationCapability cognationCapability; private CognationCapability cognationCapability;
@InjectModule @InjectModule
private RelationExtractor relationExtractor; private RelationExtractor relationExtractor;
@InjectModule @InjectModule
private StaticMemoryExtractor staticMemoryExtractor; private StaticMemoryExtractor staticMemoryExtractor;
private InteractionThreadPoolExecutor executor; private InteractionThreadPoolExecutor executor;
@Init @Init
public void init() { public void init() {
this.executor = InteractionThreadPoolExecutor.getInstance(); this.executor = InteractionThreadPoolExecutor.getInstance();
} }
@Override @Override
public void doExecute(PartnerRunningFlowContext context) { public void doExecute(PartnerRunningFlowContext context) {
executor.execute(() -> { executor.execute(() -> {
@@ -69,12 +58,10 @@ public class PerceiveUpdater extends PostRunningAbstractAgentModuleAbstract {
perceiveCapability.updateUser(user); perceiveCapability.updateUser(user);
}); });
} }
@Override @Override
protected boolean relyOnMessage() { protected boolean relyOnMessage() {
return true; return true;
} }
private void runRelationExtractorAction(PartnerRunningFlowContext context, ReentrantLock userLock, User user) { private void runRelationExtractorAction(PartnerRunningFlowContext context, ReentrantLock userLock, User user) {
RelationExtractResult relationExtractResult = relationExtractor.execute(context); RelationExtractResult relationExtractResult = relationExtractor.execute(context);
userLock.lock(); userLock.lock();
@@ -84,7 +71,6 @@ public class PerceiveUpdater extends PostRunningAbstractAgentModuleAbstract {
user.updateRelationChange(relationExtractResult.getRelationChangeHistory()); user.updateRelationChange(relationExtractResult.getRelationChangeHistory());
userLock.unlock(); userLock.unlock();
} }
private void runStaticExtractorAction(PartnerRunningFlowContext context, ReentrantLock userLock, User user) { private void runStaticExtractorAction(PartnerRunningFlowContext context, ReentrantLock userLock, User user) {
HashMap<String, String> newStaticMemory = staticMemoryExtractor.execute(context); HashMap<String, String> newStaticMemory = staticMemoryExtractor.execute(context);
userLock.lock(); userLock.lock();
@@ -92,4 +78,8 @@ public class PerceiveUpdater extends PostRunningAbstractAgentModuleAbstract {
userLock.unlock(); userLock.unlock();
} }
@Override
public int order() {
return 7;
}
} }

View File

@@ -4,9 +4,8 @@ import com.alibaba.fastjson2.JSONObject;
import lombok.Data; import lombok.Data;
import lombok.EqualsAndHashCode; import lombok.EqualsAndHashCode;
import work.slhaf.partner.api.agent.factory.capability.annotation.InjectCapability; import work.slhaf.partner.api.agent.factory.capability.annotation.InjectCapability;
import work.slhaf.partner.api.agent.factory.module.abstracts.AbstractAgentSubModule; import work.slhaf.partner.api.agent.factory.module.abstracts.AbstractAgentModule;
import work.slhaf.partner.api.agent.factory.module.abstracts.ActivateModel; import work.slhaf.partner.api.agent.factory.module.abstracts.ActivateModel;
import work.slhaf.partner.api.agent.factory.module.annotation.AgentSubModule;
import work.slhaf.partner.api.chat.pojo.ChatResponse; import work.slhaf.partner.api.chat.pojo.ChatResponse;
import work.slhaf.partner.api.chat.pojo.Message; import work.slhaf.partner.api.chat.pojo.Message;
import work.slhaf.partner.core.cognation.CognationCapability; import work.slhaf.partner.core.cognation.CognationCapability;
@@ -19,20 +18,14 @@ import work.slhaf.partner.runtime.interaction.data.context.PartnerRunningFlowCon
import java.util.ArrayList; import java.util.ArrayList;
import java.util.HashMap; import java.util.HashMap;
import java.util.List; import java.util.List;
@EqualsAndHashCode(callSuper = true) @EqualsAndHashCode(callSuper = true)
@Data @Data
@AgentSubModule public class RelationExtractor extends AbstractAgentModule.Sub<PartnerRunningFlowContext, RelationExtractResult> implements ActivateModel {
public class RelationExtractor extends AbstractAgentSubModule<PartnerRunningFlowContext, RelationExtractResult> implements ActivateModel {
@InjectCapability @InjectCapability
private CognationCapability cognationCapability; private CognationCapability cognationCapability;
@InjectCapability @InjectCapability
private PerceiveCapability perceiveCapability; private PerceiveCapability perceiveCapability;
private List<Message> tempMessages; private List<Message> tempMessages;
@Override @Override
public RelationExtractResult execute(PartnerRunningFlowContext context){ public RelationExtractResult execute(PartnerRunningFlowContext context){
tempMessages = new ArrayList<>(cognationCapability.getChatMessages()); tempMessages = new ArrayList<>(cognationCapability.getChatMessages());
@@ -43,8 +36,6 @@ public class RelationExtractor extends AbstractAgentSubModule<PartnerRunningFlow
perceiveCapability.updateUser(user); perceiveCapability.updateUser(user);
return relationExtractResult; return relationExtractResult;
} }
private User getTempUser(PartnerRunningFlowContext context, RelationExtractResult relationExtractResult) { private User getTempUser(PartnerRunningFlowContext context, RelationExtractResult relationExtractResult) {
User user = new User(); User user = new User();
user.setUuid(context.getUserId()); user.setUuid(context.getUserId());
@@ -53,12 +44,10 @@ public class RelationExtractor extends AbstractAgentSubModule<PartnerRunningFlow
user.setAttitude(relationExtractResult.getAttitude()); user.setAttitude(relationExtractResult.getAttitude());
return user; return user;
} }
private RelationExtractResult getRelationResult(RelationExtractInput input) { private RelationExtractResult getRelationResult(RelationExtractInput input) {
ChatResponse response = singleChat(JSONObject.toJSONString(input)); ChatResponse response = singleChat(JSONObject.toJSONString(input));
return JSONObject.parseObject(response.getMessage(), RelationExtractResult.class); return JSONObject.parseObject(response.getMessage(), RelationExtractResult.class);
} }
private RelationExtractInput getRelationInput(String userId) { private RelationExtractInput getRelationInput(String userId) {
HashMap<String,String> map = new HashMap<>(); HashMap<String,String> map = new HashMap<>();
User user = perceiveCapability.getUser(userId); User user = perceiveCapability.getUser(userId);
@@ -72,15 +61,12 @@ public class RelationExtractor extends AbstractAgentSubModule<PartnerRunningFlow
input.setChatMessages(tempMessages); input.setChatMessages(tempMessages);
return input; return input;
} }
@Override @Override
public String modelKey() { public String modelKey() {
return "relation_extractor"; return "relation_extractor";
} }
@Override @Override
public boolean withBasicPrompt() { public boolean withBasicPrompt() {
return true; return true;
} }
} }

View File

@@ -5,9 +5,8 @@ import com.alibaba.fastjson2.JSONObject;
import lombok.Data; import lombok.Data;
import lombok.EqualsAndHashCode; import lombok.EqualsAndHashCode;
import work.slhaf.partner.api.agent.factory.capability.annotation.InjectCapability; import work.slhaf.partner.api.agent.factory.capability.annotation.InjectCapability;
import work.slhaf.partner.api.agent.factory.module.abstracts.AbstractAgentSubModule; import work.slhaf.partner.api.agent.factory.module.abstracts.AbstractAgentModule;
import work.slhaf.partner.api.agent.factory.module.abstracts.ActivateModel; import work.slhaf.partner.api.agent.factory.module.abstracts.ActivateModel;
import work.slhaf.partner.api.agent.factory.module.annotation.AgentSubModule;
import work.slhaf.partner.api.chat.pojo.ChatResponse; import work.slhaf.partner.api.chat.pojo.ChatResponse;
import work.slhaf.partner.core.cognation.CognationCapability; import work.slhaf.partner.core.cognation.CognationCapability;
import work.slhaf.partner.core.perceive.PerceiveCapability; import work.slhaf.partner.core.perceive.PerceiveCapability;
@@ -15,17 +14,13 @@ import work.slhaf.partner.module.modules.perceive.updater.static_extractor.entit
import work.slhaf.partner.runtime.interaction.data.context.PartnerRunningFlowContext; import work.slhaf.partner.runtime.interaction.data.context.PartnerRunningFlowContext;
import java.util.HashMap; import java.util.HashMap;
@EqualsAndHashCode(callSuper = true) @EqualsAndHashCode(callSuper = true)
@Data @Data
@AgentSubModule public class StaticMemoryExtractor extends AbstractAgentModule.Sub<PartnerRunningFlowContext, HashMap<String, String>> implements ActivateModel {
public class StaticMemoryExtractor extends AbstractAgentSubModule<PartnerRunningFlowContext, HashMap<String, String>> implements ActivateModel {
@InjectCapability @InjectCapability
private CognationCapability cognationCapability; private CognationCapability cognationCapability;
@InjectCapability @InjectCapability
private PerceiveCapability perceiveCapability; private PerceiveCapability perceiveCapability;
@Override @Override
public HashMap<String, String> execute(PartnerRunningFlowContext context) { public HashMap<String, String> execute(PartnerRunningFlowContext context) {
StaticMemoryExtractInput input = StaticMemoryExtractInput.builder() StaticMemoryExtractInput input = StaticMemoryExtractInput.builder()
@@ -33,22 +28,18 @@ public class StaticMemoryExtractor extends AbstractAgentSubModule<PartnerRunning
.messages(cognationCapability.getChatMessages()) .messages(cognationCapability.getChatMessages())
.existedStaticMap(perceiveCapability.getUser(context.getUserId()).getStaticMemory()) .existedStaticMap(perceiveCapability.getUser(context.getUserId()).getStaticMemory())
.build(); .build();
ChatResponse response = singleChat(JSONUtil.toJsonPrettyStr(input)); ChatResponse response = singleChat(JSONUtil.toJsonPrettyStr(input));
JSONObject jsonObject = JSONObject.parseObject(response.getMessage()); JSONObject jsonObject = JSONObject.parseObject(response.getMessage());
HashMap<String, String> result = new HashMap<>(); HashMap<String, String> result = new HashMap<>();
jsonObject.forEach((k, v) -> result.put(k, (String) v)); jsonObject.forEach((k, v) -> result.put(k, (String) v));
return result; return result;
} }
@Override @Override
public String modelKey() { public String modelKey() {
return "static_extractor"; return "static_extractor";
} }
@Override @Override
public boolean withBasicPrompt() { public boolean withBasicPrompt() {
return true; return true;
} }
} }

View File

@@ -1,29 +1,25 @@
package work.slhaf.partner.module.modules.process; package work.slhaf.partner.module.modules.process;
import lombok.Data; import lombok.Data;
import lombok.EqualsAndHashCode; import lombok.EqualsAndHashCode;
import lombok.extern.slf4j.Slf4j;
import work.slhaf.partner.api.agent.factory.capability.annotation.InjectCapability; import work.slhaf.partner.api.agent.factory.capability.annotation.InjectCapability;
import work.slhaf.partner.api.agent.factory.module.abstracts.AbstractAgentRunningModule; import work.slhaf.partner.api.agent.factory.module.abstracts.AbstractAgentModule;
import work.slhaf.partner.api.agent.factory.module.annotation.AgentRunningModule;
import work.slhaf.partner.core.cognation.CognationCapability; import work.slhaf.partner.core.cognation.CognationCapability;
import work.slhaf.partner.runtime.interaction.data.context.PartnerRunningFlowContext; import work.slhaf.partner.runtime.interaction.data.context.PartnerRunningFlowContext;
@EqualsAndHashCode(callSuper = true) @EqualsAndHashCode(callSuper = true)
@Slf4j
@Data @Data
@AgentRunningModule(name = "postprocess_executor", order = 6) public class PostprocessExecutor extends AbstractAgentModule.Running<PartnerRunningFlowContext> {
public class PostprocessExecutor extends AbstractAgentRunningModule<PartnerRunningFlowContext> {
private static final int POST_PROCESS_TRIGGER_ROLL_LIMIT = 36; private static final int POST_PROCESS_TRIGGER_ROLL_LIMIT = 36;
@InjectCapability @InjectCapability
private CognationCapability cognationCapability; private CognationCapability cognationCapability;
@Override @Override
public void execute(PartnerRunningFlowContext context) { public void execute(PartnerRunningFlowContext context) {
boolean trigger = cognationCapability.getChatMessages().size() >= POST_PROCESS_TRIGGER_ROLL_LIMIT; boolean trigger = cognationCapability.getChatMessages().size() >= POST_PROCESS_TRIGGER_ROLL_LIMIT;
context.getModuleContext().getExtraContext().put("post_process_trigger", trigger); context.getModuleContext().getExtraContext().put("post_process_trigger", trigger);
log.debug("[PostprocessExecutor] 是否执行后处理: {}", trigger); log.debug("[PostprocessExecutor] 是否执行后处理: {}", trigger);
} }
@Override
public int order() {
return 6;
}
} }

View File

@@ -2,9 +2,7 @@ package work.slhaf.partner.module.modules.process;
import lombok.Data; import lombok.Data;
import lombok.EqualsAndHashCode; import lombok.EqualsAndHashCode;
import lombok.extern.slf4j.Slf4j;
import work.slhaf.partner.api.agent.factory.capability.annotation.InjectCapability; import work.slhaf.partner.api.agent.factory.capability.annotation.InjectCapability;
import work.slhaf.partner.api.agent.factory.module.annotation.AgentRunningModule;
import work.slhaf.partner.core.cognation.CognationCapability; import work.slhaf.partner.core.cognation.CognationCapability;
import work.slhaf.partner.core.perceive.PerceiveCapability; import work.slhaf.partner.core.perceive.PerceiveCapability;
import work.slhaf.partner.core.perceive.pojo.User; import work.slhaf.partner.core.perceive.pojo.User;
@@ -16,31 +14,24 @@ import java.time.LocalDateTime;
import java.time.format.DateTimeFormatter; import java.time.format.DateTimeFormatter;
import java.util.HashMap; import java.util.HashMap;
import java.util.Map; import java.util.Map;
@EqualsAndHashCode(callSuper = true) @EqualsAndHashCode(callSuper = true)
@Data @Data
@Slf4j
@AgentRunningModule(name = "preprocess_executor", order = 1)
public class PreprocessExecutor extends PreRunningAbstractAgentModuleAbstract { public class PreprocessExecutor extends PreRunningAbstractAgentModuleAbstract {
@InjectCapability @InjectCapability
private CognationCapability cognationCapability; private CognationCapability cognationCapability;
@InjectCapability @InjectCapability
private PerceiveCapability perceiveCapability; private PerceiveCapability perceiveCapability;
@Override @Override
public void doExecute(PartnerRunningFlowContext context) { public void doExecute(PartnerRunningFlowContext context) {
checkAndSetMemoryId(); checkAndSetMemoryId();
getInteractionContext(context); getInteractionContext(context);
} }
private void checkAndSetMemoryId() { private void checkAndSetMemoryId() {
String currentMemoryId = cognationCapability.getCurrentMemoryId(); String currentMemoryId = cognationCapability.getCurrentMemoryId();
if (currentMemoryId == null || cognationCapability.getChatMessages().isEmpty()) { if (currentMemoryId == null || cognationCapability.getChatMessages().isEmpty()) {
cognationCapability.refreshMemoryId(); cognationCapability.refreshMemoryId();
} }
} }
private void getInteractionContext(PartnerRunningFlowContext context) { private void getInteractionContext(PartnerRunningFlowContext context) {
log.debug("[PreprocessExecutor] 预处理原始输入: {}", context); log.debug("[PreprocessExecutor] 预处理原始输入: {}", context);
User user = perceiveCapability.getUser(context.getUserInfo(), context.getPlatform()); User user = perceiveCapability.getUser(context.getUserInfo(), context.getPlatform());
@@ -49,15 +40,12 @@ public class PreprocessExecutor extends PreRunningAbstractAgentModuleAbstract {
} }
String userId = user.getUuid(); String userId = user.getUuid();
context.setUserId(userId); context.setUserId(userId);
String userStr = "[" + context.getUserNickname() + "(" + userId + ")]"; String userStr = "[" + context.getUserNickname() + "(" + userId + ")]";
String input = userStr + " " + context.getInput(); String input = userStr + " " + context.getInput();
context.setInput(input); context.setInput(input);
setCoreContext(context); setCoreContext(context);
log.debug("[PreprocessExecutor] 预处理结果: {}", context); log.debug("[PreprocessExecutor] 预处理结果: {}", context);
} }
@Override @Override
protected Map<String, String> getPromptDataMap(PartnerRunningFlowContext context) { protected Map<String, String> getPromptDataMap(PartnerRunningFlowContext context) {
HashMap<String, String> map = new HashMap<>(); HashMap<String, String> map = new HashMap<>();
@@ -69,12 +57,10 @@ public class PreprocessExecutor extends PreRunningAbstractAgentModuleAbstract {
map.put("其他", "历史对话中将在用户消息的最后一行标注时间"); map.put("其他", "历史对话中将在用户消息的最后一行标注时间");
return map; return map;
} }
@Override @Override
protected String moduleName() { protected String moduleName() {
return "[基础模块]"; return "[基础模块]";
} }
private void setCoreContext(PartnerRunningFlowContext context) { private void setCoreContext(PartnerRunningFlowContext context) {
CoreContext coreContext = context.getCoreContext(); CoreContext coreContext = context.getCoreContext();
coreContext.setText(context.getInput()); coreContext.setText(context.getInput());
@@ -82,4 +68,9 @@ public class PreprocessExecutor extends PreRunningAbstractAgentModuleAbstract {
coreContext.setUserNick(context.getUserNickname()); coreContext.setUserNick(context.getUserNickname());
coreContext.setUserId(context.getUserId()); coreContext.setUserId(context.getUserId());
} }
@Override
public int order() {
return 1;
}
} }