feat(memory): 增加记忆缓存功能并优化数据结构- 新增 memorySliceCache 和 memoryNodeCacheCounter 用于缓存记忆切片

- 优化数据结构,使用 ConcurrentHashMap 和 CopyOnWriteArrayList 替代 HashMap 和 ArrayList
- 为 MemoryNode 添加唯一标识 memoryNodeId,可作为记忆节点文件名
- 更新 selectMemoryByPath 方法,增加缓存逻辑
- 修改 updateDialogMap 方法,优化用户对话缓存更新逻辑
This commit is contained in:
2025-04-12 15:26:13 +08:00
parent ae4859004f
commit 6f643b525f
3 changed files with 92 additions and 20 deletions

View File

@@ -16,6 +16,8 @@ import java.nio.file.Paths;
import java.time.LocalDate; import java.time.LocalDate;
import java.time.LocalDateTime; import java.time.LocalDateTime;
import java.util.*; import java.util.*;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.CopyOnWriteArrayList;
@EqualsAndHashCode(callSuper = true) @EqualsAndHashCode(callSuper = true)
@Data @Data
@@ -38,13 +40,13 @@ public class MemoryGraph extends PersistableObject {
* 用于存储已存在的主题列表,便于记忆查找, 使用根主题名称作为键, 子主题名称集合为值 * 用于存储已存在的主题列表,便于记忆查找, 使用根主题名称作为键, 子主题名称集合为值
* 该部分在'主题提取LLM'的system prompt中常驻 * 该部分在'主题提取LLM'的system prompt中常驻
*/ */
private HashMap<String, LinkedHashSet<String>> existedTopics; private HashMap<String /*根主题名*/, LinkedHashSet<String> /*子主题列表*/> existedTopics;
/** /**
* 记忆节点的日期索引, 同一日期内按照对话id区分 * 记忆节点的日期索引, 同一日期内按照对话id区分
* 同时作为临时的同一对话切片容器, 用于为同一对话内的不同切片提供更新上下文的场所 * 同时作为临时的同一对话切片容器, 用于为同一对话内的不同切片提供更新上下文的场所
*/ */
private HashMap<LocalDate, HashMap<String, List<MemorySlice>>> dateIndex; private HashMap<LocalDate, HashMap<String /*对话id, 即slice中的字段'memoryId'*/, List<MemorySlice>>> dateIndex;
/** /**
* 近两日的对话总结缓存, 用于为大模型提供必要的记忆补充, hashmap以切片的存储时间为键总结为值 * 近两日的对话总结缓存, 用于为大模型提供必要的记忆补充, hashmap以切片的存储时间为键总结为值
@@ -56,7 +58,7 @@ public class MemoryGraph extends PersistableObject {
/** /**
* 近两日的区分用户的对话总结缓存在prompt结构上比dialogMap层级深一层, dialogMap更具近两日整体对话的摘要性质 * 近两日的区分用户的对话总结缓存在prompt结构上比dialogMap层级深一层, dialogMap更具近两日整体对话的摘要性质
*/ */
private HashMap<LocalDateTime,HashMap<String/*userId*/,String>> userDialogMap; private ConcurrentHashMap<LocalDateTime, ConcurrentHashMap<String/*userId*/, String>> userDialogMap;
/** /**
* 当前对话的活动性总结, 拥有比dialogMap更丰富的全文细节, 作为当前对话token超限时的必要上下文压缩存储 * 当前对话的活动性总结, 拥有比dialogMap更丰富的全文细节, 作为当前对话token超限时的必要上下文压缩存储
@@ -67,7 +69,23 @@ public class MemoryGraph extends PersistableObject {
* 存储确定性记忆, 如'用户爱好'等确定性信息 * 存储确定性记忆, 如'用户爱好'等确定性信息
* 该部分作为'主LLM'system prompt常驻 * 该部分作为'主LLM'system prompt常驻
*/ */
private HashMap<String /*userId*/, HashMap<String /*memoryKey*/,String /*memoryValue*/>> staticMemory; private HashMap<String /*userId*/, ConcurrentHashMap<String /*memoryKey*/, String /*memoryValue*/>> staticMemory;
/**
* memorySliceCache计数器每日清空
*/
private ConcurrentHashMap<List<String> /*触发查询的主题列表*/, Integer> memoryNodeCacheCounter;
/**
* 记忆切片缓存,每日清空
* 用于记录作为终点节点调用次数最多的记忆节点的切片数据
*/
private ConcurrentHashMap<List<String> /*主题路径*/, List<MemorySlice> /*切片列表*/> memorySliceCache;
/**
* 缓存日期
*/
private LocalDate cacheDate;
public MemoryGraph(String id) { public MemoryGraph(String id) {
this.id = id; this.id = id;
@@ -75,6 +93,8 @@ public class MemoryGraph extends PersistableObject {
this.existedTopics = new HashMap<>(); this.existedTopics = new HashMap<>();
this.dateIndex = new HashMap<>(); this.dateIndex = new HashMap<>();
this.staticMemory = new HashMap<>(); this.staticMemory = new HashMap<>();
this.memoryNodeCacheCounter = new ConcurrentHashMap<>();
this.memorySliceCache = new ConcurrentHashMap<>();
} }
public static MemoryGraph initialize(String id) { public static MemoryGraph initialize(String id) {
@@ -133,14 +153,18 @@ public class MemoryGraph extends PersistableObject {
} }
public void insertMemory(List<String> topicPath, MemorySlice slice) throws IOException, ClassNotFoundException { public void insertMemory(List<String> topicPath, MemorySlice slice) throws IOException, ClassNotFoundException {
//每日刷新缓存
checkCacheDate();
//如果topicPath在memorySliceCache中存在对应缓存由于进行的插入操作则需要移除该缓存但不清除相关计数
memorySliceCache.remove(topicPath);
topicPath = new ArrayList<>(topicPath); topicPath = new ArrayList<>(topicPath);
//查看是否存在根主题节点 //查看是否存在根主题节点
String rootTopic = topicPath.getFirst(); String rootTopic = topicPath.getFirst();
topicPath.removeFirst(); topicPath.removeFirst();
if (!topicNodes.containsKey(rootTopic)) { if (!topicNodes.containsKey(rootTopic)) {
TopicNode rootNode = new TopicNode(); TopicNode rootNode = new TopicNode();
rootNode.setMemoryNodes(new ArrayList<>()); rootNode.setMemoryNodes(new CopyOnWriteArrayList<>());
rootNode.setTopicNodes(new HashMap<>()); rootNode.setTopicNodes(new ConcurrentHashMap<>());
topicNodes.put(rootTopic, rootNode); topicNodes.put(rootTopic, rootNode);
existedTopics.put(rootTopic, new LinkedHashSet<>()); existedTopics.put(rootTopic, new LinkedHashSet<>());
} }
@@ -154,9 +178,9 @@ public class MemoryGraph extends PersistableObject {
TopicNode newNode = new TopicNode(); TopicNode newNode = new TopicNode();
lastTopicNode.getTopicNodes().put(topic, newNode); lastTopicNode.getTopicNodes().put(topic, newNode);
lastTopicNode = newNode; lastTopicNode = newNode;
List<MemoryNode> nodeList = new ArrayList<>(); CopyOnWriteArrayList<MemoryNode> nodeList = new CopyOnWriteArrayList<>();
lastTopicNode.setMemoryNodes(nodeList); lastTopicNode.setMemoryNodes(nodeList);
lastTopicNode.setTopicNodes(new HashMap<>()); lastTopicNode.setTopicNodes(new ConcurrentHashMap<>());
existedTopicNodes.add(topic); existedTopicNodes.add(topic);
} }
} }
@@ -174,6 +198,7 @@ public class MemoryGraph extends PersistableObject {
if (!hasSlice) { if (!hasSlice) {
node = new MemoryNode(); node = new MemoryNode();
node.setLocalDate(now); node.setLocalDate(now);
node.setMemoryNodeId(UUID.randomUUID().toString());
node.setMemorySliceList(new ArrayList<>()); node.setMemorySliceList(new ArrayList<>());
lastTopicNode.getMemoryNodes().add(node); lastTopicNode.getMemoryNodes().add(node);
lastTopicNode.getMemoryNodes().sort(null); lastTopicNode.getMemoryNodes().sort(null);
@@ -188,11 +213,12 @@ public class MemoryGraph extends PersistableObject {
private void updateDialogMap(MemorySlice slice) { private void updateDialogMap(MemorySlice slice) {
String summary = slice.getSummary(); String summary = slice.getSummary();
LocalDateTime now = LocalDateTime.now(); LocalDateTime now = LocalDateTime.now();
//更新dialogMap
//更新dialogMap -------------------------
//移除两天前的上下文缓存(切片总结) //移除两天前的上下文缓存(切片总结)
List<LocalDateTime> keysToRemove = new ArrayList<>(); List<LocalDateTime> keysToRemove = new ArrayList<>();
dialogMap.forEach((k, v) -> { dialogMap.forEach((k, v) -> {
if (now.minusDays(2).isAfter(k)){ if (now.minusDays(2).isAfter(k)) {
keysToRemove.add(k); keysToRemove.add(k);
} }
}); });
@@ -201,11 +227,13 @@ public class MemoryGraph extends PersistableObject {
} }
keysToRemove.clear(); keysToRemove.clear();
//放入新缓存 //放入新缓存
dialogMap.put(now,summary); dialogMap.put(now, summary);
//---------------------------------------
//更新userDialogMap //更新userDialogMap
//移除两天前上下文缓存(切片总结) //移除两天前上下文缓存(切片总结)
userDialogMap.forEach((k,v) -> { userDialogMap.forEach((k, v) -> {
if (now.minusDays(2).isAfter(k)){ if (now.minusDays(2).isAfter(k)) {
keysToRemove.add(k); keysToRemove.add(k);
} }
}); });
@@ -213,7 +241,10 @@ public class MemoryGraph extends PersistableObject {
userDialogMap.remove(dateTime); userDialogMap.remove(dateTime);
} }
//放入新缓存 //放入新缓存
userDialogMap.get(now).put(slice.getStartUser(),slice.getSummary()); userDialogMap
.computeIfAbsent(now, k -> new ConcurrentHashMap<>())
.merge(slice.getStartUser(), slice.getSummary(), (oldVal, newVal) -> oldVal + " " + newVal);
} }
private void updateDateIndex(LocalDate now, MemorySlice slice) { private void updateDateIndex(LocalDate now, MemorySlice slice) {
@@ -245,6 +276,14 @@ public class MemoryGraph extends PersistableObject {
} }
public List<MemorySlice> selectMemoryByPath(List<String> topicPath) throws IOException, ClassNotFoundException { public List<MemorySlice> selectMemoryByPath(List<String> topicPath) throws IOException, ClassNotFoundException {
//每日刷新缓存
checkCacheDate();
//检测缓存并更新计数, 查看是否需要放入缓存
updateCacheCounter(topicPath);
//查看是否存在缓存,如果存在,则直接返回
if (memorySliceCache.containsKey(topicPath)) {
return memorySliceCache.get(topicPath);
}
List<MemorySlice> targetSliceList = new ArrayList<>(); List<MemorySlice> targetSliceList = new ArrayList<>();
topicPath = new ArrayList<>(topicPath); topicPath = new ArrayList<>(topicPath);
String targetTopic = topicPath.getLast(); String targetTopic = topicPath.getLast();
@@ -277,10 +316,35 @@ public class MemoryGraph extends PersistableObject {
if (!targetParentMemoryNodes.isEmpty()) { if (!targetParentMemoryNodes.isEmpty()) {
targetSliceList.addAll(targetParentMemoryNodes.getFirst().getMemorySliceList()); targetSliceList.addAll(targetParentMemoryNodes.getFirst().getMemorySliceList());
} }
//放入缓存
updateCache(topicPath, targetSliceList);
return targetSliceList; return targetSliceList;
} }
public HashMap<String,List<MemorySlice>> selectMemoryByDate(LocalDate date){ private void updateCache(List<String> topicPath, List<MemorySlice> targetSliceList) {
Integer tempCount = memoryNodeCacheCounter.get(topicPath);
if (tempCount >= 5) {
memorySliceCache.put(topicPath, targetSliceList);
}
}
private void updateCacheCounter(List<String> topicPath) {
if (memoryNodeCacheCounter.containsKey(topicPath)) {
Integer tempCount = memoryNodeCacheCounter.get(topicPath);
memoryNodeCacheCounter.put(topicPath, ++tempCount);
} else {
memoryNodeCacheCounter.put(topicPath, 1);
}
}
private void checkCacheDate() {
if (cacheDate.isBefore(LocalDate.now())) {
memorySliceCache.clear();
memoryNodeCacheCounter.clear();
}
}
public HashMap<String, List<MemorySlice>> selectMemoryByDate(LocalDate date) {
return dateIndex.get(date); return dateIndex.get(date);
} }

View File

@@ -23,7 +23,12 @@ public class MemoryNode extends PersistableObject implements Comparable<MemoryNo
private static String SLICE_DATA_DIR = "./data/slice/"; private static String SLICE_DATA_DIR = "./data/slice/";
/** /**
* 记忆节点所属日期, 以日期为文件名在硬盘存储记忆数据(如 2025-04-11.slice) * 记忆节点唯一标识, 用于作为实际文件名, 如(xxxx-xxxxx-xxxxx.slice)
*/
private String memoryNodeId;
/**
* 记忆节点所属日期
*/ */
private LocalDate localDate; private LocalDate localDate;
@@ -44,10 +49,11 @@ public class MemoryNode extends PersistableObject implements Comparable<MemoryNo
public List<MemorySlice> getMemorySliceList() throws IOException, ClassNotFoundException { public List<MemorySlice> getMemorySliceList() throws IOException, ClassNotFoundException {
//检查是否存在对应文件 //检查是否存在对应文件
File file = new File(SLICE_DATA_DIR+this.getLocalDate()+".slice"); File file = new File(SLICE_DATA_DIR+this.getMemoryNodeId()+".slice");
if (file.exists()){ if (file.exists()){
this.memorySliceList = deserialize(file); this.memorySliceList = deserialize(file);
}else { }else {
//逻辑正常的话这部分应该不会出现除非在insertMemory中进行save操作之前出现异常中断了方法但程序却没有结束
this.memorySliceList = new ArrayList<>(); this.memorySliceList = new ArrayList<>();
} }
return this.memorySliceList; return this.memorySliceList;
@@ -57,7 +63,7 @@ public class MemoryNode extends PersistableObject implements Comparable<MemoryNo
if (memorySliceList == null){ if (memorySliceList == null){
throw new NullSliceListException("memorySliceList为NULL! 检查实现逻辑!"); throw new NullSliceListException("memorySliceList为NULL! 检查实现逻辑!");
} }
File file = new File(SLICE_DATA_DIR+this.getLocalDate()+".slice"); File file = new File(SLICE_DATA_DIR+this.getMemoryNodeId()+".slice");
try(ObjectOutputStream oos = new ObjectOutputStream(new FileOutputStream(file))){ try(ObjectOutputStream oos = new ObjectOutputStream(new FileOutputStream(file))){
oos.writeObject(this.memorySliceList); oos.writeObject(this.memorySliceList);
} }

View File

@@ -7,6 +7,8 @@ import work.slhaf.memory.pojo.PersistableObject;
import java.io.Serial; import java.io.Serial;
import java.util.HashMap; import java.util.HashMap;
import java.util.List; import java.util.List;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.CopyOnWriteArrayList;
@EqualsAndHashCode(callSuper = true) @EqualsAndHashCode(callSuper = true)
@Data @Data
@@ -15,6 +17,6 @@ public class TopicNode extends PersistableObject {
@Serial @Serial
private static final long serialVersionUID = 1L; private static final long serialVersionUID = 1L;
private HashMap<String,TopicNode> topicNodes; private ConcurrentHashMap<String,TopicNode> topicNodes;
private List<MemoryNode> memoryNodes; private CopyOnWriteArrayList<MemoryNode> memoryNodes;
} }