推进记忆模块

- 在 InteractionThreadPoolExecutor 中引入虚拟线程池 (newVirtualThreadPerTaskExecutor)
- 更新相关测试文件以适应新的线程池
- 优化 MemorySummarizer 中的单条目摘要逻辑
- 为 SingleSummarizer 、 MultiSummarizer 设计了提示词
- 还差两份提示词没有设计...
This commit is contained in:
2025-05-07 21:38:41 +08:00
parent 3dd21f840e
commit 9e0af5e5aa
8 changed files with 289 additions and 31 deletions

View File

@@ -38,7 +38,7 @@ public class Agent implements TaskCallback, InputReceiver {
conn.close();
}
server.stop();
log.info("WebSocketServer 已优雅关闭");
log.info("WebSocketServer 已关闭");
} catch (Exception e) {
log.error("关闭失败", e);
}

View File

@@ -1,28 +1,45 @@
package work.slhaf.agent.core.interaction;
import java.util.concurrent.ArrayBlockingQueue;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.ThreadPoolExecutor;
import lombok.Getter;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
public class InteractionThreadPoolExecutor extends ThreadPoolExecutor {
@Getter
public class InteractionThreadPoolExecutor {
private static InteractionThreadPoolExecutor interactionThreadPoolExecutor;
private InteractionThreadPoolExecutor(int corePoolSize, int maximumPoolSize, long keepAliveTime, TimeUnit unit, BlockingQueue<Runnable> workQueue) {
super(corePoolSize, maximumPoolSize, keepAliveTime, unit, workQueue);
}
private final ExecutorService executorService = Executors.newVirtualThreadPerTaskExecutor();
public static InteractionThreadPoolExecutor getInstance() {
if (interactionThreadPoolExecutor == null) {
interactionThreadPoolExecutor = new InteractionThreadPoolExecutor(
8,
24,
60,
TimeUnit.SECONDS,
new ArrayBlockingQueue<>(50)
);
interactionThreadPoolExecutor = new InteractionThreadPoolExecutor();
}
return interactionThreadPoolExecutor;
}
public <T> void invokeAll(List<Callable<T>> tasks, int time, TimeUnit timeUnit) {
try {
executorService.invokeAll(tasks, time, timeUnit);
} catch (InterruptedException e) {
throw new RuntimeException(e);
}
}
public <T> void invokeAll(List<Callable<T>> tasks) {
try {
executorService.invokeAll(tasks);
} catch (InterruptedException e) {
throw new RuntimeException(e);
}
}
public void execute(Runnable runnable) {
executorService.execute(runnable);
}
}

View File

@@ -6,11 +6,11 @@ import work.slhaf.agent.core.interaction.data.InteractionContext;
import work.slhaf.agent.core.memory.MemoryManager;
import work.slhaf.agent.core.memory.pojo.MemoryResult;
import work.slhaf.agent.core.memory.pojo.MemorySlice;
import work.slhaf.agent.modules.memory.selector.evaluator.data.EvaluatorInput;
import work.slhaf.agent.modules.memory.selector.evaluator.SliceSelectEvaluator;
import work.slhaf.agent.modules.memory.selector.evaluator.data.EvaluatorInput;
import work.slhaf.agent.modules.memory.selector.extractor.MemorySelectExtractor;
import work.slhaf.agent.modules.memory.selector.extractor.data.ExtractorMatchData;
import work.slhaf.agent.modules.memory.selector.extractor.data.ExtractorResult;
import work.slhaf.agent.modules.memory.selector.extractor.MemorySelectExtractor;
import work.slhaf.agent.shared.memory.EvaluatedSlice;
import java.io.IOException;

View File

@@ -14,15 +14,13 @@ import work.slhaf.agent.modules.memory.selector.extractor.MemorySelectExtractor;
import work.slhaf.agent.modules.memory.updater.static_extractor.StaticMemoryExtractor;
import work.slhaf.agent.modules.memory.updater.static_extractor.data.StaticMemoryExtractInput;
import work.slhaf.agent.modules.memory.updater.summarizer.MemorySummarizer;
import work.slhaf.agent.modules.memory.updater.summarizer.data.SummarizeResult;
import work.slhaf.agent.modules.memory.updater.summarizer.data.SummarizeInput;
import work.slhaf.agent.modules.memory.updater.summarizer.data.SummarizeResult;
import java.io.IOException;
import java.time.LocalDateTime;
import java.util.*;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
@Data
@Slf4j
@@ -30,7 +28,6 @@ public class MemoryUpdater implements InteractionModule {
private static MemoryUpdater memoryUpdater;
private ExecutorService updateExecutor;
private MemoryManager memoryManager;
private InteractionThreadPoolExecutor executor;
private MemorySelectExtractor memorySelectExtractor;
@@ -48,7 +45,6 @@ public class MemoryUpdater implements InteractionModule {
memoryUpdater.setMemorySelectExtractor(MemorySelectExtractor.getInstance());
memoryUpdater.setMemorySummarizer(MemorySummarizer.getInstance());
memoryUpdater.setSessionManager(SessionManager.getInstance());
memoryUpdater.setUpdateExecutor(Executors.newSingleThreadExecutor());
memoryUpdater.setStaticMemoryExtractor(StaticMemoryExtractor.getInstance());
}
return memoryUpdater;
@@ -59,7 +55,7 @@ public class MemoryUpdater implements InteractionModule {
if (interactionContext.isFinished()) {
return;
}
updateExecutor.execute(() -> {
executor.execute(() -> {
//如果token 大于阈值,则更新记忆
JSONObject moduleContext = interactionContext.getModuleContext();
if (moduleContext.getIntValue("total_token") > 24000) {
@@ -73,9 +69,7 @@ public class MemoryUpdater implements InteractionModule {
} catch (InterruptedException | IOException | ClassNotFoundException e) {
log.error("记忆更新线程出错: {}", e.getLocalizedMessage());
}
}
});
}

View File

@@ -4,7 +4,6 @@ import lombok.Builder;
import lombok.Data;
import work.slhaf.agent.common.chat.pojo.Message;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

View File

@@ -13,8 +13,8 @@ import work.slhaf.agent.common.config.Config;
import work.slhaf.agent.common.model.Model;
import work.slhaf.agent.common.model.ModelConstant;
import work.slhaf.agent.core.interaction.InteractionThreadPoolExecutor;
import work.slhaf.agent.modules.memory.updater.summarizer.data.SummarizeResult;
import work.slhaf.agent.modules.memory.updater.summarizer.data.SummarizeInput;
import work.slhaf.agent.modules.memory.updater.summarizer.data.SummarizeResult;
import java.io.IOException;
import java.util.ArrayList;
@@ -32,7 +32,10 @@ public class MemorySummarizer extends Model {
private static MemorySummarizer memorySummarizer;
public static final String MODEL_KEY = "memory_summarizer";
private static final List<String> prompts = List.of();
private static final List<String> prompts = List.of(
Constant.SINGLE_SUMMARIZE_PROMPT,
Constant.MULTI_SUMMARIZE_PROMPT
);
private InteractionThreadPoolExecutor executor;
@@ -63,14 +66,14 @@ public class MemorySummarizer extends Model {
return JSONObject.parseObject(response.getMessage(), SummarizeResult.class);
}
private void singleMessageSummarize(List<Message> chatMessages) throws InterruptedException {
private void singleMessageSummarize(List<Message> chatMessages) {
List<Callable<Void>> tasks = new ArrayList<>();
for (Message chatMessage : chatMessages) {
if (chatMessage.getRole().equals(ChatConstant.Character.ASSISTANT)) {
String content = chatMessage.getContent();
if (chatMessage.getContent().length() > 500) {
tasks.add(() -> {
chatMessage.setContent(singleSummarizeExecute(prompts.get(0), content));
chatMessage.setContent(singleSummarizeExecute(prompts.getFirst(), JSONObject.of("content", content).toString()));
return null;
});
}
@@ -83,7 +86,7 @@ public class MemorySummarizer extends Model {
try {
ChatResponse response = chatClient.runChat(List.of(new Message(ChatConstant.Character.SYSTEM, prompt),
new Message(ChatConstant.Character.USER, content)));
return response.getMessage();
return JSONObject.parseObject(extractJson(response.getMessage())).getString("content");
} catch (Exception e) {
log.error(e.getLocalizedMessage());
return content;
@@ -96,4 +99,210 @@ public class MemorySummarizer extends Model {
new Message(ChatConstant.Character.USER, JSONUtil.toJsonPrettyStr(singleMemorySummary))));
return JSONObject.parseObject(extractJson(response.getMessage())).getString("value");
}
private static class Constant {
public static final String SINGLE_SUMMARIZE_PROMPT = """
SINGLE_SUMMARIZER 提示词
功能说明
你需要根据用户输入的JSON数据中的`content`字段内容生成精简且保留关键细节的摘要严格控制输出在500字以内。
输入字段说明
• `content`: 需要被摘要的原始文本内容(可能包含复杂信息或多段落结构)
输出规则
1. 基本响应格式:
{
"content": string // 摘要后的文本内容
}
2. 摘要质量要求:
• 保留所有关键事实和数据
• 维持原始信息的因果关系
• 优先保留具体名词和数字信息
• 删除冗余修饰词和重复表达
3. 长度控制:
• 硬性限制绝对不超过500字符按中文计算
• 理想长度200-450字符区间
4. 特殊处理:
• 当检测到列表/条目信息时:改用分号连接
• 当存在直接引语时:保留核心引述但可简化引导句
处理流程
1. 首次扫描识别文本中的关键要素5W1H
2. 二次分析:标注需要保留的具体数据/专有名词
3. 结构优化:
a. 合并同类段落
b. 转换长句为短句
c. 用更简洁的表达替换复杂句式
4. 最终校验:检查是否丢失关键信息
完整示例
示例1常规长文本
输入:{
"content": "在2023年第四季度XX公司实现了显著增长。财报显示总收入达到4.56亿元同比增长32%。其中主要增长来自智能手机业务板块该板块贡献了3.12亿元收入同比增长达45%。同时智能家居业务收入1.44亿元同比增长12%。公司CEO在财报电话会议中强调增长主要得益于东南亚市场的成功拓展..."
}
输出:{
"content": "XX公司2023年Q4总收入4.56亿元(同比+32%智能手机业务贡献3.12亿元(+45%智能家居1.44亿元(+12%),增长主要来自东南亚市场拓展。"
}
示例2多段落文本
输入:{
"content": "本次项目改造涉及三个主要方面。首先硬件升级包括1) 更换全部服务器设备2) 安装新的网络交换机3) 部署智能安防系统。其次,软件系统将迁移至新平台,需完成数据迁移和接口适配。最后,人员培训计划分三阶段实施..."
}
输出:{
"content": "项目改造含硬件升级(更换服务器、新交换机、智能安防)、软件系统迁移(含数据迁移和接口适配)及分三阶段的人员培训。"
}
示例3技术文档
输入:{
"content": "该算法采用改进的卷积神经网络架构包含3个主要模块特征提取模块由5个卷积层组成、注意力机制模块含通道和空间注意力、以及分类模块使用2个全连接层。在ImageNet数据集上达到92.3%的准确率..."
}
输出:{
"content": "算法使用改进CNN架构含特征提取5卷积层、注意力机制通道+空间和分类模块2全连接层在ImageNet上准确率92.3%。"
}
""";
public static final String MULTI_SUMMARIZE_PROMPT = """
DialogueTopicMapper 提示词
功能说明
分析对话内容并生成最深为7层的多层次主题路径支持智能扩展主题树结构根据用户意图动态调整路径生成策略。
在保证符合以下要求的同时尽快输出
输入字段说明
• topicTree: 现有主题树结构(多根节点)
• chatMessages: 完整对话记录(需分析双方发言)
输出规则
0. **只需要输出所需的JSON文本**
1. 核心结构(保持原格式):
{
"summary": "", // 精简摘要100-150字
"topicPath": "", // 主路径(领域纯净的完整抽象链)
"relatedTopicPath": [], // 相关路径(允许跨领域)
"isPrivate": false
}
2. 主题路径生成细则:
• 抽象链构建流程:
a. 以`user`的意图为主要锚点,锁定最低节点
b. 逐层抽象(地标→城市→国家→大洲),需保证抽象链的纯净,确保不会跨越领域
c. 修剪抽象链,使其保持在[3, 7]层之内,同时每层的抽象节点考虑扩展性及可复用性
d. 形成最终路径(格式:领域→大类→子类→实例)
• 意图影响规则:
用户意图类型 | 主路径特征 | 相关路径特征
----------------|-------------------------|-------------------
知识咨询 | 聚焦专业领域链 | 补充相关学科
经验分享 | 生活场景链 | 关联文化/社会
事件讨论 | 时空维度链 | 链接相关事件
3. 动态扩展规范:
• 新根节点创建条件:
- 当抽象层级超过现有树结构时(如现有最高为"国家",需创建"大洲"
- 检测到全新领域维度时(如原树无"天文"相关节点)
主题树格式示例
(使用自然换行,无需转义符)
地理[root]
└── 欧洲
├── 法国
└── 德国
生活[root]
└── 旅行
├── 自由行
└── 跟团游
处理流程
1. 意图分析阶段:
a. 判断对话类型(咨询/分享/讨论)
b. 标记关键实体和动作
2. 路径构建阶段:
a. 自下而上构建抽象链(实例→抽象概念)
b. 验证层级逻辑(子类必须属于父类范畴)
c. 生成最终路径(格式示例:生活->旅行->自由行->欧洲游)
3. 扩展校验阶段:
a. 新增节点必须通过逻辑验证
b. 技术术语需符合行业标准
完整示例
示例1日常分享
输入:{
"topicTree": "
生活[root]
└── 旅行",
"chatMessages": [
{"role": "user", "content": "刚完成欧洲自由行,在巴黎铁塔拍到绝美夜景"},
{"role": "assistant", "content": "推荐使用Lightroom处理夜景RAW格式"}
]
}
输出:{
"summary": "用户分享欧洲自由行经历并讨论夜景照片处理",
"topicPath": "生活->旅行->自由行->欧洲->法国->巴黎铁塔",
"relatedTopicPath": [
"艺术->摄影->夜景拍摄",
"科技->软件->图像处理->Lightroom"
],
"isPrivate": false
}
示例2专业咨询
输入:{
"topicTree": "
计算机[root]
└── 编程",
"chatMessages": [
{"role": "user", "content": "SpringBoot项目如何实现JWT鉴权"},
{"role": "assistant", "content": "需集成spring-security-jwt依赖..."}
]
}
输出:{
"summary": "讨论SpringBoot项目集成JWT鉴权的技术方案",
"topicPath": "计算机->软件开发->Java->SpringBoot->安全->JWT",
"relatedTopicPath": [
"计算机->网络安全->认证协议",
"数学->加密算法->非对称加密"
],
"isPrivate": false
}
示例3事件讨论
输入:{
"topicTree": "
社会[root]
├── 教育
└── 科技",
"chatMessages": [
{"role": "user", "content": "听说某大学研发出脑机接口新成果"},
{"role": "assistant", "content": "该技术涉及神经科学和AI的跨学科研究"}
]
}
输出:{
"summary": "讨论某大学在脑机接口领域的跨学科研究成果",
"topicPath": "社会->科技->人工智能->脑机接口",
"relatedTopicPath": [
"科学->生物学->神经科学",
"教育->高等教育->科研创新"
],
"isPrivate": false
}
示例4隐私事件
输入:{
"topicTree": "
法律[root]
└── 隐私",
"chatMessages": [
{"role": "user", "content": "这个合同条款请仅限我们之间知晓"},
{"role": "assistant", "content": "已启用加密存储,不会外泄"}
]
}
输出:{
"summary": "用户要求保密合同条款内容",
"topicPath": "法律->合同法->保密条款",
"relatedTopicPath": ["信息技术->数据安全->加密存储"],
"isPrivate": true
}
""";
}
}

View File

@@ -0,0 +1,34 @@
package memory;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
public class ThreadPoolTest {
public static void main(String[] args) throws InterruptedException {
testExecutor(Executors.newVirtualThreadPerTaskExecutor());
// Thread.sleep(2000); // 等待系统输出稳定
// testExecutor("普通线程池", Executors.newFixedThreadPool(100));
}
private static void testExecutor(ExecutorService es) throws InterruptedException {
long start = System.currentTimeMillis();
for (int i = 0; i < 100000; i++) {
es.submit(() -> {
Thread.sleep(1000);
return 0;
});
}
es.shutdown();
if (es.awaitTermination(5, TimeUnit.MINUTES)) {
long end = System.currentTimeMillis();
System.out.println("虚拟线程" + "耗时:" + (end - start));
} else {
System.err.println("虚拟线程" + "未能在规定时间内完成所有任务");
}
}
}