mirror of
https://github.com/slhaf/Partner.git
synced 2026-05-12 16:53:04 +08:00
记忆模块、主模块完成, 进行了一些小测试
- 完成了totalSummarizer、staticMemoryExtractor的提示词设计 - 更新了 coreModel 的提示词设计,使其聚焦于最新用户,同时做到不同用户的上下文语义隔离、知识共享 - 更新了 MemoryUpdater 中针对多人场景的记忆切片设置 involvedUserId 功能 - 在程序结束时将主动触发 MemoryGraph 的持久化 - 在Config中添加了对于StaticMemoryExtractor的适配 - PersistableObject 移动位置至common包
This commit is contained in:
2
.gitignore
vendored
2
.gitignore
vendored
@@ -44,3 +44,5 @@ build/
|
||||
/src/test/java/memory/result/output1.json
|
||||
/src/test/java/memory/result/output2.json
|
||||
/src/test/java/memory/result/total_input.json
|
||||
/src/test/java/memory/result/input3.json
|
||||
/src/test/java/memory/result/input4.json
|
||||
|
||||
@@ -23,7 +23,7 @@ public class Agent implements TaskCallback, InputReceiver {
|
||||
private InteractionHub interactionHub;
|
||||
private MessageSender messageSender;
|
||||
|
||||
public static Agent initialize() throws IOException {
|
||||
public static void initialize() throws IOException {
|
||||
if (agent == null) {
|
||||
//加载配置
|
||||
Config config = Config.getConfig();
|
||||
@@ -31,23 +31,14 @@ public class Agent implements TaskCallback, InputReceiver {
|
||||
agent.setInteractionHub(InteractionHub.initialize());
|
||||
agent.registerTaskCallback();
|
||||
AgentWebSocketServer server = new AgentWebSocketServer(config.getWebSocketConfig().getPort(),agent);
|
||||
server.start();
|
||||
Runtime.getRuntime().addShutdownHook(new Thread(() -> {
|
||||
try {
|
||||
for (WebSocket conn : server.getConnections()) {
|
||||
conn.close();
|
||||
}
|
||||
server.stop();
|
||||
log.info("WebSocketServer 已关闭");
|
||||
} catch (Exception e) {
|
||||
log.error("关闭失败", e);
|
||||
}
|
||||
}));
|
||||
|
||||
server.launch();
|
||||
agent.setMessageSender(server);
|
||||
|
||||
log.info("Agent 加载完毕..");
|
||||
}
|
||||
}
|
||||
|
||||
public static Agent getInstance() throws IOException {
|
||||
initialize();
|
||||
return agent;
|
||||
}
|
||||
|
||||
|
||||
@@ -1,12 +1,20 @@
|
||||
package work.slhaf.agent.common.chat.pojo;
|
||||
|
||||
import lombok.*;
|
||||
import work.slhaf.agent.common.pojo.PersistableObject;
|
||||
|
||||
import java.io.Serial;
|
||||
|
||||
@EqualsAndHashCode(callSuper = true)
|
||||
@Builder
|
||||
@Data
|
||||
@AllArgsConstructor
|
||||
@NoArgsConstructor
|
||||
public class Message {
|
||||
public class Message extends PersistableObject {
|
||||
|
||||
@Serial
|
||||
private static final long serialVersionUID = 1L;
|
||||
|
||||
@NonNull
|
||||
private String role;
|
||||
@NonNull
|
||||
|
||||
@@ -9,6 +9,7 @@ import work.slhaf.agent.modules.memory.selector.MemorySelector;
|
||||
import work.slhaf.agent.modules.memory.selector.evaluator.SliceSelectEvaluator;
|
||||
import work.slhaf.agent.modules.memory.selector.extractor.MemorySelectExtractor;
|
||||
import work.slhaf.agent.modules.memory.updater.MemoryUpdater;
|
||||
import work.slhaf.agent.modules.memory.updater.static_extractor.StaticMemoryExtractor;
|
||||
import work.slhaf.agent.modules.memory.updater.summarizer.MemorySummarizer;
|
||||
import work.slhaf.agent.modules.task.TaskEvaluator;
|
||||
import work.slhaf.agent.modules.task.TaskScheduler;
|
||||
@@ -69,8 +70,8 @@ public class Config {
|
||||
List<ModuleConfig> moduleConfigList = List.of(
|
||||
new ModuleConfig(MemorySelector.class.getName(), ModuleConfig.Constant.INTERNAL, null),
|
||||
new ModuleConfig(CoreModel.class.getName(),ModuleConfig.Constant.INTERNAL,null),
|
||||
new ModuleConfig(MemoryUpdater.class.getName(),ModuleConfig.Constant.INTERNAL,null),
|
||||
new ModuleConfig(TaskScheduler.class.getName(), ModuleConfig.Constant.INTERNAL, null)
|
||||
new ModuleConfig(MemoryUpdater.class.getName(),ModuleConfig.Constant.INTERNAL,null)
|
||||
// new ModuleConfig(TaskScheduler.class.getName(), ModuleConfig.Constant.INTERNAL, null)
|
||||
);
|
||||
config.setModuleConfigList(moduleConfigList);
|
||||
}
|
||||
@@ -105,7 +106,7 @@ public class Config {
|
||||
modelConfig.setModel(scanner.nextLine());
|
||||
|
||||
}
|
||||
for (int i = 0; i < 5; i++) {
|
||||
for (int i = 0; i < 6; i++) {
|
||||
String modelKey = switch (i) {
|
||||
case 0 -> {
|
||||
System.out.println("CoreModel:");
|
||||
@@ -127,6 +128,10 @@ public class Config {
|
||||
System.out.println("MemorySummarizer:");
|
||||
yield MemorySummarizer.MODEL_KEY;
|
||||
}
|
||||
case 5 -> {
|
||||
System.out.println("StaticMemoryExtractor:");
|
||||
yield StaticMemoryExtractor.MODEL_KEY;
|
||||
}
|
||||
default -> throw new RuntimeException();
|
||||
};
|
||||
if (!singleModel) {
|
||||
|
||||
@@ -5,12 +5,26 @@ public class ModelConstant {
|
||||
CoreModel 提示词
|
||||
|
||||
功能说明
|
||||
你需要根据用户的当前输入(text)生成恰当的回复。只有当以下字段与text内容直接相关时,才需要参考它们:
|
||||
你需要根据用户的当前输入(text)生成恰当的回复。每条用户输入都采用以下格式:
|
||||
|
||||
```
|
||||
[用户昵称(用户uuid)] 实际输入内容
|
||||
|
||||
```
|
||||
|
||||
你需要只基于最新一条消息中的用户(即最后一条user类型消息中括号内的uuid)进行回应,仅参考该用户的历史上下文内容。
|
||||
|
||||
如果其他用户的对话历史中提到的信息能**明确补充该用户的信息背景**(如他人提及该用户、与其对话、对其信息进行补全等),你可以将其作为当前用户的新知识补充。否则,完全忽略其他用户的内容。
|
||||
|
||||
注意,历史消息中将只包含带有前缀 `[用户昵称(用户uuid)]` 的完整输入文本,不会带有下文提到的额外字段。
|
||||
|
||||
字段说明
|
||||
- text:指的是“原始输入内容”,包含带有前缀 `[用户昵称(用户uuid)]` 的完整输入文本
|
||||
- datetime:当text包含时间相关语义时使用
|
||||
- character:当需要根据角色设定调整语气时使用
|
||||
- user_nick:当text中包含对用户的称呼或个性化需求时使用
|
||||
- user_id:用户的唯一标识,该字段真正具有区分用户的作用
|
||||
其他所有字段仅在明确与text内容相关时才予以考虑,否则应完全忽略。
|
||||
- user_id:等于括号中的uuid,用于唯一标识用户
|
||||
- memory_slices/static_memory:仅与当前用户相关
|
||||
|
||||
输入字段优先级
|
||||
1. 首要关注text字段,这是核心输入内容
|
||||
@@ -20,37 +34,43 @@ public class ModelConstant {
|
||||
• user_nick:仅当需要个性化称呼时生效
|
||||
3. 其他所有扩展字段(如memory_slices/static_memory等):
|
||||
- 必须与text内容有明确关联时才参考
|
||||
- 若字段内容与text无关,则完全忽略该字段
|
||||
- 且只考虑当前用户的字段内容,忽略其他用户相关内容
|
||||
|
||||
响应生成规范
|
||||
- 回复必须完全基于text字段的核心语义生成
|
||||
- 禁止引用与当前text无关的历史内容
|
||||
- 若角色设定与当前对话无关,应自动忽略
|
||||
- 输出格式严格为:
|
||||
{
|
||||
"text": "响应内容"
|
||||
}
|
||||
|
||||
核心生成逻辑
|
||||
1. 主内容优先原则
|
||||
- 首先独立分析text字段的语义
|
||||
- 只有当其他字段内容能直接辅助理解text时(如text说"上次说的那个"对应memory_slices中的记录),才调用相关字段
|
||||
- 若text是独立完整表达(如单字、短句、新话题开启),则忽略所有非核心字段
|
||||
- 独立分析text字段的语义
|
||||
- 仅在其他字段能直接辅助理解text的前提下引用(如text中提及“上次说的那个”)
|
||||
- 若text表达独立完整(如新话题),忽略所有非核心字段
|
||||
|
||||
2. 无关字段过滤机制
|
||||
- 当text属于以下情况时,强制忽略所有扩展字段:
|
||||
✓ 短于5个字符的输入(如"在"、"好的")
|
||||
✓ 明显开启新话题的提问(如"量子计算是什么")
|
||||
✓ 不含指代词的独立陈述句
|
||||
- 示例:当text="今天天气如何"时,即使存在量子计算相关的memory_slices也应忽略
|
||||
2. 多用户隔离机制
|
||||
- 每条消息都带有格式 `[用户昵称(用户uuid)]`
|
||||
- 所有分析仅基于最后一条user消息中的用户进行处理
|
||||
- memory_slices/static_memory等内容只会包含该用户的相关信息
|
||||
- 如果历史中其他用户提到了当前用户的信息,可用于补充理解;否则忽略
|
||||
|
||||
3. 响应生成规范
|
||||
- 回复必须完全基于text的核心语义生成
|
||||
- 禁止出现"根据您之前提到的XX"等无关内容引用
|
||||
- 当角色设定(character)与当前对话无关时(如科技助手回答日常问候),暂时覆盖角色设定
|
||||
3. 无关字段过滤机制
|
||||
- text短于5个字符(如“在”、“好的”)
|
||||
- text开启新话题(如“量子计算是什么”)
|
||||
- text为独立句子,无引用上下文指代
|
||||
→ 此类情况强制忽略所有扩展字段
|
||||
|
||||
输出格式
|
||||
{
|
||||
"text": "响应内容" // 必须严格对应text字段的语义
|
||||
}
|
||||
|
||||
最终注意事项
|
||||
1. 回应内容必须紧扣用户输入,且契合角色设定
|
||||
2. 遇到模糊提问时,优先推测最常见的语境理解,不要直接问“你指的是什么”
|
||||
3. 回应应自然衔接,并允许后续系统模块追加更多限定、扩展字段
|
||||
4. 你只需要生成JSON格式的响应对象,字段仅包含`text`,但在模块扩展下,字段内容可以有所增加。确保你可以兼容这些扩展而不破坏结构。
|
||||
5. 若用户的输入(text)与其他字段中的内容无关,可忽略其他字段的内容
|
||||
1. 回应内容必须紧扣用户输入,确保基于当前用户的语境
|
||||
2. 遇到模糊问题时,推测常见语境理解,不要直接提问
|
||||
3. 回应应自然衔接,适配后续可能拼接的上下文或约束
|
||||
4. 输出字段固定为`text`,但内容可根据上下文扩展
|
||||
5. 若text与memory_slices等扩展字段无关,应完全忽略
|
||||
6. 请确保你对每一轮对话都只针对当前输入用户作出回应,保持多用户上下文隔离的准确性
|
||||
|
||||
> 以下模块可能会追加更多内容限制或上下文提示,请确保你的回答能够自然兼容这些后续拼接的内容,并调整输出格式。
|
||||
|
||||
@@ -305,5 +325,95 @@ public class ModelConstant {
|
||||
public static final String BASE_SUMMARIZER_PROMPT = """
|
||||
""";
|
||||
public static final String STATIC_MEMORY_EXTRACTOR_PROMPT = """
|
||||
StaticMemoryExtractor 提示词
|
||||
功能说明
|
||||
你需要根据用户对话记录(messages)和现有静态记忆(existedStaticMemory),分析并输出需要新增或修改的静态记忆项。静态记忆指用户长期有效的个人信息、习惯偏好等常识性数据。
|
||||
|
||||
输入字段说明
|
||||
• `userId`: 用户唯一标识符(仅用于追踪)
|
||||
• `messages`: 对话记录数组(需特别关注user角色的content内容)
|
||||
• `existedStaticMemory`: 现有静态记忆键值对(需对比更新)
|
||||
|
||||
输出规则
|
||||
1. 基本格式:
|
||||
{
|
||||
"[记忆键名]": "[记忆内容]",
|
||||
...
|
||||
}
|
||||
2. 更新逻辑:
|
||||
• 新增记忆:当对话中首次出现明确的新信息时(如"我养了只叫Tom的猫")
|
||||
• 修改记忆:当新信息与原有记忆冲突或需要细化时(如原"居住地":"北京" → "海淀区")
|
||||
• 保留键名:修改时严格保持原记忆键不变
|
||||
3. 内容要求:
|
||||
• 值必须是可直接存储的字符串
|
||||
• 排除临时性/情绪化内容(如"今天好累")
|
||||
• 合并关联信息(如"Python和Java" → "编程语言:Python, Java")
|
||||
|
||||
处理流程
|
||||
1. 扫描messages提取以下信息:
|
||||
a. 人口统计学特征(年龄/职业/居住地等)
|
||||
b. 长期兴趣爱好
|
||||
c. 人际关系(家人/宠物等)
|
||||
d. 长期计划/目标
|
||||
2. 对比existedStaticMemory:
|
||||
a. 新信息 → 新增键值对
|
||||
b. 更精确信息 → 更新对应键的值
|
||||
c. 矛盾信息 → 以最新对话为准
|
||||
3. 过滤无效内容:
|
||||
a. 排除模糊表述(如"可能"、"考虑中")
|
||||
b. 排除时效性短于1个月的信息
|
||||
|
||||
完整示例
|
||||
示例1(新增记忆):
|
||||
输入:{
|
||||
"userId": "U123",
|
||||
"messages": [
|
||||
{"role": "user", "content": "我最近收养了只金毛叫Lucky"},
|
||||
{"role": "assistant", "content": "金毛是很温顺的犬种呢"}
|
||||
],
|
||||
"existedStaticMemory": {"爱好": "爬山"}
|
||||
}
|
||||
输出:{
|
||||
"宠物": "金毛犬Lucky"
|
||||
}
|
||||
|
||||
示例2(修改记忆):
|
||||
输入:{
|
||||
"userId": "U456",
|
||||
"messages": [
|
||||
{"role": "user", "content": "下个月要搬去上海静安区了"},
|
||||
{"role": "assistant", "content": "需要帮您找静安区的餐厅吗?"}
|
||||
],
|
||||
"existedStaticMemory": {"居住地": "北京"}
|
||||
}
|
||||
输出:{
|
||||
"居住地": "上海静安区"
|
||||
}
|
||||
|
||||
示例3(混合更新):
|
||||
输入:{
|
||||
"userId": "U789",
|
||||
"messages": [
|
||||
{"role": "user", "content": "我的MacBook Pro用了3年"},
|
||||
{"role": "assistant", "content": "建议考虑M系列芯片的新款"},
|
||||
{"role": "user", "content": "其实我更喜欢Windows系统"}
|
||||
],
|
||||
"existedStaticMemory": {"电子设备": "iPhone 13", "操作系统偏好": "macOS"}
|
||||
}
|
||||
输出:{
|
||||
"电子设备": "MacBook Pro",
|
||||
"操作系统偏好": "Windows"
|
||||
}
|
||||
|
||||
特殊处理
|
||||
1. 当信息可信度不足时:
|
||||
• 不生成记忆项(如用户说"也许我会学钢琴")
|
||||
2. 当存在多轮矛盾时:
|
||||
• 以最后一次明确表述为准
|
||||
3. 空输入处理:
|
||||
{
|
||||
"error": "no valid input"
|
||||
}
|
||||
4. 当提到其他人时,应区分这个人的事件是否与user真正相关,如果与user无关,应当忽略
|
||||
""";
|
||||
}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
package work.slhaf.agent.core.memory.pojo;
|
||||
package work.slhaf.agent.common.pojo;
|
||||
|
||||
import java.io.Serializable;
|
||||
|
||||
@@ -5,6 +5,7 @@ import lombok.EqualsAndHashCode;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.io.FileUtils;
|
||||
import work.slhaf.agent.common.chat.pojo.Message;
|
||||
import work.slhaf.agent.common.pojo.PersistableObject;
|
||||
import work.slhaf.agent.core.memory.exception.UnExistedTopicException;
|
||||
import work.slhaf.agent.core.memory.node.MemoryNode;
|
||||
import work.slhaf.agent.core.memory.node.TopicNode;
|
||||
|
||||
@@ -45,6 +45,14 @@ public class MemoryManager implements InteractionModule {
|
||||
memoryManager.setMemoryGraph(MemoryGraph.getInstance(config.getAgentId()));
|
||||
memoryManager.setActivatedSlices(new HashMap<>());
|
||||
log.info("MemoryManager注册完毕...");
|
||||
Runtime.getRuntime().addShutdownHook(new Thread(() -> {
|
||||
try {
|
||||
memoryManager.save();
|
||||
log.info("MemoryGraph已保存");
|
||||
} catch (IOException e) {
|
||||
log.error("保存MemoryGraph失败: ", e);
|
||||
}
|
||||
}));
|
||||
}
|
||||
return memoryManager;
|
||||
}
|
||||
@@ -135,7 +143,11 @@ public class MemoryManager implements InteractionModule {
|
||||
memoryGraph.getStaticMemory().get(userId).putAll(newStaticMemory);
|
||||
}
|
||||
|
||||
public void updateDialogMap(LocalDateTime dateTime,String newDialogCache) {
|
||||
public void updateDialogMap(LocalDateTime dateTime, String newDialogCache) {
|
||||
memoryGraph.updateDialogMap(dateTime, newDialogCache);
|
||||
}
|
||||
|
||||
public void save() throws IOException {
|
||||
memoryGraph.serialize();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,7 +5,7 @@ import lombok.EqualsAndHashCode;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import work.slhaf.agent.core.memory.exception.NullSliceListException;
|
||||
import work.slhaf.agent.core.memory.pojo.MemorySlice;
|
||||
import work.slhaf.agent.core.memory.pojo.PersistableObject;
|
||||
import work.slhaf.agent.common.pojo.PersistableObject;
|
||||
|
||||
import java.io.*;
|
||||
import java.nio.file.Files;
|
||||
|
||||
@@ -2,7 +2,7 @@ package work.slhaf.agent.core.memory.node;
|
||||
|
||||
import lombok.Data;
|
||||
import lombok.EqualsAndHashCode;
|
||||
import work.slhaf.agent.core.memory.pojo.PersistableObject;
|
||||
import work.slhaf.agent.common.pojo.PersistableObject;
|
||||
|
||||
import java.io.Serial;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
|
||||
@@ -3,6 +3,7 @@ package work.slhaf.agent.core.memory.pojo;
|
||||
import lombok.Data;
|
||||
import lombok.EqualsAndHashCode;
|
||||
import work.slhaf.agent.common.chat.pojo.Message;
|
||||
import work.slhaf.agent.common.pojo.PersistableObject;
|
||||
|
||||
import java.io.Serial;
|
||||
import java.util.List;
|
||||
|
||||
@@ -1,11 +1,19 @@
|
||||
package work.slhaf.agent.core.memory.pojo;
|
||||
|
||||
import lombok.Data;
|
||||
import lombok.EqualsAndHashCode;
|
||||
import work.slhaf.agent.common.pojo.PersistableObject;
|
||||
|
||||
import java.io.Serial;
|
||||
import java.util.List;
|
||||
|
||||
@EqualsAndHashCode(callSuper = true)
|
||||
@Data
|
||||
public class User {
|
||||
public class User extends PersistableObject {
|
||||
|
||||
@Serial
|
||||
private static final long serialVersionUID = 1L;
|
||||
|
||||
private String uuid;
|
||||
private List<String> info;
|
||||
private String nickName;
|
||||
|
||||
@@ -49,15 +49,13 @@ public class CoreModel extends Model implements InteractionModule {
|
||||
|
||||
@Override
|
||||
public void execute(InteractionContext interactionContext) {
|
||||
//TODO 添加新的system prompt 引导主模型专注于最新的用户输入
|
||||
//TODO 需要更新主模型prompt
|
||||
String tempPrompt = interactionContext.getModulePrompt().toString();
|
||||
if (!tempPrompt.equals(promptCache)) {
|
||||
coreModel.getMessages().set(0, new Message(ChatConstant.Character.SYSTEM, ModelConstant.CORE_MODEL_PROMPT + "\r\n" + tempPrompt));
|
||||
promptCache = tempPrompt;
|
||||
}
|
||||
String user = "[" + interactionContext.getUserNickname() + "(" + interactionContext.getUserId() + ")]";
|
||||
Message userMessage = new Message(ChatConstant.Character.USER, user + interactionContext.getCoreContext().getString("text"));
|
||||
Message userMessage = new Message(ChatConstant.Character.USER, user + " " + interactionContext.getCoreContext());
|
||||
this.messages.add(userMessage);
|
||||
JSONObject response = null;
|
||||
int count = 0;
|
||||
@@ -65,13 +63,15 @@ public class CoreModel extends Model implements InteractionModule {
|
||||
try {
|
||||
ChatResponse chatResponse = this.chat();
|
||||
response = JSONObject.parse(extractJson(chatResponse.getMessage()));
|
||||
this.messages.removeLast();
|
||||
this.messages.add(new Message(ChatConstant.Character.USER, interactionContext.getCoreContext().getString("text")));
|
||||
Message assistantMessage = new Message(ChatConstant.Character.ASSISTANT, response.getString("text"));
|
||||
this.messages.add(assistantMessage);
|
||||
|
||||
//设置上下文
|
||||
interactionContext.getModuleContext().put("total_token", chatResponse.getUsageBean().getTotal_tokens());
|
||||
//区分单人聊天场景
|
||||
if (interactionContext.isSingle()){
|
||||
if (interactionContext.isSingle()) {
|
||||
MetaMessage metaMessage = new MetaMessage(userMessage, assistantMessage);
|
||||
sessionManager.addMetaMessage(interactionContext.getUserId(), metaMessage);
|
||||
}
|
||||
|
||||
@@ -5,9 +5,11 @@ import com.alibaba.fastjson2.JSONObject;
|
||||
import lombok.ToString;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.java_websocket.WebSocket;
|
||||
import org.java_websocket.framing.Framedata;
|
||||
import org.java_websocket.handshake.ClientHandshake;
|
||||
import org.java_websocket.server.WebSocketServer;
|
||||
import work.slhaf.agent.core.interaction.InputReceiver;
|
||||
import work.slhaf.agent.core.interaction.InteractionThreadPoolExecutor;
|
||||
import work.slhaf.agent.core.interaction.data.InteractionInputData;
|
||||
import work.slhaf.agent.core.interaction.data.InteractionOutputData;
|
||||
|
||||
@@ -18,23 +20,86 @@ import java.util.concurrent.ConcurrentHashMap;
|
||||
@Slf4j
|
||||
public class AgentWebSocketServer extends WebSocketServer implements MessageSender {
|
||||
|
||||
private static final long HEARTBEAT_INTERVAL = 10_000;
|
||||
|
||||
@ToString.Exclude
|
||||
private final InputReceiver receiver;
|
||||
private final ConcurrentHashMap<String, WebSocket> userSessions = new ConcurrentHashMap<>();
|
||||
private final InteractionThreadPoolExecutor executor;
|
||||
|
||||
// 记录最后一次收到Pong的时间
|
||||
private final ConcurrentHashMap<WebSocket, Long> lastPongTimes = new ConcurrentHashMap<>();
|
||||
|
||||
public AgentWebSocketServer(int port, InputReceiver receiver) {
|
||||
super(new InetSocketAddress(port));
|
||||
this.receiver = receiver;
|
||||
this.executor = InteractionThreadPoolExecutor.getInstance();
|
||||
}
|
||||
|
||||
public void launch() {
|
||||
this.start();
|
||||
setShutDownHook();
|
||||
startHeartbeatThread();
|
||||
}
|
||||
|
||||
private void startHeartbeatThread() {
|
||||
executor.execute(() -> {
|
||||
while (!Thread.interrupted()){
|
||||
try{
|
||||
Thread.sleep(HEARTBEAT_INTERVAL);
|
||||
checkConnections();
|
||||
} catch (InterruptedException e) {
|
||||
Thread.currentThread().interrupt();
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
private void checkConnections() {
|
||||
long now = System.currentTimeMillis();
|
||||
for (WebSocket conn : getConnections()) {
|
||||
if (conn.isOpen()) {
|
||||
// 发送Ping
|
||||
conn.sendPing();
|
||||
log.debug("Sent Ping to {}", conn.getRemoteSocketAddress());
|
||||
|
||||
// 检查上次Pong响应是否超时(2倍心跳间隔)
|
||||
Long lastPong = lastPongTimes.get(conn);
|
||||
if (lastPong != null && now - lastPong > HEARTBEAT_INTERVAL * 2) {
|
||||
log.warn("Connection {} timed out, closing...", conn.getRemoteSocketAddress());
|
||||
conn.close(1001, "No Pong response");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onWebsocketPong(WebSocket conn, Framedata f) {
|
||||
lastPongTimes.put(conn, System.currentTimeMillis());
|
||||
log.debug("Received Pong from {}", conn.getRemoteSocketAddress());
|
||||
}
|
||||
|
||||
private void setShutDownHook() {
|
||||
Runtime.getRuntime().addShutdownHook(new Thread(() -> {
|
||||
try {
|
||||
//关闭WebSocketServer
|
||||
this.stop();
|
||||
log.info("WebSocketServer 已关闭");
|
||||
} catch (Exception e) {
|
||||
log.error("WebSocketServer关闭失败: ", e);
|
||||
}
|
||||
}));
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onOpen(WebSocket webSocket, ClientHandshake clientHandshake) {
|
||||
log.info("新连接: {}",webSocket.getRemoteSocketAddress());
|
||||
log.info("新连接: {}", webSocket.getRemoteSocketAddress());
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onClose(WebSocket webSocket, int i, String s, boolean b) {
|
||||
log.info("连接关闭: {}",webSocket.getRemoteSocketAddress());
|
||||
log.info("连接关闭: {}", webSocket.getRemoteSocketAddress());
|
||||
lastPongTimes.remove(webSocket);
|
||||
userSessions.values().removeIf(session -> session.equals(webSocket));
|
||||
}
|
||||
|
||||
@@ -64,8 +129,8 @@ public class AgentWebSocketServer extends WebSocketServer implements MessageSend
|
||||
WebSocket webSocket = userSessions.get(outputData.getUserInfo());
|
||||
if (webSocket != null && webSocket.isOpen()) {
|
||||
webSocket.send(JSONUtil.toJsonStr(outputData));
|
||||
}else {
|
||||
log.warn("用户不在线: {}",outputData.getUserInfo());
|
||||
} else {
|
||||
log.warn("用户不在线: {}", outputData.getUserInfo());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -50,9 +50,14 @@ public class MemorySelectExtractor extends Model {
|
||||
public ExtractorResult execute(InteractionContext context) {
|
||||
//结构化为指定格式
|
||||
List<Message> chatMessages = new ArrayList<>();
|
||||
for (MetaMessage metaMessage : sessionManager.getSingleMetaMessageMap().get(context.getUserId())) {
|
||||
chatMessages.add(metaMessage.getUserMessage());
|
||||
chatMessages.add(metaMessage.getAssistantMessage());
|
||||
List<MetaMessage> metaMessages = sessionManager.getSingleMetaMessageMap().get(context.getUserId());
|
||||
if (metaMessages == null) {
|
||||
sessionManager.getSingleMetaMessageMap().put(context.getUserId(), new ArrayList<>());
|
||||
} else {
|
||||
for (MetaMessage metaMessage : metaMessages) {
|
||||
chatMessages.add(metaMessage.getUserMessage());
|
||||
chatMessages.add(metaMessage.getAssistantMessage());
|
||||
}
|
||||
}
|
||||
|
||||
ExtractorInput extractorInput = ExtractorInput.builder()
|
||||
|
||||
@@ -3,6 +3,7 @@ package work.slhaf.agent.modules.memory.updater;
|
||||
import com.alibaba.fastjson2.JSONObject;
|
||||
import lombok.Data;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import work.slhaf.agent.common.chat.constant.ChatConstant;
|
||||
import work.slhaf.agent.common.chat.pojo.Message;
|
||||
import work.slhaf.agent.core.interaction.InteractionModule;
|
||||
import work.slhaf.agent.core.interaction.InteractionThreadPoolExecutor;
|
||||
@@ -21,6 +22,8 @@ import java.io.IOException;
|
||||
import java.time.LocalDateTime;
|
||||
import java.util.*;
|
||||
import java.util.concurrent.Callable;
|
||||
import java.util.regex.Matcher;
|
||||
import java.util.regex.Pattern;
|
||||
|
||||
@Data
|
||||
@Slf4j
|
||||
@@ -46,6 +49,7 @@ public class MemoryUpdater implements InteractionModule {
|
||||
memoryUpdater.setMemorySummarizer(MemorySummarizer.getInstance());
|
||||
memoryUpdater.setSessionManager(SessionManager.getInstance());
|
||||
memoryUpdater.setStaticMemoryExtractor(StaticMemoryExtractor.getInstance());
|
||||
memoryUpdater.setExecutor(InteractionThreadPoolExecutor.getInstance());
|
||||
}
|
||||
return memoryUpdater;
|
||||
}
|
||||
@@ -81,6 +85,8 @@ public class MemoryUpdater implements InteractionModule {
|
||||
try {
|
||||
SummarizeResult summarizeResult = memorySummarizer.execute(new SummarizeInput(memoryManager.getChatMessages(), memoryManager.getTopicTree()));
|
||||
MemorySlice memorySlice = getMemorySlice(userId, summarizeResult, memoryManager.getChatMessages());
|
||||
//设置involvedUserId
|
||||
setInvolvedUserId(userId,memorySlice,memoryManager.getChatMessages());
|
||||
memoryManager.insertSlice(memorySlice, summarizeResult.getTopicPath());
|
||||
//更新总dialogMap
|
||||
singleMemorySummary.put("total", summarizeResult.getSummary());
|
||||
@@ -91,6 +97,27 @@ public class MemoryUpdater implements InteractionModule {
|
||||
});
|
||||
}
|
||||
|
||||
private void setInvolvedUserId(String startUserId, MemorySlice memorySlice, List<Message> chatMessages) {
|
||||
for (Message chatMessage : chatMessages) {
|
||||
if (chatMessage.getRole().equals(ChatConstant.Character.ASSISTANT)) {
|
||||
continue;
|
||||
}
|
||||
//匹配userId
|
||||
String content = chatMessage.getContent();
|
||||
Pattern pattern = Pattern.compile("\\[.*\\(([^)]+)\\)\\]");
|
||||
Matcher matcher = pattern.matcher(content);
|
||||
if (!matcher.find()) {
|
||||
continue;
|
||||
}
|
||||
String userId = matcher.group(1);
|
||||
if (userId.equals(startUserId)){
|
||||
continue;
|
||||
}
|
||||
memorySlice.getInvolvedUserIds().add(userId);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
private void updateSingleChatSlices(String interactionContext, HashMap<String, String> singleMemorySummary) throws InterruptedException {
|
||||
//更新单聊记忆,同时从chatMessages中去掉单聊记忆
|
||||
Set<String> userIdSet = new HashSet<>(sessionManager.getSingleMetaMessageMap().keySet());
|
||||
|
||||
@@ -20,7 +20,7 @@ public class StaticMemoryExtractor extends Model {
|
||||
|
||||
private static StaticMemoryExtractor staticMemoryExtractor;
|
||||
|
||||
private static final String MODEL_KEY = "static_memory_extractor";
|
||||
public static final String MODEL_KEY = "static_memory_extractor";
|
||||
|
||||
|
||||
public static StaticMemoryExtractor getInstance() throws IOException, ClassNotFoundException {
|
||||
|
||||
@@ -34,7 +34,8 @@ public class MemorySummarizer extends Model {
|
||||
public static final String MODEL_KEY = "memory_summarizer";
|
||||
private static final List<String> prompts = List.of(
|
||||
Constant.SINGLE_SUMMARIZE_PROMPT,
|
||||
Constant.MULTI_SUMMARIZE_PROMPT
|
||||
Constant.MULTI_SUMMARIZE_PROMPT,
|
||||
Constant.TOTAL_SUMMARIZE_PROMPT
|
||||
);
|
||||
|
||||
private InteractionThreadPoolExecutor executor;
|
||||
@@ -304,5 +305,95 @@ public class MemorySummarizer extends Model {
|
||||
}
|
||||
|
||||
""";
|
||||
|
||||
public static final String TOTAL_SUMMARIZE_PROMPT = """
|
||||
TOTAL_SUMMARIZER 提示词
|
||||
功能说明
|
||||
你需要根据输入的多个独立用户对话摘要,生成一份综合性的总结报告。每个用户的对话内容彼此无关联,需保持原始信息的同时进行概括性整合,最终输出标准化JSON格式的响应。
|
||||
|
||||
输入字段说明
|
||||
• 输入数据为JSON对象:
|
||||
- key: 用户uuid(需在输出中保留)
|
||||
- value: 该用户的对话摘要文本(需要处理的内容)
|
||||
|
||||
输出规则
|
||||
1. 基本响应格式:
|
||||
{
|
||||
"content": string // 综合摘要文本
|
||||
}
|
||||
2. 内容要求:
|
||||
• 严格控制在800字以内
|
||||
• 保持客观中立,不添加解释性内容
|
||||
• 使用分号分隔不同用户的摘要内容
|
||||
• 保留原始对话的关键事实信息
|
||||
• 对重复信息进行合并处理
|
||||
3. 格式要求:
|
||||
• 每个用户摘要以"用户[uuid]:"开头
|
||||
• 不同用户摘要间用分号分隔
|
||||
• 末尾不添加总结性陈述
|
||||
|
||||
处理流程
|
||||
1. 解析输入JSON的所有键值对
|
||||
2. 对每个摘要执行:
|
||||
a. 提取关键事实信息
|
||||
b. 删除问候语等非实质性内容
|
||||
c. 简化重复表达
|
||||
3. 合并处理:
|
||||
a. 识别不同摘要中的相同信息点
|
||||
b. 合并相同信息点的不同表述
|
||||
4. 生成最终摘要:
|
||||
a. 按原始输入顺序排列用户摘要
|
||||
b. 确保总字数≤800
|
||||
c. 验证信息完整性
|
||||
|
||||
完整示例
|
||||
示例1(基础情况):
|
||||
输入:{
|
||||
"aaa-111": "需要购买笔记本电脑,预算5000左右,主要用于办公",
|
||||
"bbb-222": "想买游戏本,预算8000-10000,要能运行3A大作",
|
||||
"ccc-333": "咨询轻薄本推荐,经常出差使用"
|
||||
}
|
||||
输出:{
|
||||
"content": "
|
||||
用户[aaa-111]:需要5000元左右的办公笔记本;
|
||||
用户[bbb-222]:寻求8000-10000元的游戏本,要求能运行3A大作;
|
||||
用户[ccc-333]:咨询适合出差使用的轻薄本"
|
||||
}
|
||||
|
||||
示例2(信息合并):
|
||||
输入:{
|
||||
"ddd-444": "想了解Python入门课程,零基础",
|
||||
"eee-555": "询问Java和Python哪个更适合新手",
|
||||
"fff-666": "零基础,想学Python数据分析"
|
||||
}
|
||||
输出:{
|
||||
"content": "
|
||||
用户[ddd-444]:零基础想了解Python入门课程;
|
||||
用户[eee-555]:询问Java和Python对新手的适用性;
|
||||
用户[fff-666]:零基础想学习Python数据分析"
|
||||
}
|
||||
|
||||
示例3(长文本精简):
|
||||
输入:{
|
||||
"ggg-777": "您好!我最近在准备考研,想咨询下时间规划。具体是想了解每天应该分配多少时间给英语复习,我现在英语水平大概是四级刚过的程度...(后续200字详细描述)",
|
||||
"hhh-888": "考研政治怎么准备?需要报班吗?"
|
||||
}
|
||||
输出:{
|
||||
"content": "
|
||||
用户[ggg-777]:咨询考研英语复习时间规划,当前英语水平为四级;
|
||||
用户[hhh-888]:询问考研政治备考方法及是否需要报班"
|
||||
}
|
||||
|
||||
特殊处理
|
||||
1. 当总字数超出限制时:
|
||||
• 尽量保留所有出现的用户摘要
|
||||
2. 当输入为空时:
|
||||
{
|
||||
"content": ""
|
||||
}
|
||||
3. 当用户uuid包含特殊字符时:
|
||||
• 保持原始uuid格式不做修改
|
||||
• 示例:用户[xxx-ddssss-xx]:内容摘要
|
||||
""";
|
||||
}
|
||||
}
|
||||
|
||||
31
src/test/java/memory/RegexTest.java
Normal file
31
src/test/java/memory/RegexTest.java
Normal file
@@ -0,0 +1,31 @@
|
||||
package memory;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import java.util.regex.Matcher;
|
||||
import java.util.regex.Pattern;
|
||||
|
||||
public class RegexTest {
|
||||
|
||||
@Test
|
||||
public void regexTest(){
|
||||
String[] examples = {
|
||||
"[小明(userId)] 我在开会] (te[]st)",
|
||||
"[用户(昵)称(userId)] 你好[呀]",
|
||||
"[测试账号(userId)] 这是一个(test(123))消息"
|
||||
};
|
||||
|
||||
Pattern pattern = Pattern.compile("\\[.*?\\(([^)]+)\\)\\]");
|
||||
|
||||
for (String example : examples) {
|
||||
Matcher matcher = pattern.matcher(example);
|
||||
if (matcher.find()) {
|
||||
System.out.println("在 '" + example + "' 中找到 userId: " + matcher.group(1));
|
||||
System.out.println();
|
||||
} else {
|
||||
System.out.println("在 '" + example + "' 中未找到 userId");
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user