mirror of
https://github.com/slhaf/Partner.git
synced 2026-05-12 08:43:02 +08:00
推进记忆模块
- 暂时移除 currentCompressedSessionContext 字段,其功能目前来看与userDialogMap与dialogMap重复,且与当前设想逻辑结合存在困难 - 添加了 dialogMap 更新逻辑 - 添加静态记忆提取功能,在更新单聊、多聊记忆时同步提取(实际运行存在先后,但放到一起了) - 新增 Model 子类所需的 prompt 待添加 - 主模型相关 prompt 待调整
This commit is contained in:
@@ -304,4 +304,6 @@ public class ModelConstant {
|
||||
""";
|
||||
public static final String BASE_SUMMARIZER_PROMPT = """
|
||||
""";
|
||||
public static final String STATIC_MEMORY_EXTRACTOR_PROMPT = """
|
||||
""";
|
||||
}
|
||||
|
||||
@@ -63,7 +63,7 @@ public class MemoryGraph extends PersistableObject {
|
||||
/**
|
||||
* 当前对话的活动性总结, 拥有比dialogMap更丰富的全文细节, 作为当前对话token超限时的必要上下文压缩存储
|
||||
*/
|
||||
private List<String> currentCompressedSessionContext;
|
||||
// private List<String> currentCompressedSessionContext;
|
||||
|
||||
/**
|
||||
* 存储确定性记忆, 如'用户爱好'等确定性信息
|
||||
@@ -123,7 +123,7 @@ public class MemoryGraph extends PersistableObject {
|
||||
this.selectedSlices = new HashSet<>();
|
||||
this.users = new ArrayList<>();
|
||||
this.userDialogMap = new ConcurrentHashMap<>();
|
||||
this.currentCompressedSessionContext = new ArrayList<>();
|
||||
// this.currentCompressedSessionContext = new ArrayList<>();
|
||||
this.dialogMap = new HashMap<>();
|
||||
}
|
||||
|
||||
@@ -255,24 +255,9 @@ public class MemoryGraph extends PersistableObject {
|
||||
String summary = slice.getSummary();
|
||||
LocalDateTime now = LocalDateTime.now();
|
||||
|
||||
//更新dialogMap -------------------------
|
||||
//移除两天前的上下文缓存(切片总结)
|
||||
List<LocalDateTime> keysToRemove = new ArrayList<>();
|
||||
dialogMap.forEach((k, v) -> {
|
||||
if (now.minusDays(2).isAfter(k)) {
|
||||
keysToRemove.add(k);
|
||||
}
|
||||
});
|
||||
for (LocalDateTime dateTime : keysToRemove) {
|
||||
dialogMap.remove(dateTime);
|
||||
}
|
||||
keysToRemove.clear();
|
||||
//放入新缓存
|
||||
dialogMap.put(now, summary);
|
||||
//---------------------------------------
|
||||
|
||||
//更新userDialogMap
|
||||
//移除两天前上下文缓存(切片总结)
|
||||
List<LocalDateTime> keysToRemove = new ArrayList<>();
|
||||
userDialogMap.forEach((k, v) -> {
|
||||
v.forEach((i, j) -> {
|
||||
if (now.minusDays(2).isAfter(i)) {
|
||||
@@ -288,7 +273,7 @@ public class MemoryGraph extends PersistableObject {
|
||||
//放入新缓存
|
||||
userDialogMap
|
||||
.computeIfAbsent(slice.getStartUserId(), k -> new ConcurrentHashMap<>())
|
||||
.merge(now, slice.getSummary(), (oldVal, newVal) -> oldVal + " " + newVal);
|
||||
.merge(now, summary, (oldVal, newVal) -> oldVal + " " + newVal);
|
||||
|
||||
}
|
||||
|
||||
@@ -484,5 +469,21 @@ public class MemoryGraph extends PersistableObject {
|
||||
}
|
||||
|
||||
|
||||
public void updateDialogMap(LocalDateTime dateTime, String newDialogCache) {
|
||||
List<LocalDateTime> keysToRemove = new ArrayList<>();
|
||||
|
||||
dialogMap.forEach((k, v) -> {
|
||||
if (dateTime.minusDays(2).isAfter(k)) {
|
||||
keysToRemove.add(k);
|
||||
}
|
||||
});
|
||||
for (LocalDateTime temp : keysToRemove) {
|
||||
dialogMap.remove(temp);
|
||||
}
|
||||
keysToRemove.clear();
|
||||
//放入新缓存
|
||||
dialogMap.put(dateTime, newDialogCache);
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -130,4 +130,12 @@ public class MemoryManager implements InteractionModule {
|
||||
memoryGraph.getChatMessages().removeAll(messages);
|
||||
messageCleanLock.unlock();
|
||||
}
|
||||
|
||||
public void insertStaticMemory(String userId, Map<String, String> newStaticMemory) {
|
||||
memoryGraph.getStaticMemory().get(userId).putAll(newStaticMemory);
|
||||
}
|
||||
|
||||
public void updateDialogMap(LocalDateTime dateTime,String newDialogCache) {
|
||||
memoryGraph.updateDialogMap(dateTime, newDialogCache);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -14,7 +14,6 @@ public class SessionManager {
|
||||
private static SessionManager sessionManager;
|
||||
|
||||
private HashMap<String /*startUserId*/, List<MetaMessage>> singleMetaMessageMap;
|
||||
private HashMap<String /*startUserId*/, List<MetaMessage>> multiMetaMessageMap;
|
||||
|
||||
public static SessionManager getInstance() {
|
||||
if (sessionManager == null) {
|
||||
|
||||
@@ -5,15 +5,19 @@ import com.alibaba.fastjson2.JSONObject;
|
||||
import lombok.Data;
|
||||
import lombok.EqualsAndHashCode;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import work.slhaf.agent.common.chat.pojo.Message;
|
||||
import work.slhaf.agent.common.chat.pojo.MetaMessage;
|
||||
import work.slhaf.agent.common.config.Config;
|
||||
import work.slhaf.agent.common.model.Model;
|
||||
import work.slhaf.agent.common.model.ModelConstant;
|
||||
import work.slhaf.agent.core.interaction.data.InteractionContext;
|
||||
import work.slhaf.agent.core.memory.MemoryManager;
|
||||
import work.slhaf.agent.core.session.SessionManager;
|
||||
import work.slhaf.agent.modules.memory.selector.extractor.data.ExtractorInput;
|
||||
import work.slhaf.agent.modules.memory.selector.extractor.data.ExtractorResult;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
import static work.slhaf.agent.common.util.ExtractUtil.extractJson;
|
||||
@@ -26,6 +30,7 @@ public class MemorySelectExtractor extends Model {
|
||||
private static MemorySelectExtractor memorySelectExtractor;
|
||||
|
||||
private MemoryManager memoryManager;
|
||||
private SessionManager sessionManager;
|
||||
|
||||
private MemorySelectExtractor() {
|
||||
}
|
||||
@@ -35,6 +40,7 @@ public class MemorySelectExtractor extends Model {
|
||||
Config config = Config.getConfig();
|
||||
memorySelectExtractor = new MemorySelectExtractor();
|
||||
memorySelectExtractor.setMemoryManager(MemoryManager.getInstance());
|
||||
memorySelectExtractor.setSessionManager(SessionManager.getInstance());
|
||||
setModel(config, memorySelectExtractor, MODEL_KEY, ModelConstant.SELECT_EXTRACTOR_PROMPT);
|
||||
}
|
||||
|
||||
@@ -43,11 +49,16 @@ public class MemorySelectExtractor extends Model {
|
||||
|
||||
public ExtractorResult execute(InteractionContext context) {
|
||||
//结构化为指定格式
|
||||
//TODO 将历史消息替换为sessionManager中的用户对应信息列表
|
||||
List<Message> chatMessages = new ArrayList<>();
|
||||
for (MetaMessage metaMessage : sessionManager.getSingleMetaMessageMap().get(context.getUserId())) {
|
||||
chatMessages.add(metaMessage.getUserMessage());
|
||||
chatMessages.add(metaMessage.getAssistantMessage());
|
||||
}
|
||||
|
||||
ExtractorInput extractorInput = ExtractorInput.builder()
|
||||
.text(context.getInput())
|
||||
.date(context.getDateTime().toLocalDate())
|
||||
.history(memoryManager.getChatMessages())
|
||||
.history(chatMessages)
|
||||
.topic_tree(memoryManager.getTopicTree())
|
||||
.build();
|
||||
String responseStr = extractJson(singleChat(JSONUtil.toJsonPrettyStr(extractorInput)).getMessage());
|
||||
|
||||
@@ -11,11 +11,14 @@ import work.slhaf.agent.core.memory.MemoryManager;
|
||||
import work.slhaf.agent.core.memory.pojo.MemorySlice;
|
||||
import work.slhaf.agent.core.session.SessionManager;
|
||||
import work.slhaf.agent.modules.memory.selector.extractor.MemorySelectExtractor;
|
||||
import work.slhaf.agent.modules.memory.updater.static_extractor.StaticMemoryExtractor;
|
||||
import work.slhaf.agent.modules.memory.updater.static_extractor.data.StaticMemoryExtractInput;
|
||||
import work.slhaf.agent.modules.memory.updater.summarizer.MemorySummarizer;
|
||||
import work.slhaf.agent.modules.memory.updater.summarizer.data.SummarizeResult;
|
||||
import work.slhaf.agent.modules.memory.updater.summarizer.data.TotalSummarizeInput;
|
||||
import work.slhaf.agent.modules.memory.updater.summarizer.data.SummarizeInput;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.time.LocalDateTime;
|
||||
import java.util.*;
|
||||
import java.util.concurrent.Callable;
|
||||
import java.util.concurrent.ExecutorService;
|
||||
@@ -33,6 +36,7 @@ public class MemoryUpdater implements InteractionModule {
|
||||
private MemorySelectExtractor memorySelectExtractor;
|
||||
private MemorySummarizer memorySummarizer;
|
||||
private SessionManager sessionManager;
|
||||
private StaticMemoryExtractor staticMemoryExtractor;
|
||||
|
||||
private MemoryUpdater() {
|
||||
}
|
||||
@@ -45,13 +49,13 @@ public class MemoryUpdater implements InteractionModule {
|
||||
memoryUpdater.setMemorySummarizer(MemorySummarizer.getInstance());
|
||||
memoryUpdater.setSessionManager(SessionManager.getInstance());
|
||||
memoryUpdater.setUpdateExecutor(Executors.newSingleThreadExecutor());
|
||||
memoryUpdater.setStaticMemoryExtractor(StaticMemoryExtractor.getInstance());
|
||||
}
|
||||
return memoryUpdater;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void execute(InteractionContext interactionContext) throws InterruptedException {
|
||||
//TODO 需要保持压缩上下文、更新总摘要、更新确定性记忆、总结所有切片后,更新dialogMap
|
||||
public void execute(InteractionContext interactionContext) {
|
||||
if (interactionContext.isFinished()) {
|
||||
return;
|
||||
}
|
||||
@@ -59,16 +63,14 @@ public class MemoryUpdater implements InteractionModule {
|
||||
//如果token 大于阈值,则更新记忆
|
||||
JSONObject moduleContext = interactionContext.getModuleContext();
|
||||
if (moduleContext.getIntValue("total_token") > 24000) {
|
||||
//更新单聊记忆,同时从chatMessages中去掉单聊记忆
|
||||
String userId = interactionContext.getUserId();
|
||||
HashMap<String, String> singleMemorySummary = new HashMap<>();
|
||||
try {
|
||||
updateSingleChatSlices(interactionContext);
|
||||
//更新多人场景下的记忆
|
||||
updateMultiChatSlices(interactionContext);
|
||||
//更新确定性记忆
|
||||
executor.execute(() -> {
|
||||
|
||||
});
|
||||
} catch (InterruptedException e) {
|
||||
//更新单聊记忆以及该场景中对应的确定性记忆,同时从chatMessages中去掉单聊记忆
|
||||
updateSingleChatSlices(userId, singleMemorySummary);
|
||||
//更新多人场景下的记忆及相关的确定性记忆
|
||||
updateMultiChatSlices(userId, singleMemorySummary);
|
||||
} catch (InterruptedException | IOException | ClassNotFoundException e) {
|
||||
log.error("记忆更新线程出错: {}", e.getLocalizedMessage());
|
||||
}
|
||||
|
||||
@@ -78,41 +80,67 @@ public class MemoryUpdater implements InteractionModule {
|
||||
|
||||
}
|
||||
|
||||
private void updateMultiChatSlices(InteractionContext interactionContext) {
|
||||
//TODO 更新多人场景对话记忆
|
||||
private void updateMultiChatSlices(String userId, HashMap<String, String> singleMemorySummary) throws InterruptedException, IOException, ClassNotFoundException {
|
||||
//此时chatMessages中不再包含单聊记录,直接执行摘要以及切片插入
|
||||
|
||||
//对剩下的多人聊天记录进行进行摘要
|
||||
executor.execute(() -> {
|
||||
try {
|
||||
SummarizeResult summarizeResult = memorySummarizer.execute(new SummarizeInput(memoryManager.getChatMessages(), memoryManager.getTopicTree()));
|
||||
MemorySlice memorySlice = getMemorySlice(userId, summarizeResult, memoryManager.getChatMessages());
|
||||
memoryManager.insertSlice(memorySlice, summarizeResult.getTopicPath());
|
||||
//更新总dialogMap
|
||||
singleMemorySummary.put("total", summarizeResult.getSummary());
|
||||
memoryManager.updateDialogMap(LocalDateTime.now(), memorySummarizer.executeTotalSummary(singleMemorySummary));
|
||||
} catch (IOException | ClassNotFoundException | InterruptedException e) {
|
||||
log.error("多人场景记忆更新失败: {}", e.getLocalizedMessage());
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
private void updateSingleChatSlices(InteractionContext interactionContext) throws InterruptedException {
|
||||
private void updateSingleChatSlices(String interactionContext, HashMap<String, String> singleMemorySummary) throws InterruptedException {
|
||||
//更新单聊记忆,同时从chatMessages中去掉单聊记忆
|
||||
Set<String> userIdSet = new HashSet<>(sessionManager.getSingleMetaMessageMap().keySet());
|
||||
List<Callable<Void>> tasks = new ArrayList<>();
|
||||
//多人聊天?
|
||||
for (String id : userIdSet) {
|
||||
List<Message> messages = sessionManager.unpackAndClear(id);
|
||||
tasks.add(() -> {
|
||||
List<Message> messages = sessionManager.unpackAndClear(id);
|
||||
try {
|
||||
SummarizeResult summarizeResult = memorySummarizer.execute(new TotalSummarizeInput(messages, memoryManager.getTopicTree()));
|
||||
//单聊记忆更新
|
||||
SummarizeResult summarizeResult = memorySummarizer.execute(new SummarizeInput(messages, memoryManager.getTopicTree()));
|
||||
MemorySlice memorySlice = getMemorySlice(interactionContext, summarizeResult, messages);
|
||||
//插入时userDialogMap已经进行更新
|
||||
memoryManager.insertSlice(memorySlice, summarizeResult.getTopicPath());
|
||||
//从chatMessages中移除单聊记录
|
||||
memoryManager.cleanMessage(messages);
|
||||
//添加至singleMemorySummary
|
||||
singleMemorySummary.put(id, summarizeResult.getSummary());
|
||||
} catch (Exception e) {
|
||||
log.error("记忆更新出错: {}", e.getLocalizedMessage());
|
||||
log.error("单聊记忆更新出错: {}", e.getLocalizedMessage());
|
||||
}
|
||||
return null;
|
||||
});
|
||||
|
||||
tasks.add(() -> {
|
||||
StaticMemoryExtractInput input = StaticMemoryExtractInput.builder()
|
||||
.userId(id)
|
||||
.messages(messages)
|
||||
.existedStaticMemory(memoryManager.getStaticMemory(id))
|
||||
.build();
|
||||
Map<String, String> staticMemoryResult = staticMemoryExtractor.execute(input);
|
||||
memoryManager.insertStaticMemory(id, staticMemoryResult);
|
||||
return null;
|
||||
});
|
||||
}
|
||||
executor.invokeAll(tasks);
|
||||
}
|
||||
|
||||
private static MemorySlice getMemorySlice(InteractionContext interactionContext, SummarizeResult summarizeResult, List<Message> chatMessages) {
|
||||
private static MemorySlice getMemorySlice(String userId, SummarizeResult summarizeResult, List<Message> chatMessages) {
|
||||
MemorySlice memorySlice = new MemorySlice();
|
||||
memorySlice.setPrivate(summarizeResult.isPrivate());
|
||||
memorySlice.setSummary(summarizeResult.getSummary());
|
||||
memorySlice.setChatMessages(chatMessages);
|
||||
memorySlice.setStartUserId(interactionContext.getUserId());
|
||||
memorySlice.setStartUserId(userId);
|
||||
List<List<String>> relatedTopicPathList = new ArrayList<>();
|
||||
for (String string : summarizeResult.getRelatedTopicPath()) {
|
||||
List<String> list = Arrays.stream(string.split("->")).toList();
|
||||
|
||||
@@ -0,0 +1,43 @@
|
||||
package work.slhaf.agent.modules.memory.updater.static_extractor;
|
||||
|
||||
import cn.hutool.json.JSONUtil;
|
||||
import com.alibaba.fastjson2.JSONObject;
|
||||
import lombok.Data;
|
||||
import lombok.EqualsAndHashCode;
|
||||
import work.slhaf.agent.common.chat.pojo.ChatResponse;
|
||||
import work.slhaf.agent.common.config.Config;
|
||||
import work.slhaf.agent.common.model.Model;
|
||||
import work.slhaf.agent.common.model.ModelConstant;
|
||||
import work.slhaf.agent.modules.memory.updater.static_extractor.data.StaticMemoryExtractInput;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
|
||||
@EqualsAndHashCode(callSuper = true)
|
||||
@Data
|
||||
public class StaticMemoryExtractor extends Model {
|
||||
|
||||
private static StaticMemoryExtractor staticMemoryExtractor;
|
||||
|
||||
private static final String MODEL_KEY = "static_memory_extractor";
|
||||
|
||||
|
||||
public static StaticMemoryExtractor getInstance() throws IOException, ClassNotFoundException {
|
||||
if (staticMemoryExtractor == null) {
|
||||
staticMemoryExtractor = new StaticMemoryExtractor();
|
||||
setModel(Config.getConfig(), staticMemoryExtractor, MODEL_KEY, ModelConstant.STATIC_MEMORY_EXTRACTOR_PROMPT);
|
||||
}
|
||||
return staticMemoryExtractor;
|
||||
}
|
||||
|
||||
public Map<String, String> execute(StaticMemoryExtractInput input) {
|
||||
ChatResponse response = singleChat(JSONUtil.toJsonPrettyStr(input));
|
||||
JSONObject jsonObject = JSONObject.parseObject(response.getMessage());
|
||||
Map<String, String> result = new HashMap<>();
|
||||
jsonObject.forEach((k, v) -> {
|
||||
result.put(k, (String) v);
|
||||
});
|
||||
return result;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,17 @@
|
||||
package work.slhaf.agent.modules.memory.updater.static_extractor.data;
|
||||
|
||||
import lombok.Builder;
|
||||
import lombok.Data;
|
||||
import work.slhaf.agent.common.chat.pojo.Message;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
@Data
|
||||
@Builder
|
||||
public class StaticMemoryExtractInput {
|
||||
private String userId;
|
||||
private List<Message> messages;
|
||||
private Map<String,String> existedStaticMemory;
|
||||
}
|
||||
@@ -14,14 +14,17 @@ import work.slhaf.agent.common.model.Model;
|
||||
import work.slhaf.agent.common.model.ModelConstant;
|
||||
import work.slhaf.agent.core.interaction.InteractionThreadPoolExecutor;
|
||||
import work.slhaf.agent.modules.memory.updater.summarizer.data.SummarizeResult;
|
||||
import work.slhaf.agent.modules.memory.updater.summarizer.data.TotalSummarizeInput;
|
||||
import work.slhaf.agent.modules.memory.updater.summarizer.data.SummarizeInput;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.concurrent.Callable;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
|
||||
import static work.slhaf.agent.common.util.ExtractUtil.extractJson;
|
||||
|
||||
@EqualsAndHashCode(callSuper = true)
|
||||
@Data
|
||||
@Slf4j
|
||||
@@ -42,16 +45,16 @@ public class MemorySummarizer extends Model {
|
||||
return memorySummarizer;
|
||||
}
|
||||
|
||||
public SummarizeResult execute(TotalSummarizeInput input) throws InterruptedException {
|
||||
//进行长文本批量摘要
|
||||
singleMessageSummarize(input.getChatMessages());
|
||||
//进行整体摘要并返回结果
|
||||
return multiMessageSummarize(input);
|
||||
public SummarizeResult execute(SummarizeInput input) throws InterruptedException {
|
||||
//进行长文本批量摘要
|
||||
singleMessageSummarize(input.getChatMessages());
|
||||
//进行整体摘要并返回结果
|
||||
return multiMessageSummarize(input);
|
||||
}
|
||||
|
||||
private SummarizeResult multiMessageSummarize(TotalSummarizeInput input) {
|
||||
private SummarizeResult multiMessageSummarize(SummarizeInput input) {
|
||||
String messageStr = JSONUtil.toJsonPrettyStr(input);
|
||||
return multiSummarizeExecute(prompts.get(1),messageStr);
|
||||
return multiSummarizeExecute(prompts.get(1), messageStr);
|
||||
}
|
||||
|
||||
private SummarizeResult multiSummarizeExecute(String prompt, String messageStr) {
|
||||
@@ -73,7 +76,7 @@ public class MemorySummarizer extends Model {
|
||||
}
|
||||
}
|
||||
}
|
||||
executor.invokeAll(tasks,30, TimeUnit.SECONDS);
|
||||
executor.invokeAll(tasks, 30, TimeUnit.SECONDS);
|
||||
}
|
||||
|
||||
private @NonNull String singleSummarizeExecute(String prompt, String content) {
|
||||
@@ -88,4 +91,9 @@ public class MemorySummarizer extends Model {
|
||||
}
|
||||
|
||||
|
||||
public String executeTotalSummary(HashMap<String, String> singleMemorySummary) {
|
||||
ChatResponse response = chatClient.runChat(List.of(new Message(ChatConstant.Character.SYSTEM, prompts.get(2)),
|
||||
new Message(ChatConstant.Character.USER, JSONUtil.toJsonPrettyStr(singleMemorySummary))));
|
||||
return JSONObject.parseObject(extractJson(response.getMessage())).getString("value");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -8,7 +8,7 @@ import java.util.List;
|
||||
|
||||
@AllArgsConstructor
|
||||
@Data
|
||||
public class TotalSummarizeInput {
|
||||
public class SummarizeInput {
|
||||
private List<Message> chatMessages;
|
||||
private String topicTree;
|
||||
}
|
||||
Reference in New Issue
Block a user