实现了MemoryGraph的序列化/反序列化;

实现了MemoryGraph的插入功能;
This commit is contained in:
2025-04-08 22:31:48 +08:00
commit cad3af346f
12 changed files with 464 additions and 0 deletions

View File

@@ -0,0 +1,13 @@
package work.slhaf;
import work.slhaf.memory.MemoryGraph;
import work.slhaf.memory.content.MemorySlice;
import java.util.Arrays;
public class Main {
public static void main(String[] args) {
MemoryGraph graph = MemoryGraph.initialize("test");
}
}

View File

@@ -0,0 +1,146 @@
package work.slhaf.memory;
import lombok.Data;
import work.slhaf.memory.content.MemorySlice;
import work.slhaf.memory.node.MemoryNode;
import work.slhaf.memory.node.TopicNode;
import java.io.*;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.time.LocalDateTime;
import java.util.*;
@Data
public class MemoryGraph implements Serializable {
@Serial
private static final long serialVersionUID = 1L;
private static final String STORAGE_DIR = "./data/memory/";
private String id;
private HashMap<String,TopicNode> topicNodes;
public static MemoryGraph memoryGraph;
private HashMap<String,Set<String>> existedTopics;
public MemoryGraph(String id) {
this.id = id;
this.topicNodes = new HashMap<>();
this.existedTopics = new HashMap<>();
}
public static MemoryGraph initialize(String id) {
// 检查存储目录是否存在,不存在则创建
createStorageDirectory();
Path filePath = getFilePath(id);
if (Files.exists(filePath)) {
try {
// 从文件加载
return deserialize(id);
} catch (Exception e) {
System.err.println("加载序列化文件失败,创建新实例: " + e.getMessage());
return new MemoryGraph(id);
}
} else {
// 创建新实例
return new MemoryGraph(id);
}
}
public void serialize() {
Path filePath = getFilePath(this.id);
try (ObjectOutputStream oos = new ObjectOutputStream(
new FileOutputStream(filePath.toFile()))) {
oos.writeObject(this);
System.out.println("MemoryGraph 已保存到: " + filePath);
} catch (IOException e) {
System.err.println("序列化保存失败: " + e.getMessage());
}
}
private static MemoryGraph deserialize(String id) throws IOException, ClassNotFoundException {
Path filePath = getFilePath(id);
try (ObjectInputStream ois = new ObjectInputStream(
new FileInputStream(filePath.toFile()))) {
MemoryGraph graph = (MemoryGraph) ois.readObject();
System.out.println("MemoryGraph 已从文件加载: " + filePath);
return graph;
}
}
private static Path getFilePath(String id) {
return Paths.get(STORAGE_DIR, id + ".memory");
}
private static void createStorageDirectory() {
try {
Files.createDirectories(Paths.get(STORAGE_DIR));
} catch (IOException e) {
System.err.println("创建存储目录失败: " + e.getMessage());
}
}
public void insertMemory(List<String> topicPath, MemorySlice slice) {
topicPath = new ArrayList<>(topicPath);
if (topicNodes == null){
topicNodes = new HashMap<>();
}
//查看是否存在根主题节点
String rootTopic = topicPath.getFirst();
topicPath.removeFirst();
if (!topicNodes.containsKey(rootTopic)) {
TopicNode rootNode = new TopicNode();
rootNode.setMemoryNodes(new ArrayList<>());
rootNode.setTopicNodes(new HashMap<>());
topicNodes.put(rootTopic,rootNode);
existedTopics.put(rootTopic,new HashSet<>());
}
TopicNode lastTopicNode = topicNodes.get(rootTopic);
Set<String> existedTopicNodes = existedTopics.get(rootTopic);
for (int i = 0; i < topicPath.size(); i++) {
String topic = topicPath.get(i);
if (existedTopicNodes.contains(topic)) {
lastTopicNode = lastTopicNode.getTopicNodes().get(topic);
}else {
TopicNode newNode = new TopicNode();
lastTopicNode.getTopicNodes().put(topic, newNode);
lastTopicNode = newNode;
List<MemoryNode> nodeList = new ArrayList<>();
lastTopicNode.setMemoryNodes(nodeList);
lastTopicNode.setTopicNodes(new HashMap<>());
existedTopicNodes.add(topic);
/*if (i == topicPath.size() - 1) {
lastTopicNode.setMemoryNodes(new ArrayList<>());
lastTopicNode.setTopicNodes(new HashMap<>());
}*/
}
}
//检查是否存在当天对应的memoryData
LocalDateTime now = LocalDateTime.now();
boolean hasSlice = false;
MemoryNode node = null;
for (MemoryNode memoryNode : lastTopicNode.getMemoryNodes()) {
if (now.toLocalDate().equals(memoryNode.getLocalDateTime().toLocalDate())){
hasSlice = true;
node = memoryNode;
break;
}
}
if (!hasSlice) {
node = new MemoryNode();
node.setLocalDateTime(now);
node.setMemorySliceList(new ArrayList<>());
lastTopicNode.getMemoryNodes().add(node);
}
node.getMemorySliceList().add(slice);
}
}

View File

@@ -0,0 +1,18 @@
package work.slhaf.memory.content;
import lombok.Data;
import work.slhaf.memory.node.TopicNode;
import java.io.Serializable;
import java.util.LinkedHashMap;
import java.util.List;
@Data
public class MemorySlice implements Serializable {
private String memoryId;
private Integer memoryRank;
private String slicePath;
private List<TopicNode> relatedTopics;
private LinkedHashMap<Integer,String> sliceBefore;
private LinkedHashMap<Integer,String> sliceAfter;
}

View File

@@ -0,0 +1,14 @@
package work.slhaf.memory.node;
import lombok.Data;
import work.slhaf.memory.content.MemorySlice;
import java.io.Serializable;
import java.time.LocalDateTime;
import java.util.List;
@Data
public class MemoryNode implements Serializable {
private LocalDateTime localDateTime;
private List<MemorySlice> memorySliceList;
}

View File

@@ -0,0 +1,13 @@
package work.slhaf.memory.node;
import lombok.Data;
import java.io.Serializable;
import java.util.HashMap;
import java.util.List;
@Data
public class TopicNode implements Serializable {
private HashMap<String,TopicNode> topicNodes;
private List<MemoryNode> memoryNodes;
}

View File

@@ -0,0 +1,132 @@
package memory;
import org.junit.Before;
import org.junit.Test;
import org.junit.jupiter.api.BeforeEach;
import work.slhaf.memory.MemoryGraph;
import work.slhaf.memory.content.MemorySlice;
import work.slhaf.memory.node.MemoryNode;
import work.slhaf.memory.node.TopicNode;
import java.time.LocalDateTime;
import java.util.Arrays;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;
import static org.junit.Assert.*;
public class InsertTest {
private MemoryGraph memoryGraph;
private final String testId = "test";
@Before
public void setUp() {
memoryGraph = new MemoryGraph(testId);
memoryGraph.setTopicNodes(new HashMap<>());
memoryGraph.setExistedTopics(new HashMap<>());
}
@Test
public void testInsertMemory_NewRootTopic() {
// 准备测试数据
List<String> topicPath = new LinkedList<>(Arrays.asList("Programming", "Java", "Collections"));
MemorySlice slice = createTestMemorySlice("slice1");
// 执行测试
memoryGraph.insertMemory(topicPath, slice);
// 验证结果
assertTrue(memoryGraph.getTopicNodes().containsKey("Programming"));
TopicNode programmingNode = memoryGraph.getTopicNodes().get("Programming");
assertTrue(programmingNode.getTopicNodes().containsKey("Java"));
TopicNode javaNode = programmingNode.getTopicNodes().get("Java");
assertTrue(javaNode.getTopicNodes().containsKey("Collections"));
TopicNode collectionsNode = javaNode.getTopicNodes().get("Collections");
assertEquals(1, collectionsNode.getMemoryNodes().size());
MemoryNode memoryNode = collectionsNode.getMemoryNodes().get(0);
assertEquals(LocalDateTime.now().toLocalDate(), memoryNode.getLocalDateTime().toLocalDate());
assertEquals(1, memoryNode.getMemorySliceList().size());
assertEquals(slice, memoryNode.getMemorySliceList().get(0));
}
@Test
public void testInsertMemory_ExistingTopicPath() {
// 准备初始数据
List<String> topicPath1 = new LinkedList<>(Arrays.asList("Programming", "Java", "Collections"));
MemorySlice slice1 = createTestMemorySlice("slice1");
memoryGraph.insertMemory(topicPath1, slice1);
// 插入第二个记忆片段到相同路径
List<String> topicPath2 = new LinkedList<>(Arrays.asList("Programming", "Java", "Collections"));
MemorySlice slice2 = createTestMemorySlice("slice2");
memoryGraph.insertMemory(topicPath2, slice2);
// 验证结果
TopicNode collectionsNode = memoryGraph.getTopicNodes().get("Programming")
.getTopicNodes().get("Java")
.getTopicNodes().get("Collections");
assertEquals(1, collectionsNode.getMemoryNodes().size()); // 同一天应该只有一个MemoryNode
assertEquals(2, collectionsNode.getMemoryNodes().get(0).getMemorySliceList().size()); // 但有两个MemorySlice
}
@Test
public void testInsertMemory_DifferentDays() {
// 准备测试数据
List<String> topicPath = new LinkedList<>(Arrays.asList("Math", "Algebra"));
MemorySlice slice1 = createTestMemorySlice("slice1");
MemorySlice slice2 = createTestMemorySlice("slice2");
// 第一次插入
memoryGraph.insertMemory(topicPath, slice1);
// 模拟第二天
MemoryNode firstNode = memoryGraph.getTopicNodes().get("Math")
.getTopicNodes().get("Algebra")
.getMemoryNodes().get(0);
firstNode.setLocalDateTime(LocalDateTime.now().minusDays(1));
// 第二次插入
memoryGraph.insertMemory(topicPath, slice2);
// 验证结果
TopicNode algebraNode = memoryGraph.getTopicNodes().get("Math")
.getTopicNodes().get("Algebra");
assertEquals(2, algebraNode.getMemoryNodes().size()); // 应该有两个MemoryNode
}
@Test
public void testInsertMemory_PartialExistingPath() {
// 准备初始数据 - 创建部分路径
List<String> topicPath1 = new LinkedList<>(Arrays.asList("Science", "Physics"));
MemorySlice slice1 = createTestMemorySlice("slice1");
memoryGraph.insertMemory(topicPath1, slice1);
// 插入到已存在路径的扩展路径
List<String> topicPath2 = new LinkedList<>(Arrays.asList("Science", "Physics", "Mechanics"));
MemorySlice slice2 = createTestMemorySlice("slice2");
memoryGraph.insertMemory(topicPath2, slice2);
// 验证结果
TopicNode physicsNode = memoryGraph.getTopicNodes().get("Science")
.getTopicNodes().get("Physics");
assertTrue(physicsNode.getTopicNodes().containsKey("Mechanics"));
assertEquals(1, physicsNode.getMemoryNodes().size()); // Physics节点有自己的记忆
assertEquals(1, physicsNode.getTopicNodes().get("Mechanics").getMemoryNodes().size()); // Mechanics节点也有记忆
}
private MemorySlice createTestMemorySlice(String id) {
MemorySlice slice = new MemorySlice();
slice.setMemoryId(id);
slice.setMemoryRank(1);
// 可以设置其他必要属性
return slice;
}
}