refactor(memory): refactor memory selector into asynchronous module

This commit is contained in:
2026-03-28 22:41:48 +08:00
parent 09f90d8ad5
commit db20e0ca78
3 changed files with 94 additions and 24 deletions

View File

@@ -1,13 +1,14 @@
package work.slhaf.partner.module.memory.selector;
import com.alibaba.fastjson2.JSONObject;
import lombok.Data;
import lombok.EqualsAndHashCode;
import org.jetbrains.annotations.NotNull;
import work.slhaf.partner.api.agent.factory.capability.annotation.InjectCapability;
import work.slhaf.partner.api.agent.factory.component.abstracts.AbstractAgentModule;
import work.slhaf.partner.api.agent.factory.component.annotation.InjectModule;
import work.slhaf.partner.core.action.ActionCapability;
import work.slhaf.partner.core.action.ActionCore;
import work.slhaf.partner.core.cognition.CognitionCapability;
import work.slhaf.partner.core.memory.MemoryCapability;
import work.slhaf.partner.core.memory.exception.UnExistedDateIndexException;
import work.slhaf.partner.core.memory.exception.UnExistedTopicException;
import work.slhaf.partner.core.memory.pojo.ActivatedMemorySlice;
@@ -15,23 +16,28 @@ import work.slhaf.partner.module.memory.runtime.MemoryRuntime;
import work.slhaf.partner.module.memory.selector.evaluator.SliceSelectEvaluator;
import work.slhaf.partner.module.memory.selector.evaluator.entity.EvaluatorInput;
import work.slhaf.partner.module.memory.selector.extractor.MemorySelectExtractor;
import work.slhaf.partner.module.memory.selector.extractor.entity.ExtractorInput;
import work.slhaf.partner.module.memory.selector.extractor.entity.ExtractorMatchData;
import work.slhaf.partner.module.memory.selector.extractor.entity.ExtractorResult;
import work.slhaf.partner.runtime.interaction.data.context.PartnerRunningFlowContext;
import java.time.LocalDate;
import java.util.ArrayList;
import java.util.Collection;
import java.util.LinkedHashMap;
import java.util.List;
import java.time.LocalDateTime;
import java.time.ZonedDateTime;
import java.util.*;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReentrantLock;
@EqualsAndHashCode(callSuper = true)
@Data
public class MemorySelector extends AbstractAgentModule.Running<PartnerRunningFlowContext> {
@InjectCapability
private MemoryCapability memoryCapability;
@InjectCapability
private CognitionCapability cognitionCapability;
@InjectCapability
private ActionCapability actionCapability;
@InjectModule
private MemoryRuntime memoryRuntime;
@InjectModule
@@ -39,31 +45,88 @@ public class MemorySelector extends AbstractAgentModule.Running<PartnerRunningFl
@InjectModule
private MemorySelectExtractor memorySelectExtractor;
private AtomicBoolean memoryCalling = new AtomicBoolean(false);
private Map<LocalDateTime, String> collectedInputs = new HashMap<>();
private Lock inputsLock = new ReentrantLock();
@Override
public void execute(PartnerRunningFlowContext runningFlowContext) {
ExtractorResult extractorResult = memorySelectExtractor.execute(runningFlowContext);
if (extractorResult.isRecall() || !extractorResult.getMatches().isEmpty()) {
memoryCapability.clearActivatedSlices();
List<ActivatedMemorySlice> activatedSlices = selectAndEvaluateMemory(runningFlowContext, extractorResult);
memoryCapability.updateActivatedSlices(activatedSlices);
public void execute(@NotNull PartnerRunningFlowContext runningFlowContext) {
inputsLock.lock();
try {
collectedInputs.put(ZonedDateTime.now().toLocalDateTime(), runningFlowContext.getInput());
} finally {
inputsLock.unlock();
}
tryStartMemoryRecallWorker();
}
private void tryStartMemoryRecallWorker() {
if (!memoryCalling.compareAndSet(false, true)) {
return;
}
actionCapability.getExecutor(ActionCore.ExecutorType.VIRTUAL).execute(() -> {
try {
drainMemoryRecall();
} finally {
memoryCalling.set(false);
// 防止竞态worker 退出前后,刚好来了新输入,但没有线程负责再拉起 worker
if (!collectedInputs.isEmpty()) {
tryStartMemoryRecallWorker();
}
}
});
}
private void drainMemoryRecall() {
while (true) {
Map<LocalDateTime, String> snapshotInputs;
inputsLock.lock();
try {
if (collectedInputs.isEmpty()) {
return;
}
snapshotInputs = new HashMap<>(collectedInputs);
collectedInputs.clear();
} finally {
inputsLock.unlock();
}
ExtractorInput input = new ExtractorInput(
snapshotInputs,
memoryRuntime.getTopicTree(),
snapshotInputs.keySet()
.stream()
.max(LocalDateTime::compareTo)
.orElseThrow()
.toLocalDate()
);
ExtractorResult extractorResult = memorySelectExtractor.execute(input);
if (extractorResult.isRecall() || !extractorResult.getMatches().isEmpty()) {
List<ActivatedMemorySlice> activatedSlices = selectAndEvaluateMemory(snapshotInputs, extractorResult);
updateMemoryContext(activatedSlices);
}
}
}
private List<ActivatedMemorySlice> selectAndEvaluateMemory(PartnerRunningFlowContext runningFlowContext,
ExtractorResult extractorResult) {
private void updateMemoryContext(List<ActivatedMemorySlice> activatedSlices) {
// TODO
}
private List<ActivatedMemorySlice> selectAndEvaluateMemory(Map<LocalDateTime, String> snapshotInputs, ExtractorResult extractorResult) {
log.debug("[MemorySelector] 触发记忆回溯...");
LinkedHashMap<String, ActivatedMemorySlice> candidates = new LinkedHashMap<>();
setMemoryCandidates(candidates, extractorResult.getMatches());
removeDuplicateSlice(candidates.values());
EvaluatorInput evaluatorInput = EvaluatorInput.builder()
.input(runningFlowContext.getInput())
.inputs(snapshotInputs)
.memorySlices(new ArrayList<>(candidates.values()))
.messages(cognitionCapability.getChatMessages())
.build();
log.debug("[MemorySelector] 切片评估输入: {}", JSONObject.toJSONString(evaluatorInput));
List<ActivatedMemorySlice> memorySlices = sliceSelectEvaluator.execute(evaluatorInput);
log.debug("[MemorySelector] 切片评估结果: {}", JSONObject.toJSONString(memorySlices));
return memorySlices;
return sliceSelectEvaluator.execute(evaluatorInput);
}
private void setMemoryCandidates(LinkedHashMap<String, ActivatedMemorySlice> candidates,

View File

@@ -1,5 +1,6 @@
package work.slhaf.partner.module.memory.selector.extractor;
import kotlin.Unit;
import lombok.Data;
import lombok.EqualsAndHashCode;
import org.jetbrains.annotations.NotNull;
@@ -58,7 +59,11 @@ public class MemorySelectExtractor extends AbstractAgentModule.Sub<ExtractorInpu
return new TaskBlock() {
@Override
protected void fillXml(@NotNull Document document, @NotNull Element root) {
appendTextElement(document, root, "latest_input", input.getInput());
appendListElement(document, root, "new_inputs", "input", input.getInputs().entrySet(), (inputElement, input) -> {
inputElement.setAttribute("datetime", input.getKey().format(DateTimeFormatter.ofPattern("yyyy-MM-dd HH:mm:ss")));
inputElement.setTextContent(input.getValue());
return Unit.INSTANCE;
});
appendTextElement(document, root, "current_date", input.getDate().format(DateTimeFormatter.ofPattern("yyyy-MM-dd")));
appendTextElement(document, root, "memory_topic_tree", input.getTopic_tree());
}

View File

@@ -4,11 +4,13 @@ import lombok.AllArgsConstructor;
import lombok.Data;
import java.time.LocalDate;
import java.time.LocalDateTime;
import java.util.Map;
@Data
@AllArgsConstructor
public class ExtractorInput {
private String input;
private Map<LocalDateTime, String> inputs;
private String topic_tree;
private LocalDate date;
}