实现了MemoryGraph的查找功能,包含目标记忆节点、邻近记忆节点的查找,并编通过AI写了测试用例

This commit is contained in:
2025-04-09 23:20:47 +08:00
parent cad3af346f
commit d75f83b1a2
7 changed files with 303 additions and 21 deletions

View File

@@ -2,12 +2,12 @@ package memory;
import org.junit.Before;
import org.junit.Test;
import org.junit.jupiter.api.BeforeEach;
import work.slhaf.memory.MemoryGraph;
import work.slhaf.memory.content.MemorySlice;
import work.slhaf.memory.node.MemoryNode;
import work.slhaf.memory.node.TopicNode;
import java.time.LocalDate;
import java.time.LocalDateTime;
import java.util.Arrays;
import java.util.HashMap;
@@ -18,7 +18,7 @@ import static org.junit.Assert.*;
public class InsertTest {
private MemoryGraph memoryGraph;
private final String testId = "test";
private final String testId = "test_insert";
@Before
public void setUp() {
@@ -48,7 +48,7 @@ public class InsertTest {
assertEquals(1, collectionsNode.getMemoryNodes().size());
MemoryNode memoryNode = collectionsNode.getMemoryNodes().get(0);
assertEquals(LocalDateTime.now().toLocalDate(), memoryNode.getLocalDateTime().toLocalDate());
assertEquals(LocalDate.now(), memoryNode.getLocalDate());
assertEquals(1, memoryNode.getMemorySliceList().size());
assertEquals(slice, memoryNode.getMemorySliceList().get(0));
}
@@ -88,7 +88,7 @@ public class InsertTest {
MemoryNode firstNode = memoryGraph.getTopicNodes().get("Math")
.getTopicNodes().get("Algebra")
.getMemoryNodes().get(0);
firstNode.setLocalDateTime(LocalDateTime.now().minusDays(1));
firstNode.setLocalDate(LocalDate.now().minusDays(1));
// 第二次插入
memoryGraph.insertMemory(topicPath, slice2);
@@ -129,4 +129,41 @@ public class InsertTest {
return slice;
}
@Test
public void testSerializationConsistency() {
// 构造 MemorySlice
MemorySlice slice = new MemorySlice();
slice.setMemoryId("001");
slice.setMemoryRank(5);
slice.setSlicePath("/demo/path");
List<String> topicPath = Arrays.asList("生活", "学习", "Java");
// 插入 memory
memoryGraph.insertMemory(topicPath, slice);
memoryGraph.serialize();
// 反序列化
MemoryGraph loadedGraph = MemoryGraph.initialize(testId);
// 校验topic 是否存在
assertNotNull(loadedGraph.getTopicNodes().get("生活"));
TopicNode lifeNode = loadedGraph.getTopicNodes().get("生活");
assertNotNull(lifeNode.getTopicNodes().get("学习"));
TopicNode studyNode = lifeNode.getTopicNodes().get("学习");
assertNotNull(studyNode.getTopicNodes().get("Java"));
TopicNode javaNode = studyNode.getTopicNodes().get("Java");
// 校验:是否存在 MemoryNode
assertFalse(javaNode.getMemoryNodes().isEmpty());
// 校验MemorySlice 内容一致
MemorySlice deserializedSlice = javaNode.getMemoryNodes().get(0).getMemorySliceList().get(0);
assertEquals("001", deserializedSlice.getMemoryId());
assertEquals(Integer.valueOf(5), deserializedSlice.getMemoryRank());
assertEquals("/demo/path", deserializedSlice.getSlicePath());
}
}

View File

@@ -0,0 +1,166 @@
package memory;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import work.slhaf.memory.MemoryGraph;
import work.slhaf.memory.content.MemorySlice;
import work.slhaf.memory.exception.UnExistedTopicException;
import work.slhaf.memory.node.MemoryNode;
import work.slhaf.memory.node.TopicNode;
import java.time.LocalDate;
import java.util.ArrayList;
import java.util.List;
import static org.junit.jupiter.api.Assertions.*;
class SearchTest {
private MemoryGraph memoryGraph;
private final LocalDate today = LocalDate.now();
private final LocalDate yesterday = LocalDate.now().minusDays(1);
// 初始化测试环境,模拟插入基础数据
@BeforeEach
void setUp() {
memoryGraph = new MemoryGraph("testGraph");
// 构建基础主题路径:根主题 -> 编程 -> Java
List<String> javaPath = new ArrayList<>();
javaPath.add("编程");
javaPath.add("Java");
// 插入今天的Java相关记忆
MemorySlice javaMemory = createMemorySlice("java1");
memoryGraph.insertMemory(javaPath, javaMemory);
// 插入昨天的Java记忆应不会出现在邻近结果中
MemorySlice oldJavaMemory = createMemorySlice("javaOld");
MemoryNode oldNode = new MemoryNode();
oldNode.setLocalDate(yesterday);
oldNode.setMemorySliceList(List.of(oldJavaMemory));
}
// 场景1查询存在的完整主题路径含相关主题
@Test
void selectMemory_shouldReturnTargetAndRelatedAndParentMemories() {
// 准备相关主题数据:根主题 -> 算法 -> 排序
List<String> sortPath = new ArrayList<>();
sortPath.add("算法");
sortPath.add("排序");
MemorySlice sortMemory = createMemorySlice("sort1");
sortMemory.setRelatedTopics(List.of(
createTopicPath("编程", "Java") // 设置反向关联
));
memoryGraph.insertMemory(sortPath, sortMemory);
// 执行查询:编程 -> Java
List<String> queryPath = new ArrayList<>();
queryPath.add("算法");
queryPath.add("排序");
List<MemorySlice> results = memoryGraph.selectMemory(queryPath);
// 验证结果应包含:
// 1. 目标节点所有记忆java1
// 2. 相关主题排序的最新记忆sort1
// 3. 父节点(编程)的最新记忆(需要提前插入)
assertTrue(results.stream().anyMatch(m -> "java1".equals(m.getMemoryId())));
assertTrue(results.stream().anyMatch(m -> "sort1".equals(m.getMemoryId())));
assertEquals(2, results.size()); // 根据具体实现可能调整
}
// 场景2查询不存在的主题路径
@Test
void selectMemory_shouldThrowWhenPathNotExist() {
List<String> invalidPath = new ArrayList<>();
invalidPath.add("不存在的主题");
assertThrows(UnExistedTopicException.class, () -> {
memoryGraph.selectMemory(invalidPath);
});
}
// 场景3无相关主题时仅返回目标节点和父节点记忆
@Test
void selectMemory_withoutRelatedTopics_shouldReturnTargetAndParent() {
// 插入父级记忆:根主题 -> 编程
List<String> parentPath = new ArrayList<>();
parentPath.add("编程");
MemorySlice parentMemory = createMemorySlice("parent1");
memoryGraph.insertMemory(parentPath, parentMemory);
// 执行查询
List<String> queryPath = new ArrayList<>();
queryPath.add("编程");
queryPath.add("Java");
List<MemorySlice> results = memoryGraph.selectMemory(queryPath);
// 应包含Java记忆 + 父级最新记忆
assertTrue(results.stream().anyMatch(m -> "java1".equals(m.getMemoryId())));
assertTrue(results.stream().anyMatch(m -> "parent1".equals(m.getMemoryId())));
assertEquals(2, results.size());
}
// 场景4验证日期排序应优先取最新日期的邻近记忆
@Test
void selectMemory_shouldGetLatestRelatedMemory() {
// 准备相关主题路径:根主题 -> 数据库
List<String> dbPath = new ArrayList<>();
dbPath.add("数据库");
dbPath.add("mysql");
// 插入今天的数据库记忆(正常流程)
MemorySlice newDbMemory = createMemorySlice("dbNew");
memoryGraph.insertMemory(dbPath, newDbMemory);
// 手动构建并插入昨天的数据库记忆
MemorySlice oldDbMemory = createMemorySlice("dbOld");
TopicNode dbTopicNode = memoryGraph.getTopicNodes().get("数据库");
// 创建昨日记忆节点并添加到主题节点
MemoryNode oldMemoryNode = new MemoryNode();
oldMemoryNode.setLocalDate(yesterday);
oldMemoryNode.setMemorySliceList(new ArrayList<>(List.of(oldDbMemory)));
dbTopicNode.getMemoryNodes().add(oldMemoryNode);
// 对记忆节点进行日期排序根据compareTo方法
dbTopicNode.getMemoryNodes().sort(null);
// 创建Java记忆并关联数据库主题
MemorySlice javaMemory = createMemorySlice("java2");
javaMemory.setRelatedTopics(List.of(
createTopicPath("数据库","") // 完整主题路径
));
memoryGraph.insertMemory(createTopicPath("编程", "Java"), javaMemory);
// 执行查询
List<String> queryPath = createTopicPath("编程", "Java");
List<MemorySlice> results = memoryGraph.selectMemory(queryPath);
// 验证结果应包含最新关联记忆dbNew
assertTrue(results.stream().anyMatch(m -> "dbNew".equals(m.getMemoryId())),
"应包含最新的数据库记忆");
assertFalse(results.stream().anyMatch(m -> "dbOld".equals(m.getMemoryId())),
"不应包含过期的数据库记忆");
// 验证结果包含目标记忆java1和java2
assertTrue(results.stream().anyMatch(m -> "java1".equals(m.getMemoryId())),
"应包含基础测试数据");
assertTrue(results.stream().anyMatch(m -> "java2".equals(m.getMemoryId())),
"应包含当前测试插入数据");
}
private MemorySlice createMemorySlice(String id) {
MemorySlice slice = new MemorySlice();
slice.setMemoryId(id);
slice.setMemoryRank(1);
return slice;
}
private ArrayList<String> createTopicPath(String... topics) {
ArrayList<String> path = new ArrayList<>();
for (String topic : topics) {
path.add(topic);
}
return path;
}
}