diff --git a/.gitignore b/.gitignore index c7423c04..e447df73 100644 --- a/.gitignore +++ b/.gitignore @@ -44,3 +44,5 @@ build/ /src/test/java/memory/result/output1.json /src/test/java/memory/result/output2.json /src/test/java/memory/result/total_input.json +/src/test/java/memory/result/input3.json +/src/test/java/memory/result/input4.json diff --git a/src/main/java/work/slhaf/agent/Agent.java b/src/main/java/work/slhaf/agent/Agent.java index 9053bf23..b5c2c6e9 100644 --- a/src/main/java/work/slhaf/agent/Agent.java +++ b/src/main/java/work/slhaf/agent/Agent.java @@ -23,7 +23,7 @@ public class Agent implements TaskCallback, InputReceiver { private InteractionHub interactionHub; private MessageSender messageSender; - public static Agent initialize() throws IOException { + public static void initialize() throws IOException { if (agent == null) { //加载配置 Config config = Config.getConfig(); @@ -31,23 +31,14 @@ public class Agent implements TaskCallback, InputReceiver { agent.setInteractionHub(InteractionHub.initialize()); agent.registerTaskCallback(); AgentWebSocketServer server = new AgentWebSocketServer(config.getWebSocketConfig().getPort(),agent); - server.start(); - Runtime.getRuntime().addShutdownHook(new Thread(() -> { - try { - for (WebSocket conn : server.getConnections()) { - conn.close(); - } - server.stop(); - log.info("WebSocketServer 已关闭"); - } catch (Exception e) { - log.error("关闭失败", e); - } - })); - + server.launch(); agent.setMessageSender(server); - log.info("Agent 加载完毕.."); } + } + + public static Agent getInstance() throws IOException { + initialize(); return agent; } diff --git a/src/main/java/work/slhaf/agent/common/chat/pojo/Message.java b/src/main/java/work/slhaf/agent/common/chat/pojo/Message.java index 6e490761..5d84b640 100644 --- a/src/main/java/work/slhaf/agent/common/chat/pojo/Message.java +++ b/src/main/java/work/slhaf/agent/common/chat/pojo/Message.java @@ -1,12 +1,20 @@ package work.slhaf.agent.common.chat.pojo; import lombok.*; +import work.slhaf.agent.common.pojo.PersistableObject; +import java.io.Serial; + +@EqualsAndHashCode(callSuper = true) @Builder @Data @AllArgsConstructor @NoArgsConstructor -public class Message { +public class Message extends PersistableObject { + + @Serial + private static final long serialVersionUID = 1L; + @NonNull private String role; @NonNull diff --git a/src/main/java/work/slhaf/agent/common/config/Config.java b/src/main/java/work/slhaf/agent/common/config/Config.java index 62ec2e2b..efb0afd4 100644 --- a/src/main/java/work/slhaf/agent/common/config/Config.java +++ b/src/main/java/work/slhaf/agent/common/config/Config.java @@ -9,6 +9,7 @@ import work.slhaf.agent.modules.memory.selector.MemorySelector; import work.slhaf.agent.modules.memory.selector.evaluator.SliceSelectEvaluator; import work.slhaf.agent.modules.memory.selector.extractor.MemorySelectExtractor; import work.slhaf.agent.modules.memory.updater.MemoryUpdater; +import work.slhaf.agent.modules.memory.updater.static_extractor.StaticMemoryExtractor; import work.slhaf.agent.modules.memory.updater.summarizer.MemorySummarizer; import work.slhaf.agent.modules.task.TaskEvaluator; import work.slhaf.agent.modules.task.TaskScheduler; @@ -69,8 +70,8 @@ public class Config { List moduleConfigList = List.of( new ModuleConfig(MemorySelector.class.getName(), ModuleConfig.Constant.INTERNAL, null), new ModuleConfig(CoreModel.class.getName(),ModuleConfig.Constant.INTERNAL,null), - new ModuleConfig(MemoryUpdater.class.getName(),ModuleConfig.Constant.INTERNAL,null), - new ModuleConfig(TaskScheduler.class.getName(), ModuleConfig.Constant.INTERNAL, null) + new ModuleConfig(MemoryUpdater.class.getName(),ModuleConfig.Constant.INTERNAL,null) +// new ModuleConfig(TaskScheduler.class.getName(), ModuleConfig.Constant.INTERNAL, null) ); config.setModuleConfigList(moduleConfigList); } @@ -105,7 +106,7 @@ public class Config { modelConfig.setModel(scanner.nextLine()); } - for (int i = 0; i < 5; i++) { + for (int i = 0; i < 6; i++) { String modelKey = switch (i) { case 0 -> { System.out.println("CoreModel:"); @@ -127,6 +128,10 @@ public class Config { System.out.println("MemorySummarizer:"); yield MemorySummarizer.MODEL_KEY; } + case 5 -> { + System.out.println("StaticMemoryExtractor:"); + yield StaticMemoryExtractor.MODEL_KEY; + } default -> throw new RuntimeException(); }; if (!singleModel) { diff --git a/src/main/java/work/slhaf/agent/common/model/ModelConstant.java b/src/main/java/work/slhaf/agent/common/model/ModelConstant.java index a0bde30e..86a67726 100644 --- a/src/main/java/work/slhaf/agent/common/model/ModelConstant.java +++ b/src/main/java/work/slhaf/agent/common/model/ModelConstant.java @@ -5,12 +5,26 @@ public class ModelConstant { CoreModel 提示词 功能说明 - 你需要根据用户的当前输入(text)生成恰当的回复。只有当以下字段与text内容直接相关时,才需要参考它们: + 你需要根据用户的当前输入(text)生成恰当的回复。每条用户输入都采用以下格式: + + ``` + [用户昵称(用户uuid)] 实际输入内容 + + ``` + + 你需要只基于最新一条消息中的用户(即最后一条user类型消息中括号内的uuid)进行回应,仅参考该用户的历史上下文内容。 + + 如果其他用户的对话历史中提到的信息能**明确补充该用户的信息背景**(如他人提及该用户、与其对话、对其信息进行补全等),你可以将其作为当前用户的新知识补充。否则,完全忽略其他用户的内容。 + + 注意,历史消息中将只包含带有前缀 `[用户昵称(用户uuid)]` 的完整输入文本,不会带有下文提到的额外字段。 + + 字段说明 + - text:指的是“原始输入内容”,包含带有前缀 `[用户昵称(用户uuid)]` 的完整输入文本 - datetime:当text包含时间相关语义时使用 - character:当需要根据角色设定调整语气时使用 - user_nick:当text中包含对用户的称呼或个性化需求时使用 - - user_id:用户的唯一标识,该字段真正具有区分用户的作用 - 其他所有字段仅在明确与text内容相关时才予以考虑,否则应完全忽略。 + - user_id:等于括号中的uuid,用于唯一标识用户 + - memory_slices/static_memory:仅与当前用户相关 输入字段优先级 1. 首要关注text字段,这是核心输入内容 @@ -20,37 +34,43 @@ public class ModelConstant { • user_nick:仅当需要个性化称呼时生效 3. 其他所有扩展字段(如memory_slices/static_memory等): - 必须与text内容有明确关联时才参考 - - 若字段内容与text无关,则完全忽略该字段 + - 且只考虑当前用户的字段内容,忽略其他用户相关内容 + + 响应生成规范 + - 回复必须完全基于text字段的核心语义生成 + - 禁止引用与当前text无关的历史内容 + - 若角色设定与当前对话无关,应自动忽略 + - 输出格式严格为: + { + "text": "响应内容" + } 核心生成逻辑 1. 主内容优先原则 - - 首先独立分析text字段的语义 - - 只有当其他字段内容能直接辅助理解text时(如text说"上次说的那个"对应memory_slices中的记录),才调用相关字段 - - 若text是独立完整表达(如单字、短句、新话题开启),则忽略所有非核心字段 + - 独立分析text字段的语义 + - 仅在其他字段能直接辅助理解text的前提下引用(如text中提及“上次说的那个”) + - 若text表达独立完整(如新话题),忽略所有非核心字段 - 2. 无关字段过滤机制 - - 当text属于以下情况时,强制忽略所有扩展字段: - ✓ 短于5个字符的输入(如"在"、"好的") - ✓ 明显开启新话题的提问(如"量子计算是什么") - ✓ 不含指代词的独立陈述句 - - 示例:当text="今天天气如何"时,即使存在量子计算相关的memory_slices也应忽略 + 2. 多用户隔离机制 + - 每条消息都带有格式 `[用户昵称(用户uuid)]` + - 所有分析仅基于最后一条user消息中的用户进行处理 + - memory_slices/static_memory等内容只会包含该用户的相关信息 + - 如果历史中其他用户提到了当前用户的信息,可用于补充理解;否则忽略 - 3. 响应生成规范 - - 回复必须完全基于text的核心语义生成 - - 禁止出现"根据您之前提到的XX"等无关内容引用 - - 当角色设定(character)与当前对话无关时(如科技助手回答日常问候),暂时覆盖角色设定 + 3. 无关字段过滤机制 + - text短于5个字符(如“在”、“好的”) + - text开启新话题(如“量子计算是什么”) + - text为独立句子,无引用上下文指代 + → 此类情况强制忽略所有扩展字段 - 输出格式 - { - "text": "响应内容" // 必须严格对应text字段的语义 - } 最终注意事项 - 1. 回应内容必须紧扣用户输入,且契合角色设定 - 2. 遇到模糊提问时,优先推测最常见的语境理解,不要直接问“你指的是什么” - 3. 回应应自然衔接,并允许后续系统模块追加更多限定、扩展字段 - 4. 你只需要生成JSON格式的响应对象,字段仅包含`text`,但在模块扩展下,字段内容可以有所增加。确保你可以兼容这些扩展而不破坏结构。 - 5. 若用户的输入(text)与其他字段中的内容无关,可忽略其他字段的内容 + 1. 回应内容必须紧扣用户输入,确保基于当前用户的语境 + 2. 遇到模糊问题时,推测常见语境理解,不要直接提问 + 3. 回应应自然衔接,适配后续可能拼接的上下文或约束 + 4. 输出字段固定为`text`,但内容可根据上下文扩展 + 5. 若text与memory_slices等扩展字段无关,应完全忽略 + 6. 请确保你对每一轮对话都只针对当前输入用户作出回应,保持多用户上下文隔离的准确性 > 以下模块可能会追加更多内容限制或上下文提示,请确保你的回答能够自然兼容这些后续拼接的内容,并调整输出格式。 @@ -305,5 +325,95 @@ public class ModelConstant { public static final String BASE_SUMMARIZER_PROMPT = """ """; public static final String STATIC_MEMORY_EXTRACTOR_PROMPT = """ + StaticMemoryExtractor 提示词 + 功能说明 + 你需要根据用户对话记录(messages)和现有静态记忆(existedStaticMemory),分析并输出需要新增或修改的静态记忆项。静态记忆指用户长期有效的个人信息、习惯偏好等常识性数据。 + + 输入字段说明 + • `userId`: 用户唯一标识符(仅用于追踪) + • `messages`: 对话记录数组(需特别关注user角色的content内容) + • `existedStaticMemory`: 现有静态记忆键值对(需对比更新) + + 输出规则 + 1. 基本格式: + { + "[记忆键名]": "[记忆内容]", + ... + } + 2. 更新逻辑: + • 新增记忆:当对话中首次出现明确的新信息时(如"我养了只叫Tom的猫") + • 修改记忆:当新信息与原有记忆冲突或需要细化时(如原"居住地":"北京" → "海淀区") + • 保留键名:修改时严格保持原记忆键不变 + 3. 内容要求: + • 值必须是可直接存储的字符串 + • 排除临时性/情绪化内容(如"今天好累") + • 合并关联信息(如"Python和Java" → "编程语言:Python, Java") + + 处理流程 + 1. 扫描messages提取以下信息: + a. 人口统计学特征(年龄/职业/居住地等) + b. 长期兴趣爱好 + c. 人际关系(家人/宠物等) + d. 长期计划/目标 + 2. 对比existedStaticMemory: + a. 新信息 → 新增键值对 + b. 更精确信息 → 更新对应键的值 + c. 矛盾信息 → 以最新对话为准 + 3. 过滤无效内容: + a. 排除模糊表述(如"可能"、"考虑中") + b. 排除时效性短于1个月的信息 + + 完整示例 + 示例1(新增记忆): + 输入:{ + "userId": "U123", + "messages": [ + {"role": "user", "content": "我最近收养了只金毛叫Lucky"}, + {"role": "assistant", "content": "金毛是很温顺的犬种呢"} + ], + "existedStaticMemory": {"爱好": "爬山"} + } + 输出:{ + "宠物": "金毛犬Lucky" + } + + 示例2(修改记忆): + 输入:{ + "userId": "U456", + "messages": [ + {"role": "user", "content": "下个月要搬去上海静安区了"}, + {"role": "assistant", "content": "需要帮您找静安区的餐厅吗?"} + ], + "existedStaticMemory": {"居住地": "北京"} + } + 输出:{ + "居住地": "上海静安区" + } + + 示例3(混合更新): + 输入:{ + "userId": "U789", + "messages": [ + {"role": "user", "content": "我的MacBook Pro用了3年"}, + {"role": "assistant", "content": "建议考虑M系列芯片的新款"}, + {"role": "user", "content": "其实我更喜欢Windows系统"} + ], + "existedStaticMemory": {"电子设备": "iPhone 13", "操作系统偏好": "macOS"} + } + 输出:{ + "电子设备": "MacBook Pro", + "操作系统偏好": "Windows" + } + + 特殊处理 + 1. 当信息可信度不足时: + • 不生成记忆项(如用户说"也许我会学钢琴") + 2. 当存在多轮矛盾时: + • 以最后一次明确表述为准 + 3. 空输入处理: + { + "error": "no valid input" + } + 4. 当提到其他人时,应区分这个人的事件是否与user真正相关,如果与user无关,应当忽略 """; } diff --git a/src/main/java/work/slhaf/agent/core/memory/pojo/PersistableObject.java b/src/main/java/work/slhaf/agent/common/pojo/PersistableObject.java similarity index 69% rename from src/main/java/work/slhaf/agent/core/memory/pojo/PersistableObject.java rename to src/main/java/work/slhaf/agent/common/pojo/PersistableObject.java index 270a8400..5b125d5d 100644 --- a/src/main/java/work/slhaf/agent/core/memory/pojo/PersistableObject.java +++ b/src/main/java/work/slhaf/agent/common/pojo/PersistableObject.java @@ -1,4 +1,4 @@ -package work.slhaf.agent.core.memory.pojo; +package work.slhaf.agent.common.pojo; import java.io.Serializable; diff --git a/src/main/java/work/slhaf/agent/core/memory/MemoryGraph.java b/src/main/java/work/slhaf/agent/core/memory/MemoryGraph.java index 7b0f094c..2eeec744 100644 --- a/src/main/java/work/slhaf/agent/core/memory/MemoryGraph.java +++ b/src/main/java/work/slhaf/agent/core/memory/MemoryGraph.java @@ -5,6 +5,7 @@ import lombok.EqualsAndHashCode; import lombok.extern.slf4j.Slf4j; import org.apache.commons.io.FileUtils; import work.slhaf.agent.common.chat.pojo.Message; +import work.slhaf.agent.common.pojo.PersistableObject; import work.slhaf.agent.core.memory.exception.UnExistedTopicException; import work.slhaf.agent.core.memory.node.MemoryNode; import work.slhaf.agent.core.memory.node.TopicNode; diff --git a/src/main/java/work/slhaf/agent/core/memory/MemoryManager.java b/src/main/java/work/slhaf/agent/core/memory/MemoryManager.java index 010ae4ec..2253a502 100644 --- a/src/main/java/work/slhaf/agent/core/memory/MemoryManager.java +++ b/src/main/java/work/slhaf/agent/core/memory/MemoryManager.java @@ -45,6 +45,14 @@ public class MemoryManager implements InteractionModule { memoryManager.setMemoryGraph(MemoryGraph.getInstance(config.getAgentId())); memoryManager.setActivatedSlices(new HashMap<>()); log.info("MemoryManager注册完毕..."); + Runtime.getRuntime().addShutdownHook(new Thread(() -> { + try { + memoryManager.save(); + log.info("MemoryGraph已保存"); + } catch (IOException e) { + log.error("保存MemoryGraph失败: ", e); + } + })); } return memoryManager; } @@ -135,7 +143,11 @@ public class MemoryManager implements InteractionModule { memoryGraph.getStaticMemory().get(userId).putAll(newStaticMemory); } - public void updateDialogMap(LocalDateTime dateTime,String newDialogCache) { + public void updateDialogMap(LocalDateTime dateTime, String newDialogCache) { memoryGraph.updateDialogMap(dateTime, newDialogCache); } + + public void save() throws IOException { + memoryGraph.serialize(); + } } diff --git a/src/main/java/work/slhaf/agent/core/memory/node/MemoryNode.java b/src/main/java/work/slhaf/agent/core/memory/node/MemoryNode.java index c587650f..85fb6df5 100644 --- a/src/main/java/work/slhaf/agent/core/memory/node/MemoryNode.java +++ b/src/main/java/work/slhaf/agent/core/memory/node/MemoryNode.java @@ -5,7 +5,7 @@ import lombok.EqualsAndHashCode; import lombok.extern.slf4j.Slf4j; import work.slhaf.agent.core.memory.exception.NullSliceListException; import work.slhaf.agent.core.memory.pojo.MemorySlice; -import work.slhaf.agent.core.memory.pojo.PersistableObject; +import work.slhaf.agent.common.pojo.PersistableObject; import java.io.*; import java.nio.file.Files; diff --git a/src/main/java/work/slhaf/agent/core/memory/node/TopicNode.java b/src/main/java/work/slhaf/agent/core/memory/node/TopicNode.java index a84ca505..95af2487 100644 --- a/src/main/java/work/slhaf/agent/core/memory/node/TopicNode.java +++ b/src/main/java/work/slhaf/agent/core/memory/node/TopicNode.java @@ -2,7 +2,7 @@ package work.slhaf.agent.core.memory.node; import lombok.Data; import lombok.EqualsAndHashCode; -import work.slhaf.agent.core.memory.pojo.PersistableObject; +import work.slhaf.agent.common.pojo.PersistableObject; import java.io.Serial; import java.util.concurrent.ConcurrentHashMap; diff --git a/src/main/java/work/slhaf/agent/core/memory/pojo/MemorySlice.java b/src/main/java/work/slhaf/agent/core/memory/pojo/MemorySlice.java index 7b5ac894..e9abf616 100644 --- a/src/main/java/work/slhaf/agent/core/memory/pojo/MemorySlice.java +++ b/src/main/java/work/slhaf/agent/core/memory/pojo/MemorySlice.java @@ -3,6 +3,7 @@ package work.slhaf.agent.core.memory.pojo; import lombok.Data; import lombok.EqualsAndHashCode; import work.slhaf.agent.common.chat.pojo.Message; +import work.slhaf.agent.common.pojo.PersistableObject; import java.io.Serial; import java.util.List; diff --git a/src/main/java/work/slhaf/agent/core/memory/pojo/User.java b/src/main/java/work/slhaf/agent/core/memory/pojo/User.java index 8eed2a71..ac81c718 100644 --- a/src/main/java/work/slhaf/agent/core/memory/pojo/User.java +++ b/src/main/java/work/slhaf/agent/core/memory/pojo/User.java @@ -1,11 +1,19 @@ package work.slhaf.agent.core.memory.pojo; import lombok.Data; +import lombok.EqualsAndHashCode; +import work.slhaf.agent.common.pojo.PersistableObject; +import java.io.Serial; import java.util.List; +@EqualsAndHashCode(callSuper = true) @Data -public class User { +public class User extends PersistableObject { + + @Serial + private static final long serialVersionUID = 1L; + private String uuid; private List info; private String nickName; diff --git a/src/main/java/work/slhaf/agent/core/module/CoreModel.java b/src/main/java/work/slhaf/agent/core/module/CoreModel.java index 7ec7da51..2e8b5a24 100644 --- a/src/main/java/work/slhaf/agent/core/module/CoreModel.java +++ b/src/main/java/work/slhaf/agent/core/module/CoreModel.java @@ -49,15 +49,13 @@ public class CoreModel extends Model implements InteractionModule { @Override public void execute(InteractionContext interactionContext) { - //TODO 添加新的system prompt 引导主模型专注于最新的用户输入 - //TODO 需要更新主模型prompt String tempPrompt = interactionContext.getModulePrompt().toString(); if (!tempPrompt.equals(promptCache)) { coreModel.getMessages().set(0, new Message(ChatConstant.Character.SYSTEM, ModelConstant.CORE_MODEL_PROMPT + "\r\n" + tempPrompt)); promptCache = tempPrompt; } String user = "[" + interactionContext.getUserNickname() + "(" + interactionContext.getUserId() + ")]"; - Message userMessage = new Message(ChatConstant.Character.USER, user + interactionContext.getCoreContext().getString("text")); + Message userMessage = new Message(ChatConstant.Character.USER, user + " " + interactionContext.getCoreContext()); this.messages.add(userMessage); JSONObject response = null; int count = 0; @@ -65,13 +63,15 @@ public class CoreModel extends Model implements InteractionModule { try { ChatResponse chatResponse = this.chat(); response = JSONObject.parse(extractJson(chatResponse.getMessage())); + this.messages.removeLast(); + this.messages.add(new Message(ChatConstant.Character.USER, interactionContext.getCoreContext().getString("text"))); Message assistantMessage = new Message(ChatConstant.Character.ASSISTANT, response.getString("text")); this.messages.add(assistantMessage); //设置上下文 interactionContext.getModuleContext().put("total_token", chatResponse.getUsageBean().getTotal_tokens()); //区分单人聊天场景 - if (interactionContext.isSingle()){ + if (interactionContext.isSingle()) { MetaMessage metaMessage = new MetaMessage(userMessage, assistantMessage); sessionManager.addMetaMessage(interactionContext.getUserId(), metaMessage); } diff --git a/src/main/java/work/slhaf/agent/gateway/AgentWebSocketServer.java b/src/main/java/work/slhaf/agent/gateway/AgentWebSocketServer.java index c79f3110..48557a28 100644 --- a/src/main/java/work/slhaf/agent/gateway/AgentWebSocketServer.java +++ b/src/main/java/work/slhaf/agent/gateway/AgentWebSocketServer.java @@ -5,9 +5,11 @@ import com.alibaba.fastjson2.JSONObject; import lombok.ToString; import lombok.extern.slf4j.Slf4j; import org.java_websocket.WebSocket; +import org.java_websocket.framing.Framedata; import org.java_websocket.handshake.ClientHandshake; import org.java_websocket.server.WebSocketServer; import work.slhaf.agent.core.interaction.InputReceiver; +import work.slhaf.agent.core.interaction.InteractionThreadPoolExecutor; import work.slhaf.agent.core.interaction.data.InteractionInputData; import work.slhaf.agent.core.interaction.data.InteractionOutputData; @@ -18,23 +20,86 @@ import java.util.concurrent.ConcurrentHashMap; @Slf4j public class AgentWebSocketServer extends WebSocketServer implements MessageSender { + private static final long HEARTBEAT_INTERVAL = 10_000; + @ToString.Exclude private final InputReceiver receiver; private final ConcurrentHashMap userSessions = new ConcurrentHashMap<>(); + private final InteractionThreadPoolExecutor executor; + + // 记录最后一次收到Pong的时间 + private final ConcurrentHashMap lastPongTimes = new ConcurrentHashMap<>(); public AgentWebSocketServer(int port, InputReceiver receiver) { super(new InetSocketAddress(port)); this.receiver = receiver; + this.executor = InteractionThreadPoolExecutor.getInstance(); + } + + public void launch() { + this.start(); + setShutDownHook(); + startHeartbeatThread(); + } + + private void startHeartbeatThread() { + executor.execute(() -> { + while (!Thread.interrupted()){ + try{ + Thread.sleep(HEARTBEAT_INTERVAL); + checkConnections(); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + } + }); + } + + private void checkConnections() { + long now = System.currentTimeMillis(); + for (WebSocket conn : getConnections()) { + if (conn.isOpen()) { + // 发送Ping + conn.sendPing(); + log.debug("Sent Ping to {}", conn.getRemoteSocketAddress()); + + // 检查上次Pong响应是否超时(2倍心跳间隔) + Long lastPong = lastPongTimes.get(conn); + if (lastPong != null && now - lastPong > HEARTBEAT_INTERVAL * 2) { + log.warn("Connection {} timed out, closing...", conn.getRemoteSocketAddress()); + conn.close(1001, "No Pong response"); + } + } + } + } + + @Override + public void onWebsocketPong(WebSocket conn, Framedata f) { + lastPongTimes.put(conn, System.currentTimeMillis()); + log.debug("Received Pong from {}", conn.getRemoteSocketAddress()); + } + + private void setShutDownHook() { + Runtime.getRuntime().addShutdownHook(new Thread(() -> { + try { + //关闭WebSocketServer + this.stop(); + log.info("WebSocketServer 已关闭"); + } catch (Exception e) { + log.error("WebSocketServer关闭失败: ", e); + } + })); } @Override public void onOpen(WebSocket webSocket, ClientHandshake clientHandshake) { - log.info("新连接: {}",webSocket.getRemoteSocketAddress()); + log.info("新连接: {}", webSocket.getRemoteSocketAddress()); } @Override public void onClose(WebSocket webSocket, int i, String s, boolean b) { - log.info("连接关闭: {}",webSocket.getRemoteSocketAddress()); + log.info("连接关闭: {}", webSocket.getRemoteSocketAddress()); + lastPongTimes.remove(webSocket); userSessions.values().removeIf(session -> session.equals(webSocket)); } @@ -64,8 +129,8 @@ public class AgentWebSocketServer extends WebSocketServer implements MessageSend WebSocket webSocket = userSessions.get(outputData.getUserInfo()); if (webSocket != null && webSocket.isOpen()) { webSocket.send(JSONUtil.toJsonStr(outputData)); - }else { - log.warn("用户不在线: {}",outputData.getUserInfo()); + } else { + log.warn("用户不在线: {}", outputData.getUserInfo()); } } } diff --git a/src/main/java/work/slhaf/agent/modules/memory/selector/extractor/MemorySelectExtractor.java b/src/main/java/work/slhaf/agent/modules/memory/selector/extractor/MemorySelectExtractor.java index c1c74d95..404794e7 100644 --- a/src/main/java/work/slhaf/agent/modules/memory/selector/extractor/MemorySelectExtractor.java +++ b/src/main/java/work/slhaf/agent/modules/memory/selector/extractor/MemorySelectExtractor.java @@ -50,9 +50,14 @@ public class MemorySelectExtractor extends Model { public ExtractorResult execute(InteractionContext context) { //结构化为指定格式 List chatMessages = new ArrayList<>(); - for (MetaMessage metaMessage : sessionManager.getSingleMetaMessageMap().get(context.getUserId())) { - chatMessages.add(metaMessage.getUserMessage()); - chatMessages.add(metaMessage.getAssistantMessage()); + List metaMessages = sessionManager.getSingleMetaMessageMap().get(context.getUserId()); + if (metaMessages == null) { + sessionManager.getSingleMetaMessageMap().put(context.getUserId(), new ArrayList<>()); + } else { + for (MetaMessage metaMessage : metaMessages) { + chatMessages.add(metaMessage.getUserMessage()); + chatMessages.add(metaMessage.getAssistantMessage()); + } } ExtractorInput extractorInput = ExtractorInput.builder() diff --git a/src/main/java/work/slhaf/agent/modules/memory/updater/MemoryUpdater.java b/src/main/java/work/slhaf/agent/modules/memory/updater/MemoryUpdater.java index 027b515f..1021c5d3 100644 --- a/src/main/java/work/slhaf/agent/modules/memory/updater/MemoryUpdater.java +++ b/src/main/java/work/slhaf/agent/modules/memory/updater/MemoryUpdater.java @@ -3,6 +3,7 @@ package work.slhaf.agent.modules.memory.updater; import com.alibaba.fastjson2.JSONObject; import lombok.Data; import lombok.extern.slf4j.Slf4j; +import work.slhaf.agent.common.chat.constant.ChatConstant; import work.slhaf.agent.common.chat.pojo.Message; import work.slhaf.agent.core.interaction.InteractionModule; import work.slhaf.agent.core.interaction.InteractionThreadPoolExecutor; @@ -21,6 +22,8 @@ import java.io.IOException; import java.time.LocalDateTime; import java.util.*; import java.util.concurrent.Callable; +import java.util.regex.Matcher; +import java.util.regex.Pattern; @Data @Slf4j @@ -46,6 +49,7 @@ public class MemoryUpdater implements InteractionModule { memoryUpdater.setMemorySummarizer(MemorySummarizer.getInstance()); memoryUpdater.setSessionManager(SessionManager.getInstance()); memoryUpdater.setStaticMemoryExtractor(StaticMemoryExtractor.getInstance()); + memoryUpdater.setExecutor(InteractionThreadPoolExecutor.getInstance()); } return memoryUpdater; } @@ -81,6 +85,8 @@ public class MemoryUpdater implements InteractionModule { try { SummarizeResult summarizeResult = memorySummarizer.execute(new SummarizeInput(memoryManager.getChatMessages(), memoryManager.getTopicTree())); MemorySlice memorySlice = getMemorySlice(userId, summarizeResult, memoryManager.getChatMessages()); + //设置involvedUserId + setInvolvedUserId(userId,memorySlice,memoryManager.getChatMessages()); memoryManager.insertSlice(memorySlice, summarizeResult.getTopicPath()); //更新总dialogMap singleMemorySummary.put("total", summarizeResult.getSummary()); @@ -91,6 +97,27 @@ public class MemoryUpdater implements InteractionModule { }); } +private void setInvolvedUserId(String startUserId, MemorySlice memorySlice, List chatMessages) { + for (Message chatMessage : chatMessages) { + if (chatMessage.getRole().equals(ChatConstant.Character.ASSISTANT)) { + continue; + } + //匹配userId + String content = chatMessage.getContent(); + Pattern pattern = Pattern.compile("\\[.*\\(([^)]+)\\)\\]"); + Matcher matcher = pattern.matcher(content); + if (!matcher.find()) { + continue; + } + String userId = matcher.group(1); + if (userId.equals(startUserId)){ + continue; + } + memorySlice.getInvolvedUserIds().add(userId); + } +} + + private void updateSingleChatSlices(String interactionContext, HashMap singleMemorySummary) throws InterruptedException { //更新单聊记忆,同时从chatMessages中去掉单聊记忆 Set userIdSet = new HashSet<>(sessionManager.getSingleMetaMessageMap().keySet()); diff --git a/src/main/java/work/slhaf/agent/modules/memory/updater/static_extractor/StaticMemoryExtractor.java b/src/main/java/work/slhaf/agent/modules/memory/updater/static_extractor/StaticMemoryExtractor.java index 90dbc3e6..4ce88bb1 100644 --- a/src/main/java/work/slhaf/agent/modules/memory/updater/static_extractor/StaticMemoryExtractor.java +++ b/src/main/java/work/slhaf/agent/modules/memory/updater/static_extractor/StaticMemoryExtractor.java @@ -20,7 +20,7 @@ public class StaticMemoryExtractor extends Model { private static StaticMemoryExtractor staticMemoryExtractor; - private static final String MODEL_KEY = "static_memory_extractor"; + public static final String MODEL_KEY = "static_memory_extractor"; public static StaticMemoryExtractor getInstance() throws IOException, ClassNotFoundException { diff --git a/src/main/java/work/slhaf/agent/modules/memory/updater/summarizer/MemorySummarizer.java b/src/main/java/work/slhaf/agent/modules/memory/updater/summarizer/MemorySummarizer.java index 9b224707..b7e9f481 100644 --- a/src/main/java/work/slhaf/agent/modules/memory/updater/summarizer/MemorySummarizer.java +++ b/src/main/java/work/slhaf/agent/modules/memory/updater/summarizer/MemorySummarizer.java @@ -34,7 +34,8 @@ public class MemorySummarizer extends Model { public static final String MODEL_KEY = "memory_summarizer"; private static final List prompts = List.of( Constant.SINGLE_SUMMARIZE_PROMPT, - Constant.MULTI_SUMMARIZE_PROMPT + Constant.MULTI_SUMMARIZE_PROMPT, + Constant.TOTAL_SUMMARIZE_PROMPT ); private InteractionThreadPoolExecutor executor; @@ -304,5 +305,95 @@ public class MemorySummarizer extends Model { } """; + + public static final String TOTAL_SUMMARIZE_PROMPT = """ + TOTAL_SUMMARIZER 提示词 + 功能说明 + 你需要根据输入的多个独立用户对话摘要,生成一份综合性的总结报告。每个用户的对话内容彼此无关联,需保持原始信息的同时进行概括性整合,最终输出标准化JSON格式的响应。 + + 输入字段说明 + • 输入数据为JSON对象: + - key: 用户uuid(需在输出中保留) + - value: 该用户的对话摘要文本(需要处理的内容) + + 输出规则 + 1. 基本响应格式: + { + "content": string // 综合摘要文本 + } + 2. 内容要求: + • 严格控制在800字以内 + • 保持客观中立,不添加解释性内容 + • 使用分号分隔不同用户的摘要内容 + • 保留原始对话的关键事实信息 + • 对重复信息进行合并处理 + 3. 格式要求: + • 每个用户摘要以"用户[uuid]:"开头 + • 不同用户摘要间用分号分隔 + • 末尾不添加总结性陈述 + + 处理流程 + 1. 解析输入JSON的所有键值对 + 2. 对每个摘要执行: + a. 提取关键事实信息 + b. 删除问候语等非实质性内容 + c. 简化重复表达 + 3. 合并处理: + a. 识别不同摘要中的相同信息点 + b. 合并相同信息点的不同表述 + 4. 生成最终摘要: + a. 按原始输入顺序排列用户摘要 + b. 确保总字数≤800 + c. 验证信息完整性 + + 完整示例 + 示例1(基础情况): + 输入:{ + "aaa-111": "需要购买笔记本电脑,预算5000左右,主要用于办公", + "bbb-222": "想买游戏本,预算8000-10000,要能运行3A大作", + "ccc-333": "咨询轻薄本推荐,经常出差使用" + } + 输出:{ + "content": " + 用户[aaa-111]:需要5000元左右的办公笔记本; + 用户[bbb-222]:寻求8000-10000元的游戏本,要求能运行3A大作; + 用户[ccc-333]:咨询适合出差使用的轻薄本" + } + + 示例2(信息合并): + 输入:{ + "ddd-444": "想了解Python入门课程,零基础", + "eee-555": "询问Java和Python哪个更适合新手", + "fff-666": "零基础,想学Python数据分析" + } + 输出:{ + "content": " + 用户[ddd-444]:零基础想了解Python入门课程; + 用户[eee-555]:询问Java和Python对新手的适用性; + 用户[fff-666]:零基础想学习Python数据分析" + } + + 示例3(长文本精简): + 输入:{ + "ggg-777": "您好!我最近在准备考研,想咨询下时间规划。具体是想了解每天应该分配多少时间给英语复习,我现在英语水平大概是四级刚过的程度...(后续200字详细描述)", + "hhh-888": "考研政治怎么准备?需要报班吗?" + } + 输出:{ + "content": " + 用户[ggg-777]:咨询考研英语复习时间规划,当前英语水平为四级; + 用户[hhh-888]:询问考研政治备考方法及是否需要报班" + } + + 特殊处理 + 1. 当总字数超出限制时: + • 尽量保留所有出现的用户摘要 + 2. 当输入为空时: + { + "content": "" + } + 3. 当用户uuid包含特殊字符时: + • 保持原始uuid格式不做修改 + • 示例:用户[xxx-ddssss-xx]:内容摘要 + """; } } diff --git a/src/test/java/memory/RegexTest.java b/src/test/java/memory/RegexTest.java new file mode 100644 index 00000000..b00ea413 --- /dev/null +++ b/src/test/java/memory/RegexTest.java @@ -0,0 +1,31 @@ +package memory; + +import org.junit.jupiter.api.Test; + +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +public class RegexTest { + + @Test + public void regexTest(){ + String[] examples = { + "[小明(userId)] 我在开会] (te[]st)", + "[用户(昵)称(userId)] 你好[呀]", + "[测试账号(userId)] 这是一个(test(123))消息" + }; + + Pattern pattern = Pattern.compile("\\[.*?\\(([^)]+)\\)\\]"); + + for (String example : examples) { + Matcher matcher = pattern.matcher(example); + if (matcher.find()) { + System.out.println("在 '" + example + "' 中找到 userId: " + matcher.group(1)); + System.out.println(); + } else { + System.out.println("在 '" + example + "' 中未找到 userId"); + } + } + + } +}