From ee1a033c1b995487f7ae416f5ac24003115b1b6f Mon Sep 17 00:00:00 2001 From: slhafzjw Date: Tue, 10 Mar 2026 20:42:46 +0800 Subject: [PATCH] refactor(memory): centralize memory recording and retrieval logic --- .../modules/memory/runtime/MemoryRuntime.java | 87 ++++++++++++++++--- .../memory/selector/MemorySelector.java | 57 ++---------- .../modules/memory/updater/MemoryUpdater.java | 24 ++--- 3 files changed, 90 insertions(+), 78 deletions(-) diff --git a/Partner-Core/src/main/java/work/slhaf/partner/module/modules/memory/runtime/MemoryRuntime.java b/Partner-Core/src/main/java/work/slhaf/partner/module/modules/memory/runtime/MemoryRuntime.java index 55a8f611..d7ed04a0 100644 --- a/Partner-Core/src/main/java/work/slhaf/partner/module/modules/memory/runtime/MemoryRuntime.java +++ b/Partner-Core/src/main/java/work/slhaf/partner/module/modules/memory/runtime/MemoryRuntime.java @@ -1,6 +1,5 @@ package work.slhaf.partner.module.modules.memory.runtime; -import lombok.Data; import lombok.EqualsAndHashCode; import lombok.extern.slf4j.Slf4j; import org.apache.commons.io.FileUtils; @@ -14,6 +13,7 @@ import work.slhaf.partner.core.cognation.CognationCapability; import work.slhaf.partner.core.memory.MemoryCapability; import work.slhaf.partner.core.memory.exception.UnExistedDateIndexException; import work.slhaf.partner.core.memory.exception.UnExistedTopicException; +import work.slhaf.partner.core.memory.pojo.ActivatedMemorySlice; import work.slhaf.partner.core.memory.pojo.MemorySlice; import work.slhaf.partner.core.memory.pojo.MemoryUnit; import work.slhaf.partner.core.memory.pojo.SliceRef; @@ -33,7 +33,6 @@ import java.util.concurrent.locks.ReentrantLock; import static work.slhaf.partner.common.Constant.Path.MEMORY_DATA; @EqualsAndHashCode(callSuper = true) -@Data @Slf4j public class MemoryRuntime extends AbstractAgentModule.Standalone { @@ -63,7 +62,7 @@ public class MemoryRuntime extends AbstractAgentModule.Standalone { } } - public void bindTopic(String topicPath, SliceRef sliceRef) { + private void bindTopic(String topicPath, SliceRef sliceRef) { String normalizedPath = normalizeTopicPath(topicPath); runtimeLock.lock(); try { @@ -79,7 +78,21 @@ public class MemoryRuntime extends AbstractAgentModule.Standalone { } } - public void indexMemoryUnit(MemoryUnit memoryUnit) { + public void recordMemory(MemoryUnit memoryUnit, String topicPath, List relatedTopicPaths, String dialogSummary) { + memoryCapability.saveMemoryUnit(memoryUnit); + MemorySlice memorySlice = memoryUnit.getSlices().getFirst(); + SliceRef sliceRef = new SliceRef(memoryUnit.getId(), memorySlice.getId()); + indexMemoryUnit(memoryUnit); + bindTopic(topicPath, sliceRef); + if (relatedTopicPaths != null) { + for (String relatedTopicPath : relatedTopicPaths) { + bindTopic(relatedTopicPath, sliceRef); + } + } + updateDialogMap(LocalDateTime.now(), dialogSummary); + } + + private void indexMemoryUnit(MemoryUnit memoryUnit) { runtimeLock.lock(); try { for (CopyOnWriteArrayList refs : dateIndex.values()) { @@ -100,7 +113,7 @@ public class MemoryRuntime extends AbstractAgentModule.Standalone { } } - public List findByTopicPath(String topicPath) { + private List findByTopicPath(String topicPath) { String normalizedPath = normalizeTopicPath(topicPath); List refs = topicSlices.get(normalizedPath); if (refs == null || refs.isEmpty()) { @@ -109,7 +122,7 @@ public class MemoryRuntime extends AbstractAgentModule.Standalone { return new ArrayList<>(refs); } - public List findByDate(LocalDate date) { + private List findByDate(LocalDate date) { List refs = dateIndex.get(date); if (refs == null || refs.isEmpty()) { throw new UnExistedDateIndexException("不存在的日期索引: " + date); @@ -117,7 +130,15 @@ public class MemoryRuntime extends AbstractAgentModule.Standalone { return new ArrayList<>(refs); } - public void updateDialogMap(LocalDateTime dateTime, String newDialogCache) { + public List queryActivatedMemoryByTopicPath(String topicPath) { + return buildActivatedMemorySlices(findByTopicPath(topicPath)); + } + + public List queryActivatedMemoryByDate(LocalDate date) { + return buildActivatedMemorySlices(findByDate(date)); + } + + private void updateDialogMap(LocalDateTime dateTime, String newDialogCache) { runtimeLock.lock(); try { List keysToRemove = new ArrayList<>(); @@ -136,10 +157,6 @@ public class MemoryRuntime extends AbstractAgentModule.Standalone { } } - public HashMap getDialogMap() { - return dialogMap; - } - public String getDialogMapStr() { StringBuilder str = new StringBuilder(); dialogMap.entrySet().stream() @@ -151,6 +168,10 @@ public class MemoryRuntime extends AbstractAgentModule.Standalone { return str.toString(); } + public boolean containsDialogSummary(String summary) { + return dialogMap.containsValue(summary); + } + public String getTopicTree() { TopicTreeNode root = new TopicTreeNode(); for (Map.Entry> entry : topicSlices.entrySet()) { @@ -171,6 +192,50 @@ public class MemoryRuntime extends AbstractAgentModule.Standalone { return stringBuilder.toString(); } + private List buildActivatedMemorySlices(List refs) { + List slices = new ArrayList<>(); + for (SliceRef ref : refs) { + ActivatedMemorySlice slice = buildActivatedMemorySlice(ref); + if (slice != null) { + slices.add(slice); + } + } + return slices; + } + + private ActivatedMemorySlice buildActivatedMemorySlice(SliceRef ref) { + MemoryUnit memoryUnit = memoryCapability.getMemoryUnit(ref.getUnitId()); + MemorySlice memorySlice = memoryCapability.getMemorySlice(ref.getUnitId(), ref.getSliceId()); + if (memoryUnit == null || memorySlice == null) { + return null; + } + List messages = sliceMessages(memoryUnit, memorySlice); + LocalDate date = Instant.ofEpochMilli(memorySlice.getTimestamp()) + .atZone(ZoneId.systemDefault()) + .toLocalDate(); + return ActivatedMemorySlice.builder() + .unitId(ref.getUnitId()) + .sliceId(ref.getSliceId()) + .summary(memorySlice.getSummary()) + .timestamp(memorySlice.getTimestamp()) + .date(date) + .messages(messages) + .build(); + } + + private List sliceMessages(MemoryUnit memoryUnit, MemorySlice memorySlice) { + List conversationMessages = memoryUnit.getConversationMessages(); + if (conversationMessages == null || conversationMessages.isEmpty()) { + return List.of(); + } + int start = Math.max(0, memorySlice.getStartIndex()); + int end = Math.min(conversationMessages.size() - 1, memorySlice.getEndIndex()); + if (start > end) { + return List.of(); + } + return new ArrayList<>(conversationMessages.subList(start, end + 1)); + } + private void printSubTopicsTreeFormat(TopicTreeNode node, String prefix, StringBuilder stringBuilder) { List> entries = new ArrayList<>(node.children.entrySet()); for (int i = 0; i < entries.size(); i++) { diff --git a/Partner-Core/src/main/java/work/slhaf/partner/module/modules/memory/selector/MemorySelector.java b/Partner-Core/src/main/java/work/slhaf/partner/module/modules/memory/selector/MemorySelector.java index fc0c4365..1df73f3d 100644 --- a/Partner-Core/src/main/java/work/slhaf/partner/module/modules/memory/selector/MemorySelector.java +++ b/Partner-Core/src/main/java/work/slhaf/partner/module/modules/memory/selector/MemorySelector.java @@ -6,15 +6,11 @@ import lombok.EqualsAndHashCode; import work.slhaf.partner.api.agent.factory.capability.annotation.InjectCapability; import work.slhaf.partner.api.agent.factory.component.abstracts.AbstractAgentModule; import work.slhaf.partner.api.agent.factory.component.annotation.InjectModule; -import work.slhaf.partner.api.chat.pojo.Message; import work.slhaf.partner.core.cognation.CognationCapability; import work.slhaf.partner.core.memory.MemoryCapability; import work.slhaf.partner.core.memory.exception.UnExistedDateIndexException; import work.slhaf.partner.core.memory.exception.UnExistedTopicException; import work.slhaf.partner.core.memory.pojo.ActivatedMemorySlice; -import work.slhaf.partner.core.memory.pojo.MemorySlice; -import work.slhaf.partner.core.memory.pojo.MemoryUnit; -import work.slhaf.partner.core.memory.pojo.SliceRef; import work.slhaf.partner.module.modules.memory.runtime.MemoryRuntime; import work.slhaf.partner.module.modules.memory.selector.evaluator.SliceSelectEvaluator; import work.slhaf.partner.module.modules.memory.selector.evaluator.entity.EvaluatorInput; @@ -23,9 +19,7 @@ import work.slhaf.partner.module.modules.memory.selector.extractor.entity.Extrac import work.slhaf.partner.module.modules.memory.selector.extractor.entity.ExtractorResult; import work.slhaf.partner.runtime.interaction.data.context.PartnerRunningFlowContext; -import java.time.Instant; import java.time.LocalDate; -import java.time.ZoneId; import java.util.ArrayList; import java.util.Collection; import java.util.LinkedHashMap; @@ -85,16 +79,15 @@ public class MemorySelector extends AbstractAgentModule.Running matches) { for (ExtractorMatchData match : matches) { try { - List refs = switch (match.getType()) { - case ExtractorMatchData.Constant.TOPIC -> memoryRuntime.findByTopicPath(match.getText()); - case ExtractorMatchData.Constant.DATE -> memoryRuntime.findByDate(LocalDate.parse(match.getText())); + List recalledSlices = switch (match.getType()) { + case ExtractorMatchData.Constant.TOPIC -> + memoryRuntime.queryActivatedMemoryByTopicPath(match.getText()); + case ExtractorMatchData.Constant.DATE -> + memoryRuntime.queryActivatedMemoryByDate(LocalDate.parse(match.getText())); default -> List.of(); }; - for (SliceRef ref : refs) { - ActivatedMemorySlice recalledSlice = buildActivatedMemorySlice(ref); - if (recalledSlice != null) { - candidates.putIfAbsent(ref.getUnitId() + ":" + ref.getSliceId(), recalledSlice); - } + for (ActivatedMemorySlice recalledSlice : recalledSlices) { + candidates.putIfAbsent(recalledSlice.getUnitId() + ":" + recalledSlice.getSliceId(), recalledSlice); } } catch (UnExistedDateIndexException | UnExistedTopicException e) { log.error("[MemorySelector] 不存在的记忆索引", e); @@ -103,42 +96,8 @@ public class MemorySelector extends AbstractAgentModule.Running messages = sliceMessages(memoryUnit, memorySlice); - LocalDate date = Instant.ofEpochMilli(memorySlice.getTimestamp()) - .atZone(ZoneId.systemDefault()) - .toLocalDate(); - return ActivatedMemorySlice.builder() - .unitId(ref.getUnitId()) - .sliceId(ref.getSliceId()) - .summary(memorySlice.getSummary()) - .timestamp(memorySlice.getTimestamp()) - .date(date) - .messages(messages) - .build(); - } - - private List sliceMessages(MemoryUnit memoryUnit, MemorySlice memorySlice) { - List conversationMessages = memoryUnit.getConversationMessages(); - if (conversationMessages == null || conversationMessages.isEmpty()) { - return List.of(); - } - int start = Math.max(0, memorySlice.getStartIndex()); - int end = Math.min(conversationMessages.size() - 1, memorySlice.getEndIndex()); - if (start > end) { - return List.of(); - } - return new ArrayList<>(conversationMessages.subList(start, end + 1)); - } - private void removeDuplicateSlice(Collection candidates) { - Collection values = memoryRuntime.getDialogMap().values(); - candidates.removeIf(m -> values.contains(m.getSummary())); + candidates.removeIf(m -> memoryRuntime.containsDialogSummary(m.getSummary())); } @Override diff --git a/Partner-Core/src/main/java/work/slhaf/partner/module/modules/memory/updater/MemoryUpdater.java b/Partner-Core/src/main/java/work/slhaf/partner/module/modules/memory/updater/MemoryUpdater.java index 345857c0..efc21d5b 100644 --- a/Partner-Core/src/main/java/work/slhaf/partner/module/modules/memory/updater/MemoryUpdater.java +++ b/Partner-Core/src/main/java/work/slhaf/partner/module/modules/memory/updater/MemoryUpdater.java @@ -15,7 +15,6 @@ import work.slhaf.partner.core.cognation.CognationCapability; import work.slhaf.partner.core.memory.MemoryCapability; import work.slhaf.partner.core.memory.pojo.MemorySlice; import work.slhaf.partner.core.memory.pojo.MemoryUnit; -import work.slhaf.partner.core.memory.pojo.SliceRef; import work.slhaf.partner.module.common.module.PostRunningAgentModule; import work.slhaf.partner.module.modules.action.scheduler.ActionScheduler; import work.slhaf.partner.module.modules.memory.runtime.MemoryRuntime; @@ -25,7 +24,6 @@ import work.slhaf.partner.module.modules.memory.updater.summarizer.entity.Summar import work.slhaf.partner.module.modules.memory.updater.summarizer.entity.SummarizeResult; import work.slhaf.partner.runtime.interaction.data.context.PartnerRunningFlowContext; -import java.time.LocalDateTime; import java.util.ArrayList; import java.util.Arrays; import java.util.List; @@ -145,26 +143,16 @@ public class MemoryUpdater extends PostRunningAgentModule { SummarizeResult summarizeResult = summarize(summarizeInput); log.debug("[MemoryUpdater] 记忆更新-总结流程-输出: {}", JSONObject.toJSONString(summarizeResult)); MemoryUnit memoryUnit = buildMemoryUnit(chatMessages, summarizeResult); - memoryCapability.saveMemoryUnit(memoryUnit); - MemorySlice memorySlice = memoryUnit.getSlices().getFirst(); - SliceRef sliceRef = new SliceRef(memoryUnit.getId(), memorySlice.getId()); - bindTopics(memoryUnit, summarizeResult, sliceRef); - memoryRuntime.updateDialogMap(LocalDateTime.now(), summarizeResult.getSummary()); + memoryRuntime.recordMemory( + memoryUnit, + summarizeResult.getTopicPath(), + summarizeResult.getRelatedTopicPath(), + summarizeResult.getSummary() + ); lastUpdatedTime = System.currentTimeMillis(); log.debug("[MemoryUpdater] 记忆更新流程结束..."); } - private void bindTopics(MemoryUnit memoryUnit, SummarizeResult summarizeResult, SliceRef sliceRef) { - memoryRuntime.indexMemoryUnit(memoryUnit); - memoryRuntime.bindTopic(summarizeResult.getTopicPath(), sliceRef); - if (summarizeResult.getRelatedTopicPath() == null) { - return; - } - for (String relatedTopicPath : summarizeResult.getRelatedTopicPath()) { - memoryRuntime.bindTopic(relatedTopicPath, sliceRef); - } - } - private List getCleanedMessages(List chatMessages) { return chatMessages.stream() .map(message -> {