mirror of
https://github.com/slhaf/Partner.git
synced 2026-05-12 16:53:04 +08:00
添加dateIndex(记忆切片的日期索引)、dialogMap(近期对话缓存)、staticMemory(确定性记忆)等字段,并实现相关更新操作;
调整了MemorySlice中的部分结构; 添加了必要的注释;
This commit is contained in:
@@ -1,9 +1,6 @@
|
|||||||
package work.slhaf;
|
package work.slhaf;
|
||||||
|
|
||||||
import work.slhaf.memory.MemoryGraph;
|
import work.slhaf.memory.MemoryGraph;
|
||||||
import work.slhaf.memory.content.MemorySlice;
|
|
||||||
|
|
||||||
import java.util.Arrays;
|
|
||||||
|
|
||||||
public class Main {
|
public class Main {
|
||||||
public static void main(String[] args) {
|
public static void main(String[] args) {
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ import java.nio.file.Files;
|
|||||||
import java.nio.file.Path;
|
import java.nio.file.Path;
|
||||||
import java.nio.file.Paths;
|
import java.nio.file.Paths;
|
||||||
import java.time.LocalDate;
|
import java.time.LocalDate;
|
||||||
|
import java.time.LocalDateTime;
|
||||||
import java.util.*;
|
import java.util.*;
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
@@ -19,16 +20,44 @@ public class MemoryGraph implements Serializable {
|
|||||||
@Serial
|
@Serial
|
||||||
private static final long serialVersionUID = 1L;
|
private static final long serialVersionUID = 1L;
|
||||||
private static final String STORAGE_DIR = "./data/memory/";
|
private static final String STORAGE_DIR = "./data/memory/";
|
||||||
|
//todo: 实现记忆的短期缓存机制
|
||||||
private String id;
|
private String id;
|
||||||
|
/**
|
||||||
|
* key: 根主题名称 value: 根主题节点
|
||||||
|
*/
|
||||||
private HashMap<String, TopicNode> topicNodes;
|
private HashMap<String, TopicNode> topicNodes;
|
||||||
public static MemoryGraph memoryGraph;
|
public static MemoryGraph memoryGraph;
|
||||||
private HashMap<String, Set<String>> existedTopics;
|
|
||||||
|
/**
|
||||||
|
* 用于存储已存在的主题列表,便于记忆查找, 使用根主题名称作为键, 子主题名称集合为值
|
||||||
|
* 该部分在'主题提取LLM'的system prompt中常驻
|
||||||
|
*/
|
||||||
|
private HashMap<String, LinkedHashSet<String>> existedTopics;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 记忆节点的日期索引, 同一日期内按照对话id区分
|
||||||
|
* 同时作为临时的同一对话切片容器, 用于为同一对话内的不同切片提供更新上下文的场所
|
||||||
|
*/
|
||||||
|
private HashMap<LocalDate, HashMap<String, List<MemorySlice>>> dateIndex;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 近两日的对话总结缓存, 用于为大模型提供必要的记忆补充, hashmap以切片的存储时间为键,总结为值
|
||||||
|
* 该部分作为'主LLM'system prompt常驻
|
||||||
|
*/
|
||||||
|
private HashMap<LocalDateTime, String> dialogMap;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 存储确定性记忆, 如'用户爱好'等确定性信息
|
||||||
|
* 该部分作为'主LLM'system prompt常驻
|
||||||
|
*/
|
||||||
|
private HashMap<String, LinkedHashMap<LocalDate, String>> staticMemory;
|
||||||
|
|
||||||
public MemoryGraph(String id) {
|
public MemoryGraph(String id) {
|
||||||
this.id = id;
|
this.id = id;
|
||||||
this.topicNodes = new HashMap<>();
|
this.topicNodes = new HashMap<>();
|
||||||
this.existedTopics = new HashMap<>();
|
this.existedTopics = new HashMap<>();
|
||||||
|
this.dateIndex = new HashMap<>();
|
||||||
|
this.staticMemory = new HashMap<>();
|
||||||
}
|
}
|
||||||
|
|
||||||
public static MemoryGraph initialize(String id) {
|
public static MemoryGraph initialize(String id) {
|
||||||
@@ -88,9 +117,6 @@ public class MemoryGraph implements Serializable {
|
|||||||
|
|
||||||
public void insertMemory(List<String> topicPath, MemorySlice slice) {
|
public void insertMemory(List<String> topicPath, MemorySlice slice) {
|
||||||
topicPath = new ArrayList<>(topicPath);
|
topicPath = new ArrayList<>(topicPath);
|
||||||
if (topicNodes == null) {
|
|
||||||
topicNodes = new HashMap<>();
|
|
||||||
}
|
|
||||||
//查看是否存在根主题节点
|
//查看是否存在根主题节点
|
||||||
String rootTopic = topicPath.getFirst();
|
String rootTopic = topicPath.getFirst();
|
||||||
topicPath.removeFirst();
|
topicPath.removeFirst();
|
||||||
@@ -99,7 +125,7 @@ public class MemoryGraph implements Serializable {
|
|||||||
rootNode.setMemoryNodes(new ArrayList<>());
|
rootNode.setMemoryNodes(new ArrayList<>());
|
||||||
rootNode.setTopicNodes(new HashMap<>());
|
rootNode.setTopicNodes(new HashMap<>());
|
||||||
topicNodes.put(rootTopic, rootNode);
|
topicNodes.put(rootTopic, rootNode);
|
||||||
existedTopics.put(rootTopic, new HashSet<>());
|
existedTopics.put(rootTopic, new LinkedHashSet<>());
|
||||||
}
|
}
|
||||||
|
|
||||||
TopicNode lastTopicNode = topicNodes.get(rootTopic);
|
TopicNode lastTopicNode = topicNodes.get(rootTopic);
|
||||||
@@ -115,13 +141,9 @@ public class MemoryGraph implements Serializable {
|
|||||||
lastTopicNode.setMemoryNodes(nodeList);
|
lastTopicNode.setMemoryNodes(nodeList);
|
||||||
lastTopicNode.setTopicNodes(new HashMap<>());
|
lastTopicNode.setTopicNodes(new HashMap<>());
|
||||||
existedTopicNodes.add(topic);
|
existedTopicNodes.add(topic);
|
||||||
/*if (i == topicPath.size() - 1) {
|
|
||||||
lastTopicNode.setMemoryNodes(new ArrayList<>());
|
|
||||||
lastTopicNode.setTopicNodes(new HashMap<>());
|
|
||||||
}*/
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
//检查是否存在当天对应的memoryData
|
//检查是否存在当天对应的memorySlice
|
||||||
LocalDate now = LocalDate.now();
|
LocalDate now = LocalDate.now();
|
||||||
boolean hasSlice = false;
|
boolean hasSlice = false;
|
||||||
MemoryNode node = null;
|
MemoryNode node = null;
|
||||||
@@ -140,9 +162,56 @@ public class MemoryGraph implements Serializable {
|
|||||||
lastTopicNode.getMemoryNodes().sort(null);
|
lastTopicNode.getMemoryNodes().sort(null);
|
||||||
}
|
}
|
||||||
node.getMemorySliceList().add(slice);
|
node.getMemorySliceList().add(slice);
|
||||||
|
|
||||||
|
updateDateIndex(now, slice);
|
||||||
|
updateDialogMap(slice);
|
||||||
}
|
}
|
||||||
|
|
||||||
public List<MemorySlice> selectMemory(List<String> topicPath) {
|
private void updateDialogMap(MemorySlice slice) {
|
||||||
|
String summary = slice.getSliceData().getSummary();
|
||||||
|
LocalDateTime now = LocalDateTime.now();
|
||||||
|
//移除两天前的上下文补充(切片总结)
|
||||||
|
List<LocalDateTime> keysToRemove = new ArrayList<>();
|
||||||
|
dialogMap.forEach((k, v) -> {
|
||||||
|
if (now.minusDays(2).isAfter(k)){
|
||||||
|
keysToRemove.add(k);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
for (LocalDateTime dateTime : keysToRemove) {
|
||||||
|
dialogMap.remove(dateTime);
|
||||||
|
}
|
||||||
|
dialogMap.put(now,summary);
|
||||||
|
}
|
||||||
|
|
||||||
|
private void updateDateIndex(LocalDate now, MemorySlice slice) {
|
||||||
|
String memoryId = slice.getMemoryId();
|
||||||
|
//查看是否存在当前日期的对话切片索引
|
||||||
|
if (!dateIndex.containsKey(now)) {
|
||||||
|
dateIndex.put(now, new HashMap<>());
|
||||||
|
}
|
||||||
|
//查看当前日期的索引中是否存在该对话的索引
|
||||||
|
HashMap<String, List<MemorySlice>> currentDateDialogSlices = dateIndex.get(now);
|
||||||
|
if (!currentDateDialogSlices.containsKey(memoryId)) {
|
||||||
|
List<MemorySlice> memorySliceList = new ArrayList<>();
|
||||||
|
currentDateDialogSlices.put(memoryId, memorySliceList);
|
||||||
|
}
|
||||||
|
//处理上下文关系
|
||||||
|
List<MemorySlice> memorySliceList = currentDateDialogSlices.get(memoryId);
|
||||||
|
if (memorySliceList.isEmpty()) {
|
||||||
|
memorySliceList.add(slice);
|
||||||
|
} else {
|
||||||
|
//排序
|
||||||
|
memorySliceList.sort(null);
|
||||||
|
MemorySlice tempSlice = memorySliceList.getLast();
|
||||||
|
//末尾切片添加当前切片的引用
|
||||||
|
tempSlice.setSliceAfter(slice);
|
||||||
|
//当前切片添加前序切片的引用
|
||||||
|
slice.setSliceBefore(tempSlice);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
public List<MemorySlice> selectMemoryByPath(List<String> topicPath) {
|
||||||
List<MemorySlice> targetSliceList = new ArrayList<>();
|
List<MemorySlice> targetSliceList = new ArrayList<>();
|
||||||
topicPath = new ArrayList<>(topicPath);
|
topicPath = new ArrayList<>(topicPath);
|
||||||
String targetTopic = topicPath.getLast();
|
String targetTopic = topicPath.getLast();
|
||||||
@@ -178,6 +247,10 @@ public class MemoryGraph implements Serializable {
|
|||||||
return targetSliceList;
|
return targetSliceList;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public HashMap<String,List<MemorySlice>> selectMemoryByDate(LocalDate date){
|
||||||
|
return dateIndex.get(date);
|
||||||
|
}
|
||||||
|
|
||||||
private TopicNode getTargetParentNode(List<String> topicPath, String targetTopic) {
|
private TopicNode getTargetParentNode(List<String> topicPath, String targetTopic) {
|
||||||
String topTopic = topicPath.getFirst();
|
String topTopic = topicPath.getFirst();
|
||||||
if (!existedTopics.containsKey(topTopic)) {
|
if (!existedTopics.containsKey(topTopic)) {
|
||||||
|
|||||||
@@ -1,21 +1,38 @@
|
|||||||
package work.slhaf.memory.content;
|
package work.slhaf.memory.content;
|
||||||
|
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
import work.slhaf.memory.node.TopicNode;
|
|
||||||
|
|
||||||
import java.io.Serializable;
|
import java.io.Serializable;
|
||||||
import java.util.LinkedHashMap;
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
public class MemorySlice implements Serializable {
|
public class MemorySlice implements Serializable, Comparable<MemorySlice> {
|
||||||
//关联的完整对话的id
|
//关联的完整对话的id
|
||||||
private String memoryId;
|
private String memoryId;
|
||||||
//该切片在关联的完整对话中的顺序
|
//该切片在关联的完整对话中的顺序, 由时间戳确定
|
||||||
private Integer memoryRank;
|
private Long timestamp;
|
||||||
private String slicePath;
|
private String slicePath;
|
||||||
private List<List<String>> relatedTopics;
|
private List<List<String>> relatedTopics;
|
||||||
//关联完整对话中的前序切片, 排序为键,完整路径为值
|
//关联完整对话中的前序切片, 排序为键,完整路径为值
|
||||||
private LinkedHashMap<Integer,String> sliceBefore;
|
private MemorySlice sliceBefore;
|
||||||
private LinkedHashMap<Integer,String> sliceAfter;
|
private MemorySlice sliceAfter;
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int compareTo(MemorySlice memorySlice) {
|
||||||
|
if (memorySlice.getTimestamp() > this.getTimestamp()) {
|
||||||
|
return -1;
|
||||||
|
} else if (memorySlice.getTimestamp() < this.timestamp) {
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
public SliceData getSliceData(){
|
||||||
|
//todo: 待实现获取逻辑
|
||||||
|
return new SliceData();
|
||||||
|
}
|
||||||
|
|
||||||
|
public void saveSlice(SliceData sliceData){
|
||||||
|
//todo: 待实现存储逻辑, 该逻辑内将设置`slicePath`
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
10
src/main/java/work/slhaf/memory/content/SliceData.java
Normal file
10
src/main/java/work/slhaf/memory/content/SliceData.java
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
package work.slhaf.memory.content;
|
||||||
|
|
||||||
|
import com.alibaba.fastjson2.JSONArray;
|
||||||
|
import lombok.Data;
|
||||||
|
|
||||||
|
@Data
|
||||||
|
public class SliceData {
|
||||||
|
private String summary;
|
||||||
|
private JSONArray content;
|
||||||
|
}
|
||||||
@@ -9,6 +9,5 @@ import java.util.List;
|
|||||||
@Data
|
@Data
|
||||||
public class TopicNode implements Serializable {
|
public class TopicNode implements Serializable {
|
||||||
private HashMap<String,TopicNode> topicNodes;
|
private HashMap<String,TopicNode> topicNodes;
|
||||||
// private Integer weight = 0;
|
|
||||||
private List<MemoryNode> memoryNodes;
|
private List<MemoryNode> memoryNodes;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -124,7 +124,6 @@ public class InsertTest {
|
|||||||
private MemorySlice createTestMemorySlice(String id) {
|
private MemorySlice createTestMemorySlice(String id) {
|
||||||
MemorySlice slice = new MemorySlice();
|
MemorySlice slice = new MemorySlice();
|
||||||
slice.setMemoryId(id);
|
slice.setMemoryId(id);
|
||||||
slice.setMemoryRank(1);
|
|
||||||
// 可以设置其他必要属性
|
// 可以设置其他必要属性
|
||||||
return slice;
|
return slice;
|
||||||
}
|
}
|
||||||
@@ -134,7 +133,6 @@ public class InsertTest {
|
|||||||
// 构造 MemorySlice
|
// 构造 MemorySlice
|
||||||
MemorySlice slice = new MemorySlice();
|
MemorySlice slice = new MemorySlice();
|
||||||
slice.setMemoryId("001");
|
slice.setMemoryId("001");
|
||||||
slice.setMemoryRank(5);
|
|
||||||
slice.setSlicePath("/demo/path");
|
slice.setSlicePath("/demo/path");
|
||||||
|
|
||||||
List<String> topicPath = Arrays.asList("生活", "学习", "Java");
|
List<String> topicPath = Arrays.asList("生活", "学习", "Java");
|
||||||
@@ -162,7 +160,6 @@ public class InsertTest {
|
|||||||
// 校验:MemorySlice 内容一致
|
// 校验:MemorySlice 内容一致
|
||||||
MemorySlice deserializedSlice = javaNode.getMemoryNodes().get(0).getMemorySliceList().get(0);
|
MemorySlice deserializedSlice = javaNode.getMemoryNodes().get(0).getMemorySliceList().get(0);
|
||||||
assertEquals("001", deserializedSlice.getMemoryId());
|
assertEquals("001", deserializedSlice.getMemoryId());
|
||||||
assertEquals(Integer.valueOf(5), deserializedSlice.getMemoryRank());
|
|
||||||
assertEquals("/demo/path", deserializedSlice.getSlicePath());
|
assertEquals("/demo/path", deserializedSlice.getSlicePath());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -57,7 +57,7 @@ class SearchTest {
|
|||||||
List<String> queryPath = new ArrayList<>();
|
List<String> queryPath = new ArrayList<>();
|
||||||
queryPath.add("算法");
|
queryPath.add("算法");
|
||||||
queryPath.add("排序");
|
queryPath.add("排序");
|
||||||
List<MemorySlice> results = memoryGraph.selectMemory(queryPath);
|
List<MemorySlice> results = memoryGraph.selectMemoryByPath(queryPath);
|
||||||
|
|
||||||
// 验证结果应包含:
|
// 验证结果应包含:
|
||||||
// 1. 目标节点所有记忆(java1)
|
// 1. 目标节点所有记忆(java1)
|
||||||
@@ -75,7 +75,7 @@ class SearchTest {
|
|||||||
invalidPath.add("不存在的主题");
|
invalidPath.add("不存在的主题");
|
||||||
|
|
||||||
assertThrows(UnExistedTopicException.class, () -> {
|
assertThrows(UnExistedTopicException.class, () -> {
|
||||||
memoryGraph.selectMemory(invalidPath);
|
memoryGraph.selectMemoryByPath(invalidPath);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -92,7 +92,7 @@ class SearchTest {
|
|||||||
List<String> queryPath = new ArrayList<>();
|
List<String> queryPath = new ArrayList<>();
|
||||||
queryPath.add("编程");
|
queryPath.add("编程");
|
||||||
queryPath.add("Java");
|
queryPath.add("Java");
|
||||||
List<MemorySlice> results = memoryGraph.selectMemory(queryPath);
|
List<MemorySlice> results = memoryGraph.selectMemoryByPath(queryPath);
|
||||||
|
|
||||||
// 应包含:Java记忆 + 父级最新记忆
|
// 应包含:Java记忆 + 父级最新记忆
|
||||||
assertTrue(results.stream().anyMatch(m -> "java1".equals(m.getMemoryId())));
|
assertTrue(results.stream().anyMatch(m -> "java1".equals(m.getMemoryId())));
|
||||||
@@ -134,7 +134,7 @@ class SearchTest {
|
|||||||
|
|
||||||
// 执行查询
|
// 执行查询
|
||||||
List<String> queryPath = createTopicPath("编程", "Java");
|
List<String> queryPath = createTopicPath("编程", "Java");
|
||||||
List<MemorySlice> results = memoryGraph.selectMemory(queryPath);
|
List<MemorySlice> results = memoryGraph.selectMemoryByPath(queryPath);
|
||||||
|
|
||||||
// 验证结果应包含最新关联记忆(dbNew)
|
// 验证结果应包含最新关联记忆(dbNew)
|
||||||
assertTrue(results.stream().anyMatch(m -> "dbNew".equals(m.getMemoryId())),
|
assertTrue(results.stream().anyMatch(m -> "dbNew".equals(m.getMemoryId())),
|
||||||
@@ -152,7 +152,6 @@ class SearchTest {
|
|||||||
private MemorySlice createMemorySlice(String id) {
|
private MemorySlice createMemorySlice(String id) {
|
||||||
MemorySlice slice = new MemorySlice();
|
MemorySlice slice = new MemorySlice();
|
||||||
slice.setMemoryId(id);
|
slice.setMemoryId(id);
|
||||||
slice.setMemoryRank(1);
|
|
||||||
return slice;
|
return slice;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user