From 24d4510270d9fa078de3f7f8008b6e694509a24b Mon Sep 17 00:00:00 2001 From: slhaf Date: Thu, 10 Apr 2025 17:51:01 +0800 Subject: [PATCH] =?UTF-8?q?=E6=B7=BB=E5=8A=A0dateIndex(=E8=AE=B0=E5=BF=86?= =?UTF-8?q?=E5=88=87=E7=89=87=E7=9A=84=E6=97=A5=E6=9C=9F=E7=B4=A2=E5=BC=95?= =?UTF-8?q?)=E3=80=81dialogMap(=E8=BF=91=E6=9C=9F=E5=AF=B9=E8=AF=9D?= =?UTF-8?q?=E7=BC=93=E5=AD=98)=E3=80=81staticMemory(=E7=A1=AE=E5=AE=9A?= =?UTF-8?q?=E6=80=A7=E8=AE=B0=E5=BF=86)=E7=AD=89=E5=AD=97=E6=AE=B5?= =?UTF-8?q?=EF=BC=8C=E5=B9=B6=E5=AE=9E=E7=8E=B0=E7=9B=B8=E5=85=B3=E6=9B=B4?= =?UTF-8?q?=E6=96=B0=E6=93=8D=E4=BD=9C;=20=E8=B0=83=E6=95=B4=E4=BA=86Memor?= =?UTF-8?q?ySlice=E4=B8=AD=E7=9A=84=E9=83=A8=E5=88=86=E7=BB=93=E6=9E=84;?= =?UTF-8?q?=20=E6=B7=BB=E5=8A=A0=E4=BA=86=E5=BF=85=E8=A6=81=E7=9A=84?= =?UTF-8?q?=E6=B3=A8=E9=87=8A;?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/main/java/work/slhaf/Main.java | 3 - .../java/work/slhaf/memory/MemoryGraph.java | 101 +++++++++++++++--- .../slhaf/memory/content/MemorySlice.java | 31 ++++-- .../work/slhaf/memory/content/SliceData.java | 10 ++ .../work/slhaf/memory/node/TopicNode.java | 1 - src/test/java/memory/InsertTest.java | 3 - src/test/java/memory/SearchTest.java | 9 +- 7 files changed, 125 insertions(+), 33 deletions(-) create mode 100644 src/main/java/work/slhaf/memory/content/SliceData.java diff --git a/src/main/java/work/slhaf/Main.java b/src/main/java/work/slhaf/Main.java index bad5cde9..5e7701c8 100644 --- a/src/main/java/work/slhaf/Main.java +++ b/src/main/java/work/slhaf/Main.java @@ -1,9 +1,6 @@ package work.slhaf; import work.slhaf.memory.MemoryGraph; -import work.slhaf.memory.content.MemorySlice; - -import java.util.Arrays; public class Main { public static void main(String[] args) { diff --git a/src/main/java/work/slhaf/memory/MemoryGraph.java b/src/main/java/work/slhaf/memory/MemoryGraph.java index a5498af4..696bd51e 100644 --- a/src/main/java/work/slhaf/memory/MemoryGraph.java +++ b/src/main/java/work/slhaf/memory/MemoryGraph.java @@ -11,6 +11,7 @@ import java.nio.file.Files; import java.nio.file.Path; import java.nio.file.Paths; import java.time.LocalDate; +import java.time.LocalDateTime; import java.util.*; @Data @@ -19,16 +20,44 @@ public class MemoryGraph implements Serializable { @Serial private static final long serialVersionUID = 1L; private static final String STORAGE_DIR = "./data/memory/"; - + //todo: 实现记忆的短期缓存机制 private String id; + /** + * key: 根主题名称 value: 根主题节点 + */ private HashMap topicNodes; public static MemoryGraph memoryGraph; - private HashMap> existedTopics; + + /** + * 用于存储已存在的主题列表,便于记忆查找, 使用根主题名称作为键, 子主题名称集合为值 + * 该部分在'主题提取LLM'的system prompt中常驻 + */ + private HashMap> existedTopics; + + /** + * 记忆节点的日期索引, 同一日期内按照对话id区分 + * 同时作为临时的同一对话切片容器, 用于为同一对话内的不同切片提供更新上下文的场所 + */ + private HashMap>> dateIndex; + + /** + * 近两日的对话总结缓存, 用于为大模型提供必要的记忆补充, hashmap以切片的存储时间为键,总结为值 + * 该部分作为'主LLM'system prompt常驻 + */ + private HashMap dialogMap; + + /** + * 存储确定性记忆, 如'用户爱好'等确定性信息 + * 该部分作为'主LLM'system prompt常驻 + */ + private HashMap> staticMemory; public MemoryGraph(String id) { this.id = id; this.topicNodes = new HashMap<>(); this.existedTopics = new HashMap<>(); + this.dateIndex = new HashMap<>(); + this.staticMemory = new HashMap<>(); } public static MemoryGraph initialize(String id) { @@ -88,9 +117,6 @@ public class MemoryGraph implements Serializable { public void insertMemory(List topicPath, MemorySlice slice) { topicPath = new ArrayList<>(topicPath); - if (topicNodes == null) { - topicNodes = new HashMap<>(); - } //查看是否存在根主题节点 String rootTopic = topicPath.getFirst(); topicPath.removeFirst(); @@ -99,7 +125,7 @@ public class MemoryGraph implements Serializable { rootNode.setMemoryNodes(new ArrayList<>()); rootNode.setTopicNodes(new HashMap<>()); topicNodes.put(rootTopic, rootNode); - existedTopics.put(rootTopic, new HashSet<>()); + existedTopics.put(rootTopic, new LinkedHashSet<>()); } TopicNode lastTopicNode = topicNodes.get(rootTopic); @@ -115,13 +141,9 @@ public class MemoryGraph implements Serializable { lastTopicNode.setMemoryNodes(nodeList); lastTopicNode.setTopicNodes(new HashMap<>()); existedTopicNodes.add(topic); - /*if (i == topicPath.size() - 1) { - lastTopicNode.setMemoryNodes(new ArrayList<>()); - lastTopicNode.setTopicNodes(new HashMap<>()); - }*/ } } - //检查是否存在当天对应的memoryData + //检查是否存在当天对应的memorySlice LocalDate now = LocalDate.now(); boolean hasSlice = false; MemoryNode node = null; @@ -140,9 +162,56 @@ public class MemoryGraph implements Serializable { lastTopicNode.getMemoryNodes().sort(null); } node.getMemorySliceList().add(slice); + + updateDateIndex(now, slice); + updateDialogMap(slice); } - public List selectMemory(List topicPath) { + private void updateDialogMap(MemorySlice slice) { + String summary = slice.getSliceData().getSummary(); + LocalDateTime now = LocalDateTime.now(); + //移除两天前的上下文补充(切片总结) + List keysToRemove = new ArrayList<>(); + dialogMap.forEach((k, v) -> { + if (now.minusDays(2).isAfter(k)){ + keysToRemove.add(k); + } + }); + for (LocalDateTime dateTime : keysToRemove) { + dialogMap.remove(dateTime); + } + dialogMap.put(now,summary); + } + + private void updateDateIndex(LocalDate now, MemorySlice slice) { + String memoryId = slice.getMemoryId(); + //查看是否存在当前日期的对话切片索引 + if (!dateIndex.containsKey(now)) { + dateIndex.put(now, new HashMap<>()); + } + //查看当前日期的索引中是否存在该对话的索引 + HashMap> currentDateDialogSlices = dateIndex.get(now); + if (!currentDateDialogSlices.containsKey(memoryId)) { + List memorySliceList = new ArrayList<>(); + currentDateDialogSlices.put(memoryId, memorySliceList); + } + //处理上下文关系 + List memorySliceList = currentDateDialogSlices.get(memoryId); + if (memorySliceList.isEmpty()) { + memorySliceList.add(slice); + } else { + //排序 + memorySliceList.sort(null); + MemorySlice tempSlice = memorySliceList.getLast(); + //末尾切片添加当前切片的引用 + tempSlice.setSliceAfter(slice); + //当前切片添加前序切片的引用 + slice.setSliceBefore(tempSlice); + } + + } + + public List selectMemoryByPath(List topicPath) { List targetSliceList = new ArrayList<>(); topicPath = new ArrayList<>(topicPath); String targetTopic = topicPath.getLast(); @@ -178,15 +247,19 @@ public class MemoryGraph implements Serializable { return targetSliceList; } + public HashMap> selectMemoryByDate(LocalDate date){ + return dateIndex.get(date); + } + private TopicNode getTargetParentNode(List topicPath, String targetTopic) { String topTopic = topicPath.getFirst(); - if (!existedTopics.containsKey(topTopic)){ + if (!existedTopics.containsKey(topTopic)) { throw new UnExistedTopicException("不存在的主题: " + topTopic); } TopicNode targetParentNode = topicNodes.get(topTopic); topicPath.removeFirst(); for (String topic : topicPath) { - if (!existedTopics.get(topTopic).contains(topic)){ + if (!existedTopics.get(topTopic).contains(topic)) { throw new UnExistedTopicException("不存在的主题: " + topTopic); } } diff --git a/src/main/java/work/slhaf/memory/content/MemorySlice.java b/src/main/java/work/slhaf/memory/content/MemorySlice.java index 28f15c4c..a5952754 100644 --- a/src/main/java/work/slhaf/memory/content/MemorySlice.java +++ b/src/main/java/work/slhaf/memory/content/MemorySlice.java @@ -1,21 +1,38 @@ package work.slhaf.memory.content; import lombok.Data; -import work.slhaf.memory.node.TopicNode; import java.io.Serializable; -import java.util.LinkedHashMap; import java.util.List; @Data -public class MemorySlice implements Serializable { +public class MemorySlice implements Serializable, Comparable { //关联的完整对话的id private String memoryId; - //该切片在关联的完整对话中的顺序 - private Integer memoryRank; + //该切片在关联的完整对话中的顺序, 由时间戳确定 + private Long timestamp; private String slicePath; private List> relatedTopics; //关联完整对话中的前序切片, 排序为键,完整路径为值 - private LinkedHashMap sliceBefore; - private LinkedHashMap sliceAfter; + private MemorySlice sliceBefore; + private MemorySlice sliceAfter; + + @Override + public int compareTo(MemorySlice memorySlice) { + if (memorySlice.getTimestamp() > this.getTimestamp()) { + return -1; + } else if (memorySlice.getTimestamp() < this.timestamp) { + return 1; + } + return 0; + } + + public SliceData getSliceData(){ + //todo: 待实现获取逻辑 + return new SliceData(); + } + + public void saveSlice(SliceData sliceData){ + //todo: 待实现存储逻辑, 该逻辑内将设置`slicePath` + } } diff --git a/src/main/java/work/slhaf/memory/content/SliceData.java b/src/main/java/work/slhaf/memory/content/SliceData.java new file mode 100644 index 00000000..57db6349 --- /dev/null +++ b/src/main/java/work/slhaf/memory/content/SliceData.java @@ -0,0 +1,10 @@ +package work.slhaf.memory.content; + +import com.alibaba.fastjson2.JSONArray; +import lombok.Data; + +@Data +public class SliceData { + private String summary; + private JSONArray content; +} diff --git a/src/main/java/work/slhaf/memory/node/TopicNode.java b/src/main/java/work/slhaf/memory/node/TopicNode.java index af96c800..6ebbc969 100644 --- a/src/main/java/work/slhaf/memory/node/TopicNode.java +++ b/src/main/java/work/slhaf/memory/node/TopicNode.java @@ -9,6 +9,5 @@ import java.util.List; @Data public class TopicNode implements Serializable { private HashMap topicNodes; -// private Integer weight = 0; private List memoryNodes; } diff --git a/src/test/java/memory/InsertTest.java b/src/test/java/memory/InsertTest.java index b342c456..31dae867 100644 --- a/src/test/java/memory/InsertTest.java +++ b/src/test/java/memory/InsertTest.java @@ -124,7 +124,6 @@ public class InsertTest { private MemorySlice createTestMemorySlice(String id) { MemorySlice slice = new MemorySlice(); slice.setMemoryId(id); - slice.setMemoryRank(1); // 可以设置其他必要属性 return slice; } @@ -134,7 +133,6 @@ public class InsertTest { // 构造 MemorySlice MemorySlice slice = new MemorySlice(); slice.setMemoryId("001"); - slice.setMemoryRank(5); slice.setSlicePath("/demo/path"); List topicPath = Arrays.asList("生活", "学习", "Java"); @@ -162,7 +160,6 @@ public class InsertTest { // 校验:MemorySlice 内容一致 MemorySlice deserializedSlice = javaNode.getMemoryNodes().get(0).getMemorySliceList().get(0); assertEquals("001", deserializedSlice.getMemoryId()); - assertEquals(Integer.valueOf(5), deserializedSlice.getMemoryRank()); assertEquals("/demo/path", deserializedSlice.getSlicePath()); } diff --git a/src/test/java/memory/SearchTest.java b/src/test/java/memory/SearchTest.java index c9e3548a..6f33222b 100644 --- a/src/test/java/memory/SearchTest.java +++ b/src/test/java/memory/SearchTest.java @@ -57,7 +57,7 @@ class SearchTest { List queryPath = new ArrayList<>(); queryPath.add("算法"); queryPath.add("排序"); - List results = memoryGraph.selectMemory(queryPath); + List results = memoryGraph.selectMemoryByPath(queryPath); // 验证结果应包含: // 1. 目标节点所有记忆(java1) @@ -75,7 +75,7 @@ class SearchTest { invalidPath.add("不存在的主题"); assertThrows(UnExistedTopicException.class, () -> { - memoryGraph.selectMemory(invalidPath); + memoryGraph.selectMemoryByPath(invalidPath); }); } @@ -92,7 +92,7 @@ class SearchTest { List queryPath = new ArrayList<>(); queryPath.add("编程"); queryPath.add("Java"); - List results = memoryGraph.selectMemory(queryPath); + List results = memoryGraph.selectMemoryByPath(queryPath); // 应包含:Java记忆 + 父级最新记忆 assertTrue(results.stream().anyMatch(m -> "java1".equals(m.getMemoryId()))); @@ -134,7 +134,7 @@ class SearchTest { // 执行查询 List queryPath = createTopicPath("编程", "Java"); - List results = memoryGraph.selectMemory(queryPath); + List results = memoryGraph.selectMemoryByPath(queryPath); // 验证结果应包含最新关联记忆(dbNew) assertTrue(results.stream().anyMatch(m -> "dbNew".equals(m.getMemoryId())), @@ -152,7 +152,6 @@ class SearchTest { private MemorySlice createMemorySlice(String id) { MemorySlice slice = new MemorySlice(); slice.setMemoryId(id); - slice.setMemoryRank(1); return slice; }