推进 ActionExecutor、确定动态插拔式行动调度的实现思路

- 在 ActionCore 中添加关闭hook,用于正确设置异常中断时执行中任务的状态
- 修正 actionPool 相关注释及用法
- 将 ActionData 中行动链字段调整为 LinkedHashMap 用于更好地支持分组并发及动态调度
- 重构 ActionExecutor 行动链执行逻辑,采用 Phaser 支持动态调度
- 扩展 InputData、Context 字段并调整 GateWay 格式化逻辑以适应特殊输入
This commit is contained in:
2025-10-27 23:43:06 +08:00
parent 83832d2060
commit e35e18f3b7
12 changed files with 125 additions and 67 deletions

View File

@@ -25,7 +25,7 @@ import java.util.stream.Collectors;
public class ActionCore extends PartnerCore<ActionCore> {
/**
* 对应本次交互即将执行或将要放置在行动池的预备任务因此将以本次交互的uuid为键其起到的作用相当于暂时的模块上下文
* 持久行动池以用户id为键存储所有状态的任务
*/
private HashMap<String, List<ActionData>> actionPool = new HashMap<>();
@@ -48,8 +48,26 @@ public class ActionCore extends PartnerCore<ActionCore> {
public ActionCore() throws IOException, ClassNotFoundException {
new ActionWatchService(existedMetaActions, virtualExecutor).launch();
setupShutdownHook();
}
private void setupShutdownHook() {
// 将执行中的行动状态置为失败
List<ActionData> executingActionList = listExecutingAction();
for (ActionData actionData : executingActionList) {
actionData.setStatus(ActionData.ActionStatus.FAILED);
actionData.setResult("由于系统中断而失败");
}
}
private List<ActionData> listExecutingAction() {
return actionPool.values().stream()
.flatMap(Collection::stream)
.filter(action -> action.getStatus() == ActionData.ActionStatus.EXECUTING)
.collect(Collectors.toList());
}
@CapabilityMethod
public synchronized void putPendingActions(String userId, ActionData actionData) {
pendingActions.computeIfAbsent(userId, k -> {

View File

@@ -2,6 +2,7 @@ package work.slhaf.partner.core.action.entity;
import lombok.Data;
import java.util.LinkedHashMap;
import java.util.List;
/**
@@ -12,7 +13,7 @@ public abstract class ActionData {
protected String uuid;
protected String tendency;
protected ActionStatus status;
protected List<MetaAction> actionChain;
protected LinkedHashMap<Integer, List<MetaAction>> actionChain;
protected String result;
protected String reason;
protected String description;

View File

@@ -26,7 +26,7 @@ public class MetaAction implements Comparable<MetaAction>, Runnable {
/**
* 行动结果,包括执行状态和相应内容(执行结果或者错误信息)
*/
private Result result;
private Result result = new Result();
/**
* 执行顺序,升序排列
*/
@@ -88,7 +88,7 @@ public class MetaAction implements Comparable<MetaAction>, Runnable {
@Data
public static class Result {
private boolean success;
private boolean success = true;
private String data;
}

View File

@@ -12,11 +12,9 @@ import work.slhaf.partner.core.action.entity.ActionData;
import work.slhaf.partner.core.action.entity.ImmediateActionData;
import work.slhaf.partner.core.action.entity.MetaAction;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.concurrent.CountDownLatch;
import java.util.*;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Phaser;
@Slf4j
@AgentSubModule
@@ -28,6 +26,8 @@ public class ActionExecutor extends AgentRunningSubModule<List<ImmediateActionDa
private ExecutorService virtualExecutor;
private ExecutorService platformExecutor;
private HashMap<String, PhaserActionChain> phaserRecorder = new HashMap<>();
@Init
public void init() {
virtualExecutor = actionCapability.getExecutor(ActionCore.ExecutorType.VIRTUAL);
@@ -45,69 +45,63 @@ public class ActionExecutor extends AgentRunningSubModule<List<ImmediateActionDa
private void handleActionData(ImmediateActionData actionData) {
virtualExecutor.execute(() -> {
actionData.setStatus(ActionData.ActionStatus.EXECUTING);
List<MetaAction> actionChain = actionData.getActionChain();
actionChain.sort(MetaAction::compareTo);
LinkedHashMap<Integer, List<MetaAction>> actionChain = actionData.getActionChain();
List<MetaAction> virtual = new ArrayList<>();
List<MetaAction> platform = new ArrayList<>();
int order;
for (int index = 0; index < actionChain.size(); index++) {
MetaAction metaAction = actionChain.get(index);
// 根据io类型放入合适的列表
if (metaAction.isIo()) {
virtual.add(metaAction);
} else {
platform.add(metaAction);
actionChain.forEach((k, v) -> {
for (MetaAction metaAction : v) {
// 根据io类型放入合适的列表
if (metaAction.isIo()) {
virtual.add(metaAction);
} else {
platform.add(metaAction);
}
}
// 记录当前order
order = metaAction.getOrder();
// 如果下一个行动单元的order与当前不同则执行并清空当前组内容
if (actionChain.size() <= (index + 1) || actionChain.get(index + 1).getOrder() != order) {
runGroupAction(virtual, platform, actionChain);
}
}
runGroupAction(virtual, platform, actionChain);
});
});
}
//TODO 考虑是否使用phaser来承担同组的动态任务新增
private void runGroupAction(List<MetaAction> virtual, List<MetaAction> platform, List<MetaAction> actionChain) {
boolean first = true;
do {
CountDownLatch latch = new CountDownLatch(virtual.size() + platform.size());
runGroupAction(virtual, virtualExecutor, actionChain, latch, first);
runGroupAction(platform, platformExecutor, actionChain, latch, first);
try {
latch.await();
} catch (InterruptedException e) {
log.error("[{}] CountDownLatch被中断", modelKey());
}
first = false;
} while (!virtual.isEmpty() || !platform.isEmpty());
// 使用phaser来承担同组的动态任务新增
private void runGroupAction(List<MetaAction> virtual, List<MetaAction> platform, LinkedHashMap<Integer, List<MetaAction>> actionChain) {
Phaser phaser = new Phaser();
phaser.register();
String groupId = UUID.randomUUID().toString();
phaserRecorder.put(groupId, new PhaserActionChain(phaser, actionChain));
runGroupAction(virtual, virtualExecutor, actionChain, phaser);
runGroupAction(platform, platformExecutor, actionChain, phaser);
phaserRecorder.remove(groupId);
phaser.arriveAndAwaitAdvance();
}
private void runGroupAction(List<MetaAction> actions, ExecutorService executor, List<MetaAction> actionChain, CountDownLatch latch, boolean first) {
if (!first && !new HashSet<>(actionChain).containsAll(actions)) {
// 该部分对应LLM新增本组执行单元时将其添加至actionChain记录。对于后续组级别的新增将直接在上一级调用处体现除了注意并发安全外无需额外处理
int index = actionChain.indexOf(actions.getLast());
actionChain.addAll(index, actions);
}
private void runGroupAction(List<MetaAction> actions, ExecutorService executor, LinkedHashMap<Integer, List<MetaAction>> actionChain, Phaser phaser) {
for (MetaAction action : actions) {
phaser.register();
executor.execute(() -> {
boolean success = true;
MetaAction.Result result = action.getResult();
do {
// 该循环对应LLM的调整参数后重试
if (!success) {
//TODO LLM决策是重构参数、执行自对话反思、还是选择向用户求助(通过cognationCore暴露方法可能需要修改其他模块以进行适应)
if (!result.isSuccess()) {
//TODO LLM决策是重构参数、执行自对话反思、还是选择向用户求助(通过cognationCore暴露方法可能需要修改其他模块以进行适应),仅重构参数时无需结束当前循环
// 若使用Phaser作为执行线程与反思、求助等调用流程的同步协调应当需要额外维护Phaser全局字段获取到反思结果或者用户反馈后
// 调用对应的phaser注册任务在ActionExecutor中动态添加任务至actionChain,同时启动异步执行
// 而且由于执行与放入的为同一个MetaAction对象所以执行结果可被当前行动链获取但virtual、executor两个列表似乎不行需要重构执行模式建议将行动链直接重构为LinkedHashMaporder为键
String input = getInput(result.getData());
}
action.run();
success = action.getResult().isSuccess();
} while (!success);
latch.countDown();
} while (!result.isSuccess());
//TODO 将执行结果写入特定对话角色记忆(cognationCore暴露方法)
phaser.arriveAndDeregister();
});
}
}
private String getInput(String data) {
return null;
}
@Override
public String modelKey() {
return "action_executor";
@@ -117,4 +111,7 @@ public class ActionExecutor extends AgentRunningSubModule<List<ImmediateActionDa
public boolean withBasicPrompt() {
return false;
}
private record PhaserActionChain(Phaser phaser, LinkedHashMap<Integer, List<MetaAction>> actionChain) {
}
}

View File

@@ -11,6 +11,7 @@ import work.slhaf.partner.common.vector.VectorClient;
import work.slhaf.partner.core.action.ActionCapability;
import work.slhaf.partner.core.action.entity.ActionData;
import work.slhaf.partner.core.action.entity.ImmediateActionData;
import work.slhaf.partner.core.action.entity.MetaAction;
import work.slhaf.partner.core.action.entity.ScheduledActionData;
import work.slhaf.partner.core.action.entity.cache.CacheAdjustData;
import work.slhaf.partner.core.action.entity.cache.CacheAdjustMetaData;
@@ -29,10 +30,7 @@ import work.slhaf.partner.module.modules.action.planner.extractor.entity.Extract
import work.slhaf.partner.module.modules.action.planner.extractor.entity.ExtractorResult;
import work.slhaf.partner.runtime.interaction.data.context.PartnerRunningFlowContext;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.UUID;
import java.util.*;
import java.util.concurrent.Callable;
/**
@@ -163,8 +161,9 @@ public class ActionPlanner extends PreRunningModule {
@Override
protected HashMap<String, String> getPromptDataMap(PartnerRunningFlowContext context) {
HashMap<String, String> map = new HashMap<>();
setupPendingActions(map, context.getUserId());
setupPreparedActions(map, context.getUuid());
String userId = context.getUserId();
setupPendingActions(map, userId);
setupPreparedActions(map, userId);
return map;
}
@@ -179,8 +178,8 @@ public class ActionPlanner extends PreRunningModule {
}
}
private void setupPreparedActions(HashMap<String, String> map, String uuid) {
List<ActionData> actionData = actionCapability.listPreparedAction(uuid);
private void setupPreparedActions(HashMap<String, String> map, String userId) {
List<ActionData> actionData = actionCapability.listPreparedAction(userId);
if (actionData == null || actionData.isEmpty()) {
map.put("[预备行动] <预备行动信息>", "无预备行动");
return;
@@ -229,10 +228,14 @@ public class ActionPlanner extends PreRunningModule {
}
private ActionData buildMetaActionInfo(EvaluatorResult evaluatorResult) {
LinkedHashMap<Integer, List<MetaAction>> actionChain = new LinkedHashMap<>();
for (MetaAction metaAction : evaluatorResult.getActionChain()) {
actionChain.computeIfAbsent(metaAction.getOrder(), k -> new ArrayList<>()).add(metaAction);
}
return switch (evaluatorResult.getType()) {
case PLANNING -> {
ScheduledActionData actionInfo = new ScheduledActionData();
actionInfo.setActionChain(evaluatorResult.getActionChain());
actionInfo.setActionChain(actionChain);
actionInfo.setScheduleContent(evaluatorResult.getScheduleContent());
actionInfo.setStatus(ActionData.ActionStatus.PREPARE);
actionInfo.setUuid(UUID.randomUUID().toString());
@@ -240,7 +243,7 @@ public class ActionPlanner extends PreRunningModule {
}
case IMMEDIATE -> {
ImmediateActionData actionInfo = new ImmediateActionData();
actionInfo.setActionChain(evaluatorResult.getActionChain());
actionInfo.setActionChain(actionChain);
actionInfo.setStatus(ActionData.ActionStatus.PREPARE);
actionInfo.setUuid(UUID.randomUUID().toString());
yield actionInfo;

View File

@@ -3,6 +3,7 @@ package work.slhaf.partner.runtime.interaction;
import work.slhaf.partner.api.agent.runtime.interaction.AgentInteractionAdapter;
import work.slhaf.partner.runtime.interaction.data.PartnerInputData;
import work.slhaf.partner.runtime.interaction.data.PartnerOutputData;
import work.slhaf.partner.runtime.interaction.data.SpecializedPartnerInputData;
import work.slhaf.partner.runtime.interaction.data.context.PartnerRunningFlowContext;
public class PartnerInteractionAdapter extends AgentInteractionAdapter<PartnerInputData, PartnerOutputData, PartnerRunningFlowContext> {
@@ -33,6 +34,10 @@ public class PartnerInteractionAdapter extends AgentInteractionAdapter<PartnerIn
context.setSingle(inputData.isSingle());
context.setPlatform(inputData.getPlatform());
context.setInput(inputData.getContent());
context.setType(inputData.getInputType());
if (inputData instanceof SpecializedPartnerInputData specializedData) {
context.setPayload(specializedData.getPayload());
}
return context;
}
}

View File

@@ -13,8 +13,7 @@ import work.slhaf.partner.api.agent.runtime.interaction.AgentGateway;
import work.slhaf.partner.api.agent.runtime.interaction.AgentInteractionAdapter;
import work.slhaf.partner.common.config.PartnerAgentConfigManager;
import work.slhaf.partner.common.thread.InteractionThreadPoolExecutor;
import work.slhaf.partner.runtime.interaction.data.PartnerInputData;
import work.slhaf.partner.runtime.interaction.data.PartnerOutputData;
import work.slhaf.partner.runtime.interaction.data.*;
import work.slhaf.partner.runtime.interaction.data.context.PartnerRunningFlowContext;
import java.net.InetSocketAddress;
@@ -141,7 +140,12 @@ public class WebSocketGateway extends WebSocketServer implements AgentGateway<Pa
@Override
public void onMessage(WebSocket webSocket, String s) {
PartnerInputData inputData = JSONObject.parseObject(s, PartnerInputData.class);
JSONObject parsedObject = JSONObject.parseObject(s);
PartnerInputType inputType = parsedObject.getObject(SpecializedPayloadConstant.TYPE, PartnerInputType.class);
PartnerInputData inputData = switch (inputType) {
case NORMAL -> parsedObject.to(PartnerInputData.class);
case SYSTEM, ASSIST_REQUEST, REFLECTION -> parsedObject.to(SpecializedPartnerInputData.class);
};
userSessions.put(inputData.getUserInfo(), webSocket); // 注册连接
receive(inputData);
}

View File

@@ -7,7 +7,8 @@ import work.slhaf.partner.api.agent.runtime.interaction.data.AgentInputData;
@EqualsAndHashCode(callSuper = true)
@Data
public class PartnerInputData extends AgentInputData {
private String userNickName;
private String platform;
private boolean single;
protected String userNickName;
protected String platform;
protected boolean single;
protected PartnerInputType inputType;
}

View File

@@ -0,0 +1,5 @@
package work.slhaf.partner.runtime.interaction.data;
public enum PartnerInputType {
NORMAL, REFLECTION, ASSIST_REQUEST, SYSTEM
}

View File

@@ -0,0 +1,12 @@
package work.slhaf.partner.runtime.interaction.data;
import lombok.Data;
import lombok.EqualsAndHashCode;
import java.util.Map;
@EqualsAndHashCode(callSuper = true)
@Data
public class SpecializedPartnerInputData extends PartnerInputData {
protected Map<String, String> payload;
}

View File

@@ -0,0 +1,7 @@
package work.slhaf.partner.runtime.interaction.data;
public class SpecializedPayloadConstant {
public static final String TASK_ID = "taskId";
public static final String ACTION_ID = "actionId";
public static final String TYPE = "inputType";
}

View File

@@ -5,6 +5,7 @@ import lombok.Data;
import lombok.EqualsAndHashCode;
import work.slhaf.partner.api.agent.runtime.interaction.flow.entity.RunningFlowContext;
import work.slhaf.partner.module.common.entity.AppendPromptData;
import work.slhaf.partner.runtime.interaction.data.PartnerInputType;
import work.slhaf.partner.runtime.interaction.data.context.subcontext.CoreContext;
import work.slhaf.partner.runtime.interaction.data.context.subcontext.ModuleContext;
@@ -12,6 +13,7 @@ import java.io.Serial;
import java.time.LocalDateTime;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.UUID;
@EqualsAndHashCode(callSuper = true)
@@ -30,6 +32,9 @@ public class PartnerRunningFlowContext extends RunningFlowContext {
protected LocalDateTime dateTime;
protected boolean single;
protected PartnerInputType type;
protected Map<String, String> payload;
protected String input;
protected CoreContext coreContext = new CoreContext();