mirror of
https://github.com/slhaf/Partner.git
synced 2026-05-12 08:43:02 +08:00
进行第一阶段的调试修复
- 添加 DebugMonitor 加了一个较为无用的线程,打上断点用于获取即时模块信息 - 调整 MemoryGraph 中临时切片前后序生成容器与日期索引分离,日期索引将存储 localDate 和 memoryId - 修复了几个 MemoryGraph 用到的类对于序列化的实现 - 将 MemoryManager 中维护的 activatedMemorySlice 同样作为主题提取的依据,并调整了提示词 - 因 主题提取LLM 效果不稳定,故添加了必要的异常处理机制 - MemorySlice 中前后序机制存在循环引用问题,排除 toString 的调用 - 移除了提示词中过多的示例,仅保留一份 - 记忆自动更新线程调整:在 SessionManager 中维护 lastUpdatedTime ,用于标识最近聊天时间; 计数当前对话数量,如果只有系统提示词则不进行记忆更新 - MemoryGraph 获取主题树时将标识记忆节点数量,供主题提取模型识别,减少空主题节点作为目标节点的情况 - 在 PreprocessExecutor 执行时将先判断是否存在memoryId, 除此之外,memoryId也将在记忆自动更新线程执行后刷新 - 将 SessionManager 添加了序列化机制,添加了程序停止时自动序列化保存的钩子 - 当主题提取模型指定了不包含记忆节点的主题节点时,MemoryResult 的空属性将会使其跳过无效的切片评估 待解决问题: - MemoryGraph 在进行记忆更新时有时会出现错误,出现条件不明 - MemorySelector 拿到的List<EvaluatedSlice> 总是为空,原因不明 - 切片评估线程似乎没有运行,原因不明
This commit is contained in:
@@ -2,8 +2,8 @@ package work.slhaf.agent;
|
|||||||
|
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.java_websocket.WebSocket;
|
|
||||||
import work.slhaf.agent.common.config.Config;
|
import work.slhaf.agent.common.config.Config;
|
||||||
|
import work.slhaf.agent.common.monitor.DebugMonitor;
|
||||||
import work.slhaf.agent.core.InteractionHub;
|
import work.slhaf.agent.core.InteractionHub;
|
||||||
import work.slhaf.agent.core.interaction.InputReceiver;
|
import work.slhaf.agent.core.interaction.InputReceiver;
|
||||||
import work.slhaf.agent.core.interaction.TaskCallback;
|
import work.slhaf.agent.core.interaction.TaskCallback;
|
||||||
@@ -34,6 +34,9 @@ public class Agent implements TaskCallback, InputReceiver {
|
|||||||
server.launch();
|
server.launch();
|
||||||
agent.setMessageSender(server);
|
agent.setMessageSender(server);
|
||||||
log.info("Agent 加载完毕..");
|
log.info("Agent 加载完毕..");
|
||||||
|
|
||||||
|
//启动监测线程
|
||||||
|
DebugMonitor.initialize();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -2,10 +2,19 @@ package work.slhaf.agent.common.chat.pojo;
|
|||||||
|
|
||||||
import lombok.AllArgsConstructor;
|
import lombok.AllArgsConstructor;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
|
import lombok.EqualsAndHashCode;
|
||||||
|
import work.slhaf.agent.common.pojo.PersistableObject;
|
||||||
|
|
||||||
|
import java.io.Serial;
|
||||||
|
|
||||||
|
@EqualsAndHashCode(callSuper = true)
|
||||||
@Data
|
@Data
|
||||||
@AllArgsConstructor
|
@AllArgsConstructor
|
||||||
public class MetaMessage {
|
public class MetaMessage extends PersistableObject {
|
||||||
|
|
||||||
|
@Serial
|
||||||
|
private static final long serialVersionUID = 1L;
|
||||||
|
|
||||||
private Message userMessage;
|
private Message userMessage;
|
||||||
private Message assistantMessage;
|
private Message assistantMessage;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -12,7 +12,6 @@ import work.slhaf.agent.modules.memory.updater.MemoryUpdater;
|
|||||||
import work.slhaf.agent.modules.memory.updater.static_extractor.StaticMemoryExtractor;
|
import work.slhaf.agent.modules.memory.updater.static_extractor.StaticMemoryExtractor;
|
||||||
import work.slhaf.agent.modules.memory.updater.summarizer.MemorySummarizer;
|
import work.slhaf.agent.modules.memory.updater.summarizer.MemorySummarizer;
|
||||||
import work.slhaf.agent.modules.task.TaskEvaluator;
|
import work.slhaf.agent.modules.task.TaskEvaluator;
|
||||||
import work.slhaf.agent.modules.task.TaskScheduler;
|
|
||||||
|
|
||||||
import java.io.File;
|
import java.io.File;
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
|
|||||||
@@ -178,12 +178,14 @@ public class ModelConstant {
|
|||||||
输入字段说明
|
输入字段说明
|
||||||
• `text`: 用户当前输入的文本内容
|
• `text`: 用户当前输入的文本内容
|
||||||
|
|
||||||
• `topic_tree`: 当前可用的主题树结构(多层级结构,需返回从根节点([root])到目标节点的完整路径)
|
• `topic_tree`: 当前可用的主题树结构(多层级结构,需返回从根节点([root])到目标节点的完整路径), 主题树中类似`[0]`的标志为主题节点下对应的记忆节点数量,当记忆节点数量为0时,该主题节点不能作为目标节点
|
||||||
|
|
||||||
• `date`: 当前对话发生的日期(用于时间推理)
|
• `date`: 当前对话发生的日期(用于时间推理)
|
||||||
|
|
||||||
• `history`: 用户与LLM的完整对话历史(用于主题连续性判断)
|
• `history`: 用户与LLM的完整对话历史(用于主题连续性判断)
|
||||||
|
|
||||||
|
• `activated_memory_slices`: 已经激活的记忆切片
|
||||||
|
|
||||||
|
|
||||||
输出规则
|
输出规则
|
||||||
1. 基本响应格式:
|
1. 基本响应格式:
|
||||||
@@ -205,6 +207,10 @@ public class ModelConstant {
|
|||||||
|
|
||||||
◦ 除非包含明确的新子主题,否则不重复提取相同主题路径
|
◦ 除非包含明确的新子主题,否则不重复提取相同主题路径
|
||||||
|
|
||||||
|
• 当激活的记忆切片已经不符合当前主题时:
|
||||||
|
|
||||||
|
◦ 除非主题树中存在匹配的主题路径,否则仍不进行提取操作
|
||||||
|
|
||||||
|
|
||||||
3. 日期提取规则(保持不变):
|
3. 日期提取规则(保持不变):
|
||||||
• 仅接受具体日期(YYYY-MM-DD格式)
|
• 仅接受具体日期(YYYY-MM-DD格式)
|
||||||
@@ -226,11 +232,12 @@ public class ModelConstant {
|
|||||||
|
|
||||||
决策流程
|
决策流程
|
||||||
0. 若主题树为空或者未提供主题树,则直接将recall设置为null, 不进行后续判定
|
0. 若主题树为空或者未提供主题树,则直接将recall设置为null, 不进行后续判定
|
||||||
1. 首先分析`history`判断当前对话主题上下文
|
1. 对于所有记忆节点个数为0的主题节点来说,这些节点不能作为主题路径的终点
|
||||||
2. 然后分析`text`:
|
2. 首先分析`text`:
|
||||||
a. 检测是否包含具体日期→添加date类型
|
a. 检测用户提到的具体日期是否明确与某事物/事件相关→添加date类型
|
||||||
b. 检测是否包含新主题→添加topic类型
|
b. 检测用户提到的事物/事件是否明确与主题树中存在的主题路径相关→添加topic类型
|
||||||
3. 最终综合判断`recall`值, 如果找到了对应的主题路径,则recall值为true; 否则为false
|
3. 分析`history`判断当前对话主题上下文, 如果与`text`中的内容明显无关,则仅只依据`text`内容提取主题路径
|
||||||
|
4. 最终综合判断`recall`值, 如果找到了对应的主题路径,则recall值为true; 否则为false
|
||||||
|
|
||||||
完整示例
|
完整示例
|
||||||
示例1(主题延续):
|
示例1(主题延续):
|
||||||
@@ -238,12 +245,12 @@ public class ModelConstant {
|
|||||||
"text": "关于NodeJS的并发处理,还有哪些要注意的",
|
"text": "关于NodeJS的并发处理,还有哪些要注意的",
|
||||||
"topic_tree": "
|
"topic_tree": "
|
||||||
编程[root]
|
编程[root]
|
||||||
├── JavaScript
|
├── JavaScript[0]
|
||||||
│ ├── NodeJS
|
│ ├── NodeJS
|
||||||
│ │ ├── 并发处理
|
│ │ ├── 并发处理[1]
|
||||||
│ │ └── 事件循环
|
│ │ └── 事件循环[0]
|
||||||
│ └── Express
|
│ └── Express[1]
|
||||||
│ └── 中间件
|
│ └── 中间件[0]
|
||||||
└── Python",
|
└── Python",
|
||||||
"date": "2024-04-20",
|
"date": "2024-04-20",
|
||||||
"history": [
|
"history": [
|
||||||
@@ -261,12 +268,12 @@ public class ModelConstant {
|
|||||||
"text": "现在我想了解Express中间件的原理",
|
"text": "现在我想了解Express中间件的原理",
|
||||||
"topic_tree": "
|
"topic_tree": "
|
||||||
编程[root]
|
编程[root]
|
||||||
├── JavaScript
|
├── JavaScript[0]
|
||||||
│ ├── NodeJS
|
│ ├── NodeJS[0]
|
||||||
│ │ ├── 并发处理
|
│ │ ├── 并发处理[1]
|
||||||
│ │ └── 事件循环
|
│ │ └── 事件循环[0]
|
||||||
│ └── Express
|
│ └── Express[0]
|
||||||
│ └── 中间件
|
│ └── 中间件[1]
|
||||||
└── Python",
|
└── Python",
|
||||||
"date": "2024-04-20",
|
"date": "2024-04-20",
|
||||||
"history": [
|
"history": [
|
||||||
@@ -286,12 +293,12 @@ public class ModelConstant {
|
|||||||
"text": "2024-04-15讨论的Python内容和现在的Express需求",
|
"text": "2024-04-15讨论的Python内容和现在的Express需求",
|
||||||
"topic_tree": "
|
"topic_tree": "
|
||||||
编程[root]
|
编程[root]
|
||||||
├── JavaScript
|
├── JavaScript[0]
|
||||||
│ ├── NodeJS
|
│ ├── NodeJS[0]
|
||||||
│ │ ├── 并发处理
|
│ │ ├── 并发处理[1]
|
||||||
│ │ └── 事件循环
|
│ │ └── 事件循环[1]
|
||||||
│ └── Express
|
│ └── Express[1]
|
||||||
│ └── 中间件
|
│ └── 中间件[0]
|
||||||
└── Python",
|
└── Python",
|
||||||
"date": "2024-04-20",
|
"date": "2024-04-20",
|
||||||
"history": [
|
"history": [
|
||||||
@@ -312,12 +319,12 @@ public class ModelConstant {
|
|||||||
"text": "上周说的那个JavaScript特性",
|
"text": "上周说的那个JavaScript特性",
|
||||||
"topic_tree": "
|
"topic_tree": "
|
||||||
编程[root]
|
编程[root]
|
||||||
├── JavaScript
|
├── JavaScript[0]
|
||||||
│ ├── NodeJS
|
│ ├── NodeJS[0]
|
||||||
│ │ ├── 并发处理
|
│ │ ├── 并发处理[1]
|
||||||
│ │ └── 事件循环
|
│ │ └── 事件循环[1]
|
||||||
│ └── Express
|
│ └── Express[0]
|
||||||
│ └── 中间件
|
│ └── 中间件[1]
|
||||||
└── Python",
|
└── Python",
|
||||||
"date": "2024-04-20",
|
"date": "2024-04-20",
|
||||||
"history": [...]
|
"history": [...]
|
||||||
|
|||||||
@@ -0,0 +1,36 @@
|
|||||||
|
package work.slhaf.agent.common.monitor;
|
||||||
|
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import work.slhaf.agent.core.interaction.InteractionThreadPoolExecutor;
|
||||||
|
|
||||||
|
@Slf4j
|
||||||
|
public class DebugMonitor {
|
||||||
|
|
||||||
|
private InteractionThreadPoolExecutor executor;
|
||||||
|
private static DebugMonitor debugMonitor;
|
||||||
|
|
||||||
|
public static void initialize() {
|
||||||
|
debugMonitor = new DebugMonitor();
|
||||||
|
debugMonitor.executor = InteractionThreadPoolExecutor.getInstance();
|
||||||
|
debugMonitor.runMonitor();
|
||||||
|
}
|
||||||
|
|
||||||
|
private void runMonitor() {
|
||||||
|
executor.execute(() -> {
|
||||||
|
while (true) {
|
||||||
|
try {
|
||||||
|
Thread.sleep(3000);
|
||||||
|
} catch (InterruptedException e) {
|
||||||
|
log.error("监测线程报错?");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
public static DebugMonitor getInstance(){
|
||||||
|
if (debugMonitor == null) {
|
||||||
|
initialize();
|
||||||
|
}
|
||||||
|
return debugMonitor;
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,15 +1,20 @@
|
|||||||
package work.slhaf.agent.core.memory;
|
package work.slhaf.agent.core.memory;
|
||||||
|
|
||||||
|
import cn.hutool.json.JSONUtil;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
import lombok.EqualsAndHashCode;
|
import lombok.EqualsAndHashCode;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.apache.commons.io.FileUtils;
|
import org.apache.commons.io.FileUtils;
|
||||||
import work.slhaf.agent.common.chat.pojo.Message;
|
import work.slhaf.agent.common.chat.pojo.Message;
|
||||||
import work.slhaf.agent.common.pojo.PersistableObject;
|
import work.slhaf.agent.common.pojo.PersistableObject;
|
||||||
|
import work.slhaf.agent.core.memory.exception.UnExistedDateIndexException;
|
||||||
import work.slhaf.agent.core.memory.exception.UnExistedTopicException;
|
import work.slhaf.agent.core.memory.exception.UnExistedTopicException;
|
||||||
import work.slhaf.agent.core.memory.node.MemoryNode;
|
import work.slhaf.agent.core.memory.node.MemoryNode;
|
||||||
import work.slhaf.agent.core.memory.node.TopicNode;
|
import work.slhaf.agent.core.memory.node.TopicNode;
|
||||||
import work.slhaf.agent.core.memory.pojo.*;
|
import work.slhaf.agent.core.memory.pojo.MemoryResult;
|
||||||
|
import work.slhaf.agent.core.memory.pojo.MemorySlice;
|
||||||
|
import work.slhaf.agent.core.memory.pojo.MemorySliceResult;
|
||||||
|
import work.slhaf.agent.core.memory.pojo.User;
|
||||||
|
|
||||||
import java.io.*;
|
import java.io.*;
|
||||||
import java.nio.file.Files;
|
import java.nio.file.Files;
|
||||||
@@ -44,10 +49,14 @@ public class MemoryGraph extends PersistableObject {
|
|||||||
private HashMap<String /*根主题名*/, LinkedHashSet<String> /*子主题列表*/> existedTopics;
|
private HashMap<String /*根主题名*/, LinkedHashSet<String> /*子主题列表*/> existedTopics;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 记忆节点的日期索引, 同一日期内按照对话id区分
|
* 临时的同一对话切片容器, 用于为同一对话内的不同切片提供更新上下文的场所
|
||||||
* 同时作为临时的同一对话切片容器, 用于为同一对话内的不同切片提供更新上下文的场所
|
|
||||||
*/
|
*/
|
||||||
private HashMap<LocalDate, HashMap<String /*对话id, 即slice中的字段'memoryId'*/, List<MemorySlice>>> dateIndex;
|
private HashMap<String /*对话id, 即slice中的字段'memoryId'*/, List<MemorySlice>> currentDateDialogSlices;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 记忆节点的日期索引, 同一日期内按照对话id区分
|
||||||
|
*/
|
||||||
|
private HashMap<LocalDate, Set<String>> dateIndex;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 近两日的对话总结缓存, 用于为大模型提供必要的记忆补充, hashmap以切片的存储时间为键,总结为值
|
* 近两日的对话总结缓存, 用于为大模型提供必要的记忆补充, hashmap以切片的存储时间为键,总结为值
|
||||||
@@ -110,13 +119,11 @@ public class MemoryGraph extends PersistableObject {
|
|||||||
*/
|
*/
|
||||||
private Set<Long> selectedSlices;
|
private Set<Long> selectedSlices;
|
||||||
|
|
||||||
private String memoryId;
|
|
||||||
|
|
||||||
public MemoryGraph(String id) {
|
public MemoryGraph(String id) {
|
||||||
this.id = id;
|
this.id = id;
|
||||||
this.topicNodes = new HashMap<>();
|
this.topicNodes = new HashMap<>();
|
||||||
this.existedTopics = new HashMap<>();
|
this.existedTopics = new HashMap<>();
|
||||||
this.dateIndex = new HashMap<>();
|
this.currentDateDialogSlices = new HashMap<>();
|
||||||
this.staticMemory = new HashMap<>();
|
this.staticMemory = new HashMap<>();
|
||||||
this.memoryNodeCacheCounter = new ConcurrentHashMap<>();
|
this.memoryNodeCacheCounter = new ConcurrentHashMap<>();
|
||||||
this.memorySliceCache = new ConcurrentHashMap<>();
|
this.memorySliceCache = new ConcurrentHashMap<>();
|
||||||
@@ -129,6 +136,7 @@ public class MemoryGraph extends PersistableObject {
|
|||||||
this.character = """
|
this.character = """
|
||||||
实话实说,不做糖衣炮弹。 采取前瞻性的观点。 始终保持尊重。 乐于分享明确的观点。 保持轻松、随和。 直奔主题。 务实至上。 勇于创新,打破常规思维。使用中文回答所有问题。
|
实话实说,不做糖衣炮弹。 采取前瞻性的观点。 始终保持尊重。 乐于分享明确的观点。 保持轻松、随和。 直奔主题。 务实至上。 勇于创新,打破常规思维。使用中文回答所有问题。
|
||||||
""";
|
""";
|
||||||
|
this.dateIndex = new HashMap<>();
|
||||||
}
|
}
|
||||||
|
|
||||||
public static MemoryGraph getInstance(String id) throws IOException, ClassNotFoundException {
|
public static MemoryGraph getInstance(String id) throws IOException, ClassNotFoundException {
|
||||||
@@ -195,12 +203,18 @@ public class MemoryGraph extends PersistableObject {
|
|||||||
LocalDate now = LocalDate.now();
|
LocalDate now = LocalDate.now();
|
||||||
boolean hasSlice = false;
|
boolean hasSlice = false;
|
||||||
MemoryNode node = null;
|
MemoryNode node = null;
|
||||||
for (MemoryNode memoryNode : lastTopicNode.getMemoryNodes()) {
|
try {
|
||||||
if (now.equals(memoryNode.getLocalDate())) {
|
for (MemoryNode memoryNode : lastTopicNode.getMemoryNodes()) {
|
||||||
hasSlice = true;
|
if (now.equals(memoryNode.getLocalDate())) {
|
||||||
node = memoryNode;
|
hasSlice = true;
|
||||||
break;
|
node = memoryNode;
|
||||||
|
break;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
} catch (Exception e) {
|
||||||
|
log.error("插入记忆时出错: ", e);
|
||||||
|
log.error("主题路径: {}; 切片内容: {}", topicPath, slice);
|
||||||
|
log.error("主题树状态: {}", JSONUtil.toJsonPrettyStr(topicNodes));
|
||||||
}
|
}
|
||||||
if (!hasSlice) {
|
if (!hasSlice) {
|
||||||
node = new MemoryNode();
|
node = new MemoryNode();
|
||||||
@@ -217,13 +231,27 @@ public class MemoryGraph extends PersistableObject {
|
|||||||
generateTopicPath(relatedTopic);
|
generateTopicPath(relatedTopic);
|
||||||
}
|
}
|
||||||
|
|
||||||
updateDateIndex(now, slice);
|
updateSlicePrecedent(slice);
|
||||||
|
updateDateIndex(slice);
|
||||||
|
|
||||||
if (!slice.isPrivate()) {
|
if (!slice.isPrivate()) {
|
||||||
updateUserDialogMap(slice);
|
updateUserDialogMap(slice);
|
||||||
}
|
}
|
||||||
node.saveMemorySliceList();
|
node.saveMemorySliceList();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private void updateDateIndex(MemorySlice slice) {
|
||||||
|
String memoryId = slice.getMemoryId();
|
||||||
|
LocalDate date = LocalDate.now();
|
||||||
|
if (!dateIndex.containsKey(date)) {
|
||||||
|
HashSet<String> memoryIdSet = new HashSet<>();
|
||||||
|
memoryIdSet.add(memoryId);
|
||||||
|
dateIndex.put(date, memoryIdSet);
|
||||||
|
} else {
|
||||||
|
dateIndex.get(date).add(memoryId);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
private TopicNode generateTopicPath(List<String> topicPath) {
|
private TopicNode generateTopicPath(List<String> topicPath) {
|
||||||
topicPath = new ArrayList<>(topicPath);
|
topicPath = new ArrayList<>(topicPath);
|
||||||
//查看是否存在根主题节点
|
//查看是否存在根主题节点
|
||||||
@@ -240,11 +268,17 @@ public class MemoryGraph extends PersistableObject {
|
|||||||
TopicNode lastTopicNode = topicNodes.get(rootTopic);
|
TopicNode lastTopicNode = topicNodes.get(rootTopic);
|
||||||
Set<String> existedTopicNodes = existedTopics.get(rootTopic);
|
Set<String> existedTopicNodes = existedTopics.get(rootTopic);
|
||||||
for (String topic : topicPath) {
|
for (String topic : topicPath) {
|
||||||
if (existedTopicNodes.contains(topic)) {
|
if (existedTopicNodes.contains(topic) && lastTopicNode.getTopicNodes().containsKey(topic)) {
|
||||||
lastTopicNode = lastTopicNode.getTopicNodes().get(topic);
|
lastTopicNode = lastTopicNode.getTopicNodes().get(topic);
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
TopicNode newNode = new TopicNode();
|
TopicNode newNode = new TopicNode();
|
||||||
lastTopicNode.getTopicNodes().put(topic, newNode);
|
try {
|
||||||
|
lastTopicNode.getTopicNodes().put(topic, newNode);
|
||||||
|
} catch (Exception e) {
|
||||||
|
log.error("主题路径: {}; ", topicPath);
|
||||||
|
log.error("主题树状态: {}", JSONUtil.toJsonPrettyStr(topicNodes));
|
||||||
|
}
|
||||||
lastTopicNode = newNode;
|
lastTopicNode = newNode;
|
||||||
CopyOnWriteArrayList<MemoryNode> nodeList = new CopyOnWriteArrayList<>();
|
CopyOnWriteArrayList<MemoryNode> nodeList = new CopyOnWriteArrayList<>();
|
||||||
lastTopicNode.setMemoryNodes(nodeList);
|
lastTopicNode.setMemoryNodes(nodeList);
|
||||||
@@ -281,16 +315,12 @@ public class MemoryGraph extends PersistableObject {
|
|||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private void updateDateIndex(LocalDate now, MemorySlice slice) {
|
private void updateSlicePrecedent(MemorySlice slice) {
|
||||||
String memoryId = slice.getMemoryId();
|
String memoryId = slice.getMemoryId();
|
||||||
//查看是否存在当前日期的对话切片索引
|
//查看是否切换了memoryId
|
||||||
if (!dateIndex.containsKey(now)) {
|
|
||||||
dateIndex.put(now, new HashMap<>());
|
|
||||||
}
|
|
||||||
//查看当前日期的索引中是否存在该对话的索引
|
|
||||||
HashMap<String, List<MemorySlice>> currentDateDialogSlices = dateIndex.get(now);
|
|
||||||
if (!currentDateDialogSlices.containsKey(memoryId)) {
|
if (!currentDateDialogSlices.containsKey(memoryId)) {
|
||||||
List<MemorySlice> memorySliceList = new ArrayList<>();
|
List<MemorySlice> memorySliceList = new ArrayList<>();
|
||||||
|
currentDateDialogSlices.clear();
|
||||||
currentDateDialogSlices.put(memoryId, memorySliceList);
|
currentDateDialogSlices.put(memoryId, memorySliceList);
|
||||||
}
|
}
|
||||||
//处理上下文关系
|
//处理上下文关系
|
||||||
@@ -313,20 +343,20 @@ public class MemoryGraph extends PersistableObject {
|
|||||||
|
|
||||||
public MemoryResult selectMemory(String topicPathStr) throws IOException, ClassNotFoundException {
|
public MemoryResult selectMemory(String topicPathStr) throws IOException, ClassNotFoundException {
|
||||||
List<String> topicPath = List.of(topicPathStr.split("->"));
|
List<String> topicPath = List.of(topicPathStr.split("->"));
|
||||||
|
List<String> path = new ArrayList<>(topicPath);
|
||||||
MemoryResult memoryResult = new MemoryResult();
|
MemoryResult memoryResult = new MemoryResult();
|
||||||
|
|
||||||
//每日刷新缓存
|
//每日刷新缓存
|
||||||
checkCacheDate();
|
checkCacheDate();
|
||||||
//检测缓存并更新计数, 查看是否需要放入缓存
|
//检测缓存并更新计数, 查看是否需要放入缓存
|
||||||
updateCacheCounter(topicPath);
|
updateCacheCounter(path);
|
||||||
//查看是否存在缓存,如果存在,则直接返回
|
//查看是否存在缓存,如果存在,则直接返回
|
||||||
if (memorySliceCache.containsKey(topicPath)) {
|
if (memorySliceCache.containsKey(path)) {
|
||||||
return memorySliceCache.get(topicPath);
|
return memorySliceCache.get(path);
|
||||||
}
|
}
|
||||||
CopyOnWriteArrayList<MemorySliceResult> targetSliceList = new CopyOnWriteArrayList<>();
|
CopyOnWriteArrayList<MemorySliceResult> targetSliceList = new CopyOnWriteArrayList<>();
|
||||||
topicPath = new ArrayList<>(topicPath);
|
String targetTopic = path.getLast();
|
||||||
String targetTopic = topicPath.getLast();
|
TopicNode targetParentNode = getTargetParentNode(path, targetTopic);
|
||||||
TopicNode targetParentNode = getTargetParentNode(topicPath, targetTopic);
|
|
||||||
List<List<String>> relatedTopics = new ArrayList<>();
|
List<List<String>> relatedTopics = new ArrayList<>();
|
||||||
|
|
||||||
//终点记忆节点
|
//终点记忆节点
|
||||||
@@ -389,6 +419,10 @@ public class MemoryGraph extends PersistableObject {
|
|||||||
|
|
||||||
private void updateCache(List<String> topicPath, MemoryResult memoryResult) {
|
private void updateCache(List<String> topicPath, MemoryResult memoryResult) {
|
||||||
Integer tempCount = memoryNodeCacheCounter.get(topicPath);
|
Integer tempCount = memoryNodeCacheCounter.get(topicPath);
|
||||||
|
if (tempCount == null) {
|
||||||
|
log.error("tempCount为null? memoryNodeCacheCounter: {}; topicPath: {}", memoryNodeCacheCounter, topicPath);
|
||||||
|
return;
|
||||||
|
}
|
||||||
if (tempCount >= 5) {
|
if (tempCount >= 5) {
|
||||||
memorySliceCache.put(topicPath, memoryResult);
|
memorySliceCache.put(topicPath, memoryResult);
|
||||||
}
|
}
|
||||||
@@ -404,17 +438,19 @@ public class MemoryGraph extends PersistableObject {
|
|||||||
}
|
}
|
||||||
|
|
||||||
private void checkCacheDate() {
|
private void checkCacheDate() {
|
||||||
if ( cacheDate == null || cacheDate.isBefore(LocalDate.now())) {
|
if (cacheDate == null || cacheDate.isBefore(LocalDate.now())) {
|
||||||
memorySliceCache.clear();
|
memorySliceCache.clear();
|
||||||
memoryNodeCacheCounter.clear();
|
memoryNodeCacheCounter.clear();
|
||||||
cacheDate = LocalDate.now();
|
cacheDate = LocalDate.now();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public MemoryResult selectMemory(LocalDate date) {
|
public MemoryResult selectMemory(LocalDate date) throws IOException, ClassNotFoundException {
|
||||||
MemoryResult memoryResult = new MemoryResult();
|
MemoryResult memoryResult = new MemoryResult();
|
||||||
CopyOnWriteArrayList<MemorySliceResult> targetSliceList = new CopyOnWriteArrayList<>();
|
CopyOnWriteArrayList<MemorySliceResult> targetSliceList = new CopyOnWriteArrayList<>();
|
||||||
for (List<MemorySlice> value : dateIndex.get(date).values()) {
|
//加载节点并获取记忆切片列表
|
||||||
|
List<List<MemorySlice>> currentDateDialogSlices = loadSlicesByDate(date);
|
||||||
|
for (List<MemorySlice> value : currentDateDialogSlices) {
|
||||||
for (MemorySlice memorySlice : value) {
|
for (MemorySlice memorySlice : value) {
|
||||||
if (selectedSlices.contains(memorySlice.getTimestamp())) {
|
if (selectedSlices.contains(memorySlice.getTimestamp())) {
|
||||||
continue;
|
continue;
|
||||||
@@ -429,6 +465,19 @@ public class MemoryGraph extends PersistableObject {
|
|||||||
return memoryResult;
|
return memoryResult;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private List<List<MemorySlice>> loadSlicesByDate(LocalDate date) throws IOException, ClassNotFoundException {
|
||||||
|
if (!dateIndex.containsKey(date)) {
|
||||||
|
throw new UnExistedDateIndexException("不存在的日期索引: " + date);
|
||||||
|
}
|
||||||
|
List<List<MemorySlice>> list = new ArrayList<>();
|
||||||
|
for (String memoryId : dateIndex.get(date)) {
|
||||||
|
MemoryNode memoryNode = new MemoryNode();
|
||||||
|
memoryNode.setMemoryNodeId(memoryId);
|
||||||
|
list.add(memoryNode.loadMemorySliceList());
|
||||||
|
}
|
||||||
|
return list;
|
||||||
|
}
|
||||||
|
|
||||||
private TopicNode getTargetParentNode(List<String> topicPath, String targetTopic) {
|
private TopicNode getTargetParentNode(List<String> topicPath, String targetTopic) {
|
||||||
String topTopic = topicPath.getFirst();
|
String topTopic = topicPath.getFirst();
|
||||||
if (!existedTopics.containsKey(topTopic)) {
|
if (!existedTopics.containsKey(topTopic)) {
|
||||||
@@ -468,7 +517,7 @@ public class MemoryGraph extends PersistableObject {
|
|||||||
for (int i = 0; i < entries.size(); i++) {
|
for (int i = 0; i < entries.size(); i++) {
|
||||||
boolean last = (i == entries.size() - 1);
|
boolean last = (i == entries.size() - 1);
|
||||||
Map.Entry<String, TopicNode> entry = entries.get(i);
|
Map.Entry<String, TopicNode> entry = entries.get(i);
|
||||||
stringBuilder.append(prefix).append(last ? "└── " : "├── ").append(entry.getKey()).append("\r\n");
|
stringBuilder.append(prefix).append(last ? "└── " : "├── ").append(entry.getKey()).append("[").append(entry.getValue().getMemoryNodes().size()).append("]").append("\r\n");
|
||||||
printSubTopicsTreeFormat(entry.getValue(), prefix + (last ? " " : "│ "), stringBuilder);
|
printSubTopicsTreeFormat(entry.getValue(), prefix + (last ? " " : "│ "), stringBuilder);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,8 +4,6 @@ import lombok.Data;
|
|||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import work.slhaf.agent.common.chat.pojo.Message;
|
import work.slhaf.agent.common.chat.pojo.Message;
|
||||||
import work.slhaf.agent.common.config.Config;
|
import work.slhaf.agent.common.config.Config;
|
||||||
import work.slhaf.agent.core.interaction.InteractionModule;
|
|
||||||
import work.slhaf.agent.core.interaction.data.InteractionContext;
|
|
||||||
import work.slhaf.agent.core.memory.pojo.MemoryResult;
|
import work.slhaf.agent.core.memory.pojo.MemoryResult;
|
||||||
import work.slhaf.agent.core.memory.pojo.MemorySlice;
|
import work.slhaf.agent.core.memory.pojo.MemorySlice;
|
||||||
import work.slhaf.agent.core.memory.pojo.User;
|
import work.slhaf.agent.core.memory.pojo.User;
|
||||||
@@ -21,7 +19,7 @@ import java.util.concurrent.locks.ReentrantLock;
|
|||||||
|
|
||||||
@Data
|
@Data
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class MemoryManager implements InteractionModule {
|
public class MemoryManager {
|
||||||
|
|
||||||
private static MemoryManager memoryManager;
|
private static MemoryManager memoryManager;
|
||||||
private final Lock sliceInsertLock = new ReentrantLock();
|
private final Lock sliceInsertLock = new ReentrantLock();
|
||||||
@@ -33,10 +31,6 @@ public class MemoryManager implements InteractionModule {
|
|||||||
private MemoryManager() {
|
private MemoryManager() {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public void execute(InteractionContext interactionContext) {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
public static MemoryManager getInstance() throws IOException, ClassNotFoundException {
|
public static MemoryManager getInstance() throws IOException, ClassNotFoundException {
|
||||||
if (memoryManager == null) {
|
if (memoryManager == null) {
|
||||||
@@ -44,24 +38,28 @@ public class MemoryManager implements InteractionModule {
|
|||||||
memoryManager = new MemoryManager();
|
memoryManager = new MemoryManager();
|
||||||
memoryManager.setMemoryGraph(MemoryGraph.getInstance(config.getAgentId()));
|
memoryManager.setMemoryGraph(MemoryGraph.getInstance(config.getAgentId()));
|
||||||
memoryManager.setActivatedSlices(new HashMap<>());
|
memoryManager.setActivatedSlices(new HashMap<>());
|
||||||
|
memoryManager.setShutdownHook();
|
||||||
log.info("MemoryManager注册完毕...");
|
log.info("MemoryManager注册完毕...");
|
||||||
Runtime.getRuntime().addShutdownHook(new Thread(() -> {
|
|
||||||
try {
|
|
||||||
memoryManager.save();
|
|
||||||
log.info("MemoryGraph已保存");
|
|
||||||
} catch (IOException e) {
|
|
||||||
log.error("保存MemoryGraph失败: ", e);
|
|
||||||
}
|
|
||||||
}));
|
|
||||||
}
|
}
|
||||||
return memoryManager;
|
return memoryManager;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private void setShutdownHook() {
|
||||||
|
Runtime.getRuntime().addShutdownHook(new Thread(() -> {
|
||||||
|
try {
|
||||||
|
memoryManager.save();
|
||||||
|
log.info("MemoryGraph已保存");
|
||||||
|
} catch (IOException e) {
|
||||||
|
log.error("保存MemoryGraph失败: ", e);
|
||||||
|
}
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
|
||||||
public MemoryResult selectMemory(String path) throws IOException, ClassNotFoundException {
|
public MemoryResult selectMemory(String path) throws IOException, ClassNotFoundException {
|
||||||
return memoryGraph.selectMemory(path);
|
return memoryGraph.selectMemory(path);
|
||||||
}
|
}
|
||||||
|
|
||||||
public MemoryResult selectMemory(LocalDate date) {
|
public MemoryResult selectMemory(LocalDate date) throws IOException, ClassNotFoundException {
|
||||||
return memoryGraph.selectMemory(date);
|
return memoryGraph.selectMemory(date);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -118,14 +116,6 @@ public class MemoryManager implements InteractionModule {
|
|||||||
return memoryGraph.getCharacter();
|
return memoryGraph.getCharacter();
|
||||||
}
|
}
|
||||||
|
|
||||||
public void resetMemoryId() {
|
|
||||||
memoryGraph.setMemoryId(UUID.randomUUID().toString());
|
|
||||||
}
|
|
||||||
|
|
||||||
public String getMemoryId() {
|
|
||||||
return memoryGraph.getMemoryId();
|
|
||||||
}
|
|
||||||
|
|
||||||
public void insertSlice(MemorySlice memorySlice, String topicPath) throws IOException, ClassNotFoundException {
|
public void insertSlice(MemorySlice memorySlice, String topicPath) throws IOException, ClassNotFoundException {
|
||||||
sliceInsertLock.lock();
|
sliceInsertLock.lock();
|
||||||
List<String> topicPathList = Arrays.stream(topicPath.split("->")).toList();
|
List<String> topicPathList = Arrays.stream(topicPath.split("->")).toList();
|
||||||
|
|||||||
@@ -0,0 +1,7 @@
|
|||||||
|
package work.slhaf.agent.core.memory.exception;
|
||||||
|
|
||||||
|
public class UnExistedDateIndexException extends RuntimeException {
|
||||||
|
public UnExistedDateIndexException(String message) {
|
||||||
|
super(message);
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -3,9 +3,9 @@ package work.slhaf.agent.core.memory.node;
|
|||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
import lombok.EqualsAndHashCode;
|
import lombok.EqualsAndHashCode;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import work.slhaf.agent.common.pojo.PersistableObject;
|
||||||
import work.slhaf.agent.core.memory.exception.NullSliceListException;
|
import work.slhaf.agent.core.memory.exception.NullSliceListException;
|
||||||
import work.slhaf.agent.core.memory.pojo.MemorySlice;
|
import work.slhaf.agent.core.memory.pojo.MemorySlice;
|
||||||
import work.slhaf.agent.common.pojo.PersistableObject;
|
|
||||||
|
|
||||||
import java.io.*;
|
import java.io.*;
|
||||||
import java.nio.file.Files;
|
import java.nio.file.Files;
|
||||||
|
|||||||
@@ -1,12 +1,20 @@
|
|||||||
package work.slhaf.agent.core.memory.pojo;
|
package work.slhaf.agent.core.memory.pojo;
|
||||||
|
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
|
import lombok.EqualsAndHashCode;
|
||||||
|
import work.slhaf.agent.common.pojo.PersistableObject;
|
||||||
|
|
||||||
|
import java.io.Serial;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.concurrent.CopyOnWriteArrayList;
|
import java.util.concurrent.CopyOnWriteArrayList;
|
||||||
|
|
||||||
|
@EqualsAndHashCode(callSuper = true)
|
||||||
@Data
|
@Data
|
||||||
public class MemoryResult {
|
public class MemoryResult extends PersistableObject {
|
||||||
|
|
||||||
|
@Serial
|
||||||
|
private static final long serialVersionUID = 1L;
|
||||||
|
|
||||||
private CopyOnWriteArrayList<MemorySliceResult> memorySliceResult;
|
private CopyOnWriteArrayList<MemorySliceResult> memorySliceResult;
|
||||||
private List<MemorySlice> relatedMemorySliceResult;
|
private List<MemorySlice> relatedMemorySliceResult;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package work.slhaf.agent.core.memory.pojo;
|
|||||||
|
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
import lombok.EqualsAndHashCode;
|
import lombok.EqualsAndHashCode;
|
||||||
|
import lombok.ToString;
|
||||||
import work.slhaf.agent.common.chat.pojo.Message;
|
import work.slhaf.agent.common.chat.pojo.Message;
|
||||||
import work.slhaf.agent.common.pojo.PersistableObject;
|
import work.slhaf.agent.common.pojo.PersistableObject;
|
||||||
|
|
||||||
@@ -40,6 +41,7 @@ public class MemorySlice extends PersistableObject implements Comparable<MemoryS
|
|||||||
/**
|
/**
|
||||||
* 关联完整对话中的前序切片, 排序为键,完整路径为值
|
* 关联完整对话中的前序切片, 排序为键,完整路径为值
|
||||||
*/
|
*/
|
||||||
|
@ToString.Exclude
|
||||||
private MemorySlice sliceBefore, sliceAfter;
|
private MemorySlice sliceBefore, sliceAfter;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|||||||
@@ -1,30 +1,68 @@
|
|||||||
package work.slhaf.agent.core.session;
|
package work.slhaf.agent.core.session;
|
||||||
|
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
|
import lombok.EqualsAndHashCode;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import work.slhaf.agent.common.chat.pojo.Message;
|
import work.slhaf.agent.common.chat.pojo.Message;
|
||||||
import work.slhaf.agent.common.chat.pojo.MetaMessage;
|
import work.slhaf.agent.common.chat.pojo.MetaMessage;
|
||||||
|
import work.slhaf.agent.common.config.Config;
|
||||||
|
import work.slhaf.agent.common.pojo.PersistableObject;
|
||||||
|
|
||||||
|
import java.io.*;
|
||||||
|
import java.nio.file.Files;
|
||||||
|
import java.nio.file.Path;
|
||||||
|
import java.nio.file.Paths;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.HashMap;
|
import java.util.HashMap;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
import java.util.UUID;
|
||||||
|
|
||||||
|
@EqualsAndHashCode(callSuper = true)
|
||||||
@Data
|
@Data
|
||||||
public class SessionManager {
|
@Slf4j
|
||||||
|
public class SessionManager extends PersistableObject {
|
||||||
|
|
||||||
|
@Serial
|
||||||
|
private static final long serialVersionUID = 1L;
|
||||||
|
private static final String STORAGE_DIR = "./data/session/";
|
||||||
|
|
||||||
private static SessionManager sessionManager;
|
private static SessionManager sessionManager;
|
||||||
|
|
||||||
|
private String id;
|
||||||
private HashMap<String /*startUserId*/, List<MetaMessage>> singleMetaMessageMap;
|
private HashMap<String /*startUserId*/, List<MetaMessage>> singleMetaMessageMap;
|
||||||
|
private String currentMemoryId;
|
||||||
|
private long lastUpdatedTime;
|
||||||
|
|
||||||
public static SessionManager getInstance() {
|
public static SessionManager getInstance() throws IOException, ClassNotFoundException {
|
||||||
if (sessionManager == null) {
|
if (sessionManager == null) {
|
||||||
sessionManager = new SessionManager();
|
String id = Config.getConfig().getAgentId();
|
||||||
sessionManager.setSingleMetaMessageMap(new HashMap<>());
|
Path filePath = Paths.get(STORAGE_DIR, id + ".session");
|
||||||
|
if (Files.exists(filePath)) {
|
||||||
|
sessionManager = deserialize(id);
|
||||||
|
} else {
|
||||||
|
sessionManager = new SessionManager();
|
||||||
|
sessionManager.setSingleMetaMessageMap(new HashMap<>());
|
||||||
|
sessionManager.id = id;
|
||||||
|
sessionManager.setShutdownHook();
|
||||||
|
sessionManager.lastUpdatedTime = 0;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return sessionManager;
|
return sessionManager;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private void setShutdownHook() {
|
||||||
|
Runtime.getRuntime().addShutdownHook(new Thread(() -> {
|
||||||
|
try {
|
||||||
|
sessionManager.serialize();
|
||||||
|
log.info("SessionManager 已保存");
|
||||||
|
} catch (IOException e) {
|
||||||
|
log.error("保存 SessionManager 失败: ", e);
|
||||||
|
}
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
|
||||||
public void addMetaMessage(String userId, MetaMessage metaMessage) {
|
public void addMetaMessage(String userId, MetaMessage metaMessage) {
|
||||||
if (singleMetaMessageMap.containsKey(userId)) {
|
if (singleMetaMessageMap.containsKey(userId)) {
|
||||||
singleMetaMessageMap.get(userId).add(metaMessage);
|
singleMetaMessageMap.get(userId).add(metaMessage);
|
||||||
} else {
|
} else {
|
||||||
singleMetaMessageMap.put(userId, new java.util.ArrayList<>());
|
singleMetaMessageMap.put(userId, new java.util.ArrayList<>());
|
||||||
@@ -42,4 +80,34 @@ public class SessionManager {
|
|||||||
return messages;
|
return messages;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public void refreshMemoryId() {
|
||||||
|
currentMemoryId = UUID.randomUUID().toString();
|
||||||
|
}
|
||||||
|
|
||||||
|
public void serialize() throws IOException {
|
||||||
|
Path filePath = Paths.get(STORAGE_DIR, this.id + ".session");
|
||||||
|
Files.createDirectories(Path.of(STORAGE_DIR));
|
||||||
|
try (ObjectOutputStream oos = new ObjectOutputStream(
|
||||||
|
new FileOutputStream(filePath.toFile()))) {
|
||||||
|
oos.writeObject(this);
|
||||||
|
log.info("SessionManager 已保存到: {}", filePath);
|
||||||
|
} catch (IOException e) {
|
||||||
|
log.error("序列化保存失败: {}", e.getMessage());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private static SessionManager deserialize(String id) throws IOException, ClassNotFoundException {
|
||||||
|
Path filePath = Paths.get(STORAGE_DIR, id + ".session");
|
||||||
|
try (ObjectInputStream ois = new ObjectInputStream(new FileInputStream(filePath.toFile()))) {
|
||||||
|
SessionManager sessionManager = (SessionManager) ois.readObject();
|
||||||
|
log.info("SessionManager 已从文件加载: {}", filePath);
|
||||||
|
return sessionManager;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public void resetLastUpdatedTime() {
|
||||||
|
lastUpdatedTime = System.currentTimeMillis();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -5,6 +5,8 @@ import lombok.extern.slf4j.Slf4j;
|
|||||||
import work.slhaf.agent.core.interaction.InteractionModule;
|
import work.slhaf.agent.core.interaction.InteractionModule;
|
||||||
import work.slhaf.agent.core.interaction.data.InteractionContext;
|
import work.slhaf.agent.core.interaction.data.InteractionContext;
|
||||||
import work.slhaf.agent.core.memory.MemoryManager;
|
import work.slhaf.agent.core.memory.MemoryManager;
|
||||||
|
import work.slhaf.agent.core.memory.exception.UnExistedDateIndexException;
|
||||||
|
import work.slhaf.agent.core.memory.exception.UnExistedTopicException;
|
||||||
import work.slhaf.agent.core.memory.pojo.MemoryResult;
|
import work.slhaf.agent.core.memory.pojo.MemoryResult;
|
||||||
import work.slhaf.agent.core.memory.pojo.MemorySlice;
|
import work.slhaf.agent.core.memory.pojo.MemorySlice;
|
||||||
import work.slhaf.agent.modules.memory.selector.evaluator.SliceSelectEvaluator;
|
import work.slhaf.agent.modules.memory.selector.evaluator.SliceSelectEvaluator;
|
||||||
@@ -84,12 +86,6 @@ public class MemorySelector implements InteractionModule {
|
|||||||
List<EvaluatedSlice> memorySlices = sliceSelectEvaluator.execute(evaluatorInput);
|
List<EvaluatedSlice> memorySlices = sliceSelectEvaluator.execute(evaluatorInput);
|
||||||
memoryManager.getActivatedSlices().put(userId,memorySlices);
|
memoryManager.getActivatedSlices().put(userId,memorySlices);
|
||||||
|
|
||||||
//向上下文设置切片存入标志,条件:对话历史列表不为空;触发了记忆查询
|
|
||||||
/*if (!memoryManager.getChatMessages().isEmpty()) {
|
|
||||||
interactionContext.getModuleContext().put("new_topic", true);
|
|
||||||
interactionContext.getModuleContext().put("messages_to_store", List.of(memoryManager.getChatMessages()));
|
|
||||||
}*/
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//设置上下文
|
//设置上下文
|
||||||
@@ -103,13 +99,18 @@ public class MemorySelector implements InteractionModule {
|
|||||||
|
|
||||||
private void setMemoryResultList(List<MemoryResult> memoryResultList, List<ExtractorMatchData> matches, String userId) throws IOException, ClassNotFoundException {
|
private void setMemoryResultList(List<MemoryResult> memoryResultList, List<ExtractorMatchData> matches, String userId) throws IOException, ClassNotFoundException {
|
||||||
for (ExtractorMatchData match : matches) {
|
for (ExtractorMatchData match : matches) {
|
||||||
MemoryResult memoryResult = switch (match.getType()) {
|
try {
|
||||||
case ExtractorMatchData.Constant.DATE -> memoryManager.selectMemory(match.getText());
|
MemoryResult memoryResult = switch (match.getType()) {
|
||||||
case ExtractorMatchData.Constant.TOPIC -> memoryManager.selectMemory(LocalDate.parse(match.getText()));
|
case ExtractorMatchData.Constant.TOPIC -> memoryManager.selectMemory(match.getText());
|
||||||
default -> null;
|
case ExtractorMatchData.Constant.DATE ->
|
||||||
};
|
memoryManager.selectMemory(LocalDate.parse(match.getText()));
|
||||||
if (memoryResult == null) continue;
|
default -> null;
|
||||||
memoryResultList.add(memoryResult);
|
};
|
||||||
|
if (memoryResult == null) continue;
|
||||||
|
memoryResultList.add(memoryResult);
|
||||||
|
}catch (UnExistedDateIndexException | UnExistedTopicException e) {
|
||||||
|
log.error("不存在的记忆索引! 请尝试更换更合适的主题提取LLM!");
|
||||||
|
}
|
||||||
}
|
}
|
||||||
//清理切片记录
|
//清理切片记录
|
||||||
memoryManager.cleanSelectedSliceFilter();
|
memoryManager.cleanSelectedSliceFilter();
|
||||||
@@ -127,6 +128,6 @@ public class MemorySelector implements InteractionModule {
|
|||||||
if (memorySlice.isPrivate()) {
|
if (memorySlice.isPrivate()) {
|
||||||
return memorySlice.getStartUserId().equals(userId);
|
return memorySlice.getStartUserId().equals(userId);
|
||||||
}
|
}
|
||||||
return true;
|
return false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -59,7 +59,11 @@ public class SliceSelectEvaluator extends Model {
|
|||||||
List<Callable<Void>> tasks = new ArrayList<>();
|
List<Callable<Void>> tasks = new ArrayList<>();
|
||||||
Queue<EvaluatedSlice> queue = new ConcurrentLinkedDeque<>();
|
Queue<EvaluatedSlice> queue = new ConcurrentLinkedDeque<>();
|
||||||
for (MemoryResult memoryResult : memoryResultList) {
|
for (MemoryResult memoryResult : memoryResultList) {
|
||||||
|
if (memoryResult.getMemorySliceResult().isEmpty() && memoryResult.getRelatedMemorySliceResult().isEmpty()){
|
||||||
|
continue;
|
||||||
|
}
|
||||||
tasks.add(() -> {
|
tasks.add(() -> {
|
||||||
|
log.debug("切片评估...");
|
||||||
List<SliceSummary> sliceSummaryList = new ArrayList<>();
|
List<SliceSummary> sliceSummaryList = new ArrayList<>();
|
||||||
//映射查找键值
|
//映射查找键值
|
||||||
Map<Long, SliceSummary> map = new HashMap<>();
|
Map<Long, SliceSummary> map = new HashMap<>();
|
||||||
@@ -71,6 +75,7 @@ public class SliceSelectEvaluator extends Model {
|
|||||||
.history(evaluatorInput.getMessages())
|
.history(evaluatorInput.getMessages())
|
||||||
.build();
|
.build();
|
||||||
EvaluatorResult evaluatorResult = JSONObject.parseObject(extractJson(singleChat(JSONUtil.toJsonStr(batchInput)).getMessage()), EvaluatorResult.class);
|
EvaluatorResult evaluatorResult = JSONObject.parseObject(extractJson(singleChat(JSONUtil.toJsonStr(batchInput)).getMessage()), EvaluatorResult.class);
|
||||||
|
log.debug("评估结果: {}", evaluatorResult);
|
||||||
for (Long result : evaluatorResult.getResults()) {
|
for (Long result : evaluatorResult.getResults()) {
|
||||||
SliceSummary sliceSummary = map.get(result);
|
SliceSummary sliceSummary = map.get(result);
|
||||||
EvaluatedSlice evaluatedSlice = EvaluatedSlice.builder()
|
EvaluatedSlice evaluatedSlice = EvaluatedSlice.builder()
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ import work.slhaf.agent.core.memory.MemoryManager;
|
|||||||
import work.slhaf.agent.core.session.SessionManager;
|
import work.slhaf.agent.core.session.SessionManager;
|
||||||
import work.slhaf.agent.modules.memory.selector.extractor.data.ExtractorInput;
|
import work.slhaf.agent.modules.memory.selector.extractor.data.ExtractorInput;
|
||||||
import work.slhaf.agent.modules.memory.selector.extractor.data.ExtractorResult;
|
import work.slhaf.agent.modules.memory.selector.extractor.data.ExtractorResult;
|
||||||
|
import work.slhaf.agent.shared.memory.EvaluatedSlice;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
@@ -60,11 +61,14 @@ public class MemorySelectExtractor extends Model {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
List<EvaluatedSlice> activatedMemorySlices = memoryManager.getActivatedSlices().get(context.getUserId());
|
||||||
|
|
||||||
ExtractorInput extractorInput = ExtractorInput.builder()
|
ExtractorInput extractorInput = ExtractorInput.builder()
|
||||||
.text(context.getInput())
|
.text(context.getInput())
|
||||||
.date(context.getDateTime().toLocalDate())
|
.date(context.getDateTime().toLocalDate())
|
||||||
.history(chatMessages)
|
.history(chatMessages)
|
||||||
.topic_tree(memoryManager.getTopicTree())
|
.topic_tree(memoryManager.getTopicTree())
|
||||||
|
.activatedMemorySlices(activatedMemorySlices)
|
||||||
.build();
|
.build();
|
||||||
String responseStr = extractJson(singleChat(JSONUtil.toJsonPrettyStr(extractorInput)).getMessage());
|
String responseStr = extractJson(singleChat(JSONUtil.toJsonPrettyStr(extractorInput)).getMessage());
|
||||||
|
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package work.slhaf.agent.modules.memory.selector.extractor.data;
|
|||||||
import lombok.Builder;
|
import lombok.Builder;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
import work.slhaf.agent.common.chat.pojo.Message;
|
import work.slhaf.agent.common.chat.pojo.Message;
|
||||||
|
import work.slhaf.agent.shared.memory.EvaluatedSlice;
|
||||||
|
|
||||||
import java.time.LocalDate;
|
import java.time.LocalDate;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
@@ -14,4 +15,5 @@ public class ExtractorInput {
|
|||||||
private String topic_tree;
|
private String topic_tree;
|
||||||
private LocalDate date;
|
private LocalDate date;
|
||||||
private List<Message> history;
|
private List<Message> history;
|
||||||
|
private List<EvaluatedSlice> activatedMemorySlices;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -41,7 +41,6 @@ public class MemoryUpdater implements InteractionModule {
|
|||||||
private MemorySummarizer memorySummarizer;
|
private MemorySummarizer memorySummarizer;
|
||||||
private SessionManager sessionManager;
|
private SessionManager sessionManager;
|
||||||
private StaticMemoryExtractor staticMemoryExtractor;
|
private StaticMemoryExtractor staticMemoryExtractor;
|
||||||
private long lastUpdatedTime = 0;
|
|
||||||
|
|
||||||
private MemoryUpdater() {
|
private MemoryUpdater() {
|
||||||
}
|
}
|
||||||
@@ -67,8 +66,13 @@ public class MemoryUpdater implements InteractionModule {
|
|||||||
while (!Thread.interrupted()) {
|
while (!Thread.interrupted()) {
|
||||||
try {
|
try {
|
||||||
long currentTime = System.currentTimeMillis();
|
long currentTime = System.currentTimeMillis();
|
||||||
if (lastUpdatedTime != 0 && currentTime - lastUpdatedTime > UPDATE_TRIGGER_INTERVAL) {
|
long lastUpdatedTime = sessionManager.getLastUpdatedTime();
|
||||||
|
int chatCount = memoryManager.getChatMessages().size();
|
||||||
|
if (lastUpdatedTime != 0 && currentTime - lastUpdatedTime > UPDATE_TRIGGER_INTERVAL && chatCount > 1) {
|
||||||
updateMemory();
|
updateMemory();
|
||||||
|
//重置MemoryId
|
||||||
|
sessionManager.refreshMemoryId();
|
||||||
|
log.info("记忆更新: 自动触发");
|
||||||
}
|
}
|
||||||
Thread.sleep(SCHEDULED_UPDATE_INTERVAL);
|
Thread.sleep(SCHEDULED_UPDATE_INTERVAL);
|
||||||
} catch (Exception e) {
|
} catch (Exception e) {
|
||||||
@@ -90,52 +94,65 @@ public class MemoryUpdater implements InteractionModule {
|
|||||||
if (moduleContext.getIntValue("total_token") > 24000) {
|
if (moduleContext.getIntValue("total_token") > 24000) {
|
||||||
try {
|
try {
|
||||||
updateMemory();
|
updateMemory();
|
||||||
|
log.info("记忆更新: token超限");
|
||||||
} catch (Exception e) {
|
} catch (Exception e) {
|
||||||
log.error("记忆更新线程出错: {}", e.getLocalizedMessage());
|
log.error("记忆更新线程出错: {}", e.getLocalizedMessage());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
sessionManager.resetLastUpdatedTime();
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private void updateMemory() throws InterruptedException, IOException, ClassNotFoundException {
|
private void updateMemory() throws IOException, ClassNotFoundException {
|
||||||
HashMap<String, String> singleMemorySummary = new HashMap<>();
|
HashMap<String, String> singleMemorySummary = new HashMap<>();
|
||||||
//更新单聊记忆以及该场景中对应的确定性记忆,同时从chatMessages中去掉单聊记忆
|
//更新单聊记忆以及该场景中对应的确定性记忆,同时从chatMessages中去掉单聊记忆
|
||||||
updateSingleChatSlices(singleMemorySummary);
|
updateSingleChatSlices(singleMemorySummary);
|
||||||
//更新多人场景下的记忆及相关的确定性记忆
|
//更新多人场景下的记忆及相关的确定性记忆
|
||||||
updateMultiChatSlices(singleMemorySummary);
|
updateMultiChatSlices(singleMemorySummary);
|
||||||
//更新最近更新时间
|
//清空chatMessages
|
||||||
lastUpdatedTime = System.currentTimeMillis();
|
clearChatMessages();
|
||||||
}
|
}
|
||||||
|
|
||||||
private void updateMultiChatSlices(HashMap<String, String> singleMemorySummary) throws InterruptedException, IOException, ClassNotFoundException {
|
private void updateMultiChatSlices(HashMap<String, String> singleMemorySummary) {
|
||||||
//此时chatMessages中不再包含单聊记录,直接执行摘要以及切片插入
|
//此时chatMessages中不再包含单聊记录,直接执行摘要以及切片插入
|
||||||
//对剩下的多人聊天记录进行进行摘要
|
//对剩下的多人聊天记录进行进行摘要
|
||||||
executor.execute(() -> {
|
executor.execute(() -> {
|
||||||
try {
|
try {
|
||||||
//以第一条user对应的id为发起用户
|
List<Message> chatMessages = new ArrayList<>(memoryManager.getChatMessages());
|
||||||
Pattern pattern = Pattern.compile(USERID_REGEX);
|
chatMessages.removeFirst();
|
||||||
Matcher matcher = pattern.matcher(memoryManager.getChatMessages().get(1).getContent());
|
if (!chatMessages.isEmpty()) {
|
||||||
if (!matcher.find()){
|
//以第一条user对应的id为发起用户
|
||||||
throw new RuntimeException("未匹配到 userId!");
|
Pattern pattern = Pattern.compile(USERID_REGEX);
|
||||||
|
Matcher matcher = pattern.matcher(chatMessages.getFirst().getContent());
|
||||||
|
if (!matcher.find()) {
|
||||||
|
throw new RuntimeException("未匹配到 userId!");
|
||||||
|
}
|
||||||
|
String userId = matcher.group(1);
|
||||||
|
SummarizeResult summarizeResult = memorySummarizer.execute(new SummarizeInput(chatMessages, memoryManager.getTopicTree()));
|
||||||
|
MemorySlice memorySlice = getMemorySlice(userId, summarizeResult, chatMessages);
|
||||||
|
//设置involvedUserId
|
||||||
|
setInvolvedUserId(userId, memorySlice, chatMessages);
|
||||||
|
memoryManager.insertSlice(memorySlice, summarizeResult.getTopicPath());
|
||||||
|
|
||||||
|
if (!singleMemorySummary.isEmpty()) {
|
||||||
|
memoryManager.updateDialogMap(LocalDateTime.now(), summarizeResult.getSummary());
|
||||||
|
}
|
||||||
|
}else{
|
||||||
|
memoryManager.updateDialogMap(LocalDateTime.now(),memorySummarizer.executeTotalSummary(singleMemorySummary));
|
||||||
}
|
}
|
||||||
String userId = matcher.group(1);
|
|
||||||
SummarizeResult summarizeResult = memorySummarizer.execute(new SummarizeInput(memoryManager.getChatMessages(), memoryManager.getTopicTree()));
|
|
||||||
MemorySlice memorySlice = getMemorySlice(userId, summarizeResult, memoryManager.getChatMessages());
|
|
||||||
//设置involvedUserId
|
|
||||||
List<Message> messages = new ArrayList<>(memoryManager.getChatMessages());
|
|
||||||
messages.removeFirst();
|
|
||||||
setInvolvedUserId(userId, memorySlice, messages);
|
|
||||||
memoryManager.insertSlice(memorySlice, summarizeResult.getTopicPath());
|
|
||||||
//更新总dialogMap
|
|
||||||
singleMemorySummary.put("total", summarizeResult.getSummary());
|
|
||||||
memoryManager.updateDialogMap(LocalDateTime.now(), memorySummarizer.executeTotalSummary(singleMemorySummary));
|
|
||||||
} catch (IOException | ClassNotFoundException | InterruptedException e) {
|
} catch (IOException | ClassNotFoundException | InterruptedException e) {
|
||||||
log.error("多人场景记忆更新失败: {}", e.getLocalizedMessage());
|
log.error("多人场景记忆更新失败: {}", e.getLocalizedMessage());
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private void clearChatMessages() {
|
||||||
|
Message first = memoryManager.getChatMessages().getFirst();
|
||||||
|
memoryManager.getChatMessages().clear();
|
||||||
|
memoryManager.getChatMessages().add(first);
|
||||||
|
}
|
||||||
|
|
||||||
private void setInvolvedUserId(String startUserId, MemorySlice memorySlice, List<Message> chatMessages) {
|
private void setInvolvedUserId(String startUserId, MemorySlice memorySlice, List<Message> chatMessages) {
|
||||||
for (Message chatMessage : chatMessages) {
|
for (Message chatMessage : chatMessages) {
|
||||||
if (chatMessage.getRole().equals(ChatConstant.Character.ASSISTANT)) {
|
if (chatMessage.getRole().equals(ChatConstant.Character.ASSISTANT)) {
|
||||||
@@ -158,7 +175,7 @@ public class MemoryUpdater implements InteractionModule {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
private void updateSingleChatSlices(HashMap<String, String> singleMemorySummary) throws InterruptedException {
|
private void updateSingleChatSlices(HashMap<String, String> singleMemorySummary) {
|
||||||
//更新单聊记忆,同时从chatMessages中去掉单聊记忆
|
//更新单聊记忆,同时从chatMessages中去掉单聊记忆
|
||||||
Set<String> userIdSet = new HashSet<>(sessionManager.getSingleMetaMessageMap().keySet());
|
Set<String> userIdSet = new HashSet<>(sessionManager.getSingleMetaMessageMap().keySet());
|
||||||
List<Callable<Void>> tasks = new ArrayList<>();
|
List<Callable<Void>> tasks = new ArrayList<>();
|
||||||
@@ -177,7 +194,7 @@ public class MemoryUpdater implements InteractionModule {
|
|||||||
//添加至singleMemorySummary
|
//添加至singleMemorySummary
|
||||||
singleMemorySummary.put(id, summarizeResult.getSummary());
|
singleMemorySummary.put(id, summarizeResult.getSummary());
|
||||||
} catch (Exception e) {
|
} catch (Exception e) {
|
||||||
log.error("单聊记忆更新出错: {}", e.getLocalizedMessage());
|
log.error("单聊记忆更新出错: ", e);
|
||||||
}
|
}
|
||||||
return null;
|
return null;
|
||||||
});
|
});
|
||||||
@@ -196,8 +213,13 @@ public class MemoryUpdater implements InteractionModule {
|
|||||||
executor.invokeAll(tasks);
|
executor.invokeAll(tasks);
|
||||||
}
|
}
|
||||||
|
|
||||||
private static MemorySlice getMemorySlice(String userId, SummarizeResult summarizeResult, List<Message> chatMessages) {
|
private MemorySlice getMemorySlice(String userId, SummarizeResult summarizeResult, List<Message> chatMessages) {
|
||||||
MemorySlice memorySlice = new MemorySlice();
|
MemorySlice memorySlice = new MemorySlice();
|
||||||
|
//设置 memoryId,timestamp
|
||||||
|
memorySlice.setMemoryId(sessionManager.getCurrentMemoryId());
|
||||||
|
memorySlice.setTimestamp(System.currentTimeMillis());
|
||||||
|
|
||||||
|
//补充信息
|
||||||
memorySlice.setPrivate(summarizeResult.isPrivate());
|
memorySlice.setPrivate(summarizeResult.isPrivate());
|
||||||
memorySlice.setSummary(summarizeResult.getSummary());
|
memorySlice.setSummary(summarizeResult.getSummary());
|
||||||
memorySlice.setChatMessages(chatMessages);
|
memorySlice.setChatMessages(chatMessages);
|
||||||
|
|||||||
@@ -98,7 +98,7 @@ public class MemorySummarizer extends Model {
|
|||||||
public String executeTotalSummary(HashMap<String, String> singleMemorySummary) {
|
public String executeTotalSummary(HashMap<String, String> singleMemorySummary) {
|
||||||
ChatResponse response = chatClient.runChat(List.of(new Message(ChatConstant.Character.SYSTEM, prompts.get(2)),
|
ChatResponse response = chatClient.runChat(List.of(new Message(ChatConstant.Character.SYSTEM, prompts.get(2)),
|
||||||
new Message(ChatConstant.Character.USER, JSONUtil.toJsonPrettyStr(singleMemorySummary))));
|
new Message(ChatConstant.Character.USER, JSONUtil.toJsonPrettyStr(singleMemorySummary))));
|
||||||
return JSONObject.parseObject(extractJson(response.getMessage())).getString("value");
|
return JSONObject.parseObject(extractJson(response.getMessage())).getString("content");
|
||||||
}
|
}
|
||||||
|
|
||||||
private static class Constant {
|
private static class Constant {
|
||||||
@@ -137,29 +137,13 @@ public class MemorySummarizer extends Model {
|
|||||||
4. 最终校验:检查是否丢失关键信息
|
4. 最终校验:检查是否丢失关键信息
|
||||||
|
|
||||||
完整示例
|
完整示例
|
||||||
示例1(常规长文本):
|
示例:
|
||||||
输入:{
|
输入:{
|
||||||
"content": "在2023年第四季度,XX公司实现了显著增长。财报显示总收入达到4.56亿元,同比增长32%。其中主要增长来自智能手机业务板块,该板块贡献了3.12亿元收入,同比增长达45%。同时智能家居业务收入1.44亿元,同比增长12%。公司CEO在财报电话会议中强调,增长主要得益于东南亚市场的成功拓展..."
|
"content": "在2023年第四季度,XX公司实现了显著增长。财报显示总收入达到4.56亿元,同比增长32%。其中主要增长来自智能手机业务板块,该板块贡献了3.12亿元收入,同比增长达45%。同时智能家居业务收入1.44亿元,同比增长12%。公司CEO在财报电话会议中强调,增长主要得益于东南亚市场的成功拓展..."
|
||||||
}
|
}
|
||||||
输出:{
|
输出:{
|
||||||
"content": "XX公司2023年Q4总收入4.56亿元(同比+32%),智能手机业务贡献3.12亿元(+45%),智能家居1.44亿元(+12%),增长主要来自东南亚市场拓展。"
|
"content": "XX公司2023年Q4总收入4.56亿元(同比+32%),智能手机业务贡献3.12亿元(+45%),智能家居1.44亿元(+12%),增长主要来自东南亚市场拓展。"
|
||||||
}
|
}
|
||||||
|
|
||||||
示例2(多段落文本):
|
|
||||||
输入:{
|
|
||||||
"content": "本次项目改造涉及三个主要方面。首先,硬件升级包括:1) 更换全部服务器设备;2) 安装新的网络交换机;3) 部署智能安防系统。其次,软件系统将迁移至新平台,需完成数据迁移和接口适配。最后,人员培训计划分三阶段实施..."
|
|
||||||
}
|
|
||||||
输出:{
|
|
||||||
"content": "项目改造含硬件升级(更换服务器、新交换机、智能安防)、软件系统迁移(含数据迁移和接口适配)及分三阶段的人员培训。"
|
|
||||||
}
|
|
||||||
|
|
||||||
示例3(技术文档):
|
|
||||||
输入:{
|
|
||||||
"content": "该算法采用改进的卷积神经网络架构,包含3个主要模块:特征提取模块(由5个卷积层组成)、注意力机制模块(含通道和空间注意力)、以及分类模块(使用2个全连接层)。在ImageNet数据集上达到92.3%的准确率..."
|
|
||||||
}
|
|
||||||
输出:{
|
|
||||||
"content": "算法使用改进CNN架构,含特征提取(5卷积层)、注意力机制(通道+空间)和分类模块(2全连接层),在ImageNet上准确率92.3%。"
|
|
||||||
}
|
|
||||||
""";
|
""";
|
||||||
|
|
||||||
public static final String MULTI_SUMMARIZE_PROMPT = """
|
public static final String MULTI_SUMMARIZE_PROMPT = """
|
||||||
@@ -185,7 +169,7 @@ public class MemorySummarizer extends Model {
|
|||||||
|
|
||||||
2. 主题路径生成细则:
|
2. 主题路径生成细则:
|
||||||
• 抽象链构建流程:
|
• 抽象链构建流程:
|
||||||
a. 以`user`的意图为主要锚点,锁定最低节点
|
a. 以`user`的输入内容意图为主要锚点,锁定最低节点
|
||||||
b. 逐层抽象(地标→城市→国家→大洲),需保证抽象链的纯净,确保不会跨越领域
|
b. 逐层抽象(地标→城市→国家→大洲),需保证抽象链的纯净,确保不会跨越领域
|
||||||
c. 修剪抽象链,使其保持在[3, 7]层之内,同时每层的抽象节点考虑扩展性及可复用性
|
c. 修剪抽象链,使其保持在[3, 7]层之内,同时每层的抽象节点考虑扩展性及可复用性
|
||||||
d. 形成最终路径(格式:领域→大类→子类→实例)
|
d. 形成最终路径(格式:领域→大类→子类→实例)
|
||||||
@@ -226,7 +210,7 @@ public class MemorySummarizer extends Model {
|
|||||||
b. 技术术语需符合行业标准
|
b. 技术术语需符合行业标准
|
||||||
|
|
||||||
完整示例
|
完整示例
|
||||||
示例1(日常分享):
|
示例:
|
||||||
输入:{
|
输入:{
|
||||||
"topicTree": "
|
"topicTree": "
|
||||||
生活[root]
|
生活[root]
|
||||||
@@ -237,7 +221,7 @@ public class MemorySummarizer extends Model {
|
|||||||
]
|
]
|
||||||
}
|
}
|
||||||
输出:{
|
输出:{
|
||||||
"summary": "用户分享欧洲自由行经历并讨论夜景照片处理",
|
"summary": "用户分享欧洲自由行经历并讨论夜景照片处理...",
|
||||||
"topicPath": "生活->旅行->自由行->欧洲->法国->巴黎铁塔",
|
"topicPath": "生活->旅行->自由行->欧洲->法国->巴黎铁塔",
|
||||||
"relatedTopicPath": [
|
"relatedTopicPath": [
|
||||||
"艺术->摄影->夜景拍摄",
|
"艺术->摄影->夜景拍摄",
|
||||||
@@ -245,65 +229,6 @@ public class MemorySummarizer extends Model {
|
|||||||
],
|
],
|
||||||
"isPrivate": false
|
"isPrivate": false
|
||||||
}
|
}
|
||||||
|
|
||||||
示例2(专业咨询):
|
|
||||||
输入:{
|
|
||||||
"topicTree": "
|
|
||||||
计算机[root]
|
|
||||||
└── 编程",
|
|
||||||
"chatMessages": [
|
|
||||||
{"role": "user", "content": "SpringBoot项目如何实现JWT鉴权"},
|
|
||||||
{"role": "assistant", "content": "需集成spring-security-jwt依赖..."}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
输出:{
|
|
||||||
"summary": "讨论SpringBoot项目集成JWT鉴权的技术方案",
|
|
||||||
"topicPath": "计算机->软件开发->Java->SpringBoot->安全->JWT",
|
|
||||||
"relatedTopicPath": [
|
|
||||||
"计算机->网络安全->认证协议",
|
|
||||||
"数学->加密算法->非对称加密"
|
|
||||||
],
|
|
||||||
"isPrivate": false
|
|
||||||
}
|
|
||||||
|
|
||||||
示例3(事件讨论):
|
|
||||||
输入:{
|
|
||||||
"topicTree": "
|
|
||||||
社会[root]
|
|
||||||
├── 教育
|
|
||||||
└── 科技",
|
|
||||||
"chatMessages": [
|
|
||||||
{"role": "user", "content": "听说某大学研发出脑机接口新成果"},
|
|
||||||
{"role": "assistant", "content": "该技术涉及神经科学和AI的跨学科研究"}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
输出:{
|
|
||||||
"summary": "讨论某大学在脑机接口领域的跨学科研究成果",
|
|
||||||
"topicPath": "社会->科技->人工智能->脑机接口",
|
|
||||||
"relatedTopicPath": [
|
|
||||||
"科学->生物学->神经科学",
|
|
||||||
"教育->高等教育->科研创新"
|
|
||||||
],
|
|
||||||
"isPrivate": false
|
|
||||||
}
|
|
||||||
|
|
||||||
示例4(隐私事件):
|
|
||||||
输入:{
|
|
||||||
"topicTree": "
|
|
||||||
法律[root]
|
|
||||||
└── 隐私",
|
|
||||||
"chatMessages": [
|
|
||||||
{"role": "user", "content": "这个合同条款请仅限我们之间知晓"},
|
|
||||||
{"role": "assistant", "content": "已启用加密存储,不会外泄"}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
输出:{
|
|
||||||
"summary": "用户要求保密合同条款内容",
|
|
||||||
"topicPath": "法律->合同法->保密条款",
|
|
||||||
"relatedTopicPath": ["信息技术->数据安全->加密存储"],
|
|
||||||
"isPrivate": true
|
|
||||||
}
|
|
||||||
|
|
||||||
""";
|
""";
|
||||||
|
|
||||||
public static final String TOTAL_SUMMARIZE_PROMPT = """
|
public static final String TOTAL_SUMMARIZE_PROMPT = """
|
||||||
@@ -347,7 +272,7 @@ public class MemorySummarizer extends Model {
|
|||||||
c. 验证信息完整性
|
c. 验证信息完整性
|
||||||
|
|
||||||
完整示例
|
完整示例
|
||||||
示例1(基础情况):
|
示例:
|
||||||
输入:{
|
输入:{
|
||||||
"aaa-111": "需要购买笔记本电脑,预算5000左右,主要用于办公",
|
"aaa-111": "需要购买笔记本电脑,预算5000左右,主要用于办公",
|
||||||
"bbb-222": "想买游戏本,预算8000-10000,要能运行3A大作",
|
"bbb-222": "想买游戏本,预算8000-10000,要能运行3A大作",
|
||||||
@@ -360,30 +285,6 @@ public class MemorySummarizer extends Model {
|
|||||||
用户[ccc-333]:咨询适合出差使用的轻薄本"
|
用户[ccc-333]:咨询适合出差使用的轻薄本"
|
||||||
}
|
}
|
||||||
|
|
||||||
示例2(信息合并):
|
|
||||||
输入:{
|
|
||||||
"ddd-444": "想了解Python入门课程,零基础",
|
|
||||||
"eee-555": "询问Java和Python哪个更适合新手",
|
|
||||||
"fff-666": "零基础,想学Python数据分析"
|
|
||||||
}
|
|
||||||
输出:{
|
|
||||||
"content": "
|
|
||||||
用户[ddd-444]:零基础想了解Python入门课程;
|
|
||||||
用户[eee-555]:询问Java和Python对新手的适用性;
|
|
||||||
用户[fff-666]:零基础想学习Python数据分析"
|
|
||||||
}
|
|
||||||
|
|
||||||
示例3(长文本精简):
|
|
||||||
输入:{
|
|
||||||
"ggg-777": "您好!我最近在准备考研,想咨询下时间规划。具体是想了解每天应该分配多少时间给英语复习,我现在英语水平大概是四级刚过的程度...(后续200字详细描述)",
|
|
||||||
"hhh-888": "考研政治怎么准备?需要报班吗?"
|
|
||||||
}
|
|
||||||
输出:{
|
|
||||||
"content": "
|
|
||||||
用户[ggg-777]:咨询考研英语复习时间规划,当前英语水平为四级;
|
|
||||||
用户[hhh-888]:询问考研政治备考方法及是否需要报班"
|
|
||||||
}
|
|
||||||
|
|
||||||
特殊处理
|
特殊处理
|
||||||
1. 当总字数超出限制时:
|
1. 当总字数超出限制时:
|
||||||
• 尽量保留所有出现的用户摘要
|
• 尽量保留所有出现的用户摘要
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import lombok.Data;
|
|||||||
import work.slhaf.agent.core.interaction.data.InteractionContext;
|
import work.slhaf.agent.core.interaction.data.InteractionContext;
|
||||||
import work.slhaf.agent.core.interaction.data.InteractionInputData;
|
import work.slhaf.agent.core.interaction.data.InteractionInputData;
|
||||||
import work.slhaf.agent.core.memory.MemoryManager;
|
import work.slhaf.agent.core.memory.MemoryManager;
|
||||||
|
import work.slhaf.agent.core.session.SessionManager;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.time.LocalDateTime;
|
import java.time.LocalDateTime;
|
||||||
@@ -16,6 +17,7 @@ public class PreprocessExecutor {
|
|||||||
private static PreprocessExecutor preprocessExecutor;
|
private static PreprocessExecutor preprocessExecutor;
|
||||||
|
|
||||||
private MemoryManager memoryManager;
|
private MemoryManager memoryManager;
|
||||||
|
private SessionManager sessionManager;
|
||||||
|
|
||||||
private PreprocessExecutor() {
|
private PreprocessExecutor() {
|
||||||
}
|
}
|
||||||
@@ -24,14 +26,27 @@ public class PreprocessExecutor {
|
|||||||
if (preprocessExecutor == null) {
|
if (preprocessExecutor == null) {
|
||||||
preprocessExecutor = new PreprocessExecutor();
|
preprocessExecutor = new PreprocessExecutor();
|
||||||
preprocessExecutor.setMemoryManager(MemoryManager.getInstance());
|
preprocessExecutor.setMemoryManager(MemoryManager.getInstance());
|
||||||
|
preprocessExecutor.setSessionManager(SessionManager.getInstance());
|
||||||
}
|
}
|
||||||
return preprocessExecutor;
|
return preprocessExecutor;
|
||||||
}
|
}
|
||||||
|
|
||||||
public InteractionContext execute(InteractionInputData inputData) {
|
public InteractionContext execute(InteractionInputData inputData) {
|
||||||
InteractionContext context = new InteractionContext();
|
checkAndSetMemoryId();
|
||||||
String userId = memoryManager.getUserId(inputData.getUserInfo(), inputData.getUserNickName());
|
return getInteractionContext(inputData);
|
||||||
|
}
|
||||||
|
|
||||||
|
private void checkAndSetMemoryId() {
|
||||||
|
String currentMemoryId = sessionManager.getCurrentMemoryId();
|
||||||
|
if (currentMemoryId == null || memoryManager.getChatMessages().isEmpty()) {
|
||||||
|
sessionManager.refreshMemoryId();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private InteractionContext getInteractionContext(InteractionInputData inputData) {
|
||||||
|
InteractionContext context = new InteractionContext();
|
||||||
|
|
||||||
|
String userId = memoryManager.getUserId(inputData.getUserInfo(), inputData.getUserNickName());
|
||||||
context.setUserId(userId);
|
context.setUserId(userId);
|
||||||
context.setUserNickname(inputData.getUserNickName());
|
context.setUserNickname(inputData.getUserNickName());
|
||||||
context.setUserInfo(inputData.getUserInfo());
|
context.setUserInfo(inputData.getUserInfo());
|
||||||
@@ -55,7 +70,6 @@ public class PreprocessExecutor {
|
|||||||
|
|
||||||
context.setSingle(inputData.isSingle());
|
context.setSingle(inputData.isSingle());
|
||||||
context.setFinished(false);
|
context.setFinished(false);
|
||||||
|
|
||||||
return context;
|
return context;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,34 +1,30 @@
|
|||||||
package memory;
|
package memory;
|
||||||
|
|
||||||
import java.util.concurrent.ExecutorService;
|
import org.junit.jupiter.api.Test;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.concurrent.Callable;
|
||||||
import java.util.concurrent.Executors;
|
import java.util.concurrent.Executors;
|
||||||
import java.util.concurrent.TimeUnit;
|
import java.util.concurrent.TimeUnit;
|
||||||
|
|
||||||
public class ThreadPoolTest {
|
public class ThreadPoolTest {
|
||||||
public static void main(String[] args) throws InterruptedException {
|
|
||||||
testExecutor(Executors.newVirtualThreadPerTaskExecutor());
|
|
||||||
|
|
||||||
// Thread.sleep(2000); // 等待系统输出稳定
|
@Test
|
||||||
|
public void testExecutor() throws InterruptedException {
|
||||||
// testExecutor("普通线程池", Executors.newFixedThreadPool(100));
|
List<Callable<Void>> tasks = new ArrayList<>();
|
||||||
}
|
for (int i = 0; i < 5; i++) {
|
||||||
|
int finalI = i;
|
||||||
private static void testExecutor(ExecutorService es) throws InterruptedException {
|
tasks.add(() -> {
|
||||||
long start = System.currentTimeMillis();
|
System.out.println("开始: " + finalI);
|
||||||
|
Thread.sleep(5000);
|
||||||
for (int i = 0; i < 100000; i++) {
|
System.out.println("结束: " + finalI);
|
||||||
es.submit(() -> {
|
return null;
|
||||||
Thread.sleep(1000);
|
|
||||||
return 0;
|
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
es.shutdown();
|
Executors.newVirtualThreadPerTaskExecutor().invokeAll(tasks, 10, TimeUnit.SECONDS);
|
||||||
if (es.awaitTermination(5, TimeUnit.MINUTES)) {
|
|
||||||
long end = System.currentTimeMillis();
|
System.out.println("hello");
|
||||||
System.out.println("虚拟线程" + "耗时:" + (end - start));
|
|
||||||
} else {
|
|
||||||
System.err.println("虚拟线程" + "未能在规定时间内完成所有任务");
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user