diff --git a/src/main/java/work/slhaf/agent/common/model/ModelConstant.java b/src/main/java/work/slhaf/agent/common/model/ModelConstant.java index bae1dd5d..a0bde30e 100644 --- a/src/main/java/work/slhaf/agent/common/model/ModelConstant.java +++ b/src/main/java/work/slhaf/agent/common/model/ModelConstant.java @@ -304,4 +304,6 @@ public class ModelConstant { """; public static final String BASE_SUMMARIZER_PROMPT = """ """; + public static final String STATIC_MEMORY_EXTRACTOR_PROMPT = """ + """; } diff --git a/src/main/java/work/slhaf/agent/core/memory/MemoryGraph.java b/src/main/java/work/slhaf/agent/core/memory/MemoryGraph.java index 47627f64..7b0f094c 100644 --- a/src/main/java/work/slhaf/agent/core/memory/MemoryGraph.java +++ b/src/main/java/work/slhaf/agent/core/memory/MemoryGraph.java @@ -63,7 +63,7 @@ public class MemoryGraph extends PersistableObject { /** * 当前对话的活动性总结, 拥有比dialogMap更丰富的全文细节, 作为当前对话token超限时的必要上下文压缩存储 */ - private List currentCompressedSessionContext; +// private List currentCompressedSessionContext; /** * 存储确定性记忆, 如'用户爱好'等确定性信息 @@ -123,7 +123,7 @@ public class MemoryGraph extends PersistableObject { this.selectedSlices = new HashSet<>(); this.users = new ArrayList<>(); this.userDialogMap = new ConcurrentHashMap<>(); - this.currentCompressedSessionContext = new ArrayList<>(); +// this.currentCompressedSessionContext = new ArrayList<>(); this.dialogMap = new HashMap<>(); } @@ -255,24 +255,9 @@ public class MemoryGraph extends PersistableObject { String summary = slice.getSummary(); LocalDateTime now = LocalDateTime.now(); - //更新dialogMap ------------------------- - //移除两天前的上下文缓存(切片总结) - List keysToRemove = new ArrayList<>(); - dialogMap.forEach((k, v) -> { - if (now.minusDays(2).isAfter(k)) { - keysToRemove.add(k); - } - }); - for (LocalDateTime dateTime : keysToRemove) { - dialogMap.remove(dateTime); - } - keysToRemove.clear(); - //放入新缓存 - dialogMap.put(now, summary); - //--------------------------------------- - //更新userDialogMap //移除两天前上下文缓存(切片总结) + List keysToRemove = new ArrayList<>(); userDialogMap.forEach((k, v) -> { v.forEach((i, j) -> { if (now.minusDays(2).isAfter(i)) { @@ -288,7 +273,7 @@ public class MemoryGraph extends PersistableObject { //放入新缓存 userDialogMap .computeIfAbsent(slice.getStartUserId(), k -> new ConcurrentHashMap<>()) - .merge(now, slice.getSummary(), (oldVal, newVal) -> oldVal + " " + newVal); + .merge(now, summary, (oldVal, newVal) -> oldVal + " " + newVal); } @@ -484,5 +469,21 @@ public class MemoryGraph extends PersistableObject { } + public void updateDialogMap(LocalDateTime dateTime, String newDialogCache) { + List keysToRemove = new ArrayList<>(); + + dialogMap.forEach((k, v) -> { + if (dateTime.minusDays(2).isAfter(k)) { + keysToRemove.add(k); + } + }); + for (LocalDateTime temp : keysToRemove) { + dialogMap.remove(temp); + } + keysToRemove.clear(); + //放入新缓存 + dialogMap.put(dateTime, newDialogCache); + + } } diff --git a/src/main/java/work/slhaf/agent/core/memory/MemoryManager.java b/src/main/java/work/slhaf/agent/core/memory/MemoryManager.java index c3972959..010ae4ec 100644 --- a/src/main/java/work/slhaf/agent/core/memory/MemoryManager.java +++ b/src/main/java/work/slhaf/agent/core/memory/MemoryManager.java @@ -130,4 +130,12 @@ public class MemoryManager implements InteractionModule { memoryGraph.getChatMessages().removeAll(messages); messageCleanLock.unlock(); } + + public void insertStaticMemory(String userId, Map newStaticMemory) { + memoryGraph.getStaticMemory().get(userId).putAll(newStaticMemory); + } + + public void updateDialogMap(LocalDateTime dateTime,String newDialogCache) { + memoryGraph.updateDialogMap(dateTime, newDialogCache); + } } diff --git a/src/main/java/work/slhaf/agent/core/session/SessionManager.java b/src/main/java/work/slhaf/agent/core/session/SessionManager.java index 33990118..79a5b6ea 100644 --- a/src/main/java/work/slhaf/agent/core/session/SessionManager.java +++ b/src/main/java/work/slhaf/agent/core/session/SessionManager.java @@ -14,7 +14,6 @@ public class SessionManager { private static SessionManager sessionManager; private HashMap> singleMetaMessageMap; - private HashMap> multiMetaMessageMap; public static SessionManager getInstance() { if (sessionManager == null) { diff --git a/src/main/java/work/slhaf/agent/modules/memory/selector/extractor/MemorySelectExtractor.java b/src/main/java/work/slhaf/agent/modules/memory/selector/extractor/MemorySelectExtractor.java index 1d47f265..c1c74d95 100644 --- a/src/main/java/work/slhaf/agent/modules/memory/selector/extractor/MemorySelectExtractor.java +++ b/src/main/java/work/slhaf/agent/modules/memory/selector/extractor/MemorySelectExtractor.java @@ -5,15 +5,19 @@ import com.alibaba.fastjson2.JSONObject; import lombok.Data; import lombok.EqualsAndHashCode; import lombok.extern.slf4j.Slf4j; +import work.slhaf.agent.common.chat.pojo.Message; +import work.slhaf.agent.common.chat.pojo.MetaMessage; import work.slhaf.agent.common.config.Config; import work.slhaf.agent.common.model.Model; import work.slhaf.agent.common.model.ModelConstant; import work.slhaf.agent.core.interaction.data.InteractionContext; import work.slhaf.agent.core.memory.MemoryManager; +import work.slhaf.agent.core.session.SessionManager; import work.slhaf.agent.modules.memory.selector.extractor.data.ExtractorInput; import work.slhaf.agent.modules.memory.selector.extractor.data.ExtractorResult; import java.io.IOException; +import java.util.ArrayList; import java.util.List; import static work.slhaf.agent.common.util.ExtractUtil.extractJson; @@ -26,6 +30,7 @@ public class MemorySelectExtractor extends Model { private static MemorySelectExtractor memorySelectExtractor; private MemoryManager memoryManager; + private SessionManager sessionManager; private MemorySelectExtractor() { } @@ -35,6 +40,7 @@ public class MemorySelectExtractor extends Model { Config config = Config.getConfig(); memorySelectExtractor = new MemorySelectExtractor(); memorySelectExtractor.setMemoryManager(MemoryManager.getInstance()); + memorySelectExtractor.setSessionManager(SessionManager.getInstance()); setModel(config, memorySelectExtractor, MODEL_KEY, ModelConstant.SELECT_EXTRACTOR_PROMPT); } @@ -43,11 +49,16 @@ public class MemorySelectExtractor extends Model { public ExtractorResult execute(InteractionContext context) { //结构化为指定格式 - //TODO 将历史消息替换为sessionManager中的用户对应信息列表 + List chatMessages = new ArrayList<>(); + for (MetaMessage metaMessage : sessionManager.getSingleMetaMessageMap().get(context.getUserId())) { + chatMessages.add(metaMessage.getUserMessage()); + chatMessages.add(metaMessage.getAssistantMessage()); + } + ExtractorInput extractorInput = ExtractorInput.builder() .text(context.getInput()) .date(context.getDateTime().toLocalDate()) - .history(memoryManager.getChatMessages()) + .history(chatMessages) .topic_tree(memoryManager.getTopicTree()) .build(); String responseStr = extractJson(singleChat(JSONUtil.toJsonPrettyStr(extractorInput)).getMessage()); diff --git a/src/main/java/work/slhaf/agent/modules/memory/updater/MemoryUpdater.java b/src/main/java/work/slhaf/agent/modules/memory/updater/MemoryUpdater.java index 5db3b11a..8e0328f8 100644 --- a/src/main/java/work/slhaf/agent/modules/memory/updater/MemoryUpdater.java +++ b/src/main/java/work/slhaf/agent/modules/memory/updater/MemoryUpdater.java @@ -11,11 +11,14 @@ import work.slhaf.agent.core.memory.MemoryManager; import work.slhaf.agent.core.memory.pojo.MemorySlice; import work.slhaf.agent.core.session.SessionManager; import work.slhaf.agent.modules.memory.selector.extractor.MemorySelectExtractor; +import work.slhaf.agent.modules.memory.updater.static_extractor.StaticMemoryExtractor; +import work.slhaf.agent.modules.memory.updater.static_extractor.data.StaticMemoryExtractInput; import work.slhaf.agent.modules.memory.updater.summarizer.MemorySummarizer; import work.slhaf.agent.modules.memory.updater.summarizer.data.SummarizeResult; -import work.slhaf.agent.modules.memory.updater.summarizer.data.TotalSummarizeInput; +import work.slhaf.agent.modules.memory.updater.summarizer.data.SummarizeInput; import java.io.IOException; +import java.time.LocalDateTime; import java.util.*; import java.util.concurrent.Callable; import java.util.concurrent.ExecutorService; @@ -33,6 +36,7 @@ public class MemoryUpdater implements InteractionModule { private MemorySelectExtractor memorySelectExtractor; private MemorySummarizer memorySummarizer; private SessionManager sessionManager; + private StaticMemoryExtractor staticMemoryExtractor; private MemoryUpdater() { } @@ -45,13 +49,13 @@ public class MemoryUpdater implements InteractionModule { memoryUpdater.setMemorySummarizer(MemorySummarizer.getInstance()); memoryUpdater.setSessionManager(SessionManager.getInstance()); memoryUpdater.setUpdateExecutor(Executors.newSingleThreadExecutor()); + memoryUpdater.setStaticMemoryExtractor(StaticMemoryExtractor.getInstance()); } return memoryUpdater; } @Override - public void execute(InteractionContext interactionContext) throws InterruptedException { - //TODO 需要保持压缩上下文、更新总摘要、更新确定性记忆、总结所有切片后,更新dialogMap + public void execute(InteractionContext interactionContext) { if (interactionContext.isFinished()) { return; } @@ -59,16 +63,14 @@ public class MemoryUpdater implements InteractionModule { //如果token 大于阈值,则更新记忆 JSONObject moduleContext = interactionContext.getModuleContext(); if (moduleContext.getIntValue("total_token") > 24000) { - //更新单聊记忆,同时从chatMessages中去掉单聊记忆 + String userId = interactionContext.getUserId(); + HashMap singleMemorySummary = new HashMap<>(); try { - updateSingleChatSlices(interactionContext); - //更新多人场景下的记忆 - updateMultiChatSlices(interactionContext); - //更新确定性记忆 - executor.execute(() -> { - - }); - } catch (InterruptedException e) { + //更新单聊记忆以及该场景中对应的确定性记忆,同时从chatMessages中去掉单聊记忆 + updateSingleChatSlices(userId, singleMemorySummary); + //更新多人场景下的记忆及相关的确定性记忆 + updateMultiChatSlices(userId, singleMemorySummary); + } catch (InterruptedException | IOException | ClassNotFoundException e) { log.error("记忆更新线程出错: {}", e.getLocalizedMessage()); } @@ -78,41 +80,67 @@ public class MemoryUpdater implements InteractionModule { } - private void updateMultiChatSlices(InteractionContext interactionContext) { - //TODO 更新多人场景对话记忆 + private void updateMultiChatSlices(String userId, HashMap singleMemorySummary) throws InterruptedException, IOException, ClassNotFoundException { //此时chatMessages中不再包含单聊记录,直接执行摘要以及切片插入 - + //对剩下的多人聊天记录进行进行摘要 + executor.execute(() -> { + try { + SummarizeResult summarizeResult = memorySummarizer.execute(new SummarizeInput(memoryManager.getChatMessages(), memoryManager.getTopicTree())); + MemorySlice memorySlice = getMemorySlice(userId, summarizeResult, memoryManager.getChatMessages()); + memoryManager.insertSlice(memorySlice, summarizeResult.getTopicPath()); + //更新总dialogMap + singleMemorySummary.put("total", summarizeResult.getSummary()); + memoryManager.updateDialogMap(LocalDateTime.now(), memorySummarizer.executeTotalSummary(singleMemorySummary)); + } catch (IOException | ClassNotFoundException | InterruptedException e) { + log.error("多人场景记忆更新失败: {}", e.getLocalizedMessage()); + } + }); } - private void updateSingleChatSlices(InteractionContext interactionContext) throws InterruptedException { + private void updateSingleChatSlices(String interactionContext, HashMap singleMemorySummary) throws InterruptedException { //更新单聊记忆,同时从chatMessages中去掉单聊记忆 Set userIdSet = new HashSet<>(sessionManager.getSingleMetaMessageMap().keySet()); List> tasks = new ArrayList<>(); //多人聊天? for (String id : userIdSet) { + List messages = sessionManager.unpackAndClear(id); tasks.add(() -> { - List messages = sessionManager.unpackAndClear(id); try { - SummarizeResult summarizeResult = memorySummarizer.execute(new TotalSummarizeInput(messages, memoryManager.getTopicTree())); + //单聊记忆更新 + SummarizeResult summarizeResult = memorySummarizer.execute(new SummarizeInput(messages, memoryManager.getTopicTree())); MemorySlice memorySlice = getMemorySlice(interactionContext, summarizeResult, messages); + //插入时userDialogMap已经进行更新 memoryManager.insertSlice(memorySlice, summarizeResult.getTopicPath()); //从chatMessages中移除单聊记录 memoryManager.cleanMessage(messages); + //添加至singleMemorySummary + singleMemorySummary.put(id, summarizeResult.getSummary()); } catch (Exception e) { - log.error("记忆更新出错: {}", e.getLocalizedMessage()); + log.error("单聊记忆更新出错: {}", e.getLocalizedMessage()); } return null; }); + + tasks.add(() -> { + StaticMemoryExtractInput input = StaticMemoryExtractInput.builder() + .userId(id) + .messages(messages) + .existedStaticMemory(memoryManager.getStaticMemory(id)) + .build(); + Map staticMemoryResult = staticMemoryExtractor.execute(input); + memoryManager.insertStaticMemory(id, staticMemoryResult); + return null; + }); } executor.invokeAll(tasks); } - private static MemorySlice getMemorySlice(InteractionContext interactionContext, SummarizeResult summarizeResult, List chatMessages) { + private static MemorySlice getMemorySlice(String userId, SummarizeResult summarizeResult, List chatMessages) { MemorySlice memorySlice = new MemorySlice(); memorySlice.setPrivate(summarizeResult.isPrivate()); memorySlice.setSummary(summarizeResult.getSummary()); memorySlice.setChatMessages(chatMessages); - memorySlice.setStartUserId(interactionContext.getUserId()); + memorySlice.setStartUserId(userId); List> relatedTopicPathList = new ArrayList<>(); for (String string : summarizeResult.getRelatedTopicPath()) { List list = Arrays.stream(string.split("->")).toList(); diff --git a/src/main/java/work/slhaf/agent/modules/memory/updater/static_extractor/StaticMemoryExtractor.java b/src/main/java/work/slhaf/agent/modules/memory/updater/static_extractor/StaticMemoryExtractor.java new file mode 100644 index 00000000..90dbc3e6 --- /dev/null +++ b/src/main/java/work/slhaf/agent/modules/memory/updater/static_extractor/StaticMemoryExtractor.java @@ -0,0 +1,43 @@ +package work.slhaf.agent.modules.memory.updater.static_extractor; + +import cn.hutool.json.JSONUtil; +import com.alibaba.fastjson2.JSONObject; +import lombok.Data; +import lombok.EqualsAndHashCode; +import work.slhaf.agent.common.chat.pojo.ChatResponse; +import work.slhaf.agent.common.config.Config; +import work.slhaf.agent.common.model.Model; +import work.slhaf.agent.common.model.ModelConstant; +import work.slhaf.agent.modules.memory.updater.static_extractor.data.StaticMemoryExtractInput; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; + +@EqualsAndHashCode(callSuper = true) +@Data +public class StaticMemoryExtractor extends Model { + + private static StaticMemoryExtractor staticMemoryExtractor; + + private static final String MODEL_KEY = "static_memory_extractor"; + + + public static StaticMemoryExtractor getInstance() throws IOException, ClassNotFoundException { + if (staticMemoryExtractor == null) { + staticMemoryExtractor = new StaticMemoryExtractor(); + setModel(Config.getConfig(), staticMemoryExtractor, MODEL_KEY, ModelConstant.STATIC_MEMORY_EXTRACTOR_PROMPT); + } + return staticMemoryExtractor; + } + + public Map execute(StaticMemoryExtractInput input) { + ChatResponse response = singleChat(JSONUtil.toJsonPrettyStr(input)); + JSONObject jsonObject = JSONObject.parseObject(response.getMessage()); + Map result = new HashMap<>(); + jsonObject.forEach((k, v) -> { + result.put(k, (String) v); + }); + return result; + } +} diff --git a/src/main/java/work/slhaf/agent/modules/memory/updater/static_extractor/data/StaticMemoryExtractInput.java b/src/main/java/work/slhaf/agent/modules/memory/updater/static_extractor/data/StaticMemoryExtractInput.java new file mode 100644 index 00000000..690872b2 --- /dev/null +++ b/src/main/java/work/slhaf/agent/modules/memory/updater/static_extractor/data/StaticMemoryExtractInput.java @@ -0,0 +1,17 @@ +package work.slhaf.agent.modules.memory.updater.static_extractor.data; + +import lombok.Builder; +import lombok.Data; +import work.slhaf.agent.common.chat.pojo.Message; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +@Data +@Builder +public class StaticMemoryExtractInput { + private String userId; + private List messages; + private Map existedStaticMemory; +} diff --git a/src/main/java/work/slhaf/agent/modules/memory/updater/summarizer/MemorySummarizer.java b/src/main/java/work/slhaf/agent/modules/memory/updater/summarizer/MemorySummarizer.java index 896fe0ad..87a20285 100644 --- a/src/main/java/work/slhaf/agent/modules/memory/updater/summarizer/MemorySummarizer.java +++ b/src/main/java/work/slhaf/agent/modules/memory/updater/summarizer/MemorySummarizer.java @@ -14,14 +14,17 @@ import work.slhaf.agent.common.model.Model; import work.slhaf.agent.common.model.ModelConstant; import work.slhaf.agent.core.interaction.InteractionThreadPoolExecutor; import work.slhaf.agent.modules.memory.updater.summarizer.data.SummarizeResult; -import work.slhaf.agent.modules.memory.updater.summarizer.data.TotalSummarizeInput; +import work.slhaf.agent.modules.memory.updater.summarizer.data.SummarizeInput; import java.io.IOException; import java.util.ArrayList; +import java.util.HashMap; import java.util.List; import java.util.concurrent.Callable; import java.util.concurrent.TimeUnit; +import static work.slhaf.agent.common.util.ExtractUtil.extractJson; + @EqualsAndHashCode(callSuper = true) @Data @Slf4j @@ -42,16 +45,16 @@ public class MemorySummarizer extends Model { return memorySummarizer; } - public SummarizeResult execute(TotalSummarizeInput input) throws InterruptedException { - //进行长文本批量摘要 - singleMessageSummarize(input.getChatMessages()); - //进行整体摘要并返回结果 - return multiMessageSummarize(input); + public SummarizeResult execute(SummarizeInput input) throws InterruptedException { + //进行长文本批量摘要 + singleMessageSummarize(input.getChatMessages()); + //进行整体摘要并返回结果 + return multiMessageSummarize(input); } - private SummarizeResult multiMessageSummarize(TotalSummarizeInput input) { + private SummarizeResult multiMessageSummarize(SummarizeInput input) { String messageStr = JSONUtil.toJsonPrettyStr(input); - return multiSummarizeExecute(prompts.get(1),messageStr); + return multiSummarizeExecute(prompts.get(1), messageStr); } private SummarizeResult multiSummarizeExecute(String prompt, String messageStr) { @@ -73,7 +76,7 @@ public class MemorySummarizer extends Model { } } } - executor.invokeAll(tasks,30, TimeUnit.SECONDS); + executor.invokeAll(tasks, 30, TimeUnit.SECONDS); } private @NonNull String singleSummarizeExecute(String prompt, String content) { @@ -88,4 +91,9 @@ public class MemorySummarizer extends Model { } + public String executeTotalSummary(HashMap singleMemorySummary) { + ChatResponse response = chatClient.runChat(List.of(new Message(ChatConstant.Character.SYSTEM, prompts.get(2)), + new Message(ChatConstant.Character.USER, JSONUtil.toJsonPrettyStr(singleMemorySummary)))); + return JSONObject.parseObject(extractJson(response.getMessage())).getString("value"); + } } diff --git a/src/main/java/work/slhaf/agent/modules/memory/updater/summarizer/data/TotalSummarizeInput.java b/src/main/java/work/slhaf/agent/modules/memory/updater/summarizer/data/SummarizeInput.java similarity index 89% rename from src/main/java/work/slhaf/agent/modules/memory/updater/summarizer/data/TotalSummarizeInput.java rename to src/main/java/work/slhaf/agent/modules/memory/updater/summarizer/data/SummarizeInput.java index 5f74ddc1..d36ad98f 100644 --- a/src/main/java/work/slhaf/agent/modules/memory/updater/summarizer/data/TotalSummarizeInput.java +++ b/src/main/java/work/slhaf/agent/modules/memory/updater/summarizer/data/SummarizeInput.java @@ -8,7 +8,7 @@ import java.util.List; @AllArgsConstructor @Data -public class TotalSummarizeInput { +public class SummarizeInput { private List chatMessages; private String topicTree; }