mirror of
https://github.com/slhaf/Partner.git
synced 2026-05-12 16:53:04 +08:00
refactor(runtime): support collect context by source and interrupt same-source running flow by module order
This commit is contained in:
@@ -69,7 +69,7 @@ public class ActionPlanner extends AbstractAgentModule.Running<PartnerRunningFlo
|
|||||||
|
|
||||||
@Override
|
@Override
|
||||||
protected void doExecute(@NotNull PartnerRunningFlowContext context) {
|
protected void doExecute(@NotNull PartnerRunningFlowContext context) {
|
||||||
String input = context.getInput();
|
String input = context.encodeInputsXml();
|
||||||
Result<ExtractorResult> result = actionExtractor.execute(input)
|
Result<ExtractorResult> result = actionExtractor.execute(input)
|
||||||
.onFailure(exp -> {
|
.onFailure(exp -> {
|
||||||
ExceptionReporterHandler.INSTANCE.report(exp, ContextExceptionReporter.REPORTER_NAME);
|
ExceptionReporterHandler.INSTANCE.report(exp, ContextExceptionReporter.REPORTER_NAME);
|
||||||
|
|||||||
@@ -38,7 +38,7 @@ public class CommunicationProducer extends AbstractAgentModule.Running<PartnerRu
|
|||||||
你接下来收到的消息固定分为三个区段:
|
你接下来收到的消息固定分为三个区段:
|
||||||
1. system message 是 Head, 用于说明整个输入结构与输出要求。
|
1. system message 是 Head, 用于说明整个输入结构与输出要求。
|
||||||
2. <context> 区段只承载 type=CONTEXT 的上下文块, 其中每个子块都带有独立来源, 仅作为理解当前状态与辅助决策的依据。
|
2. <context> 区段只承载 type=CONTEXT 的上下文块, 其中每个子块都带有独立来源, 仅作为理解当前状态与辅助决策的依据。
|
||||||
3. Conversation 区段是对话轨迹; 最新的一条 user message 会使用 <input> 结构, 其中 <content> 是本轮用户原始输入, 其他子标签是输入元信息与 type=SUPPLY 的补充块, 补充块会按 blockName 分区。
|
3. Conversation 区段是对话轨迹; 最新的一条 user message 会使用 <input> 结构, 其中 <inputs> 承载本轮按时间顺序排列的输入序列, 每个 <input> 节点会带有相对首条输入的时间间隔属性, 其他子标签是输入元信息与 type=SUPPLY 的补充块, 补充块会按 blockName 分区。
|
||||||
你必须综合 Context 与 Conversation 回答最新输入, 不要把 XML 标签当作需要原样复述给用户的内容。
|
你必须综合 Context 与 Conversation 回答最新输入, 不要把 XML 标签当作需要原样复述给用户的内容。
|
||||||
直接输出最终回应内容即可, 不需要额外包装为 JSON。
|
直接输出最终回应内容即可, 不需要额外包装为 JSON。
|
||||||
""";
|
""";
|
||||||
@@ -129,7 +129,7 @@ public class CommunicationProducer extends AbstractAgentModule.Running<PartnerRu
|
|||||||
Element root = document.createElement("input");
|
Element root = document.createElement("input");
|
||||||
document.appendChild(root);
|
document.appendChild(root);
|
||||||
|
|
||||||
appendTextElement(document, root, "content", runningFlowContext.getInput());
|
runningFlowContext.appendInputsXml(document, root);
|
||||||
appendTextElement(document, root, "source", runningFlowContext.getSource());
|
appendTextElement(document, root, "source", runningFlowContext.getSource());
|
||||||
for (Map.Entry<String, String> entry : runningFlowContext.getAdditionalUserInfo().entrySet()) {
|
for (Map.Entry<String, String> entry : runningFlowContext.getAdditionalUserInfo().entrySet()) {
|
||||||
appendTextElement(document, root, sanitizeTagName(entry.getKey()), entry.getValue());
|
appendTextElement(document, root, sanitizeTagName(entry.getKey()), entry.getValue());
|
||||||
@@ -159,7 +159,7 @@ public class CommunicationProducer extends AbstractAgentModule.Running<PartnerRu
|
|||||||
}
|
}
|
||||||
|
|
||||||
private String formatConversationUserMessage(PartnerRunningFlowContext runningFlowContext) {
|
private String formatConversationUserMessage(PartnerRunningFlowContext runningFlowContext) {
|
||||||
return "[" + runningFlowContext.getSource() + "]" + ": " + runningFlowContext.getInput();
|
return "[" + runningFlowContext.getSource() + "]" + ": " + runningFlowContext.formatInputsForHistory();
|
||||||
}
|
}
|
||||||
|
|
||||||
private Document newDocument() throws Exception {
|
private Document newDocument() throws Exception {
|
||||||
|
|||||||
@@ -6,14 +6,13 @@ import lombok.EqualsAndHashCode;
|
|||||||
import org.jetbrains.annotations.NotNull;
|
import org.jetbrains.annotations.NotNull;
|
||||||
import org.w3c.dom.Document;
|
import org.w3c.dom.Document;
|
||||||
import org.w3c.dom.Element;
|
import org.w3c.dom.Element;
|
||||||
import work.slhaf.partner.core.action.ActionCapability;
|
|
||||||
import work.slhaf.partner.core.action.ActionCore;
|
|
||||||
import work.slhaf.partner.core.cognition.BlockContent;
|
import work.slhaf.partner.core.cognition.BlockContent;
|
||||||
import work.slhaf.partner.core.cognition.CognitionCapability;
|
import work.slhaf.partner.core.cognition.CognitionCapability;
|
||||||
import work.slhaf.partner.core.cognition.ContextBlock;
|
import work.slhaf.partner.core.cognition.ContextBlock;
|
||||||
import work.slhaf.partner.framework.agent.factory.capability.annotation.InjectCapability;
|
import work.slhaf.partner.framework.agent.factory.capability.annotation.InjectCapability;
|
||||||
import work.slhaf.partner.framework.agent.factory.component.abstracts.AbstractAgentModule;
|
import work.slhaf.partner.framework.agent.factory.component.abstracts.AbstractAgentModule;
|
||||||
import work.slhaf.partner.framework.agent.factory.component.annotation.InjectModule;
|
import work.slhaf.partner.framework.agent.factory.component.annotation.InjectModule;
|
||||||
|
import work.slhaf.partner.framework.agent.interaction.flow.RunningFlowContext;
|
||||||
import work.slhaf.partner.framework.agent.model.pojo.Message;
|
import work.slhaf.partner.framework.agent.model.pojo.Message;
|
||||||
import work.slhaf.partner.module.memory.runtime.MemoryRuntime;
|
import work.slhaf.partner.module.memory.runtime.MemoryRuntime;
|
||||||
import work.slhaf.partner.module.memory.runtime.exception.MemoryLookupException;
|
import work.slhaf.partner.module.memory.runtime.exception.MemoryLookupException;
|
||||||
@@ -26,12 +25,7 @@ import work.slhaf.partner.module.memory.selector.extractor.entity.ExtractorResul
|
|||||||
import work.slhaf.partner.runtime.PartnerRunningFlowContext;
|
import work.slhaf.partner.runtime.PartnerRunningFlowContext;
|
||||||
|
|
||||||
import java.time.LocalDate;
|
import java.time.LocalDate;
|
||||||
import java.time.LocalDateTime;
|
|
||||||
import java.time.ZonedDateTime;
|
|
||||||
import java.util.*;
|
import java.util.*;
|
||||||
import java.util.concurrent.atomic.AtomicBoolean;
|
|
||||||
import java.util.concurrent.locks.Lock;
|
|
||||||
import java.util.concurrent.locks.ReentrantLock;
|
|
||||||
|
|
||||||
@EqualsAndHashCode(callSuper = true)
|
@EqualsAndHashCode(callSuper = true)
|
||||||
@Data
|
@Data
|
||||||
@@ -42,8 +36,6 @@ public class MemorySelector extends AbstractAgentModule.Running<PartnerRunningFl
|
|||||||
|
|
||||||
@InjectCapability
|
@InjectCapability
|
||||||
private CognitionCapability cognitionCapability;
|
private CognitionCapability cognitionCapability;
|
||||||
@InjectCapability
|
|
||||||
private ActionCapability actionCapability;
|
|
||||||
|
|
||||||
@InjectModule
|
@InjectModule
|
||||||
private MemoryRuntime memoryRuntime;
|
private MemoryRuntime memoryRuntime;
|
||||||
@@ -52,73 +44,22 @@ public class MemorySelector extends AbstractAgentModule.Running<PartnerRunningFl
|
|||||||
@InjectModule
|
@InjectModule
|
||||||
private MemorySelectExtractor memorySelectExtractor;
|
private MemorySelectExtractor memorySelectExtractor;
|
||||||
|
|
||||||
private AtomicBoolean memoryCalling = new AtomicBoolean(false);
|
|
||||||
private Map<LocalDateTime, String> collectedInputs = new HashMap<>();
|
|
||||||
private Lock inputsLock = new ReentrantLock();
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
protected void doExecute(@NotNull PartnerRunningFlowContext runningFlowContext) {
|
protected void doExecute(@NotNull PartnerRunningFlowContext runningFlowContext) {
|
||||||
inputsLock.lock();
|
List<RunningFlowContext.InputEntry> snapshotInputs = List.copyOf(runningFlowContext.getInputs());
|
||||||
try {
|
|
||||||
collectedInputs.put(ZonedDateTime.now().toLocalDateTime(), runningFlowContext.getInput());
|
|
||||||
} finally {
|
|
||||||
inputsLock.unlock();
|
|
||||||
}
|
|
||||||
|
|
||||||
tryStartMemoryRecallWorker();
|
|
||||||
}
|
|
||||||
|
|
||||||
private void tryStartMemoryRecallWorker() {
|
|
||||||
if (!memoryCalling.compareAndSet(false, true)) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
actionCapability.getExecutor(ActionCore.ExecutorType.VIRTUAL).execute(() -> {
|
|
||||||
try {
|
|
||||||
drainMemoryRecall();
|
|
||||||
} finally {
|
|
||||||
memoryCalling.set(false);
|
|
||||||
|
|
||||||
// 防止竞态:worker 退出前后,刚好来了新输入,但没有线程负责再拉起 worker
|
|
||||||
if (!collectedInputs.isEmpty()) {
|
|
||||||
tryStartMemoryRecallWorker();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
private void drainMemoryRecall() {
|
|
||||||
while (true) {
|
|
||||||
Map<LocalDateTime, String> snapshotInputs;
|
|
||||||
|
|
||||||
inputsLock.lock();
|
|
||||||
try {
|
|
||||||
if (collectedInputs.isEmpty()) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
snapshotInputs = new HashMap<>(collectedInputs);
|
|
||||||
collectedInputs.clear();
|
|
||||||
} finally {
|
|
||||||
inputsLock.unlock();
|
|
||||||
}
|
|
||||||
|
|
||||||
ExtractorInput input = new ExtractorInput(
|
ExtractorInput input = new ExtractorInput(
|
||||||
snapshotInputs,
|
snapshotInputs,
|
||||||
memoryRuntime.getTopicTree(),
|
memoryRuntime.getTopicTree(),
|
||||||
snapshotInputs.keySet()
|
runningFlowContext.getFirstInputDateTime().toLocalDate()
|
||||||
.stream()
|
|
||||||
.max(LocalDateTime::compareTo)
|
|
||||||
.orElseThrow()
|
|
||||||
.toLocalDate()
|
|
||||||
);
|
);
|
||||||
|
|
||||||
ExtractorResult extractorResult = memorySelectExtractor.execute(input);
|
ExtractorResult extractorResult = memorySelectExtractor.execute(input);
|
||||||
if (!extractorResult.getMatches().isEmpty()) {
|
if (extractorResult.getMatches().isEmpty()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
List<ActivatedMemorySlice> activatedSlices = selectAndEvaluateMemory(snapshotInputs, extractorResult);
|
List<ActivatedMemorySlice> activatedSlices = selectAndEvaluateMemory(snapshotInputs, extractorResult);
|
||||||
updateMemoryContext(activatedSlices);
|
updateMemoryContext(activatedSlices);
|
||||||
}
|
}
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private void updateMemoryContext(List<ActivatedMemorySlice> activatedSlices) {
|
private void updateMemoryContext(List<ActivatedMemorySlice> activatedSlices) {
|
||||||
cognitionCapability.contextWorkspace().register(new ContextBlock(
|
cognitionCapability.contextWorkspace().register(new ContextBlock(
|
||||||
@@ -205,7 +146,7 @@ public class MemorySelector extends AbstractAgentModule.Running<PartnerRunningFl
|
|||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
private List<ActivatedMemorySlice> selectAndEvaluateMemory(Map<LocalDateTime, String> snapshotInputs, ExtractorResult extractorResult) {
|
private List<ActivatedMemorySlice> selectAndEvaluateMemory(List<RunningFlowContext.InputEntry> snapshotInputs, ExtractorResult extractorResult) {
|
||||||
LinkedHashMap<String, ActivatedMemorySlice> candidates = new LinkedHashMap<>();
|
LinkedHashMap<String, ActivatedMemorySlice> candidates = new LinkedHashMap<>();
|
||||||
setMemoryCandidates(candidates, extractorResult.getMatches());
|
setMemoryCandidates(candidates, extractorResult.getMatches());
|
||||||
EvaluatorInput evaluatorInput = EvaluatorInput.builder()
|
EvaluatorInput evaluatorInput = EvaluatorInput.builder()
|
||||||
|
|||||||
@@ -21,7 +21,6 @@ import work.slhaf.partner.module.memory.selector.evaluator.entity.EvaluatorBatch
|
|||||||
import work.slhaf.partner.module.memory.selector.evaluator.entity.EvaluatorBatchResult;
|
import work.slhaf.partner.module.memory.selector.evaluator.entity.EvaluatorBatchResult;
|
||||||
import work.slhaf.partner.module.memory.selector.evaluator.entity.EvaluatorInput;
|
import work.slhaf.partner.module.memory.selector.evaluator.entity.EvaluatorInput;
|
||||||
|
|
||||||
import java.time.format.DateTimeFormatter;
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Locale;
|
import java.util.Locale;
|
||||||
@@ -89,9 +88,12 @@ public class SliceSelectEvaluator extends AbstractAgentModule.Sub<EvaluatorInput
|
|||||||
return new TaskBlock() {
|
return new TaskBlock() {
|
||||||
@Override
|
@Override
|
||||||
protected void fillXml(@NotNull Document document, @NotNull Element root) {
|
protected void fillXml(@NotNull Document document, @NotNull Element root) {
|
||||||
appendListElement(document, root, "new_inputs", "input", batchInput.getInputs().entrySet(), (inputElement, input) -> {
|
appendChildElement(document, root, "new_inputs", (inputsElement) -> {
|
||||||
inputElement.setAttribute("datetime", input.getKey().format(DateTimeFormatter.ofPattern("yyyy-MM-dd HH:mm:ss")));
|
appendListElement(document, inputsElement, "inputs", "input", batchInput.getInputs(), (inputElement, entry) -> {
|
||||||
inputElement.setTextContent(input.getValue());
|
inputElement.setAttribute("interval-to-first", String.valueOf(entry.getOffsetMillis()));
|
||||||
|
inputElement.setTextContent(entry.getContent());
|
||||||
|
return Unit.INSTANCE;
|
||||||
|
});
|
||||||
return Unit.INSTANCE;
|
return Unit.INSTANCE;
|
||||||
});
|
});
|
||||||
appendChildElement(document, root, "memory_slice", (element) -> {
|
appendChildElement(document, root, "memory_slice", (element) -> {
|
||||||
|
|||||||
@@ -2,14 +2,14 @@ package work.slhaf.partner.module.memory.selector.evaluator.entity;
|
|||||||
|
|
||||||
import lombok.AllArgsConstructor;
|
import lombok.AllArgsConstructor;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
|
import work.slhaf.partner.framework.agent.interaction.flow.RunningFlowContext;
|
||||||
import work.slhaf.partner.module.memory.selector.ActivatedMemorySlice;
|
import work.slhaf.partner.module.memory.selector.ActivatedMemorySlice;
|
||||||
|
|
||||||
import java.time.LocalDateTime;
|
import java.util.List;
|
||||||
import java.util.Map;
|
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
@AllArgsConstructor
|
@AllArgsConstructor
|
||||||
public class EvaluatorBatchInput {
|
public class EvaluatorBatchInput {
|
||||||
private Map<LocalDateTime, String> inputs;
|
private List<RunningFlowContext.InputEntry> inputs;
|
||||||
private ActivatedMemorySlice activatedMemorySlice;
|
private ActivatedMemorySlice activatedMemorySlice;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,15 +2,14 @@ package work.slhaf.partner.module.memory.selector.evaluator.entity;
|
|||||||
|
|
||||||
import lombok.Builder;
|
import lombok.Builder;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
|
import work.slhaf.partner.framework.agent.interaction.flow.RunningFlowContext;
|
||||||
import work.slhaf.partner.module.memory.selector.ActivatedMemorySlice;
|
import work.slhaf.partner.module.memory.selector.ActivatedMemorySlice;
|
||||||
|
|
||||||
import java.time.LocalDateTime;
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
@Builder
|
@Builder
|
||||||
public class EvaluatorInput {
|
public class EvaluatorInput {
|
||||||
private Map<LocalDateTime, String> inputs;
|
private List<RunningFlowContext.InputEntry> inputs;
|
||||||
private List<ActivatedMemorySlice> memorySlices;
|
private List<ActivatedMemorySlice> memorySlices;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -57,9 +57,12 @@ public class MemorySelectExtractor extends AbstractAgentModule.Sub<ExtractorInpu
|
|||||||
return new TaskBlock() {
|
return new TaskBlock() {
|
||||||
@Override
|
@Override
|
||||||
protected void fillXml(@NotNull Document document, @NotNull Element root) {
|
protected void fillXml(@NotNull Document document, @NotNull Element root) {
|
||||||
appendListElement(document, root, "new_inputs", "input", input.getInputs().entrySet(), (inputElement, input) -> {
|
appendChildElement(document, root, "new_inputs", (inputsElement) -> {
|
||||||
inputElement.setAttribute("datetime", input.getKey().format(DateTimeFormatter.ofPattern("yyyy-MM-dd HH:mm:ss")));
|
appendListElement(document, inputsElement, "inputs", "input", input.getInputs(), (inputElement, entry) -> {
|
||||||
inputElement.setTextContent(input.getValue());
|
inputElement.setAttribute("interval-to-first", String.valueOf(entry.getOffsetMillis()));
|
||||||
|
inputElement.setTextContent(entry.getContent());
|
||||||
|
return Unit.INSTANCE;
|
||||||
|
});
|
||||||
return Unit.INSTANCE;
|
return Unit.INSTANCE;
|
||||||
});
|
});
|
||||||
appendTextElement(document, root, "current_date", input.getDate().format(DateTimeFormatter.ofPattern("yyyy-MM-dd")));
|
appendTextElement(document, root, "current_date", input.getDate().format(DateTimeFormatter.ofPattern("yyyy-MM-dd")));
|
||||||
|
|||||||
@@ -2,15 +2,15 @@ package work.slhaf.partner.module.memory.selector.extractor.entity;
|
|||||||
|
|
||||||
import lombok.AllArgsConstructor;
|
import lombok.AllArgsConstructor;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
|
import work.slhaf.partner.framework.agent.interaction.flow.RunningFlowContext;
|
||||||
|
|
||||||
import java.time.LocalDate;
|
import java.time.LocalDate;
|
||||||
import java.time.LocalDateTime;
|
import java.util.List;
|
||||||
import java.util.Map;
|
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
@AllArgsConstructor
|
@AllArgsConstructor
|
||||||
public class ExtractorInput {
|
public class ExtractorInput {
|
||||||
private Map<LocalDateTime, String> inputs;
|
private List<RunningFlowContext.InputEntry> inputs;
|
||||||
private String topic_tree;
|
private String topic_tree;
|
||||||
private LocalDate date;
|
private LocalDate date;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,8 +4,12 @@ import work.slhaf.partner.framework.agent.interaction.flow.RunningFlowContext
|
|||||||
|
|
||||||
class PartnerRunningFlowContext private constructor(
|
class PartnerRunningFlowContext private constructor(
|
||||||
override val source: String,
|
override val source: String,
|
||||||
override val input: String,
|
inputs: List<InputEntry>,
|
||||||
) : RunningFlowContext() {
|
firstInputEpochMillis: Long,
|
||||||
|
additionalUserInfo: Map<String, String> = emptyMap(),
|
||||||
|
skippedModules: Set<String> = emptySet(),
|
||||||
|
target: String = source
|
||||||
|
) : RunningFlowContext(inputs, firstInputEpochMillis, additionalUserInfo, skippedModules, target) {
|
||||||
|
|
||||||
companion object {
|
companion object {
|
||||||
|
|
||||||
@@ -27,15 +31,39 @@ class PartnerRunningFlowContext private constructor(
|
|||||||
}
|
}
|
||||||
|
|
||||||
@JvmStatic
|
@JvmStatic
|
||||||
fun fromUser(userId: String, input: String) = PartnerRunningFlowContext(
|
fun fromUser(userId: String, input: String, receivedAtMillis: Long = System.currentTimeMillis()) =
|
||||||
|
PartnerRunningFlowContext(
|
||||||
SourceTag.buildUserSource(userId),
|
SourceTag.buildUserSource(userId),
|
||||||
input
|
listOf(InputEntry(0L, input)),
|
||||||
|
receivedAtMillis
|
||||||
)
|
)
|
||||||
|
|
||||||
@JvmStatic
|
@JvmStatic
|
||||||
fun fromSelf(input: String) = PartnerRunningFlowContext(SourceTag.buildAgentSource(), input).apply {
|
fun fromSelf(input: String, receivedAtMillis: Long = System.currentTimeMillis()) =
|
||||||
|
PartnerRunningFlowContext(
|
||||||
|
SourceTag.buildAgentSource(),
|
||||||
|
listOf(InputEntry(0L, input)),
|
||||||
|
receivedAtMillis
|
||||||
|
).apply {
|
||||||
putUserInfo(InfoKeys.PLATFORM, SOURCE_SELF_PLATFORM)
|
putUserInfo(InfoKeys.PLATFORM, SOURCE_SELF_PLATFORM)
|
||||||
putUserInfo(InfoKeys.NICKNAME, SOURCE_SELF_NICKNAME)
|
putUserInfo(InfoKeys.NICKNAME, SOURCE_SELF_NICKNAME)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
override fun copyWith(
|
||||||
|
inputs: List<InputEntry>,
|
||||||
|
firstInputEpochMillis: Long,
|
||||||
|
additionalUserInfo: Map<String, String>,
|
||||||
|
skippedModules: Set<String>,
|
||||||
|
target: String
|
||||||
|
): RunningFlowContext {
|
||||||
|
return PartnerRunningFlowContext(
|
||||||
|
source = source,
|
||||||
|
inputs = inputs,
|
||||||
|
firstInputEpochMillis = firstInputEpochMillis,
|
||||||
|
additionalUserInfo = additionalUserInfo,
|
||||||
|
skippedModules = skippedModules,
|
||||||
|
target = target
|
||||||
|
)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
@@ -15,18 +15,36 @@ import work.slhaf.partner.framework.agent.interaction.data.InteractionEvent
|
|||||||
import work.slhaf.partner.framework.agent.interaction.flow.RunningFlowContext
|
import work.slhaf.partner.framework.agent.interaction.flow.RunningFlowContext
|
||||||
import work.slhaf.partner.framework.agent.support.Result
|
import work.slhaf.partner.framework.agent.support.Result
|
||||||
import java.nio.file.Path
|
import java.nio.file.Path
|
||||||
|
import java.util.*
|
||||||
|
|
||||||
object AgentRuntime : Configurable, ConfigRegistration<ModuleMaskConfig> {
|
object AgentRuntime : Configurable, ConfigRegistration<ModuleMaskConfig> {
|
||||||
|
|
||||||
|
private const val DEFAULT_LOG_CHANNEL = "log_channel"
|
||||||
|
|
||||||
private val scope = CoroutineScope(SupervisorJob() + Dispatchers.Default)
|
private val scope = CoroutineScope(SupervisorJob() + Dispatchers.Default)
|
||||||
|
|
||||||
private val channel = Channel<RunningFlowContext>(Channel.UNLIMITED)
|
private val wakeSignal = Channel<Unit>(Channel.UNLIMITED)
|
||||||
private val responseChannels = mutableMapOf<String, ResponseChannel>(
|
private val stateLock = Any()
|
||||||
LogChannel.channelName to LogChannel
|
|
||||||
)
|
/**
|
||||||
|
* 按照 source 分开存储的最新的 context,input 聚合、其余信息按照最新输入
|
||||||
|
*/
|
||||||
|
private val latestContextsBySource = LinkedHashMap<String, RunningFlowContext>()
|
||||||
|
|
||||||
|
/**
|
||||||
|
* source 队列,其中元素不会重复,触发唤醒信号时,从该队列取出 source 并处理对应的 context
|
||||||
|
*/
|
||||||
|
private val sourceQueue = ArrayDeque<String>()
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 与对应 source 的最新 context 对应,用于记录 context 版本状态
|
||||||
|
*/
|
||||||
|
private val sourceVersions = mutableMapOf<String, Long>()
|
||||||
|
|
||||||
|
private val responseChannels = mutableMapOf<String, ResponseChannel>()
|
||||||
|
|
||||||
@Volatile
|
@Volatile
|
||||||
private var defaultChannel: String = LogChannel.channelName
|
private var defaultChannel: String = DEFAULT_LOG_CHANNEL
|
||||||
|
|
||||||
@Volatile
|
@Volatile
|
||||||
private var runningModules: Map<Int, List<AbstractAgentModule.Running<RunningFlowContext>>> = emptyMap()
|
private var runningModules: Map<Int, List<AbstractAgentModule.Running<RunningFlowContext>>> = emptyMap()
|
||||||
@@ -34,13 +52,20 @@ object AgentRuntime : Configurable, ConfigRegistration<ModuleMaskConfig> {
|
|||||||
@Volatile
|
@Volatile
|
||||||
private var maskedModules: Set<String> = emptySet()
|
private var maskedModules: Set<String> = emptySet()
|
||||||
|
|
||||||
|
@Volatile
|
||||||
|
private var currentExecutingSource: String? = null
|
||||||
|
|
||||||
|
@Volatile
|
||||||
|
private var currentExecutingContext: RunningFlowContext? = null
|
||||||
|
|
||||||
init {
|
init {
|
||||||
register()
|
register()
|
||||||
scope.launch {
|
scope.launch {
|
||||||
for (ctx in channel) {
|
for (@Suppress("UNUSED_VARIABLE") ignored in wakeSignal) {
|
||||||
executeTurn(ctx)
|
drainQueue()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
responseChannels.putIfAbsent(DEFAULT_LOG_CHANNEL, LogChannel)
|
||||||
}
|
}
|
||||||
|
|
||||||
fun registerResponseChannel(channelName: String, responseChannel: ResponseChannel) {
|
fun registerResponseChannel(channelName: String, responseChannel: ResponseChannel) {
|
||||||
@@ -48,7 +73,7 @@ object AgentRuntime : Configurable, ConfigRegistration<ModuleMaskConfig> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fun unregisterResponseChannel(channelName: String) {
|
fun unregisterResponseChannel(channelName: String) {
|
||||||
if (channelName == LogChannel.channelName) {
|
if (channelName == DEFAULT_LOG_CHANNEL) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
responseChannels.remove(channelName)
|
responseChannels.remove(channelName)
|
||||||
@@ -64,26 +89,95 @@ object AgentRuntime : Configurable, ConfigRegistration<ModuleMaskConfig> {
|
|||||||
fun response(event: InteractionEvent, channelName: String = defaultChannel) {
|
fun response(event: InteractionEvent, channelName: String = defaultChannel) {
|
||||||
val channel = responseChannels[channelName]
|
val channel = responseChannels[channelName]
|
||||||
if (channel == null) {
|
if (channel == null) {
|
||||||
responseChannels[defaultChannel]?.response(event) ?: LogChannel.response(event)
|
responseChannels[defaultChannel]?.response(event)
|
||||||
|
?: responseChannels[DEFAULT_LOG_CHANNEL]?.response(event)
|
||||||
|
?: LogChannel.response(event)
|
||||||
} else {
|
} else {
|
||||||
channel.response(event)
|
channel.response(event)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fun <C : RunningFlowContext> submit(context: C) = runBlocking {
|
fun <C : RunningFlowContext> submit(context: C) = runBlocking {
|
||||||
channel.send(context)
|
synchronized(stateLock) {
|
||||||
|
val source = context.source
|
||||||
|
latestContextsBySource[source] = latestContextsBySource[source]?.mergedWith(context) ?: context
|
||||||
|
sourceVersions[source] = (sourceVersions[source] ?: 0L) + 1L
|
||||||
|
if (!sourceQueue.contains(source)) {
|
||||||
|
sourceQueue.addLast(source)
|
||||||
|
}
|
||||||
|
if (currentExecutingSource == source) {
|
||||||
|
currentExecutingContext?.status?.interrupted = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
wakeSignal.send(Unit)
|
||||||
}
|
}
|
||||||
|
|
||||||
private suspend fun executeTurn(runningFlowContext: RunningFlowContext) {
|
private suspend fun drainQueue() {
|
||||||
|
while (true) {
|
||||||
|
val source = synchronized(stateLock) {
|
||||||
|
sourceQueue.firstOrNull()
|
||||||
|
} ?: return
|
||||||
|
executeSource(source)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private suspend fun executeSource(source: String) {
|
||||||
|
while (true) {
|
||||||
|
val execution = synchronized(stateLock) {
|
||||||
|
val context = latestContextsBySource[source] ?: run {
|
||||||
|
sourceQueue.remove(source)
|
||||||
|
sourceVersions.remove(source)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
currentExecutingSource = source
|
||||||
|
currentExecutingContext = context
|
||||||
|
context.status.interrupted = false
|
||||||
|
SourceExecution(context, sourceVersions[source] ?: 0L)
|
||||||
|
}
|
||||||
|
|
||||||
|
val interrupted = executeTurn(execution.context)
|
||||||
|
|
||||||
|
val shouldRetry = synchronized(stateLock) {
|
||||||
|
currentExecutingSource = null
|
||||||
|
currentExecutingContext = null
|
||||||
|
val latestContext = latestContextsBySource[source]
|
||||||
|
val latestVersion = sourceVersions[source] ?: execution.version
|
||||||
|
when {
|
||||||
|
latestContext == null -> {
|
||||||
|
sourceQueue.remove(source)
|
||||||
|
sourceVersions.remove(source)
|
||||||
|
false
|
||||||
|
}
|
||||||
|
|
||||||
|
interrupted || latestVersion != execution.version -> true
|
||||||
|
else -> {
|
||||||
|
latestContextsBySource.remove(source)
|
||||||
|
sourceQueue.remove(source)
|
||||||
|
sourceVersions.remove(source)
|
||||||
|
false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!shouldRetry) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private suspend fun executeTurn(runningFlowContext: RunningFlowContext): Boolean {
|
||||||
if (runningModules.isEmpty()) {
|
if (runningModules.isEmpty()) {
|
||||||
refreshRunningModules()
|
refreshRunningModules()
|
||||||
}
|
}
|
||||||
|
|
||||||
for (modules in runningModules.values) {
|
for (modules in runningModules.values) {
|
||||||
|
if (runningFlowContext.status.interrupted) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
executeOrder(modules, runningFlowContext)
|
executeOrder(modules, runningFlowContext)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return runningFlowContext.status.interrupted
|
||||||
}
|
}
|
||||||
|
|
||||||
private fun refreshRunningModules() {
|
private fun refreshRunningModules() {
|
||||||
@@ -102,6 +196,9 @@ object AgentRuntime : Configurable, ConfigRegistration<ModuleMaskConfig> {
|
|||||||
coroutineScope {
|
coroutineScope {
|
||||||
val jobs = modules.map { module ->
|
val jobs = modules.map { module ->
|
||||||
async {
|
async {
|
||||||
|
if (runningFlowContext.status.interrupted) {
|
||||||
|
return@async
|
||||||
|
}
|
||||||
if (runningFlowContext.skippedModules.contains(module.moduleName)) {
|
if (runningFlowContext.skippedModules.contains(module.moduleName)) {
|
||||||
return@async
|
return@async
|
||||||
}
|
}
|
||||||
@@ -144,6 +241,10 @@ object AgentRuntime : Configurable, ConfigRegistration<ModuleMaskConfig> {
|
|||||||
refreshRunningModules()
|
refreshRunningModules()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private data class SourceExecution(
|
||||||
|
val context: RunningFlowContext,
|
||||||
|
val version: Long
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
data class ModuleMaskConfig(
|
data class ModuleMaskConfig(
|
||||||
|
|||||||
@@ -1,39 +1,60 @@
|
|||||||
package work.slhaf.partner.framework.agent.interaction.flow
|
package work.slhaf.partner.framework.agent.interaction.flow
|
||||||
|
|
||||||
import com.alibaba.fastjson2.JSONObject
|
import com.alibaba.fastjson2.JSONObject
|
||||||
|
import org.w3c.dom.Document
|
||||||
|
import org.w3c.dom.Element
|
||||||
|
import java.time.Instant
|
||||||
import java.time.LocalDateTime
|
import java.time.LocalDateTime
|
||||||
import java.time.ZonedDateTime
|
import java.time.ZoneId
|
||||||
import java.util.*
|
import java.util.*
|
||||||
|
import kotlin.math.min
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 流程上下文
|
* 流程上下文
|
||||||
*/
|
*/
|
||||||
abstract class RunningFlowContext {
|
abstract class RunningFlowContext protected constructor(
|
||||||
|
inputs: List<InputEntry>,
|
||||||
|
val firstInputEpochMillis: Long,
|
||||||
|
additionalUserInfo: Map<String, String> = emptyMap(),
|
||||||
|
skippedModules: Set<String> = emptySet(),
|
||||||
|
target: String = ""
|
||||||
|
) {
|
||||||
/**
|
/**
|
||||||
* 消息来源: 由谁发出
|
* 消息来源: 由谁发出
|
||||||
*/
|
*/
|
||||||
abstract val source: String
|
abstract val source: String
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 消息内容
|
* 输入序列
|
||||||
*/
|
*/
|
||||||
abstract val input: String
|
val inputs: List<InputEntry> = inputs.sortedBy { it.offsetMillis }
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 兼容旧路径的纯文本输入表示,按时间顺序换行拼接
|
||||||
|
*/
|
||||||
|
val input: String
|
||||||
|
get() = formatInputsForHistory()
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 消息回应对象,默认与 source 一致
|
* 消息回应对象,默认与 source 一致
|
||||||
*/
|
*/
|
||||||
var target = source
|
var target: String = target
|
||||||
|
|
||||||
private val _additionalUserInfo = mutableMapOf<String, String>()
|
private val _additionalUserInfo = additionalUserInfo.toMutableMap()
|
||||||
val additionalUserInfo: Map<String, String>
|
val additionalUserInfo: Map<String, String>
|
||||||
get() = _additionalUserInfo
|
get() = _additionalUserInfo
|
||||||
|
|
||||||
private val _skippedModules = mutableSetOf<String>()
|
private val _skippedModules = skippedModules.toMutableSet()
|
||||||
val skippedModules: Set<String>
|
val skippedModules: Set<String>
|
||||||
get() = _skippedModules
|
get() = _skippedModules
|
||||||
|
|
||||||
val status = Status()
|
val status = Status()
|
||||||
|
|
||||||
|
val firstInputDateTime: LocalDateTime
|
||||||
|
get() = Instant.ofEpochMilli(firstInputEpochMillis)
|
||||||
|
.atZone(ZoneId.systemDefault())
|
||||||
|
.toLocalDateTime()
|
||||||
|
|
||||||
fun addSkippedModule(moduleName: String) {
|
fun addSkippedModule(moduleName: String) {
|
||||||
_skippedModules.add(moduleName)
|
_skippedModules.add(moduleName)
|
||||||
}
|
}
|
||||||
@@ -45,14 +66,104 @@ abstract class RunningFlowContext {
|
|||||||
fun putUserInfo(key: String, value: Any) {
|
fun putUserInfo(key: String, value: Any) {
|
||||||
_additionalUserInfo[key] = try {
|
_additionalUserInfo[key] = try {
|
||||||
JSONObject.toJSONString(value)
|
JSONObject.toJSONString(value)
|
||||||
} catch (e: Exception) {
|
} catch (_: Exception) {
|
||||||
value.toString()
|
value.toString()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fun formatInputsForHistory(): String = inputs.joinToString("\n") { it.content }
|
||||||
|
|
||||||
|
@JvmOverloads
|
||||||
|
fun appendInputsXml(
|
||||||
|
document: Document,
|
||||||
|
parent: Element,
|
||||||
|
containerTagName: String = "inputs",
|
||||||
|
inputTagName: String = "input",
|
||||||
|
intervalAttributeName: String = "interval-to-first"
|
||||||
|
) {
|
||||||
|
val inputsElement = document.createElement(containerTagName)
|
||||||
|
parent.appendChild(inputsElement)
|
||||||
|
inputs.forEach { entry ->
|
||||||
|
val inputElement = document.createElement(inputTagName)
|
||||||
|
inputElement.setAttribute(intervalAttributeName, entry.offsetMillis.toString())
|
||||||
|
inputElement.textContent = entry.content
|
||||||
|
inputsElement.appendChild(inputElement)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fun encodeInputsXml(): String {
|
||||||
|
val builder = StringBuilder()
|
||||||
|
builder.append("<inputs>")
|
||||||
|
inputs.forEach { entry ->
|
||||||
|
builder.append("<input interval-to-first=\"")
|
||||||
|
.append(escapeXml(entry.offsetMillis.toString()))
|
||||||
|
.append("\">")
|
||||||
|
.append(escapeXml(entry.content))
|
||||||
|
.append("</input>")
|
||||||
|
}
|
||||||
|
builder.append("</inputs>")
|
||||||
|
return builder.toString()
|
||||||
|
}
|
||||||
|
|
||||||
|
fun mergedWith(other: RunningFlowContext): RunningFlowContext {
|
||||||
|
require(source == other.source) {
|
||||||
|
"Unable to merge RunningFlowContext from different source: $source != ${other.source}"
|
||||||
|
}
|
||||||
|
val mergedFirstEpochMillis = min(firstInputEpochMillis, other.firstInputEpochMillis)
|
||||||
|
val mergedInputs = buildList(inputs.size + other.inputs.size) {
|
||||||
|
addAll(normalizeInputs(this@RunningFlowContext, mergedFirstEpochMillis))
|
||||||
|
addAll(normalizeInputs(other, mergedFirstEpochMillis))
|
||||||
|
}.sortedBy { it.offsetMillis }
|
||||||
|
|
||||||
|
val mergedAdditionalUserInfo = LinkedHashMap<String, String>(_additionalUserInfo)
|
||||||
|
mergedAdditionalUserInfo.putAll(other.additionalUserInfo)
|
||||||
|
|
||||||
|
val mergedSkippedModules = LinkedHashSet<String>(_skippedModules)
|
||||||
|
mergedSkippedModules.addAll(other.skippedModules)
|
||||||
|
|
||||||
|
return copyWith(
|
||||||
|
inputs = mergedInputs,
|
||||||
|
firstInputEpochMillis = mergedFirstEpochMillis,
|
||||||
|
additionalUserInfo = mergedAdditionalUserInfo,
|
||||||
|
skippedModules = mergedSkippedModules,
|
||||||
|
target = other.target.ifBlank { target }
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
protected abstract fun copyWith(
|
||||||
|
inputs: List<InputEntry>,
|
||||||
|
firstInputEpochMillis: Long,
|
||||||
|
additionalUserInfo: Map<String, String>,
|
||||||
|
skippedModules: Set<String>,
|
||||||
|
target: String
|
||||||
|
): RunningFlowContext
|
||||||
|
|
||||||
|
private fun normalizeInputs(context: RunningFlowContext, firstEpochMillis: Long): List<InputEntry> {
|
||||||
|
return context.inputs.map { entry ->
|
||||||
|
InputEntry(
|
||||||
|
offsetMillis = context.firstInputEpochMillis + entry.offsetMillis - firstEpochMillis,
|
||||||
|
content = entry.content
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private fun escapeXml(value: String): String {
|
||||||
|
return value
|
||||||
|
.replace("&", "&")
|
||||||
|
.replace("<", "<")
|
||||||
|
.replace(">", ">")
|
||||||
|
.replace("\"", """)
|
||||||
|
.replace("'", "'")
|
||||||
|
}
|
||||||
|
|
||||||
|
data class InputEntry(
|
||||||
|
val offsetMillis: Long,
|
||||||
|
val content: String
|
||||||
|
)
|
||||||
|
|
||||||
class Info {
|
class Info {
|
||||||
val uuid = UUID.randomUUID().toString()
|
val uuid = UUID.randomUUID().toString()
|
||||||
val dateTime: LocalDateTime = ZonedDateTime.now().toLocalDateTime()
|
val dateTime: LocalDateTime = LocalDateTime.now()
|
||||||
}
|
}
|
||||||
|
|
||||||
class Status {
|
class Status {
|
||||||
@@ -62,6 +173,12 @@ abstract class RunningFlowContext {
|
|||||||
val ok: Boolean
|
val ok: Boolean
|
||||||
get() = errors.isEmpty()
|
get() = errors.isEmpty()
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 模块边界上的协作式打断标记
|
||||||
|
*/
|
||||||
|
@Volatile
|
||||||
|
var interrupted: Boolean = false
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 本次执行时收集到的异常信息
|
* 本次执行时收集到的异常信息
|
||||||
*/
|
*/
|
||||||
|
|||||||
@@ -0,0 +1,224 @@
|
|||||||
|
package work.slhaf.partner.framework.agent.interaction
|
||||||
|
|
||||||
|
import org.junit.jupiter.api.AfterEach
|
||||||
|
import org.junit.jupiter.api.Assertions.assertEquals
|
||||||
|
import org.junit.jupiter.api.Assertions.assertTrue
|
||||||
|
import org.junit.jupiter.api.BeforeEach
|
||||||
|
import org.junit.jupiter.api.Test
|
||||||
|
import work.slhaf.partner.framework.agent.factory.component.abstracts.AbstractAgentModule
|
||||||
|
import work.slhaf.partner.framework.agent.factory.context.AgentContext
|
||||||
|
import work.slhaf.partner.framework.agent.factory.context.ModuleContextData
|
||||||
|
import work.slhaf.partner.framework.agent.interaction.flow.RunningFlowContext
|
||||||
|
import java.time.ZonedDateTime
|
||||||
|
import java.util.concurrent.CopyOnWriteArrayList
|
||||||
|
import java.util.concurrent.CountDownLatch
|
||||||
|
import java.util.concurrent.TimeUnit
|
||||||
|
import java.util.concurrent.atomic.AtomicInteger
|
||||||
|
|
||||||
|
class AgentRuntimeTest {
|
||||||
|
|
||||||
|
@BeforeEach
|
||||||
|
fun setUp() {
|
||||||
|
resetAgentRuntime()
|
||||||
|
clearModules()
|
||||||
|
}
|
||||||
|
|
||||||
|
@AfterEach
|
||||||
|
fun tearDown() {
|
||||||
|
resetAgentRuntime()
|
||||||
|
clearModules()
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun `running flow context preserves offsets and xml encoding`() {
|
||||||
|
val first = TestRunningFlowContext.of("source-a", "first", 1_000L)
|
||||||
|
val second = TestRunningFlowContext.of("source-a", "second", 1_250L)
|
||||||
|
|
||||||
|
val merged = first.mergedWith(second)
|
||||||
|
|
||||||
|
assertEquals(listOf(0L, 250L), merged.inputs.map { it.offsetMillis })
|
||||||
|
assertEquals("first\nsecond", merged.input)
|
||||||
|
assertEquals(
|
||||||
|
"<inputs><input interval-to-first=\"0\">first</input><input interval-to-first=\"250\">second</input></inputs>",
|
||||||
|
merged.encodeInputsXml()
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun `agent runtime keeps source queue in first arrival order`() {
|
||||||
|
val recorder = RecordingModule(order = 1, expectedExecutions = 2)
|
||||||
|
registerModule("queue-recorder", recorder)
|
||||||
|
|
||||||
|
AgentRuntime.submit(TestRunningFlowContext.of("source-a", "alpha"))
|
||||||
|
AgentRuntime.submit(TestRunningFlowContext.of("source-b", "beta"))
|
||||||
|
|
||||||
|
assertTrue(recorder.latch.await(5, TimeUnit.SECONDS))
|
||||||
|
assertEquals(listOf("source-a", "source-b"), recorder.sources)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun `agent runtime interrupts current source and reruns from chain head with merged context`() {
|
||||||
|
val blocking = BlockingModule()
|
||||||
|
val finalizer = RecordingModule(order = 2, expectedExecutions = 1)
|
||||||
|
registerModule("blocking-module", blocking)
|
||||||
|
registerModule("finalizer-module", finalizer)
|
||||||
|
|
||||||
|
AgentRuntime.submit(TestRunningFlowContext.of("source-a", "first", 1_000L))
|
||||||
|
assertTrue(blocking.firstExecutionStarted.await(5, TimeUnit.SECONDS))
|
||||||
|
|
||||||
|
AgentRuntime.submit(TestRunningFlowContext.of("source-a", "second", 1_300L))
|
||||||
|
blocking.releaseFirstExecution.countDown()
|
||||||
|
|
||||||
|
assertTrue(finalizer.latch.await(5, TimeUnit.SECONDS))
|
||||||
|
waitUntil { blocking.seenInputSizes.size >= 2 }
|
||||||
|
|
||||||
|
assertEquals(listOf(1, 2), blocking.seenInputSizes)
|
||||||
|
assertEquals(listOf(2), finalizer.inputSizes)
|
||||||
|
assertEquals(listOf("first\nsecond"), finalizer.historyInputs)
|
||||||
|
}
|
||||||
|
|
||||||
|
private fun registerModule(name: String, module: AbstractAgentModule.Running<*>) {
|
||||||
|
@Suppress("UNCHECKED_CAST")
|
||||||
|
AgentContext.addModule(
|
||||||
|
name,
|
||||||
|
ModuleContextData.Running(
|
||||||
|
module.javaClass,
|
||||||
|
module,
|
||||||
|
ZonedDateTime.now(),
|
||||||
|
null,
|
||||||
|
module.order()
|
||||||
|
) as ModuleContextData<AbstractAgentModule>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
private fun clearModules() {
|
||||||
|
@Suppress("UNCHECKED_CAST")
|
||||||
|
val modules = AgentContext.modules as MutableMap<String, ModuleContextData<AbstractAgentModule>>
|
||||||
|
modules.clear()
|
||||||
|
}
|
||||||
|
|
||||||
|
private fun resetAgentRuntime() {
|
||||||
|
setPrivateField("runningModules", emptyMap<Int, List<AbstractAgentModule.Running<RunningFlowContext>>>())
|
||||||
|
setPrivateField("maskedModules", emptySet<String>())
|
||||||
|
setPrivateField("currentExecutingSource", null)
|
||||||
|
setPrivateField("currentExecutingContext", null)
|
||||||
|
getPrivateMutableMap<String, RunningFlowContext>("latestContextsBySource").clear()
|
||||||
|
getPrivateMutableMap<String, Long>("sourceVersions").clear()
|
||||||
|
getPrivateDeque<String>("sourceQueue").clear()
|
||||||
|
}
|
||||||
|
|
||||||
|
private fun waitUntil(timeoutMillis: Long = 5_000L, condition: () -> Boolean) {
|
||||||
|
val deadline = System.currentTimeMillis() + timeoutMillis
|
||||||
|
while (System.currentTimeMillis() < deadline) {
|
||||||
|
if (condition()) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
Thread.sleep(20L)
|
||||||
|
}
|
||||||
|
error("Condition was not satisfied within $timeoutMillis ms")
|
||||||
|
}
|
||||||
|
|
||||||
|
private fun setPrivateField(fieldName: String, value: Any?) {
|
||||||
|
val field = AgentRuntime::class.java.getDeclaredField(fieldName)
|
||||||
|
field.isAccessible = true
|
||||||
|
field.set(AgentRuntime, value)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Suppress("UNCHECKED_CAST")
|
||||||
|
private fun <K, V> getPrivateMutableMap(fieldName: String): MutableMap<K, V> {
|
||||||
|
val field = AgentRuntime::class.java.getDeclaredField(fieldName)
|
||||||
|
field.isAccessible = true
|
||||||
|
return field.get(AgentRuntime) as MutableMap<K, V>
|
||||||
|
}
|
||||||
|
|
||||||
|
@Suppress("UNCHECKED_CAST")
|
||||||
|
private fun <T> getPrivateDeque(fieldName: String): java.util.ArrayDeque<T> {
|
||||||
|
val field = AgentRuntime::class.java.getDeclaredField(fieldName)
|
||||||
|
field.isAccessible = true
|
||||||
|
return field.get(AgentRuntime) as java.util.ArrayDeque<T>
|
||||||
|
}
|
||||||
|
|
||||||
|
private class RecordingModule(
|
||||||
|
private val order: Int,
|
||||||
|
expectedExecutions: Int
|
||||||
|
) : AbstractAgentModule.Running<TestRunningFlowContext>() {
|
||||||
|
val sources = CopyOnWriteArrayList<String>()
|
||||||
|
val inputSizes = CopyOnWriteArrayList<Int>()
|
||||||
|
val historyInputs = CopyOnWriteArrayList<String>()
|
||||||
|
val latch = CountDownLatch(expectedExecutions)
|
||||||
|
|
||||||
|
init {
|
||||||
|
moduleName = "recording-$order"
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun doExecute(context: TestRunningFlowContext) {
|
||||||
|
sources.add(context.source)
|
||||||
|
inputSizes.add(context.inputs.size)
|
||||||
|
historyInputs.add(context.input)
|
||||||
|
latch.countDown()
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun order(): Int = order
|
||||||
|
}
|
||||||
|
|
||||||
|
private class BlockingModule : AbstractAgentModule.Running<TestRunningFlowContext>() {
|
||||||
|
val seenInputSizes = CopyOnWriteArrayList<Int>()
|
||||||
|
val firstExecutionStarted = CountDownLatch(1)
|
||||||
|
val releaseFirstExecution = CountDownLatch(1)
|
||||||
|
private val invocationCount = AtomicInteger(0)
|
||||||
|
|
||||||
|
init {
|
||||||
|
moduleName = "blocking"
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun doExecute(context: TestRunningFlowContext) {
|
||||||
|
seenInputSizes.add(context.inputs.size)
|
||||||
|
if (invocationCount.getAndIncrement() == 0) {
|
||||||
|
firstExecutionStarted.countDown()
|
||||||
|
releaseFirstExecution.await(5, TimeUnit.SECONDS)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun order(): Int = 1
|
||||||
|
}
|
||||||
|
|
||||||
|
private class TestRunningFlowContext private constructor(
|
||||||
|
override val source: String,
|
||||||
|
inputs: List<InputEntry>,
|
||||||
|
firstInputEpochMillis: Long,
|
||||||
|
target: String = source
|
||||||
|
) : RunningFlowContext(inputs, firstInputEpochMillis, target = target) {
|
||||||
|
|
||||||
|
companion object {
|
||||||
|
fun of(
|
||||||
|
source: String,
|
||||||
|
input: String,
|
||||||
|
receivedAtMillis: Long = System.currentTimeMillis()
|
||||||
|
): TestRunningFlowContext {
|
||||||
|
return TestRunningFlowContext(
|
||||||
|
source = source,
|
||||||
|
inputs = listOf(InputEntry(0L, input)),
|
||||||
|
firstInputEpochMillis = receivedAtMillis
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun copyWith(
|
||||||
|
inputs: List<InputEntry>,
|
||||||
|
firstInputEpochMillis: Long,
|
||||||
|
additionalUserInfo: Map<String, String>,
|
||||||
|
skippedModules: Set<String>,
|
||||||
|
target: String
|
||||||
|
): RunningFlowContext {
|
||||||
|
return TestRunningFlowContext(
|
||||||
|
source = source,
|
||||||
|
inputs = inputs,
|
||||||
|
firstInputEpochMillis = firstInputEpochMillis,
|
||||||
|
target = target
|
||||||
|
).apply {
|
||||||
|
additionalUserInfo.forEach(::putUserInfo)
|
||||||
|
skippedModules.forEach(::addSkippedModule)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user