refactor(memory): centralize memory recording and retrieval logic

This commit is contained in:
2026-03-10 20:42:46 +08:00
parent 027e8bddc0
commit ee1a033c1b
3 changed files with 90 additions and 78 deletions

View File

@@ -1,6 +1,5 @@
package work.slhaf.partner.module.modules.memory.runtime;
import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.io.FileUtils;
@@ -14,6 +13,7 @@ import work.slhaf.partner.core.cognation.CognationCapability;
import work.slhaf.partner.core.memory.MemoryCapability;
import work.slhaf.partner.core.memory.exception.UnExistedDateIndexException;
import work.slhaf.partner.core.memory.exception.UnExistedTopicException;
import work.slhaf.partner.core.memory.pojo.ActivatedMemorySlice;
import work.slhaf.partner.core.memory.pojo.MemorySlice;
import work.slhaf.partner.core.memory.pojo.MemoryUnit;
import work.slhaf.partner.core.memory.pojo.SliceRef;
@@ -33,7 +33,6 @@ import java.util.concurrent.locks.ReentrantLock;
import static work.slhaf.partner.common.Constant.Path.MEMORY_DATA;
@EqualsAndHashCode(callSuper = true)
@Data
@Slf4j
public class MemoryRuntime extends AbstractAgentModule.Standalone {
@@ -63,7 +62,7 @@ public class MemoryRuntime extends AbstractAgentModule.Standalone {
}
}
public void bindTopic(String topicPath, SliceRef sliceRef) {
private void bindTopic(String topicPath, SliceRef sliceRef) {
String normalizedPath = normalizeTopicPath(topicPath);
runtimeLock.lock();
try {
@@ -79,7 +78,21 @@ public class MemoryRuntime extends AbstractAgentModule.Standalone {
}
}
public void indexMemoryUnit(MemoryUnit memoryUnit) {
public void recordMemory(MemoryUnit memoryUnit, String topicPath, List<String> relatedTopicPaths, String dialogSummary) {
memoryCapability.saveMemoryUnit(memoryUnit);
MemorySlice memorySlice = memoryUnit.getSlices().getFirst();
SliceRef sliceRef = new SliceRef(memoryUnit.getId(), memorySlice.getId());
indexMemoryUnit(memoryUnit);
bindTopic(topicPath, sliceRef);
if (relatedTopicPaths != null) {
for (String relatedTopicPath : relatedTopicPaths) {
bindTopic(relatedTopicPath, sliceRef);
}
}
updateDialogMap(LocalDateTime.now(), dialogSummary);
}
private void indexMemoryUnit(MemoryUnit memoryUnit) {
runtimeLock.lock();
try {
for (CopyOnWriteArrayList<SliceRef> refs : dateIndex.values()) {
@@ -100,7 +113,7 @@ public class MemoryRuntime extends AbstractAgentModule.Standalone {
}
}
public List<SliceRef> findByTopicPath(String topicPath) {
private List<SliceRef> findByTopicPath(String topicPath) {
String normalizedPath = normalizeTopicPath(topicPath);
List<SliceRef> refs = topicSlices.get(normalizedPath);
if (refs == null || refs.isEmpty()) {
@@ -109,7 +122,7 @@ public class MemoryRuntime extends AbstractAgentModule.Standalone {
return new ArrayList<>(refs);
}
public List<SliceRef> findByDate(LocalDate date) {
private List<SliceRef> findByDate(LocalDate date) {
List<SliceRef> refs = dateIndex.get(date);
if (refs == null || refs.isEmpty()) {
throw new UnExistedDateIndexException("不存在的日期索引: " + date);
@@ -117,7 +130,15 @@ public class MemoryRuntime extends AbstractAgentModule.Standalone {
return new ArrayList<>(refs);
}
public void updateDialogMap(LocalDateTime dateTime, String newDialogCache) {
public List<ActivatedMemorySlice> queryActivatedMemoryByTopicPath(String topicPath) {
return buildActivatedMemorySlices(findByTopicPath(topicPath));
}
public List<ActivatedMemorySlice> queryActivatedMemoryByDate(LocalDate date) {
return buildActivatedMemorySlices(findByDate(date));
}
private void updateDialogMap(LocalDateTime dateTime, String newDialogCache) {
runtimeLock.lock();
try {
List<LocalDateTime> keysToRemove = new ArrayList<>();
@@ -136,10 +157,6 @@ public class MemoryRuntime extends AbstractAgentModule.Standalone {
}
}
public HashMap<LocalDateTime, String> getDialogMap() {
return dialogMap;
}
public String getDialogMapStr() {
StringBuilder str = new StringBuilder();
dialogMap.entrySet().stream()
@@ -151,6 +168,10 @@ public class MemoryRuntime extends AbstractAgentModule.Standalone {
return str.toString();
}
public boolean containsDialogSummary(String summary) {
return dialogMap.containsValue(summary);
}
public String getTopicTree() {
TopicTreeNode root = new TopicTreeNode();
for (Map.Entry<String, CopyOnWriteArrayList<SliceRef>> entry : topicSlices.entrySet()) {
@@ -171,6 +192,50 @@ public class MemoryRuntime extends AbstractAgentModule.Standalone {
return stringBuilder.toString();
}
private List<ActivatedMemorySlice> buildActivatedMemorySlices(List<SliceRef> refs) {
List<ActivatedMemorySlice> slices = new ArrayList<>();
for (SliceRef ref : refs) {
ActivatedMemorySlice slice = buildActivatedMemorySlice(ref);
if (slice != null) {
slices.add(slice);
}
}
return slices;
}
private ActivatedMemorySlice buildActivatedMemorySlice(SliceRef ref) {
MemoryUnit memoryUnit = memoryCapability.getMemoryUnit(ref.getUnitId());
MemorySlice memorySlice = memoryCapability.getMemorySlice(ref.getUnitId(), ref.getSliceId());
if (memoryUnit == null || memorySlice == null) {
return null;
}
List<work.slhaf.partner.api.chat.pojo.Message> messages = sliceMessages(memoryUnit, memorySlice);
LocalDate date = Instant.ofEpochMilli(memorySlice.getTimestamp())
.atZone(ZoneId.systemDefault())
.toLocalDate();
return ActivatedMemorySlice.builder()
.unitId(ref.getUnitId())
.sliceId(ref.getSliceId())
.summary(memorySlice.getSummary())
.timestamp(memorySlice.getTimestamp())
.date(date)
.messages(messages)
.build();
}
private List<work.slhaf.partner.api.chat.pojo.Message> sliceMessages(MemoryUnit memoryUnit, MemorySlice memorySlice) {
List<work.slhaf.partner.api.chat.pojo.Message> conversationMessages = memoryUnit.getConversationMessages();
if (conversationMessages == null || conversationMessages.isEmpty()) {
return List.of();
}
int start = Math.max(0, memorySlice.getStartIndex());
int end = Math.min(conversationMessages.size() - 1, memorySlice.getEndIndex());
if (start > end) {
return List.of();
}
return new ArrayList<>(conversationMessages.subList(start, end + 1));
}
private void printSubTopicsTreeFormat(TopicTreeNode node, String prefix, StringBuilder stringBuilder) {
List<Map.Entry<String, TopicTreeNode>> entries = new ArrayList<>(node.children.entrySet());
for (int i = 0; i < entries.size(); i++) {

View File

@@ -6,15 +6,11 @@ import lombok.EqualsAndHashCode;
import work.slhaf.partner.api.agent.factory.capability.annotation.InjectCapability;
import work.slhaf.partner.api.agent.factory.component.abstracts.AbstractAgentModule;
import work.slhaf.partner.api.agent.factory.component.annotation.InjectModule;
import work.slhaf.partner.api.chat.pojo.Message;
import work.slhaf.partner.core.cognation.CognationCapability;
import work.slhaf.partner.core.memory.MemoryCapability;
import work.slhaf.partner.core.memory.exception.UnExistedDateIndexException;
import work.slhaf.partner.core.memory.exception.UnExistedTopicException;
import work.slhaf.partner.core.memory.pojo.ActivatedMemorySlice;
import work.slhaf.partner.core.memory.pojo.MemorySlice;
import work.slhaf.partner.core.memory.pojo.MemoryUnit;
import work.slhaf.partner.core.memory.pojo.SliceRef;
import work.slhaf.partner.module.modules.memory.runtime.MemoryRuntime;
import work.slhaf.partner.module.modules.memory.selector.evaluator.SliceSelectEvaluator;
import work.slhaf.partner.module.modules.memory.selector.evaluator.entity.EvaluatorInput;
@@ -23,9 +19,7 @@ import work.slhaf.partner.module.modules.memory.selector.extractor.entity.Extrac
import work.slhaf.partner.module.modules.memory.selector.extractor.entity.ExtractorResult;
import work.slhaf.partner.runtime.interaction.data.context.PartnerRunningFlowContext;
import java.time.Instant;
import java.time.LocalDate;
import java.time.ZoneId;
import java.util.ArrayList;
import java.util.Collection;
import java.util.LinkedHashMap;
@@ -85,16 +79,15 @@ public class MemorySelector extends AbstractAgentModule.Running<PartnerRunningFl
List<ExtractorMatchData> matches) {
for (ExtractorMatchData match : matches) {
try {
List<SliceRef> refs = switch (match.getType()) {
case ExtractorMatchData.Constant.TOPIC -> memoryRuntime.findByTopicPath(match.getText());
case ExtractorMatchData.Constant.DATE -> memoryRuntime.findByDate(LocalDate.parse(match.getText()));
List<ActivatedMemorySlice> recalledSlices = switch (match.getType()) {
case ExtractorMatchData.Constant.TOPIC ->
memoryRuntime.queryActivatedMemoryByTopicPath(match.getText());
case ExtractorMatchData.Constant.DATE ->
memoryRuntime.queryActivatedMemoryByDate(LocalDate.parse(match.getText()));
default -> List.of();
};
for (SliceRef ref : refs) {
ActivatedMemorySlice recalledSlice = buildActivatedMemorySlice(ref);
if (recalledSlice != null) {
candidates.putIfAbsent(ref.getUnitId() + ":" + ref.getSliceId(), recalledSlice);
}
for (ActivatedMemorySlice recalledSlice : recalledSlices) {
candidates.putIfAbsent(recalledSlice.getUnitId() + ":" + recalledSlice.getSliceId(), recalledSlice);
}
} catch (UnExistedDateIndexException | UnExistedTopicException e) {
log.error("[MemorySelector] 不存在的记忆索引", e);
@@ -103,42 +96,8 @@ public class MemorySelector extends AbstractAgentModule.Running<PartnerRunningFl
}
}
private ActivatedMemorySlice buildActivatedMemorySlice(SliceRef ref) {
MemoryUnit memoryUnit = memoryCapability.getMemoryUnit(ref.getUnitId());
MemorySlice memorySlice = memoryCapability.getMemorySlice(ref.getUnitId(), ref.getSliceId());
if (memoryUnit == null || memorySlice == null) {
return null;
}
List<Message> messages = sliceMessages(memoryUnit, memorySlice);
LocalDate date = Instant.ofEpochMilli(memorySlice.getTimestamp())
.atZone(ZoneId.systemDefault())
.toLocalDate();
return ActivatedMemorySlice.builder()
.unitId(ref.getUnitId())
.sliceId(ref.getSliceId())
.summary(memorySlice.getSummary())
.timestamp(memorySlice.getTimestamp())
.date(date)
.messages(messages)
.build();
}
private List<Message> sliceMessages(MemoryUnit memoryUnit, MemorySlice memorySlice) {
List<Message> conversationMessages = memoryUnit.getConversationMessages();
if (conversationMessages == null || conversationMessages.isEmpty()) {
return List.of();
}
int start = Math.max(0, memorySlice.getStartIndex());
int end = Math.min(conversationMessages.size() - 1, memorySlice.getEndIndex());
if (start > end) {
return List.of();
}
return new ArrayList<>(conversationMessages.subList(start, end + 1));
}
private void removeDuplicateSlice(Collection<ActivatedMemorySlice> candidates) {
Collection<String> values = memoryRuntime.getDialogMap().values();
candidates.removeIf(m -> values.contains(m.getSummary()));
candidates.removeIf(m -> memoryRuntime.containsDialogSummary(m.getSummary()));
}
@Override

View File

@@ -15,7 +15,6 @@ import work.slhaf.partner.core.cognation.CognationCapability;
import work.slhaf.partner.core.memory.MemoryCapability;
import work.slhaf.partner.core.memory.pojo.MemorySlice;
import work.slhaf.partner.core.memory.pojo.MemoryUnit;
import work.slhaf.partner.core.memory.pojo.SliceRef;
import work.slhaf.partner.module.common.module.PostRunningAgentModule;
import work.slhaf.partner.module.modules.action.scheduler.ActionScheduler;
import work.slhaf.partner.module.modules.memory.runtime.MemoryRuntime;
@@ -25,7 +24,6 @@ import work.slhaf.partner.module.modules.memory.updater.summarizer.entity.Summar
import work.slhaf.partner.module.modules.memory.updater.summarizer.entity.SummarizeResult;
import work.slhaf.partner.runtime.interaction.data.context.PartnerRunningFlowContext;
import java.time.LocalDateTime;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
@@ -145,26 +143,16 @@ public class MemoryUpdater extends PostRunningAgentModule {
SummarizeResult summarizeResult = summarize(summarizeInput);
log.debug("[MemoryUpdater] 记忆更新-总结流程-输出: {}", JSONObject.toJSONString(summarizeResult));
MemoryUnit memoryUnit = buildMemoryUnit(chatMessages, summarizeResult);
memoryCapability.saveMemoryUnit(memoryUnit);
MemorySlice memorySlice = memoryUnit.getSlices().getFirst();
SliceRef sliceRef = new SliceRef(memoryUnit.getId(), memorySlice.getId());
bindTopics(memoryUnit, summarizeResult, sliceRef);
memoryRuntime.updateDialogMap(LocalDateTime.now(), summarizeResult.getSummary());
memoryRuntime.recordMemory(
memoryUnit,
summarizeResult.getTopicPath(),
summarizeResult.getRelatedTopicPath(),
summarizeResult.getSummary()
);
lastUpdatedTime = System.currentTimeMillis();
log.debug("[MemoryUpdater] 记忆更新流程结束...");
}
private void bindTopics(MemoryUnit memoryUnit, SummarizeResult summarizeResult, SliceRef sliceRef) {
memoryRuntime.indexMemoryUnit(memoryUnit);
memoryRuntime.bindTopic(summarizeResult.getTopicPath(), sliceRef);
if (summarizeResult.getRelatedTopicPath() == null) {
return;
}
for (String relatedTopicPath : summarizeResult.getRelatedTopicPath()) {
memoryRuntime.bindTopic(relatedTopicPath, sliceRef);
}
}
private List<Message> getCleanedMessages(List<Message> chatMessages) {
return chatMessages.stream()
.map(message -> {