mirror of
https://github.com/slhaf/Partner.git
synced 2026-05-12 08:43:02 +08:00
feat(memory): 增加记忆缓存功能并优化数据结构- 新增 memorySliceCache 和 memoryNodeCacheCounter 用于缓存记忆切片
- 优化数据结构,使用 ConcurrentHashMap 和 CopyOnWriteArrayList 替代 HashMap 和 ArrayList - 为 MemoryNode 添加唯一标识 memoryNodeId,可作为记忆节点文件名 - 更新 selectMemoryByPath 方法,增加缓存逻辑 - 修改 updateDialogMap 方法,优化用户对话缓存更新逻辑
This commit is contained in:
@@ -16,6 +16,8 @@ import java.nio.file.Paths;
|
|||||||
import java.time.LocalDate;
|
import java.time.LocalDate;
|
||||||
import java.time.LocalDateTime;
|
import java.time.LocalDateTime;
|
||||||
import java.util.*;
|
import java.util.*;
|
||||||
|
import java.util.concurrent.ConcurrentHashMap;
|
||||||
|
import java.util.concurrent.CopyOnWriteArrayList;
|
||||||
|
|
||||||
@EqualsAndHashCode(callSuper = true)
|
@EqualsAndHashCode(callSuper = true)
|
||||||
@Data
|
@Data
|
||||||
@@ -38,13 +40,13 @@ public class MemoryGraph extends PersistableObject {
|
|||||||
* 用于存储已存在的主题列表,便于记忆查找, 使用根主题名称作为键, 子主题名称集合为值
|
* 用于存储已存在的主题列表,便于记忆查找, 使用根主题名称作为键, 子主题名称集合为值
|
||||||
* 该部分在'主题提取LLM'的system prompt中常驻
|
* 该部分在'主题提取LLM'的system prompt中常驻
|
||||||
*/
|
*/
|
||||||
private HashMap<String, LinkedHashSet<String>> existedTopics;
|
private HashMap<String /*根主题名*/, LinkedHashSet<String> /*子主题列表*/> existedTopics;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 记忆节点的日期索引, 同一日期内按照对话id区分
|
* 记忆节点的日期索引, 同一日期内按照对话id区分
|
||||||
* 同时作为临时的同一对话切片容器, 用于为同一对话内的不同切片提供更新上下文的场所
|
* 同时作为临时的同一对话切片容器, 用于为同一对话内的不同切片提供更新上下文的场所
|
||||||
*/
|
*/
|
||||||
private HashMap<LocalDate, HashMap<String, List<MemorySlice>>> dateIndex;
|
private HashMap<LocalDate, HashMap<String /*对话id, 即slice中的字段'memoryId'*/, List<MemorySlice>>> dateIndex;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 近两日的对话总结缓存, 用于为大模型提供必要的记忆补充, hashmap以切片的存储时间为键,总结为值
|
* 近两日的对话总结缓存, 用于为大模型提供必要的记忆补充, hashmap以切片的存储时间为键,总结为值
|
||||||
@@ -56,7 +58,7 @@ public class MemoryGraph extends PersistableObject {
|
|||||||
/**
|
/**
|
||||||
* 近两日的区分用户的对话总结缓存,在prompt结构上比dialogMap层级深一层, dialogMap更具近两日整体对话的摘要性质
|
* 近两日的区分用户的对话总结缓存,在prompt结构上比dialogMap层级深一层, dialogMap更具近两日整体对话的摘要性质
|
||||||
*/
|
*/
|
||||||
private HashMap<LocalDateTime,HashMap<String/*userId*/,String>> userDialogMap;
|
private ConcurrentHashMap<LocalDateTime, ConcurrentHashMap<String/*userId*/, String>> userDialogMap;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 当前对话的活动性总结, 拥有比dialogMap更丰富的全文细节, 作为当前对话token超限时的必要上下文压缩存储
|
* 当前对话的活动性总结, 拥有比dialogMap更丰富的全文细节, 作为当前对话token超限时的必要上下文压缩存储
|
||||||
@@ -67,7 +69,23 @@ public class MemoryGraph extends PersistableObject {
|
|||||||
* 存储确定性记忆, 如'用户爱好'等确定性信息
|
* 存储确定性记忆, 如'用户爱好'等确定性信息
|
||||||
* 该部分作为'主LLM'system prompt常驻
|
* 该部分作为'主LLM'system prompt常驻
|
||||||
*/
|
*/
|
||||||
private HashMap<String /*userId*/, HashMap<String /*memoryKey*/,String /*memoryValue*/>> staticMemory;
|
private HashMap<String /*userId*/, ConcurrentHashMap<String /*memoryKey*/, String /*memoryValue*/>> staticMemory;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* memorySliceCache计数器,每日清空
|
||||||
|
*/
|
||||||
|
private ConcurrentHashMap<List<String> /*触发查询的主题列表*/, Integer> memoryNodeCacheCounter;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 记忆切片缓存,每日清空
|
||||||
|
* 用于记录作为终点节点调用次数最多的记忆节点的切片数据
|
||||||
|
*/
|
||||||
|
private ConcurrentHashMap<List<String> /*主题路径*/, List<MemorySlice> /*切片列表*/> memorySliceCache;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 缓存日期
|
||||||
|
*/
|
||||||
|
private LocalDate cacheDate;
|
||||||
|
|
||||||
public MemoryGraph(String id) {
|
public MemoryGraph(String id) {
|
||||||
this.id = id;
|
this.id = id;
|
||||||
@@ -75,6 +93,8 @@ public class MemoryGraph extends PersistableObject {
|
|||||||
this.existedTopics = new HashMap<>();
|
this.existedTopics = new HashMap<>();
|
||||||
this.dateIndex = new HashMap<>();
|
this.dateIndex = new HashMap<>();
|
||||||
this.staticMemory = new HashMap<>();
|
this.staticMemory = new HashMap<>();
|
||||||
|
this.memoryNodeCacheCounter = new ConcurrentHashMap<>();
|
||||||
|
this.memorySliceCache = new ConcurrentHashMap<>();
|
||||||
}
|
}
|
||||||
|
|
||||||
public static MemoryGraph initialize(String id) {
|
public static MemoryGraph initialize(String id) {
|
||||||
@@ -133,14 +153,18 @@ public class MemoryGraph extends PersistableObject {
|
|||||||
}
|
}
|
||||||
|
|
||||||
public void insertMemory(List<String> topicPath, MemorySlice slice) throws IOException, ClassNotFoundException {
|
public void insertMemory(List<String> topicPath, MemorySlice slice) throws IOException, ClassNotFoundException {
|
||||||
|
//每日刷新缓存
|
||||||
|
checkCacheDate();
|
||||||
|
//如果topicPath在memorySliceCache中存在对应缓存,由于进行的插入操作,则需要移除该缓存,但不清除相关计数
|
||||||
|
memorySliceCache.remove(topicPath);
|
||||||
topicPath = new ArrayList<>(topicPath);
|
topicPath = new ArrayList<>(topicPath);
|
||||||
//查看是否存在根主题节点
|
//查看是否存在根主题节点
|
||||||
String rootTopic = topicPath.getFirst();
|
String rootTopic = topicPath.getFirst();
|
||||||
topicPath.removeFirst();
|
topicPath.removeFirst();
|
||||||
if (!topicNodes.containsKey(rootTopic)) {
|
if (!topicNodes.containsKey(rootTopic)) {
|
||||||
TopicNode rootNode = new TopicNode();
|
TopicNode rootNode = new TopicNode();
|
||||||
rootNode.setMemoryNodes(new ArrayList<>());
|
rootNode.setMemoryNodes(new CopyOnWriteArrayList<>());
|
||||||
rootNode.setTopicNodes(new HashMap<>());
|
rootNode.setTopicNodes(new ConcurrentHashMap<>());
|
||||||
topicNodes.put(rootTopic, rootNode);
|
topicNodes.put(rootTopic, rootNode);
|
||||||
existedTopics.put(rootTopic, new LinkedHashSet<>());
|
existedTopics.put(rootTopic, new LinkedHashSet<>());
|
||||||
}
|
}
|
||||||
@@ -154,9 +178,9 @@ public class MemoryGraph extends PersistableObject {
|
|||||||
TopicNode newNode = new TopicNode();
|
TopicNode newNode = new TopicNode();
|
||||||
lastTopicNode.getTopicNodes().put(topic, newNode);
|
lastTopicNode.getTopicNodes().put(topic, newNode);
|
||||||
lastTopicNode = newNode;
|
lastTopicNode = newNode;
|
||||||
List<MemoryNode> nodeList = new ArrayList<>();
|
CopyOnWriteArrayList<MemoryNode> nodeList = new CopyOnWriteArrayList<>();
|
||||||
lastTopicNode.setMemoryNodes(nodeList);
|
lastTopicNode.setMemoryNodes(nodeList);
|
||||||
lastTopicNode.setTopicNodes(new HashMap<>());
|
lastTopicNode.setTopicNodes(new ConcurrentHashMap<>());
|
||||||
existedTopicNodes.add(topic);
|
existedTopicNodes.add(topic);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -174,6 +198,7 @@ public class MemoryGraph extends PersistableObject {
|
|||||||
if (!hasSlice) {
|
if (!hasSlice) {
|
||||||
node = new MemoryNode();
|
node = new MemoryNode();
|
||||||
node.setLocalDate(now);
|
node.setLocalDate(now);
|
||||||
|
node.setMemoryNodeId(UUID.randomUUID().toString());
|
||||||
node.setMemorySliceList(new ArrayList<>());
|
node.setMemorySliceList(new ArrayList<>());
|
||||||
lastTopicNode.getMemoryNodes().add(node);
|
lastTopicNode.getMemoryNodes().add(node);
|
||||||
lastTopicNode.getMemoryNodes().sort(null);
|
lastTopicNode.getMemoryNodes().sort(null);
|
||||||
@@ -188,11 +213,12 @@ public class MemoryGraph extends PersistableObject {
|
|||||||
private void updateDialogMap(MemorySlice slice) {
|
private void updateDialogMap(MemorySlice slice) {
|
||||||
String summary = slice.getSummary();
|
String summary = slice.getSummary();
|
||||||
LocalDateTime now = LocalDateTime.now();
|
LocalDateTime now = LocalDateTime.now();
|
||||||
//更新dialogMap
|
|
||||||
|
//更新dialogMap -------------------------
|
||||||
//移除两天前的上下文缓存(切片总结)
|
//移除两天前的上下文缓存(切片总结)
|
||||||
List<LocalDateTime> keysToRemove = new ArrayList<>();
|
List<LocalDateTime> keysToRemove = new ArrayList<>();
|
||||||
dialogMap.forEach((k, v) -> {
|
dialogMap.forEach((k, v) -> {
|
||||||
if (now.minusDays(2).isAfter(k)){
|
if (now.minusDays(2).isAfter(k)) {
|
||||||
keysToRemove.add(k);
|
keysToRemove.add(k);
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
@@ -201,11 +227,13 @@ public class MemoryGraph extends PersistableObject {
|
|||||||
}
|
}
|
||||||
keysToRemove.clear();
|
keysToRemove.clear();
|
||||||
//放入新缓存
|
//放入新缓存
|
||||||
dialogMap.put(now,summary);
|
dialogMap.put(now, summary);
|
||||||
|
//---------------------------------------
|
||||||
|
|
||||||
//更新userDialogMap
|
//更新userDialogMap
|
||||||
//移除两天前上下文缓存(切片总结)
|
//移除两天前上下文缓存(切片总结)
|
||||||
userDialogMap.forEach((k,v) -> {
|
userDialogMap.forEach((k, v) -> {
|
||||||
if (now.minusDays(2).isAfter(k)){
|
if (now.minusDays(2).isAfter(k)) {
|
||||||
keysToRemove.add(k);
|
keysToRemove.add(k);
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
@@ -213,7 +241,10 @@ public class MemoryGraph extends PersistableObject {
|
|||||||
userDialogMap.remove(dateTime);
|
userDialogMap.remove(dateTime);
|
||||||
}
|
}
|
||||||
//放入新缓存
|
//放入新缓存
|
||||||
userDialogMap.get(now).put(slice.getStartUser(),slice.getSummary());
|
userDialogMap
|
||||||
|
.computeIfAbsent(now, k -> new ConcurrentHashMap<>())
|
||||||
|
.merge(slice.getStartUser(), slice.getSummary(), (oldVal, newVal) -> oldVal + " " + newVal);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private void updateDateIndex(LocalDate now, MemorySlice slice) {
|
private void updateDateIndex(LocalDate now, MemorySlice slice) {
|
||||||
@@ -245,6 +276,14 @@ public class MemoryGraph extends PersistableObject {
|
|||||||
}
|
}
|
||||||
|
|
||||||
public List<MemorySlice> selectMemoryByPath(List<String> topicPath) throws IOException, ClassNotFoundException {
|
public List<MemorySlice> selectMemoryByPath(List<String> topicPath) throws IOException, ClassNotFoundException {
|
||||||
|
//每日刷新缓存
|
||||||
|
checkCacheDate();
|
||||||
|
//检测缓存并更新计数, 查看是否需要放入缓存
|
||||||
|
updateCacheCounter(topicPath);
|
||||||
|
//查看是否存在缓存,如果存在,则直接返回
|
||||||
|
if (memorySliceCache.containsKey(topicPath)) {
|
||||||
|
return memorySliceCache.get(topicPath);
|
||||||
|
}
|
||||||
List<MemorySlice> targetSliceList = new ArrayList<>();
|
List<MemorySlice> targetSliceList = new ArrayList<>();
|
||||||
topicPath = new ArrayList<>(topicPath);
|
topicPath = new ArrayList<>(topicPath);
|
||||||
String targetTopic = topicPath.getLast();
|
String targetTopic = topicPath.getLast();
|
||||||
@@ -277,10 +316,35 @@ public class MemoryGraph extends PersistableObject {
|
|||||||
if (!targetParentMemoryNodes.isEmpty()) {
|
if (!targetParentMemoryNodes.isEmpty()) {
|
||||||
targetSliceList.addAll(targetParentMemoryNodes.getFirst().getMemorySliceList());
|
targetSliceList.addAll(targetParentMemoryNodes.getFirst().getMemorySliceList());
|
||||||
}
|
}
|
||||||
|
//放入缓存
|
||||||
|
updateCache(topicPath, targetSliceList);
|
||||||
return targetSliceList;
|
return targetSliceList;
|
||||||
}
|
}
|
||||||
|
|
||||||
public HashMap<String,List<MemorySlice>> selectMemoryByDate(LocalDate date){
|
private void updateCache(List<String> topicPath, List<MemorySlice> targetSliceList) {
|
||||||
|
Integer tempCount = memoryNodeCacheCounter.get(topicPath);
|
||||||
|
if (tempCount >= 5) {
|
||||||
|
memorySliceCache.put(topicPath, targetSliceList);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private void updateCacheCounter(List<String> topicPath) {
|
||||||
|
if (memoryNodeCacheCounter.containsKey(topicPath)) {
|
||||||
|
Integer tempCount = memoryNodeCacheCounter.get(topicPath);
|
||||||
|
memoryNodeCacheCounter.put(topicPath, ++tempCount);
|
||||||
|
} else {
|
||||||
|
memoryNodeCacheCounter.put(topicPath, 1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private void checkCacheDate() {
|
||||||
|
if (cacheDate.isBefore(LocalDate.now())) {
|
||||||
|
memorySliceCache.clear();
|
||||||
|
memoryNodeCacheCounter.clear();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public HashMap<String, List<MemorySlice>> selectMemoryByDate(LocalDate date) {
|
||||||
return dateIndex.get(date);
|
return dateIndex.get(date);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -23,7 +23,12 @@ public class MemoryNode extends PersistableObject implements Comparable<MemoryNo
|
|||||||
private static String SLICE_DATA_DIR = "./data/slice/";
|
private static String SLICE_DATA_DIR = "./data/slice/";
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 记忆节点所属日期, 以日期为文件名在硬盘存储记忆数据(如 2025-04-11.slice)
|
* 记忆节点唯一标识, 用于作为实际文件名, 如(xxxx-xxxxx-xxxxx.slice)
|
||||||
|
*/
|
||||||
|
private String memoryNodeId;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 记忆节点所属日期
|
||||||
*/
|
*/
|
||||||
private LocalDate localDate;
|
private LocalDate localDate;
|
||||||
|
|
||||||
@@ -44,10 +49,11 @@ public class MemoryNode extends PersistableObject implements Comparable<MemoryNo
|
|||||||
|
|
||||||
public List<MemorySlice> getMemorySliceList() throws IOException, ClassNotFoundException {
|
public List<MemorySlice> getMemorySliceList() throws IOException, ClassNotFoundException {
|
||||||
//检查是否存在对应文件
|
//检查是否存在对应文件
|
||||||
File file = new File(SLICE_DATA_DIR+this.getLocalDate()+".slice");
|
File file = new File(SLICE_DATA_DIR+this.getMemoryNodeId()+".slice");
|
||||||
if (file.exists()){
|
if (file.exists()){
|
||||||
this.memorySliceList = deserialize(file);
|
this.memorySliceList = deserialize(file);
|
||||||
}else {
|
}else {
|
||||||
|
//逻辑正常的话,这部分应该不会出现,除非在insertMemory中进行save操作之前出现异常,中断了方法,但程序却没有结束
|
||||||
this.memorySliceList = new ArrayList<>();
|
this.memorySliceList = new ArrayList<>();
|
||||||
}
|
}
|
||||||
return this.memorySliceList;
|
return this.memorySliceList;
|
||||||
@@ -57,7 +63,7 @@ public class MemoryNode extends PersistableObject implements Comparable<MemoryNo
|
|||||||
if (memorySliceList == null){
|
if (memorySliceList == null){
|
||||||
throw new NullSliceListException("memorySliceList为NULL! 检查实现逻辑!");
|
throw new NullSliceListException("memorySliceList为NULL! 检查实现逻辑!");
|
||||||
}
|
}
|
||||||
File file = new File(SLICE_DATA_DIR+this.getLocalDate()+".slice");
|
File file = new File(SLICE_DATA_DIR+this.getMemoryNodeId()+".slice");
|
||||||
try(ObjectOutputStream oos = new ObjectOutputStream(new FileOutputStream(file))){
|
try(ObjectOutputStream oos = new ObjectOutputStream(new FileOutputStream(file))){
|
||||||
oos.writeObject(this.memorySliceList);
|
oos.writeObject(this.memorySliceList);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -7,6 +7,8 @@ import work.slhaf.memory.pojo.PersistableObject;
|
|||||||
import java.io.Serial;
|
import java.io.Serial;
|
||||||
import java.util.HashMap;
|
import java.util.HashMap;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
import java.util.concurrent.ConcurrentHashMap;
|
||||||
|
import java.util.concurrent.CopyOnWriteArrayList;
|
||||||
|
|
||||||
@EqualsAndHashCode(callSuper = true)
|
@EqualsAndHashCode(callSuper = true)
|
||||||
@Data
|
@Data
|
||||||
@@ -15,6 +17,6 @@ public class TopicNode extends PersistableObject {
|
|||||||
@Serial
|
@Serial
|
||||||
private static final long serialVersionUID = 1L;
|
private static final long serialVersionUID = 1L;
|
||||||
|
|
||||||
private HashMap<String,TopicNode> topicNodes;
|
private ConcurrentHashMap<String,TopicNode> topicNodes;
|
||||||
private List<MemoryNode> memoryNodes;
|
private CopyOnWriteArrayList<MemoryNode> memoryNodes;
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user