推进记忆更新模块以及单智能体多用户相关设计

- 更新 PreprocessExecutor 以设置用户 ID
- 新增 MetaMessage 类用于封装用户和助手的消息对
- 新增 SessionManager 用于记录独立用户在共享上下文中的聊天记录
- 发现`单智能体多用户`的相关设计问题,已有思路,待解决
This commit is contained in:
2025-04-29 21:57:07 +08:00
parent 40ac6bef03
commit b8b5661d79
16 changed files with 217 additions and 47 deletions

View File

@@ -0,0 +1,11 @@
package work.slhaf.agent.common.chat.pojo;
import lombok.AllArgsConstructor;
import lombok.Data;
@Data
@AllArgsConstructor
public class MetaMessage {
private Message userMessage;
private Message assistantMessage;
}

View File

@@ -9,6 +9,7 @@ public class ModelConstant {
- datetime当text包含时间相关语义时使用
- character当需要根据角色设定调整语气时使用
- user_nick当text中包含对用户的称呼或个性化需求时使用
- user_id用户的唯一标识该字段真正具有区分用户的作用
其他所有字段仅在明确与text内容相关时才予以考虑否则应完全忽略。
输入字段优先级

View File

@@ -52,8 +52,6 @@ public class InteractionModulesLoader {
private static InteractionModule loadInternalModule(String className) {
try {
Class<?> clazz = Class.forName(className);
//TODO 后续需要规范`getInstance`方法的实现
return (InteractionModule) clazz.getMethod("getInstance").invoke(null);
} catch (ClassNotFoundException | InvocationTargetException | IllegalAccessException | NoSuchMethodException e) {
throw new RuntimeException("Fail to load internal module: " + className,e);

View File

@@ -7,9 +7,11 @@ import java.time.LocalDateTime;
@Data
public class InteractionContext {
protected String userInfo;
protected String userId;
protected String userNickname;
protected String userInfo;
protected LocalDateTime dateTime;
protected boolean single;
protected boolean finished;
protected String input;

View File

@@ -11,4 +11,5 @@ public class InteractionInputData {
private String content;
private LocalDateTime localDateTime;
private String platform;
private boolean single;
}

View File

@@ -109,6 +109,8 @@ public class MemoryGraph extends PersistableObject {
*/
private Set<Long> selectedSlices;
private String memoryId;
public MemoryGraph(String id) {
this.id = id;
this.topicNodes = new HashMap<>();

View File

@@ -7,28 +7,31 @@ import work.slhaf.agent.common.config.Config;
import work.slhaf.agent.core.interaction.InteractionModule;
import work.slhaf.agent.core.interaction.data.InteractionContext;
import work.slhaf.agent.core.memory.pojo.MemoryResult;
import work.slhaf.agent.core.memory.pojo.MemorySlice;
import work.slhaf.agent.core.memory.pojo.User;
import work.slhaf.agent.shared.memory.EvaluatedSlice;
import java.io.IOException;
import java.time.LocalDate;
import java.time.LocalDateTime;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.UUID;
import java.util.*;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReentrantLock;
@Data
@Slf4j
public class MemoryManager implements InteractionModule {
private static MemoryManager memoryManager;
private final Lock sliceInsertLock = new ReentrantLock();
private final Lock messageCleanLock = new ReentrantLock();
private MemoryGraph memoryGraph;
private HashMap<String,List<EvaluatedSlice>> activatedSlices;
private HashMap<String, List<EvaluatedSlice>> activatedSlices;
private MemoryManager(){}
private MemoryManager() {
}
@Override
public void execute(InteractionContext interactionContext) {
@@ -54,14 +57,14 @@ public class MemoryManager implements InteractionModule {
return memoryGraph.selectMemory(date);
}
public void cleanSelectedSliceFilter(){
public void cleanSelectedSliceFilter() {
memoryGraph.getSelectedSlices().clear();
}
public String getUserId(String userInfo,String nickName) {
public String getUserId(String userInfo, String nickName) {
String userId = null;
for (User user : memoryGraph.getUsers()) {
if (user.getInfo().contains(userInfo)){
if (user.getInfo().contains(userInfo)) {
userId = user.getUuid();
}
}
@@ -73,7 +76,7 @@ public class MemoryManager implements InteractionModule {
return userId;
}
public List<Message> getChatMessages(){
public List<Message> getChatMessages() {
return memoryGraph.getChatMessages();
}
@@ -91,7 +94,7 @@ public class MemoryManager implements InteractionModule {
return memoryGraph.getTopicTree();
}
public ConcurrentHashMap<String,String> getStaticMemory(String userId) {
public ConcurrentHashMap<String, String> getStaticMemory(String userId) {
return memoryGraph.getStaticMemory().get(userId);
}
@@ -106,4 +109,25 @@ public class MemoryManager implements InteractionModule {
public String getCharacter() {
return memoryGraph.getCharacter();
}
public void resetMemoryId() {
memoryGraph.setMemoryId(UUID.randomUUID().toString());
}
public String getMemoryId() {
return memoryGraph.getMemoryId();
}
public void insertSlice(MemorySlice memorySlice, String topicPath) throws IOException, ClassNotFoundException {
sliceInsertLock.lock();
List<String> topicPathList = Arrays.stream(topicPath.split("->")).toList();
memoryGraph.insertMemory(topicPathList, memorySlice);
sliceInsertLock.unlock();
}
public void cleanMessage(List<Message> messages) {
messageCleanLock.lock();
memoryGraph.getChatMessages().removeAll(messages);
messageCleanLock.unlock();
}
}

View File

@@ -7,12 +7,14 @@ import lombok.extern.slf4j.Slf4j;
import work.slhaf.agent.common.chat.constant.ChatConstant;
import work.slhaf.agent.common.chat.pojo.ChatResponse;
import work.slhaf.agent.common.chat.pojo.Message;
import work.slhaf.agent.common.chat.pojo.MetaMessage;
import work.slhaf.agent.common.config.Config;
import work.slhaf.agent.common.model.Model;
import work.slhaf.agent.common.model.ModelConstant;
import work.slhaf.agent.core.interaction.InteractionModule;
import work.slhaf.agent.core.interaction.data.InteractionContext;
import work.slhaf.agent.core.memory.MemoryManager;
import work.slhaf.agent.core.session.SessionManager;
import java.io.IOException;
@@ -27,6 +29,7 @@ public class CoreModel extends Model implements InteractionModule {
private static CoreModel coreModel;
private MemoryManager memoryManager;
private SessionManager sessionManager;
private String promptCache;
private CoreModel() {
@@ -46,22 +49,32 @@ public class CoreModel extends Model implements InteractionModule {
@Override
public void execute(InteractionContext interactionContext) {
//TODO 添加新的system prompt 引导主模型专注于最新的用户输入
//TODO 需要更新主模型prompt
String tempPrompt = interactionContext.getModulePrompt().toString();
if (!tempPrompt.equals(promptCache)) {
coreModel.getMessages().set(0, new Message(ChatConstant.Character.SYSTEM, ModelConstant.CORE_MODEL_PROMPT + "\r\n" + tempPrompt));
promptCache = tempPrompt;
}
this.messages.add(new Message(ChatConstant.Character.USER, interactionContext.getCoreContext().getString("text")));
ChatResponse chatResponse = this.chat();
String user = "[" + interactionContext.getUserNickname() + "(" + interactionContext.getUserId() + ")]";
Message userMessage = new Message(ChatConstant.Character.USER, user + interactionContext.getCoreContext().getString("text"));
this.messages.add(userMessage);
JSONObject response = null;
int count = 0;
while (true) {
try {
ChatResponse chatResponse = this.chat();
response = JSONObject.parse(extractJson(chatResponse.getMessage()));
this.messages.add(new Message(ChatConstant.Character.ASSISTANT, response.getString("text")));
Message assistantMessage = new Message(ChatConstant.Character.ASSISTANT, response.getString("text"));
this.messages.add(assistantMessage);
//设置上下文
interactionContext.getModuleContext().put("total_token", chatResponse.getUsageBean().getTotal_tokens());
//区分单人聊天场景
if (interactionContext.isSingle()){
MetaMessage metaMessage = new MetaMessage(userMessage, assistantMessage);
sessionManager.addMetaMessage(interactionContext.getUserId(), metaMessage);
}
break;
} catch (Exception e) {
count++;

View File

@@ -0,0 +1,46 @@
package work.slhaf.agent.core.session;
import lombok.Data;
import work.slhaf.agent.common.chat.pojo.Message;
import work.slhaf.agent.common.chat.pojo.MetaMessage;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
@Data
public class SessionManager {
private static SessionManager sessionManager;
private HashMap<String /*startUserId*/, List<MetaMessage>> singleMetaMessageMap;
private HashMap<String /*startUserId*/, List<MetaMessage>> multiMetaMessageMap;
public static SessionManager getInstance() {
if (sessionManager == null) {
sessionManager = new SessionManager();
sessionManager.setSingleMetaMessageMap(new HashMap<>());
}
return sessionManager;
}
public void addMetaMessage(String userId, MetaMessage metaMessage) {
if (singleMetaMessageMap.containsKey(userId)) {
singleMetaMessageMap.get(userId).add(metaMessage);
} else {
singleMetaMessageMap.put(userId, new java.util.ArrayList<>());
singleMetaMessageMap.get(userId).add(metaMessage);
}
}
public List<Message> unpackAndClear(String userId) {
List<Message> messages = new ArrayList<>();
for (MetaMessage metaMessage : singleMetaMessageMap.get(userId)) {
messages.add(metaMessage.getUserMessage());
messages.add(metaMessage.getAssistantMessage());
}
singleMetaMessageMap.remove(userId);
return messages;
}
}

View File

@@ -65,7 +65,7 @@ public class MemorySelector implements InteractionModule {
@Override
public void execute(InteractionContext interactionContext) throws IOException, ClassNotFoundException, InterruptedException {
String userId = memoryManager.getUserId(interactionContext.getUserInfo(), interactionContext.getUserNickname());
String userId =interactionContext.getUserId();
//获取主题路径
ExtractorResult extractorResult = memorySelectExtractor.execute(interactionContext);
if (extractorResult.isRecall()) {
@@ -82,14 +82,13 @@ public class MemorySelector implements InteractionModule {
memoryManager.getActivatedSlices().put(userId,memorySlices);
//向上下文设置切片存入标志,条件:对话历史列表不为空;触发了记忆查询
if (!memoryManager.getChatMessages().isEmpty()) {
/*if (!memoryManager.getChatMessages().isEmpty()) {
interactionContext.getModuleContext().put("new_topic", true);
interactionContext.getModuleContext().put("messages_to_store", List.of(memoryManager.getChatMessages()));
}
}*/
}
//设置上下文
interactionContext.getCoreContext().put("memory_slices",memoryManager.getActivatedSlices().get(userId));
interactionContext.getCoreContext().put("static_memory",memoryManager.getStaticMemory(userId));

View File

@@ -43,6 +43,7 @@ public class MemorySelectExtractor extends Model {
public ExtractorResult execute(InteractionContext context) {
//结构化为指定格式
//TODO 将历史消息替换为sessionManager中的用户对应信息列表
ExtractorInput extractorInput = ExtractorInput.builder()
.text(context.getInput())
.date(context.getDateTime().toLocalDate())

View File

@@ -9,12 +9,17 @@ import work.slhaf.agent.core.interaction.InteractionThreadPoolExecutor;
import work.slhaf.agent.core.interaction.data.InteractionContext;
import work.slhaf.agent.core.memory.MemoryManager;
import work.slhaf.agent.core.memory.pojo.MemorySlice;
import work.slhaf.agent.core.session.SessionManager;
import work.slhaf.agent.modules.memory.selector.extractor.MemorySelectExtractor;
import work.slhaf.agent.modules.memory.updater.summarizer.MemorySummarizer;
import work.slhaf.agent.modules.memory.updater.summarizer.data.SummarizeResult;
import work.slhaf.agent.modules.memory.updater.summarizer.data.TotalSummarizeInput;
import java.io.IOException;
import java.util.List;
import java.util.*;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
@Data
@Slf4j
@@ -22,10 +27,12 @@ public class MemoryUpdater implements InteractionModule {
private static MemoryUpdater memoryUpdater;
private ExecutorService updateExecutor;
private MemoryManager memoryManager;
private InteractionThreadPoolExecutor executor;
private MemorySelectExtractor memorySelectExtractor;
private MemorySummarizer memorySummarizer;
private SessionManager sessionManager;
private MemoryUpdater() {
}
@@ -36,38 +43,82 @@ public class MemoryUpdater implements InteractionModule {
memoryUpdater.setMemoryManager(MemoryManager.getInstance());
memoryUpdater.setMemorySelectExtractor(MemorySelectExtractor.getInstance());
memoryUpdater.setMemorySummarizer(MemorySummarizer.getInstance());
memoryUpdater.setSessionManager(SessionManager.getInstance());
memoryUpdater.setUpdateExecutor(Executors.newSingleThreadExecutor());
}
return memoryUpdater;
}
@Override
public void execute(InteractionContext interactionContext) {
public void execute(InteractionContext interactionContext) throws InterruptedException {
//TODO 需要保持压缩上下文、更新总摘要、更新确定性记忆、总结所有切片后更新dialogMap
if (interactionContext.isFinished()) {
return;
}
//如果token 大于阈值,则更新记忆
JSONObject moduleContext = interactionContext.getModuleContext();
if (moduleContext.getIntValue("total_token") > 24000 || (moduleContext.containsKey("new_topic") && moduleContext.getBooleanValue("new_topic"))) {
executor.execute(() -> {
//整理切片
List<Message> chatMessages = moduleContext.getList("messages_to_store", Message.class);
//进行摘要、判断是否为私密记忆、生成主题路径
updateExecutor.execute(() -> {
//如果token 大于阈值,则更新记忆
JSONObject moduleContext = interactionContext.getModuleContext();
if (moduleContext.getIntValue("total_token") > 24000) {
//更新单聊记忆同时从chatMessages中去掉单聊记忆
try {
SummarizeResult summarizeResult = memorySummarizer.execute(chatMessages);
//整理为切片并存储
MemorySlice memorySlice = new MemorySlice();
updateSingleChatSlices(interactionContext);
//更新多人场景下的记忆
updateMultiChatSlices(interactionContext);
//更新确定性记忆
executor.execute(() -> {
// memoryManager.insertSlice();
} catch (Exception e) {
log.error("记忆更新出错: {}", e.getLocalizedMessage());
});
} catch (InterruptedException e) {
log.error("记忆更新线程出错: {}", e.getLocalizedMessage());
}
});
}
//更新确定性记忆
executor.execute(() -> {
}
});
}
private void updateMultiChatSlices(InteractionContext interactionContext) {
//TODO 更新多人场景对话记忆
//此时chatMessages中不再包含单聊记录直接执行摘要以及切片插入
}
private void updateSingleChatSlices(InteractionContext interactionContext) throws InterruptedException {
//更新单聊记忆同时从chatMessages中去掉单聊记忆
Set<String> userIdSet = new HashSet<>(sessionManager.getSingleMetaMessageMap().keySet());
List<Callable<Void>> tasks = new ArrayList<>();
//多人聊天?
for (String id : userIdSet) {
tasks.add(() -> {
List<Message> messages = sessionManager.unpackAndClear(id);
try {
SummarizeResult summarizeResult = memorySummarizer.execute(new TotalSummarizeInput(messages, memoryManager.getTopicTree()));
MemorySlice memorySlice = getMemorySlice(interactionContext, summarizeResult, messages);
memoryManager.insertSlice(memorySlice, summarizeResult.getTopicPath());
//从chatMessages中移除单聊记录
memoryManager.cleanMessage(messages);
} catch (Exception e) {
log.error("记忆更新出错: {}", e.getLocalizedMessage());
}
return null;
});
}
executor.invokeAll(tasks);
}
private static MemorySlice getMemorySlice(InteractionContext interactionContext, SummarizeResult summarizeResult, List<Message> chatMessages) {
MemorySlice memorySlice = new MemorySlice();
memorySlice.setPrivate(summarizeResult.isPrivate());
memorySlice.setSummary(summarizeResult.getSummary());
memorySlice.setChatMessages(chatMessages);
memorySlice.setStartUserId(interactionContext.getUserId());
List<List<String>> relatedTopicPathList = new ArrayList<>();
for (String string : summarizeResult.getRelatedTopicPath()) {
List<String> list = Arrays.stream(string.split("->")).toList();
relatedTopicPathList.add(list);
}
memorySlice.setRelatedTopics(relatedTopicPathList);
return memorySlice;
}
}

View File

@@ -14,6 +14,7 @@ import work.slhaf.agent.common.model.Model;
import work.slhaf.agent.common.model.ModelConstant;
import work.slhaf.agent.core.interaction.InteractionThreadPoolExecutor;
import work.slhaf.agent.modules.memory.updater.summarizer.data.SummarizeResult;
import work.slhaf.agent.modules.memory.updater.summarizer.data.TotalSummarizeInput;
import java.io.IOException;
import java.util.ArrayList;
@@ -41,15 +42,15 @@ public class MemorySummarizer extends Model {
return memorySummarizer;
}
public SummarizeResult execute(List<Message> chatMessages) throws InterruptedException {
public SummarizeResult execute(TotalSummarizeInput input) throws InterruptedException {
//进行长文本批量摘要
singleMessageSummarize(chatMessages);
singleMessageSummarize(input.getChatMessages());
//进行整体摘要并返回结果
return multiMessageSummarize(chatMessages);
return multiMessageSummarize(input);
}
private SummarizeResult multiMessageSummarize(List<Message> chatMessages) {
String messageStr = JSONUtil.toJsonPrettyStr(chatMessages);
private SummarizeResult multiMessageSummarize(TotalSummarizeInput input) {
String messageStr = JSONUtil.toJsonPrettyStr(input);
return multiSummarizeExecute(prompts.get(1),messageStr);
}

View File

@@ -2,9 +2,12 @@ package work.slhaf.agent.modules.memory.updater.summarizer.data;
import lombok.Data;
import java.util.List;
@Data
public class SummarizeResult {
private String summary;
private String topicPath;
private List<String> relatedTopicPath;
private boolean isPrivate;
}

View File

@@ -0,0 +1,14 @@
package work.slhaf.agent.modules.memory.updater.summarizer.data;
import lombok.AllArgsConstructor;
import lombok.Data;
import work.slhaf.agent.common.chat.pojo.Message;
import java.util.List;
@AllArgsConstructor
@Data
public class TotalSummarizeInput {
private List<Message> chatMessages;
private String topicTree;
}

View File

@@ -30,9 +30,11 @@ public class PreprocessExecutor {
public InteractionContext execute(InteractionInputData inputData) {
InteractionContext context = new InteractionContext();
String userId = memoryManager.getUserId(inputData.getUserInfo(), inputData.getUserNickName());
context.setUserInfo(inputData.getUserInfo());
context.setUserId(userId);
context.setUserNickname(inputData.getUserNickName());
context.setUserInfo(inputData.getUserInfo());
context.setDateTime(inputData.getLocalDateTime());
context.setFinished(false);
@@ -41,8 +43,9 @@ public class PreprocessExecutor {
context.setCoreContext(new JSONObject());
context.getCoreContext().put("text", inputData.getContent());
context.getCoreContext().put("datetime", LocalDateTime.now().format(DateTimeFormatter.ISO_LOCAL_DATE_TIME));
context.getCoreContext().put("character",memoryManager.getCharacter());
context.getCoreContext().put("character", memoryManager.getCharacter());
context.getCoreContext().put("user_nick", inputData.getUserNickName());
context.getCoreContext().put("user_id", userId);
context.setModuleContext(new JSONObject());