From d75f83b1a281843ebe64b949daa0ace916374c49 Mon Sep 17 00:00:00 2001 From: slhaf Date: Wed, 9 Apr 2025 23:20:47 +0800 Subject: [PATCH] =?UTF-8?q?=E5=AE=9E=E7=8E=B0=E4=BA=86MemoryGraph=E7=9A=84?= =?UTF-8?q?=E6=9F=A5=E6=89=BE=E5=8A=9F=E8=83=BD=EF=BC=8C=E5=8C=85=E5=90=AB?= =?UTF-8?q?=E7=9B=AE=E6=A0=87=E8=AE=B0=E5=BF=86=E8=8A=82=E7=82=B9=E3=80=81?= =?UTF-8?q?=E9=82=BB=E8=BF=91=E8=AE=B0=E5=BF=86=E8=8A=82=E7=82=B9=E7=9A=84?= =?UTF-8?q?=E6=9F=A5=E6=89=BE=EF=BC=8C=E5=B9=B6=E7=BC=96=E9=80=9A=E8=BF=87?= =?UTF-8?q?AI=E5=86=99=E4=BA=86=E6=B5=8B=E8=AF=95=E7=94=A8=E4=BE=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../java/work/slhaf/memory/MemoryGraph.java | 82 +++++++-- .../slhaf/memory/content/MemorySlice.java | 5 +- .../exception/UnExistedTopicException.java | 7 + .../work/slhaf/memory/node/MemoryNode.java | 18 +- .../work/slhaf/memory/node/TopicNode.java | 1 + src/test/java/memory/InsertTest.java | 45 ++++- src/test/java/memory/SearchTest.java | 166 ++++++++++++++++++ 7 files changed, 303 insertions(+), 21 deletions(-) create mode 100644 src/main/java/work/slhaf/memory/exception/UnExistedTopicException.java create mode 100644 src/test/java/memory/SearchTest.java diff --git a/src/main/java/work/slhaf/memory/MemoryGraph.java b/src/main/java/work/slhaf/memory/MemoryGraph.java index 3b54e1cb..a5498af4 100644 --- a/src/main/java/work/slhaf/memory/MemoryGraph.java +++ b/src/main/java/work/slhaf/memory/MemoryGraph.java @@ -2,6 +2,7 @@ package work.slhaf.memory; import lombok.Data; import work.slhaf.memory.content.MemorySlice; +import work.slhaf.memory.exception.UnExistedTopicException; import work.slhaf.memory.node.MemoryNode; import work.slhaf.memory.node.TopicNode; @@ -9,7 +10,7 @@ import java.io.*; import java.nio.file.Files; import java.nio.file.Path; import java.nio.file.Paths; -import java.time.LocalDateTime; +import java.time.LocalDate; import java.util.*; @Data @@ -20,9 +21,9 @@ public class MemoryGraph implements Serializable { private static final String STORAGE_DIR = "./data/memory/"; private String id; - private HashMap topicNodes; + private HashMap topicNodes; public static MemoryGraph memoryGraph; - private HashMap> existedTopics; + private HashMap> existedTopics; public MemoryGraph(String id) { this.id = id; @@ -85,10 +86,9 @@ public class MemoryGraph implements Serializable { } } - public void insertMemory(List topicPath, MemorySlice slice) { topicPath = new ArrayList<>(topicPath); - if (topicNodes == null){ + if (topicNodes == null) { topicNodes = new HashMap<>(); } //查看是否存在根主题节点 @@ -98,17 +98,16 @@ public class MemoryGraph implements Serializable { TopicNode rootNode = new TopicNode(); rootNode.setMemoryNodes(new ArrayList<>()); rootNode.setTopicNodes(new HashMap<>()); - topicNodes.put(rootTopic,rootNode); - existedTopics.put(rootTopic,new HashSet<>()); + topicNodes.put(rootTopic, rootNode); + existedTopics.put(rootTopic, new HashSet<>()); } TopicNode lastTopicNode = topicNodes.get(rootTopic); Set existedTopicNodes = existedTopics.get(rootTopic); - for (int i = 0; i < topicPath.size(); i++) { - String topic = topicPath.get(i); + for (String topic : topicPath) { if (existedTopicNodes.contains(topic)) { lastTopicNode = lastTopicNode.getTopicNodes().get(topic); - }else { + } else { TopicNode newNode = new TopicNode(); lastTopicNode.getTopicNodes().put(topic, newNode); lastTopicNode = newNode; @@ -123,11 +122,11 @@ public class MemoryGraph implements Serializable { } } //检查是否存在当天对应的memoryData - LocalDateTime now = LocalDateTime.now(); + LocalDate now = LocalDate.now(); boolean hasSlice = false; MemoryNode node = null; for (MemoryNode memoryNode : lastTopicNode.getMemoryNodes()) { - if (now.toLocalDate().equals(memoryNode.getLocalDateTime().toLocalDate())){ + if (now.equals(memoryNode.getLocalDate())) { hasSlice = true; node = memoryNode; break; @@ -135,12 +134,69 @@ public class MemoryGraph implements Serializable { } if (!hasSlice) { node = new MemoryNode(); - node.setLocalDateTime(now); + node.setLocalDate(now); node.setMemorySliceList(new ArrayList<>()); lastTopicNode.getMemoryNodes().add(node); + lastTopicNode.getMemoryNodes().sort(null); } node.getMemorySliceList().add(slice); } + public List selectMemory(List topicPath) { + List targetSliceList = new ArrayList<>(); + topicPath = new ArrayList<>(topicPath); + String targetTopic = topicPath.getLast(); + TopicNode targetParentNode = getTargetParentNode(topicPath, targetTopic); + List> relatedTopics = new ArrayList<>(); + //终点记忆节点 + for (MemoryNode memoryNode : targetParentNode.getTopicNodes().get(targetTopic).getMemoryNodes()) { + List endpointMemorySliceList = memoryNode.getMemorySliceList(); + targetSliceList.addAll(endpointMemorySliceList); + for (MemorySlice memorySlice : endpointMemorySliceList) { + if (memorySlice.getRelatedTopics() != null) { + relatedTopics.addAll(memorySlice.getRelatedTopics()); + } + } + } + //邻近记忆节点 联系 + for (List relatedTopic : relatedTopics) { + List tempTopicPath = new ArrayList<>(relatedTopic); + String tempTargetTopic = tempTopicPath.getLast(); + TopicNode tempTargetParentNode = getTargetParentNode(tempTopicPath, tempTargetTopic); + //获取终点节点及其最新记忆节点 + TopicNode tempTargetNode = tempTargetParentNode.getTopicNodes().get(tempTopicPath.getLast()); + List tempMemoryNodes = tempTargetNode.getMemoryNodes(); + if (!tempMemoryNodes.isEmpty()) { + targetSliceList.addAll(tempMemoryNodes.getFirst().getMemorySliceList()); + } + } + //邻近记忆节点 父级 + List targetParentMemoryNodes = targetParentNode.getMemoryNodes(); + if (!targetParentMemoryNodes.isEmpty()) { + targetSliceList.addAll(targetParentMemoryNodes.getFirst().getMemorySliceList()); + } + return targetSliceList; + } + + private TopicNode getTargetParentNode(List topicPath, String targetTopic) { + String topTopic = topicPath.getFirst(); + if (!existedTopics.containsKey(topTopic)){ + throw new UnExistedTopicException("不存在的主题: " + topTopic); + } + TopicNode targetParentNode = topicNodes.get(topTopic); + topicPath.removeFirst(); + for (String topic : topicPath) { + if (!existedTopics.get(topTopic).contains(topic)){ + throw new UnExistedTopicException("不存在的主题: " + topTopic); + } + } + + //逐层查找目标主题,可选取终点主题节点相邻位置的主题节点。终点记忆节点选取全部memoryNode, 邻近记忆节点选取最新日期的memoryNode + while (!targetParentNode.getTopicNodes().containsKey(targetTopic)) { + targetParentNode = targetParentNode.getTopicNodes().get(topicPath.getFirst()); + topicPath.removeFirst(); + } + return targetParentNode; + } } diff --git a/src/main/java/work/slhaf/memory/content/MemorySlice.java b/src/main/java/work/slhaf/memory/content/MemorySlice.java index a1028915..28f15c4c 100644 --- a/src/main/java/work/slhaf/memory/content/MemorySlice.java +++ b/src/main/java/work/slhaf/memory/content/MemorySlice.java @@ -9,10 +9,13 @@ import java.util.List; @Data public class MemorySlice implements Serializable { + //关联的完整对话的id private String memoryId; + //该切片在关联的完整对话中的顺序 private Integer memoryRank; private String slicePath; - private List relatedTopics; + private List> relatedTopics; + //关联完整对话中的前序切片, 排序为键,完整路径为值 private LinkedHashMap sliceBefore; private LinkedHashMap sliceAfter; } diff --git a/src/main/java/work/slhaf/memory/exception/UnExistedTopicException.java b/src/main/java/work/slhaf/memory/exception/UnExistedTopicException.java new file mode 100644 index 00000000..d3662da3 --- /dev/null +++ b/src/main/java/work/slhaf/memory/exception/UnExistedTopicException.java @@ -0,0 +1,7 @@ +package work.slhaf.memory.exception; + +public class UnExistedTopicException extends RuntimeException { + public UnExistedTopicException(String message) { + super(message); + } +} diff --git a/src/main/java/work/slhaf/memory/node/MemoryNode.java b/src/main/java/work/slhaf/memory/node/MemoryNode.java index c686970c..3580f208 100644 --- a/src/main/java/work/slhaf/memory/node/MemoryNode.java +++ b/src/main/java/work/slhaf/memory/node/MemoryNode.java @@ -4,11 +4,23 @@ import lombok.Data; import work.slhaf.memory.content.MemorySlice; import java.io.Serializable; -import java.time.LocalDateTime; +import java.time.LocalDate; import java.util.List; @Data -public class MemoryNode implements Serializable { - private LocalDateTime localDateTime; +public class MemoryNode implements Serializable, Comparable { + //记忆节点所属日期 + private LocalDate localDate; + //该日期对应的全部记忆切片 private List memorySliceList; + + @Override + public int compareTo(MemoryNode memoryNode) { + if (memoryNode.getLocalDate().isAfter(this.localDate)) { + return -1; + } else if (memoryNode.getLocalDate().isBefore(this.localDate)) { + return 1; + } + return 0; + } } diff --git a/src/main/java/work/slhaf/memory/node/TopicNode.java b/src/main/java/work/slhaf/memory/node/TopicNode.java index 6ebbc969..af96c800 100644 --- a/src/main/java/work/slhaf/memory/node/TopicNode.java +++ b/src/main/java/work/slhaf/memory/node/TopicNode.java @@ -9,5 +9,6 @@ import java.util.List; @Data public class TopicNode implements Serializable { private HashMap topicNodes; +// private Integer weight = 0; private List memoryNodes; } diff --git a/src/test/java/memory/InsertTest.java b/src/test/java/memory/InsertTest.java index 453a4c08..b342c456 100644 --- a/src/test/java/memory/InsertTest.java +++ b/src/test/java/memory/InsertTest.java @@ -2,12 +2,12 @@ package memory; import org.junit.Before; import org.junit.Test; -import org.junit.jupiter.api.BeforeEach; import work.slhaf.memory.MemoryGraph; import work.slhaf.memory.content.MemorySlice; import work.slhaf.memory.node.MemoryNode; import work.slhaf.memory.node.TopicNode; +import java.time.LocalDate; import java.time.LocalDateTime; import java.util.Arrays; import java.util.HashMap; @@ -18,7 +18,7 @@ import static org.junit.Assert.*; public class InsertTest { private MemoryGraph memoryGraph; - private final String testId = "test"; + private final String testId = "test_insert"; @Before public void setUp() { @@ -48,7 +48,7 @@ public class InsertTest { assertEquals(1, collectionsNode.getMemoryNodes().size()); MemoryNode memoryNode = collectionsNode.getMemoryNodes().get(0); - assertEquals(LocalDateTime.now().toLocalDate(), memoryNode.getLocalDateTime().toLocalDate()); + assertEquals(LocalDate.now(), memoryNode.getLocalDate()); assertEquals(1, memoryNode.getMemorySliceList().size()); assertEquals(slice, memoryNode.getMemorySliceList().get(0)); } @@ -88,7 +88,7 @@ public class InsertTest { MemoryNode firstNode = memoryGraph.getTopicNodes().get("Math") .getTopicNodes().get("Algebra") .getMemoryNodes().get(0); - firstNode.setLocalDateTime(LocalDateTime.now().minusDays(1)); + firstNode.setLocalDate(LocalDate.now().minusDays(1)); // 第二次插入 memoryGraph.insertMemory(topicPath, slice2); @@ -129,4 +129,41 @@ public class InsertTest { return slice; } + @Test + public void testSerializationConsistency() { + // 构造 MemorySlice + MemorySlice slice = new MemorySlice(); + slice.setMemoryId("001"); + slice.setMemoryRank(5); + slice.setSlicePath("/demo/path"); + + List topicPath = Arrays.asList("生活", "学习", "Java"); + + // 插入 memory + memoryGraph.insertMemory(topicPath, slice); + memoryGraph.serialize(); + + // 反序列化 + MemoryGraph loadedGraph = MemoryGraph.initialize(testId); + + // 校验:topic 是否存在 + assertNotNull(loadedGraph.getTopicNodes().get("生活")); + TopicNode lifeNode = loadedGraph.getTopicNodes().get("生活"); + + assertNotNull(lifeNode.getTopicNodes().get("学习")); + TopicNode studyNode = lifeNode.getTopicNodes().get("学习"); + + assertNotNull(studyNode.getTopicNodes().get("Java")); + TopicNode javaNode = studyNode.getTopicNodes().get("Java"); + + // 校验:是否存在 MemoryNode + assertFalse(javaNode.getMemoryNodes().isEmpty()); + + // 校验:MemorySlice 内容一致 + MemorySlice deserializedSlice = javaNode.getMemoryNodes().get(0).getMemorySliceList().get(0); + assertEquals("001", deserializedSlice.getMemoryId()); + assertEquals(Integer.valueOf(5), deserializedSlice.getMemoryRank()); + assertEquals("/demo/path", deserializedSlice.getSlicePath()); + } + } diff --git a/src/test/java/memory/SearchTest.java b/src/test/java/memory/SearchTest.java new file mode 100644 index 00000000..c9e3548a --- /dev/null +++ b/src/test/java/memory/SearchTest.java @@ -0,0 +1,166 @@ +package memory; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import work.slhaf.memory.MemoryGraph; +import work.slhaf.memory.content.MemorySlice; +import work.slhaf.memory.exception.UnExistedTopicException; +import work.slhaf.memory.node.MemoryNode; +import work.slhaf.memory.node.TopicNode; + +import java.time.LocalDate; +import java.util.ArrayList; +import java.util.List; + +import static org.junit.jupiter.api.Assertions.*; + +class SearchTest { + private MemoryGraph memoryGraph; + private final LocalDate today = LocalDate.now(); + private final LocalDate yesterday = LocalDate.now().minusDays(1); + + // 初始化测试环境,模拟插入基础数据 + @BeforeEach + void setUp() { + memoryGraph = new MemoryGraph("testGraph"); + + // 构建基础主题路径:根主题 -> 编程 -> Java + List javaPath = new ArrayList<>(); + javaPath.add("编程"); + javaPath.add("Java"); + + // 插入今天的Java相关记忆 + MemorySlice javaMemory = createMemorySlice("java1"); + memoryGraph.insertMemory(javaPath, javaMemory); + + // 插入昨天的Java记忆(应不会出现在邻近结果中) + MemorySlice oldJavaMemory = createMemorySlice("javaOld"); + MemoryNode oldNode = new MemoryNode(); + oldNode.setLocalDate(yesterday); + oldNode.setMemorySliceList(List.of(oldJavaMemory)); + } + + // 场景1:查询存在的完整主题路径(含相关主题) + @Test + void selectMemory_shouldReturnTargetAndRelatedAndParentMemories() { + // 准备相关主题数据:根主题 -> 算法 -> 排序 + List sortPath = new ArrayList<>(); + sortPath.add("算法"); + sortPath.add("排序"); + MemorySlice sortMemory = createMemorySlice("sort1"); + sortMemory.setRelatedTopics(List.of( + createTopicPath("编程", "Java") // 设置反向关联 + )); + memoryGraph.insertMemory(sortPath, sortMemory); + + // 执行查询:编程 -> Java + List queryPath = new ArrayList<>(); + queryPath.add("算法"); + queryPath.add("排序"); + List results = memoryGraph.selectMemory(queryPath); + + // 验证结果应包含: + // 1. 目标节点所有记忆(java1) + // 2. 相关主题(排序)的最新记忆(sort1) + // 3. 父节点(编程)的最新记忆(需要提前插入) + assertTrue(results.stream().anyMatch(m -> "java1".equals(m.getMemoryId()))); + assertTrue(results.stream().anyMatch(m -> "sort1".equals(m.getMemoryId()))); + assertEquals(2, results.size()); // 根据具体实现可能调整 + } + + // 场景2:查询不存在的主题路径 + @Test + void selectMemory_shouldThrowWhenPathNotExist() { + List invalidPath = new ArrayList<>(); + invalidPath.add("不存在的主题"); + + assertThrows(UnExistedTopicException.class, () -> { + memoryGraph.selectMemory(invalidPath); + }); + } + + // 场景3:无相关主题时仅返回目标节点和父节点记忆 + @Test + void selectMemory_withoutRelatedTopics_shouldReturnTargetAndParent() { + // 插入父级记忆:根主题 -> 编程 + List parentPath = new ArrayList<>(); + parentPath.add("编程"); + MemorySlice parentMemory = createMemorySlice("parent1"); + memoryGraph.insertMemory(parentPath, parentMemory); + + // 执行查询 + List queryPath = new ArrayList<>(); + queryPath.add("编程"); + queryPath.add("Java"); + List results = memoryGraph.selectMemory(queryPath); + + // 应包含:Java记忆 + 父级最新记忆 + assertTrue(results.stream().anyMatch(m -> "java1".equals(m.getMemoryId()))); + assertTrue(results.stream().anyMatch(m -> "parent1".equals(m.getMemoryId()))); + assertEquals(2, results.size()); + } + + // 场景4:验证日期排序,应优先取最新日期的邻近记忆 + @Test + void selectMemory_shouldGetLatestRelatedMemory() { + // 准备相关主题路径:根主题 -> 数据库 + List dbPath = new ArrayList<>(); + dbPath.add("数据库"); + dbPath.add("mysql"); + + // 插入今天的数据库记忆(正常流程) + MemorySlice newDbMemory = createMemorySlice("dbNew"); + memoryGraph.insertMemory(dbPath, newDbMemory); + + // 手动构建并插入昨天的数据库记忆 + MemorySlice oldDbMemory = createMemorySlice("dbOld"); + TopicNode dbTopicNode = memoryGraph.getTopicNodes().get("数据库"); + + // 创建昨日记忆节点并添加到主题节点 + MemoryNode oldMemoryNode = new MemoryNode(); + oldMemoryNode.setLocalDate(yesterday); + oldMemoryNode.setMemorySliceList(new ArrayList<>(List.of(oldDbMemory))); + dbTopicNode.getMemoryNodes().add(oldMemoryNode); + + // 对记忆节点进行日期排序(根据compareTo方法) + dbTopicNode.getMemoryNodes().sort(null); + + // 创建Java记忆并关联数据库主题 + MemorySlice javaMemory = createMemorySlice("java2"); + javaMemory.setRelatedTopics(List.of( + createTopicPath("数据库","") // 完整主题路径 + )); + memoryGraph.insertMemory(createTopicPath("编程", "Java"), javaMemory); + + // 执行查询 + List queryPath = createTopicPath("编程", "Java"); + List results = memoryGraph.selectMemory(queryPath); + + // 验证结果应包含最新关联记忆(dbNew) + assertTrue(results.stream().anyMatch(m -> "dbNew".equals(m.getMemoryId())), + "应包含最新的数据库记忆"); + assertFalse(results.stream().anyMatch(m -> "dbOld".equals(m.getMemoryId())), + "不应包含过期的数据库记忆"); + + // 验证结果包含目标记忆(java1和java2) + assertTrue(results.stream().anyMatch(m -> "java1".equals(m.getMemoryId())), + "应包含基础测试数据"); + assertTrue(results.stream().anyMatch(m -> "java2".equals(m.getMemoryId())), + "应包含当前测试插入数据"); + } + + private MemorySlice createMemorySlice(String id) { + MemorySlice slice = new MemorySlice(); + slice.setMemoryId(id); + slice.setMemoryRank(1); + return slice; + } + + private ArrayList createTopicPath(String... topics) { + ArrayList path = new ArrayList<>(); + for (String topic : topics) { + path.add(topic); + } + return path; + } +} \ No newline at end of file