mirror of
https://github.com/slhaf/Partner.git
synced 2026-05-12 08:43:02 +08:00
feat(memory): 增加记忆缓存功能并优化数据结构- 新增 memorySliceCache 和 memoryNodeCacheCounter 用于缓存记忆切片
- 优化数据结构,使用 ConcurrentHashMap 和 CopyOnWriteArrayList 替代 HashMap 和 ArrayList - 为 MemoryNode 添加唯一标识 memoryNodeId,可作为记忆节点文件名 - 更新 selectMemoryByPath 方法,增加缓存逻辑 - 修改 updateDialogMap 方法,优化用户对话缓存更新逻辑
This commit is contained in:
@@ -16,6 +16,8 @@ import java.nio.file.Paths;
|
||||
import java.time.LocalDate;
|
||||
import java.time.LocalDateTime;
|
||||
import java.util.*;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
import java.util.concurrent.CopyOnWriteArrayList;
|
||||
|
||||
@EqualsAndHashCode(callSuper = true)
|
||||
@Data
|
||||
@@ -38,13 +40,13 @@ public class MemoryGraph extends PersistableObject {
|
||||
* 用于存储已存在的主题列表,便于记忆查找, 使用根主题名称作为键, 子主题名称集合为值
|
||||
* 该部分在'主题提取LLM'的system prompt中常驻
|
||||
*/
|
||||
private HashMap<String, LinkedHashSet<String>> existedTopics;
|
||||
private HashMap<String /*根主题名*/, LinkedHashSet<String> /*子主题列表*/> existedTopics;
|
||||
|
||||
/**
|
||||
* 记忆节点的日期索引, 同一日期内按照对话id区分
|
||||
* 同时作为临时的同一对话切片容器, 用于为同一对话内的不同切片提供更新上下文的场所
|
||||
*/
|
||||
private HashMap<LocalDate, HashMap<String, List<MemorySlice>>> dateIndex;
|
||||
private HashMap<LocalDate, HashMap<String /*对话id, 即slice中的字段'memoryId'*/, List<MemorySlice>>> dateIndex;
|
||||
|
||||
/**
|
||||
* 近两日的对话总结缓存, 用于为大模型提供必要的记忆补充, hashmap以切片的存储时间为键,总结为值
|
||||
@@ -56,7 +58,7 @@ public class MemoryGraph extends PersistableObject {
|
||||
/**
|
||||
* 近两日的区分用户的对话总结缓存,在prompt结构上比dialogMap层级深一层, dialogMap更具近两日整体对话的摘要性质
|
||||
*/
|
||||
private HashMap<LocalDateTime,HashMap<String/*userId*/,String>> userDialogMap;
|
||||
private ConcurrentHashMap<LocalDateTime, ConcurrentHashMap<String/*userId*/, String>> userDialogMap;
|
||||
|
||||
/**
|
||||
* 当前对话的活动性总结, 拥有比dialogMap更丰富的全文细节, 作为当前对话token超限时的必要上下文压缩存储
|
||||
@@ -67,7 +69,23 @@ public class MemoryGraph extends PersistableObject {
|
||||
* 存储确定性记忆, 如'用户爱好'等确定性信息
|
||||
* 该部分作为'主LLM'system prompt常驻
|
||||
*/
|
||||
private HashMap<String /*userId*/, HashMap<String /*memoryKey*/,String /*memoryValue*/>> staticMemory;
|
||||
private HashMap<String /*userId*/, ConcurrentHashMap<String /*memoryKey*/, String /*memoryValue*/>> staticMemory;
|
||||
|
||||
/**
|
||||
* memorySliceCache计数器,每日清空
|
||||
*/
|
||||
private ConcurrentHashMap<List<String> /*触发查询的主题列表*/, Integer> memoryNodeCacheCounter;
|
||||
|
||||
/**
|
||||
* 记忆切片缓存,每日清空
|
||||
* 用于记录作为终点节点调用次数最多的记忆节点的切片数据
|
||||
*/
|
||||
private ConcurrentHashMap<List<String> /*主题路径*/, List<MemorySlice> /*切片列表*/> memorySliceCache;
|
||||
|
||||
/**
|
||||
* 缓存日期
|
||||
*/
|
||||
private LocalDate cacheDate;
|
||||
|
||||
public MemoryGraph(String id) {
|
||||
this.id = id;
|
||||
@@ -75,6 +93,8 @@ public class MemoryGraph extends PersistableObject {
|
||||
this.existedTopics = new HashMap<>();
|
||||
this.dateIndex = new HashMap<>();
|
||||
this.staticMemory = new HashMap<>();
|
||||
this.memoryNodeCacheCounter = new ConcurrentHashMap<>();
|
||||
this.memorySliceCache = new ConcurrentHashMap<>();
|
||||
}
|
||||
|
||||
public static MemoryGraph initialize(String id) {
|
||||
@@ -133,14 +153,18 @@ public class MemoryGraph extends PersistableObject {
|
||||
}
|
||||
|
||||
public void insertMemory(List<String> topicPath, MemorySlice slice) throws IOException, ClassNotFoundException {
|
||||
//每日刷新缓存
|
||||
checkCacheDate();
|
||||
//如果topicPath在memorySliceCache中存在对应缓存,由于进行的插入操作,则需要移除该缓存,但不清除相关计数
|
||||
memorySliceCache.remove(topicPath);
|
||||
topicPath = new ArrayList<>(topicPath);
|
||||
//查看是否存在根主题节点
|
||||
String rootTopic = topicPath.getFirst();
|
||||
topicPath.removeFirst();
|
||||
if (!topicNodes.containsKey(rootTopic)) {
|
||||
TopicNode rootNode = new TopicNode();
|
||||
rootNode.setMemoryNodes(new ArrayList<>());
|
||||
rootNode.setTopicNodes(new HashMap<>());
|
||||
rootNode.setMemoryNodes(new CopyOnWriteArrayList<>());
|
||||
rootNode.setTopicNodes(new ConcurrentHashMap<>());
|
||||
topicNodes.put(rootTopic, rootNode);
|
||||
existedTopics.put(rootTopic, new LinkedHashSet<>());
|
||||
}
|
||||
@@ -154,9 +178,9 @@ public class MemoryGraph extends PersistableObject {
|
||||
TopicNode newNode = new TopicNode();
|
||||
lastTopicNode.getTopicNodes().put(topic, newNode);
|
||||
lastTopicNode = newNode;
|
||||
List<MemoryNode> nodeList = new ArrayList<>();
|
||||
CopyOnWriteArrayList<MemoryNode> nodeList = new CopyOnWriteArrayList<>();
|
||||
lastTopicNode.setMemoryNodes(nodeList);
|
||||
lastTopicNode.setTopicNodes(new HashMap<>());
|
||||
lastTopicNode.setTopicNodes(new ConcurrentHashMap<>());
|
||||
existedTopicNodes.add(topic);
|
||||
}
|
||||
}
|
||||
@@ -174,6 +198,7 @@ public class MemoryGraph extends PersistableObject {
|
||||
if (!hasSlice) {
|
||||
node = new MemoryNode();
|
||||
node.setLocalDate(now);
|
||||
node.setMemoryNodeId(UUID.randomUUID().toString());
|
||||
node.setMemorySliceList(new ArrayList<>());
|
||||
lastTopicNode.getMemoryNodes().add(node);
|
||||
lastTopicNode.getMemoryNodes().sort(null);
|
||||
@@ -188,11 +213,12 @@ public class MemoryGraph extends PersistableObject {
|
||||
private void updateDialogMap(MemorySlice slice) {
|
||||
String summary = slice.getSummary();
|
||||
LocalDateTime now = LocalDateTime.now();
|
||||
//更新dialogMap
|
||||
|
||||
//更新dialogMap -------------------------
|
||||
//移除两天前的上下文缓存(切片总结)
|
||||
List<LocalDateTime> keysToRemove = new ArrayList<>();
|
||||
dialogMap.forEach((k, v) -> {
|
||||
if (now.minusDays(2).isAfter(k)){
|
||||
if (now.minusDays(2).isAfter(k)) {
|
||||
keysToRemove.add(k);
|
||||
}
|
||||
});
|
||||
@@ -201,11 +227,13 @@ public class MemoryGraph extends PersistableObject {
|
||||
}
|
||||
keysToRemove.clear();
|
||||
//放入新缓存
|
||||
dialogMap.put(now,summary);
|
||||
dialogMap.put(now, summary);
|
||||
//---------------------------------------
|
||||
|
||||
//更新userDialogMap
|
||||
//移除两天前上下文缓存(切片总结)
|
||||
userDialogMap.forEach((k,v) -> {
|
||||
if (now.minusDays(2).isAfter(k)){
|
||||
userDialogMap.forEach((k, v) -> {
|
||||
if (now.minusDays(2).isAfter(k)) {
|
||||
keysToRemove.add(k);
|
||||
}
|
||||
});
|
||||
@@ -213,7 +241,10 @@ public class MemoryGraph extends PersistableObject {
|
||||
userDialogMap.remove(dateTime);
|
||||
}
|
||||
//放入新缓存
|
||||
userDialogMap.get(now).put(slice.getStartUser(),slice.getSummary());
|
||||
userDialogMap
|
||||
.computeIfAbsent(now, k -> new ConcurrentHashMap<>())
|
||||
.merge(slice.getStartUser(), slice.getSummary(), (oldVal, newVal) -> oldVal + " " + newVal);
|
||||
|
||||
}
|
||||
|
||||
private void updateDateIndex(LocalDate now, MemorySlice slice) {
|
||||
@@ -245,6 +276,14 @@ public class MemoryGraph extends PersistableObject {
|
||||
}
|
||||
|
||||
public List<MemorySlice> selectMemoryByPath(List<String> topicPath) throws IOException, ClassNotFoundException {
|
||||
//每日刷新缓存
|
||||
checkCacheDate();
|
||||
//检测缓存并更新计数, 查看是否需要放入缓存
|
||||
updateCacheCounter(topicPath);
|
||||
//查看是否存在缓存,如果存在,则直接返回
|
||||
if (memorySliceCache.containsKey(topicPath)) {
|
||||
return memorySliceCache.get(topicPath);
|
||||
}
|
||||
List<MemorySlice> targetSliceList = new ArrayList<>();
|
||||
topicPath = new ArrayList<>(topicPath);
|
||||
String targetTopic = topicPath.getLast();
|
||||
@@ -277,10 +316,35 @@ public class MemoryGraph extends PersistableObject {
|
||||
if (!targetParentMemoryNodes.isEmpty()) {
|
||||
targetSliceList.addAll(targetParentMemoryNodes.getFirst().getMemorySliceList());
|
||||
}
|
||||
//放入缓存
|
||||
updateCache(topicPath, targetSliceList);
|
||||
return targetSliceList;
|
||||
}
|
||||
|
||||
public HashMap<String,List<MemorySlice>> selectMemoryByDate(LocalDate date){
|
||||
private void updateCache(List<String> topicPath, List<MemorySlice> targetSliceList) {
|
||||
Integer tempCount = memoryNodeCacheCounter.get(topicPath);
|
||||
if (tempCount >= 5) {
|
||||
memorySliceCache.put(topicPath, targetSliceList);
|
||||
}
|
||||
}
|
||||
|
||||
private void updateCacheCounter(List<String> topicPath) {
|
||||
if (memoryNodeCacheCounter.containsKey(topicPath)) {
|
||||
Integer tempCount = memoryNodeCacheCounter.get(topicPath);
|
||||
memoryNodeCacheCounter.put(topicPath, ++tempCount);
|
||||
} else {
|
||||
memoryNodeCacheCounter.put(topicPath, 1);
|
||||
}
|
||||
}
|
||||
|
||||
private void checkCacheDate() {
|
||||
if (cacheDate.isBefore(LocalDate.now())) {
|
||||
memorySliceCache.clear();
|
||||
memoryNodeCacheCounter.clear();
|
||||
}
|
||||
}
|
||||
|
||||
public HashMap<String, List<MemorySlice>> selectMemoryByDate(LocalDate date) {
|
||||
return dateIndex.get(date);
|
||||
}
|
||||
|
||||
|
||||
@@ -23,7 +23,12 @@ public class MemoryNode extends PersistableObject implements Comparable<MemoryNo
|
||||
private static String SLICE_DATA_DIR = "./data/slice/";
|
||||
|
||||
/**
|
||||
* 记忆节点所属日期, 以日期为文件名在硬盘存储记忆数据(如 2025-04-11.slice)
|
||||
* 记忆节点唯一标识, 用于作为实际文件名, 如(xxxx-xxxxx-xxxxx.slice)
|
||||
*/
|
||||
private String memoryNodeId;
|
||||
|
||||
/**
|
||||
* 记忆节点所属日期
|
||||
*/
|
||||
private LocalDate localDate;
|
||||
|
||||
@@ -44,10 +49,11 @@ public class MemoryNode extends PersistableObject implements Comparable<MemoryNo
|
||||
|
||||
public List<MemorySlice> getMemorySliceList() throws IOException, ClassNotFoundException {
|
||||
//检查是否存在对应文件
|
||||
File file = new File(SLICE_DATA_DIR+this.getLocalDate()+".slice");
|
||||
File file = new File(SLICE_DATA_DIR+this.getMemoryNodeId()+".slice");
|
||||
if (file.exists()){
|
||||
this.memorySliceList = deserialize(file);
|
||||
}else {
|
||||
//逻辑正常的话,这部分应该不会出现,除非在insertMemory中进行save操作之前出现异常,中断了方法,但程序却没有结束
|
||||
this.memorySliceList = new ArrayList<>();
|
||||
}
|
||||
return this.memorySliceList;
|
||||
@@ -57,7 +63,7 @@ public class MemoryNode extends PersistableObject implements Comparable<MemoryNo
|
||||
if (memorySliceList == null){
|
||||
throw new NullSliceListException("memorySliceList为NULL! 检查实现逻辑!");
|
||||
}
|
||||
File file = new File(SLICE_DATA_DIR+this.getLocalDate()+".slice");
|
||||
File file = new File(SLICE_DATA_DIR+this.getMemoryNodeId()+".slice");
|
||||
try(ObjectOutputStream oos = new ObjectOutputStream(new FileOutputStream(file))){
|
||||
oos.writeObject(this.memorySliceList);
|
||||
}
|
||||
|
||||
@@ -7,6 +7,8 @@ import work.slhaf.memory.pojo.PersistableObject;
|
||||
import java.io.Serial;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
import java.util.concurrent.CopyOnWriteArrayList;
|
||||
|
||||
@EqualsAndHashCode(callSuper = true)
|
||||
@Data
|
||||
@@ -15,6 +17,6 @@ public class TopicNode extends PersistableObject {
|
||||
@Serial
|
||||
private static final long serialVersionUID = 1L;
|
||||
|
||||
private HashMap<String,TopicNode> topicNodes;
|
||||
private List<MemoryNode> memoryNodes;
|
||||
private ConcurrentHashMap<String,TopicNode> topicNodes;
|
||||
private CopyOnWriteArrayList<MemoryNode> memoryNodes;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user