推进记忆模块

- 暂时移除 currentCompressedSessionContext 字段,其功能目前来看与userDialogMap与dialogMap重复,且与当前设想逻辑结合存在困难
- 添加了 dialogMap 更新逻辑
- 添加静态记忆提取功能,在更新单聊、多聊记忆时同步提取(实际运行存在先后,但放到一起了)
- 新增 Model 子类所需的 prompt 待添加
- 主模型相关 prompt 待调整
This commit is contained in:
2025-05-06 23:09:24 +08:00
parent b8b5661d79
commit 3dd21f840e
10 changed files with 170 additions and 53 deletions

View File

@@ -304,4 +304,6 @@ public class ModelConstant {
"""; """;
public static final String BASE_SUMMARIZER_PROMPT = """ public static final String BASE_SUMMARIZER_PROMPT = """
"""; """;
public static final String STATIC_MEMORY_EXTRACTOR_PROMPT = """
""";
} }

View File

@@ -63,7 +63,7 @@ public class MemoryGraph extends PersistableObject {
/** /**
* 当前对话的活动性总结, 拥有比dialogMap更丰富的全文细节, 作为当前对话token超限时的必要上下文压缩存储 * 当前对话的活动性总结, 拥有比dialogMap更丰富的全文细节, 作为当前对话token超限时的必要上下文压缩存储
*/ */
private List<String> currentCompressedSessionContext; // private List<String> currentCompressedSessionContext;
/** /**
* 存储确定性记忆, 如'用户爱好'等确定性信息 * 存储确定性记忆, 如'用户爱好'等确定性信息
@@ -123,7 +123,7 @@ public class MemoryGraph extends PersistableObject {
this.selectedSlices = new HashSet<>(); this.selectedSlices = new HashSet<>();
this.users = new ArrayList<>(); this.users = new ArrayList<>();
this.userDialogMap = new ConcurrentHashMap<>(); this.userDialogMap = new ConcurrentHashMap<>();
this.currentCompressedSessionContext = new ArrayList<>(); // this.currentCompressedSessionContext = new ArrayList<>();
this.dialogMap = new HashMap<>(); this.dialogMap = new HashMap<>();
} }
@@ -255,24 +255,9 @@ public class MemoryGraph extends PersistableObject {
String summary = slice.getSummary(); String summary = slice.getSummary();
LocalDateTime now = LocalDateTime.now(); LocalDateTime now = LocalDateTime.now();
//更新dialogMap -------------------------
//移除两天前的上下文缓存(切片总结)
List<LocalDateTime> keysToRemove = new ArrayList<>();
dialogMap.forEach((k, v) -> {
if (now.minusDays(2).isAfter(k)) {
keysToRemove.add(k);
}
});
for (LocalDateTime dateTime : keysToRemove) {
dialogMap.remove(dateTime);
}
keysToRemove.clear();
//放入新缓存
dialogMap.put(now, summary);
//---------------------------------------
//更新userDialogMap //更新userDialogMap
//移除两天前上下文缓存(切片总结) //移除两天前上下文缓存(切片总结)
List<LocalDateTime> keysToRemove = new ArrayList<>();
userDialogMap.forEach((k, v) -> { userDialogMap.forEach((k, v) -> {
v.forEach((i, j) -> { v.forEach((i, j) -> {
if (now.minusDays(2).isAfter(i)) { if (now.minusDays(2).isAfter(i)) {
@@ -288,7 +273,7 @@ public class MemoryGraph extends PersistableObject {
//放入新缓存 //放入新缓存
userDialogMap userDialogMap
.computeIfAbsent(slice.getStartUserId(), k -> new ConcurrentHashMap<>()) .computeIfAbsent(slice.getStartUserId(), k -> new ConcurrentHashMap<>())
.merge(now, slice.getSummary(), (oldVal, newVal) -> oldVal + " " + newVal); .merge(now, summary, (oldVal, newVal) -> oldVal + " " + newVal);
} }
@@ -484,5 +469,21 @@ public class MemoryGraph extends PersistableObject {
} }
public void updateDialogMap(LocalDateTime dateTime, String newDialogCache) {
List<LocalDateTime> keysToRemove = new ArrayList<>();
dialogMap.forEach((k, v) -> {
if (dateTime.minusDays(2).isAfter(k)) {
keysToRemove.add(k);
}
});
for (LocalDateTime temp : keysToRemove) {
dialogMap.remove(temp);
}
keysToRemove.clear();
//放入新缓存
dialogMap.put(dateTime, newDialogCache);
}
} }

View File

@@ -130,4 +130,12 @@ public class MemoryManager implements InteractionModule {
memoryGraph.getChatMessages().removeAll(messages); memoryGraph.getChatMessages().removeAll(messages);
messageCleanLock.unlock(); messageCleanLock.unlock();
} }
public void insertStaticMemory(String userId, Map<String, String> newStaticMemory) {
memoryGraph.getStaticMemory().get(userId).putAll(newStaticMemory);
}
public void updateDialogMap(LocalDateTime dateTime,String newDialogCache) {
memoryGraph.updateDialogMap(dateTime, newDialogCache);
}
} }

View File

@@ -14,7 +14,6 @@ public class SessionManager {
private static SessionManager sessionManager; private static SessionManager sessionManager;
private HashMap<String /*startUserId*/, List<MetaMessage>> singleMetaMessageMap; private HashMap<String /*startUserId*/, List<MetaMessage>> singleMetaMessageMap;
private HashMap<String /*startUserId*/, List<MetaMessage>> multiMetaMessageMap;
public static SessionManager getInstance() { public static SessionManager getInstance() {
if (sessionManager == null) { if (sessionManager == null) {

View File

@@ -5,15 +5,19 @@ import com.alibaba.fastjson2.JSONObject;
import lombok.Data; import lombok.Data;
import lombok.EqualsAndHashCode; import lombok.EqualsAndHashCode;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import work.slhaf.agent.common.chat.pojo.Message;
import work.slhaf.agent.common.chat.pojo.MetaMessage;
import work.slhaf.agent.common.config.Config; import work.slhaf.agent.common.config.Config;
import work.slhaf.agent.common.model.Model; import work.slhaf.agent.common.model.Model;
import work.slhaf.agent.common.model.ModelConstant; import work.slhaf.agent.common.model.ModelConstant;
import work.slhaf.agent.core.interaction.data.InteractionContext; import work.slhaf.agent.core.interaction.data.InteractionContext;
import work.slhaf.agent.core.memory.MemoryManager; import work.slhaf.agent.core.memory.MemoryManager;
import work.slhaf.agent.core.session.SessionManager;
import work.slhaf.agent.modules.memory.selector.extractor.data.ExtractorInput; import work.slhaf.agent.modules.memory.selector.extractor.data.ExtractorInput;
import work.slhaf.agent.modules.memory.selector.extractor.data.ExtractorResult; import work.slhaf.agent.modules.memory.selector.extractor.data.ExtractorResult;
import java.io.IOException; import java.io.IOException;
import java.util.ArrayList;
import java.util.List; import java.util.List;
import static work.slhaf.agent.common.util.ExtractUtil.extractJson; import static work.slhaf.agent.common.util.ExtractUtil.extractJson;
@@ -26,6 +30,7 @@ public class MemorySelectExtractor extends Model {
private static MemorySelectExtractor memorySelectExtractor; private static MemorySelectExtractor memorySelectExtractor;
private MemoryManager memoryManager; private MemoryManager memoryManager;
private SessionManager sessionManager;
private MemorySelectExtractor() { private MemorySelectExtractor() {
} }
@@ -35,6 +40,7 @@ public class MemorySelectExtractor extends Model {
Config config = Config.getConfig(); Config config = Config.getConfig();
memorySelectExtractor = new MemorySelectExtractor(); memorySelectExtractor = new MemorySelectExtractor();
memorySelectExtractor.setMemoryManager(MemoryManager.getInstance()); memorySelectExtractor.setMemoryManager(MemoryManager.getInstance());
memorySelectExtractor.setSessionManager(SessionManager.getInstance());
setModel(config, memorySelectExtractor, MODEL_KEY, ModelConstant.SELECT_EXTRACTOR_PROMPT); setModel(config, memorySelectExtractor, MODEL_KEY, ModelConstant.SELECT_EXTRACTOR_PROMPT);
} }
@@ -43,11 +49,16 @@ public class MemorySelectExtractor extends Model {
public ExtractorResult execute(InteractionContext context) { public ExtractorResult execute(InteractionContext context) {
//结构化为指定格式 //结构化为指定格式
//TODO 将历史消息替换为sessionManager中的用户对应信息列表 List<Message> chatMessages = new ArrayList<>();
for (MetaMessage metaMessage : sessionManager.getSingleMetaMessageMap().get(context.getUserId())) {
chatMessages.add(metaMessage.getUserMessage());
chatMessages.add(metaMessage.getAssistantMessage());
}
ExtractorInput extractorInput = ExtractorInput.builder() ExtractorInput extractorInput = ExtractorInput.builder()
.text(context.getInput()) .text(context.getInput())
.date(context.getDateTime().toLocalDate()) .date(context.getDateTime().toLocalDate())
.history(memoryManager.getChatMessages()) .history(chatMessages)
.topic_tree(memoryManager.getTopicTree()) .topic_tree(memoryManager.getTopicTree())
.build(); .build();
String responseStr = extractJson(singleChat(JSONUtil.toJsonPrettyStr(extractorInput)).getMessage()); String responseStr = extractJson(singleChat(JSONUtil.toJsonPrettyStr(extractorInput)).getMessage());

View File

@@ -11,11 +11,14 @@ import work.slhaf.agent.core.memory.MemoryManager;
import work.slhaf.agent.core.memory.pojo.MemorySlice; import work.slhaf.agent.core.memory.pojo.MemorySlice;
import work.slhaf.agent.core.session.SessionManager; import work.slhaf.agent.core.session.SessionManager;
import work.slhaf.agent.modules.memory.selector.extractor.MemorySelectExtractor; import work.slhaf.agent.modules.memory.selector.extractor.MemorySelectExtractor;
import work.slhaf.agent.modules.memory.updater.static_extractor.StaticMemoryExtractor;
import work.slhaf.agent.modules.memory.updater.static_extractor.data.StaticMemoryExtractInput;
import work.slhaf.agent.modules.memory.updater.summarizer.MemorySummarizer; import work.slhaf.agent.modules.memory.updater.summarizer.MemorySummarizer;
import work.slhaf.agent.modules.memory.updater.summarizer.data.SummarizeResult; import work.slhaf.agent.modules.memory.updater.summarizer.data.SummarizeResult;
import work.slhaf.agent.modules.memory.updater.summarizer.data.TotalSummarizeInput; import work.slhaf.agent.modules.memory.updater.summarizer.data.SummarizeInput;
import java.io.IOException; import java.io.IOException;
import java.time.LocalDateTime;
import java.util.*; import java.util.*;
import java.util.concurrent.Callable; import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService; import java.util.concurrent.ExecutorService;
@@ -33,6 +36,7 @@ public class MemoryUpdater implements InteractionModule {
private MemorySelectExtractor memorySelectExtractor; private MemorySelectExtractor memorySelectExtractor;
private MemorySummarizer memorySummarizer; private MemorySummarizer memorySummarizer;
private SessionManager sessionManager; private SessionManager sessionManager;
private StaticMemoryExtractor staticMemoryExtractor;
private MemoryUpdater() { private MemoryUpdater() {
} }
@@ -45,13 +49,13 @@ public class MemoryUpdater implements InteractionModule {
memoryUpdater.setMemorySummarizer(MemorySummarizer.getInstance()); memoryUpdater.setMemorySummarizer(MemorySummarizer.getInstance());
memoryUpdater.setSessionManager(SessionManager.getInstance()); memoryUpdater.setSessionManager(SessionManager.getInstance());
memoryUpdater.setUpdateExecutor(Executors.newSingleThreadExecutor()); memoryUpdater.setUpdateExecutor(Executors.newSingleThreadExecutor());
memoryUpdater.setStaticMemoryExtractor(StaticMemoryExtractor.getInstance());
} }
return memoryUpdater; return memoryUpdater;
} }
@Override @Override
public void execute(InteractionContext interactionContext) throws InterruptedException { public void execute(InteractionContext interactionContext) {
//TODO 需要保持压缩上下文、更新总摘要、更新确定性记忆、总结所有切片后更新dialogMap
if (interactionContext.isFinished()) { if (interactionContext.isFinished()) {
return; return;
} }
@@ -59,16 +63,14 @@ public class MemoryUpdater implements InteractionModule {
//如果token 大于阈值,则更新记忆 //如果token 大于阈值,则更新记忆
JSONObject moduleContext = interactionContext.getModuleContext(); JSONObject moduleContext = interactionContext.getModuleContext();
if (moduleContext.getIntValue("total_token") > 24000) { if (moduleContext.getIntValue("total_token") > 24000) {
//更新单聊记忆同时从chatMessages中去掉单聊记忆 String userId = interactionContext.getUserId();
HashMap<String, String> singleMemorySummary = new HashMap<>();
try { try {
updateSingleChatSlices(interactionContext); //更新单聊记忆以及该场景中对应的确定性记忆同时从chatMessages中去掉单聊记忆
//更新多人场景下的记忆 updateSingleChatSlices(userId, singleMemorySummary);
updateMultiChatSlices(interactionContext); //更新多人场景下的记忆及相关的确定性记忆
//更新确定性记忆 updateMultiChatSlices(userId, singleMemorySummary);
executor.execute(() -> { } catch (InterruptedException | IOException | ClassNotFoundException e) {
});
} catch (InterruptedException e) {
log.error("记忆更新线程出错: {}", e.getLocalizedMessage()); log.error("记忆更新线程出错: {}", e.getLocalizedMessage());
} }
@@ -78,41 +80,67 @@ public class MemoryUpdater implements InteractionModule {
} }
private void updateMultiChatSlices(InteractionContext interactionContext) { private void updateMultiChatSlices(String userId, HashMap<String, String> singleMemorySummary) throws InterruptedException, IOException, ClassNotFoundException {
//TODO 更新多人场景对话记忆
//此时chatMessages中不再包含单聊记录直接执行摘要以及切片插入 //此时chatMessages中不再包含单聊记录直接执行摘要以及切片插入
//对剩下的多人聊天记录进行进行摘要
executor.execute(() -> {
try {
SummarizeResult summarizeResult = memorySummarizer.execute(new SummarizeInput(memoryManager.getChatMessages(), memoryManager.getTopicTree()));
MemorySlice memorySlice = getMemorySlice(userId, summarizeResult, memoryManager.getChatMessages());
memoryManager.insertSlice(memorySlice, summarizeResult.getTopicPath());
//更新总dialogMap
singleMemorySummary.put("total", summarizeResult.getSummary());
memoryManager.updateDialogMap(LocalDateTime.now(), memorySummarizer.executeTotalSummary(singleMemorySummary));
} catch (IOException | ClassNotFoundException | InterruptedException e) {
log.error("多人场景记忆更新失败: {}", e.getLocalizedMessage());
}
});
} }
private void updateSingleChatSlices(InteractionContext interactionContext) throws InterruptedException { private void updateSingleChatSlices(String interactionContext, HashMap<String, String> singleMemorySummary) throws InterruptedException {
//更新单聊记忆同时从chatMessages中去掉单聊记忆 //更新单聊记忆同时从chatMessages中去掉单聊记忆
Set<String> userIdSet = new HashSet<>(sessionManager.getSingleMetaMessageMap().keySet()); Set<String> userIdSet = new HashSet<>(sessionManager.getSingleMetaMessageMap().keySet());
List<Callable<Void>> tasks = new ArrayList<>(); List<Callable<Void>> tasks = new ArrayList<>();
//多人聊天? //多人聊天?
for (String id : userIdSet) { for (String id : userIdSet) {
tasks.add(() -> {
List<Message> messages = sessionManager.unpackAndClear(id); List<Message> messages = sessionManager.unpackAndClear(id);
tasks.add(() -> {
try { try {
SummarizeResult summarizeResult = memorySummarizer.execute(new TotalSummarizeInput(messages, memoryManager.getTopicTree())); //单聊记忆更新
SummarizeResult summarizeResult = memorySummarizer.execute(new SummarizeInput(messages, memoryManager.getTopicTree()));
MemorySlice memorySlice = getMemorySlice(interactionContext, summarizeResult, messages); MemorySlice memorySlice = getMemorySlice(interactionContext, summarizeResult, messages);
//插入时userDialogMap已经进行更新
memoryManager.insertSlice(memorySlice, summarizeResult.getTopicPath()); memoryManager.insertSlice(memorySlice, summarizeResult.getTopicPath());
//从chatMessages中移除单聊记录 //从chatMessages中移除单聊记录
memoryManager.cleanMessage(messages); memoryManager.cleanMessage(messages);
//添加至singleMemorySummary
singleMemorySummary.put(id, summarizeResult.getSummary());
} catch (Exception e) { } catch (Exception e) {
log.error("记忆更新出错: {}", e.getLocalizedMessage()); log.error("单聊记忆更新出错: {}", e.getLocalizedMessage());
} }
return null; return null;
}); });
tasks.add(() -> {
StaticMemoryExtractInput input = StaticMemoryExtractInput.builder()
.userId(id)
.messages(messages)
.existedStaticMemory(memoryManager.getStaticMemory(id))
.build();
Map<String, String> staticMemoryResult = staticMemoryExtractor.execute(input);
memoryManager.insertStaticMemory(id, staticMemoryResult);
return null;
});
} }
executor.invokeAll(tasks); executor.invokeAll(tasks);
} }
private static MemorySlice getMemorySlice(InteractionContext interactionContext, SummarizeResult summarizeResult, List<Message> chatMessages) { private static MemorySlice getMemorySlice(String userId, SummarizeResult summarizeResult, List<Message> chatMessages) {
MemorySlice memorySlice = new MemorySlice(); MemorySlice memorySlice = new MemorySlice();
memorySlice.setPrivate(summarizeResult.isPrivate()); memorySlice.setPrivate(summarizeResult.isPrivate());
memorySlice.setSummary(summarizeResult.getSummary()); memorySlice.setSummary(summarizeResult.getSummary());
memorySlice.setChatMessages(chatMessages); memorySlice.setChatMessages(chatMessages);
memorySlice.setStartUserId(interactionContext.getUserId()); memorySlice.setStartUserId(userId);
List<List<String>> relatedTopicPathList = new ArrayList<>(); List<List<String>> relatedTopicPathList = new ArrayList<>();
for (String string : summarizeResult.getRelatedTopicPath()) { for (String string : summarizeResult.getRelatedTopicPath()) {
List<String> list = Arrays.stream(string.split("->")).toList(); List<String> list = Arrays.stream(string.split("->")).toList();

View File

@@ -0,0 +1,43 @@
package work.slhaf.agent.modules.memory.updater.static_extractor;
import cn.hutool.json.JSONUtil;
import com.alibaba.fastjson2.JSONObject;
import lombok.Data;
import lombok.EqualsAndHashCode;
import work.slhaf.agent.common.chat.pojo.ChatResponse;
import work.slhaf.agent.common.config.Config;
import work.slhaf.agent.common.model.Model;
import work.slhaf.agent.common.model.ModelConstant;
import work.slhaf.agent.modules.memory.updater.static_extractor.data.StaticMemoryExtractInput;
import java.io.IOException;
import java.util.HashMap;
import java.util.Map;
@EqualsAndHashCode(callSuper = true)
@Data
public class StaticMemoryExtractor extends Model {
private static StaticMemoryExtractor staticMemoryExtractor;
private static final String MODEL_KEY = "static_memory_extractor";
public static StaticMemoryExtractor getInstance() throws IOException, ClassNotFoundException {
if (staticMemoryExtractor == null) {
staticMemoryExtractor = new StaticMemoryExtractor();
setModel(Config.getConfig(), staticMemoryExtractor, MODEL_KEY, ModelConstant.STATIC_MEMORY_EXTRACTOR_PROMPT);
}
return staticMemoryExtractor;
}
public Map<String, String> execute(StaticMemoryExtractInput input) {
ChatResponse response = singleChat(JSONUtil.toJsonPrettyStr(input));
JSONObject jsonObject = JSONObject.parseObject(response.getMessage());
Map<String, String> result = new HashMap<>();
jsonObject.forEach((k, v) -> {
result.put(k, (String) v);
});
return result;
}
}

View File

@@ -0,0 +1,17 @@
package work.slhaf.agent.modules.memory.updater.static_extractor.data;
import lombok.Builder;
import lombok.Data;
import work.slhaf.agent.common.chat.pojo.Message;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
@Data
@Builder
public class StaticMemoryExtractInput {
private String userId;
private List<Message> messages;
private Map<String,String> existedStaticMemory;
}

View File

@@ -14,14 +14,17 @@ import work.slhaf.agent.common.model.Model;
import work.slhaf.agent.common.model.ModelConstant; import work.slhaf.agent.common.model.ModelConstant;
import work.slhaf.agent.core.interaction.InteractionThreadPoolExecutor; import work.slhaf.agent.core.interaction.InteractionThreadPoolExecutor;
import work.slhaf.agent.modules.memory.updater.summarizer.data.SummarizeResult; import work.slhaf.agent.modules.memory.updater.summarizer.data.SummarizeResult;
import work.slhaf.agent.modules.memory.updater.summarizer.data.TotalSummarizeInput; import work.slhaf.agent.modules.memory.updater.summarizer.data.SummarizeInput;
import java.io.IOException; import java.io.IOException;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.HashMap;
import java.util.List; import java.util.List;
import java.util.concurrent.Callable; import java.util.concurrent.Callable;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import static work.slhaf.agent.common.util.ExtractUtil.extractJson;
@EqualsAndHashCode(callSuper = true) @EqualsAndHashCode(callSuper = true)
@Data @Data
@Slf4j @Slf4j
@@ -42,14 +45,14 @@ public class MemorySummarizer extends Model {
return memorySummarizer; return memorySummarizer;
} }
public SummarizeResult execute(TotalSummarizeInput input) throws InterruptedException { public SummarizeResult execute(SummarizeInput input) throws InterruptedException {
//进行长文本批量摘要 //进行长文本批量摘要
singleMessageSummarize(input.getChatMessages()); singleMessageSummarize(input.getChatMessages());
//进行整体摘要并返回结果 //进行整体摘要并返回结果
return multiMessageSummarize(input); return multiMessageSummarize(input);
} }
private SummarizeResult multiMessageSummarize(TotalSummarizeInput input) { private SummarizeResult multiMessageSummarize(SummarizeInput input) {
String messageStr = JSONUtil.toJsonPrettyStr(input); String messageStr = JSONUtil.toJsonPrettyStr(input);
return multiSummarizeExecute(prompts.get(1), messageStr); return multiSummarizeExecute(prompts.get(1), messageStr);
} }
@@ -88,4 +91,9 @@ public class MemorySummarizer extends Model {
} }
public String executeTotalSummary(HashMap<String, String> singleMemorySummary) {
ChatResponse response = chatClient.runChat(List.of(new Message(ChatConstant.Character.SYSTEM, prompts.get(2)),
new Message(ChatConstant.Character.USER, JSONUtil.toJsonPrettyStr(singleMemorySummary))));
return JSONObject.parseObject(extractJson(response.getMessage())).getString("value");
}
} }

View File

@@ -8,7 +8,7 @@ import java.util.List;
@AllArgsConstructor @AllArgsConstructor
@Data @Data
public class TotalSummarizeInput { public class SummarizeInput {
private List<Message> chatMessages; private List<Message> chatMessages;
private String topicTree; private String topicTree;
} }