进行第一阶段的调试修复

- 添加 DebugMonitor 加了一个较为无用的线程,打上断点用于获取即时模块信息
- 调整 MemoryGraph 中临时切片前后序生成容器与日期索引分离,日期索引将存储 localDate 和 memoryId
- 修复了几个 MemoryGraph 用到的类对于序列化的实现
- 将 MemoryManager 中维护的 activatedMemorySlice 同样作为主题提取的依据,并调整了提示词
- 因 主题提取LLM 效果不稳定,故添加了必要的异常处理机制
- MemorySlice 中前后序机制存在循环引用问题,排除 toString 的调用
- 移除了提示词中过多的示例,仅保留一份
- 记忆自动更新线程调整:在 SessionManager 中维护 lastUpdatedTime ,用于标识最近聊天时间; 计数当前对话数量,如果只有系统提示词则不进行记忆更新
- MemoryGraph 获取主题树时将标识记忆节点数量,供主题提取模型识别,减少空主题节点作为目标节点的情况
- 在 PreprocessExecutor 执行时将先判断是否存在memoryId, 除此之外,memoryId也将在记忆自动更新线程执行后刷新
- 将 SessionManager 添加了序列化机制,添加了程序停止时自动序列化保存的钩子
- 当主题提取模型指定了不包含记忆节点的主题节点时,MemoryResult 的空属性将会使其跳过无效的切片评估

待解决问题:
- MemoryGraph 在进行记忆更新时有时会出现错误,出现条件不明
- MemorySelector 拿到的List<EvaluatedSlice> 总是为空,原因不明
- 切片评估线程似乎没有运行,原因不明
This commit is contained in:
2025-05-10 21:51:45 +08:00
parent 550a5ee2b0
commit 15d6b98eac
20 changed files with 387 additions and 264 deletions

View File

@@ -2,8 +2,8 @@ package work.slhaf.agent;
import lombok.Data; import lombok.Data;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.java_websocket.WebSocket;
import work.slhaf.agent.common.config.Config; import work.slhaf.agent.common.config.Config;
import work.slhaf.agent.common.monitor.DebugMonitor;
import work.slhaf.agent.core.InteractionHub; import work.slhaf.agent.core.InteractionHub;
import work.slhaf.agent.core.interaction.InputReceiver; import work.slhaf.agent.core.interaction.InputReceiver;
import work.slhaf.agent.core.interaction.TaskCallback; import work.slhaf.agent.core.interaction.TaskCallback;
@@ -34,6 +34,9 @@ public class Agent implements TaskCallback, InputReceiver {
server.launch(); server.launch();
agent.setMessageSender(server); agent.setMessageSender(server);
log.info("Agent 加载完毕.."); log.info("Agent 加载完毕..");
//启动监测线程
DebugMonitor.initialize();
} }
} }

View File

@@ -2,10 +2,19 @@ package work.slhaf.agent.common.chat.pojo;
import lombok.AllArgsConstructor; import lombok.AllArgsConstructor;
import lombok.Data; import lombok.Data;
import lombok.EqualsAndHashCode;
import work.slhaf.agent.common.pojo.PersistableObject;
import java.io.Serial;
@EqualsAndHashCode(callSuper = true)
@Data @Data
@AllArgsConstructor @AllArgsConstructor
public class MetaMessage { public class MetaMessage extends PersistableObject {
@Serial
private static final long serialVersionUID = 1L;
private Message userMessage; private Message userMessage;
private Message assistantMessage; private Message assistantMessage;
} }

View File

@@ -12,7 +12,6 @@ import work.slhaf.agent.modules.memory.updater.MemoryUpdater;
import work.slhaf.agent.modules.memory.updater.static_extractor.StaticMemoryExtractor; import work.slhaf.agent.modules.memory.updater.static_extractor.StaticMemoryExtractor;
import work.slhaf.agent.modules.memory.updater.summarizer.MemorySummarizer; import work.slhaf.agent.modules.memory.updater.summarizer.MemorySummarizer;
import work.slhaf.agent.modules.task.TaskEvaluator; import work.slhaf.agent.modules.task.TaskEvaluator;
import work.slhaf.agent.modules.task.TaskScheduler;
import java.io.File; import java.io.File;
import java.io.IOException; import java.io.IOException;

View File

@@ -178,12 +178,14 @@ public class ModelConstant {
输入字段说明 输入字段说明
• `text`: 用户当前输入的文本内容 • `text`: 用户当前输入的文本内容
• `topic_tree`: 当前可用的主题树结构(多层级结构,需返回从根节点([root])到目标节点的完整路径) • `topic_tree`: 当前可用的主题树结构(多层级结构,需返回从根节点([root])到目标节点的完整路径), 主题树中类似`[0]`的标志为主题节点下对应的记忆节点数量当记忆节点数量为0时该主题节点不能作为目标节点
• `date`: 当前对话发生的日期(用于时间推理) • `date`: 当前对话发生的日期(用于时间推理)
• `history`: 用户与LLM的完整对话历史用于主题连续性判断 • `history`: 用户与LLM的完整对话历史用于主题连续性判断
• `activated_memory_slices`: 已经激活的记忆切片
输出规则 输出规则
1. 基本响应格式: 1. 基本响应格式:
@@ -205,6 +207,10 @@ public class ModelConstant {
◦ 除非包含明确的新子主题,否则不重复提取相同主题路径 ◦ 除非包含明确的新子主题,否则不重复提取相同主题路径
• 当激活的记忆切片已经不符合当前主题时:
◦ 除非主题树中存在匹配的主题路径,否则仍不进行提取操作
3. 日期提取规则(保持不变): 3. 日期提取规则(保持不变):
• 仅接受具体日期YYYY-MM-DD格式 • 仅接受具体日期YYYY-MM-DD格式
@@ -226,11 +232,12 @@ public class ModelConstant {
决策流程 决策流程
0. 若主题树为空或者未提供主题树则直接将recall设置为null, 不进行后续判定 0. 若主题树为空或者未提供主题树则直接将recall设置为null, 不进行后续判定
1. 首先分析`history`判断当前对话主题上下文 1. 对于所有记忆节点个数为0的主题节点来说这些节点不能作为主题路径的终点
2. 然后分析`text` 2. 首先分析`text`
a. 检测是否包含具体日期→添加date类型 a. 检测用户提到的具体日期是否明确与某事物/事件相关→添加date类型
b. 检测是否包含新主题→添加topic类型 b. 检测用户提到的事物/事件是否明确与主题树中存在的主题路径相关→添加topic类型
3. 最终综合判断`recall`值, 如果找到了对应的主题路径则recall值为true; 否则为false 3. 分析`history`判断当前对话主题上下文, 如果与`text`中的内容明显无关,则仅只依据`text`内容提取主题路径
4. 最终综合判断`recall`值, 如果找到了对应的主题路径则recall值为true; 否则为false
完整示例 完整示例
示例1主题延续 示例1主题延续
@@ -238,12 +245,12 @@ public class ModelConstant {
"text": "关于NodeJS的并发处理还有哪些要注意的", "text": "关于NodeJS的并发处理还有哪些要注意的",
"topic_tree": " "topic_tree": "
编程[root] 编程[root]
├── JavaScript ├── JavaScript[0]
│ ├── NodeJS │ ├── NodeJS
│ │ ├── 并发处理 │ │ ├── 并发处理[1]
│ │ └── 事件循环 │ │ └── 事件循环[0]
│ └── Express │ └── Express[1]
│ └── 中间件 │ └── 中间件[0]
└── Python", └── Python",
"date": "2024-04-20", "date": "2024-04-20",
"history": [ "history": [
@@ -261,12 +268,12 @@ public class ModelConstant {
"text": "现在我想了解Express中间件的原理", "text": "现在我想了解Express中间件的原理",
"topic_tree": " "topic_tree": "
编程[root] 编程[root]
├── JavaScript ├── JavaScript[0]
│ ├── NodeJS │ ├── NodeJS[0]
│ │ ├── 并发处理 │ │ ├── 并发处理[1]
│ │ └── 事件循环 │ │ └── 事件循环[0]
│ └── Express │ └── Express[0]
│ └── 中间件 │ └── 中间件[1]
└── Python", └── Python",
"date": "2024-04-20", "date": "2024-04-20",
"history": [ "history": [
@@ -286,12 +293,12 @@ public class ModelConstant {
"text": "2024-04-15讨论的Python内容和现在的Express需求", "text": "2024-04-15讨论的Python内容和现在的Express需求",
"topic_tree": " "topic_tree": "
编程[root] 编程[root]
├── JavaScript ├── JavaScript[0]
│ ├── NodeJS │ ├── NodeJS[0]
│ │ ├── 并发处理 │ │ ├── 并发处理[1]
│ │ └── 事件循环 │ │ └── 事件循环[1]
│ └── Express │ └── Express[1]
│ └── 中间件 │ └── 中间件[0]
└── Python", └── Python",
"date": "2024-04-20", "date": "2024-04-20",
"history": [ "history": [
@@ -312,12 +319,12 @@ public class ModelConstant {
"text": "上周说的那个JavaScript特性", "text": "上周说的那个JavaScript特性",
"topic_tree": " "topic_tree": "
编程[root] 编程[root]
├── JavaScript ├── JavaScript[0]
│ ├── NodeJS │ ├── NodeJS[0]
│ │ ├── 并发处理 │ │ ├── 并发处理[1]
│ │ └── 事件循环 │ │ └── 事件循环[1]
│ └── Express │ └── Express[0]
│ └── 中间件 │ └── 中间件[1]
└── Python", └── Python",
"date": "2024-04-20", "date": "2024-04-20",
"history": [...] "history": [...]

View File

@@ -0,0 +1,36 @@
package work.slhaf.agent.common.monitor;
import lombok.extern.slf4j.Slf4j;
import work.slhaf.agent.core.interaction.InteractionThreadPoolExecutor;
@Slf4j
public class DebugMonitor {
private InteractionThreadPoolExecutor executor;
private static DebugMonitor debugMonitor;
public static void initialize() {
debugMonitor = new DebugMonitor();
debugMonitor.executor = InteractionThreadPoolExecutor.getInstance();
debugMonitor.runMonitor();
}
private void runMonitor() {
executor.execute(() -> {
while (true) {
try {
Thread.sleep(3000);
} catch (InterruptedException e) {
log.error("监测线程报错?");
}
}
});
}
public static DebugMonitor getInstance(){
if (debugMonitor == null) {
initialize();
}
return debugMonitor;
}
}

View File

@@ -1,15 +1,20 @@
package work.slhaf.agent.core.memory; package work.slhaf.agent.core.memory;
import cn.hutool.json.JSONUtil;
import lombok.Data; import lombok.Data;
import lombok.EqualsAndHashCode; import lombok.EqualsAndHashCode;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.apache.commons.io.FileUtils; import org.apache.commons.io.FileUtils;
import work.slhaf.agent.common.chat.pojo.Message; import work.slhaf.agent.common.chat.pojo.Message;
import work.slhaf.agent.common.pojo.PersistableObject; import work.slhaf.agent.common.pojo.PersistableObject;
import work.slhaf.agent.core.memory.exception.UnExistedDateIndexException;
import work.slhaf.agent.core.memory.exception.UnExistedTopicException; import work.slhaf.agent.core.memory.exception.UnExistedTopicException;
import work.slhaf.agent.core.memory.node.MemoryNode; import work.slhaf.agent.core.memory.node.MemoryNode;
import work.slhaf.agent.core.memory.node.TopicNode; import work.slhaf.agent.core.memory.node.TopicNode;
import work.slhaf.agent.core.memory.pojo.*; import work.slhaf.agent.core.memory.pojo.MemoryResult;
import work.slhaf.agent.core.memory.pojo.MemorySlice;
import work.slhaf.agent.core.memory.pojo.MemorySliceResult;
import work.slhaf.agent.core.memory.pojo.User;
import java.io.*; import java.io.*;
import java.nio.file.Files; import java.nio.file.Files;
@@ -44,10 +49,14 @@ public class MemoryGraph extends PersistableObject {
private HashMap<String /*根主题名*/, LinkedHashSet<String> /*子主题列表*/> existedTopics; private HashMap<String /*根主题名*/, LinkedHashSet<String> /*子主题列表*/> existedTopics;
/** /**
* 记忆节点的日期索引, 同一日期内按照对话id区分 * 临时的同一对话切片容器, 用于为同一对话内的不同切片提供更新上下文的场所
* 同时作为临时的同一对话切片容器, 用于为同一对话内的不同切片提供更新上下文的场所
*/ */
private HashMap<LocalDate, HashMap<String /*对话id, 即slice中的字段'memoryId'*/, List<MemorySlice>>> dateIndex; private HashMap<String /*对话id, 即slice中的字段'memoryId'*/, List<MemorySlice>> currentDateDialogSlices;
/**
* 记忆节点的日期索引, 同一日期内按照对话id区分
*/
private HashMap<LocalDate, Set<String>> dateIndex;
/** /**
* 近两日的对话总结缓存, 用于为大模型提供必要的记忆补充, hashmap以切片的存储时间为键总结为值 * 近两日的对话总结缓存, 用于为大模型提供必要的记忆补充, hashmap以切片的存储时间为键总结为值
@@ -110,13 +119,11 @@ public class MemoryGraph extends PersistableObject {
*/ */
private Set<Long> selectedSlices; private Set<Long> selectedSlices;
private String memoryId;
public MemoryGraph(String id) { public MemoryGraph(String id) {
this.id = id; this.id = id;
this.topicNodes = new HashMap<>(); this.topicNodes = new HashMap<>();
this.existedTopics = new HashMap<>(); this.existedTopics = new HashMap<>();
this.dateIndex = new HashMap<>(); this.currentDateDialogSlices = new HashMap<>();
this.staticMemory = new HashMap<>(); this.staticMemory = new HashMap<>();
this.memoryNodeCacheCounter = new ConcurrentHashMap<>(); this.memoryNodeCacheCounter = new ConcurrentHashMap<>();
this.memorySliceCache = new ConcurrentHashMap<>(); this.memorySliceCache = new ConcurrentHashMap<>();
@@ -129,6 +136,7 @@ public class MemoryGraph extends PersistableObject {
this.character = """ this.character = """
实话实说,不做糖衣炮弹。 采取前瞻性的观点。 始终保持尊重。 乐于分享明确的观点。 保持轻松、随和。 直奔主题。 务实至上。 勇于创新,打破常规思维。使用中文回答所有问题。 实话实说,不做糖衣炮弹。 采取前瞻性的观点。 始终保持尊重。 乐于分享明确的观点。 保持轻松、随和。 直奔主题。 务实至上。 勇于创新,打破常规思维。使用中文回答所有问题。
"""; """;
this.dateIndex = new HashMap<>();
} }
public static MemoryGraph getInstance(String id) throws IOException, ClassNotFoundException { public static MemoryGraph getInstance(String id) throws IOException, ClassNotFoundException {
@@ -195,12 +203,18 @@ public class MemoryGraph extends PersistableObject {
LocalDate now = LocalDate.now(); LocalDate now = LocalDate.now();
boolean hasSlice = false; boolean hasSlice = false;
MemoryNode node = null; MemoryNode node = null;
for (MemoryNode memoryNode : lastTopicNode.getMemoryNodes()) { try {
if (now.equals(memoryNode.getLocalDate())) { for (MemoryNode memoryNode : lastTopicNode.getMemoryNodes()) {
hasSlice = true; if (now.equals(memoryNode.getLocalDate())) {
node = memoryNode; hasSlice = true;
break; node = memoryNode;
break;
}
} }
} catch (Exception e) {
log.error("插入记忆时出错: ", e);
log.error("主题路径: {}; 切片内容: {}", topicPath, slice);
log.error("主题树状态: {}", JSONUtil.toJsonPrettyStr(topicNodes));
} }
if (!hasSlice) { if (!hasSlice) {
node = new MemoryNode(); node = new MemoryNode();
@@ -217,13 +231,27 @@ public class MemoryGraph extends PersistableObject {
generateTopicPath(relatedTopic); generateTopicPath(relatedTopic);
} }
updateDateIndex(now, slice); updateSlicePrecedent(slice);
updateDateIndex(slice);
if (!slice.isPrivate()) { if (!slice.isPrivate()) {
updateUserDialogMap(slice); updateUserDialogMap(slice);
} }
node.saveMemorySliceList(); node.saveMemorySliceList();
} }
private void updateDateIndex(MemorySlice slice) {
String memoryId = slice.getMemoryId();
LocalDate date = LocalDate.now();
if (!dateIndex.containsKey(date)) {
HashSet<String> memoryIdSet = new HashSet<>();
memoryIdSet.add(memoryId);
dateIndex.put(date, memoryIdSet);
} else {
dateIndex.get(date).add(memoryId);
}
}
private TopicNode generateTopicPath(List<String> topicPath) { private TopicNode generateTopicPath(List<String> topicPath) {
topicPath = new ArrayList<>(topicPath); topicPath = new ArrayList<>(topicPath);
//查看是否存在根主题节点 //查看是否存在根主题节点
@@ -240,11 +268,17 @@ public class MemoryGraph extends PersistableObject {
TopicNode lastTopicNode = topicNodes.get(rootTopic); TopicNode lastTopicNode = topicNodes.get(rootTopic);
Set<String> existedTopicNodes = existedTopics.get(rootTopic); Set<String> existedTopicNodes = existedTopics.get(rootTopic);
for (String topic : topicPath) { for (String topic : topicPath) {
if (existedTopicNodes.contains(topic)) { if (existedTopicNodes.contains(topic) && lastTopicNode.getTopicNodes().containsKey(topic)) {
lastTopicNode = lastTopicNode.getTopicNodes().get(topic); lastTopicNode = lastTopicNode.getTopicNodes().get(topic);
} else { } else {
TopicNode newNode = new TopicNode(); TopicNode newNode = new TopicNode();
lastTopicNode.getTopicNodes().put(topic, newNode); try {
lastTopicNode.getTopicNodes().put(topic, newNode);
} catch (Exception e) {
log.error("主题路径: {}; ", topicPath);
log.error("主题树状态: {}", JSONUtil.toJsonPrettyStr(topicNodes));
}
lastTopicNode = newNode; lastTopicNode = newNode;
CopyOnWriteArrayList<MemoryNode> nodeList = new CopyOnWriteArrayList<>(); CopyOnWriteArrayList<MemoryNode> nodeList = new CopyOnWriteArrayList<>();
lastTopicNode.setMemoryNodes(nodeList); lastTopicNode.setMemoryNodes(nodeList);
@@ -281,16 +315,12 @@ public class MemoryGraph extends PersistableObject {
} }
private void updateDateIndex(LocalDate now, MemorySlice slice) { private void updateSlicePrecedent(MemorySlice slice) {
String memoryId = slice.getMemoryId(); String memoryId = slice.getMemoryId();
//查看是否存在当前日期的对话切片索引 //查看是否切换了memoryId
if (!dateIndex.containsKey(now)) {
dateIndex.put(now, new HashMap<>());
}
//查看当前日期的索引中是否存在该对话的索引
HashMap<String, List<MemorySlice>> currentDateDialogSlices = dateIndex.get(now);
if (!currentDateDialogSlices.containsKey(memoryId)) { if (!currentDateDialogSlices.containsKey(memoryId)) {
List<MemorySlice> memorySliceList = new ArrayList<>(); List<MemorySlice> memorySliceList = new ArrayList<>();
currentDateDialogSlices.clear();
currentDateDialogSlices.put(memoryId, memorySliceList); currentDateDialogSlices.put(memoryId, memorySliceList);
} }
//处理上下文关系 //处理上下文关系
@@ -313,20 +343,20 @@ public class MemoryGraph extends PersistableObject {
public MemoryResult selectMemory(String topicPathStr) throws IOException, ClassNotFoundException { public MemoryResult selectMemory(String topicPathStr) throws IOException, ClassNotFoundException {
List<String> topicPath = List.of(topicPathStr.split("->")); List<String> topicPath = List.of(topicPathStr.split("->"));
List<String> path = new ArrayList<>(topicPath);
MemoryResult memoryResult = new MemoryResult(); MemoryResult memoryResult = new MemoryResult();
//每日刷新缓存 //每日刷新缓存
checkCacheDate(); checkCacheDate();
//检测缓存并更新计数, 查看是否需要放入缓存 //检测缓存并更新计数, 查看是否需要放入缓存
updateCacheCounter(topicPath); updateCacheCounter(path);
//查看是否存在缓存,如果存在,则直接返回 //查看是否存在缓存,如果存在,则直接返回
if (memorySliceCache.containsKey(topicPath)) { if (memorySliceCache.containsKey(path)) {
return memorySliceCache.get(topicPath); return memorySliceCache.get(path);
} }
CopyOnWriteArrayList<MemorySliceResult> targetSliceList = new CopyOnWriteArrayList<>(); CopyOnWriteArrayList<MemorySliceResult> targetSliceList = new CopyOnWriteArrayList<>();
topicPath = new ArrayList<>(topicPath); String targetTopic = path.getLast();
String targetTopic = topicPath.getLast(); TopicNode targetParentNode = getTargetParentNode(path, targetTopic);
TopicNode targetParentNode = getTargetParentNode(topicPath, targetTopic);
List<List<String>> relatedTopics = new ArrayList<>(); List<List<String>> relatedTopics = new ArrayList<>();
//终点记忆节点 //终点记忆节点
@@ -389,6 +419,10 @@ public class MemoryGraph extends PersistableObject {
private void updateCache(List<String> topicPath, MemoryResult memoryResult) { private void updateCache(List<String> topicPath, MemoryResult memoryResult) {
Integer tempCount = memoryNodeCacheCounter.get(topicPath); Integer tempCount = memoryNodeCacheCounter.get(topicPath);
if (tempCount == null) {
log.error("tempCount为null? memoryNodeCacheCounter: {}; topicPath: {}", memoryNodeCacheCounter, topicPath);
return;
}
if (tempCount >= 5) { if (tempCount >= 5) {
memorySliceCache.put(topicPath, memoryResult); memorySliceCache.put(topicPath, memoryResult);
} }
@@ -404,17 +438,19 @@ public class MemoryGraph extends PersistableObject {
} }
private void checkCacheDate() { private void checkCacheDate() {
if ( cacheDate == null || cacheDate.isBefore(LocalDate.now())) { if (cacheDate == null || cacheDate.isBefore(LocalDate.now())) {
memorySliceCache.clear(); memorySliceCache.clear();
memoryNodeCacheCounter.clear(); memoryNodeCacheCounter.clear();
cacheDate = LocalDate.now(); cacheDate = LocalDate.now();
} }
} }
public MemoryResult selectMemory(LocalDate date) { public MemoryResult selectMemory(LocalDate date) throws IOException, ClassNotFoundException {
MemoryResult memoryResult = new MemoryResult(); MemoryResult memoryResult = new MemoryResult();
CopyOnWriteArrayList<MemorySliceResult> targetSliceList = new CopyOnWriteArrayList<>(); CopyOnWriteArrayList<MemorySliceResult> targetSliceList = new CopyOnWriteArrayList<>();
for (List<MemorySlice> value : dateIndex.get(date).values()) { //加载节点并获取记忆切片列表
List<List<MemorySlice>> currentDateDialogSlices = loadSlicesByDate(date);
for (List<MemorySlice> value : currentDateDialogSlices) {
for (MemorySlice memorySlice : value) { for (MemorySlice memorySlice : value) {
if (selectedSlices.contains(memorySlice.getTimestamp())) { if (selectedSlices.contains(memorySlice.getTimestamp())) {
continue; continue;
@@ -429,6 +465,19 @@ public class MemoryGraph extends PersistableObject {
return memoryResult; return memoryResult;
} }
private List<List<MemorySlice>> loadSlicesByDate(LocalDate date) throws IOException, ClassNotFoundException {
if (!dateIndex.containsKey(date)) {
throw new UnExistedDateIndexException("不存在的日期索引: " + date);
}
List<List<MemorySlice>> list = new ArrayList<>();
for (String memoryId : dateIndex.get(date)) {
MemoryNode memoryNode = new MemoryNode();
memoryNode.setMemoryNodeId(memoryId);
list.add(memoryNode.loadMemorySliceList());
}
return list;
}
private TopicNode getTargetParentNode(List<String> topicPath, String targetTopic) { private TopicNode getTargetParentNode(List<String> topicPath, String targetTopic) {
String topTopic = topicPath.getFirst(); String topTopic = topicPath.getFirst();
if (!existedTopics.containsKey(topTopic)) { if (!existedTopics.containsKey(topTopic)) {
@@ -468,7 +517,7 @@ public class MemoryGraph extends PersistableObject {
for (int i = 0; i < entries.size(); i++) { for (int i = 0; i < entries.size(); i++) {
boolean last = (i == entries.size() - 1); boolean last = (i == entries.size() - 1);
Map.Entry<String, TopicNode> entry = entries.get(i); Map.Entry<String, TopicNode> entry = entries.get(i);
stringBuilder.append(prefix).append(last ? "└── " : "├── ").append(entry.getKey()).append("\r\n"); stringBuilder.append(prefix).append(last ? "└── " : "├── ").append(entry.getKey()).append("[").append(entry.getValue().getMemoryNodes().size()).append("]").append("\r\n");
printSubTopicsTreeFormat(entry.getValue(), prefix + (last ? " " : ""), stringBuilder); printSubTopicsTreeFormat(entry.getValue(), prefix + (last ? " " : ""), stringBuilder);
} }
} }

View File

@@ -4,8 +4,6 @@ import lombok.Data;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import work.slhaf.agent.common.chat.pojo.Message; import work.slhaf.agent.common.chat.pojo.Message;
import work.slhaf.agent.common.config.Config; import work.slhaf.agent.common.config.Config;
import work.slhaf.agent.core.interaction.InteractionModule;
import work.slhaf.agent.core.interaction.data.InteractionContext;
import work.slhaf.agent.core.memory.pojo.MemoryResult; import work.slhaf.agent.core.memory.pojo.MemoryResult;
import work.slhaf.agent.core.memory.pojo.MemorySlice; import work.slhaf.agent.core.memory.pojo.MemorySlice;
import work.slhaf.agent.core.memory.pojo.User; import work.slhaf.agent.core.memory.pojo.User;
@@ -21,7 +19,7 @@ import java.util.concurrent.locks.ReentrantLock;
@Data @Data
@Slf4j @Slf4j
public class MemoryManager implements InteractionModule { public class MemoryManager {
private static MemoryManager memoryManager; private static MemoryManager memoryManager;
private final Lock sliceInsertLock = new ReentrantLock(); private final Lock sliceInsertLock = new ReentrantLock();
@@ -33,10 +31,6 @@ public class MemoryManager implements InteractionModule {
private MemoryManager() { private MemoryManager() {
} }
@Override
public void execute(InteractionContext interactionContext) {
}
public static MemoryManager getInstance() throws IOException, ClassNotFoundException { public static MemoryManager getInstance() throws IOException, ClassNotFoundException {
if (memoryManager == null) { if (memoryManager == null) {
@@ -44,24 +38,28 @@ public class MemoryManager implements InteractionModule {
memoryManager = new MemoryManager(); memoryManager = new MemoryManager();
memoryManager.setMemoryGraph(MemoryGraph.getInstance(config.getAgentId())); memoryManager.setMemoryGraph(MemoryGraph.getInstance(config.getAgentId()));
memoryManager.setActivatedSlices(new HashMap<>()); memoryManager.setActivatedSlices(new HashMap<>());
memoryManager.setShutdownHook();
log.info("MemoryManager注册完毕..."); log.info("MemoryManager注册完毕...");
Runtime.getRuntime().addShutdownHook(new Thread(() -> {
try {
memoryManager.save();
log.info("MemoryGraph已保存");
} catch (IOException e) {
log.error("保存MemoryGraph失败: ", e);
}
}));
} }
return memoryManager; return memoryManager;
} }
private void setShutdownHook() {
Runtime.getRuntime().addShutdownHook(new Thread(() -> {
try {
memoryManager.save();
log.info("MemoryGraph已保存");
} catch (IOException e) {
log.error("保存MemoryGraph失败: ", e);
}
}));
}
public MemoryResult selectMemory(String path) throws IOException, ClassNotFoundException { public MemoryResult selectMemory(String path) throws IOException, ClassNotFoundException {
return memoryGraph.selectMemory(path); return memoryGraph.selectMemory(path);
} }
public MemoryResult selectMemory(LocalDate date) { public MemoryResult selectMemory(LocalDate date) throws IOException, ClassNotFoundException {
return memoryGraph.selectMemory(date); return memoryGraph.selectMemory(date);
} }
@@ -118,14 +116,6 @@ public class MemoryManager implements InteractionModule {
return memoryGraph.getCharacter(); return memoryGraph.getCharacter();
} }
public void resetMemoryId() {
memoryGraph.setMemoryId(UUID.randomUUID().toString());
}
public String getMemoryId() {
return memoryGraph.getMemoryId();
}
public void insertSlice(MemorySlice memorySlice, String topicPath) throws IOException, ClassNotFoundException { public void insertSlice(MemorySlice memorySlice, String topicPath) throws IOException, ClassNotFoundException {
sliceInsertLock.lock(); sliceInsertLock.lock();
List<String> topicPathList = Arrays.stream(topicPath.split("->")).toList(); List<String> topicPathList = Arrays.stream(topicPath.split("->")).toList();

View File

@@ -0,0 +1,7 @@
package work.slhaf.agent.core.memory.exception;
public class UnExistedDateIndexException extends RuntimeException {
public UnExistedDateIndexException(String message) {
super(message);
}
}

View File

@@ -3,9 +3,9 @@ package work.slhaf.agent.core.memory.node;
import lombok.Data; import lombok.Data;
import lombok.EqualsAndHashCode; import lombok.EqualsAndHashCode;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import work.slhaf.agent.common.pojo.PersistableObject;
import work.slhaf.agent.core.memory.exception.NullSliceListException; import work.slhaf.agent.core.memory.exception.NullSliceListException;
import work.slhaf.agent.core.memory.pojo.MemorySlice; import work.slhaf.agent.core.memory.pojo.MemorySlice;
import work.slhaf.agent.common.pojo.PersistableObject;
import java.io.*; import java.io.*;
import java.nio.file.Files; import java.nio.file.Files;

View File

@@ -1,12 +1,20 @@
package work.slhaf.agent.core.memory.pojo; package work.slhaf.agent.core.memory.pojo;
import lombok.Data; import lombok.Data;
import lombok.EqualsAndHashCode;
import work.slhaf.agent.common.pojo.PersistableObject;
import java.io.Serial;
import java.util.List; import java.util.List;
import java.util.concurrent.CopyOnWriteArrayList; import java.util.concurrent.CopyOnWriteArrayList;
@EqualsAndHashCode(callSuper = true)
@Data @Data
public class MemoryResult { public class MemoryResult extends PersistableObject {
@Serial
private static final long serialVersionUID = 1L;
private CopyOnWriteArrayList<MemorySliceResult> memorySliceResult; private CopyOnWriteArrayList<MemorySliceResult> memorySliceResult;
private List<MemorySlice> relatedMemorySliceResult; private List<MemorySlice> relatedMemorySliceResult;
} }

View File

@@ -2,6 +2,7 @@ package work.slhaf.agent.core.memory.pojo;
import lombok.Data; import lombok.Data;
import lombok.EqualsAndHashCode; import lombok.EqualsAndHashCode;
import lombok.ToString;
import work.slhaf.agent.common.chat.pojo.Message; import work.slhaf.agent.common.chat.pojo.Message;
import work.slhaf.agent.common.pojo.PersistableObject; import work.slhaf.agent.common.pojo.PersistableObject;
@@ -40,6 +41,7 @@ public class MemorySlice extends PersistableObject implements Comparable<MemoryS
/** /**
* 关联完整对话中的前序切片, 排序为键,完整路径为值 * 关联完整对话中的前序切片, 排序为键,完整路径为值
*/ */
@ToString.Exclude
private MemorySlice sliceBefore, sliceAfter; private MemorySlice sliceBefore, sliceAfter;
/** /**

View File

@@ -1,30 +1,68 @@
package work.slhaf.agent.core.session; package work.slhaf.agent.core.session;
import lombok.Data; import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.extern.slf4j.Slf4j;
import work.slhaf.agent.common.chat.pojo.Message; import work.slhaf.agent.common.chat.pojo.Message;
import work.slhaf.agent.common.chat.pojo.MetaMessage; import work.slhaf.agent.common.chat.pojo.MetaMessage;
import work.slhaf.agent.common.config.Config;
import work.slhaf.agent.common.pojo.PersistableObject;
import java.io.*;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.HashMap; import java.util.HashMap;
import java.util.List; import java.util.List;
import java.util.UUID;
@EqualsAndHashCode(callSuper = true)
@Data @Data
public class SessionManager { @Slf4j
public class SessionManager extends PersistableObject {
@Serial
private static final long serialVersionUID = 1L;
private static final String STORAGE_DIR = "./data/session/";
private static SessionManager sessionManager; private static SessionManager sessionManager;
private String id;
private HashMap<String /*startUserId*/, List<MetaMessage>> singleMetaMessageMap; private HashMap<String /*startUserId*/, List<MetaMessage>> singleMetaMessageMap;
private String currentMemoryId;
private long lastUpdatedTime;
public static SessionManager getInstance() { public static SessionManager getInstance() throws IOException, ClassNotFoundException {
if (sessionManager == null) { if (sessionManager == null) {
sessionManager = new SessionManager(); String id = Config.getConfig().getAgentId();
sessionManager.setSingleMetaMessageMap(new HashMap<>()); Path filePath = Paths.get(STORAGE_DIR, id + ".session");
if (Files.exists(filePath)) {
sessionManager = deserialize(id);
} else {
sessionManager = new SessionManager();
sessionManager.setSingleMetaMessageMap(new HashMap<>());
sessionManager.id = id;
sessionManager.setShutdownHook();
sessionManager.lastUpdatedTime = 0;
}
} }
return sessionManager; return sessionManager;
} }
private void setShutdownHook() {
Runtime.getRuntime().addShutdownHook(new Thread(() -> {
try {
sessionManager.serialize();
log.info("SessionManager 已保存");
} catch (IOException e) {
log.error("保存 SessionManager 失败: ", e);
}
}));
}
public void addMetaMessage(String userId, MetaMessage metaMessage) { public void addMetaMessage(String userId, MetaMessage metaMessage) {
if (singleMetaMessageMap.containsKey(userId)) { if (singleMetaMessageMap.containsKey(userId)) {
singleMetaMessageMap.get(userId).add(metaMessage); singleMetaMessageMap.get(userId).add(metaMessage);
} else { } else {
singleMetaMessageMap.put(userId, new java.util.ArrayList<>()); singleMetaMessageMap.put(userId, new java.util.ArrayList<>());
@@ -42,4 +80,34 @@ public class SessionManager {
return messages; return messages;
} }
public void refreshMemoryId() {
currentMemoryId = UUID.randomUUID().toString();
}
public void serialize() throws IOException {
Path filePath = Paths.get(STORAGE_DIR, this.id + ".session");
Files.createDirectories(Path.of(STORAGE_DIR));
try (ObjectOutputStream oos = new ObjectOutputStream(
new FileOutputStream(filePath.toFile()))) {
oos.writeObject(this);
log.info("SessionManager 已保存到: {}", filePath);
} catch (IOException e) {
log.error("序列化保存失败: {}", e.getMessage());
}
}
private static SessionManager deserialize(String id) throws IOException, ClassNotFoundException {
Path filePath = Paths.get(STORAGE_DIR, id + ".session");
try (ObjectInputStream ois = new ObjectInputStream(new FileInputStream(filePath.toFile()))) {
SessionManager sessionManager = (SessionManager) ois.readObject();
log.info("SessionManager 已从文件加载: {}", filePath);
return sessionManager;
}
}
public void resetLastUpdatedTime() {
lastUpdatedTime = System.currentTimeMillis();
}
} }

View File

@@ -5,6 +5,8 @@ import lombok.extern.slf4j.Slf4j;
import work.slhaf.agent.core.interaction.InteractionModule; import work.slhaf.agent.core.interaction.InteractionModule;
import work.slhaf.agent.core.interaction.data.InteractionContext; import work.slhaf.agent.core.interaction.data.InteractionContext;
import work.slhaf.agent.core.memory.MemoryManager; import work.slhaf.agent.core.memory.MemoryManager;
import work.slhaf.agent.core.memory.exception.UnExistedDateIndexException;
import work.slhaf.agent.core.memory.exception.UnExistedTopicException;
import work.slhaf.agent.core.memory.pojo.MemoryResult; import work.slhaf.agent.core.memory.pojo.MemoryResult;
import work.slhaf.agent.core.memory.pojo.MemorySlice; import work.slhaf.agent.core.memory.pojo.MemorySlice;
import work.slhaf.agent.modules.memory.selector.evaluator.SliceSelectEvaluator; import work.slhaf.agent.modules.memory.selector.evaluator.SliceSelectEvaluator;
@@ -84,12 +86,6 @@ public class MemorySelector implements InteractionModule {
List<EvaluatedSlice> memorySlices = sliceSelectEvaluator.execute(evaluatorInput); List<EvaluatedSlice> memorySlices = sliceSelectEvaluator.execute(evaluatorInput);
memoryManager.getActivatedSlices().put(userId,memorySlices); memoryManager.getActivatedSlices().put(userId,memorySlices);
//向上下文设置切片存入标志,条件:对话历史列表不为空;触发了记忆查询
/*if (!memoryManager.getChatMessages().isEmpty()) {
interactionContext.getModuleContext().put("new_topic", true);
interactionContext.getModuleContext().put("messages_to_store", List.of(memoryManager.getChatMessages()));
}*/
} }
//设置上下文 //设置上下文
@@ -103,13 +99,18 @@ public class MemorySelector implements InteractionModule {
private void setMemoryResultList(List<MemoryResult> memoryResultList, List<ExtractorMatchData> matches, String userId) throws IOException, ClassNotFoundException { private void setMemoryResultList(List<MemoryResult> memoryResultList, List<ExtractorMatchData> matches, String userId) throws IOException, ClassNotFoundException {
for (ExtractorMatchData match : matches) { for (ExtractorMatchData match : matches) {
MemoryResult memoryResult = switch (match.getType()) { try {
case ExtractorMatchData.Constant.DATE -> memoryManager.selectMemory(match.getText()); MemoryResult memoryResult = switch (match.getType()) {
case ExtractorMatchData.Constant.TOPIC -> memoryManager.selectMemory(LocalDate.parse(match.getText())); case ExtractorMatchData.Constant.TOPIC -> memoryManager.selectMemory(match.getText());
default -> null; case ExtractorMatchData.Constant.DATE ->
}; memoryManager.selectMemory(LocalDate.parse(match.getText()));
if (memoryResult == null) continue; default -> null;
memoryResultList.add(memoryResult); };
if (memoryResult == null) continue;
memoryResultList.add(memoryResult);
}catch (UnExistedDateIndexException | UnExistedTopicException e) {
log.error("不存在的记忆索引! 请尝试更换更合适的主题提取LLM!");
}
} }
//清理切片记录 //清理切片记录
memoryManager.cleanSelectedSliceFilter(); memoryManager.cleanSelectedSliceFilter();
@@ -127,6 +128,6 @@ public class MemorySelector implements InteractionModule {
if (memorySlice.isPrivate()) { if (memorySlice.isPrivate()) {
return memorySlice.getStartUserId().equals(userId); return memorySlice.getStartUserId().equals(userId);
} }
return true; return false;
} }
} }

View File

@@ -59,7 +59,11 @@ public class SliceSelectEvaluator extends Model {
List<Callable<Void>> tasks = new ArrayList<>(); List<Callable<Void>> tasks = new ArrayList<>();
Queue<EvaluatedSlice> queue = new ConcurrentLinkedDeque<>(); Queue<EvaluatedSlice> queue = new ConcurrentLinkedDeque<>();
for (MemoryResult memoryResult : memoryResultList) { for (MemoryResult memoryResult : memoryResultList) {
if (memoryResult.getMemorySliceResult().isEmpty() && memoryResult.getRelatedMemorySliceResult().isEmpty()){
continue;
}
tasks.add(() -> { tasks.add(() -> {
log.debug("切片评估...");
List<SliceSummary> sliceSummaryList = new ArrayList<>(); List<SliceSummary> sliceSummaryList = new ArrayList<>();
//映射查找键值 //映射查找键值
Map<Long, SliceSummary> map = new HashMap<>(); Map<Long, SliceSummary> map = new HashMap<>();
@@ -71,6 +75,7 @@ public class SliceSelectEvaluator extends Model {
.history(evaluatorInput.getMessages()) .history(evaluatorInput.getMessages())
.build(); .build();
EvaluatorResult evaluatorResult = JSONObject.parseObject(extractJson(singleChat(JSONUtil.toJsonStr(batchInput)).getMessage()), EvaluatorResult.class); EvaluatorResult evaluatorResult = JSONObject.parseObject(extractJson(singleChat(JSONUtil.toJsonStr(batchInput)).getMessage()), EvaluatorResult.class);
log.debug("评估结果: {}", evaluatorResult);
for (Long result : evaluatorResult.getResults()) { for (Long result : evaluatorResult.getResults()) {
SliceSummary sliceSummary = map.get(result); SliceSummary sliceSummary = map.get(result);
EvaluatedSlice evaluatedSlice = EvaluatedSlice.builder() EvaluatedSlice evaluatedSlice = EvaluatedSlice.builder()

View File

@@ -15,6 +15,7 @@ import work.slhaf.agent.core.memory.MemoryManager;
import work.slhaf.agent.core.session.SessionManager; import work.slhaf.agent.core.session.SessionManager;
import work.slhaf.agent.modules.memory.selector.extractor.data.ExtractorInput; import work.slhaf.agent.modules.memory.selector.extractor.data.ExtractorInput;
import work.slhaf.agent.modules.memory.selector.extractor.data.ExtractorResult; import work.slhaf.agent.modules.memory.selector.extractor.data.ExtractorResult;
import work.slhaf.agent.shared.memory.EvaluatedSlice;
import java.io.IOException; import java.io.IOException;
import java.util.ArrayList; import java.util.ArrayList;
@@ -60,11 +61,14 @@ public class MemorySelectExtractor extends Model {
} }
} }
List<EvaluatedSlice> activatedMemorySlices = memoryManager.getActivatedSlices().get(context.getUserId());
ExtractorInput extractorInput = ExtractorInput.builder() ExtractorInput extractorInput = ExtractorInput.builder()
.text(context.getInput()) .text(context.getInput())
.date(context.getDateTime().toLocalDate()) .date(context.getDateTime().toLocalDate())
.history(chatMessages) .history(chatMessages)
.topic_tree(memoryManager.getTopicTree()) .topic_tree(memoryManager.getTopicTree())
.activatedMemorySlices(activatedMemorySlices)
.build(); .build();
String responseStr = extractJson(singleChat(JSONUtil.toJsonPrettyStr(extractorInput)).getMessage()); String responseStr = extractJson(singleChat(JSONUtil.toJsonPrettyStr(extractorInput)).getMessage());

View File

@@ -3,6 +3,7 @@ package work.slhaf.agent.modules.memory.selector.extractor.data;
import lombok.Builder; import lombok.Builder;
import lombok.Data; import lombok.Data;
import work.slhaf.agent.common.chat.pojo.Message; import work.slhaf.agent.common.chat.pojo.Message;
import work.slhaf.agent.shared.memory.EvaluatedSlice;
import java.time.LocalDate; import java.time.LocalDate;
import java.util.List; import java.util.List;
@@ -14,4 +15,5 @@ public class ExtractorInput {
private String topic_tree; private String topic_tree;
private LocalDate date; private LocalDate date;
private List<Message> history; private List<Message> history;
private List<EvaluatedSlice> activatedMemorySlices;
} }

View File

@@ -41,7 +41,6 @@ public class MemoryUpdater implements InteractionModule {
private MemorySummarizer memorySummarizer; private MemorySummarizer memorySummarizer;
private SessionManager sessionManager; private SessionManager sessionManager;
private StaticMemoryExtractor staticMemoryExtractor; private StaticMemoryExtractor staticMemoryExtractor;
private long lastUpdatedTime = 0;
private MemoryUpdater() { private MemoryUpdater() {
} }
@@ -67,8 +66,13 @@ public class MemoryUpdater implements InteractionModule {
while (!Thread.interrupted()) { while (!Thread.interrupted()) {
try { try {
long currentTime = System.currentTimeMillis(); long currentTime = System.currentTimeMillis();
if (lastUpdatedTime != 0 && currentTime - lastUpdatedTime > UPDATE_TRIGGER_INTERVAL) { long lastUpdatedTime = sessionManager.getLastUpdatedTime();
int chatCount = memoryManager.getChatMessages().size();
if (lastUpdatedTime != 0 && currentTime - lastUpdatedTime > UPDATE_TRIGGER_INTERVAL && chatCount > 1) {
updateMemory(); updateMemory();
//重置MemoryId
sessionManager.refreshMemoryId();
log.info("记忆更新: 自动触发");
} }
Thread.sleep(SCHEDULED_UPDATE_INTERVAL); Thread.sleep(SCHEDULED_UPDATE_INTERVAL);
} catch (Exception e) { } catch (Exception e) {
@@ -90,52 +94,65 @@ public class MemoryUpdater implements InteractionModule {
if (moduleContext.getIntValue("total_token") > 24000) { if (moduleContext.getIntValue("total_token") > 24000) {
try { try {
updateMemory(); updateMemory();
log.info("记忆更新: token超限");
} catch (Exception e) { } catch (Exception e) {
log.error("记忆更新线程出错: {}", e.getLocalizedMessage()); log.error("记忆更新线程出错: {}", e.getLocalizedMessage());
} }
} }
}); });
sessionManager.resetLastUpdatedTime();
} }
private void updateMemory() throws InterruptedException, IOException, ClassNotFoundException { private void updateMemory() throws IOException, ClassNotFoundException {
HashMap<String, String> singleMemorySummary = new HashMap<>(); HashMap<String, String> singleMemorySummary = new HashMap<>();
//更新单聊记忆以及该场景中对应的确定性记忆同时从chatMessages中去掉单聊记忆 //更新单聊记忆以及该场景中对应的确定性记忆同时从chatMessages中去掉单聊记忆
updateSingleChatSlices(singleMemorySummary); updateSingleChatSlices(singleMemorySummary);
//更新多人场景下的记忆及相关的确定性记忆 //更新多人场景下的记忆及相关的确定性记忆
updateMultiChatSlices(singleMemorySummary); updateMultiChatSlices(singleMemorySummary);
//更新最近更新时间 //清空chatMessages
lastUpdatedTime = System.currentTimeMillis(); clearChatMessages();
} }
private void updateMultiChatSlices(HashMap<String, String> singleMemorySummary) throws InterruptedException, IOException, ClassNotFoundException { private void updateMultiChatSlices(HashMap<String, String> singleMemorySummary) {
//此时chatMessages中不再包含单聊记录直接执行摘要以及切片插入 //此时chatMessages中不再包含单聊记录直接执行摘要以及切片插入
//对剩下的多人聊天记录进行进行摘要 //对剩下的多人聊天记录进行进行摘要
executor.execute(() -> { executor.execute(() -> {
try { try {
//以第一条user对应的id为发起用户 List<Message> chatMessages = new ArrayList<>(memoryManager.getChatMessages());
Pattern pattern = Pattern.compile(USERID_REGEX); chatMessages.removeFirst();
Matcher matcher = pattern.matcher(memoryManager.getChatMessages().get(1).getContent()); if (!chatMessages.isEmpty()) {
if (!matcher.find()){ //以第一条user对应的id为发起用户
throw new RuntimeException("未匹配到 userId!"); Pattern pattern = Pattern.compile(USERID_REGEX);
Matcher matcher = pattern.matcher(chatMessages.getFirst().getContent());
if (!matcher.find()) {
throw new RuntimeException("未匹配到 userId!");
}
String userId = matcher.group(1);
SummarizeResult summarizeResult = memorySummarizer.execute(new SummarizeInput(chatMessages, memoryManager.getTopicTree()));
MemorySlice memorySlice = getMemorySlice(userId, summarizeResult, chatMessages);
//设置involvedUserId
setInvolvedUserId(userId, memorySlice, chatMessages);
memoryManager.insertSlice(memorySlice, summarizeResult.getTopicPath());
if (!singleMemorySummary.isEmpty()) {
memoryManager.updateDialogMap(LocalDateTime.now(), summarizeResult.getSummary());
}
}else{
memoryManager.updateDialogMap(LocalDateTime.now(),memorySummarizer.executeTotalSummary(singleMemorySummary));
} }
String userId = matcher.group(1);
SummarizeResult summarizeResult = memorySummarizer.execute(new SummarizeInput(memoryManager.getChatMessages(), memoryManager.getTopicTree()));
MemorySlice memorySlice = getMemorySlice(userId, summarizeResult, memoryManager.getChatMessages());
//设置involvedUserId
List<Message> messages = new ArrayList<>(memoryManager.getChatMessages());
messages.removeFirst();
setInvolvedUserId(userId, memorySlice, messages);
memoryManager.insertSlice(memorySlice, summarizeResult.getTopicPath());
//更新总dialogMap
singleMemorySummary.put("total", summarizeResult.getSummary());
memoryManager.updateDialogMap(LocalDateTime.now(), memorySummarizer.executeTotalSummary(singleMemorySummary));
} catch (IOException | ClassNotFoundException | InterruptedException e) { } catch (IOException | ClassNotFoundException | InterruptedException e) {
log.error("多人场景记忆更新失败: {}", e.getLocalizedMessage()); log.error("多人场景记忆更新失败: {}", e.getLocalizedMessage());
} }
}); });
} }
private void clearChatMessages() {
Message first = memoryManager.getChatMessages().getFirst();
memoryManager.getChatMessages().clear();
memoryManager.getChatMessages().add(first);
}
private void setInvolvedUserId(String startUserId, MemorySlice memorySlice, List<Message> chatMessages) { private void setInvolvedUserId(String startUserId, MemorySlice memorySlice, List<Message> chatMessages) {
for (Message chatMessage : chatMessages) { for (Message chatMessage : chatMessages) {
if (chatMessage.getRole().equals(ChatConstant.Character.ASSISTANT)) { if (chatMessage.getRole().equals(ChatConstant.Character.ASSISTANT)) {
@@ -158,7 +175,7 @@ public class MemoryUpdater implements InteractionModule {
} }
private void updateSingleChatSlices(HashMap<String, String> singleMemorySummary) throws InterruptedException { private void updateSingleChatSlices(HashMap<String, String> singleMemorySummary) {
//更新单聊记忆同时从chatMessages中去掉单聊记忆 //更新单聊记忆同时从chatMessages中去掉单聊记忆
Set<String> userIdSet = new HashSet<>(sessionManager.getSingleMetaMessageMap().keySet()); Set<String> userIdSet = new HashSet<>(sessionManager.getSingleMetaMessageMap().keySet());
List<Callable<Void>> tasks = new ArrayList<>(); List<Callable<Void>> tasks = new ArrayList<>();
@@ -177,7 +194,7 @@ public class MemoryUpdater implements InteractionModule {
//添加至singleMemorySummary //添加至singleMemorySummary
singleMemorySummary.put(id, summarizeResult.getSummary()); singleMemorySummary.put(id, summarizeResult.getSummary());
} catch (Exception e) { } catch (Exception e) {
log.error("单聊记忆更新出错: {}", e.getLocalizedMessage()); log.error("单聊记忆更新出错: ", e);
} }
return null; return null;
}); });
@@ -196,8 +213,13 @@ public class MemoryUpdater implements InteractionModule {
executor.invokeAll(tasks); executor.invokeAll(tasks);
} }
private static MemorySlice getMemorySlice(String userId, SummarizeResult summarizeResult, List<Message> chatMessages) { private MemorySlice getMemorySlice(String userId, SummarizeResult summarizeResult, List<Message> chatMessages) {
MemorySlice memorySlice = new MemorySlice(); MemorySlice memorySlice = new MemorySlice();
//设置 memoryId,timestamp
memorySlice.setMemoryId(sessionManager.getCurrentMemoryId());
memorySlice.setTimestamp(System.currentTimeMillis());
//补充信息
memorySlice.setPrivate(summarizeResult.isPrivate()); memorySlice.setPrivate(summarizeResult.isPrivate());
memorySlice.setSummary(summarizeResult.getSummary()); memorySlice.setSummary(summarizeResult.getSummary());
memorySlice.setChatMessages(chatMessages); memorySlice.setChatMessages(chatMessages);

View File

@@ -98,7 +98,7 @@ public class MemorySummarizer extends Model {
public String executeTotalSummary(HashMap<String, String> singleMemorySummary) { public String executeTotalSummary(HashMap<String, String> singleMemorySummary) {
ChatResponse response = chatClient.runChat(List.of(new Message(ChatConstant.Character.SYSTEM, prompts.get(2)), ChatResponse response = chatClient.runChat(List.of(new Message(ChatConstant.Character.SYSTEM, prompts.get(2)),
new Message(ChatConstant.Character.USER, JSONUtil.toJsonPrettyStr(singleMemorySummary)))); new Message(ChatConstant.Character.USER, JSONUtil.toJsonPrettyStr(singleMemorySummary))));
return JSONObject.parseObject(extractJson(response.getMessage())).getString("value"); return JSONObject.parseObject(extractJson(response.getMessage())).getString("content");
} }
private static class Constant { private static class Constant {
@@ -137,29 +137,13 @@ public class MemorySummarizer extends Model {
4. 最终校验:检查是否丢失关键信息 4. 最终校验:检查是否丢失关键信息
完整示例 完整示例
示例1常规长文本 示例:
输入:{ 输入:{
"content": "在2023年第四季度XX公司实现了显著增长。财报显示总收入达到4.56亿元同比增长32%。其中主要增长来自智能手机业务板块该板块贡献了3.12亿元收入同比增长达45%。同时智能家居业务收入1.44亿元同比增长12%。公司CEO在财报电话会议中强调增长主要得益于东南亚市场的成功拓展..." "content": "在2023年第四季度XX公司实现了显著增长。财报显示总收入达到4.56亿元同比增长32%。其中主要增长来自智能手机业务板块该板块贡献了3.12亿元收入同比增长达45%。同时智能家居业务收入1.44亿元同比增长12%。公司CEO在财报电话会议中强调增长主要得益于东南亚市场的成功拓展..."
} }
输出:{ 输出:{
"content": "XX公司2023年Q4总收入4.56亿元(同比+32%智能手机业务贡献3.12亿元(+45%智能家居1.44亿元(+12%),增长主要来自东南亚市场拓展。" "content": "XX公司2023年Q4总收入4.56亿元(同比+32%智能手机业务贡献3.12亿元(+45%智能家居1.44亿元(+12%),增长主要来自东南亚市场拓展。"
} }
示例2多段落文本
输入:{
"content": "本次项目改造涉及三个主要方面。首先硬件升级包括1) 更换全部服务器设备2) 安装新的网络交换机3) 部署智能安防系统。其次,软件系统将迁移至新平台,需完成数据迁移和接口适配。最后,人员培训计划分三阶段实施..."
}
输出:{
"content": "项目改造含硬件升级(更换服务器、新交换机、智能安防)、软件系统迁移(含数据迁移和接口适配)及分三阶段的人员培训。"
}
示例3技术文档
输入:{
"content": "该算法采用改进的卷积神经网络架构包含3个主要模块特征提取模块由5个卷积层组成、注意力机制模块含通道和空间注意力、以及分类模块使用2个全连接层。在ImageNet数据集上达到92.3%的准确率..."
}
输出:{
"content": "算法使用改进CNN架构含特征提取5卷积层、注意力机制通道+空间和分类模块2全连接层在ImageNet上准确率92.3%。"
}
"""; """;
public static final String MULTI_SUMMARIZE_PROMPT = """ public static final String MULTI_SUMMARIZE_PROMPT = """
@@ -185,7 +169,7 @@ public class MemorySummarizer extends Model {
2. 主题路径生成细则: 2. 主题路径生成细则:
• 抽象链构建流程: • 抽象链构建流程:
a. 以`user`的意图为主要锚点,锁定最低节点 a. 以`user`的输入内容意图为主要锚点,锁定最低节点
b. 逐层抽象(地标→城市→国家→大洲),需保证抽象链的纯净,确保不会跨越领域 b. 逐层抽象(地标→城市→国家→大洲),需保证抽象链的纯净,确保不会跨越领域
c. 修剪抽象链,使其保持在[3, 7]层之内,同时每层的抽象节点考虑扩展性及可复用性 c. 修剪抽象链,使其保持在[3, 7]层之内,同时每层的抽象节点考虑扩展性及可复用性
d. 形成最终路径(格式:领域→大类→子类→实例) d. 形成最终路径(格式:领域→大类→子类→实例)
@@ -226,7 +210,7 @@ public class MemorySummarizer extends Model {
b. 技术术语需符合行业标准 b. 技术术语需符合行业标准
完整示例 完整示例
示例1日常分享 示例:
输入:{ 输入:{
"topicTree": " "topicTree": "
生活[root] 生活[root]
@@ -237,7 +221,7 @@ public class MemorySummarizer extends Model {
] ]
} }
输出:{ 输出:{
"summary": "用户分享欧洲自由行经历并讨论夜景照片处理", "summary": "用户分享欧洲自由行经历并讨论夜景照片处理...",
"topicPath": "生活->旅行->自由行->欧洲->法国->巴黎铁塔", "topicPath": "生活->旅行->自由行->欧洲->法国->巴黎铁塔",
"relatedTopicPath": [ "relatedTopicPath": [
"艺术->摄影->夜景拍摄", "艺术->摄影->夜景拍摄",
@@ -245,65 +229,6 @@ public class MemorySummarizer extends Model {
], ],
"isPrivate": false "isPrivate": false
} }
示例2专业咨询
输入:{
"topicTree": "
计算机[root]
└── 编程",
"chatMessages": [
{"role": "user", "content": "SpringBoot项目如何实现JWT鉴权"},
{"role": "assistant", "content": "需集成spring-security-jwt依赖..."}
]
}
输出:{
"summary": "讨论SpringBoot项目集成JWT鉴权的技术方案",
"topicPath": "计算机->软件开发->Java->SpringBoot->安全->JWT",
"relatedTopicPath": [
"计算机->网络安全->认证协议",
"数学->加密算法->非对称加密"
],
"isPrivate": false
}
示例3事件讨论
输入:{
"topicTree": "
社会[root]
├── 教育
└── 科技",
"chatMessages": [
{"role": "user", "content": "听说某大学研发出脑机接口新成果"},
{"role": "assistant", "content": "该技术涉及神经科学和AI的跨学科研究"}
]
}
输出:{
"summary": "讨论某大学在脑机接口领域的跨学科研究成果",
"topicPath": "社会->科技->人工智能->脑机接口",
"relatedTopicPath": [
"科学->生物学->神经科学",
"教育->高等教育->科研创新"
],
"isPrivate": false
}
示例4隐私事件
输入:{
"topicTree": "
法律[root]
└── 隐私",
"chatMessages": [
{"role": "user", "content": "这个合同条款请仅限我们之间知晓"},
{"role": "assistant", "content": "已启用加密存储,不会外泄"}
]
}
输出:{
"summary": "用户要求保密合同条款内容",
"topicPath": "法律->合同法->保密条款",
"relatedTopicPath": ["信息技术->数据安全->加密存储"],
"isPrivate": true
}
"""; """;
public static final String TOTAL_SUMMARIZE_PROMPT = """ public static final String TOTAL_SUMMARIZE_PROMPT = """
@@ -347,7 +272,7 @@ public class MemorySummarizer extends Model {
c. 验证信息完整性 c. 验证信息完整性
完整示例 完整示例
示例1基础情况 示例:
输入:{ 输入:{
"aaa-111": "需要购买笔记本电脑预算5000左右主要用于办公", "aaa-111": "需要购买笔记本电脑预算5000左右主要用于办公",
"bbb-222": "想买游戏本预算8000-10000要能运行3A大作", "bbb-222": "想买游戏本预算8000-10000要能运行3A大作",
@@ -360,30 +285,6 @@ public class MemorySummarizer extends Model {
用户[ccc-333]:咨询适合出差使用的轻薄本" 用户[ccc-333]:咨询适合出差使用的轻薄本"
} }
示例2信息合并
输入:{
"ddd-444": "想了解Python入门课程零基础",
"eee-555": "询问Java和Python哪个更适合新手",
"fff-666": "零基础想学Python数据分析"
}
输出:{
"content": "
用户[ddd-444]零基础想了解Python入门课程
用户[eee-555]询问Java和Python对新手的适用性
用户[fff-666]零基础想学习Python数据分析"
}
示例3长文本精简
输入:{
"ggg-777": "您好!我最近在准备考研,想咨询下时间规划。具体是想了解每天应该分配多少时间给英语复习,我现在英语水平大概是四级刚过的程度...后续200字详细描述",
"hhh-888": "考研政治怎么准备?需要报班吗?"
}
输出:{
"content": "
用户[ggg-777]:咨询考研英语复习时间规划,当前英语水平为四级;
用户[hhh-888]:询问考研政治备考方法及是否需要报班"
}
特殊处理 特殊处理
1. 当总字数超出限制时: 1. 当总字数超出限制时:
• 尽量保留所有出现的用户摘要 • 尽量保留所有出现的用户摘要

View File

@@ -5,6 +5,7 @@ import lombok.Data;
import work.slhaf.agent.core.interaction.data.InteractionContext; import work.slhaf.agent.core.interaction.data.InteractionContext;
import work.slhaf.agent.core.interaction.data.InteractionInputData; import work.slhaf.agent.core.interaction.data.InteractionInputData;
import work.slhaf.agent.core.memory.MemoryManager; import work.slhaf.agent.core.memory.MemoryManager;
import work.slhaf.agent.core.session.SessionManager;
import java.io.IOException; import java.io.IOException;
import java.time.LocalDateTime; import java.time.LocalDateTime;
@@ -16,6 +17,7 @@ public class PreprocessExecutor {
private static PreprocessExecutor preprocessExecutor; private static PreprocessExecutor preprocessExecutor;
private MemoryManager memoryManager; private MemoryManager memoryManager;
private SessionManager sessionManager;
private PreprocessExecutor() { private PreprocessExecutor() {
} }
@@ -24,14 +26,27 @@ public class PreprocessExecutor {
if (preprocessExecutor == null) { if (preprocessExecutor == null) {
preprocessExecutor = new PreprocessExecutor(); preprocessExecutor = new PreprocessExecutor();
preprocessExecutor.setMemoryManager(MemoryManager.getInstance()); preprocessExecutor.setMemoryManager(MemoryManager.getInstance());
preprocessExecutor.setSessionManager(SessionManager.getInstance());
} }
return preprocessExecutor; return preprocessExecutor;
} }
public InteractionContext execute(InteractionInputData inputData) { public InteractionContext execute(InteractionInputData inputData) {
InteractionContext context = new InteractionContext(); checkAndSetMemoryId();
String userId = memoryManager.getUserId(inputData.getUserInfo(), inputData.getUserNickName()); return getInteractionContext(inputData);
}
private void checkAndSetMemoryId() {
String currentMemoryId = sessionManager.getCurrentMemoryId();
if (currentMemoryId == null || memoryManager.getChatMessages().isEmpty()) {
sessionManager.refreshMemoryId();
}
}
private InteractionContext getInteractionContext(InteractionInputData inputData) {
InteractionContext context = new InteractionContext();
String userId = memoryManager.getUserId(inputData.getUserInfo(), inputData.getUserNickName());
context.setUserId(userId); context.setUserId(userId);
context.setUserNickname(inputData.getUserNickName()); context.setUserNickname(inputData.getUserNickName());
context.setUserInfo(inputData.getUserInfo()); context.setUserInfo(inputData.getUserInfo());
@@ -55,7 +70,6 @@ public class PreprocessExecutor {
context.setSingle(inputData.isSingle()); context.setSingle(inputData.isSingle());
context.setFinished(false); context.setFinished(false);
return context; return context;
} }
} }

View File

@@ -1,34 +1,30 @@
package memory; package memory;
import java.util.concurrent.ExecutorService; import org.junit.jupiter.api.Test;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.Executors; import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
public class ThreadPoolTest { public class ThreadPoolTest {
public static void main(String[] args) throws InterruptedException {
testExecutor(Executors.newVirtualThreadPerTaskExecutor());
// Thread.sleep(2000); // 等待系统输出稳定 @Test
public void testExecutor() throws InterruptedException {
// testExecutor("普通线程池", Executors.newFixedThreadPool(100)); List<Callable<Void>> tasks = new ArrayList<>();
} for (int i = 0; i < 5; i++) {
int finalI = i;
private static void testExecutor(ExecutorService es) throws InterruptedException { tasks.add(() -> {
long start = System.currentTimeMillis(); System.out.println("开始: " + finalI);
Thread.sleep(5000);
for (int i = 0; i < 100000; i++) { System.out.println("结束: " + finalI);
es.submit(() -> { return null;
Thread.sleep(1000);
return 0;
}); });
} }
es.shutdown(); Executors.newVirtualThreadPerTaskExecutor().invokeAll(tasks, 10, TimeUnit.SECONDS);
if (es.awaitTermination(5, TimeUnit.MINUTES)) {
long end = System.currentTimeMillis(); System.out.println("hello");
System.out.println("虚拟线程" + "耗时:" + (end - start));
} else {
System.err.println("虚拟线程" + "未能在规定时间内完成所有任务");
}
} }
} }