mirror of
https://github.com/slhaf/Partner.git
synced 2026-05-12 16:53:04 +08:00
实现了MemoryGraph的查找功能,包含目标记忆节点、邻近记忆节点的查找,并编通过AI写了测试用例
This commit is contained in:
@@ -2,6 +2,7 @@ package work.slhaf.memory;
|
|||||||
|
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
import work.slhaf.memory.content.MemorySlice;
|
import work.slhaf.memory.content.MemorySlice;
|
||||||
|
import work.slhaf.memory.exception.UnExistedTopicException;
|
||||||
import work.slhaf.memory.node.MemoryNode;
|
import work.slhaf.memory.node.MemoryNode;
|
||||||
import work.slhaf.memory.node.TopicNode;
|
import work.slhaf.memory.node.TopicNode;
|
||||||
|
|
||||||
@@ -9,7 +10,7 @@ import java.io.*;
|
|||||||
import java.nio.file.Files;
|
import java.nio.file.Files;
|
||||||
import java.nio.file.Path;
|
import java.nio.file.Path;
|
||||||
import java.nio.file.Paths;
|
import java.nio.file.Paths;
|
||||||
import java.time.LocalDateTime;
|
import java.time.LocalDate;
|
||||||
import java.util.*;
|
import java.util.*;
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
@@ -20,9 +21,9 @@ public class MemoryGraph implements Serializable {
|
|||||||
private static final String STORAGE_DIR = "./data/memory/";
|
private static final String STORAGE_DIR = "./data/memory/";
|
||||||
|
|
||||||
private String id;
|
private String id;
|
||||||
private HashMap<String,TopicNode> topicNodes;
|
private HashMap<String, TopicNode> topicNodes;
|
||||||
public static MemoryGraph memoryGraph;
|
public static MemoryGraph memoryGraph;
|
||||||
private HashMap<String,Set<String>> existedTopics;
|
private HashMap<String, Set<String>> existedTopics;
|
||||||
|
|
||||||
public MemoryGraph(String id) {
|
public MemoryGraph(String id) {
|
||||||
this.id = id;
|
this.id = id;
|
||||||
@@ -85,10 +86,9 @@ public class MemoryGraph implements Serializable {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
public void insertMemory(List<String> topicPath, MemorySlice slice) {
|
public void insertMemory(List<String> topicPath, MemorySlice slice) {
|
||||||
topicPath = new ArrayList<>(topicPath);
|
topicPath = new ArrayList<>(topicPath);
|
||||||
if (topicNodes == null){
|
if (topicNodes == null) {
|
||||||
topicNodes = new HashMap<>();
|
topicNodes = new HashMap<>();
|
||||||
}
|
}
|
||||||
//查看是否存在根主题节点
|
//查看是否存在根主题节点
|
||||||
@@ -98,17 +98,16 @@ public class MemoryGraph implements Serializable {
|
|||||||
TopicNode rootNode = new TopicNode();
|
TopicNode rootNode = new TopicNode();
|
||||||
rootNode.setMemoryNodes(new ArrayList<>());
|
rootNode.setMemoryNodes(new ArrayList<>());
|
||||||
rootNode.setTopicNodes(new HashMap<>());
|
rootNode.setTopicNodes(new HashMap<>());
|
||||||
topicNodes.put(rootTopic,rootNode);
|
topicNodes.put(rootTopic, rootNode);
|
||||||
existedTopics.put(rootTopic,new HashSet<>());
|
existedTopics.put(rootTopic, new HashSet<>());
|
||||||
}
|
}
|
||||||
|
|
||||||
TopicNode lastTopicNode = topicNodes.get(rootTopic);
|
TopicNode lastTopicNode = topicNodes.get(rootTopic);
|
||||||
Set<String> existedTopicNodes = existedTopics.get(rootTopic);
|
Set<String> existedTopicNodes = existedTopics.get(rootTopic);
|
||||||
for (int i = 0; i < topicPath.size(); i++) {
|
for (String topic : topicPath) {
|
||||||
String topic = topicPath.get(i);
|
|
||||||
if (existedTopicNodes.contains(topic)) {
|
if (existedTopicNodes.contains(topic)) {
|
||||||
lastTopicNode = lastTopicNode.getTopicNodes().get(topic);
|
lastTopicNode = lastTopicNode.getTopicNodes().get(topic);
|
||||||
}else {
|
} else {
|
||||||
TopicNode newNode = new TopicNode();
|
TopicNode newNode = new TopicNode();
|
||||||
lastTopicNode.getTopicNodes().put(topic, newNode);
|
lastTopicNode.getTopicNodes().put(topic, newNode);
|
||||||
lastTopicNode = newNode;
|
lastTopicNode = newNode;
|
||||||
@@ -123,11 +122,11 @@ public class MemoryGraph implements Serializable {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
//检查是否存在当天对应的memoryData
|
//检查是否存在当天对应的memoryData
|
||||||
LocalDateTime now = LocalDateTime.now();
|
LocalDate now = LocalDate.now();
|
||||||
boolean hasSlice = false;
|
boolean hasSlice = false;
|
||||||
MemoryNode node = null;
|
MemoryNode node = null;
|
||||||
for (MemoryNode memoryNode : lastTopicNode.getMemoryNodes()) {
|
for (MemoryNode memoryNode : lastTopicNode.getMemoryNodes()) {
|
||||||
if (now.toLocalDate().equals(memoryNode.getLocalDateTime().toLocalDate())){
|
if (now.equals(memoryNode.getLocalDate())) {
|
||||||
hasSlice = true;
|
hasSlice = true;
|
||||||
node = memoryNode;
|
node = memoryNode;
|
||||||
break;
|
break;
|
||||||
@@ -135,12 +134,69 @@ public class MemoryGraph implements Serializable {
|
|||||||
}
|
}
|
||||||
if (!hasSlice) {
|
if (!hasSlice) {
|
||||||
node = new MemoryNode();
|
node = new MemoryNode();
|
||||||
node.setLocalDateTime(now);
|
node.setLocalDate(now);
|
||||||
node.setMemorySliceList(new ArrayList<>());
|
node.setMemorySliceList(new ArrayList<>());
|
||||||
lastTopicNode.getMemoryNodes().add(node);
|
lastTopicNode.getMemoryNodes().add(node);
|
||||||
|
lastTopicNode.getMemoryNodes().sort(null);
|
||||||
}
|
}
|
||||||
node.getMemorySliceList().add(slice);
|
node.getMemorySliceList().add(slice);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public List<MemorySlice> selectMemory(List<String> topicPath) {
|
||||||
|
List<MemorySlice> targetSliceList = new ArrayList<>();
|
||||||
|
topicPath = new ArrayList<>(topicPath);
|
||||||
|
String targetTopic = topicPath.getLast();
|
||||||
|
TopicNode targetParentNode = getTargetParentNode(topicPath, targetTopic);
|
||||||
|
List<List<String>> relatedTopics = new ArrayList<>();
|
||||||
|
//终点记忆节点
|
||||||
|
for (MemoryNode memoryNode : targetParentNode.getTopicNodes().get(targetTopic).getMemoryNodes()) {
|
||||||
|
List<MemorySlice> endpointMemorySliceList = memoryNode.getMemorySliceList();
|
||||||
|
targetSliceList.addAll(endpointMemorySliceList);
|
||||||
|
for (MemorySlice memorySlice : endpointMemorySliceList) {
|
||||||
|
if (memorySlice.getRelatedTopics() != null) {
|
||||||
|
relatedTopics.addAll(memorySlice.getRelatedTopics());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
//邻近记忆节点 联系
|
||||||
|
for (List<String> relatedTopic : relatedTopics) {
|
||||||
|
List<String> tempTopicPath = new ArrayList<>(relatedTopic);
|
||||||
|
String tempTargetTopic = tempTopicPath.getLast();
|
||||||
|
TopicNode tempTargetParentNode = getTargetParentNode(tempTopicPath, tempTargetTopic);
|
||||||
|
//获取终点节点及其最新记忆节点
|
||||||
|
TopicNode tempTargetNode = tempTargetParentNode.getTopicNodes().get(tempTopicPath.getLast());
|
||||||
|
List<MemoryNode> tempMemoryNodes = tempTargetNode.getMemoryNodes();
|
||||||
|
if (!tempMemoryNodes.isEmpty()) {
|
||||||
|
targetSliceList.addAll(tempMemoryNodes.getFirst().getMemorySliceList());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
//邻近记忆节点 父级
|
||||||
|
List<MemoryNode> targetParentMemoryNodes = targetParentNode.getMemoryNodes();
|
||||||
|
if (!targetParentMemoryNodes.isEmpty()) {
|
||||||
|
targetSliceList.addAll(targetParentMemoryNodes.getFirst().getMemorySliceList());
|
||||||
|
}
|
||||||
|
return targetSliceList;
|
||||||
|
}
|
||||||
|
|
||||||
|
private TopicNode getTargetParentNode(List<String> topicPath, String targetTopic) {
|
||||||
|
String topTopic = topicPath.getFirst();
|
||||||
|
if (!existedTopics.containsKey(topTopic)){
|
||||||
|
throw new UnExistedTopicException("不存在的主题: " + topTopic);
|
||||||
|
}
|
||||||
|
TopicNode targetParentNode = topicNodes.get(topTopic);
|
||||||
|
topicPath.removeFirst();
|
||||||
|
for (String topic : topicPath) {
|
||||||
|
if (!existedTopics.get(topTopic).contains(topic)){
|
||||||
|
throw new UnExistedTopicException("不存在的主题: " + topTopic);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
//逐层查找目标主题,可选取终点主题节点相邻位置的主题节点。终点记忆节点选取全部memoryNode, 邻近记忆节点选取最新日期的memoryNode
|
||||||
|
while (!targetParentNode.getTopicNodes().containsKey(targetTopic)) {
|
||||||
|
targetParentNode = targetParentNode.getTopicNodes().get(topicPath.getFirst());
|
||||||
|
topicPath.removeFirst();
|
||||||
|
}
|
||||||
|
return targetParentNode;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -9,10 +9,13 @@ import java.util.List;
|
|||||||
|
|
||||||
@Data
|
@Data
|
||||||
public class MemorySlice implements Serializable {
|
public class MemorySlice implements Serializable {
|
||||||
|
//关联的完整对话的id
|
||||||
private String memoryId;
|
private String memoryId;
|
||||||
|
//该切片在关联的完整对话中的顺序
|
||||||
private Integer memoryRank;
|
private Integer memoryRank;
|
||||||
private String slicePath;
|
private String slicePath;
|
||||||
private List<TopicNode> relatedTopics;
|
private List<List<String>> relatedTopics;
|
||||||
|
//关联完整对话中的前序切片, 排序为键,完整路径为值
|
||||||
private LinkedHashMap<Integer,String> sliceBefore;
|
private LinkedHashMap<Integer,String> sliceBefore;
|
||||||
private LinkedHashMap<Integer,String> sliceAfter;
|
private LinkedHashMap<Integer,String> sliceAfter;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,7 @@
|
|||||||
|
package work.slhaf.memory.exception;
|
||||||
|
|
||||||
|
public class UnExistedTopicException extends RuntimeException {
|
||||||
|
public UnExistedTopicException(String message) {
|
||||||
|
super(message);
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -4,11 +4,23 @@ import lombok.Data;
|
|||||||
import work.slhaf.memory.content.MemorySlice;
|
import work.slhaf.memory.content.MemorySlice;
|
||||||
|
|
||||||
import java.io.Serializable;
|
import java.io.Serializable;
|
||||||
import java.time.LocalDateTime;
|
import java.time.LocalDate;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
public class MemoryNode implements Serializable {
|
public class MemoryNode implements Serializable, Comparable<MemoryNode> {
|
||||||
private LocalDateTime localDateTime;
|
//记忆节点所属日期
|
||||||
|
private LocalDate localDate;
|
||||||
|
//该日期对应的全部记忆切片
|
||||||
private List<MemorySlice> memorySliceList;
|
private List<MemorySlice> memorySliceList;
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int compareTo(MemoryNode memoryNode) {
|
||||||
|
if (memoryNode.getLocalDate().isAfter(this.localDate)) {
|
||||||
|
return -1;
|
||||||
|
} else if (memoryNode.getLocalDate().isBefore(this.localDate)) {
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -9,5 +9,6 @@ import java.util.List;
|
|||||||
@Data
|
@Data
|
||||||
public class TopicNode implements Serializable {
|
public class TopicNode implements Serializable {
|
||||||
private HashMap<String,TopicNode> topicNodes;
|
private HashMap<String,TopicNode> topicNodes;
|
||||||
|
// private Integer weight = 0;
|
||||||
private List<MemoryNode> memoryNodes;
|
private List<MemoryNode> memoryNodes;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,12 +2,12 @@ package memory;
|
|||||||
|
|
||||||
import org.junit.Before;
|
import org.junit.Before;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
import org.junit.jupiter.api.BeforeEach;
|
|
||||||
import work.slhaf.memory.MemoryGraph;
|
import work.slhaf.memory.MemoryGraph;
|
||||||
import work.slhaf.memory.content.MemorySlice;
|
import work.slhaf.memory.content.MemorySlice;
|
||||||
import work.slhaf.memory.node.MemoryNode;
|
import work.slhaf.memory.node.MemoryNode;
|
||||||
import work.slhaf.memory.node.TopicNode;
|
import work.slhaf.memory.node.TopicNode;
|
||||||
|
|
||||||
|
import java.time.LocalDate;
|
||||||
import java.time.LocalDateTime;
|
import java.time.LocalDateTime;
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
import java.util.HashMap;
|
import java.util.HashMap;
|
||||||
@@ -18,7 +18,7 @@ import static org.junit.Assert.*;
|
|||||||
|
|
||||||
public class InsertTest {
|
public class InsertTest {
|
||||||
private MemoryGraph memoryGraph;
|
private MemoryGraph memoryGraph;
|
||||||
private final String testId = "test";
|
private final String testId = "test_insert";
|
||||||
|
|
||||||
@Before
|
@Before
|
||||||
public void setUp() {
|
public void setUp() {
|
||||||
@@ -48,7 +48,7 @@ public class InsertTest {
|
|||||||
|
|
||||||
assertEquals(1, collectionsNode.getMemoryNodes().size());
|
assertEquals(1, collectionsNode.getMemoryNodes().size());
|
||||||
MemoryNode memoryNode = collectionsNode.getMemoryNodes().get(0);
|
MemoryNode memoryNode = collectionsNode.getMemoryNodes().get(0);
|
||||||
assertEquals(LocalDateTime.now().toLocalDate(), memoryNode.getLocalDateTime().toLocalDate());
|
assertEquals(LocalDate.now(), memoryNode.getLocalDate());
|
||||||
assertEquals(1, memoryNode.getMemorySliceList().size());
|
assertEquals(1, memoryNode.getMemorySliceList().size());
|
||||||
assertEquals(slice, memoryNode.getMemorySliceList().get(0));
|
assertEquals(slice, memoryNode.getMemorySliceList().get(0));
|
||||||
}
|
}
|
||||||
@@ -88,7 +88,7 @@ public class InsertTest {
|
|||||||
MemoryNode firstNode = memoryGraph.getTopicNodes().get("Math")
|
MemoryNode firstNode = memoryGraph.getTopicNodes().get("Math")
|
||||||
.getTopicNodes().get("Algebra")
|
.getTopicNodes().get("Algebra")
|
||||||
.getMemoryNodes().get(0);
|
.getMemoryNodes().get(0);
|
||||||
firstNode.setLocalDateTime(LocalDateTime.now().minusDays(1));
|
firstNode.setLocalDate(LocalDate.now().minusDays(1));
|
||||||
|
|
||||||
// 第二次插入
|
// 第二次插入
|
||||||
memoryGraph.insertMemory(topicPath, slice2);
|
memoryGraph.insertMemory(topicPath, slice2);
|
||||||
@@ -129,4 +129,41 @@ public class InsertTest {
|
|||||||
return slice;
|
return slice;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testSerializationConsistency() {
|
||||||
|
// 构造 MemorySlice
|
||||||
|
MemorySlice slice = new MemorySlice();
|
||||||
|
slice.setMemoryId("001");
|
||||||
|
slice.setMemoryRank(5);
|
||||||
|
slice.setSlicePath("/demo/path");
|
||||||
|
|
||||||
|
List<String> topicPath = Arrays.asList("生活", "学习", "Java");
|
||||||
|
|
||||||
|
// 插入 memory
|
||||||
|
memoryGraph.insertMemory(topicPath, slice);
|
||||||
|
memoryGraph.serialize();
|
||||||
|
|
||||||
|
// 反序列化
|
||||||
|
MemoryGraph loadedGraph = MemoryGraph.initialize(testId);
|
||||||
|
|
||||||
|
// 校验:topic 是否存在
|
||||||
|
assertNotNull(loadedGraph.getTopicNodes().get("生活"));
|
||||||
|
TopicNode lifeNode = loadedGraph.getTopicNodes().get("生活");
|
||||||
|
|
||||||
|
assertNotNull(lifeNode.getTopicNodes().get("学习"));
|
||||||
|
TopicNode studyNode = lifeNode.getTopicNodes().get("学习");
|
||||||
|
|
||||||
|
assertNotNull(studyNode.getTopicNodes().get("Java"));
|
||||||
|
TopicNode javaNode = studyNode.getTopicNodes().get("Java");
|
||||||
|
|
||||||
|
// 校验:是否存在 MemoryNode
|
||||||
|
assertFalse(javaNode.getMemoryNodes().isEmpty());
|
||||||
|
|
||||||
|
// 校验:MemorySlice 内容一致
|
||||||
|
MemorySlice deserializedSlice = javaNode.getMemoryNodes().get(0).getMemorySliceList().get(0);
|
||||||
|
assertEquals("001", deserializedSlice.getMemoryId());
|
||||||
|
assertEquals(Integer.valueOf(5), deserializedSlice.getMemoryRank());
|
||||||
|
assertEquals("/demo/path", deserializedSlice.getSlicePath());
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
166
src/test/java/memory/SearchTest.java
Normal file
166
src/test/java/memory/SearchTest.java
Normal file
@@ -0,0 +1,166 @@
|
|||||||
|
package memory;
|
||||||
|
|
||||||
|
import org.junit.jupiter.api.BeforeEach;
|
||||||
|
import org.junit.jupiter.api.Test;
|
||||||
|
import work.slhaf.memory.MemoryGraph;
|
||||||
|
import work.slhaf.memory.content.MemorySlice;
|
||||||
|
import work.slhaf.memory.exception.UnExistedTopicException;
|
||||||
|
import work.slhaf.memory.node.MemoryNode;
|
||||||
|
import work.slhaf.memory.node.TopicNode;
|
||||||
|
|
||||||
|
import java.time.LocalDate;
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
import static org.junit.jupiter.api.Assertions.*;
|
||||||
|
|
||||||
|
class SearchTest {
|
||||||
|
private MemoryGraph memoryGraph;
|
||||||
|
private final LocalDate today = LocalDate.now();
|
||||||
|
private final LocalDate yesterday = LocalDate.now().minusDays(1);
|
||||||
|
|
||||||
|
// 初始化测试环境,模拟插入基础数据
|
||||||
|
@BeforeEach
|
||||||
|
void setUp() {
|
||||||
|
memoryGraph = new MemoryGraph("testGraph");
|
||||||
|
|
||||||
|
// 构建基础主题路径:根主题 -> 编程 -> Java
|
||||||
|
List<String> javaPath = new ArrayList<>();
|
||||||
|
javaPath.add("编程");
|
||||||
|
javaPath.add("Java");
|
||||||
|
|
||||||
|
// 插入今天的Java相关记忆
|
||||||
|
MemorySlice javaMemory = createMemorySlice("java1");
|
||||||
|
memoryGraph.insertMemory(javaPath, javaMemory);
|
||||||
|
|
||||||
|
// 插入昨天的Java记忆(应不会出现在邻近结果中)
|
||||||
|
MemorySlice oldJavaMemory = createMemorySlice("javaOld");
|
||||||
|
MemoryNode oldNode = new MemoryNode();
|
||||||
|
oldNode.setLocalDate(yesterday);
|
||||||
|
oldNode.setMemorySliceList(List.of(oldJavaMemory));
|
||||||
|
}
|
||||||
|
|
||||||
|
// 场景1:查询存在的完整主题路径(含相关主题)
|
||||||
|
@Test
|
||||||
|
void selectMemory_shouldReturnTargetAndRelatedAndParentMemories() {
|
||||||
|
// 准备相关主题数据:根主题 -> 算法 -> 排序
|
||||||
|
List<String> sortPath = new ArrayList<>();
|
||||||
|
sortPath.add("算法");
|
||||||
|
sortPath.add("排序");
|
||||||
|
MemorySlice sortMemory = createMemorySlice("sort1");
|
||||||
|
sortMemory.setRelatedTopics(List.of(
|
||||||
|
createTopicPath("编程", "Java") // 设置反向关联
|
||||||
|
));
|
||||||
|
memoryGraph.insertMemory(sortPath, sortMemory);
|
||||||
|
|
||||||
|
// 执行查询:编程 -> Java
|
||||||
|
List<String> queryPath = new ArrayList<>();
|
||||||
|
queryPath.add("算法");
|
||||||
|
queryPath.add("排序");
|
||||||
|
List<MemorySlice> results = memoryGraph.selectMemory(queryPath);
|
||||||
|
|
||||||
|
// 验证结果应包含:
|
||||||
|
// 1. 目标节点所有记忆(java1)
|
||||||
|
// 2. 相关主题(排序)的最新记忆(sort1)
|
||||||
|
// 3. 父节点(编程)的最新记忆(需要提前插入)
|
||||||
|
assertTrue(results.stream().anyMatch(m -> "java1".equals(m.getMemoryId())));
|
||||||
|
assertTrue(results.stream().anyMatch(m -> "sort1".equals(m.getMemoryId())));
|
||||||
|
assertEquals(2, results.size()); // 根据具体实现可能调整
|
||||||
|
}
|
||||||
|
|
||||||
|
// 场景2:查询不存在的主题路径
|
||||||
|
@Test
|
||||||
|
void selectMemory_shouldThrowWhenPathNotExist() {
|
||||||
|
List<String> invalidPath = new ArrayList<>();
|
||||||
|
invalidPath.add("不存在的主题");
|
||||||
|
|
||||||
|
assertThrows(UnExistedTopicException.class, () -> {
|
||||||
|
memoryGraph.selectMemory(invalidPath);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
// 场景3:无相关主题时仅返回目标节点和父节点记忆
|
||||||
|
@Test
|
||||||
|
void selectMemory_withoutRelatedTopics_shouldReturnTargetAndParent() {
|
||||||
|
// 插入父级记忆:根主题 -> 编程
|
||||||
|
List<String> parentPath = new ArrayList<>();
|
||||||
|
parentPath.add("编程");
|
||||||
|
MemorySlice parentMemory = createMemorySlice("parent1");
|
||||||
|
memoryGraph.insertMemory(parentPath, parentMemory);
|
||||||
|
|
||||||
|
// 执行查询
|
||||||
|
List<String> queryPath = new ArrayList<>();
|
||||||
|
queryPath.add("编程");
|
||||||
|
queryPath.add("Java");
|
||||||
|
List<MemorySlice> results = memoryGraph.selectMemory(queryPath);
|
||||||
|
|
||||||
|
// 应包含:Java记忆 + 父级最新记忆
|
||||||
|
assertTrue(results.stream().anyMatch(m -> "java1".equals(m.getMemoryId())));
|
||||||
|
assertTrue(results.stream().anyMatch(m -> "parent1".equals(m.getMemoryId())));
|
||||||
|
assertEquals(2, results.size());
|
||||||
|
}
|
||||||
|
|
||||||
|
// 场景4:验证日期排序,应优先取最新日期的邻近记忆
|
||||||
|
@Test
|
||||||
|
void selectMemory_shouldGetLatestRelatedMemory() {
|
||||||
|
// 准备相关主题路径:根主题 -> 数据库
|
||||||
|
List<String> dbPath = new ArrayList<>();
|
||||||
|
dbPath.add("数据库");
|
||||||
|
dbPath.add("mysql");
|
||||||
|
|
||||||
|
// 插入今天的数据库记忆(正常流程)
|
||||||
|
MemorySlice newDbMemory = createMemorySlice("dbNew");
|
||||||
|
memoryGraph.insertMemory(dbPath, newDbMemory);
|
||||||
|
|
||||||
|
// 手动构建并插入昨天的数据库记忆
|
||||||
|
MemorySlice oldDbMemory = createMemorySlice("dbOld");
|
||||||
|
TopicNode dbTopicNode = memoryGraph.getTopicNodes().get("数据库");
|
||||||
|
|
||||||
|
// 创建昨日记忆节点并添加到主题节点
|
||||||
|
MemoryNode oldMemoryNode = new MemoryNode();
|
||||||
|
oldMemoryNode.setLocalDate(yesterday);
|
||||||
|
oldMemoryNode.setMemorySliceList(new ArrayList<>(List.of(oldDbMemory)));
|
||||||
|
dbTopicNode.getMemoryNodes().add(oldMemoryNode);
|
||||||
|
|
||||||
|
// 对记忆节点进行日期排序(根据compareTo方法)
|
||||||
|
dbTopicNode.getMemoryNodes().sort(null);
|
||||||
|
|
||||||
|
// 创建Java记忆并关联数据库主题
|
||||||
|
MemorySlice javaMemory = createMemorySlice("java2");
|
||||||
|
javaMemory.setRelatedTopics(List.of(
|
||||||
|
createTopicPath("数据库","") // 完整主题路径
|
||||||
|
));
|
||||||
|
memoryGraph.insertMemory(createTopicPath("编程", "Java"), javaMemory);
|
||||||
|
|
||||||
|
// 执行查询
|
||||||
|
List<String> queryPath = createTopicPath("编程", "Java");
|
||||||
|
List<MemorySlice> results = memoryGraph.selectMemory(queryPath);
|
||||||
|
|
||||||
|
// 验证结果应包含最新关联记忆(dbNew)
|
||||||
|
assertTrue(results.stream().anyMatch(m -> "dbNew".equals(m.getMemoryId())),
|
||||||
|
"应包含最新的数据库记忆");
|
||||||
|
assertFalse(results.stream().anyMatch(m -> "dbOld".equals(m.getMemoryId())),
|
||||||
|
"不应包含过期的数据库记忆");
|
||||||
|
|
||||||
|
// 验证结果包含目标记忆(java1和java2)
|
||||||
|
assertTrue(results.stream().anyMatch(m -> "java1".equals(m.getMemoryId())),
|
||||||
|
"应包含基础测试数据");
|
||||||
|
assertTrue(results.stream().anyMatch(m -> "java2".equals(m.getMemoryId())),
|
||||||
|
"应包含当前测试插入数据");
|
||||||
|
}
|
||||||
|
|
||||||
|
private MemorySlice createMemorySlice(String id) {
|
||||||
|
MemorySlice slice = new MemorySlice();
|
||||||
|
slice.setMemoryId(id);
|
||||||
|
slice.setMemoryRank(1);
|
||||||
|
return slice;
|
||||||
|
}
|
||||||
|
|
||||||
|
private ArrayList<String> createTopicPath(String... topics) {
|
||||||
|
ArrayList<String> path = new ArrayList<>();
|
||||||
|
for (String topic : topics) {
|
||||||
|
path.add(topic);
|
||||||
|
}
|
||||||
|
return path;
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user