Files
Partner/src/test/java/memory/InsertTest.java
slhaf 7594a1c43b - 在InteractionHub中新增了执行模块列表功能,将输出内容交给agent进行输出封装
- 移动 InteractionContext 和 InteractionModule 至本体项目
- 调整 InteractionContext 字段内容,目前已较为完善
- 新增了 PreprocessExecutor 和 MemoryUpdater
- 优化了代码结构,提高了模块化和可扩展性,模块化前遗留问题应该已解决完毕,主流程待实现
- 添加了线程池的单例实现
- 添加了模块加载器的外部模块加载功能
- 在 Model 中新增 singleChat 方法,用于流程模块的不保留上下文对话
- 将 MemoryManager 移动至 core 包下,因为 MemoryManager 将参与多个模块内部
- 将调取记忆、更新记忆功能抽取为独立模块,便于流程控制
- 添加了 TaskData 类,用于存储任务信息,后续需考虑TaskData的序列化机制
2025-04-18 22:19:04 +08:00

165 lines
6.5 KiB
Java
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package memory;
import org.junit.Before;
import org.junit.Test;
import work.slhaf.agent.core.memory.MemoryGraph;
import work.slhaf.agent.core.memory.pojo.MemorySlice;
import work.slhaf.agent.core.memory.node.MemoryNode;
import work.slhaf.agent.core.memory.node.TopicNode;
import java.io.IOException;
import java.time.LocalDate;
import java.util.Arrays;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;
import static org.junit.Assert.*;
public class InsertTest {
private MemoryGraph memoryGraph;
private final String testId = "test_insert";
@Before
public void setUp() {
memoryGraph = new MemoryGraph(testId);
memoryGraph.setTopicNodes(new HashMap<>());
memoryGraph.setExistedTopics(new HashMap<>());
}
@Test
public void testInsertMemory_NewRootTopic() throws IOException, ClassNotFoundException {
// 准备测试数据
List<String> topicPath = new LinkedList<>(Arrays.asList("Programming", "Java", "Collections"));
MemorySlice slice = createTestMemorySlice("slice1");
// 执行测试
memoryGraph.insertMemory(topicPath, slice);
// 验证结果
assertTrue(memoryGraph.getTopicNodes().containsKey("Programming"));
TopicNode programmingNode = memoryGraph.getTopicNodes().get("Programming");
assertTrue(programmingNode.getTopicNodes().containsKey("Java"));
TopicNode javaNode = programmingNode.getTopicNodes().get("Java");
assertTrue(javaNode.getTopicNodes().containsKey("Collections"));
TopicNode collectionsNode = javaNode.getTopicNodes().get("Collections");
assertEquals(1, collectionsNode.getMemoryNodes().size());
MemoryNode memoryNode = collectionsNode.getMemoryNodes().get(0);
assertEquals(LocalDate.now(), memoryNode.getLocalDate());
assertEquals(1, memoryNode.loadMemorySliceList().size());
assertEquals(slice, memoryNode.loadMemorySliceList().get(0));
}
@Test
public void testInsertMemory_ExistingTopicPath() throws IOException, ClassNotFoundException {
// 准备初始数据
List<String> topicPath1 = new LinkedList<>(Arrays.asList("Programming", "Java", "Collections"));
MemorySlice slice1 = createTestMemorySlice("slice1");
memoryGraph.insertMemory(topicPath1, slice1);
// 插入第二个记忆片段到相同路径
List<String> topicPath2 = new LinkedList<>(Arrays.asList("Programming", "Java", "Collections"));
MemorySlice slice2 = createTestMemorySlice("slice2");
memoryGraph.insertMemory(topicPath2, slice2);
// 验证结果
TopicNode collectionsNode = memoryGraph.getTopicNodes().get("Programming")
.getTopicNodes().get("Java")
.getTopicNodes().get("Collections");
assertEquals(1, collectionsNode.getMemoryNodes().size()); // 同一天应该只有一个MemoryNode
assertEquals(2, collectionsNode.getMemoryNodes().get(0).loadMemorySliceList().size()); // 但有两个MemorySlice
}
@Test
public void testInsertMemory_DifferentDays() throws IOException, ClassNotFoundException {
// 准备测试数据
List<String> topicPath = new LinkedList<>(Arrays.asList("Math", "Algebra"));
MemorySlice slice1 = createTestMemorySlice("slice1");
MemorySlice slice2 = createTestMemorySlice("slice2");
// 第一次插入
memoryGraph.insertMemory(topicPath, slice1);
// 模拟第二天
MemoryNode firstNode = memoryGraph.getTopicNodes().get("Math")
.getTopicNodes().get("Algebra")
.getMemoryNodes().get(0);
firstNode.setLocalDate(LocalDate.now().minusDays(1));
// 第二次插入
memoryGraph.insertMemory(topicPath, slice2);
// 验证结果
TopicNode algebraNode = memoryGraph.getTopicNodes().get("Math")
.getTopicNodes().get("Algebra");
assertEquals(2, algebraNode.getMemoryNodes().size()); // 应该有两个MemoryNode
}
@Test
public void testInsertMemory_PartialExistingPath() throws IOException, ClassNotFoundException {
// 准备初始数据 - 创建部分路径
List<String> topicPath1 = new LinkedList<>(Arrays.asList("Science", "Physics"));
MemorySlice slice1 = createTestMemorySlice("slice1");
memoryGraph.insertMemory(topicPath1, slice1);
// 插入到已存在路径的扩展路径
List<String> topicPath2 = new LinkedList<>(Arrays.asList("Science", "Physics", "Mechanics"));
MemorySlice slice2 = createTestMemorySlice("slice2");
memoryGraph.insertMemory(topicPath2, slice2);
// 验证结果
TopicNode physicsNode = memoryGraph.getTopicNodes().get("Science")
.getTopicNodes().get("Physics");
assertTrue(physicsNode.getTopicNodes().containsKey("Mechanics"));
assertEquals(1, physicsNode.getMemoryNodes().size()); // Physics节点有自己的记忆
assertEquals(1, physicsNode.getTopicNodes().get("Mechanics").getMemoryNodes().size()); // Mechanics节点也有记忆
}
private MemorySlice createTestMemorySlice(String id) {
MemorySlice slice = new MemorySlice();
slice.setMemoryId(id);
// 可以设置其他必要属性
return slice;
}
@Test
public void testSerializationConsistency() throws IOException, ClassNotFoundException {
// 构造 MemorySlice
MemorySlice slice = new MemorySlice();
slice.setMemoryId("001");
List<String> topicPath = Arrays.asList("生活", "学习", "Java");
// 插入 memory
memoryGraph.insertMemory(topicPath, slice);
memoryGraph.serialize();
// 反序列化
MemoryGraph loadedGraph = MemoryGraph.getInstance(testId);
// 校验topic 是否存在
assertNotNull(loadedGraph.getTopicNodes().get("生活"));
TopicNode lifeNode = loadedGraph.getTopicNodes().get("生活");
assertNotNull(lifeNode.getTopicNodes().get("学习"));
TopicNode studyNode = lifeNode.getTopicNodes().get("学习");
assertNotNull(studyNode.getTopicNodes().get("Java"));
TopicNode javaNode = studyNode.getTopicNodes().get("Java");
// 校验:是否存在 MemoryNode
assertFalse(javaNode.getMemoryNodes().isEmpty());
// 校验MemorySlice 内容一致
MemorySlice deserializedSlice = javaNode.getMemoryNodes().get(0).loadMemorySliceList().get(0);
assertEquals("001", deserializedSlice.getMemoryId());
}
}