From 9e0af5e5aa0c719f12e8719a9a93be74f51cad49 Mon Sep 17 00:00:00 2001 From: slhafzjw Date: Wed, 7 May 2025 21:38:41 +0800 Subject: [PATCH] =?UTF-8?q?=E6=8E=A8=E8=BF=9B=E8=AE=B0=E5=BF=86=E6=A8=A1?= =?UTF-8?q?=E5=9D=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 在 InteractionThreadPoolExecutor 中引入虚拟线程池 (newVirtualThreadPerTaskExecutor) - 更新相关测试文件以适应新的线程池 - 优化 MemorySummarizer 中的单条目摘要逻辑 - 为 SingleSummarizer 、 MultiSummarizer 设计了提示词 - 还差两份提示词没有设计... --- .gitignore | 5 + src/main/java/work/slhaf/agent/Agent.java | 2 +- .../InteractionThreadPoolExecutor.java | 45 ++-- .../memory/selector/MemorySelector.java | 4 +- .../modules/memory/updater/MemoryUpdater.java | 10 +- .../data/StaticMemoryExtractInput.java | 1 - .../updater/summarizer/MemorySummarizer.java | 219 +++++++++++++++++- src/test/java/memory/ThreadPoolTest.java | 34 +++ 8 files changed, 289 insertions(+), 31 deletions(-) create mode 100644 src/test/java/memory/ThreadPoolTest.java diff --git a/.gitignore b/.gitignore index 25510217..c7423c04 100644 --- a/.gitignore +++ b/.gitignore @@ -39,3 +39,8 @@ build/ /data/ /config/ /src/test/java/memory/test.json +/src/test/java/memory/result/input1.json +/src/test/java/memory/result/input2.json +/src/test/java/memory/result/output1.json +/src/test/java/memory/result/output2.json +/src/test/java/memory/result/total_input.json diff --git a/src/main/java/work/slhaf/agent/Agent.java b/src/main/java/work/slhaf/agent/Agent.java index 4674595b..9053bf23 100644 --- a/src/main/java/work/slhaf/agent/Agent.java +++ b/src/main/java/work/slhaf/agent/Agent.java @@ -38,7 +38,7 @@ public class Agent implements TaskCallback, InputReceiver { conn.close(); } server.stop(); - log.info("WebSocketServer 已优雅关闭"); + log.info("WebSocketServer 已关闭"); } catch (Exception e) { log.error("关闭失败", e); } diff --git a/src/main/java/work/slhaf/agent/core/interaction/InteractionThreadPoolExecutor.java b/src/main/java/work/slhaf/agent/core/interaction/InteractionThreadPoolExecutor.java index e9e61868..ab704faf 100644 --- a/src/main/java/work/slhaf/agent/core/interaction/InteractionThreadPoolExecutor.java +++ b/src/main/java/work/slhaf/agent/core/interaction/InteractionThreadPoolExecutor.java @@ -1,28 +1,45 @@ package work.slhaf.agent.core.interaction; -import java.util.concurrent.ArrayBlockingQueue; -import java.util.concurrent.BlockingQueue; -import java.util.concurrent.ThreadPoolExecutor; +import lombok.Getter; + +import java.util.List; +import java.util.concurrent.Callable; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; import java.util.concurrent.TimeUnit; -public class InteractionThreadPoolExecutor extends ThreadPoolExecutor { +@Getter +public class InteractionThreadPoolExecutor { private static InteractionThreadPoolExecutor interactionThreadPoolExecutor; - private InteractionThreadPoolExecutor(int corePoolSize, int maximumPoolSize, long keepAliveTime, TimeUnit unit, BlockingQueue workQueue) { - super(corePoolSize, maximumPoolSize, keepAliveTime, unit, workQueue); - } + private final ExecutorService executorService = Executors.newVirtualThreadPerTaskExecutor(); public static InteractionThreadPoolExecutor getInstance() { if (interactionThreadPoolExecutor == null) { - interactionThreadPoolExecutor = new InteractionThreadPoolExecutor( - 8, - 24, - 60, - TimeUnit.SECONDS, - new ArrayBlockingQueue<>(50) - ); + interactionThreadPoolExecutor = new InteractionThreadPoolExecutor(); } return interactionThreadPoolExecutor; } + + + public void invokeAll(List> tasks, int time, TimeUnit timeUnit) { + try { + executorService.invokeAll(tasks, time, timeUnit); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + } + + public void invokeAll(List> tasks) { + try { + executorService.invokeAll(tasks); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + } + + public void execute(Runnable runnable) { + executorService.execute(runnable); + } } diff --git a/src/main/java/work/slhaf/agent/modules/memory/selector/MemorySelector.java b/src/main/java/work/slhaf/agent/modules/memory/selector/MemorySelector.java index 791d6cc7..9e9c0708 100644 --- a/src/main/java/work/slhaf/agent/modules/memory/selector/MemorySelector.java +++ b/src/main/java/work/slhaf/agent/modules/memory/selector/MemorySelector.java @@ -6,11 +6,11 @@ import work.slhaf.agent.core.interaction.data.InteractionContext; import work.slhaf.agent.core.memory.MemoryManager; import work.slhaf.agent.core.memory.pojo.MemoryResult; import work.slhaf.agent.core.memory.pojo.MemorySlice; -import work.slhaf.agent.modules.memory.selector.evaluator.data.EvaluatorInput; import work.slhaf.agent.modules.memory.selector.evaluator.SliceSelectEvaluator; +import work.slhaf.agent.modules.memory.selector.evaluator.data.EvaluatorInput; +import work.slhaf.agent.modules.memory.selector.extractor.MemorySelectExtractor; import work.slhaf.agent.modules.memory.selector.extractor.data.ExtractorMatchData; import work.slhaf.agent.modules.memory.selector.extractor.data.ExtractorResult; -import work.slhaf.agent.modules.memory.selector.extractor.MemorySelectExtractor; import work.slhaf.agent.shared.memory.EvaluatedSlice; import java.io.IOException; diff --git a/src/main/java/work/slhaf/agent/modules/memory/updater/MemoryUpdater.java b/src/main/java/work/slhaf/agent/modules/memory/updater/MemoryUpdater.java index 8e0328f8..027b515f 100644 --- a/src/main/java/work/slhaf/agent/modules/memory/updater/MemoryUpdater.java +++ b/src/main/java/work/slhaf/agent/modules/memory/updater/MemoryUpdater.java @@ -14,15 +14,13 @@ import work.slhaf.agent.modules.memory.selector.extractor.MemorySelectExtractor; import work.slhaf.agent.modules.memory.updater.static_extractor.StaticMemoryExtractor; import work.slhaf.agent.modules.memory.updater.static_extractor.data.StaticMemoryExtractInput; import work.slhaf.agent.modules.memory.updater.summarizer.MemorySummarizer; -import work.slhaf.agent.modules.memory.updater.summarizer.data.SummarizeResult; import work.slhaf.agent.modules.memory.updater.summarizer.data.SummarizeInput; +import work.slhaf.agent.modules.memory.updater.summarizer.data.SummarizeResult; import java.io.IOException; import java.time.LocalDateTime; import java.util.*; import java.util.concurrent.Callable; -import java.util.concurrent.ExecutorService; -import java.util.concurrent.Executors; @Data @Slf4j @@ -30,7 +28,6 @@ public class MemoryUpdater implements InteractionModule { private static MemoryUpdater memoryUpdater; - private ExecutorService updateExecutor; private MemoryManager memoryManager; private InteractionThreadPoolExecutor executor; private MemorySelectExtractor memorySelectExtractor; @@ -48,7 +45,6 @@ public class MemoryUpdater implements InteractionModule { memoryUpdater.setMemorySelectExtractor(MemorySelectExtractor.getInstance()); memoryUpdater.setMemorySummarizer(MemorySummarizer.getInstance()); memoryUpdater.setSessionManager(SessionManager.getInstance()); - memoryUpdater.setUpdateExecutor(Executors.newSingleThreadExecutor()); memoryUpdater.setStaticMemoryExtractor(StaticMemoryExtractor.getInstance()); } return memoryUpdater; @@ -59,7 +55,7 @@ public class MemoryUpdater implements InteractionModule { if (interactionContext.isFinished()) { return; } - updateExecutor.execute(() -> { + executor.execute(() -> { //如果token 大于阈值,则更新记忆 JSONObject moduleContext = interactionContext.getModuleContext(); if (moduleContext.getIntValue("total_token") > 24000) { @@ -73,9 +69,7 @@ public class MemoryUpdater implements InteractionModule { } catch (InterruptedException | IOException | ClassNotFoundException e) { log.error("记忆更新线程出错: {}", e.getLocalizedMessage()); } - } - }); } diff --git a/src/main/java/work/slhaf/agent/modules/memory/updater/static_extractor/data/StaticMemoryExtractInput.java b/src/main/java/work/slhaf/agent/modules/memory/updater/static_extractor/data/StaticMemoryExtractInput.java index 690872b2..b6a50cf0 100644 --- a/src/main/java/work/slhaf/agent/modules/memory/updater/static_extractor/data/StaticMemoryExtractInput.java +++ b/src/main/java/work/slhaf/agent/modules/memory/updater/static_extractor/data/StaticMemoryExtractInput.java @@ -4,7 +4,6 @@ import lombok.Builder; import lombok.Data; import work.slhaf.agent.common.chat.pojo.Message; -import java.util.HashMap; import java.util.List; import java.util.Map; diff --git a/src/main/java/work/slhaf/agent/modules/memory/updater/summarizer/MemorySummarizer.java b/src/main/java/work/slhaf/agent/modules/memory/updater/summarizer/MemorySummarizer.java index 87a20285..9b224707 100644 --- a/src/main/java/work/slhaf/agent/modules/memory/updater/summarizer/MemorySummarizer.java +++ b/src/main/java/work/slhaf/agent/modules/memory/updater/summarizer/MemorySummarizer.java @@ -13,8 +13,8 @@ import work.slhaf.agent.common.config.Config; import work.slhaf.agent.common.model.Model; import work.slhaf.agent.common.model.ModelConstant; import work.slhaf.agent.core.interaction.InteractionThreadPoolExecutor; -import work.slhaf.agent.modules.memory.updater.summarizer.data.SummarizeResult; import work.slhaf.agent.modules.memory.updater.summarizer.data.SummarizeInput; +import work.slhaf.agent.modules.memory.updater.summarizer.data.SummarizeResult; import java.io.IOException; import java.util.ArrayList; @@ -32,7 +32,10 @@ public class MemorySummarizer extends Model { private static MemorySummarizer memorySummarizer; public static final String MODEL_KEY = "memory_summarizer"; - private static final List prompts = List.of(); + private static final List prompts = List.of( + Constant.SINGLE_SUMMARIZE_PROMPT, + Constant.MULTI_SUMMARIZE_PROMPT + ); private InteractionThreadPoolExecutor executor; @@ -63,14 +66,14 @@ public class MemorySummarizer extends Model { return JSONObject.parseObject(response.getMessage(), SummarizeResult.class); } - private void singleMessageSummarize(List chatMessages) throws InterruptedException { + private void singleMessageSummarize(List chatMessages) { List> tasks = new ArrayList<>(); for (Message chatMessage : chatMessages) { if (chatMessage.getRole().equals(ChatConstant.Character.ASSISTANT)) { String content = chatMessage.getContent(); if (chatMessage.getContent().length() > 500) { tasks.add(() -> { - chatMessage.setContent(singleSummarizeExecute(prompts.get(0), content)); + chatMessage.setContent(singleSummarizeExecute(prompts.getFirst(), JSONObject.of("content", content).toString())); return null; }); } @@ -83,7 +86,7 @@ public class MemorySummarizer extends Model { try { ChatResponse response = chatClient.runChat(List.of(new Message(ChatConstant.Character.SYSTEM, prompt), new Message(ChatConstant.Character.USER, content))); - return response.getMessage(); + return JSONObject.parseObject(extractJson(response.getMessage())).getString("content"); } catch (Exception e) { log.error(e.getLocalizedMessage()); return content; @@ -96,4 +99,210 @@ public class MemorySummarizer extends Model { new Message(ChatConstant.Character.USER, JSONUtil.toJsonPrettyStr(singleMemorySummary)))); return JSONObject.parseObject(extractJson(response.getMessage())).getString("value"); } + + private static class Constant { + public static final String SINGLE_SUMMARIZE_PROMPT = """ + SINGLE_SUMMARIZER 提示词 + 功能说明 + 你需要根据用户输入的JSON数据中的`content`字段内容,生成精简且保留关键细节的摘要,严格控制输出在500字以内。 + + 输入字段说明 + • `content`: 需要被摘要的原始文本内容(可能包含复杂信息或多段落结构) + + 输出规则 + 1. 基本响应格式: + { + "content": string // 摘要后的文本内容 + } + 2. 摘要质量要求: + • 保留所有关键事实和数据 + • 维持原始信息的因果关系 + • 优先保留具体名词和数字信息 + • 删除冗余修饰词和重复表达 + 3. 长度控制: + • 硬性限制:绝对不超过500字符(按中文计算) + • 理想长度:200-450字符区间 + 4. 特殊处理: + • 当检测到列表/条目信息时:改用分号连接 + • 当存在直接引语时:保留核心引述但可简化引导句 + + 处理流程 + 1. 首次扫描:识别文本中的关键要素(5W1H) + 2. 二次分析:标注需要保留的具体数据/专有名词 + 3. 结构优化: + a. 合并同类段落 + b. 转换长句为短句 + c. 用更简洁的表达替换复杂句式 + 4. 最终校验:检查是否丢失关键信息 + + 完整示例 + 示例1(常规长文本): + 输入:{ + "content": "在2023年第四季度,XX公司实现了显著增长。财报显示总收入达到4.56亿元,同比增长32%。其中主要增长来自智能手机业务板块,该板块贡献了3.12亿元收入,同比增长达45%。同时智能家居业务收入1.44亿元,同比增长12%。公司CEO在财报电话会议中强调,增长主要得益于东南亚市场的成功拓展..." + } + 输出:{ + "content": "XX公司2023年Q4总收入4.56亿元(同比+32%),智能手机业务贡献3.12亿元(+45%),智能家居1.44亿元(+12%),增长主要来自东南亚市场拓展。" + } + + 示例2(多段落文本): + 输入:{ + "content": "本次项目改造涉及三个主要方面。首先,硬件升级包括:1) 更换全部服务器设备;2) 安装新的网络交换机;3) 部署智能安防系统。其次,软件系统将迁移至新平台,需完成数据迁移和接口适配。最后,人员培训计划分三阶段实施..." + } + 输出:{ + "content": "项目改造含硬件升级(更换服务器、新交换机、智能安防)、软件系统迁移(含数据迁移和接口适配)及分三阶段的人员培训。" + } + + 示例3(技术文档): + 输入:{ + "content": "该算法采用改进的卷积神经网络架构,包含3个主要模块:特征提取模块(由5个卷积层组成)、注意力机制模块(含通道和空间注意力)、以及分类模块(使用2个全连接层)。在ImageNet数据集上达到92.3%的准确率..." + } + 输出:{ + "content": "算法使用改进CNN架构,含特征提取(5卷积层)、注意力机制(通道+空间)和分类模块(2全连接层),在ImageNet上准确率92.3%。" + } + """; + + public static final String MULTI_SUMMARIZE_PROMPT = """ + DialogueTopicMapper 提示词 + 功能说明 + 分析对话内容并生成最深为7层的多层次主题路径,支持智能扩展主题树结构,根据用户意图动态调整路径生成策略。 + + 在保证符合以下要求的同时尽快输出 + + 输入字段说明 + • topicTree: 现有主题树结构(多根节点) + • chatMessages: 完整对话记录(需分析双方发言) + + 输出规则 + 0. **只需要输出所需的JSON文本** + 1. 核心结构(保持原格式): + { + "summary": "", // 精简摘要(100-150字) + "topicPath": "", // 主路径(领域纯净的完整抽象链) + "relatedTopicPath": [], // 相关路径(允许跨领域) + "isPrivate": false + } + + 2. 主题路径生成细则: + • 抽象链构建流程: + a. 以`user`的意图为主要锚点,锁定最低节点 + b. 逐层抽象(地标→城市→国家→大洲),需保证抽象链的纯净,确保不会跨越领域 + c. 修剪抽象链,使其保持在[3, 7]层之内,同时每层的抽象节点考虑扩展性及可复用性 + d. 形成最终路径(格式:领域→大类→子类→实例) + + • 意图影响规则: + 用户意图类型 | 主路径特征 | 相关路径特征 + ----------------|-------------------------|------------------- + 知识咨询 | 聚焦专业领域链 | 补充相关学科 + 经验分享 | 生活场景链 | 关联文化/社会 + 事件讨论 | 时空维度链 | 链接相关事件 + + 3. 动态扩展规范: + • 新根节点创建条件: + - 当抽象层级超过现有树结构时(如现有最高为"国家",需创建"大洲") + - 检测到全新领域维度时(如原树无"天文"相关节点) + + 主题树格式示例 + (使用自然换行,无需转义符) + 地理[root] + └── 欧洲 + ├── 法国 + └── 德国 + 生活[root] + └── 旅行 + ├── 自由行 + └── 跟团游 + + 处理流程 + 1. 意图分析阶段: + a. 判断对话类型(咨询/分享/讨论) + b. 标记关键实体和动作 + 2. 路径构建阶段: + a. 自下而上构建抽象链(实例→抽象概念) + b. 验证层级逻辑(子类必须属于父类范畴) + c. 生成最终路径(格式示例:生活->旅行->自由行->欧洲游) + 3. 扩展校验阶段: + a. 新增节点必须通过逻辑验证 + b. 技术术语需符合行业标准 + + 完整示例 + 示例1(日常分享): + 输入:{ + "topicTree": " + 生活[root] + └── 旅行", + "chatMessages": [ + {"role": "user", "content": "刚完成欧洲自由行,在巴黎铁塔拍到绝美夜景"}, + {"role": "assistant", "content": "推荐使用Lightroom处理夜景RAW格式"} + ] + } + 输出:{ + "summary": "用户分享欧洲自由行经历并讨论夜景照片处理", + "topicPath": "生活->旅行->自由行->欧洲->法国->巴黎铁塔", + "relatedTopicPath": [ + "艺术->摄影->夜景拍摄", + "科技->软件->图像处理->Lightroom" + ], + "isPrivate": false + } + + 示例2(专业咨询): + 输入:{ + "topicTree": " + 计算机[root] + └── 编程", + "chatMessages": [ + {"role": "user", "content": "SpringBoot项目如何实现JWT鉴权"}, + {"role": "assistant", "content": "需集成spring-security-jwt依赖..."} + ] + } + 输出:{ + "summary": "讨论SpringBoot项目集成JWT鉴权的技术方案", + "topicPath": "计算机->软件开发->Java->SpringBoot->安全->JWT", + "relatedTopicPath": [ + "计算机->网络安全->认证协议", + "数学->加密算法->非对称加密" + ], + "isPrivate": false + } + + 示例3(事件讨论): + 输入:{ + "topicTree": " + 社会[root] + ├── 教育 + └── 科技", + "chatMessages": [ + {"role": "user", "content": "听说某大学研发出脑机接口新成果"}, + {"role": "assistant", "content": "该技术涉及神经科学和AI的跨学科研究"} + ] + } + 输出:{ + "summary": "讨论某大学在脑机接口领域的跨学科研究成果", + "topicPath": "社会->科技->人工智能->脑机接口", + "relatedTopicPath": [ + "科学->生物学->神经科学", + "教育->高等教育->科研创新" + ], + "isPrivate": false + } + + 示例4(隐私事件): + 输入:{ + "topicTree": " + 法律[root] + └── 隐私", + "chatMessages": [ + {"role": "user", "content": "这个合同条款请仅限我们之间知晓"}, + {"role": "assistant", "content": "已启用加密存储,不会外泄"} + ] + } + 输出:{ + "summary": "用户要求保密合同条款内容", + "topicPath": "法律->合同法->保密条款", + "relatedTopicPath": ["信息技术->数据安全->加密存储"], + "isPrivate": true + } + + """; + } } diff --git a/src/test/java/memory/ThreadPoolTest.java b/src/test/java/memory/ThreadPoolTest.java new file mode 100644 index 00000000..e1f418c2 --- /dev/null +++ b/src/test/java/memory/ThreadPoolTest.java @@ -0,0 +1,34 @@ +package memory; + +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; + +public class ThreadPoolTest { + public static void main(String[] args) throws InterruptedException { + testExecutor(Executors.newVirtualThreadPerTaskExecutor()); + +// Thread.sleep(2000); // 等待系统输出稳定 + +// testExecutor("普通线程池", Executors.newFixedThreadPool(100)); + } + + private static void testExecutor(ExecutorService es) throws InterruptedException { + long start = System.currentTimeMillis(); + + for (int i = 0; i < 100000; i++) { + es.submit(() -> { + Thread.sleep(1000); + return 0; + }); + } + + es.shutdown(); + if (es.awaitTermination(5, TimeUnit.MINUTES)) { + long end = System.currentTimeMillis(); + System.out.println("虚拟线程" + "耗时:" + (end - start)); + } else { + System.err.println("虚拟线程" + "未能在规定时间内完成所有任务"); + } + } +}