refactor(memory-selector): refactor memory recalling into async worker

This commit is contained in:
2026-04-25 23:22:43 +08:00
parent 94adf9a368
commit eade39328a
4 changed files with 147 additions and 21 deletions

View File

@@ -0,0 +1,15 @@
package work.slhaf.partner.module.memory.selector;
import lombok.AllArgsConstructor;
import lombok.Data;
import work.slhaf.partner.framework.agent.interaction.flow.RunningFlowContext;
import java.time.LocalDateTime;
import java.util.List;
@Data
@AllArgsConstructor
public class MemoryInputEntry {
private LocalDateTime receivedDateTime;
private List<RunningFlowContext.InputEntry> inputs;
}

View File

@@ -5,6 +5,8 @@ import lombok.extern.slf4j.Slf4j;
import org.jetbrains.annotations.NotNull; import org.jetbrains.annotations.NotNull;
import org.w3c.dom.Document; import org.w3c.dom.Document;
import org.w3c.dom.Element; import org.w3c.dom.Element;
import work.slhaf.partner.core.action.ActionCapability;
import work.slhaf.partner.core.action.ActionCore;
import work.slhaf.partner.core.cognition.BlockContent; import work.slhaf.partner.core.cognition.BlockContent;
import work.slhaf.partner.core.cognition.CognitionCapability; import work.slhaf.partner.core.cognition.CognitionCapability;
import work.slhaf.partner.core.cognition.ContextBlock; import work.slhaf.partner.core.cognition.ContextBlock;
@@ -23,7 +25,11 @@ import work.slhaf.partner.module.memory.selector.extractor.entity.ExtractorResul
import work.slhaf.partner.runtime.PartnerRunningFlowContext; import work.slhaf.partner.runtime.PartnerRunningFlowContext;
import java.time.LocalDate; import java.time.LocalDate;
import java.time.ZoneId;
import java.util.*; import java.util.*;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReentrantLock;
@Slf4j @Slf4j
public class MemorySelector extends AbstractAgentModule.Running<PartnerRunningFlowContext> { public class MemorySelector extends AbstractAgentModule.Running<PartnerRunningFlowContext> {
@@ -33,6 +39,7 @@ public class MemorySelector extends AbstractAgentModule.Running<PartnerRunningFl
@InjectCapability @InjectCapability
private CognitionCapability cognitionCapability; private CognitionCapability cognitionCapability;
private final AtomicBoolean memoryCalling = new AtomicBoolean(false);
@InjectModule @InjectModule
private MemoryRuntime memoryRuntime; private MemoryRuntime memoryRuntime;
@@ -40,24 +47,125 @@ public class MemorySelector extends AbstractAgentModule.Running<PartnerRunningFl
private MemoryRecallEvaluator memoryRecallEvaluator; private MemoryRecallEvaluator memoryRecallEvaluator;
@InjectModule @InjectModule
private MemoryRecallCueExtractor memoryRecallCueExtractor; private MemoryRecallCueExtractor memoryRecallCueExtractor;
private final Lock inputsLock = new ReentrantLock();
private final List<MemoryInputEntry> collectedInputs = new ArrayList<>();
@InjectCapability
private ActionCapability actionCapability;
@Override @Override
protected void doExecute(@NotNull PartnerRunningFlowContext runningFlowContext) { protected void doExecute(@NotNull PartnerRunningFlowContext runningFlowContext) {
List<RunningFlowContext.InputEntry> snapshotInputs = List.copyOf(runningFlowContext.getInputs()); collectInputs(runningFlowContext);
tryStartMemoryRecallWorker();
}
private void collectInputs(PartnerRunningFlowContext runningFlowContext) {
inputsLock.lock();
try {
collectedInputs.add(new MemoryInputEntry(
runningFlowContext.getFirstInputDateTime(),
List.copyOf(runningFlowContext.getInputs())
));
} finally {
inputsLock.unlock();
}
}
private void tryStartMemoryRecallWorker() {
if (!memoryCalling.compareAndSet(false, true)) {
return;
}
actionCapability.getExecutor(ActionCore.ExecutorType.VIRTUAL).execute(() -> {
try {
drainMemoryRecall();
} finally {
memoryCalling.set(false);
if (hasCollectedInputs()) {
tryStartMemoryRecallWorker();
}
}
});
}
private void drainMemoryRecall() {
while (true) {
List<MemoryInputEntry> snapshotInputs = drainCollectedInputs();
if (snapshotInputs.isEmpty()) {
return;
}
try {
recallMemory(snapshotInputs);
} catch (Exception e) {
log.error("[MemorySelector] 记忆召回任务执行失败", e);
}
}
}
private void recallMemory(List<MemoryInputEntry> memoryInputEntries) {
ExtractorInput input = new ExtractorInput( ExtractorInput input = new ExtractorInput(
snapshotInputs, memoryInputEntries,
memoryRuntime.getTopicTree(), memoryRuntime.getTopicTree()
runningFlowContext.getFirstInputDateTime().toLocalDate()
); );
ExtractorResult extractorResult = memoryRecallCueExtractor.execute(input); ExtractorResult extractorResult = memoryRecallCueExtractor.execute(input);
if (extractorResult.getMatches().isEmpty()) { if (extractorResult.getMatches().isEmpty()) {
return; return;
} }
List<ActivatedMemorySlice> activatedSlices = selectAndEvaluateMemory(snapshotInputs, extractorResult); List<ActivatedMemorySlice> activatedSlices = selectAndEvaluateMemory(flattenInputs(memoryInputEntries), extractorResult);
updateMemoryContext(activatedSlices); updateMemoryContext(activatedSlices);
} }
private List<MemoryInputEntry> drainCollectedInputs() {
inputsLock.lock();
try {
if (collectedInputs.isEmpty()) {
return List.of();
}
List<MemoryInputEntry> snapshot = new ArrayList<>(collectedInputs);
collectedInputs.clear();
snapshot.sort(Comparator.comparing(MemoryInputEntry::getReceivedDateTime));
return snapshot;
} finally {
inputsLock.unlock();
}
}
private boolean hasCollectedInputs() {
inputsLock.lock();
try {
return !collectedInputs.isEmpty();
} finally {
inputsLock.unlock();
}
}
private List<RunningFlowContext.InputEntry> flattenInputs(List<MemoryInputEntry> memoryInputEntries) {
if (memoryInputEntries.isEmpty()) {
return List.of();
}
long firstEpochMillis = memoryInputEntries.stream()
.map(MemoryInputEntry::getReceivedDateTime)
.mapToLong(this::toEpochMillis)
.min()
.orElseThrow();
return memoryInputEntries.stream()
.flatMap(entry -> {
long entryEpochMillis = toEpochMillis(entry.getReceivedDateTime());
return entry.getInputs().stream()
.map(input -> new RunningFlowContext.InputEntry(
entryEpochMillis + input.getOffsetMillis() - firstEpochMillis,
input.getContent()
));
})
.sorted(Comparator.comparingLong(RunningFlowContext.InputEntry::getOffsetMillis))
.toList();
}
private long toEpochMillis(java.time.LocalDateTime dateTime) {
return dateTime.atZone(ZoneId.systemDefault()).toInstant().toEpochMilli();
}
private void updateMemoryContext(List<ActivatedMemorySlice> activatedSlices) { private void updateMemoryContext(List<ActivatedMemorySlice> activatedSlices) {
cognitionCapability.contextWorkspace().register(new ContextBlock( cognitionCapability.contextWorkspace().register(new ContextBlock(
buildMemoryFullBlock(activatedSlices), buildMemoryFullBlock(activatedSlices),

View File

@@ -19,6 +19,7 @@ import work.slhaf.partner.module.memory.selector.extractor.entity.ExtractorResul
import java.time.format.DateTimeFormatter; import java.time.format.DateTimeFormatter;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Comparator;
import java.util.List; import java.util.List;
public class MemoryRecallCueExtractor extends AbstractAgentModule.Sub<ExtractorInput, ExtractorResult> implements ActivateModel { public class MemoryRecallCueExtractor extends AbstractAgentModule.Sub<ExtractorInput, ExtractorResult> implements ActivateModel {
@@ -29,12 +30,11 @@ public class MemoryRecallCueExtractor extends AbstractAgentModule.Sub<ExtractorI
你会收到: 你会收到:
- 一条结构化上下文消息,其中包含当前活跃的 communication 域与 memory 域内容; - 一条结构化上下文消息,其中包含当前活跃的 communication 域与 memory 域内容;
- 一条任务消息,其中包含: - 一条任务消息,其中包含:
- new_inputs一组按时间顺序累积的新输入每条输入附带 interval-to-first - memory_input_entries一组按接收时间排序的新输入批次每个批次包含 received_date_time 与该批次内按时间顺序排列的 inputs
- current_date当前日期
- memory_topic_tree当前可用的记忆主题树结构。 - memory_topic_tree当前可用的记忆主题树结构。
你的任务: 你的任务:
- 基于 new_inputs、当前语境与已有记忆主题树提取本次记忆召回最值得尝试的匹配项 - 基于 memory_input_entries、当前语境与已有记忆主题树提取本次记忆召回最值得尝试的匹配项
- 匹配项只允许有两类topic 或 date - 匹配项只允许有两类topic 或 date
- topic 用于表示应优先检索的记忆主题路径; - topic 用于表示应优先检索的记忆主题路径;
- date 用于表示应优先检索的具体日期; - date 用于表示应优先检索的具体日期;
@@ -42,7 +42,7 @@ public class MemoryRecallCueExtractor extends AbstractAgentModule.Sub<ExtractorI
提取原则: 提取原则:
- 你的目标是提取“可用于后续召回”的线索,而不是复述输入内容本身。 - 你的目标是提取“可用于后续召回”的线索,而不是复述输入内容本身。
- new_inputs 应整体理解,不要只抓最后一句;如果多条输入共同收敛到同一记忆方向,应提取更稳定的主题线索。 - memory_input_entries 应整体理解,不要只抓最后一个批次或最后一句;如果多个批次共同收敛到同一记忆方向,应提取更稳定的主题线索。
- communication 域用于判断当前输入是否在承接近期某段对话、某个旧话题或某个已出现过的指代对象。 - communication 域用于判断当前输入是否在承接近期某段对话、某个旧话题或某个已出现过的指代对象。
- memory 域用于辅助判断当前输入与哪些已激活记忆方向明显相关;只有在这种相关性明确时才使用,不要机械复述 memory 域内容。 - memory 域用于辅助判断当前输入与哪些已激活记忆方向明显相关;只有在这种相关性明确时才使用,不要机械复述 memory 域内容。
- memory_topic_tree 是 topic 提取的主要参照topic 应尽量贴近主题树中已有的层级与命名,不要随意发明与主题树无关的新路径。 - memory_topic_tree 是 topic 提取的主要参照topic 应尽量贴近主题树中已有的层级与命名,不要随意发明与主题树无关的新路径。
@@ -60,8 +60,8 @@ public class MemoryRecallCueExtractor extends AbstractAgentModule.Sub<ExtractorI
关于 date 关于 date
- date 表示一个明确的记忆日期。 - date 表示一个明确的记忆日期。
- date 的 text 必须是可被 Java LocalDate.parse 正常解析的日期文本,即 yyyy-MM-dd。 - date 的 text 必须是可被 Java LocalDate.parse 正常解析的日期文本,即 yyyy-MM-dd。
- 只有在输入中存在明确日期,或结合 current_date 后可以稳定推断出具体某一天时,才输出 date。 - 只有在输入中存在明确日期,或结合相关输入批次的 received_date_time 后可以稳定推断出具体某一天时,才输出 date。
- 像“今天”“昨天”“前天”“上周六”这类表达,只有在能够稳定落到某个具体日期时才可输出。 - 像“今天”“昨天”“前天”“上周六”这类表达,只有在能够根据对应批次 received_date_time 稳定落到某个具体日期时才可输出。
- 对于“最近”“前几天”“那段时间”“之前”“上次”这类无法稳定定位到某一天的表达,不要输出 date。 - 对于“最近”“前几天”“那段时间”“之前”“上次”这类无法稳定定位到某一天的表达,不要输出 date。
何时应提取 topic 何时应提取 topic
@@ -72,7 +72,7 @@ public class MemoryRecallCueExtractor extends AbstractAgentModule.Sub<ExtractorI
何时应提取 date 何时应提取 date
- 当前输入明确提到了某个具体日期; - 当前输入明确提到了某个具体日期;
- 当前输入使用相对日期表达,但结合 current_date 可以稳定还原到具体某一天; - 当前输入使用相对日期表达,但结合对应批次的 received_date_time 可以稳定还原到具体某一天;
- 当前输入中的回忆目标明显依赖某个特定日期,且该日期能够被明确确定。 - 当前输入中的回忆目标明显依赖某个特定日期,且该日期能够被明确确定。
何时不应轻易输出: 何时不应轻易输出:
@@ -135,15 +135,20 @@ public class MemoryRecallCueExtractor extends AbstractAgentModule.Sub<ExtractorI
return new TaskBlock() { return new TaskBlock() {
@Override @Override
protected void fillXml(@NotNull Document document, @NotNull Element root) { protected void fillXml(@NotNull Document document, @NotNull Element root) {
appendChildElement(document, root, "new_inputs", (inputsElement) -> { appendChildElement(document, root, "memory_input_entries", (entriesElement) -> {
appendListElement(document, inputsElement, "inputs", "input", input.getInputs(), (inputElement, entry) -> { appendRepeatedElements(document, entriesElement, "memory_input_entry", input.getMemoryInputEntries().stream()
inputElement.setAttribute("interval-to-first", String.valueOf(entry.getOffsetMillis())); .sorted(Comparator.comparing(work.slhaf.partner.module.memory.selector.MemoryInputEntry::getReceivedDateTime))
inputElement.setTextContent(entry.getContent()); .toList(), (entryElement, memoryInputEntry) -> {
appendTextElement(document, entryElement, "received_date_time", memoryInputEntry.getReceivedDateTime().format(DateTimeFormatter.ISO_LOCAL_DATE_TIME));
appendListElement(document, entryElement, "inputs", "input", memoryInputEntry.getInputs(), (inputElement, entry) -> {
inputElement.setAttribute("interval-to-first", String.valueOf(entry.getOffsetMillis()));
inputElement.setTextContent(entry.getContent());
return Unit.INSTANCE;
});
return Unit.INSTANCE; return Unit.INSTANCE;
}); });
return Unit.INSTANCE; return Unit.INSTANCE;
}); });
appendTextElement(document, root, "current_date", input.getDate().format(DateTimeFormatter.ofPattern("yyyy-MM-dd")));
appendTextElement(document, root, "memory_topic_tree", input.getTopic_tree()); appendTextElement(document, root, "memory_topic_tree", input.getTopic_tree());
} }
}.encodeToMessage(); }.encodeToMessage();

View File

@@ -2,15 +2,13 @@ package work.slhaf.partner.module.memory.selector.extractor.entity;
import lombok.AllArgsConstructor; import lombok.AllArgsConstructor;
import lombok.Data; import lombok.Data;
import work.slhaf.partner.framework.agent.interaction.flow.RunningFlowContext; import work.slhaf.partner.module.memory.selector.MemoryInputEntry;
import java.time.LocalDate;
import java.util.List; import java.util.List;
@Data @Data
@AllArgsConstructor @AllArgsConstructor
public class ExtractorInput { public class ExtractorInput {
private List<RunningFlowContext.InputEntry> inputs; private List<MemoryInputEntry> memoryInputEntries;
private String topic_tree; private String topic_tree;
private LocalDate date;
} }