From 5ad80d8b86bbb62134c53959ccfe1a319f2a8ce7 Mon Sep 17 00:00:00 2001 From: slhafzjw Date: Tue, 10 Mar 2026 19:41:05 +0800 Subject: [PATCH] refactor(memory): decouple memory storage and runtime structures --- .../core/cognation/CognationCapability.java | 23 - .../partner/core/cognation/CognationCore.java | 110 +-- .../partner/core/memory/MemoryCapability.java | 43 +- .../slhaf/partner/core/memory/MemoryCore.java | 641 +++--------------- ...edSlice.java => ActivatedMemorySlice.java} | 9 +- .../core/memory/pojo/MemoryResult.java | 26 - .../partner/core/memory/pojo/MemorySlice.java | 60 +- .../partner/core/memory/pojo/MemoryUnit.java | 23 + .../{MemorySliceResult.java => SliceRef.java} | 16 +- .../action/executor/ActionExecutor.java | 4 +- .../executor/entity/CorrectorInput.java | 4 +- .../executor/entity/ExtractorInput.java | 4 +- .../action/interventor/ActionInterventor.java | 7 +- .../evaluator/InterventionEvaluator.java | 4 +- .../evaluator/entity/EvaluatorInput.java | 4 +- .../modules/action/planner/ActionPlanner.java | 2 +- .../planner/evaluator/ActionEvaluator.java | 4 +- .../evaluator/entity/EvaluatorBatchInput.java | 4 +- .../evaluator/entity/EvaluatorInput.java | 4 +- .../modules/core/CommunicationProducer.java | 3 - .../modules/memory/runtime/MemoryRuntime.java | 247 +++++++ .../memory/selector/MemorySelector.java | 113 +-- .../evaluator/SliceSelectEvaluator.java | 113 ++- .../evaluator/entity/EvaluatorInput.java | 4 +- .../extractor/MemorySelectExtractor.java | 20 +- .../extractor/entity/ExtractorInput.java | 4 +- .../modules/memory/updater/MemoryUpdater.java | 207 ++---- .../summarizer/entity/SummarizeResult.java | 1 - .../modules/process/PreprocessExecutor.java | 7 +- .../java/experimental/ReflectionTest.java | 7 +- .../executor/ActionExecutorTest.java | 2 +- .../core/CommunicationProducerTest.java | 17 +- 32 files changed, 603 insertions(+), 1134 deletions(-) rename Partner-Core/src/main/java/work/slhaf/partner/core/memory/pojo/{EvaluatedSlice.java => ActivatedMemorySlice.java} (62%) delete mode 100644 Partner-Core/src/main/java/work/slhaf/partner/core/memory/pojo/MemoryResult.java create mode 100644 Partner-Core/src/main/java/work/slhaf/partner/core/memory/pojo/MemoryUnit.java rename Partner-Core/src/main/java/work/slhaf/partner/core/memory/pojo/{MemorySliceResult.java => SliceRef.java} (50%) create mode 100644 Partner-Core/src/main/java/work/slhaf/partner/module/modules/memory/runtime/MemoryRuntime.java diff --git a/Partner-Core/src/main/java/work/slhaf/partner/core/cognation/CognationCapability.java b/Partner-Core/src/main/java/work/slhaf/partner/core/cognation/CognationCapability.java index 7735a2e0..9dda332f 100644 --- a/Partner-Core/src/main/java/work/slhaf/partner/core/cognation/CognationCapability.java +++ b/Partner-Core/src/main/java/work/slhaf/partner/core/cognation/CognationCapability.java @@ -2,11 +2,8 @@ package work.slhaf.partner.core.cognation; import work.slhaf.partner.api.agent.factory.capability.annotation.Capability; import work.slhaf.partner.api.chat.pojo.Message; -import work.slhaf.partner.api.chat.pojo.MetaMessage; -import java.util.HashMap; import java.util.List; -import java.util.Map; import java.util.concurrent.locks.Lock; @Capability("cognation") @@ -20,26 +17,6 @@ public interface CognationCapability { void rollChatMessagesWithSnapshot(int snapshotSize, int retainDivisor); - void cleanMessage(List messages); - Lock getMessageLock(); - void addMetaMessage(String userId, MetaMessage metaMessage); - - List unpackAndClear(String userId); - - void refreshMemoryId(); - - void resetLastUpdatedTime(); - - long getLastUpdatedTime(); - - HashMap> getSingleMetaMessageMap(); - - Map> drainSingleMetaMessages(); - - List snapshotSingleMetaMessages(String userId); - - String getCurrentMemoryId(); - } diff --git a/Partner-Core/src/main/java/work/slhaf/partner/core/cognation/CognationCore.java b/Partner-Core/src/main/java/work/slhaf/partner/core/cognation/CognationCore.java index d2804194..34db08d5 100644 --- a/Partner-Core/src/main/java/work/slhaf/partner/core/cognation/CognationCore.java +++ b/Partner-Core/src/main/java/work/slhaf/partner/core/cognation/CognationCore.java @@ -1,6 +1,5 @@ package work.slhaf.partner.core.cognation; -import com.alibaba.fastjson2.JSONObject; import lombok.EqualsAndHashCode; import lombok.Getter; import lombok.Setter; @@ -9,13 +8,13 @@ import work.slhaf.partner.api.agent.factory.capability.annotation.CapabilityCore import work.slhaf.partner.api.agent.factory.capability.annotation.CapabilityMethod; import work.slhaf.partner.api.agent.runtime.interaction.AgentRuntime; import work.slhaf.partner.api.chat.pojo.Message; -import work.slhaf.partner.api.chat.pojo.MetaMessage; import work.slhaf.partner.core.PartnerCore; import work.slhaf.partner.runtime.interaction.data.context.PartnerRunningFlowContext; import java.io.IOException; import java.io.Serial; -import java.util.*; +import java.util.ArrayList; +import java.util.List; import java.util.concurrent.locks.Lock; import java.util.concurrent.locks.ReentrantLock; @@ -35,10 +34,6 @@ public class CognationCore extends PartnerCore { * 主模型的聊天记录 */ private List chatMessages = new ArrayList<>(); - private HashMap> singleMetaMessageMap = new HashMap<>(); - private String currentMemoryId; - private long lastUpdatedTime; - public CognationCore() throws IOException, ClassNotFoundException { } @@ -86,112 +81,11 @@ public class CognationCore extends PartnerCore { } } - @CapabilityMethod - public long getLastUpdatedTime() { - return lastUpdatedTime; - } - - @CapabilityMethod - public HashMap> getSingleMetaMessageMap() { - return singleMetaMessageMap; - } - - @CapabilityMethod - public String getCurrentMemoryId() { - return currentMemoryId; - } - - @CapabilityMethod - public void cleanMessage(List messages) { - messageLock.lock(); - try { - this.getChatMessages().removeAll(messages); - } finally { - messageLock.unlock(); - } - } - @CapabilityMethod public Lock getMessageLock() { return messageLock; } - @CapabilityMethod - public void addMetaMessage(String userId, MetaMessage metaMessage) { - log.debug("[{}] 当前会话历史: {}", getCoreKey(), JSONObject.toJSONString(singleMetaMessageMap)); - messageLock.lock(); - try { - if (singleMetaMessageMap.containsKey(userId)) { - singleMetaMessageMap.get(userId).add(metaMessage); - } else { - singleMetaMessageMap.put(userId, new java.util.ArrayList<>()); - singleMetaMessageMap.get(userId).add(metaMessage); - } - } finally { - messageLock.unlock(); - } - log.debug("[{}] 会话历史更新: {}", getCoreKey(), JSONObject.toJSONString(singleMetaMessageMap)); - } - - @CapabilityMethod - public List unpackAndClear(String userId) { - messageLock.lock(); - try { - List messages = new ArrayList<>(); - List metaMessages = singleMetaMessageMap.get(userId); - if (metaMessages == null) { - return messages; - } - for (MetaMessage metaMessage : metaMessages) { - messages.add(metaMessage.getUserMessage()); - messages.add(metaMessage.getAssistantMessage()); - } - singleMetaMessageMap.remove(userId); - return messages; - } finally { - messageLock.unlock(); - } - } - - @CapabilityMethod - public Map> drainSingleMetaMessages() { - messageLock.lock(); - try { - Map> drained = new HashMap<>(); - for (Map.Entry> entry : singleMetaMessageMap.entrySet()) { - drained.put(entry.getKey(), new ArrayList<>(entry.getValue())); - } - singleMetaMessageMap.clear(); - return drained; - } finally { - messageLock.unlock(); - } - } - - @CapabilityMethod - public List snapshotSingleMetaMessages(String userId) { - messageLock.lock(); - try { - List metaMessages = singleMetaMessageMap.get(userId); - if (metaMessages == null) { - return List.of(); - } - return List.copyOf(metaMessages); - } finally { - messageLock.unlock(); - } - } - - @CapabilityMethod - public void refreshMemoryId() { - currentMemoryId = UUID.randomUUID().toString(); - } - - @CapabilityMethod - public void resetLastUpdatedTime() { - lastUpdatedTime = System.currentTimeMillis(); - } - @Override protected String getCoreKey() { return "cognation-core"; diff --git a/Partner-Core/src/main/java/work/slhaf/partner/core/memory/MemoryCapability.java b/Partner-Core/src/main/java/work/slhaf/partner/core/memory/MemoryCapability.java index c918fdb6..9c19f925 100644 --- a/Partner-Core/src/main/java/work/slhaf/partner/core/memory/MemoryCapability.java +++ b/Partner-Core/src/main/java/work/slhaf/partner/core/memory/MemoryCapability.java @@ -1,51 +1,36 @@ package work.slhaf.partner.core.memory; import work.slhaf.partner.api.agent.factory.capability.annotation.Capability; -import work.slhaf.partner.core.memory.pojo.EvaluatedSlice; -import work.slhaf.partner.core.memory.pojo.MemoryResult; +import work.slhaf.partner.core.memory.pojo.ActivatedMemorySlice; import work.slhaf.partner.core.memory.pojo.MemorySlice; +import work.slhaf.partner.core.memory.pojo.MemoryUnit; -import java.time.LocalDate; -import java.time.LocalDateTime; -import java.util.HashMap; +import java.util.Collection; import java.util.List; -import java.util.concurrent.ConcurrentHashMap; @Capability(value = "memory") public interface MemoryCapability { - void cleanSelectedSliceFilter(); + void clearActivatedSlices(); - String getTopicTree(); + void updateActivatedSlices(List memorySlices); - HashMap getDialogMap(); + boolean hasActivatedSlices(); - ConcurrentHashMap getUserDialogMap(String userId); + int getActivatedSlicesSize(); - void updateDialogMap(LocalDateTime dateTime, String newDialogCache); + List getActivatedSlices(); - String getDialogMapStr(); + void saveMemoryUnit(MemoryUnit memoryUnit); - String getUserDialogMapStr(String userId); + MemoryUnit getMemoryUnit(String unitId); - void updateActivatedSlices(String userId, List memorySlices); + MemorySlice getMemorySlice(String unitId, String sliceId); - String getActivatedSlicesStr(String userId); + Collection listMemoryUnits(); - HashMap> getActivatedSlices(); + void refreshMemoryId(); - void clearActivatedSlices(String userId); - - boolean hasActivatedSlices(String userId); - - int getActivatedSlicesSize(String userId); - - List getActivatedSlices(String userId); - - MemoryResult selectMemory(String topicPathStr); - - MemoryResult selectMemory(LocalDate date); - - void insertSlice(MemorySlice memorySlice, String topicPath); + String getCurrentMemoryId(); } diff --git a/Partner-Core/src/main/java/work/slhaf/partner/core/memory/MemoryCore.java b/Partner-Core/src/main/java/work/slhaf/partner/core/memory/MemoryCore.java index 5b25a4f3..5a901505 100644 --- a/Partner-Core/src/main/java/work/slhaf/partner/core/memory/MemoryCore.java +++ b/Partner-Core/src/main/java/work/slhaf/partner/core/memory/MemoryCore.java @@ -7,19 +7,12 @@ import lombok.extern.slf4j.Slf4j; import work.slhaf.partner.api.agent.factory.capability.annotation.CapabilityCore; import work.slhaf.partner.api.agent.factory.capability.annotation.CapabilityMethod; import work.slhaf.partner.core.PartnerCore; -import work.slhaf.partner.core.memory.exception.UnExistedDateIndexException; -import work.slhaf.partner.core.memory.exception.UnExistedTopicException; -import work.slhaf.partner.core.memory.pojo.EvaluatedSlice; -import work.slhaf.partner.core.memory.pojo.MemoryResult; +import work.slhaf.partner.core.memory.pojo.ActivatedMemorySlice; import work.slhaf.partner.core.memory.pojo.MemorySlice; -import work.slhaf.partner.core.memory.pojo.MemorySliceResult; -import work.slhaf.partner.core.memory.pojo.node.MemoryNode; -import work.slhaf.partner.core.memory.pojo.node.TopicNode; +import work.slhaf.partner.core.memory.pojo.MemoryUnit; import java.io.IOException; import java.io.Serial; -import java.time.LocalDate; -import java.time.LocalDateTime; import java.util.*; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.CopyOnWriteArrayList; @@ -35,571 +28,121 @@ public class MemoryCore extends PartnerCore { @Serial private static final long serialVersionUID = 1L; - private final Lock sliceInsertLock = new ReentrantLock(); - /** - * key: 根主题名称 value: 根主题节点 - */ - private HashMap topicNodes = new HashMap<>(); - /** - * 用于存储已存在的主题列表,便于记忆查找, 使用根主题名称作为键, 子主题名称集合为值 - * 该部分在'主题提取LLM'的system prompt中常驻 - */ - private HashMap /*子主题列表*/> existedTopics = new HashMap<>(); - /** - * 临时的同一对话切片容器, 用于为同一对话内的不同切片提供更新上下文的场所 - */ - private HashMap> currentDateDialogSlices = new HashMap<>(); - /** - * 记忆节点的日期索引, 同一日期内按照对话id区分 - */ - private HashMap> dateIndex = new HashMap<>(); - /** - * 已被选中的切片时间戳集合,需要及时清理 - */ - private Set selectedSlices = new HashSet<>(); - private HashMap> userIndex = new HashMap<>(); - private MemoryCache cache = new MemoryCache(); + + private final Lock memoryLock = new ReentrantLock(); + private ConcurrentHashMap memoryUnits = new ConcurrentHashMap<>(); + private List activatedSlices = new CopyOnWriteArrayList<>(); + private String currentMemoryId; public MemoryCore() throws IOException, ClassNotFoundException { } - @CapabilityMethod - public MemoryResult selectMemory(LocalDate date) throws IOException, ClassNotFoundException { - MemoryResult memoryResult = new MemoryResult(); - CopyOnWriteArrayList targetSliceList = new CopyOnWriteArrayList<>(); - //加载节点并获取记忆切片列表 - List> currentDateDialogSlices = loadSlicesByDate(date); - for (List value : currentDateDialogSlices) { - for (MemorySlice memorySlice : value) { - if (selectedSlices.contains(memorySlice.getTimestamp())) { - continue; - } - MemorySliceResult memorySliceResult = new MemorySliceResult(); - memorySliceResult.setMemorySlice(memorySlice); - targetSliceList.add(memorySliceResult); - selectedSlices.add(memorySlice.getTimestamp()); - } - } - memoryResult.setMemorySliceResult(targetSliceList); - return cacheFilter(memoryResult); + public void clearActivatedSlices() { + activatedSlices.clear(); } @CapabilityMethod - public void insertSlice(MemorySlice memorySlice, String topicPath) { - sliceInsertLock.lock(); - List topicPathList = Arrays.stream(topicPath.split("->")).toList(); + public void updateActivatedSlices(List memorySlices) { + activatedSlices = new CopyOnWriteArrayList<>(memorySlices); + } + + @CapabilityMethod + public boolean hasActivatedSlices() { + return !activatedSlices.isEmpty(); + } + + @CapabilityMethod + public int getActivatedSlicesSize() { + return activatedSlices.size(); + } + + @CapabilityMethod + public List getActivatedSlices() { + return new ArrayList<>(activatedSlices); + } + + @CapabilityMethod + public void saveMemoryUnit(MemoryUnit memoryUnit) { + memoryLock.lock(); try { - //检查是否存在当天对应的memorySlice并确定是否插入 - //每日刷新缓存 - checkCacheDate(); - //如果topicPath在memorySliceCache中存在对应缓存,由于进行的插入操作,则需要移除该缓存,但不清除相关计数 - clearCacheByTopicPath(topicPathList); - insertMemory(topicPathList, memorySlice); - if (!memorySlice.isPrivate()) { - updateUserDialogMap(memorySlice); - } - } catch (Exception e) { - log.error("插入记忆时出错: ", e); + normalizeMemoryUnit(memoryUnit); + memoryUnits.put(memoryUnit.getId(), memoryUnit); + } finally { + memoryLock.unlock(); } - log.debug("插入切片: {}, 路径: {}", memorySlice, topicPath); - sliceInsertLock.unlock(); } @CapabilityMethod - public String getTopicTree() { - StringBuilder stringBuilder = new StringBuilder(); - for (Map.Entry entry : topicNodes.entrySet()) { - String rootName = entry.getKey(); - TopicNode rootNode = entry.getValue(); - stringBuilder.append(rootName).append("[root]").append("\r\n"); - printSubTopicsTreeFormat(rootNode, "", stringBuilder); - } - return stringBuilder.toString(); + public MemoryUnit getMemoryUnit(String unitId) { + return memoryUnits.get(unitId); } @CapabilityMethod - public void updateDialogMap(LocalDateTime dateTime, String newDialogCache) { - List keysToRemove = new ArrayList<>(); - HashMap dialogMap = cache.dialogMap; - dialogMap.forEach((k, v) -> { - if (dateTime.minusDays(2).isAfter(k)) { - keysToRemove.add(k); - } - }); - for (LocalDateTime temp : keysToRemove) { - dialogMap.remove(temp); - } - keysToRemove.clear(); - //放入新缓存 - dialogMap.put(dateTime, newDialogCache); - } - - @CapabilityMethod - public HashMap getDialogMap() { - return cache.dialogMap; - } - - @CapabilityMethod - public ConcurrentHashMap getUserDialogMap(String userId) { - return cache.userDialogMap.get(userId); - } - - @CapabilityMethod - public String getDialogMapStr() { - StringBuilder str = new StringBuilder(); - this.getDialogMap().forEach((dateTime, dialog) -> str.append("\n\n").append("[").append(dateTime).append("]\n") - .append(dialog)); - return str.toString(); - } - - @CapabilityMethod - public String getUserDialogMapStr(String userId) { - ConcurrentHashMap> userDialogMap = cache.userDialogMap; - if (userDialogMap.containsKey(userId)) { - StringBuilder str = new StringBuilder(); - Collection dialogMapValues = this.getDialogMap().values(); - userDialogMap.get(userId).forEach((dateTime, dialog) -> { - if (dialogMapValues.contains(dialog)) { - return; - } - str.append("\n\n").append("[").append(dateTime).append("]\n") - .append(dialog); - }); - return str.toString(); - } else { + public MemorySlice getMemorySlice(String unitId, String sliceId) { + MemoryUnit memoryUnit = memoryUnits.get(unitId); + if (memoryUnit == null || memoryUnit.getSlices() == null) { return null; } - } - - @CapabilityMethod - public MemoryResult selectMemory(String topicPathStr) { - MemoryResult memoryResult; - List topicPath = List.of(topicPathStr.split("->")); - try { - List path = new ArrayList<>(topicPath); - //每日刷新缓存 - checkCacheDate(); - //检测缓存并更新计数, 查看是否需要放入缓存 - updateCacheCounter(path); - //查看是否存在缓存,如果存在,则直接返回 - if ((memoryResult = selectCache(path)) != null) { - return memoryResult; + for (MemorySlice slice : memoryUnit.getSlices()) { + if (sliceId.equals(slice.getId())) { + return slice; } - memoryResult = selectMemory(path); - //尝试更新缓存 - updateCache(topicPath, memoryResult); - } catch (Exception e) { - log.error("[{}] selectMemory error: ", getCoreKey(), e); - log.error("[{}] 路径: {}", getCoreKey(), topicPathStr); - log.error("[{}] 主题树: {}", getCoreKey(), getTopicTree()); - memoryResult = new MemoryResult(); - memoryResult.setRelatedMemorySliceResult(new ArrayList<>()); - memoryResult.setMemorySliceResult(new CopyOnWriteArrayList<>()); - } - return cacheFilter(memoryResult); - } - - @CapabilityMethod - public void updateActivatedSlices(String userId, List memorySlices) { - cache.activatedSlices.put(userId, memorySlices); - log.debug("[{}] 已更新激活切片, userId: {}", getCoreKey(), userId); - } - - @CapabilityMethod - public String getActivatedSlicesStr(String userId) { - HashMap> activatedSlices = cache.activatedSlices; - if (activatedSlices.containsKey(userId)) { - StringBuilder str = new StringBuilder(); - activatedSlices.get(userId).forEach(slice -> str.append("\n\n").append("[").append(slice.getDate()).append("]\n") - .append(slice.getSummary())); - return str.toString(); - } else { - return null; - } - } - - @CapabilityMethod - public HashMap> getActivatedSlices() { - return cache.activatedSlices; - } - - @CapabilityMethod - public void clearActivatedSlices(String userId) { - cache.activatedSlices.remove(userId); - } - - @CapabilityMethod - public boolean hasActivatedSlices(String userId) { - HashMap> activatedSlices = cache.activatedSlices; - if (!activatedSlices.containsKey(userId)) { - return false; - } - return !activatedSlices.get(userId).isEmpty(); - } - - @CapabilityMethod - public int getActivatedSlicesSize(String userId) { - return cache.activatedSlices.get(userId).size(); - } - - @CapabilityMethod - public List getActivatedSlices(String userId) { - return cache.activatedSlices.get(userId); - } - - @CapabilityMethod - public void cleanSelectedSliceFilter() { - this.selectedSlices.clear(); - } - - private List> loadSlicesByDate(LocalDate date) throws IOException, ClassNotFoundException { - if (!dateIndex.containsKey(date)) { - throw new UnExistedDateIndexException("不存在的日期索引: " + date); - } - List> list = new ArrayList<>(); - for (String memoryNodeId : dateIndex.get(date)) { - MemoryNode memoryNode = new MemoryNode(); - memoryNode.setMemoryNodeId(memoryNodeId); - list.add(memoryNode.loadMemorySliceList()); - } - return list; - } - - private void printSubTopicsTreeFormat(TopicNode node, String prefix, StringBuilder stringBuilder) { - if (node.getTopicNodes() == null || node.getTopicNodes().isEmpty()) return; - - List> entries = new ArrayList<>(node.getTopicNodes().entrySet()); - for (int i = 0; i < entries.size(); i++) { - boolean last = (i == entries.size() - 1); - Map.Entry entry = entries.get(i); - stringBuilder.append(prefix).append(last ? "└── " : "├── ").append(entry.getKey()).append("[").append(entry.getValue().getMemoryNodes().size()).append("]").append("\r\n"); - printSubTopicsTreeFormat(entry.getValue(), prefix + (last ? " " : "│ "), stringBuilder); - } - } - - private void insertMemory(List topicPath, MemorySlice slice) throws IOException, ClassNotFoundException { - LocalDate now = LocalDate.now(); - boolean hasSlice = false; - MemoryNode node = null; - TopicNode lastTopicNode = generateTopicPath(topicPath); - for (MemoryNode memoryNode : lastTopicNode.getMemoryNodes()) { - if (now.equals(memoryNode.getLocalDate())) { - hasSlice = true; - node = memoryNode; - break; - } - } - if (!hasSlice) { - node = new MemoryNode(); - node.setLocalDate(now); - node.setMemoryNodeId(UUID.randomUUID().toString()); - node.setMemorySliceList(new CopyOnWriteArrayList<>()); - lastTopicNode.getMemoryNodes().add(node); - lastTopicNode.getMemoryNodes().sort(null); - } - node.loadMemorySliceList().add(slice); - - //生成relatedTopicPath - for (List relatedTopic : slice.getRelatedTopics()) { - generateTopicPath(relatedTopic); - } - - updateSlicePrecedent(slice); - updateDateIndex(slice); - updateUserIndex(slice); - - node.saveMemorySliceList(); - - } - - private void updateUserIndex(MemorySlice slice) { - String memoryId = slice.getMemoryId(); - String userId = slice.getStartUserId(); - if (!userIndex.containsKey(userId)) { - List memoryIdSet = new ArrayList<>(); - memoryIdSet.add(memoryId); - userIndex.put(userId, memoryIdSet); - } else { - userIndex.get(userId).add(memoryId); - } - } - - - private TopicNode generateTopicPath(List topicPath) { - topicPath = new ArrayList<>(topicPath); - //查看是否存在根主题节点 - String rootTopic = topicPath.getFirst(); - topicPath.removeFirst(); - if (!topicNodes.containsKey(rootTopic)) { - synchronized (this) { - if (!topicNodes.containsKey(rootTopic)) { - TopicNode rootNode = new TopicNode(); - topicNodes.put(rootTopic, rootNode); - existedTopics.put(rootTopic, new LinkedHashSet<>()); - } - } - } - - TopicNode current = topicNodes.get(rootTopic); - Set existedTopicNodes = existedTopics.get(rootTopic); - for (String topic : topicPath) { - if (existedTopicNodes.contains(topic) && current.getTopicNodes().containsKey(topic)) { - current = current.getTopicNodes().get(topic); - } else { - TopicNode newNode = new TopicNode(); - current.getTopicNodes().put(topic, newNode); - current = newNode; - - current.setMemoryNodes(new CopyOnWriteArrayList<>()); - current.setTopicNodes(new ConcurrentHashMap<>()); - existedTopicNodes.add(topic); - } - } - return current; - } - - private void updateSlicePrecedent(MemorySlice slice) { - String memoryId = slice.getMemoryId(); - //查看是否切换了memoryId - if (!currentDateDialogSlices.containsKey(memoryId)) { - List memorySliceList = new ArrayList<>(); - currentDateDialogSlices.clear(); - currentDateDialogSlices.put(memoryId, memorySliceList); - } - //处理上下文关系 - List memorySliceList = currentDateDialogSlices.get(memoryId); - if (memorySliceList.isEmpty()) { - memorySliceList.add(slice); - } else { - //排序 - memorySliceList.sort(null); - MemorySlice tempSlice = memorySliceList.getLast(); - //设置私密状态一致 - tempSlice.setPrivate(slice.isPrivate()); - //末尾切片添加当前切片的引用 - tempSlice.setSliceAfter(slice); - //当前切片添加前序切片的引用 - slice.setSliceBefore(tempSlice); - } - - } - - private void updateDateIndex(MemorySlice slice) { - String memoryId = slice.getMemoryId(); - LocalDate date = LocalDate.now(); - if (!dateIndex.containsKey(date)) { - HashSet memoryIdSet = new HashSet<>(); - memoryIdSet.add(memoryId); - dateIndex.put(date, memoryIdSet); - } else { - dateIndex.get(date).add(memoryId); - } - } - - public MemoryResult selectMemory(List path) throws IOException, ClassNotFoundException { - MemoryResult memoryResult = new MemoryResult(); - CopyOnWriteArrayList targetSliceList = new CopyOnWriteArrayList<>(); - String targetTopic = path.getLast(); - TopicNode targetParentNode = getTargetParentNode(path, targetTopic); - List> relatedTopics = new ArrayList<>(); - - //终点记忆节点 - MemorySliceResult sliceResult = new MemorySliceResult(); - for (MemoryNode memoryNode : targetParentNode.getTopicNodes().get(targetTopic).getMemoryNodes()) { - List endpointMemorySliceList = memoryNode.loadMemorySliceList(); - for (MemorySlice memorySlice : endpointMemorySliceList) { - if (selectedSlices.contains(memorySlice.getTimestamp())) { - continue; - } - sliceResult.setSliceBefore(memorySlice.getSliceBefore()); - sliceResult.setMemorySlice(memorySlice); - sliceResult.setSliceAfter(memorySlice.getSliceAfter()); - targetSliceList.add(sliceResult); - selectedSlices.add(memorySlice.getTimestamp()); - } - for (MemorySlice memorySlice : endpointMemorySliceList) { - if (memorySlice.getRelatedTopics() != null) { - relatedTopics.addAll(memorySlice.getRelatedTopics()); - } - } - } - memoryResult.setMemorySliceResult(targetSliceList); - - //邻近节点 - List relatedMemorySlice = new ArrayList<>(); - //邻近记忆节点 联系 - for (List relatedTopic : relatedTopics) { - List tempTopicPath = new ArrayList<>(relatedTopic); - String tempTargetTopic = tempTopicPath.getLast(); - TopicNode tempTargetParentNode = getTargetParentNode(tempTopicPath, tempTargetTopic); - //获取终点节点及其最新记忆节点 - TopicNode tempTargetNode = tempTargetParentNode.getTopicNodes().get(tempTopicPath.getLast()); - setRelatedMemorySlices(tempTargetNode, relatedMemorySlice); - } - - //邻近记忆节点 父级 - setRelatedMemorySlices(targetParentNode, relatedMemorySlice); - - //将上述结果包装为MemoryResult - memoryResult.setRelatedMemorySliceResult(relatedMemorySlice); - return memoryResult; - } - - private void setRelatedMemorySlices(TopicNode targetParentNode, List relatedMemorySlice) throws IOException, ClassNotFoundException { - List targetParentMemoryNodes = targetParentNode.getMemoryNodes(); - if (!targetParentMemoryNodes.isEmpty()) { - for (MemorySlice memorySlice : targetParentMemoryNodes.getFirst().loadMemorySliceList()) { - if (selectedSlices.contains(memorySlice.getTimestamp())) { - continue; - } - relatedMemorySlice.add(memorySlice); - selectedSlices.add(memorySlice.getTimestamp()); - } - } - } - - private TopicNode getTargetParentNode(List topicPath, String targetTopic) { - String topTopic = topicPath.getFirst(); - if (!existedTopics.containsKey(topTopic)) { - throw new UnExistedTopicException("不存在的主题: " + topTopic); - } - TopicNode targetParentNode = topicNodes.get(topTopic); - topicPath.removeFirst(); - for (String topic : topicPath) { - if (!existedTopics.get(topTopic).contains(topic)) { - throw new UnExistedTopicException("不存在的主题: " + topTopic); - } - } - - //逐层查找目标主题 - while (!targetParentNode.getTopicNodes().containsKey(targetTopic)) { - targetParentNode = targetParentNode.getTopicNodes().get(topicPath.getFirst()); - topicPath.removeFirst(); - } - return targetParentNode; - } - - private void updateCacheCounter(List topicPath) { - ConcurrentHashMap, Integer> memoryNodeCacheCounter = cache.memoryNodeCacheCounter; - if (memoryNodeCacheCounter.containsKey(topicPath)) { - Integer tempCount = memoryNodeCacheCounter.get(topicPath); - memoryNodeCacheCounter.put(topicPath, ++tempCount); - } else { - memoryNodeCacheCounter.put(topicPath, 1); - } - } - - private void checkCacheDate() { - if (cache.cacheDate == null || cache.cacheDate.isBefore(LocalDate.now())) { - cache.memorySliceCache.clear(); - cache.memoryNodeCacheCounter.clear(); - cache.cacheDate = LocalDate.now(); - } - } - - private void updateCache(List topicPath, MemoryResult memoryResult) { - ConcurrentHashMap, Integer> memoryNodeCacheCounter = cache.memoryNodeCacheCounter; - Integer tempCount = memoryNodeCacheCounter.get(topicPath); - if (tempCount == null) { - log.warn("[CacheCore] tempCount为null? memoryNodeCacheCounter: {}; topicPath: {}", memoryNodeCacheCounter, topicPath); - return; - } - if (tempCount >= 5) { - cache.memorySliceCache.put(topicPath, memoryResult); - } - } - - private void updateUserDialogMap(MemorySlice slice) { - String summary = slice.getSummary(); - LocalDateTime now = LocalDateTime.now(); - ConcurrentHashMap> userDialogMap = cache.userDialogMap; - - //更新userDialogMap - //移除两天前上下文缓存(切片总结) - List keysToRemove = new ArrayList<>(); - userDialogMap.forEach((k, v) -> v.forEach((i, j) -> { - if (now.minusDays(2).isAfter(i)) { - keysToRemove.add(i); - } - })); - for (LocalDateTime dateTime : keysToRemove) { - userDialogMap.forEach((k, v) -> v.remove(dateTime)); - } - //放入新缓存 - userDialogMap - .computeIfAbsent(slice.getStartUserId(), k -> new ConcurrentHashMap<>()) - .merge(now, summary, (oldVal, newVal) -> oldVal + " " + newVal); - - } - - private void clearCacheByTopicPath(List topicPath) { - cache.memorySliceCache.remove(topicPath); - } - - private MemoryResult selectCache(List path) { - ConcurrentHashMap, MemoryResult> memorySliceCache = cache.memorySliceCache; - if (memorySliceCache.containsKey(path)) { - return memorySliceCache.get(path); } return null; } + @CapabilityMethod + public Collection listMemoryUnits() { + return new ArrayList<>(memoryUnits.values()); + } + + @CapabilityMethod + public void refreshMemoryId() { + currentMemoryId = UUID.randomUUID().toString(); + } + + @CapabilityMethod + public String getCurrentMemoryId() { + return currentMemoryId; + } + + private void normalizeMemoryUnit(MemoryUnit memoryUnit) { + if (memoryUnit.getId() == null || memoryUnit.getId().isBlank()) { + memoryUnit.setId(UUID.randomUUID().toString()); + } + if (memoryUnit.getTimestamp() == null || memoryUnit.getTimestamp() <= 0) { + memoryUnit.setTimestamp(System.currentTimeMillis()); + } + if (memoryUnit.getConversationMessages() == null) { + memoryUnit.setConversationMessages(new ArrayList<>()); + } + if (memoryUnit.getSlices() == null) { + memoryUnit.setSlices(new ArrayList<>()); + } + int maxIndex = Math.max(memoryUnit.getConversationMessages().size() - 1, 0); + for (MemorySlice slice : memoryUnit.getSlices()) { + if (slice.getId() == null || slice.getId().isBlank()) { + slice.setId(UUID.randomUUID().toString()); + } + if (slice.getTimestamp() == null || slice.getTimestamp() <= 0) { + slice.setTimestamp(memoryUnit.getTimestamp()); + } + if (slice.getStartIndex() == null || slice.getStartIndex() < 0) { + slice.setStartIndex(0); + } + if (slice.getEndIndex() == null || slice.getEndIndex() < slice.getStartIndex()) { + slice.setEndIndex(maxIndex); + } + if (slice.getEndIndex() > maxIndex) { + slice.setEndIndex(maxIndex); + } + } + memoryUnit.getSlices().sort(Comparator.naturalOrder()); + } + @Override protected String getCoreKey() { return "memory-core"; } - - public ConcurrentHashMap> getUserDialogMap() { - return cache.userDialogMap; - } - - - private MemoryResult cacheFilter(MemoryResult memoryResult) { - //过滤掉与缓存重复的切片 - CopyOnWriteArrayList memorySliceResult = memoryResult.getMemorySliceResult(); - List relatedMemorySliceResult = memoryResult.getRelatedMemorySliceResult(); - cache.dialogMap.forEach((k, v) -> { - memorySliceResult.removeIf(m -> m.getMemorySlice().getSummary().equals(v)); - relatedMemorySliceResult.removeIf(m -> m.getSummary().equals(v)); - }); - return memoryResult; - } - - @SuppressWarnings("FieldMayBeFinal") - private static class MemoryCache { - - /** - * 近两日的对话总结缓存, 用于为大模型提供必要的记忆补充, hashmap以切片的存储时间为键,总结为值 - * 该部分作为'主LLM'system prompt常驻 - * 该部分作为近两日的整体对话缓存, 不区分用户 - */ - private HashMap dialogMap = new HashMap<>(); - - /** - * 近两日的区分用户的对话总结缓存,在prompt结构上比dialogMap层级深一层, dialogMap更具近两日整体对话的摘要性质 - */ - private ConcurrentHashMap> userDialogMap = new ConcurrentHashMap<>(); - - /** - * memorySliceCache计数器,每日清空 - */ - private ConcurrentHashMap /*触发查询的主题列表*/, Integer> memoryNodeCacheCounter = new ConcurrentHashMap<>(); - - /** - * 记忆切片缓存,每日清空 - * 用于记录作为终点节点调用次数最多的记忆节点的切片数据 - */ - private ConcurrentHashMap /*主题路径*/, MemoryResult /*切片列表*/> memorySliceCache = new ConcurrentHashMap<>(); - - /** - * 缓存日期 - */ - private LocalDate cacheDate; - - private HashMap> activatedSlices = new HashMap<>(); - - private MemoryCache() { - } - } } diff --git a/Partner-Core/src/main/java/work/slhaf/partner/core/memory/pojo/EvaluatedSlice.java b/Partner-Core/src/main/java/work/slhaf/partner/core/memory/pojo/ActivatedMemorySlice.java similarity index 62% rename from Partner-Core/src/main/java/work/slhaf/partner/core/memory/pojo/EvaluatedSlice.java rename to Partner-Core/src/main/java/work/slhaf/partner/core/memory/pojo/ActivatedMemorySlice.java index c756bf44..3dfdb54e 100644 --- a/Partner-Core/src/main/java/work/slhaf/partner/core/memory/pojo/EvaluatedSlice.java +++ b/Partner-Core/src/main/java/work/slhaf/partner/core/memory/pojo/ActivatedMemorySlice.java @@ -3,20 +3,25 @@ package work.slhaf.partner.core.memory.pojo; import lombok.Builder; import lombok.Data; import lombok.EqualsAndHashCode; +import work.slhaf.partner.api.chat.pojo.Message; import work.slhaf.partner.api.common.entity.PersistableObject; import java.io.Serial; import java.time.LocalDate; +import java.util.List; @EqualsAndHashCode(callSuper = true) @Data @Builder -public class EvaluatedSlice extends PersistableObject { +public class ActivatedMemorySlice extends PersistableObject { @Serial private static final long serialVersionUID = 1L; - // private List chatMessages; + private String unitId; + private String sliceId; private LocalDate date; + private Long timestamp; private String summary; + private List messages; } diff --git a/Partner-Core/src/main/java/work/slhaf/partner/core/memory/pojo/MemoryResult.java b/Partner-Core/src/main/java/work/slhaf/partner/core/memory/pojo/MemoryResult.java deleted file mode 100644 index cad42d7c..00000000 --- a/Partner-Core/src/main/java/work/slhaf/partner/core/memory/pojo/MemoryResult.java +++ /dev/null @@ -1,26 +0,0 @@ -package work.slhaf.partner.core.memory.pojo; - -import lombok.Data; -import lombok.EqualsAndHashCode; -import work.slhaf.partner.api.common.entity.PersistableObject; - -import java.io.Serial; -import java.util.List; -import java.util.concurrent.CopyOnWriteArrayList; - -@EqualsAndHashCode(callSuper = true) -@Data -public class MemoryResult extends PersistableObject { - - @Serial - private static final long serialVersionUID = 1L; - - private CopyOnWriteArrayList memorySliceResult; - private List relatedMemorySliceResult; - - public boolean isEmpty() { - boolean a = memorySliceResult == null || memorySliceResult.isEmpty(); - boolean b = relatedMemorySliceResult == null || relatedMemorySliceResult.isEmpty(); - return a && b; - } -} diff --git a/Partner-Core/src/main/java/work/slhaf/partner/core/memory/pojo/MemorySlice.java b/Partner-Core/src/main/java/work/slhaf/partner/core/memory/pojo/MemorySlice.java index 84f9f899..1ca07cc5 100644 --- a/Partner-Core/src/main/java/work/slhaf/partner/core/memory/pojo/MemorySlice.java +++ b/Partner-Core/src/main/java/work/slhaf/partner/core/memory/pojo/MemorySlice.java @@ -2,12 +2,9 @@ package work.slhaf.partner.core.memory.pojo; import lombok.Data; import lombok.EqualsAndHashCode; -import lombok.ToString; -import work.slhaf.partner.api.chat.pojo.Message; import work.slhaf.partner.api.common.entity.PersistableObject; import java.io.Serial; -import java.util.List; @EqualsAndHashCode(callSuper = true) @Data @@ -16,59 +13,11 @@ public class MemorySlice extends PersistableObject implements Comparable.slice", 如2025-04-11.slice - */ + private String id; + private Integer startIndex; + private Integer endIndex; private String summary; - - private List chatMessages; - - /** - * 关联的其他主题, 即"邻近节点(联系)" - */ - private List> relatedTopics; - - /** - * 关联完整对话中的前序切片, 排序为键,完整路径为值 - */ - @ToString.Exclude - private MemorySlice sliceBefore, sliceAfter; - - /** - * 多用户设定 - * 发起该切片对话的用户 - */ - private String startUserId; - - /** - * 该切片涉及到的用户uuid - */ - private List involvedUserIds; - - /** - * 是否仅供发起用户作为记忆参考 - */ - private boolean isPrivate; - - /** - * 摘要向量化结果 - */ - private float[] summaryEmbedding; - - /** - * 是否向量化 - */ - private boolean embedded; + private Long timestamp; @Override public int compareTo(MemorySlice memorySlice) { @@ -79,5 +28,4 @@ public class MemorySlice extends PersistableObject implements Comparable conversationMessages = new ArrayList<>(); + private Long timestamp; + private List slices = new ArrayList<>(); +} diff --git a/Partner-Core/src/main/java/work/slhaf/partner/core/memory/pojo/MemorySliceResult.java b/Partner-Core/src/main/java/work/slhaf/partner/core/memory/pojo/SliceRef.java similarity index 50% rename from Partner-Core/src/main/java/work/slhaf/partner/core/memory/pojo/MemorySliceResult.java rename to Partner-Core/src/main/java/work/slhaf/partner/core/memory/pojo/SliceRef.java index 06e3f8d6..f25881ff 100644 --- a/Partner-Core/src/main/java/work/slhaf/partner/core/memory/pojo/MemorySliceResult.java +++ b/Partner-Core/src/main/java/work/slhaf/partner/core/memory/pojo/SliceRef.java @@ -1,24 +1,22 @@ package work.slhaf.partner.core.memory.pojo; -import com.alibaba.fastjson2.annotation.JSONField; +import lombok.AllArgsConstructor; import lombok.Data; import lombok.EqualsAndHashCode; +import lombok.NoArgsConstructor; import work.slhaf.partner.api.common.entity.PersistableObject; import java.io.Serial; @EqualsAndHashCode(callSuper = true) @Data -public class MemorySliceResult extends PersistableObject { +@NoArgsConstructor +@AllArgsConstructor +public class SliceRef extends PersistableObject { @Serial private static final long serialVersionUID = 1L; - @JSONField(serialize = false) - private MemorySlice sliceBefore; - - private MemorySlice memorySlice; - - @JSONField(serialize = false) - private MemorySlice sliceAfter; + private String unitId; + private String sliceId; } diff --git a/Partner-Core/src/main/java/work/slhaf/partner/module/modules/action/executor/ActionExecutor.java b/Partner-Core/src/main/java/work/slhaf/partner/module/modules/action/executor/ActionExecutor.java index e436d96e..2ba47ad0 100644 --- a/Partner-Core/src/main/java/work/slhaf/partner/module/modules/action/executor/ActionExecutor.java +++ b/Partner-Core/src/main/java/work/slhaf/partner/module/modules/action/executor/ActionExecutor.java @@ -367,7 +367,7 @@ public class ActionExecutor extends AbstractAgentModule.Standalone { private ExtractorInput buildExtractorInput(MetaAction action, String source, List historyActionResults, List additionalContext) { ExtractorInput input = new ExtractorInput(); - input.setEvaluatedSlices(memoryCapability.getActivatedSlices(source)); + input.setActivatedMemorySlices(memoryCapability.getActivatedSlices()); input.setRecentMessages(cognationCapability.getChatMessages()); input.setMetaActionInfo(actionCapability.loadMetaActionInfo(action.getKey())); input.setHistoryActionResults(historyActionResults); @@ -384,7 +384,7 @@ public class ActionExecutor extends AbstractAgentModule.Standalone { .history(executableAction.getHistory().get(executableAction.getExecutingStage())) .status(executableAction.getStatus()) .recentMessages(cognationCapability.getChatMessages()) - .activatedSlices(memoryCapability.getActivatedSlices(source)) + .activatedSlices(memoryCapability.getActivatedSlices()) .build(); } } diff --git a/Partner-Core/src/main/java/work/slhaf/partner/module/modules/action/executor/entity/CorrectorInput.java b/Partner-Core/src/main/java/work/slhaf/partner/module/modules/action/executor/entity/CorrectorInput.java index db3d22a5..b728e54b 100644 --- a/Partner-Core/src/main/java/work/slhaf/partner/module/modules/action/executor/entity/CorrectorInput.java +++ b/Partner-Core/src/main/java/work/slhaf/partner/module/modules/action/executor/entity/CorrectorInput.java @@ -4,7 +4,7 @@ import lombok.Builder; import lombok.Data; import work.slhaf.partner.api.chat.pojo.Message; import work.slhaf.partner.core.action.entity.ExecutableAction; -import work.slhaf.partner.core.memory.pojo.EvaluatedSlice; +import work.slhaf.partner.core.memory.pojo.ActivatedMemorySlice; import java.util.List; @@ -20,5 +20,5 @@ public class CorrectorInput { private ExecutableAction.Status status; private List recentMessages; - private List activatedSlices; + private List activatedSlices; } diff --git a/Partner-Core/src/main/java/work/slhaf/partner/module/modules/action/executor/entity/ExtractorInput.java b/Partner-Core/src/main/java/work/slhaf/partner/module/modules/action/executor/entity/ExtractorInput.java index 8189022c..cc274428 100644 --- a/Partner-Core/src/main/java/work/slhaf/partner/module/modules/action/executor/entity/ExtractorInput.java +++ b/Partner-Core/src/main/java/work/slhaf/partner/module/modules/action/executor/entity/ExtractorInput.java @@ -3,7 +3,7 @@ package work.slhaf.partner.module.modules.action.executor.entity; import lombok.Data; import work.slhaf.partner.api.chat.pojo.Message; import work.slhaf.partner.core.action.entity.MetaActionInfo; -import work.slhaf.partner.core.memory.pojo.EvaluatedSlice; +import work.slhaf.partner.core.memory.pojo.ActivatedMemorySlice; import java.util.List; @@ -16,7 +16,7 @@ public class ExtractorInput { /** * 可参考的记忆切片 */ - private List evaluatedSlices; + private List activatedMemorySlices; /** * 历史行动执行结果 */ diff --git a/Partner-Core/src/main/java/work/slhaf/partner/module/modules/action/interventor/ActionInterventor.java b/Partner-Core/src/main/java/work/slhaf/partner/module/modules/action/interventor/ActionInterventor.java index 6d82cdd8..dae0fa8e 100644 --- a/Partner-Core/src/main/java/work/slhaf/partner/module/modules/action/interventor/ActionInterventor.java +++ b/Partner-Core/src/main/java/work/slhaf/partner/module/modules/action/interventor/ActionInterventor.java @@ -22,6 +22,7 @@ import work.slhaf.partner.module.modules.action.interventor.evaluator.entity.Eva import work.slhaf.partner.module.modules.action.interventor.recognizer.InterventionRecognizer; import work.slhaf.partner.module.modules.action.interventor.recognizer.entity.RecognizerInput; import work.slhaf.partner.module.modules.action.interventor.recognizer.entity.RecognizerResult; +import work.slhaf.partner.module.modules.memory.runtime.MemoryRuntime; import work.slhaf.partner.runtime.interaction.data.context.PartnerRunningFlowContext; import java.util.ArrayList; @@ -51,6 +52,8 @@ public class ActionInterventor extends AbstractAgentModule.Running recentMessages, List activatedSlices, + private String buildPrompt(List recentMessages, List activatedSlices, ExecutableAction executableAction, String tendency) { JSONObject json = new JSONObject(); json.put("干预倾向", tendency); diff --git a/Partner-Core/src/main/java/work/slhaf/partner/module/modules/action/interventor/evaluator/entity/EvaluatorInput.java b/Partner-Core/src/main/java/work/slhaf/partner/module/modules/action/interventor/evaluator/entity/EvaluatorInput.java index 379e5a67..1077726f 100644 --- a/Partner-Core/src/main/java/work/slhaf/partner/module/modules/action/interventor/evaluator/entity/EvaluatorInput.java +++ b/Partner-Core/src/main/java/work/slhaf/partner/module/modules/action/interventor/evaluator/entity/EvaluatorInput.java @@ -3,7 +3,7 @@ package work.slhaf.partner.module.modules.action.interventor.evaluator.entity; import lombok.Data; import work.slhaf.partner.api.chat.pojo.Message; import work.slhaf.partner.core.action.entity.ExecutableAction; -import work.slhaf.partner.core.memory.pojo.EvaluatedSlice; +import work.slhaf.partner.core.memory.pojo.ActivatedMemorySlice; import java.util.List; import java.util.Map; @@ -12,6 +12,6 @@ import java.util.Map; public class EvaluatorInput { private Map executingInterventions; private Map preparedInterventions; - private List activatedSlices; + private List activatedSlices; private List recentMessages; } diff --git a/Partner-Core/src/main/java/work/slhaf/partner/module/modules/action/planner/ActionPlanner.java b/Partner-Core/src/main/java/work/slhaf/partner/module/modules/action/planner/ActionPlanner.java index e44a5ea7..230a2b7f 100644 --- a/Partner-Core/src/main/java/work/slhaf/partner/module/modules/action/planner/ActionPlanner.java +++ b/Partner-Core/src/main/java/work/slhaf/partner/module/modules/action/planner/ActionPlanner.java @@ -328,7 +328,7 @@ public class ActionPlanner extends AbstractAgentModule.Running recentMessages; - private List activatedSlices; + private List activatedSlices; private Map availableActions; private String tendency; } diff --git a/Partner-Core/src/main/java/work/slhaf/partner/module/modules/action/planner/evaluator/entity/EvaluatorInput.java b/Partner-Core/src/main/java/work/slhaf/partner/module/modules/action/planner/evaluator/entity/EvaluatorInput.java index 79908620..e93cac3f 100644 --- a/Partner-Core/src/main/java/work/slhaf/partner/module/modules/action/planner/evaluator/entity/EvaluatorInput.java +++ b/Partner-Core/src/main/java/work/slhaf/partner/module/modules/action/planner/evaluator/entity/EvaluatorInput.java @@ -2,7 +2,7 @@ package work.slhaf.partner.module.modules.action.planner.evaluator.entity; import lombok.Data; import work.slhaf.partner.api.chat.pojo.Message; -import work.slhaf.partner.core.memory.pojo.EvaluatedSlice; +import work.slhaf.partner.core.memory.pojo.ActivatedMemorySlice; import work.slhaf.partner.core.perceive.pojo.User; import java.util.List; @@ -11,6 +11,6 @@ import java.util.List; public class EvaluatorInput { private List recentMessages; private User user; - private List activatedSlices; + private List activatedSlices; private List tendencies; } diff --git a/Partner-Core/src/main/java/work/slhaf/partner/module/modules/core/CommunicationProducer.java b/Partner-Core/src/main/java/work/slhaf/partner/module/modules/core/CommunicationProducer.java index d883c33f..963f55ce 100644 --- a/Partner-Core/src/main/java/work/slhaf/partner/module/modules/core/CommunicationProducer.java +++ b/Partner-Core/src/main/java/work/slhaf/partner/module/modules/core/CommunicationProducer.java @@ -12,7 +12,6 @@ import work.slhaf.partner.api.agent.factory.component.abstracts.ActivateModel; import work.slhaf.partner.api.agent.factory.component.annotation.Init; import work.slhaf.partner.api.agent.runtime.interaction.flow.ContextBlock; import work.slhaf.partner.api.chat.pojo.Message; -import work.slhaf.partner.api.chat.pojo.MetaMessage; import work.slhaf.partner.core.cognation.CognationCapability; import work.slhaf.partner.runtime.interaction.data.context.PartnerRunningFlowContext; @@ -126,8 +125,6 @@ public class CommunicationProducer extends AbstractAgentModule.Running> topicSlices = new HashMap<>(); + private Map> dateIndex = new HashMap<>(); + private HashMap dialogMap = new HashMap<>(); + + @Init + public void init() { + loadState(); + Runtime.getRuntime().addShutdownHook(new Thread(this::saveStateSafely)); + } + + public void bindTopic(String topicPath, SliceRef sliceRef) { + String normalizedPath = normalizeTopicPath(topicPath); + runtimeLock.lock(); + try { + CopyOnWriteArrayList refs = topicSlices.computeIfAbsent(normalizedPath, key -> new CopyOnWriteArrayList<>()); + boolean exists = refs.stream().anyMatch(ref -> Objects.equals(ref.getUnitId(), sliceRef.getUnitId()) + && Objects.equals(ref.getSliceId(), sliceRef.getSliceId())); + if (!exists) { + refs.add(sliceRef); + } + saveState(); + } finally { + runtimeLock.unlock(); + } + } + + public void indexMemoryUnit(MemoryUnit memoryUnit) { + runtimeLock.lock(); + try { + for (CopyOnWriteArrayList refs : dateIndex.values()) { + refs.removeIf(ref -> memoryUnit.getId().equals(ref.getUnitId())); + } + if (memoryUnit.getSlices() != null) { + for (MemorySlice slice : memoryUnit.getSlices()) { + LocalDate date = Instant.ofEpochMilli(slice.getTimestamp()) + .atZone(ZoneId.systemDefault()) + .toLocalDate(); + dateIndex.computeIfAbsent(date, key -> new CopyOnWriteArrayList<>()) + .addIfAbsent(new SliceRef(memoryUnit.getId(), slice.getId())); + } + } + saveState(); + } finally { + runtimeLock.unlock(); + } + } + + public List findByTopicPath(String topicPath) { + String normalizedPath = normalizeTopicPath(topicPath); + List refs = topicSlices.get(normalizedPath); + if (refs == null || refs.isEmpty()) { + throw new UnExistedTopicException("不存在的主题: " + normalizedPath); + } + return new ArrayList<>(refs); + } + + public List findByDate(LocalDate date) { + List refs = dateIndex.get(date); + if (refs == null || refs.isEmpty()) { + throw new UnExistedDateIndexException("不存在的日期索引: " + date); + } + return new ArrayList<>(refs); + } + + public void updateDialogMap(LocalDateTime dateTime, String newDialogCache) { + runtimeLock.lock(); + try { + List keysToRemove = new ArrayList<>(); + dialogMap.forEach((k, v) -> { + if (dateTime.minusDays(2).isAfter(k)) { + keysToRemove.add(k); + } + }); + for (LocalDateTime temp : keysToRemove) { + dialogMap.remove(temp); + } + dialogMap.put(dateTime, newDialogCache); + saveState(); + } finally { + runtimeLock.unlock(); + } + } + + public HashMap getDialogMap() { + return dialogMap; + } + + public String getDialogMapStr() { + StringBuilder str = new StringBuilder(); + dialogMap.entrySet().stream() + .sorted(Map.Entry.comparingByKey()) + .forEach(entry -> str.append("\n\n[") + .append(entry.getKey()) + .append("]\n") + .append(entry.getValue())); + return str.toString(); + } + + public String getTopicTree() { + TopicTreeNode root = new TopicTreeNode(); + for (Map.Entry> entry : topicSlices.entrySet()) { + String[] parts = entry.getKey().split("->"); + TopicTreeNode current = root; + for (String part : parts) { + current = current.children.computeIfAbsent(part, key -> new TopicTreeNode()); + } + current.count += entry.getValue().size(); + } + + StringBuilder stringBuilder = new StringBuilder(); + List> roots = new ArrayList<>(root.children.entrySet()); + for (Map.Entry entry : roots) { + stringBuilder.append(entry.getKey()).append("[root]").append("\r\n"); + printSubTopicsTreeFormat(entry.getValue(), "", stringBuilder); + } + return stringBuilder.toString(); + } + + private void printSubTopicsTreeFormat(TopicTreeNode node, String prefix, StringBuilder stringBuilder) { + List> entries = new ArrayList<>(node.children.entrySet()); + for (int i = 0; i < entries.size(); i++) { + boolean last = i == entries.size() - 1; + Map.Entry entry = entries.get(i); + stringBuilder.append(prefix) + .append(last ? "└── " : "├── ") + .append(entry.getKey()) + .append("[") + .append(entry.getValue().count) + .append("]") + .append("\r\n"); + printSubTopicsTreeFormat(entry.getValue(), prefix + (last ? " " : "│ "), stringBuilder); + } + } + + private String normalizeTopicPath(String topicPath) { + return topicPath == null ? "" : topicPath.trim(); + } + + private void loadState() { + Path filePath = getFilePath(); + if (!Files.exists(filePath)) { + return; + } + try (ObjectInputStream ois = new ObjectInputStream(new FileInputStream(filePath.toFile()))) { + RuntimeState state = (RuntimeState) ois.readObject(); + topicSlices = state.topicSlices; + dateIndex = state.dateIndex; + dialogMap = state.dialogMap; + } catch (Exception e) { + log.error("[MemoryRuntime] 加载运行态失败", e); + topicSlices = new HashMap<>(); + dateIndex = new HashMap<>(); + dialogMap = new HashMap<>(); + } + } + + private void saveStateSafely() { + runtimeLock.lock(); + try { + saveState(); + } finally { + runtimeLock.unlock(); + } + } + + private void saveState() { + Path filePath = getFilePath(); + Path tempPath = getTempFilePath(); + try { + Files.createDirectories(Paths.get(MEMORY_DATA)); + FileUtils.createParentDirectories(filePath.toFile()); + try (ObjectOutputStream oos = new ObjectOutputStream(new FileOutputStream(tempPath.toFile()))) { + RuntimeState state = new RuntimeState(); + state.topicSlices = new HashMap<>(topicSlices); + state.dateIndex = new HashMap<>(dateIndex); + state.dialogMap = new HashMap<>(dialogMap); + oos.writeObject(state); + } + Files.move(tempPath, filePath, java.nio.file.StandardCopyOption.REPLACE_EXISTING); + } catch (IOException e) { + log.error("[MemoryRuntime] 保存运行态失败", e); + } + } + + private Path getFilePath() { + String id = ((PartnerAgentConfigLoader) AgentConfigLoader.INSTANCE).getConfig().getAgentId(); + return Paths.get(MEMORY_DATA, id + "-" + RUNTIME_KEY + ".memory"); + } + + private Path getTempFilePath() { + String id = ((PartnerAgentConfigLoader) AgentConfigLoader.INSTANCE).getConfig().getAgentId(); + return Paths.get(MEMORY_DATA, id + "-" + RUNTIME_KEY + "-temp.memory"); + } + + private static final class TopicTreeNode { + private final Map children = new LinkedHashMap<>(); + private int count; + } + + private static final class RuntimeState extends PersistableObject { + @Serial + private static final long serialVersionUID = 1L; + + private Map> topicSlices = new HashMap<>(); + private Map> dateIndex = new HashMap<>(); + private HashMap dialogMap = new HashMap<>(); + } +} diff --git a/Partner-Core/src/main/java/work/slhaf/partner/module/modules/memory/selector/MemorySelector.java b/Partner-Core/src/main/java/work/slhaf/partner/module/modules/memory/selector/MemorySelector.java index f2a79cd7..fc0c4365 100644 --- a/Partner-Core/src/main/java/work/slhaf/partner/module/modules/memory/selector/MemorySelector.java +++ b/Partner-Core/src/main/java/work/slhaf/partner/module/modules/memory/selector/MemorySelector.java @@ -6,13 +6,16 @@ import lombok.EqualsAndHashCode; import work.slhaf.partner.api.agent.factory.capability.annotation.InjectCapability; import work.slhaf.partner.api.agent.factory.component.abstracts.AbstractAgentModule; import work.slhaf.partner.api.agent.factory.component.annotation.InjectModule; +import work.slhaf.partner.api.chat.pojo.Message; import work.slhaf.partner.core.cognation.CognationCapability; import work.slhaf.partner.core.memory.MemoryCapability; import work.slhaf.partner.core.memory.exception.UnExistedDateIndexException; import work.slhaf.partner.core.memory.exception.UnExistedTopicException; -import work.slhaf.partner.core.memory.pojo.EvaluatedSlice; -import work.slhaf.partner.core.memory.pojo.MemoryResult; +import work.slhaf.partner.core.memory.pojo.ActivatedMemorySlice; import work.slhaf.partner.core.memory.pojo.MemorySlice; +import work.slhaf.partner.core.memory.pojo.MemoryUnit; +import work.slhaf.partner.core.memory.pojo.SliceRef; +import work.slhaf.partner.module.modules.memory.runtime.MemoryRuntime; import work.slhaf.partner.module.modules.memory.selector.evaluator.SliceSelectEvaluator; import work.slhaf.partner.module.modules.memory.selector.evaluator.entity.EvaluatorInput; import work.slhaf.partner.module.modules.memory.selector.extractor.MemorySelectExtractor; @@ -20,9 +23,12 @@ import work.slhaf.partner.module.modules.memory.selector.extractor.entity.Extrac import work.slhaf.partner.module.modules.memory.selector.extractor.entity.ExtractorResult; import work.slhaf.partner.runtime.interaction.data.context.PartnerRunningFlowContext; +import java.time.Instant; import java.time.LocalDate; +import java.time.ZoneId; import java.util.ArrayList; import java.util.Collection; +import java.util.LinkedHashMap; import java.util.List; @EqualsAndHashCode(callSuper = true) @@ -33,89 +39,106 @@ public class MemorySelector extends AbstractAgentModule.Running evaluatedSlices = selectAndEvaluateMemory(runningFlowContext, extractorResult); - memoryCapability.updateActivatedSlices(userId, evaluatedSlices); + memoryCapability.clearActivatedSlices(); + List activatedSlices = selectAndEvaluateMemory(runningFlowContext, extractorResult); + memoryCapability.updateActivatedSlices(activatedSlices); } setModuleContextRecall(runningFlowContext); } - private List selectAndEvaluateMemory(PartnerRunningFlowContext runningFlowContext, ExtractorResult extractorResult) { + private List selectAndEvaluateMemory(PartnerRunningFlowContext runningFlowContext, + ExtractorResult extractorResult) { log.debug("[MemorySelector] 触发记忆回溯..."); - //查找切片 - String userId = runningFlowContext.getSource(); - List memoryResultList = new ArrayList<>(); - setMemoryResultList(memoryResultList, extractorResult.getMatches(), userId); - //评估切片 + LinkedHashMap candidates = new LinkedHashMap<>(); + setMemoryCandidates(candidates, extractorResult.getMatches()); + removeDuplicateSlice(candidates.values()); EvaluatorInput evaluatorInput = EvaluatorInput.builder() .input(runningFlowContext.getInput()) - .memoryResults(memoryResultList) + .memorySlices(new ArrayList<>(candidates.values())) .messages(cognationCapability.getChatMessages()) .build(); log.debug("[MemorySelector] 切片评估输入: {}", JSONObject.toJSONString(evaluatorInput)); - List memorySlices = sliceSelectEvaluator.execute(evaluatorInput); + List memorySlices = sliceSelectEvaluator.execute(evaluatorInput); log.debug("[MemorySelector] 切片评估结果: {}", JSONObject.toJSONString(memorySlices)); return memorySlices; } private void setModuleContextRecall(PartnerRunningFlowContext runningFlowContext) { - String userId = runningFlowContext.getSource(); - boolean recall = memoryCapability.hasActivatedSlices(userId); + boolean recall = memoryCapability.hasActivatedSlices(); runningFlowContext.getModuleContext().getExtraContext().put("recall", recall); if (recall) { - runningFlowContext.getModuleContext().getExtraContext().put("recall_count", memoryCapability.getActivatedSlicesSize(userId)); + runningFlowContext.getModuleContext().getExtraContext().put("recall_count", memoryCapability.getActivatedSlicesSize()); } } - private void setMemoryResultList(List memoryResultList, List matches, String userId) { + private void setMemoryCandidates(LinkedHashMap candidates, + List matches) { for (ExtractorMatchData match : matches) { try { - MemoryResult memoryResult = switch (match.getType()) { - case ExtractorMatchData.Constant.TOPIC -> memoryCapability.selectMemory(match.getText()); - case ExtractorMatchData.Constant.DATE -> - memoryCapability.selectMemory(LocalDate.parse(match.getText())); - default -> null; + List refs = switch (match.getType()) { + case ExtractorMatchData.Constant.TOPIC -> memoryRuntime.findByTopicPath(match.getText()); + case ExtractorMatchData.Constant.DATE -> memoryRuntime.findByDate(LocalDate.parse(match.getText())); + default -> List.of(); }; - if (memoryResult == null || memoryResult.isEmpty()) continue; - removeDuplicateSlice(memoryResult); - memoryResultList.add(memoryResult); + for (SliceRef ref : refs) { + ActivatedMemorySlice recalledSlice = buildActivatedMemorySlice(ref); + if (recalledSlice != null) { + candidates.putIfAbsent(ref.getUnitId() + ":" + ref.getSliceId(), recalledSlice); + } + } } catch (UnExistedDateIndexException | UnExistedTopicException e) { - log.error("[MemorySelector] 不存在的记忆索引! 请尝试更换更合适的主题提取LLM!", e); + log.error("[MemorySelector] 不存在的记忆索引", e); log.error("[MemorySelector] 错误索引: {}", match.getText()); } } - //清理切片记录 - memoryCapability.cleanSelectedSliceFilter(); - //根据userInfo过滤是否为私人记忆 - for (MemoryResult memoryResult : memoryResultList) { - //过滤终点记忆 - memoryResult.getMemorySliceResult().removeIf(m -> removeOrNot(m.getMemorySlice(), userId)); - //过滤邻近记忆 - memoryResult.getRelatedMemorySliceResult().removeIf(m -> removeOrNot(m, userId)); - } } - private void removeDuplicateSlice(MemoryResult memoryResult) { - Collection values = memoryCapability.getDialogMap().values(); - memoryResult.getRelatedMemorySliceResult().removeIf(m -> values.contains(m.getSummary())); - memoryResult.getMemorySliceResult().removeIf(m -> values.contains(m.getMemorySlice().getSummary())); + private ActivatedMemorySlice buildActivatedMemorySlice(SliceRef ref) { + MemoryUnit memoryUnit = memoryCapability.getMemoryUnit(ref.getUnitId()); + MemorySlice memorySlice = memoryCapability.getMemorySlice(ref.getUnitId(), ref.getSliceId()); + if (memoryUnit == null || memorySlice == null) { + return null; + } + List messages = sliceMessages(memoryUnit, memorySlice); + LocalDate date = Instant.ofEpochMilli(memorySlice.getTimestamp()) + .atZone(ZoneId.systemDefault()) + .toLocalDate(); + return ActivatedMemorySlice.builder() + .unitId(ref.getUnitId()) + .sliceId(ref.getSliceId()) + .summary(memorySlice.getSummary()) + .timestamp(memorySlice.getTimestamp()) + .date(date) + .messages(messages) + .build(); } - private boolean removeOrNot(MemorySlice memorySlice, String userId) { - if (memorySlice.isPrivate()) { - return memorySlice.getStartUserId().equals(userId); + private List sliceMessages(MemoryUnit memoryUnit, MemorySlice memorySlice) { + List conversationMessages = memoryUnit.getConversationMessages(); + if (conversationMessages == null || conversationMessages.isEmpty()) { + return List.of(); } - return false; + int start = Math.max(0, memorySlice.getStartIndex()); + int end = Math.min(conversationMessages.size() - 1, memorySlice.getEndIndex()); + if (start > end) { + return List.of(); + } + return new ArrayList<>(conversationMessages.subList(start, end + 1)); + } + + private void removeDuplicateSlice(Collection candidates) { + Collection values = memoryRuntime.getDialogMap().values(); + candidates.removeIf(m -> values.contains(m.getSummary())); } @Override diff --git a/Partner-Core/src/main/java/work/slhaf/partner/module/modules/memory/selector/evaluator/SliceSelectEvaluator.java b/Partner-Core/src/main/java/work/slhaf/partner/module/modules/memory/selector/evaluator/SliceSelectEvaluator.java index fc141d84..0d707d36 100644 --- a/Partner-Core/src/main/java/work/slhaf/partner/module/modules/memory/selector/evaluator/SliceSelectEvaluator.java +++ b/Partner-Core/src/main/java/work/slhaf/partner/module/modules/memory/selector/evaluator/SliceSelectEvaluator.java @@ -1,6 +1,5 @@ package work.slhaf.partner.module.modules.memory.selector.evaluator; -import cn.hutool.core.date.DateUtil; import cn.hutool.json.JSONUtil; import com.alibaba.fastjson2.JSONObject; import lombok.Data; @@ -10,10 +9,7 @@ import work.slhaf.partner.api.agent.factory.component.abstracts.ActivateModel; import work.slhaf.partner.api.agent.factory.component.annotation.Init; import work.slhaf.partner.api.chat.pojo.Message; import work.slhaf.partner.common.thread.InteractionThreadPoolExecutor; -import work.slhaf.partner.core.memory.pojo.EvaluatedSlice; -import work.slhaf.partner.core.memory.pojo.MemoryResult; -import work.slhaf.partner.core.memory.pojo.MemorySlice; -import work.slhaf.partner.core.memory.pojo.MemorySliceResult; +import work.slhaf.partner.core.memory.pojo.ActivatedMemorySlice; import work.slhaf.partner.module.modules.memory.selector.evaluator.entity.EvaluatorBatchInput; import work.slhaf.partner.module.modules.memory.selector.evaluator.entity.EvaluatorInput; import work.slhaf.partner.module.modules.memory.selector.evaluator.entity.EvaluatorResult; @@ -27,7 +23,7 @@ import java.util.concurrent.atomic.AtomicInteger; @EqualsAndHashCode(callSuper = true) @Data -public class SliceSelectEvaluator extends AbstractAgentModule.Sub> implements ActivateModel { +public class SliceSelectEvaluator extends AbstractAgentModule.Sub> implements ActivateModel { private InteractionThreadPoolExecutor executor; @Init @@ -36,83 +32,58 @@ public class SliceSelectEvaluator extends AbstractAgentModule.Sub execute(EvaluatorInput evaluatorInput) { + public List execute(EvaluatorInput evaluatorInput) { log.debug("[SliceSelectEvaluator] 切片评估模块开始..."); - List memoryResultList = evaluatorInput.getMemoryResults(); + List memorySlices = evaluatorInput.getMemorySlices(); List> tasks = new ArrayList<>(); - Queue queue = new ConcurrentLinkedDeque<>(); + Queue queue = new ConcurrentLinkedDeque<>(); AtomicInteger count = new AtomicInteger(0); - for (MemoryResult memoryResult : memoryResultList) { - if (memoryResult.getMemorySliceResult().isEmpty() && memoryResult.getRelatedMemorySliceResult().isEmpty()) { - continue; - } - tasks.add(() -> { - int thisCount = count.incrementAndGet(); - log.debug("[SliceSelectEvaluator] 评估[{}]开始", thisCount); - List sliceSummaryList = new ArrayList<>(); - //映射查找键值 - Map map = new HashMap<>(); - try { - setSliceSummaryList(memoryResult, sliceSummaryList, map); - EvaluatorBatchInput batchInput = EvaluatorBatchInput.builder() - .text(evaluatorInput.getInput()) - .memory_slices(sliceSummaryList) - .history(evaluatorInput.getMessages()) - .build(); - log.debug("[SliceSelectEvaluator] 评估[{}]输入: {}", thisCount, JSONObject.toJSONString(batchInput)); - EvaluatorResult evaluatorResult = formattedChat( - List.of(new Message(Message.Character.USER, JSONUtil.toJsonStr(batchInput))), - EvaluatorResult.class - ); - log.debug("[SliceSelectEvaluator] 评估[{}]结果: {}", thisCount, JSONObject.toJSONString(evaluatorResult)); - for (Long result : evaluatorResult.getResults()) { - SliceSummary sliceSummary = map.get(result); - EvaluatedSlice evaluatedSlice = EvaluatedSlice.builder() - .summary(sliceSummary.getSummary()) - .date(sliceSummary.getDate()) - .build(); -// setEvaluatedSliceMessages(evaluatedSlice, memoryResult, sliceSummary.getId()); - queue.offer(evaluatedSlice); - } - } catch (Exception e) { - log.error("[SliceSelectEvaluator] 评估[{}]出现错误: {}", thisCount, e.getLocalizedMessage()); - } - return null; - }); + if (memorySlices == null || memorySlices.isEmpty()) { + return List.of(); } + tasks.add(() -> { + int thisCount = count.incrementAndGet(); + log.debug("[SliceSelectEvaluator] 评估[{}]开始", thisCount); + List sliceSummaryList = new ArrayList<>(); + Map map = new HashMap<>(); + try { + setSliceSummaryList(memorySlices, sliceSummaryList, map); + EvaluatorBatchInput batchInput = EvaluatorBatchInput.builder() + .text(evaluatorInput.getInput()) + .memory_slices(sliceSummaryList) + .history(evaluatorInput.getMessages()) + .build(); + log.debug("[SliceSelectEvaluator] 评估[{}]输入: {}", thisCount, JSONObject.toJSONString(batchInput)); + EvaluatorResult evaluatorResult = formattedChat( + List.of(new Message(Message.Character.USER, JSONUtil.toJsonStr(batchInput))), + EvaluatorResult.class + ); + log.debug("[SliceSelectEvaluator] 评估[{}]结果: {}", thisCount, JSONObject.toJSONString(evaluatorResult)); + for (Long result : evaluatorResult.getResults()) { + ActivatedMemorySlice slice = map.get(result); + if (slice != null) { + queue.offer(slice); + } + } + } catch (Exception e) { + log.error("[SliceSelectEvaluator] 评估[{}]出现错误: {}", thisCount, e.getLocalizedMessage()); + } + return null; + }); executor.invokeAll(tasks, 30, TimeUnit.SECONDS); log.debug("[SliceSelectEvaluator] 评估模块结束, 输出队列: {}", queue); - List temp = queue.stream().toList(); - return new ArrayList<>(temp); + return new ArrayList<>(queue); } - private void setSliceSummaryList(MemoryResult memoryResult, List sliceSummaryList, Map map) { - for (MemorySliceResult memorySliceResult : memoryResult.getMemorySliceResult()) { - SliceSummary sliceSummary = new SliceSummary(); - sliceSummary.setId(memorySliceResult.getMemorySlice().getTimestamp()); - StringBuilder stringBuilder = new StringBuilder(); - if (memorySliceResult.getSliceBefore() != null) { - stringBuilder.append(memorySliceResult.getSliceBefore().getSummary()) - .append("\r\n"); - } - stringBuilder.append(memorySliceResult.getMemorySlice().getSummary()); - if (memorySliceResult.getSliceAfter() != null) { - stringBuilder.append("\r\n") - .append(memorySliceResult.getSliceAfter().getSummary()) - .append("\r\n"); - } - sliceSummary.setSummary(stringBuilder.toString()); - Long timestamp = memorySliceResult.getMemorySlice().getTimestamp(); - sliceSummary.setDate(DateUtil.date(timestamp).toLocalDateTime().toLocalDate()); - sliceSummaryList.add(sliceSummary); - map.put(timestamp, sliceSummary); - } - for (MemorySlice memorySlice : memoryResult.getRelatedMemorySliceResult()) { + private void setSliceSummaryList(List memorySlices, List sliceSummaryList, + Map map) { + for (ActivatedMemorySlice memorySlice : memorySlices) { SliceSummary sliceSummary = new SliceSummary(); sliceSummary.setId(memorySlice.getTimestamp()); sliceSummary.setSummary(memorySlice.getSummary()); + sliceSummary.setDate(memorySlice.getDate()); sliceSummaryList.add(sliceSummary); - map.put(memorySlice.getTimestamp(), sliceSummary); + map.put(memorySlice.getTimestamp(), memorySlice); } } diff --git a/Partner-Core/src/main/java/work/slhaf/partner/module/modules/memory/selector/evaluator/entity/EvaluatorInput.java b/Partner-Core/src/main/java/work/slhaf/partner/module/modules/memory/selector/evaluator/entity/EvaluatorInput.java index 29ff6a96..ce950540 100644 --- a/Partner-Core/src/main/java/work/slhaf/partner/module/modules/memory/selector/evaluator/entity/EvaluatorInput.java +++ b/Partner-Core/src/main/java/work/slhaf/partner/module/modules/memory/selector/evaluator/entity/EvaluatorInput.java @@ -3,7 +3,7 @@ package work.slhaf.partner.module.modules.memory.selector.evaluator.entity; import lombok.Builder; import lombok.Data; import work.slhaf.partner.api.chat.pojo.Message; -import work.slhaf.partner.core.memory.pojo.MemoryResult; +import work.slhaf.partner.core.memory.pojo.ActivatedMemorySlice; import java.util.List; @@ -12,5 +12,5 @@ import java.util.List; public class EvaluatorInput { private String input; private List messages; - private List memoryResults; + private List memorySlices; } diff --git a/Partner-Core/src/main/java/work/slhaf/partner/module/modules/memory/selector/extractor/MemorySelectExtractor.java b/Partner-Core/src/main/java/work/slhaf/partner/module/modules/memory/selector/extractor/MemorySelectExtractor.java index c92015fa..d6fea0ba 100644 --- a/Partner-Core/src/main/java/work/slhaf/partner/module/modules/memory/selector/extractor/MemorySelectExtractor.java +++ b/Partner-Core/src/main/java/work/slhaf/partner/module/modules/memory/selector/extractor/MemorySelectExtractor.java @@ -6,17 +6,17 @@ import lombok.EqualsAndHashCode; import work.slhaf.partner.api.agent.factory.capability.annotation.InjectCapability; import work.slhaf.partner.api.agent.factory.component.abstracts.AbstractAgentModule; import work.slhaf.partner.api.agent.factory.component.abstracts.ActivateModel; +import work.slhaf.partner.api.agent.factory.component.annotation.InjectModule; import work.slhaf.partner.api.chat.pojo.Message; -import work.slhaf.partner.api.chat.pojo.MetaMessage; import work.slhaf.partner.core.cognation.CognationCapability; import work.slhaf.partner.core.memory.MemoryCapability; -import work.slhaf.partner.core.memory.pojo.EvaluatedSlice; +import work.slhaf.partner.core.memory.pojo.ActivatedMemorySlice; +import work.slhaf.partner.module.modules.memory.runtime.MemoryRuntime; import work.slhaf.partner.module.modules.memory.selector.extractor.entity.ExtractorInput; import work.slhaf.partner.module.modules.memory.selector.extractor.entity.ExtractorMatchData; import work.slhaf.partner.module.modules.memory.selector.extractor.entity.ExtractorResult; import work.slhaf.partner.runtime.interaction.data.context.PartnerRunningFlowContext; -import java.util.ArrayList; import java.util.List; import static work.slhaf.partner.common.util.ExtractUtil.fixTopicPath; @@ -29,25 +29,21 @@ public class MemorySelectExtractor extends AbstractAgentModule.Sub chatMessages = new ArrayList<>(); - List metaMessages = cognationCapability.snapshotSingleMetaMessages(context.getSource()); - for (MetaMessage metaMessage : metaMessages) { - chatMessages.add(metaMessage.getUserMessage()); - chatMessages.add(metaMessage.getAssistantMessage()); - } + List chatMessages = cognationCapability.snapshotChatMessages(); ExtractorResult extractorResult; try { - List activatedMemorySlices = memoryCapability.getActivatedSlices(context.getSource()); + List activatedMemorySlices = memoryCapability.getActivatedSlices(); ExtractorInput extractorInput = ExtractorInput.builder() .text(context.getInput()) .date(context.getInfo().getDateTime().toLocalDate()) .history(chatMessages) - .topic_tree(memoryCapability.getTopicTree()) + .topic_tree(memoryRuntime.getTopicTree()) .activatedMemorySlices(activatedMemorySlices) .build(); log.debug("[MemorySelectExtractor] 主题提取输入: {}", JSONUtil.toJsonStr(extractorInput)); diff --git a/Partner-Core/src/main/java/work/slhaf/partner/module/modules/memory/selector/extractor/entity/ExtractorInput.java b/Partner-Core/src/main/java/work/slhaf/partner/module/modules/memory/selector/extractor/entity/ExtractorInput.java index 23908392..28c12b16 100644 --- a/Partner-Core/src/main/java/work/slhaf/partner/module/modules/memory/selector/extractor/entity/ExtractorInput.java +++ b/Partner-Core/src/main/java/work/slhaf/partner/module/modules/memory/selector/extractor/entity/ExtractorInput.java @@ -3,7 +3,7 @@ package work.slhaf.partner.module.modules.memory.selector.extractor.entity; import lombok.Builder; import lombok.Data; import work.slhaf.partner.api.chat.pojo.Message; -import work.slhaf.partner.core.memory.pojo.EvaluatedSlice; +import work.slhaf.partner.core.memory.pojo.ActivatedMemorySlice; import java.time.LocalDate; import java.util.List; @@ -15,5 +15,5 @@ public class ExtractorInput { private String topic_tree; private LocalDate date; private List history; - private List activatedMemorySlices; + private List activatedMemorySlices; } diff --git a/Partner-Core/src/main/java/work/slhaf/partner/module/modules/memory/updater/MemoryUpdater.java b/Partner-Core/src/main/java/work/slhaf/partner/module/modules/memory/updater/MemoryUpdater.java index cd176b2f..345857c0 100644 --- a/Partner-Core/src/main/java/work/slhaf/partner/module/modules/memory/updater/MemoryUpdater.java +++ b/Partner-Core/src/main/java/work/slhaf/partner/module/modules/memory/updater/MemoryUpdater.java @@ -8,31 +8,29 @@ import work.slhaf.partner.api.agent.factory.capability.annotation.InjectCapabili import work.slhaf.partner.api.agent.factory.component.annotation.Init; import work.slhaf.partner.api.agent.factory.component.annotation.InjectModule; import work.slhaf.partner.api.chat.pojo.Message; -import work.slhaf.partner.api.chat.pojo.MetaMessage; import work.slhaf.partner.common.thread.InteractionThreadPoolExecutor; import work.slhaf.partner.core.action.entity.Schedulable; import work.slhaf.partner.core.action.entity.StateAction; import work.slhaf.partner.core.cognation.CognationCapability; import work.slhaf.partner.core.memory.MemoryCapability; import work.slhaf.partner.core.memory.pojo.MemorySlice; -import work.slhaf.partner.core.perceive.PerceiveCapability; +import work.slhaf.partner.core.memory.pojo.MemoryUnit; +import work.slhaf.partner.core.memory.pojo.SliceRef; import work.slhaf.partner.module.common.module.PostRunningAgentModule; import work.slhaf.partner.module.modules.action.scheduler.ActionScheduler; +import work.slhaf.partner.module.modules.memory.runtime.MemoryRuntime; import work.slhaf.partner.module.modules.memory.updater.summarizer.MultiSummarizer; import work.slhaf.partner.module.modules.memory.updater.summarizer.SingleSummarizer; -import work.slhaf.partner.module.modules.memory.updater.summarizer.TotalSummarizer; import work.slhaf.partner.module.modules.memory.updater.summarizer.entity.SummarizeInput; import work.slhaf.partner.module.modules.memory.updater.summarizer.entity.SummarizeResult; import work.slhaf.partner.runtime.interaction.data.context.PartnerRunningFlowContext; import java.time.LocalDateTime; -import java.util.*; -import java.util.concurrent.Callable; -import java.util.concurrent.ConcurrentHashMap; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.UUID; import java.util.concurrent.atomic.AtomicBoolean; -import java.util.concurrent.atomic.AtomicInteger; - -import static work.slhaf.partner.common.util.ExtractUtil.extractUserId; @EqualsAndHashCode(callSuper = true) @Data @@ -46,21 +44,20 @@ public class MemoryUpdater extends PostRunningAgentModule { private CognationCapability cognationCapability; @InjectCapability private MemoryCapability memoryCapability; - @InjectCapability - private PerceiveCapability perceiveCapability; + @InjectModule + private MemoryRuntime memoryRuntime; @InjectModule private MultiSummarizer multiSummarizer; @InjectModule private SingleSummarizer singleSummarizer; @InjectModule - private TotalSummarizer totalSummarizer; - private final AtomicBoolean updating = new AtomicBoolean(false); - - private InteractionThreadPoolExecutor executor; - @InjectModule private ActionScheduler actionScheduler; + private final AtomicBoolean updating = new AtomicBoolean(false); + private InteractionThreadPoolExecutor executor; + private volatile long lastUpdatedTime; + @Init public void init() { executor = InteractionThreadPoolExecutor.getInstance(); @@ -86,15 +83,13 @@ public class MemoryUpdater extends PostRunningAgentModule { @Override public void doExecute(PartnerRunningFlowContext context) { executor.execute(() -> { - // 如果token 大于阈值,则更新记忆 JSONObject moduleContext = context.getModuleContext().getExtraContext(); boolean recall = moduleContext.getBoolean("recall"); if (recall) { - log.debug("[MemoryUpdater] 存在回忆"); int recallCount = moduleContext.getIntValue("recall_count"); - log.debug("[MemoryUpdater] 记忆切片数量 [{}]", recallCount); + log.debug("[MemoryUpdater] 当前激活记忆数量 [{}]", recallCount); } - boolean trigger = context.getModuleContext().getExtraContext().getBoolean("post_process_trigger"); + boolean trigger = moduleContext.getBoolean("post_process_trigger"); if (!trigger) { return; } @@ -110,7 +105,6 @@ public class MemoryUpdater extends PostRunningAgentModule { private void tryAutoUpdate() { long currentTime = System.currentTimeMillis(); - long lastUpdatedTime = cognationCapability.getLastUpdatedTime(); int chatCount = cognationCapability.snapshotChatMessages().size(); if (lastUpdatedTime != 0 && currentTime - lastUpdatedTime > UPDATE_TRIGGER_INTERVAL && chatCount > 1) { triggerMemoryUpdate(true); @@ -131,7 +125,7 @@ public class MemoryUpdater extends PostRunningAgentModule { updateMemory(chatSnapshot); cognationCapability.rollChatMessagesWithSnapshot(chatSnapshot.size(), CONTEXT_RETAIN_DIVISOR); if (refreshMemoryId) { - cognationCapability.refreshMemoryId(); + memoryCapability.refreshMemoryId(); } } catch (Exception e) { log.error("[MemoryUpdater] 记忆更新线程出错: ", e); @@ -142,75 +136,35 @@ public class MemoryUpdater extends PostRunningAgentModule { private void updateMemory(List chatSnapshot) { log.debug("[MemoryUpdater] 记忆更新流程开始..."); - Map singleMemorySummary = new ConcurrentHashMap<>(); - Map> singleChatMessages = drainSingleChatMessages(); - // 更新单聊记忆,同时从chatMessages中去掉单聊记忆 - updateSingleChatSlices(singleMemorySummary, singleChatMessages); - // 更新多人场景下的记忆及相关的确定性记忆 - List multiChatMessages = excludeSingleChatMessages(chatSnapshot, singleChatMessages); - updateMultiChatSlices(singleMemorySummary, multiChatMessages); - cognationCapability.resetLastUpdatedTime(); + List chatMessages = getCleanedMessages(chatSnapshot); + if (chatMessages.isEmpty()) { + return; + } + SummarizeInput summarizeInput = new SummarizeInput(chatMessages, memoryRuntime.getTopicTree()); + log.debug("[MemoryUpdater] 记忆更新-总结流程-输入: {}", JSONObject.toJSONString(summarizeInput)); + SummarizeResult summarizeResult = summarize(summarizeInput); + log.debug("[MemoryUpdater] 记忆更新-总结流程-输出: {}", JSONObject.toJSONString(summarizeResult)); + MemoryUnit memoryUnit = buildMemoryUnit(chatMessages, summarizeResult); + memoryCapability.saveMemoryUnit(memoryUnit); + MemorySlice memorySlice = memoryUnit.getSlices().getFirst(); + SliceRef sliceRef = new SliceRef(memoryUnit.getId(), memorySlice.getId()); + bindTopics(memoryUnit, summarizeResult, sliceRef); + memoryRuntime.updateDialogMap(LocalDateTime.now(), summarizeResult.getSummary()); + lastUpdatedTime = System.currentTimeMillis(); log.debug("[MemoryUpdater] 记忆更新流程结束..."); } - private Map> drainSingleChatMessages() { - Map> drainedMessages = new HashMap<>(); - Map> drainedMetaMessages = cognationCapability.drainSingleMetaMessages(); - for (Map.Entry> entry : drainedMetaMessages.entrySet()) { - List messages = new ArrayList<>(); - for (MetaMessage metaMessage : entry.getValue()) { - messages.add(metaMessage.getUserMessage()); - messages.add(metaMessage.getAssistantMessage()); - } - drainedMessages.put(entry.getKey(), messages); + private void bindTopics(MemoryUnit memoryUnit, SummarizeResult summarizeResult, SliceRef sliceRef) { + memoryRuntime.indexMemoryUnit(memoryUnit); + memoryRuntime.bindTopic(summarizeResult.getTopicPath(), sliceRef); + if (summarizeResult.getRelatedTopicPath() == null) { + return; } - return drainedMessages; - } - - private List excludeSingleChatMessages(List chatSnapshot, Map> singleChatMessages) { - Set singleMessages = new HashSet<>(); - for (List messages : singleChatMessages.values()) { - singleMessages.addAll(messages); + for (String relatedTopicPath : summarizeResult.getRelatedTopicPath()) { + memoryRuntime.bindTopic(relatedTopicPath, sliceRef); } - return chatSnapshot.stream() - .filter(message -> !singleMessages.contains(message)) - .toList(); } - private void updateMultiChatSlices(Map singleMemorySummary, List multiChatMessages) { - // 此时chatMessages中不再包含单聊记录,直接执行摘要以及切片插入 - // 对剩下的多人聊天记录进行进行摘要 - Callable task = () -> { - log.debug("[MemoryUpdater] 多人聊天记忆更新流程开始..."); - List chatMessages = getCleanedMessages(multiChatMessages); - if (!chatMessages.isEmpty()) { - log.debug("[MemoryUpdater] 存在多人聊天记录, 流程正常进行..."); - // 以第一条user对应的id为发起用户 - String userId = extractUserId(chatMessages.getFirst().getContent()); - if (userId == null) { - throw new RuntimeException("未匹配到 userId!"); - } - SummarizeInput summarizeInput = new SummarizeInput(chatMessages, memoryCapability.getTopicTree()); - log.debug("[MemoryUpdater] 多人聊天记忆更新-总结流程-输入: {}", summarizeInput); - SummarizeResult summarizeResult = summarize(summarizeInput); - log.debug("[MemoryUpdater] 多人聊天记忆更新-总结流程-输出: {}", summarizeResult); - MemorySlice memorySlice = getMemorySlice(userId, summarizeResult, chatMessages); - // 设置involvedUserId - setInvolvedUserId(userId, memorySlice, chatMessages); - memoryCapability.insertSlice(memorySlice, summarizeResult.getTopicPath()); - memoryCapability.updateDialogMap(LocalDateTime.now(), summarizeResult.getSummary()); - } else { - log.debug("[MemoryUpdater] 不存在多人聊天记录, 将以单聊总结为对话缓存的主要输入: {}", singleMemorySummary); - memoryCapability.updateDialogMap(LocalDateTime.now(), totalSummarizer.execute(new HashMap<>(singleMemorySummary))); - } - log.debug("[MemoryUpdater] 对话缓存更新完毕"); - log.debug("[MemoryUpdater] 多人聊天记忆更新流程结束..."); - return null; - }; - executor.invokeAll(List.of(task)); - } - - // TODO need to move time information into perceive core private List getCleanedMessages(List chatMessages) { return chatMessages.stream() .map(message -> { @@ -226,84 +180,27 @@ public class MemoryUpdater extends PostRunningAgentModule { }).toList(); } - private void setInvolvedUserId(String startUserId, MemorySlice memorySlice, List chatMessages) { - for (Message chatMessage : chatMessages) { - if (chatMessage.getRole() == Message.Character.ASSISTANT) { - continue; - } - // 匹配userId - String userId = extractUserId(chatMessage.getContent()); - if (userId == null) { - continue; - } - if (userId.equals(startUserId)) { - continue; - } - memorySlice.setInvolvedUserIds(new ArrayList<>()); - memorySlice.getInvolvedUserIds().add(userId); - } - } - - private void updateSingleChatSlices(Map singleMemorySummary, Map> singleChatMessages) { - log.debug("[MemoryUpdater] 单聊记忆更新流程开始..."); - List> tasks = new ArrayList<>(); - AtomicInteger count = new AtomicInteger(0); - for (Map.Entry> entry : singleChatMessages.entrySet()) { - String id = entry.getKey(); - List messages = entry.getValue(); - if (messages.isEmpty()) { - continue; - } - tasks.add(() -> { - int thisCount = count.incrementAndGet(); - log.debug("[MemoryUpdater] 单聊记忆[{}]更新: {}", thisCount, id); - try { - // 单聊记忆更新 - SummarizeInput summarizeInput = new SummarizeInput(messages, memoryCapability.getTopicTree()); - log.debug("[MemoryUpdater] 单聊记忆[{}]更新-总结流程-输入: {}", thisCount, JSONObject.toJSONString(summarizeInput)); - SummarizeResult summarizeResult = summarize(summarizeInput); - log.debug("[MemoryUpdater] 单聊记忆[{}]更新-总结流程-输出: {}", thisCount, JSONObject.toJSONString(summarizeResult)); - MemorySlice memorySlice = getMemorySlice(id, summarizeResult, messages); - // 插入时userDialogMap已经进行更新 - memoryCapability.insertSlice(memorySlice, summarizeResult.getTopicPath()); - // 从chatMessages中移除单聊记录 - cognationCapability.cleanMessage(messages); - // 添加至singleMemorySummary - String key = perceiveCapability.getUser(id).getNickName() + "[" + id + "]"; - singleMemorySummary.put(key, summarizeResult.getSummary()); - log.debug("[MemoryUpdater] 单聊记忆[{}]更新成功: ", thisCount); - } catch (Exception e) { - log.error("[MemoryUpdater] 单聊记忆[{}]更新出错: ", thisCount, e); - } - return null; - }); - } - executor.invokeAll(tasks); - log.debug("[MemoryUpdater] 单聊记忆更新结束..."); - } - private SummarizeResult summarize(SummarizeInput summarizeInput) { singleSummarizer.execute(summarizeInput.getChatMessages()); return multiSummarizer.execute(summarizeInput); } - private MemorySlice getMemorySlice(String userId, SummarizeResult summarizeResult, List chatMessages) { + private MemoryUnit buildMemoryUnit(List chatMessages, SummarizeResult summarizeResult) { + long now = System.currentTimeMillis(); MemorySlice memorySlice = new MemorySlice(); - // 设置 memoryId,timestamp - memorySlice.setMemoryId(cognationCapability.getCurrentMemoryId()); - memorySlice.setTimestamp(System.currentTimeMillis()); - // 补充信息 - memorySlice.setPrivate(summarizeResult.isPrivate()); + memorySlice.setId(UUID.randomUUID().toString()); + memorySlice.setStartIndex(0); + memorySlice.setEndIndex(Math.max(chatMessages.size() - 1, 0)); memorySlice.setSummary(summarizeResult.getSummary()); - memorySlice.setChatMessages(chatMessages); - memorySlice.setStartUserId(userId); - List> relatedTopicPathList = new ArrayList<>(); - for (String string : summarizeResult.getRelatedTopicPath()) { - List list = Arrays.stream(string.split("->")).toList(); - relatedTopicPathList.add(list); - } - memorySlice.setRelatedTopics(relatedTopicPathList); - return memorySlice; + memorySlice.setTimestamp(now); + + MemoryUnit memoryUnit = new MemoryUnit(); + String memoryId = memoryCapability.getCurrentMemoryId(); + memoryUnit.setId(memoryId == null || memoryId.isBlank() ? UUID.randomUUID().toString() : memoryId); + memoryUnit.setTimestamp(now); + memoryUnit.setConversationMessages(new ArrayList<>(chatMessages)); + memoryUnit.setSlices(new ArrayList<>(List.of(memorySlice))); + return memoryUnit; } @Override diff --git a/Partner-Core/src/main/java/work/slhaf/partner/module/modules/memory/updater/summarizer/entity/SummarizeResult.java b/Partner-Core/src/main/java/work/slhaf/partner/module/modules/memory/updater/summarizer/entity/SummarizeResult.java index 5485149f..dc606a7b 100644 --- a/Partner-Core/src/main/java/work/slhaf/partner/module/modules/memory/updater/summarizer/entity/SummarizeResult.java +++ b/Partner-Core/src/main/java/work/slhaf/partner/module/modules/memory/updater/summarizer/entity/SummarizeResult.java @@ -9,5 +9,4 @@ public class SummarizeResult { private String summary; private String topicPath; private List relatedTopicPath; - private boolean isPrivate; } diff --git a/Partner-Core/src/main/java/work/slhaf/partner/module/modules/process/PreprocessExecutor.java b/Partner-Core/src/main/java/work/slhaf/partner/module/modules/process/PreprocessExecutor.java index 98500d45..01cac0c4 100644 --- a/Partner-Core/src/main/java/work/slhaf/partner/module/modules/process/PreprocessExecutor.java +++ b/Partner-Core/src/main/java/work/slhaf/partner/module/modules/process/PreprocessExecutor.java @@ -5,6 +5,7 @@ import lombok.EqualsAndHashCode; import work.slhaf.partner.api.agent.factory.capability.annotation.InjectCapability; import work.slhaf.partner.api.agent.factory.component.abstracts.AbstractAgentModule; import work.slhaf.partner.core.cognation.CognationCapability; +import work.slhaf.partner.core.memory.MemoryCapability; import work.slhaf.partner.core.perceive.PerceiveCapability; import work.slhaf.partner.core.perceive.pojo.User; import work.slhaf.partner.runtime.interaction.data.context.PartnerRunningFlowContext; @@ -18,6 +19,8 @@ public class PreprocessExecutor extends AbstractAgentModule.Running { - if ("selectMemory".equals(method.getName())) { + if ("getCurrentMemoryId".equals(method.getName())) { System.out.println(111); - return new MemoryResult(); + return "memory-id"; } return null; }); - memory.selectMemory("111"); + memory.getCurrentMemoryId(); } } diff --git a/Partner-Core/src/test/java/work/slhaf/partner/module/modules/action/dispatcher/executor/ActionExecutorTest.java b/Partner-Core/src/test/java/work/slhaf/partner/module/modules/action/dispatcher/executor/ActionExecutorTest.java index f212e601..f6646d98 100644 --- a/Partner-Core/src/test/java/work/slhaf/partner/module/modules/action/dispatcher/executor/ActionExecutorTest.java +++ b/Partner-Core/src/test/java/work/slhaf/partner/module/modules/action/dispatcher/executor/ActionExecutorTest.java @@ -74,7 +74,7 @@ class ActionExecutorTest { @BeforeEach void setUp() { lenient().when(cognationCapability.getChatMessages()).thenReturn(Collections.emptyList()); - lenient().when(memoryCapability.getActivatedSlices(anyString())).thenReturn(Collections.emptyList()); + lenient().when(memoryCapability.getActivatedSlices()).thenReturn(Collections.emptyList()); lenient().when(actionCapability.putPhaserRecord(any(Phaser.class), any(ExecutableAction.class))) .thenAnswer(inv -> new PhaserRecord(inv.getArgument(0), inv.getArgument(1))); lenient().when(actionCapability.loadMetaActionInfo(anyString())).thenAnswer(inv -> { diff --git a/Partner-Core/src/test/java/work/slhaf/partner/module/modules/core/CommunicationProducerTest.java b/Partner-Core/src/test/java/work/slhaf/partner/module/modules/core/CommunicationProducerTest.java index 6983cdb9..25e95b8b 100644 --- a/Partner-Core/src/test/java/work/slhaf/partner/module/modules/core/CommunicationProducerTest.java +++ b/Partner-Core/src/test/java/work/slhaf/partner/module/modules/core/CommunicationProducerTest.java @@ -1,17 +1,15 @@ package work.slhaf.partner.module.modules.core; +import org.jetbrains.annotations.NotNull; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; -import org.jetbrains.annotations.NotNull; -import org.mockito.ArgumentCaptor; import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; import org.w3c.dom.Document; import org.w3c.dom.Element; import work.slhaf.partner.api.agent.runtime.interaction.flow.ContextBlock; import work.slhaf.partner.api.chat.pojo.Message; -import work.slhaf.partner.api.chat.pojo.MetaMessage; import work.slhaf.partner.core.cognation.CognationCapability; import work.slhaf.partner.runtime.interaction.data.context.PartnerRunningFlowContext; @@ -19,13 +17,8 @@ import java.util.ArrayList; import java.util.List; import java.util.concurrent.locks.ReentrantLock; -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertFalse; -import static org.junit.jupiter.api.Assertions.assertNotNull; -import static org.junit.jupiter.api.Assertions.assertTrue; -import static org.mockito.ArgumentMatchers.anyString; +import static org.junit.jupiter.api.Assertions.*; import static org.mockito.Mockito.lenient; -import static org.mockito.Mockito.verify; @ExtendWith(MockitoExtension.class) class CommunicationProducerTest { @@ -100,12 +93,6 @@ class CommunicationProducerTest { Message lastAssistantMessage = producer.getChatMessages().get(3); assertEquals("收到", lastAssistantMessage.getContent()); - ArgumentCaptor metaMessageCaptor = ArgumentCaptor.forClass(MetaMessage.class); - verify(cognationCapability).addMetaMessage(anyString(), metaMessageCaptor.capture()); - MetaMessage metaMessage = metaMessageCaptor.getValue(); - assertNotNull(metaMessage); - assertTrue(metaMessage.getUserMessage().getContent().startsWith("[USER]: u-1: 你好,介绍一下你现在看到的上下文")); - assertEquals("收到", metaMessage.getAssistantMessage().getContent()); assertEquals("收到", context.getCoreResponse().getString("text")); }