diff --git a/src/main/java/work/slhaf/agent/common/config/Config.java b/src/main/java/work/slhaf/agent/common/config/Config.java index b1438d00..dacbb6b2 100644 --- a/src/main/java/work/slhaf/agent/common/config/Config.java +++ b/src/main/java/work/slhaf/agent/common/config/Config.java @@ -27,7 +27,7 @@ public class Config { private static Config config; private String agentId; - private String basicCharacter; +// private String basicCharacter; private WebSocketConfig webSocketConfig; @@ -48,8 +48,8 @@ public class Config { System.out.print("输入智能体名称: "); config.setAgentId(scanner.nextLine()); - System.out.print("输入智能体基础角色设定: "); - config.setBasicCharacter(scanner.nextLine()); +// System.out.print("输入智能体基础角色设定: "); +// config.setBasicCharacter(scanner.nextLine()); System.out.println("(注意! 设定角色之后修改主配置文件将不会影响现有记忆,除非同时更换agentId)"); diff --git a/src/main/java/work/slhaf/agent/core/memory/MemoryGraph.java b/src/main/java/work/slhaf/agent/core/memory/MemoryGraph.java index 2a6f6111..02283dbd 100644 --- a/src/main/java/work/slhaf/agent/core/memory/MemoryGraph.java +++ b/src/main/java/work/slhaf/agent/core/memory/MemoryGraph.java @@ -81,7 +81,7 @@ public class MemoryGraph extends PersistableObject { * 存储确定性记忆, 如'用户爱好'等确定性信息 * 该部分作为'主LLM'system prompt常驻 */ - private HashMap> staticMemory; +// private HashMap> staticMemory; /** * memorySliceCache计数器,每日清空 @@ -121,12 +121,12 @@ public class MemoryGraph extends PersistableObject { */ private Set selectedSlices; - public MemoryGraph(String id, String basicCharacter) { + public MemoryGraph(String id) { this.id = id; this.topicNodes = new HashMap<>(); this.existedTopics = new HashMap<>(); this.currentDateDialogSlices = new HashMap<>(); - this.staticMemory = new HashMap<>(); +// this.staticMemory = new HashMap<>(); this.memoryNodeCacheCounter = new ConcurrentHashMap<>(); this.memorySliceCache = new ConcurrentHashMap<>(); this.modelPrompt = new HashMap<>(); @@ -139,7 +139,7 @@ public class MemoryGraph extends PersistableObject { this.dateIndex = new HashMap<>(); } - public static MemoryGraph getInstance(String id, String basicCharacter) throws IOException, ClassNotFoundException { + public static MemoryGraph getInstance(String id) throws IOException, ClassNotFoundException { if (memoryGraph == null) { synchronized (MemoryGraph.class) { // 检查存储目录是否存在,不存在则创建 @@ -150,7 +150,7 @@ public class MemoryGraph extends PersistableObject { memoryGraph = deserialize(id); } else { FileUtils.createParentDirectories(filePath.toFile().getParentFile()); - memoryGraph = new MemoryGraph(id, basicCharacter); + memoryGraph = new MemoryGraph(id); memoryGraph.serialize(); } log.info("MemoryGraph注册完毕..."); diff --git a/src/main/java/work/slhaf/agent/core/memory/MemoryManager.java b/src/main/java/work/slhaf/agent/core/memory/MemoryManager.java index e700ed64..2aef5ce7 100644 --- a/src/main/java/work/slhaf/agent/core/memory/MemoryManager.java +++ b/src/main/java/work/slhaf/agent/core/memory/MemoryManager.java @@ -45,7 +45,7 @@ public class MemoryManager extends PersistableObject { if (memoryManager == null) { Config config = Config.getConfig(); memoryManager = new MemoryManager(); - memoryManager.setMemoryGraph(MemoryGraph.getInstance(config.getAgentId(), config.getBasicCharacter())); + memoryManager.setMemoryGraph(MemoryGraph.getInstance(config.getAgentId())); memoryManager.setActivatedSlices(new HashMap<>()); memoryManager.setShutdownHook(); log.info("[MemoryManager] MemoryManager注册完毕..."); @@ -115,9 +115,9 @@ public class MemoryManager extends PersistableObject { return memoryGraph.getTopicTree(); } - public ConcurrentHashMap getStaticMemory(String userId) { +/* public ConcurrentHashMap getStaticMemory(String userId) { return memoryGraph.getStaticMemory().get(userId); - } + }*/ public HashMap getDialogMap() { return memoryGraph.getDialogMap(); @@ -145,12 +145,12 @@ public class MemoryManager extends PersistableObject { messageCleanLock.unlock(); } - public void insertStaticMemory(String userId, Map newStaticMemory) { +/* public void insertStaticMemory(String userId, Map newStaticMemory) { if (!memoryGraph.getStaticMemory().containsKey(userId)) { memoryGraph.getStaticMemory().put(userId, new ConcurrentHashMap<>()); } memoryGraph.getStaticMemory().get(userId).putAll(newStaticMemory); - } + }*/ public void updateDialogMap(LocalDateTime dateTime, String newDialogCache) { memoryGraph.updateDialogMap(dateTime, newDialogCache); @@ -173,4 +173,37 @@ public class MemoryManager extends PersistableObject { } return null; } + + public String getActivatedSlicesStr(String userId) { + StringBuilder str = new StringBuilder(); + if (memoryManager.getActivatedSlices().containsKey(userId)) { + memoryManager.getActivatedSlices().get(userId).forEach(slice -> { + str.append("\n\n").append("[").append(slice.getDate()).append("]\n") + .append(slice.getSummary()); + }); + } + return str.toString(); + } + + public String getDialogMapStr() { + StringBuilder str = new StringBuilder(); + memoryGraph.getDialogMap().forEach((dateTime, dialog) -> { + str.append("\n\n").append("[").append(dateTime).append("]\n") + .append(dialog); + }); + return str.toString(); + } + + public String getUserDialogMapStr(String userId) { + StringBuilder str = new StringBuilder(); + Collection dialogMapValues = memoryGraph.getDialogMap().values(); + memoryGraph.getUserDialogMap().get(userId).forEach((dateTime, dialog) -> { + if (dialogMapValues.contains(dialog)) { + return; + } + str.append("\n\n").append("[").append(dateTime).append("]\n") + .append(dialog); + }); + return str.toString(); + } } diff --git a/src/main/java/work/slhaf/agent/module/common/ModelConstant.java b/src/main/java/work/slhaf/agent/module/common/ModelConstant.java index 2b4342da..e8048bd5 100644 --- a/src/main/java/work/slhaf/agent/module/common/ModelConstant.java +++ b/src/main/java/work/slhaf/agent/module/common/ModelConstant.java @@ -9,7 +9,7 @@ public class ModelConstant { } public static class CharacterPrefix { - public static final String SYSTEM = "[system] "; + public static final String SYSTEM = "[SYSTEM] "; } } diff --git a/src/main/java/work/slhaf/agent/module/modules/core/CoreModel.java b/src/main/java/work/slhaf/agent/module/modules/core/CoreModel.java index b75eac86..5a385857 100644 --- a/src/main/java/work/slhaf/agent/module/modules/core/CoreModel.java +++ b/src/main/java/work/slhaf/agent/module/modules/core/CoreModel.java @@ -153,18 +153,45 @@ public class CoreModel extends Model implements InteractionModule { private void setAppendedPromptMessage(List appendPrompt) { Message appendDeclareMessage = Message.builder() .role(ChatConstant.Character.USER) - .content(ModelConstant.CharacterPrefix.SYSTEM + "以下为追加字段声明,可能包含用户的输入字段和你需要在回应中添加的输出字段.") +// .content(ModelConstant.CharacterPrefix.SYSTEM + "以下为追加字段声明,可能包含用户的输入字段和你需要在回应中添加的输出字段.") + .content(ModelConstant.CharacterPrefix.SYSTEM + "以下为你的相关认知内容,可在对话中参考") .build(); this.appendedMessages.add(appendDeclareMessage); for (AppendPromptData data : appendPrompt) { - StringBuilder str = new StringBuilder(data.getComment()).append("\r\n"); - data.getAppendedPrompt().forEach((k, v) -> str.append(k).append(": ").append(v).append("\r\n")); - appendedMessages.add(new Message(ChatConstant.Character.USER, str.toString())); + setStartMessage(data); + setContentMessage(data); + setEndMessage(data); } Message appendEndMessage = Message.builder() .role(ChatConstant.Character.USER) - .content(ModelConstant.CharacterPrefix.SYSTEM + "追加字段声明结束,接下来为用户的真实输入。") + .content(ModelConstant.CharacterPrefix.SYSTEM + "相关认知内容结束,接下来是‘你’——‘Partner’与用户的真正交互") .build(); this.appendedMessages.add(appendEndMessage); } + + private void setEndMessage(AppendPromptData data) { + Message endMessage = Message.builder() + .role(ChatConstant.Character.USER) + .content(ModelConstant.CharacterPrefix.SYSTEM + data.getComment() + "认知补充结束.") + .build(); + appendedMessages.add(endMessage); + } + + private void setContentMessage(AppendPromptData data) { + data.getAppendedPrompt().forEach((k, v) -> { + Message contentMessage = Message.builder() + .role(ChatConstant.Character.USER) + .content(ModelConstant.CharacterPrefix.SYSTEM + k + v) + .build(); + appendedMessages.add(contentMessage); + }); + } + + private void setStartMessage(AppendPromptData data) { + Message startMessage = Message.builder() + .role(ChatConstant.Character.USER) + .content(ModelConstant.CharacterPrefix.SYSTEM + data.getComment() + "以下为" + data.getComment() + "相关认知.") + .build(); + appendedMessages.add(startMessage); + } } diff --git a/src/main/java/work/slhaf/agent/module/modules/memory/selector/MemorySelector.java b/src/main/java/work/slhaf/agent/module/modules/memory/selector/MemorySelector.java index c3766396..5c22bbb6 100644 --- a/src/main/java/work/slhaf/agent/module/modules/memory/selector/MemorySelector.java +++ b/src/main/java/work/slhaf/agent/module/modules/memory/selector/MemorySelector.java @@ -23,6 +23,7 @@ import work.slhaf.agent.shared.memory.EvaluatedSlice; import java.io.IOException; import java.time.LocalDate; import java.util.ArrayList; +import java.util.Collection; import java.util.HashMap; import java.util.List; @@ -89,21 +90,22 @@ public class MemorySelector implements InteractionModule, AppendPrompt { //获取主题路径 ExtractorResult extractorResult = memorySelectExtractor.execute(interactionContext); if (extractorResult.isRecall() || !extractorResult.getMatches().isEmpty()) { - selectAndEvaluateMemory(interactionContext, extractorResult, userId); - } - if (extractorResult.isRecall()) { - memoryManager.getActivatedSlices().clear(); + memoryManager.getActivatedSlices().get(userId).clear(); + List evaluatedSlices = selectAndEvaluateMemory(interactionContext, extractorResult); + memoryManager.updateActivatedSlices(userId, evaluatedSlices); } //设置上下文 - setModuleContext(interactionContext, userId); +// setCoreContext(interactionContext); //设置追加提示词 setAppendedPrompt(interactionContext); + setModuleContextRecall(interactionContext); log.debug("[MemorySelector] 记忆回溯结果: {}", interactionContext); } - private void selectAndEvaluateMemory(InteractionContext interactionContext, ExtractorResult extractorResult, String userId) throws IOException, ClassNotFoundException, InterruptedException { + private List selectAndEvaluateMemory(InteractionContext interactionContext, ExtractorResult extractorResult) throws IOException, ClassNotFoundException, InterruptedException { log.debug("[MemorySelector] 触发记忆回溯..."); //查找切片 + String userId = interactionContext.getUserId(); List memoryResultList = new ArrayList<>(); setMemoryResultList(memoryResultList, extractorResult.getMatches(), userId); //评估切片 @@ -115,18 +117,19 @@ public class MemorySelector implements InteractionModule, AppendPrompt { log.debug("[MemorySelector] 切片评估输入: {}", evaluatorInput); List memorySlices = sliceSelectEvaluator.execute(evaluatorInput); log.debug("[MemorySelector] 切片评估结果: {}", memorySlices); - memoryManager.updateActivatedSlices(userId, memorySlices); + return memorySlices; } - private void setModuleContext(InteractionContext interactionContext, String userId) { + /*private void setCoreContext(InteractionContext interactionContext) { + String userId = interactionContext.getUserId(); interactionContext.getCoreContext().put("memory_slices", memoryManager.getActivatedSlices().get(userId)); - interactionContext.getCoreContext().put("static_memory", memoryManager.getStaticMemory(userId)); +// interactionContext.getCoreContext().put("static_memory", memoryManager.getStaticMemory(userId)); interactionContext.getCoreContext().put("dialog_map", memoryManager.getDialogMap()); interactionContext.getCoreContext().put("user_dialog_map", memoryManager.getUserDialogMap(userId)); - setModuleContextRecall(interactionContext, userId); - } + }*/ - private void setModuleContextRecall(InteractionContext interactionContext, String userId) { + private void setModuleContextRecall(InteractionContext interactionContext) { + String userId = interactionContext.getUserId(); boolean recall; if (memoryManager.getActivatedSlices().get(userId) == null) { recall = false; @@ -150,6 +153,7 @@ public class MemorySelector implements InteractionModule, AppendPrompt { default -> null; }; if (memoryResult == null) continue; + removeDuplicateSlice(memoryResult); memoryResultList.add(memoryResult); } catch (UnExistedDateIndexException | UnExistedTopicException e) { log.error("[MemorySelector] 不存在的记忆索引! 请尝试更换更合适的主题提取LLM!", e); @@ -168,6 +172,12 @@ public class MemorySelector implements InteractionModule, AppendPrompt { } } + private void removeDuplicateSlice(MemoryResult memoryResult) { + Collection values = memoryManager.getDialogMap().values(); + memoryResult.getRelatedMemorySliceResult().removeIf(m -> values.contains(m.getSummary())); + memoryResult.getMemorySliceResult().removeIf(m -> values.contains(m.getMemorySlice().getSummary())); + } + private boolean removeOrNot(MemorySlice memorySlice, String userId) { if (memorySlice.isPrivate()) { return memorySlice.getStartUserId().equals(userId); @@ -177,14 +187,30 @@ public class MemorySelector implements InteractionModule, AppendPrompt { @Override public void setAppendedPrompt(InteractionContext context) { - HashMap map = new HashMap<>(); - map.put("memory_slices", "本次对话可参考的记忆切片"); - map.put("static_memory", "关于本次对话对象的稳定记忆"); - map.put("dialog_map", "近两日的与所有用户的对话缓存"); - map.put("user_dialog_map", "与当前用户的近两日对话缓存"); + String userId = context.getUserId(); + HashMap map = getPromptDataMap(userId); AppendPromptData data = new AppendPromptData(); - data.setComment("[system] 追加字段: 记忆模块"); + data.setComment("[记忆模块]"); data.setAppendedPrompt(map); context.setAppendedPrompt(data); } + + private HashMap getPromptDataMap(String userId) { + HashMap map = new HashMap<>(); + String dialogMapStr = memoryManager.getDialogMapStr(); + if (!dialogMapStr.isEmpty()) { + map.put("[记忆缓存] <你最近两日和所有聊天者的对话记忆印象>", dialogMapStr); + } + + String userDialogMapStr = memoryManager.getUserDialogMapStr(userId); + if (!userDialogMapStr.isEmpty()) { + map.put("[用户记忆缓存] <与最新一条消息的发送者的近两天对话记忆印象, 可能与[记忆缓存]稍有重复>", "与当前用户的近两日对话缓存"); + } + + String sliceStr = memoryManager.getActivatedSlicesStr(userId); + if (!sliceStr.isEmpty()){ + map.put("[记忆切片] <你与最新一条消息的发送者的相关回忆, 不会与[记忆缓存]重复, 如果有重复你也可以指出来()>", sliceStr); + } + return map; + } } diff --git a/src/test/java/memory/MemoryTest.java b/src/test/java/memory/MemoryTest.java index 4fafe7ef..814e2398 100644 --- a/src/test/java/memory/MemoryTest.java +++ b/src/test/java/memory/MemoryTest.java @@ -12,7 +12,7 @@ public class MemoryTest { //@Test public void test1() { String basicCharacter = ""; - MemoryGraph graph = new MemoryGraph("test", basicCharacter); + MemoryGraph graph = new MemoryGraph("test"); HashMap topicMap = new HashMap<>(); TopicNode root1 = new TopicNode();