mirror of
https://github.com/slhaf/Partner.git
synced 2026-05-12 16:53:04 +08:00
推进 ActionExecutor、确定动态插拔式行动调度的实现思路
- 在 ActionCore 中添加关闭hook,用于正确设置异常中断时执行中任务的状态 - 修正 actionPool 相关注释及用法 - 将 ActionData 中行动链字段调整为 LinkedHashMap 用于更好地支持分组并发及动态调度 - 重构 ActionExecutor 行动链执行逻辑,采用 Phaser 支持动态调度 - 扩展 InputData、Context 字段并调整 GateWay 格式化逻辑以适应特殊输入
This commit is contained in:
@@ -25,7 +25,7 @@ import java.util.stream.Collectors;
|
||||
public class ActionCore extends PartnerCore<ActionCore> {
|
||||
|
||||
/**
|
||||
* 对应本次交互即将执行或将要放置在行动池的预备任务,因此将以本次交互的uuid为键,其起到的作用相当于暂时的模块上下文
|
||||
* 持久行动池,以用户id为键存储所有状态的任务
|
||||
*/
|
||||
private HashMap<String, List<ActionData>> actionPool = new HashMap<>();
|
||||
|
||||
@@ -48,8 +48,26 @@ public class ActionCore extends PartnerCore<ActionCore> {
|
||||
|
||||
public ActionCore() throws IOException, ClassNotFoundException {
|
||||
new ActionWatchService(existedMetaActions, virtualExecutor).launch();
|
||||
setupShutdownHook();
|
||||
}
|
||||
|
||||
private void setupShutdownHook() {
|
||||
// 将执行中的行动状态置为失败
|
||||
List<ActionData> executingActionList = listExecutingAction();
|
||||
for (ActionData actionData : executingActionList) {
|
||||
actionData.setStatus(ActionData.ActionStatus.FAILED);
|
||||
actionData.setResult("由于系统中断而失败");
|
||||
}
|
||||
}
|
||||
|
||||
private List<ActionData> listExecutingAction() {
|
||||
return actionPool.values().stream()
|
||||
.flatMap(Collection::stream)
|
||||
.filter(action -> action.getStatus() == ActionData.ActionStatus.EXECUTING)
|
||||
.collect(Collectors.toList());
|
||||
}
|
||||
|
||||
|
||||
@CapabilityMethod
|
||||
public synchronized void putPendingActions(String userId, ActionData actionData) {
|
||||
pendingActions.computeIfAbsent(userId, k -> {
|
||||
|
||||
@@ -2,6 +2,7 @@ package work.slhaf.partner.core.action.entity;
|
||||
|
||||
import lombok.Data;
|
||||
|
||||
import java.util.LinkedHashMap;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
@@ -12,7 +13,7 @@ public abstract class ActionData {
|
||||
protected String uuid;
|
||||
protected String tendency;
|
||||
protected ActionStatus status;
|
||||
protected List<MetaAction> actionChain;
|
||||
protected LinkedHashMap<Integer, List<MetaAction>> actionChain;
|
||||
protected String result;
|
||||
protected String reason;
|
||||
protected String description;
|
||||
|
||||
@@ -26,7 +26,7 @@ public class MetaAction implements Comparable<MetaAction>, Runnable {
|
||||
/**
|
||||
* 行动结果,包括执行状态和相应内容(执行结果或者错误信息)
|
||||
*/
|
||||
private Result result;
|
||||
private Result result = new Result();
|
||||
/**
|
||||
* 执行顺序,升序排列
|
||||
*/
|
||||
@@ -88,7 +88,7 @@ public class MetaAction implements Comparable<MetaAction>, Runnable {
|
||||
|
||||
@Data
|
||||
public static class Result {
|
||||
private boolean success;
|
||||
private boolean success = true;
|
||||
private String data;
|
||||
}
|
||||
|
||||
|
||||
@@ -12,11 +12,9 @@ import work.slhaf.partner.core.action.entity.ActionData;
|
||||
import work.slhaf.partner.core.action.entity.ImmediateActionData;
|
||||
import work.slhaf.partner.core.action.entity.MetaAction;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.concurrent.CountDownLatch;
|
||||
import java.util.*;
|
||||
import java.util.concurrent.ExecutorService;
|
||||
import java.util.concurrent.Phaser;
|
||||
|
||||
@Slf4j
|
||||
@AgentSubModule
|
||||
@@ -28,6 +26,8 @@ public class ActionExecutor extends AgentRunningSubModule<List<ImmediateActionDa
|
||||
private ExecutorService virtualExecutor;
|
||||
private ExecutorService platformExecutor;
|
||||
|
||||
private HashMap<String, PhaserActionChain> phaserRecorder = new HashMap<>();
|
||||
|
||||
@Init
|
||||
public void init() {
|
||||
virtualExecutor = actionCapability.getExecutor(ActionCore.ExecutorType.VIRTUAL);
|
||||
@@ -45,69 +45,63 @@ public class ActionExecutor extends AgentRunningSubModule<List<ImmediateActionDa
|
||||
private void handleActionData(ImmediateActionData actionData) {
|
||||
virtualExecutor.execute(() -> {
|
||||
actionData.setStatus(ActionData.ActionStatus.EXECUTING);
|
||||
List<MetaAction> actionChain = actionData.getActionChain();
|
||||
actionChain.sort(MetaAction::compareTo);
|
||||
LinkedHashMap<Integer, List<MetaAction>> actionChain = actionData.getActionChain();
|
||||
List<MetaAction> virtual = new ArrayList<>();
|
||||
List<MetaAction> platform = new ArrayList<>();
|
||||
int order;
|
||||
for (int index = 0; index < actionChain.size(); index++) {
|
||||
MetaAction metaAction = actionChain.get(index);
|
||||
// 根据io类型放入合适的列表
|
||||
if (metaAction.isIo()) {
|
||||
virtual.add(metaAction);
|
||||
} else {
|
||||
platform.add(metaAction);
|
||||
actionChain.forEach((k, v) -> {
|
||||
for (MetaAction metaAction : v) {
|
||||
// 根据io类型放入合适的列表
|
||||
if (metaAction.isIo()) {
|
||||
virtual.add(metaAction);
|
||||
} else {
|
||||
platform.add(metaAction);
|
||||
}
|
||||
}
|
||||
// 记录当前order
|
||||
order = metaAction.getOrder();
|
||||
// 如果下一个行动单元的order与当前不同,则执行并清空当前组内容
|
||||
if (actionChain.size() <= (index + 1) || actionChain.get(index + 1).getOrder() != order) {
|
||||
runGroupAction(virtual, platform, actionChain);
|
||||
}
|
||||
}
|
||||
runGroupAction(virtual, platform, actionChain);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
//TODO 考虑是否使用phaser来承担同组的动态任务新增
|
||||
private void runGroupAction(List<MetaAction> virtual, List<MetaAction> platform, List<MetaAction> actionChain) {
|
||||
boolean first = true;
|
||||
do {
|
||||
CountDownLatch latch = new CountDownLatch(virtual.size() + platform.size());
|
||||
runGroupAction(virtual, virtualExecutor, actionChain, latch, first);
|
||||
runGroupAction(platform, platformExecutor, actionChain, latch, first);
|
||||
try {
|
||||
latch.await();
|
||||
} catch (InterruptedException e) {
|
||||
log.error("[{}] CountDownLatch被中断", modelKey());
|
||||
}
|
||||
first = false;
|
||||
} while (!virtual.isEmpty() || !platform.isEmpty());
|
||||
// 使用phaser来承担同组的动态任务新增
|
||||
private void runGroupAction(List<MetaAction> virtual, List<MetaAction> platform, LinkedHashMap<Integer, List<MetaAction>> actionChain) {
|
||||
Phaser phaser = new Phaser();
|
||||
phaser.register();
|
||||
String groupId = UUID.randomUUID().toString();
|
||||
phaserRecorder.put(groupId, new PhaserActionChain(phaser, actionChain));
|
||||
runGroupAction(virtual, virtualExecutor, actionChain, phaser);
|
||||
runGroupAction(platform, platformExecutor, actionChain, phaser);
|
||||
phaserRecorder.remove(groupId);
|
||||
phaser.arriveAndAwaitAdvance();
|
||||
}
|
||||
|
||||
private void runGroupAction(List<MetaAction> actions, ExecutorService executor, List<MetaAction> actionChain, CountDownLatch latch, boolean first) {
|
||||
if (!first && !new HashSet<>(actionChain).containsAll(actions)) {
|
||||
// 该部分对应LLM新增本组执行单元时,将其添加至actionChain记录。对于后续组级别的新增,将直接在上一级调用处体现,除了注意并发安全外无需额外处理
|
||||
int index = actionChain.indexOf(actions.getLast());
|
||||
actionChain.addAll(index, actions);
|
||||
}
|
||||
private void runGroupAction(List<MetaAction> actions, ExecutorService executor, LinkedHashMap<Integer, List<MetaAction>> actionChain, Phaser phaser) {
|
||||
for (MetaAction action : actions) {
|
||||
phaser.register();
|
||||
executor.execute(() -> {
|
||||
boolean success = true;
|
||||
MetaAction.Result result = action.getResult();
|
||||
do {
|
||||
// 该循环对应LLM的调整参数后重试
|
||||
if (!success) {
|
||||
//TODO LLM决策是重构参数、执行自对话反思、还是选择向用户求助(通过cognationCore暴露方法,可能需要修改其他模块以进行适应)
|
||||
if (!result.isSuccess()) {
|
||||
//TODO LLM决策是重构参数、执行自对话反思、还是选择向用户求助(通过cognationCore暴露方法,可能需要修改其他模块以进行适应),仅重构参数时无需结束当前循环
|
||||
// 若使用Phaser作为执行线程与反思、求助等调用流程的同步协调,应当需要额外维护Phaser全局字段,获取到反思结果或者用户反馈后,
|
||||
// 调用对应的phaser注册任务,在ActionExecutor中动态添加任务至actionChain,同时启动异步执行
|
||||
// 而且由于执行与放入的为同一个MetaAction对象,所以执行结果可被当前行动链获取,但virtual、executor两个列表似乎不行,需要重构执行模式,建议将行动链直接重构为LinkedHashMap,order为键
|
||||
String input = getInput(result.getData());
|
||||
|
||||
}
|
||||
action.run();
|
||||
success = action.getResult().isSuccess();
|
||||
} while (!success);
|
||||
latch.countDown();
|
||||
} while (!result.isSuccess());
|
||||
//TODO 将执行结果写入特定对话角色记忆(cognationCore暴露方法)
|
||||
phaser.arriveAndDeregister();
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
private String getInput(String data) {
|
||||
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String modelKey() {
|
||||
return "action_executor";
|
||||
@@ -117,4 +111,7 @@ public class ActionExecutor extends AgentRunningSubModule<List<ImmediateActionDa
|
||||
public boolean withBasicPrompt() {
|
||||
return false;
|
||||
}
|
||||
|
||||
private record PhaserActionChain(Phaser phaser, LinkedHashMap<Integer, List<MetaAction>> actionChain) {
|
||||
}
|
||||
}
|
||||
|
||||
@@ -11,6 +11,7 @@ import work.slhaf.partner.common.vector.VectorClient;
|
||||
import work.slhaf.partner.core.action.ActionCapability;
|
||||
import work.slhaf.partner.core.action.entity.ActionData;
|
||||
import work.slhaf.partner.core.action.entity.ImmediateActionData;
|
||||
import work.slhaf.partner.core.action.entity.MetaAction;
|
||||
import work.slhaf.partner.core.action.entity.ScheduledActionData;
|
||||
import work.slhaf.partner.core.action.entity.cache.CacheAdjustData;
|
||||
import work.slhaf.partner.core.action.entity.cache.CacheAdjustMetaData;
|
||||
@@ -29,10 +30,7 @@ import work.slhaf.partner.module.modules.action.planner.extractor.entity.Extract
|
||||
import work.slhaf.partner.module.modules.action.planner.extractor.entity.ExtractorResult;
|
||||
import work.slhaf.partner.runtime.interaction.data.context.PartnerRunningFlowContext;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.UUID;
|
||||
import java.util.*;
|
||||
import java.util.concurrent.Callable;
|
||||
|
||||
/**
|
||||
@@ -163,8 +161,9 @@ public class ActionPlanner extends PreRunningModule {
|
||||
@Override
|
||||
protected HashMap<String, String> getPromptDataMap(PartnerRunningFlowContext context) {
|
||||
HashMap<String, String> map = new HashMap<>();
|
||||
setupPendingActions(map, context.getUserId());
|
||||
setupPreparedActions(map, context.getUuid());
|
||||
String userId = context.getUserId();
|
||||
setupPendingActions(map, userId);
|
||||
setupPreparedActions(map, userId);
|
||||
return map;
|
||||
}
|
||||
|
||||
@@ -179,8 +178,8 @@ public class ActionPlanner extends PreRunningModule {
|
||||
}
|
||||
}
|
||||
|
||||
private void setupPreparedActions(HashMap<String, String> map, String uuid) {
|
||||
List<ActionData> actionData = actionCapability.listPreparedAction(uuid);
|
||||
private void setupPreparedActions(HashMap<String, String> map, String userId) {
|
||||
List<ActionData> actionData = actionCapability.listPreparedAction(userId);
|
||||
if (actionData == null || actionData.isEmpty()) {
|
||||
map.put("[预备行动] <预备行动信息>", "无预备行动");
|
||||
return;
|
||||
@@ -229,10 +228,14 @@ public class ActionPlanner extends PreRunningModule {
|
||||
}
|
||||
|
||||
private ActionData buildMetaActionInfo(EvaluatorResult evaluatorResult) {
|
||||
LinkedHashMap<Integer, List<MetaAction>> actionChain = new LinkedHashMap<>();
|
||||
for (MetaAction metaAction : evaluatorResult.getActionChain()) {
|
||||
actionChain.computeIfAbsent(metaAction.getOrder(), k -> new ArrayList<>()).add(metaAction);
|
||||
}
|
||||
return switch (evaluatorResult.getType()) {
|
||||
case PLANNING -> {
|
||||
ScheduledActionData actionInfo = new ScheduledActionData();
|
||||
actionInfo.setActionChain(evaluatorResult.getActionChain());
|
||||
actionInfo.setActionChain(actionChain);
|
||||
actionInfo.setScheduleContent(evaluatorResult.getScheduleContent());
|
||||
actionInfo.setStatus(ActionData.ActionStatus.PREPARE);
|
||||
actionInfo.setUuid(UUID.randomUUID().toString());
|
||||
@@ -240,7 +243,7 @@ public class ActionPlanner extends PreRunningModule {
|
||||
}
|
||||
case IMMEDIATE -> {
|
||||
ImmediateActionData actionInfo = new ImmediateActionData();
|
||||
actionInfo.setActionChain(evaluatorResult.getActionChain());
|
||||
actionInfo.setActionChain(actionChain);
|
||||
actionInfo.setStatus(ActionData.ActionStatus.PREPARE);
|
||||
actionInfo.setUuid(UUID.randomUUID().toString());
|
||||
yield actionInfo;
|
||||
|
||||
@@ -3,6 +3,7 @@ package work.slhaf.partner.runtime.interaction;
|
||||
import work.slhaf.partner.api.agent.runtime.interaction.AgentInteractionAdapter;
|
||||
import work.slhaf.partner.runtime.interaction.data.PartnerInputData;
|
||||
import work.slhaf.partner.runtime.interaction.data.PartnerOutputData;
|
||||
import work.slhaf.partner.runtime.interaction.data.SpecializedPartnerInputData;
|
||||
import work.slhaf.partner.runtime.interaction.data.context.PartnerRunningFlowContext;
|
||||
|
||||
public class PartnerInteractionAdapter extends AgentInteractionAdapter<PartnerInputData, PartnerOutputData, PartnerRunningFlowContext> {
|
||||
@@ -33,6 +34,10 @@ public class PartnerInteractionAdapter extends AgentInteractionAdapter<PartnerIn
|
||||
context.setSingle(inputData.isSingle());
|
||||
context.setPlatform(inputData.getPlatform());
|
||||
context.setInput(inputData.getContent());
|
||||
context.setType(inputData.getInputType());
|
||||
if (inputData instanceof SpecializedPartnerInputData specializedData) {
|
||||
context.setPayload(specializedData.getPayload());
|
||||
}
|
||||
return context;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -13,8 +13,7 @@ import work.slhaf.partner.api.agent.runtime.interaction.AgentGateway;
|
||||
import work.slhaf.partner.api.agent.runtime.interaction.AgentInteractionAdapter;
|
||||
import work.slhaf.partner.common.config.PartnerAgentConfigManager;
|
||||
import work.slhaf.partner.common.thread.InteractionThreadPoolExecutor;
|
||||
import work.slhaf.partner.runtime.interaction.data.PartnerInputData;
|
||||
import work.slhaf.partner.runtime.interaction.data.PartnerOutputData;
|
||||
import work.slhaf.partner.runtime.interaction.data.*;
|
||||
import work.slhaf.partner.runtime.interaction.data.context.PartnerRunningFlowContext;
|
||||
|
||||
import java.net.InetSocketAddress;
|
||||
@@ -141,7 +140,12 @@ public class WebSocketGateway extends WebSocketServer implements AgentGateway<Pa
|
||||
|
||||
@Override
|
||||
public void onMessage(WebSocket webSocket, String s) {
|
||||
PartnerInputData inputData = JSONObject.parseObject(s, PartnerInputData.class);
|
||||
JSONObject parsedObject = JSONObject.parseObject(s);
|
||||
PartnerInputType inputType = parsedObject.getObject(SpecializedPayloadConstant.TYPE, PartnerInputType.class);
|
||||
PartnerInputData inputData = switch (inputType) {
|
||||
case NORMAL -> parsedObject.to(PartnerInputData.class);
|
||||
case SYSTEM, ASSIST_REQUEST, REFLECTION -> parsedObject.to(SpecializedPartnerInputData.class);
|
||||
};
|
||||
userSessions.put(inputData.getUserInfo(), webSocket); // 注册连接
|
||||
receive(inputData);
|
||||
}
|
||||
|
||||
@@ -7,7 +7,8 @@ import work.slhaf.partner.api.agent.runtime.interaction.data.AgentInputData;
|
||||
@EqualsAndHashCode(callSuper = true)
|
||||
@Data
|
||||
public class PartnerInputData extends AgentInputData {
|
||||
private String userNickName;
|
||||
private String platform;
|
||||
private boolean single;
|
||||
protected String userNickName;
|
||||
protected String platform;
|
||||
protected boolean single;
|
||||
protected PartnerInputType inputType;
|
||||
}
|
||||
|
||||
@@ -0,0 +1,5 @@
|
||||
package work.slhaf.partner.runtime.interaction.data;
|
||||
|
||||
public enum PartnerInputType {
|
||||
NORMAL, REFLECTION, ASSIST_REQUEST, SYSTEM
|
||||
}
|
||||
@@ -0,0 +1,12 @@
|
||||
package work.slhaf.partner.runtime.interaction.data;
|
||||
|
||||
import lombok.Data;
|
||||
import lombok.EqualsAndHashCode;
|
||||
|
||||
import java.util.Map;
|
||||
|
||||
@EqualsAndHashCode(callSuper = true)
|
||||
@Data
|
||||
public class SpecializedPartnerInputData extends PartnerInputData {
|
||||
protected Map<String, String> payload;
|
||||
}
|
||||
@@ -0,0 +1,7 @@
|
||||
package work.slhaf.partner.runtime.interaction.data;
|
||||
|
||||
public class SpecializedPayloadConstant {
|
||||
public static final String TASK_ID = "taskId";
|
||||
public static final String ACTION_ID = "actionId";
|
||||
public static final String TYPE = "inputType";
|
||||
}
|
||||
@@ -5,6 +5,7 @@ import lombok.Data;
|
||||
import lombok.EqualsAndHashCode;
|
||||
import work.slhaf.partner.api.agent.runtime.interaction.flow.entity.RunningFlowContext;
|
||||
import work.slhaf.partner.module.common.entity.AppendPromptData;
|
||||
import work.slhaf.partner.runtime.interaction.data.PartnerInputType;
|
||||
import work.slhaf.partner.runtime.interaction.data.context.subcontext.CoreContext;
|
||||
import work.slhaf.partner.runtime.interaction.data.context.subcontext.ModuleContext;
|
||||
|
||||
@@ -12,6 +13,7 @@ import java.io.Serial;
|
||||
import java.time.LocalDateTime;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.UUID;
|
||||
|
||||
@EqualsAndHashCode(callSuper = true)
|
||||
@@ -30,6 +32,9 @@ public class PartnerRunningFlowContext extends RunningFlowContext {
|
||||
protected LocalDateTime dateTime;
|
||||
protected boolean single;
|
||||
|
||||
protected PartnerInputType type;
|
||||
protected Map<String, String> payload;
|
||||
|
||||
protected String input;
|
||||
|
||||
protected CoreContext coreContext = new CoreContext();
|
||||
|
||||
Reference in New Issue
Block a user