refactor(memory): refactor memory selector into asynchronous module

This commit is contained in:
2026-03-28 22:41:48 +08:00
parent 09f90d8ad5
commit db20e0ca78
3 changed files with 94 additions and 24 deletions

View File

@@ -1,13 +1,14 @@
package work.slhaf.partner.module.memory.selector; package work.slhaf.partner.module.memory.selector;
import com.alibaba.fastjson2.JSONObject;
import lombok.Data; import lombok.Data;
import lombok.EqualsAndHashCode; import lombok.EqualsAndHashCode;
import org.jetbrains.annotations.NotNull;
import work.slhaf.partner.api.agent.factory.capability.annotation.InjectCapability; import work.slhaf.partner.api.agent.factory.capability.annotation.InjectCapability;
import work.slhaf.partner.api.agent.factory.component.abstracts.AbstractAgentModule; import work.slhaf.partner.api.agent.factory.component.abstracts.AbstractAgentModule;
import work.slhaf.partner.api.agent.factory.component.annotation.InjectModule; import work.slhaf.partner.api.agent.factory.component.annotation.InjectModule;
import work.slhaf.partner.core.action.ActionCapability;
import work.slhaf.partner.core.action.ActionCore;
import work.slhaf.partner.core.cognition.CognitionCapability; import work.slhaf.partner.core.cognition.CognitionCapability;
import work.slhaf.partner.core.memory.MemoryCapability;
import work.slhaf.partner.core.memory.exception.UnExistedDateIndexException; import work.slhaf.partner.core.memory.exception.UnExistedDateIndexException;
import work.slhaf.partner.core.memory.exception.UnExistedTopicException; import work.slhaf.partner.core.memory.exception.UnExistedTopicException;
import work.slhaf.partner.core.memory.pojo.ActivatedMemorySlice; import work.slhaf.partner.core.memory.pojo.ActivatedMemorySlice;
@@ -15,23 +16,28 @@ import work.slhaf.partner.module.memory.runtime.MemoryRuntime;
import work.slhaf.partner.module.memory.selector.evaluator.SliceSelectEvaluator; import work.slhaf.partner.module.memory.selector.evaluator.SliceSelectEvaluator;
import work.slhaf.partner.module.memory.selector.evaluator.entity.EvaluatorInput; import work.slhaf.partner.module.memory.selector.evaluator.entity.EvaluatorInput;
import work.slhaf.partner.module.memory.selector.extractor.MemorySelectExtractor; import work.slhaf.partner.module.memory.selector.extractor.MemorySelectExtractor;
import work.slhaf.partner.module.memory.selector.extractor.entity.ExtractorInput;
import work.slhaf.partner.module.memory.selector.extractor.entity.ExtractorMatchData; import work.slhaf.partner.module.memory.selector.extractor.entity.ExtractorMatchData;
import work.slhaf.partner.module.memory.selector.extractor.entity.ExtractorResult; import work.slhaf.partner.module.memory.selector.extractor.entity.ExtractorResult;
import work.slhaf.partner.runtime.interaction.data.context.PartnerRunningFlowContext; import work.slhaf.partner.runtime.interaction.data.context.PartnerRunningFlowContext;
import java.time.LocalDate; import java.time.LocalDate;
import java.util.ArrayList; import java.time.LocalDateTime;
import java.util.Collection; import java.time.ZonedDateTime;
import java.util.LinkedHashMap; import java.util.*;
import java.util.List; import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReentrantLock;
@EqualsAndHashCode(callSuper = true) @EqualsAndHashCode(callSuper = true)
@Data @Data
public class MemorySelector extends AbstractAgentModule.Running<PartnerRunningFlowContext> { public class MemorySelector extends AbstractAgentModule.Running<PartnerRunningFlowContext> {
@InjectCapability
private MemoryCapability memoryCapability;
@InjectCapability @InjectCapability
private CognitionCapability cognitionCapability; private CognitionCapability cognitionCapability;
@InjectCapability
private ActionCapability actionCapability;
@InjectModule @InjectModule
private MemoryRuntime memoryRuntime; private MemoryRuntime memoryRuntime;
@InjectModule @InjectModule
@@ -39,31 +45,88 @@ public class MemorySelector extends AbstractAgentModule.Running<PartnerRunningFl
@InjectModule @InjectModule
private MemorySelectExtractor memorySelectExtractor; private MemorySelectExtractor memorySelectExtractor;
private AtomicBoolean memoryCalling = new AtomicBoolean(false);
private Map<LocalDateTime, String> collectedInputs = new HashMap<>();
private Lock inputsLock = new ReentrantLock();
@Override @Override
public void execute(PartnerRunningFlowContext runningFlowContext) { public void execute(@NotNull PartnerRunningFlowContext runningFlowContext) {
ExtractorResult extractorResult = memorySelectExtractor.execute(runningFlowContext); inputsLock.lock();
try {
collectedInputs.put(ZonedDateTime.now().toLocalDateTime(), runningFlowContext.getInput());
} finally {
inputsLock.unlock();
}
tryStartMemoryRecallWorker();
}
private void tryStartMemoryRecallWorker() {
if (!memoryCalling.compareAndSet(false, true)) {
return;
}
actionCapability.getExecutor(ActionCore.ExecutorType.VIRTUAL).execute(() -> {
try {
drainMemoryRecall();
} finally {
memoryCalling.set(false);
// 防止竞态worker 退出前后,刚好来了新输入,但没有线程负责再拉起 worker
if (!collectedInputs.isEmpty()) {
tryStartMemoryRecallWorker();
}
}
});
}
private void drainMemoryRecall() {
while (true) {
Map<LocalDateTime, String> snapshotInputs;
inputsLock.lock();
try {
if (collectedInputs.isEmpty()) {
return;
}
snapshotInputs = new HashMap<>(collectedInputs);
collectedInputs.clear();
} finally {
inputsLock.unlock();
}
ExtractorInput input = new ExtractorInput(
snapshotInputs,
memoryRuntime.getTopicTree(),
snapshotInputs.keySet()
.stream()
.max(LocalDateTime::compareTo)
.orElseThrow()
.toLocalDate()
);
ExtractorResult extractorResult = memorySelectExtractor.execute(input);
if (extractorResult.isRecall() || !extractorResult.getMatches().isEmpty()) { if (extractorResult.isRecall() || !extractorResult.getMatches().isEmpty()) {
memoryCapability.clearActivatedSlices(); List<ActivatedMemorySlice> activatedSlices = selectAndEvaluateMemory(snapshotInputs, extractorResult);
List<ActivatedMemorySlice> activatedSlices = selectAndEvaluateMemory(runningFlowContext, extractorResult); updateMemoryContext(activatedSlices);
memoryCapability.updateActivatedSlices(activatedSlices); }
} }
} }
private List<ActivatedMemorySlice> selectAndEvaluateMemory(PartnerRunningFlowContext runningFlowContext, private void updateMemoryContext(List<ActivatedMemorySlice> activatedSlices) {
ExtractorResult extractorResult) { // TODO
}
private List<ActivatedMemorySlice> selectAndEvaluateMemory(Map<LocalDateTime, String> snapshotInputs, ExtractorResult extractorResult) {
log.debug("[MemorySelector] 触发记忆回溯..."); log.debug("[MemorySelector] 触发记忆回溯...");
LinkedHashMap<String, ActivatedMemorySlice> candidates = new LinkedHashMap<>(); LinkedHashMap<String, ActivatedMemorySlice> candidates = new LinkedHashMap<>();
setMemoryCandidates(candidates, extractorResult.getMatches()); setMemoryCandidates(candidates, extractorResult.getMatches());
removeDuplicateSlice(candidates.values()); removeDuplicateSlice(candidates.values());
EvaluatorInput evaluatorInput = EvaluatorInput.builder() EvaluatorInput evaluatorInput = EvaluatorInput.builder()
.input(runningFlowContext.getInput()) .inputs(snapshotInputs)
.memorySlices(new ArrayList<>(candidates.values())) .memorySlices(new ArrayList<>(candidates.values()))
.messages(cognitionCapability.getChatMessages())
.build(); .build();
log.debug("[MemorySelector] 切片评估输入: {}", JSONObject.toJSONString(evaluatorInput)); return sliceSelectEvaluator.execute(evaluatorInput);
List<ActivatedMemorySlice> memorySlices = sliceSelectEvaluator.execute(evaluatorInput);
log.debug("[MemorySelector] 切片评估结果: {}", JSONObject.toJSONString(memorySlices));
return memorySlices;
} }
private void setMemoryCandidates(LinkedHashMap<String, ActivatedMemorySlice> candidates, private void setMemoryCandidates(LinkedHashMap<String, ActivatedMemorySlice> candidates,

View File

@@ -1,5 +1,6 @@
package work.slhaf.partner.module.memory.selector.extractor; package work.slhaf.partner.module.memory.selector.extractor;
import kotlin.Unit;
import lombok.Data; import lombok.Data;
import lombok.EqualsAndHashCode; import lombok.EqualsAndHashCode;
import org.jetbrains.annotations.NotNull; import org.jetbrains.annotations.NotNull;
@@ -58,7 +59,11 @@ public class MemorySelectExtractor extends AbstractAgentModule.Sub<ExtractorInpu
return new TaskBlock() { return new TaskBlock() {
@Override @Override
protected void fillXml(@NotNull Document document, @NotNull Element root) { protected void fillXml(@NotNull Document document, @NotNull Element root) {
appendTextElement(document, root, "latest_input", input.getInput()); appendListElement(document, root, "new_inputs", "input", input.getInputs().entrySet(), (inputElement, input) -> {
inputElement.setAttribute("datetime", input.getKey().format(DateTimeFormatter.ofPattern("yyyy-MM-dd HH:mm:ss")));
inputElement.setTextContent(input.getValue());
return Unit.INSTANCE;
});
appendTextElement(document, root, "current_date", input.getDate().format(DateTimeFormatter.ofPattern("yyyy-MM-dd"))); appendTextElement(document, root, "current_date", input.getDate().format(DateTimeFormatter.ofPattern("yyyy-MM-dd")));
appendTextElement(document, root, "memory_topic_tree", input.getTopic_tree()); appendTextElement(document, root, "memory_topic_tree", input.getTopic_tree());
} }

View File

@@ -4,11 +4,13 @@ import lombok.AllArgsConstructor;
import lombok.Data; import lombok.Data;
import java.time.LocalDate; import java.time.LocalDate;
import java.time.LocalDateTime;
import java.util.Map;
@Data @Data
@AllArgsConstructor @AllArgsConstructor
public class ExtractorInput { public class ExtractorInput {
private String input; private Map<LocalDateTime, String> inputs;
private String topic_tree; private String topic_tree;
private LocalDate date; private LocalDate date;
} }