From cb85192c50f343dd997ffefa70890605f92b7367 Mon Sep 17 00:00:00 2001 From: slhaf Date: Sun, 20 Apr 2025 23:07:22 +0800 Subject: [PATCH] =?UTF-8?q?-=20MemoryGraph=20=E6=96=B0=E5=A2=9E=E8=BE=93?= =?UTF-8?q?=E5=87=BA=E4=B8=BB=E9=A2=98=E6=A0=91=E5=8A=9F=E8=83=BD=20-=20?= =?UTF-8?q?=E5=B0=86=20TopicExtractor=20=E9=87=8D=E5=91=BD=E5=90=8D?= =?UTF-8?q?=E4=B8=BA=20MemorySelectExtractor=20,=E5=B9=B6=E6=B7=BB?= =?UTF-8?q?=E5=8A=A0=E4=BA=86=E6=8F=90=E7=A4=BA=E8=AF=8D=20-=20=E8=AE=B0?= =?UTF-8?q?=E5=BF=86=E6=A8=A1=E5=9D=97=E5=BC=80=E5=8F=91=E5=B7=A5=E4=BD=9C?= =?UTF-8?q?=E8=BF=9B=E8=A1=8C=E4=B8=AD=20-=20=E6=96=B0=E5=A2=9E=20SliceSum?= =?UTF-8?q?mary=20=E7=B1=BB=EF=BC=8C=E6=9C=8D=E5=8A=A1=E4=BA=8E=E8=AE=B0?= =?UTF-8?q?=E5=BF=86=E6=A8=A1=E5=9D=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/main/java/work/slhaf/Main.java | 12 ++- src/main/java/work/slhaf/agent/Agent.java | 8 +- .../slhaf/agent/common/config/Config.java | 7 +- .../agent/common/model/ModelConstant.java | 50 +++++++++++ .../work/slhaf/agent/core/InteractionHub.java | 6 +- .../core/interaction/InteractionModule.java | 4 +- .../interaction/data/InteractionContext.java | 11 +-- .../data/InteractionInputData.java | 1 + .../slhaf/agent/core/memory/MemoryGraph.java | 39 ++++++-- .../agent/core/memory/MemoryManager.java | 39 ++++++++ .../agent/core/memory/pojo/MemorySlice.java | 4 +- .../slhaf/agent/core/model/CoreModel.java | 2 +- .../agent/gateway/AgentWebSocketServer.java | 4 +- .../modules/memory/MemorySelectExtractor.java | 41 +++++++++ .../agent/modules/memory/MemorySelector.java | 27 +++++- .../agent/modules/memory/MemoryUpdater.java | 2 + .../agent/modules/memory/SliceEvaluator.java | 66 +++++++++++++- .../modules/memory/data/SliceSummary.java | 9 ++ .../preprocess/PreprocessExecutor.java | 3 + .../agent/modules/task/TaskEvaluator.java | 1 - .../agent/modules/topic/TopicExtractor.java | 41 --------- src/test/java/memory/AITest.java | 89 +++++++++++++++++++ src/test/java/memory/MemoryTest.java | 62 +++++++++++++ src/test/java/memory/SearchTest.java | 8 +- 24 files changed, 449 insertions(+), 87 deletions(-) create mode 100644 src/main/java/work/slhaf/agent/modules/memory/MemorySelectExtractor.java create mode 100644 src/main/java/work/slhaf/agent/modules/memory/data/SliceSummary.java delete mode 100644 src/main/java/work/slhaf/agent/modules/topic/TopicExtractor.java create mode 100644 src/test/java/memory/AITest.java create mode 100644 src/test/java/memory/MemoryTest.java diff --git a/src/main/java/work/slhaf/Main.java b/src/main/java/work/slhaf/Main.java index 09cac49a..ae0564cd 100644 --- a/src/main/java/work/slhaf/Main.java +++ b/src/main/java/work/slhaf/Main.java @@ -1,12 +1,20 @@ package work.slhaf; import work.slhaf.agent.Agent; +import work.slhaf.agent.core.interaction.data.InteractionInputData; import java.io.IOException; public class Main { - public static void main(String[] args) throws IOException { + public static void main(String[] args) throws IOException, ClassNotFoundException { Agent agent = Agent.initialize(); - agent.receiveUserInput("111","222","hello"); + + InteractionInputData inputData = new InteractionInputData(); + inputData.setContent("hello"); + inputData.setPlatform("cli"); + inputData.setUserInfo("owner"); + inputData.setUserNickName("master"); + + agent.receiveUserInput(inputData); } } \ No newline at end of file diff --git a/src/main/java/work/slhaf/agent/Agent.java b/src/main/java/work/slhaf/agent/Agent.java index b803e29d..556f7ef6 100644 --- a/src/main/java/work/slhaf/agent/Agent.java +++ b/src/main/java/work/slhaf/agent/Agent.java @@ -36,13 +36,9 @@ public class Agent implements TaskCallback { /** * 接收用户输入,包装为标准输入数据类 - * @param input + * @param inputData */ - public void receiveUserInput(String userNickName,String userInfo,String input) throws IOException { - InteractionInputData inputData = new InteractionInputData(); - inputData.setContent(input); - inputData.setUserInfo(userInfo); - inputData.setUserNickName(userNickName); + public void receiveUserInput(InteractionInputData inputData) throws IOException, ClassNotFoundException { inputData.setLocalDateTime(LocalDateTime.now()); interactionHub.call(inputData); } diff --git a/src/main/java/work/slhaf/agent/common/config/Config.java b/src/main/java/work/slhaf/agent/common/config/Config.java index fbbd2411..40f381a4 100644 --- a/src/main/java/work/slhaf/agent/common/config/Config.java +++ b/src/main/java/work/slhaf/agent/common/config/Config.java @@ -5,13 +5,12 @@ import lombok.Data; import lombok.extern.slf4j.Slf4j; import org.apache.commons.io.FileUtils; import work.slhaf.agent.core.model.CoreModel; -import work.slhaf.agent.core.memory.MemoryManager; +import work.slhaf.agent.modules.memory.MemorySelectExtractor; import work.slhaf.agent.modules.memory.MemorySelector; import work.slhaf.agent.modules.memory.MemoryUpdater; import work.slhaf.agent.modules.memory.SliceEvaluator; import work.slhaf.agent.modules.task.TaskEvaluator; import work.slhaf.agent.modules.task.TaskScheduler; -import work.slhaf.agent.modules.topic.TopicExtractor; import java.io.File; import java.io.IOException; @@ -67,7 +66,7 @@ public class Config { private static void generatePipelineConfig() { List moduleConfigList = List.of( - new ModuleConfig(TopicExtractor.class.getName(), ModuleConfig.Constant.INTERNAL, null), + new ModuleConfig(MemorySelectExtractor.class.getName(), ModuleConfig.Constant.INTERNAL, null), new ModuleConfig(MemorySelector.class.getName(), ModuleConfig.Constant.INTERNAL, null), new ModuleConfig(CoreModel.class.getName(),ModuleConfig.Constant.INTERNAL,null), new ModuleConfig(MemoryUpdater.class.getName(),ModuleConfig.Constant.INTERNAL,null), @@ -100,7 +99,7 @@ public class Config { } case 3 -> { System.out.println("TopicExtractor:"); - yield TopicExtractor.MODEL_KEY; + yield MemorySelectExtractor.MODEL_KEY; } default -> throw new RuntimeException(); }; diff --git a/src/main/java/work/slhaf/agent/common/model/ModelConstant.java b/src/main/java/work/slhaf/agent/common/model/ModelConstant.java index a43652fe..114b085c 100644 --- a/src/main/java/work/slhaf/agent/common/model/ModelConstant.java +++ b/src/main/java/work/slhaf/agent/common/model/ModelConstant.java @@ -6,6 +6,56 @@ public class ModelConstant { public static final String SLICE_EVALUATOR_PROMPT = """ """; public static final String TOPIC_EXTRACTOR_PROMPT = """ + # MemorySelectExtractor 提示词 + + ## 功能说明 + 你需要根据用户输入的JSON数据,分析其`text`字段内容,判断是否需要通过主题路径或日期进行记忆查询,并返回标准化格式的JSON响应。 + + ## 输入字段说明 + - `text`: 用户输入的文本内容 + - `topic_tree`: 当前可用的主题树结构(括号内数字表示子主题数量) + - `date`: 当前对话发生的日期(用于时间推理) + + ## 输出规则 + 1. 当文本涉及明确主题路径时: + - 使用`"type": "topic"` + - `text`字段格式为"根主题->子主题->子子主题"(必须**完全匹配**topic_tree中的层级,包括从[root]到目标主题的完整路径) + - 示例:{ + "type": "topic", + "text": "工作->项目A->需求文档" + } + + 2. 当文本包含明确可推算的日期时: + - 使用`"type": "date"` + - 日期格式必须为"YYYY-MM-DD" + - 仅接受具体日期(不接受"上周"等模糊表达) + - 示例:{ + "type": "date", + "text": "2024-04-15" + } + + 3. 当不需要查询或无法确定时: + - 使用`"type": "none"` + - 示例:{ + "type": "none" + } + + ## 完整示例 + 用户输入:{ + "text": "还记得我们讨论过游戏引擎的物理系统实现吗?", + "topic_tree": " + 技术 (3)[root] + ├── 游戏开发 (2) + │ ├── 图形渲染 (1) + │ └── 物理系统 (0) + └── 人工智能 (1)", + "date": "2024-04-20" + } + + 正确响应:{ + "type": "topic", + "text": "技术->游戏开发->物理系统" + } """; public static final String TASK_EVALUATOR_PROMPT = """ """; diff --git a/src/main/java/work/slhaf/agent/core/InteractionHub.java b/src/main/java/work/slhaf/agent/core/InteractionHub.java index ab8a976c..cd47a5cc 100644 --- a/src/main/java/work/slhaf/agent/core/InteractionHub.java +++ b/src/main/java/work/slhaf/agent/core/InteractionHub.java @@ -7,8 +7,8 @@ import work.slhaf.agent.core.interaction.InteractionModulesLoader; import work.slhaf.agent.core.interaction.TaskCallback; import work.slhaf.agent.core.interaction.data.InteractionContext; import work.slhaf.agent.core.interaction.data.InteractionInputData; -import work.slhaf.agent.core.model.CoreModel; import work.slhaf.agent.core.memory.MemoryManager; +import work.slhaf.agent.core.model.CoreModel; import work.slhaf.agent.modules.preprocess.PreprocessExecutor; import work.slhaf.agent.modules.task.TaskScheduler; @@ -35,7 +35,7 @@ public class InteractionHub { return interactionHub; } - public void call(InteractionInputData inputData) throws IOException { + public void call(InteractionInputData inputData) throws IOException, ClassNotFoundException { //预处理 InteractionContext interactionContext = PreprocessExecutor.getInstance().execute(inputData); //加载模块 @@ -43,6 +43,6 @@ public class InteractionHub { for (InteractionModule interactionModule : interactionModules) { interactionModule.execute(interactionContext); } - callback.onTaskFinished(interactionContext.getUserInfo(), interactionContext.getCoreResponse().getMessage()); + callback.onTaskFinished(interactionContext.getUserInfo(), interactionContext.getCoreResponse().getString("message")); } } diff --git a/src/main/java/work/slhaf/agent/core/interaction/InteractionModule.java b/src/main/java/work/slhaf/agent/core/interaction/InteractionModule.java index 48d80e65..db2de53b 100644 --- a/src/main/java/work/slhaf/agent/core/interaction/InteractionModule.java +++ b/src/main/java/work/slhaf/agent/core/interaction/InteractionModule.java @@ -2,6 +2,8 @@ package work.slhaf.agent.core.interaction; import work.slhaf.agent.core.interaction.data.InteractionContext; +import java.io.IOException; + public interface InteractionModule { - void execute(InteractionContext context); + void execute(InteractionContext context) throws IOException, ClassNotFoundException; } diff --git a/src/main/java/work/slhaf/agent/core/interaction/data/InteractionContext.java b/src/main/java/work/slhaf/agent/core/interaction/data/InteractionContext.java index f4680eb5..9725b954 100644 --- a/src/main/java/work/slhaf/agent/core/interaction/data/InteractionContext.java +++ b/src/main/java/work/slhaf/agent/core/interaction/data/InteractionContext.java @@ -2,12 +2,8 @@ package work.slhaf.agent.core.interaction.data; import com.alibaba.fastjson2.JSONObject; import lombok.Data; -import work.slhaf.agent.common.chat.pojo.ChatResponse; -import work.slhaf.agent.core.memory.pojo.MemorySlice; -import work.slhaf.agent.modules.task.data.TaskData; import java.time.LocalDateTime; -import java.util.List; @Data public class InteractionContext { @@ -17,10 +13,7 @@ public class InteractionContext { protected boolean finished; protected String input; - protected JSONObject tempResult; - protected ChatResponse coreResponse; - protected List memorySlices; - protected List topicPath; - protected List taskDataList; + protected JSONObject moduleContext; + protected JSONObject coreResponse; } diff --git a/src/main/java/work/slhaf/agent/core/interaction/data/InteractionInputData.java b/src/main/java/work/slhaf/agent/core/interaction/data/InteractionInputData.java index 274b3249..76c4e5c0 100644 --- a/src/main/java/work/slhaf/agent/core/interaction/data/InteractionInputData.java +++ b/src/main/java/work/slhaf/agent/core/interaction/data/InteractionInputData.java @@ -10,4 +10,5 @@ public class InteractionInputData { private String userNickName; private String content; private LocalDateTime localDateTime; + private String platform; } diff --git a/src/main/java/work/slhaf/agent/core/memory/MemoryGraph.java b/src/main/java/work/slhaf/agent/core/memory/MemoryGraph.java index 6aaf17a5..4ad2da02 100644 --- a/src/main/java/work/slhaf/agent/core/memory/MemoryGraph.java +++ b/src/main/java/work/slhaf/agent/core/memory/MemoryGraph.java @@ -8,10 +8,7 @@ import work.slhaf.agent.common.chat.pojo.Message; import work.slhaf.agent.core.memory.exception.UnExistedTopicException; import work.slhaf.agent.core.memory.node.MemoryNode; import work.slhaf.agent.core.memory.node.TopicNode; -import work.slhaf.agent.core.memory.pojo.MemoryResult; -import work.slhaf.agent.core.memory.pojo.MemorySlice; -import work.slhaf.agent.core.memory.pojo.MemorySliceResult; -import work.slhaf.agent.core.memory.pojo.PersistableObject; +import work.slhaf.agent.core.memory.pojo.*; import java.io.*; import java.nio.file.Files; @@ -100,6 +97,11 @@ public class MemoryGraph extends PersistableObject { */ private List chatMessages; + /** + * 用户列表 + */ + private List users; + public MemoryGraph(String id) { this.id = id; this.topicNodes = new HashMap<>(); @@ -266,7 +268,7 @@ public class MemoryGraph extends PersistableObject { //放入新缓存 userDialogMap .computeIfAbsent(now, k -> new ConcurrentHashMap<>()) - .merge(slice.getStartUser(), slice.getSummary(), (oldVal, newVal) -> oldVal + " " + newVal); + .merge(slice.getStartUserId(), slice.getSummary(), (oldVal, newVal) -> oldVal + " " + newVal); } @@ -298,7 +300,8 @@ public class MemoryGraph extends PersistableObject { } - public MemoryResult selectMemory(List topicPath) throws IOException, ClassNotFoundException { + public MemoryResult selectMemory(String topicPathStr) throws IOException, ClassNotFoundException { + List topicPath = List.of(topicPathStr.split("->")); MemoryResult memoryResult = new MemoryResult(); //每日刷新缓存 @@ -319,7 +322,6 @@ public class MemoryGraph extends PersistableObject { MemorySliceResult sliceResult = new MemorySliceResult(); for (MemoryNode memoryNode : targetParentNode.getTopicNodes().get(targetTopic).getMemoryNodes()) { List endpointMemorySliceList = memoryNode.loadMemorySliceList(); -// targetSliceList.addAll(endpointMemorySliceList); for (MemorySlice memorySlice : endpointMemorySliceList) { sliceResult.setSliceBefore(memorySlice.getSliceBefore()); sliceResult.setMemorySlice(memorySlice); @@ -420,5 +422,28 @@ public class MemoryGraph extends PersistableObject { } return targetParentNode; } + + public void printTopicTree() { + for (Map.Entry entry : topicNodes.entrySet()) { + String rootName = entry.getKey(); + TopicNode rootNode = entry.getValue(); + System.out.println(rootName+"[root]"); + printSubTopicsTreeFormat(rootNode, "", true); + } + } + + private void printSubTopicsTreeFormat(TopicNode node, String prefix, boolean isLast) { + if (node.getTopicNodes() == null || node.getTopicNodes().isEmpty()) return; + + List> entries = new ArrayList<>(node.getTopicNodes().entrySet()); + for (int i = 0; i < entries.size(); i++) { + boolean last = (i == entries.size() - 1); + Map.Entry entry = entries.get(i); + System.out.println(prefix + (last ? "└── " : "├── ") + entry.getKey()); + printSubTopicsTreeFormat(entry.getValue(), prefix + (last ? " " : "│ "), last); + } + } + + } diff --git a/src/main/java/work/slhaf/agent/core/memory/MemoryManager.java b/src/main/java/work/slhaf/agent/core/memory/MemoryManager.java index 14298985..aa18cb34 100644 --- a/src/main/java/work/slhaf/agent/core/memory/MemoryManager.java +++ b/src/main/java/work/slhaf/agent/core/memory/MemoryManager.java @@ -5,9 +5,15 @@ import lombok.extern.slf4j.Slf4j; import work.slhaf.agent.common.config.Config; import work.slhaf.agent.core.interaction.InteractionModule; import work.slhaf.agent.core.interaction.data.InteractionContext; +import work.slhaf.agent.core.memory.pojo.MemoryResult; +import work.slhaf.agent.core.memory.pojo.User; import work.slhaf.agent.modules.memory.SliceEvaluator; import java.io.IOException; +import java.time.LocalDate; +import java.util.ArrayList; +import java.util.List; +import java.util.UUID; @Data @Slf4j @@ -36,4 +42,37 @@ public class MemoryManager implements InteractionModule { return memoryManager; } + public MemoryResult selectMemory(String path) throws IOException, ClassNotFoundException { + return memoryGraph.selectMemory(path); + } + + public MemoryResult selectMemory(LocalDate date) { + return memoryGraph.selectMemory(date); + } + + public String getUserId(String userInfo,String nickName) { + String userId = null; + for (User user : memoryGraph.getUsers()) { + if (user.getInfo().contains(userInfo)){ + userId = user.getUuid(); + } + } + if (userId == null) { + User newUser = setNewUser(userInfo, nickName); + memoryGraph.getUsers().add(newUser); + userId = newUser.getUuid(); + } + return userId; + } + + private static User setNewUser(String userInfo, String nickName) { + User newUser = new User(); + newUser.setUuid(UUID.randomUUID().toString()); + List infoList = new ArrayList<>(); + infoList.add(userInfo); + newUser.setInfo(infoList); + newUser.setNickName(nickName); + return newUser; + } + } diff --git a/src/main/java/work/slhaf/agent/core/memory/pojo/MemorySlice.java b/src/main/java/work/slhaf/agent/core/memory/pojo/MemorySlice.java index 79b69b8d..f223ab0b 100644 --- a/src/main/java/work/slhaf/agent/core/memory/pojo/MemorySlice.java +++ b/src/main/java/work/slhaf/agent/core/memory/pojo/MemorySlice.java @@ -45,12 +45,12 @@ public class MemorySlice extends PersistableObject implements Comparable involvedUsers; + private List involvedUserIds; /** * 是否仅供发起用户作为记忆参考 diff --git a/src/main/java/work/slhaf/agent/core/model/CoreModel.java b/src/main/java/work/slhaf/agent/core/model/CoreModel.java index e7c0d0e5..3659b118 100644 --- a/src/main/java/work/slhaf/agent/core/model/CoreModel.java +++ b/src/main/java/work/slhaf/agent/core/model/CoreModel.java @@ -38,6 +38,6 @@ public class CoreModel extends Model implements InteractionModule { //TODO 需要拼接上下文之后再发送给主模型 ChatResponse res = runChat(interactionContext.getInput()); - interactionContext.setCoreResponse(res); +// interactionContext.setCoreResponse(); } } diff --git a/src/main/java/work/slhaf/agent/gateway/AgentWebSocketServer.java b/src/main/java/work/slhaf/agent/gateway/AgentWebSocketServer.java index 922ee896..bb8b9ca4 100644 --- a/src/main/java/work/slhaf/agent/gateway/AgentWebSocketServer.java +++ b/src/main/java/work/slhaf/agent/gateway/AgentWebSocketServer.java @@ -41,8 +41,8 @@ public class AgentWebSocketServer extends WebSocketServer implements MessageSend InteractionInputData inputData = JSONObject.parseObject(s, InteractionInputData.class); userSessions.put(inputData.getUserInfo(), webSocket); // 注册连接 try { - agent.receiveUserInput(inputData.getUserNickName(), inputData.getUserInfo(), inputData.getContent()); - } catch (IOException e) { + agent.receiveUserInput(inputData); + } catch (IOException | ClassNotFoundException e) { throw new RuntimeException(e); } } diff --git a/src/main/java/work/slhaf/agent/modules/memory/MemorySelectExtractor.java b/src/main/java/work/slhaf/agent/modules/memory/MemorySelectExtractor.java new file mode 100644 index 00000000..1cb04155 --- /dev/null +++ b/src/main/java/work/slhaf/agent/modules/memory/MemorySelectExtractor.java @@ -0,0 +1,41 @@ +package work.slhaf.agent.modules.memory; + +import com.alibaba.fastjson2.JSONObject; +import lombok.Data; +import lombok.EqualsAndHashCode; +import work.slhaf.agent.common.config.Config; +import work.slhaf.agent.common.model.Model; +import work.slhaf.agent.common.model.ModelConstant; + +import java.io.IOException; + +@EqualsAndHashCode(callSuper = true) +@Data +public class MemorySelectExtractor extends Model { + public static final String MODEL_KEY = "topic_extractor"; + private static MemorySelectExtractor memorySelectExtractor; + + private MemorySelectExtractor() { + } + + public static MemorySelectExtractor getInstance() throws IOException, ClassNotFoundException { + if (memorySelectExtractor == null) { + Config config = Config.getConfig(); + memorySelectExtractor = new MemorySelectExtractor(); + setModel(config, memorySelectExtractor, MODEL_KEY, ModelConstant.TOPIC_EXTRACTOR_PROMPT); + } + + return memorySelectExtractor; + } + + public JSONObject execute(String input) { + return JSONObject.parseObject(singleChat(input).getMessage()); + } + + public static class Constant { + public static final String NONE = "none"; + public static final String DATE = "date"; + public static final String TOPIC = "topic"; + } + +} diff --git a/src/main/java/work/slhaf/agent/modules/memory/MemorySelector.java b/src/main/java/work/slhaf/agent/modules/memory/MemorySelector.java index dd3f86ab..1e29dd70 100644 --- a/src/main/java/work/slhaf/agent/modules/memory/MemorySelector.java +++ b/src/main/java/work/slhaf/agent/modules/memory/MemorySelector.java @@ -1,11 +1,14 @@ package work.slhaf.agent.modules.memory; +import com.alibaba.fastjson2.JSONObject; import lombok.Data; import work.slhaf.agent.core.interaction.InteractionModule; import work.slhaf.agent.core.interaction.data.InteractionContext; import work.slhaf.agent.core.memory.MemoryManager; +import work.slhaf.agent.core.memory.pojo.MemoryResult; import java.io.IOException; +import java.time.LocalDate; @Data public class MemorySelector implements InteractionModule { @@ -14,20 +17,40 @@ public class MemorySelector implements InteractionModule { private MemoryManager memoryManager; private SliceEvaluator sliceEvaluator; + private MemorySelectExtractor memorySelectExtractor; - private MemorySelector(){} + private MemorySelector() { + } public static MemorySelector getInstance() throws IOException, ClassNotFoundException { if (memorySelector == null) { memorySelector = new MemorySelector(); memorySelector.setMemoryManager(MemoryManager.getInstance()); memorySelector.setSliceEvaluator(SliceEvaluator.getInstance()); + memorySelector.setMemorySelectExtractor(MemorySelectExtractor.getInstance()); } return memorySelector; } @Override - public void execute(InteractionContext interactionContext) { + public void execute(InteractionContext interactionContext) throws IOException, ClassNotFoundException { + //获取主题路径 + JSONObject extractorResult = memorySelectExtractor.execute(interactionContext.getInput()); + String selectType = extractorResult.getString("type"); + //根据主结果进行操作查找切片 + MemoryResult memoryResult = switch (selectType) { + case MemorySelectExtractor.Constant.DATE -> + memoryManager.selectMemory(LocalDate.parse(extractorResult.getString(MemorySelectExtractor.Constant.DATE))); + case MemorySelectExtractor.Constant.TOPIC -> + memoryManager.selectMemory(MemorySelectExtractor.Constant.TOPIC); + default -> null; + }; + //评估切片 + if (memoryResult == null) { + memoryResult = sliceEvaluator.execute(memoryResult,interactionContext); + } + + //设置上下文 } } diff --git a/src/main/java/work/slhaf/agent/modules/memory/MemoryUpdater.java b/src/main/java/work/slhaf/agent/modules/memory/MemoryUpdater.java index 6c36f40c..1dd0687d 100644 --- a/src/main/java/work/slhaf/agent/modules/memory/MemoryUpdater.java +++ b/src/main/java/work/slhaf/agent/modules/memory/MemoryUpdater.java @@ -15,6 +15,7 @@ public class MemoryUpdater implements InteractionModule { private MemoryManager memoryManager; private InteractionThreadPoolExecutor executor; + private MemorySelectExtractor memorySelectExtractor; private MemoryUpdater(){} @@ -22,6 +23,7 @@ public class MemoryUpdater implements InteractionModule { if (memoryUpdater == null) { memoryUpdater = new MemoryUpdater(); memoryUpdater.setMemoryManager(MemoryManager.getInstance()); + memoryUpdater.setMemorySelectExtractor(MemorySelectExtractor.getInstance()); } return memoryUpdater; } diff --git a/src/main/java/work/slhaf/agent/modules/memory/SliceEvaluator.java b/src/main/java/work/slhaf/agent/modules/memory/SliceEvaluator.java index e7e19478..1b85d42f 100644 --- a/src/main/java/work/slhaf/agent/modules/memory/SliceEvaluator.java +++ b/src/main/java/work/slhaf/agent/modules/memory/SliceEvaluator.java @@ -1,13 +1,22 @@ package work.slhaf.agent.modules.memory; +import cn.hutool.json.JSONUtil; import lombok.Data; import lombok.EqualsAndHashCode; import lombok.extern.slf4j.Slf4j; import work.slhaf.agent.common.config.Config; import work.slhaf.agent.common.model.Model; import work.slhaf.agent.common.model.ModelConstant; +import work.slhaf.agent.core.interaction.data.InteractionContext; +import work.slhaf.agent.core.memory.MemoryManager; +import work.slhaf.agent.core.memory.pojo.MemoryResult; +import work.slhaf.agent.core.memory.pojo.MemorySlice; +import work.slhaf.agent.core.memory.pojo.MemorySliceResult; +import work.slhaf.agent.modules.memory.data.SliceSummary; import java.io.IOException; +import java.util.ArrayList; +import java.util.List; @EqualsAndHashCode(callSuper = true) @Data @@ -16,19 +25,72 @@ public class SliceEvaluator extends Model { public static final String MODEL_KEY = "slice_evaluator"; private static SliceEvaluator sliceEvaluator; + private MemoryManager memoryManager; - private SliceEvaluator(){} + private SliceEvaluator() { + } public static SliceEvaluator getInstance() throws IOException, ClassNotFoundException { if (sliceEvaluator == null) { Config config = Config.getConfig(); sliceEvaluator = new SliceEvaluator(); - setModel(config,sliceEvaluator, MODEL_KEY, ModelConstant.SLICE_EVALUATOR_PROMPT); + sliceEvaluator.setMemoryManager(MemoryManager.getInstance()); + setModel(config, sliceEvaluator, MODEL_KEY, ModelConstant.SLICE_EVALUATOR_PROMPT); log.info("SliceEvaluator注册完毕..."); } return sliceEvaluator; } + public MemoryResult execute(MemoryResult memoryResult, InteractionContext context) { + List sliceSummaryList = new ArrayList<>(); + setSliceSummaryList(memoryResult, context, sliceSummaryList); + String primaryJsonStr = singleChat(JSONUtil.toJsonStr(sliceSummaryList)).getMessage(); + //TODO 解析并转换为过滤后的MemoryResult + + return null; + } + + private void setSliceSummaryList(MemoryResult memoryResult, InteractionContext context, List sliceSummaryList) { + for (MemorySliceResult memorySliceResult : memoryResult.getMemorySliceResult()) { + //判断是否为发起用户 + if (accessible(memorySliceResult.getMemorySlice(), context)) { + SliceSummary sliceSummary = new SliceSummary(); + sliceSummary.setId(memorySliceResult.getMemorySlice().getTimestamp()); + String stringBuilder = memorySliceResult.getSliceBefore().getSummary() + + "\r\n" + + memorySliceResult.getMemorySlice().getSummary() + + "\r\n" + + memorySliceResult.getSliceAfter().getSummary(); + sliceSummary.setSummary(stringBuilder); + + sliceSummaryList.add(sliceSummary); + } + } + + for (MemorySlice memorySlice : memoryResult.getRelatedMemorySliceResult()) { + SliceSummary sliceSummary = new SliceSummary(); + sliceSummary.setId(memorySlice.getTimestamp()); + sliceSummary.setSummary(memorySlice.getSummary()); + + sliceSummaryList.add(sliceSummary); + } + } + + + private boolean accessible(MemorySlice slice, InteractionContext context) { + boolean ok; + String startUserId = slice.getStartUserId(); + String userInfo = context.getUserInfo(); + String nickName = context.getUserNickname(); + + if (memoryManager.getUserId(userInfo, nickName).equals(startUserId)) { + ok = true; + } else { + ok = !slice.isPrivate(); + } + + return ok; + } } diff --git a/src/main/java/work/slhaf/agent/modules/memory/data/SliceSummary.java b/src/main/java/work/slhaf/agent/modules/memory/data/SliceSummary.java new file mode 100644 index 00000000..3b15763a --- /dev/null +++ b/src/main/java/work/slhaf/agent/modules/memory/data/SliceSummary.java @@ -0,0 +1,9 @@ +package work.slhaf.agent.modules.memory.data; + +import lombok.Data; + +@Data +public class SliceSummary { + private String summary; + private Long id; +} diff --git a/src/main/java/work/slhaf/agent/modules/preprocess/PreprocessExecutor.java b/src/main/java/work/slhaf/agent/modules/preprocess/PreprocessExecutor.java index d7ff071c..d02755b3 100644 --- a/src/main/java/work/slhaf/agent/modules/preprocess/PreprocessExecutor.java +++ b/src/main/java/work/slhaf/agent/modules/preprocess/PreprocessExecutor.java @@ -1,5 +1,6 @@ package work.slhaf.agent.modules.preprocess; +import com.alibaba.fastjson2.JSONObject; import work.slhaf.agent.core.interaction.data.InteractionContext; import work.slhaf.agent.core.interaction.data.InteractionInputData; @@ -26,6 +27,8 @@ public class PreprocessExecutor { context.setFinished(false); context.setInput(inputData.getContent()); + context.setModuleContext(new JSONObject()); + return context; } } diff --git a/src/main/java/work/slhaf/agent/modules/task/TaskEvaluator.java b/src/main/java/work/slhaf/agent/modules/task/TaskEvaluator.java index 28338654..f2e98d70 100644 --- a/src/main/java/work/slhaf/agent/modules/task/TaskEvaluator.java +++ b/src/main/java/work/slhaf/agent/modules/task/TaskEvaluator.java @@ -3,7 +3,6 @@ package work.slhaf.agent.modules.task; import lombok.Data; import lombok.EqualsAndHashCode; import work.slhaf.agent.common.config.Config; -import work.slhaf.agent.common.config.ModelConfig; import work.slhaf.agent.common.model.Model; import work.slhaf.agent.common.model.ModelConstant; diff --git a/src/main/java/work/slhaf/agent/modules/topic/TopicExtractor.java b/src/main/java/work/slhaf/agent/modules/topic/TopicExtractor.java deleted file mode 100644 index 6ec36f5a..00000000 --- a/src/main/java/work/slhaf/agent/modules/topic/TopicExtractor.java +++ /dev/null @@ -1,41 +0,0 @@ -package work.slhaf.agent.modules.topic; - -import lombok.Data; -import lombok.EqualsAndHashCode; -import work.slhaf.agent.common.chat.constant.ChatConstant; -import work.slhaf.agent.common.chat.pojo.ChatResponse; -import work.slhaf.agent.common.chat.pojo.Message; -import work.slhaf.agent.common.config.Config; -import work.slhaf.agent.common.model.Model; -import work.slhaf.agent.common.model.ModelConstant; -import work.slhaf.agent.core.interaction.InteractionModule; -import work.slhaf.agent.core.interaction.data.InteractionContext; - -import java.io.IOException; - -@EqualsAndHashCode(callSuper = true) -@Data -public class TopicExtractor extends Model implements InteractionModule { - public static final String MODEL_KEY = "topic_extractor"; - private static TopicExtractor topicExtractor; - - private TopicExtractor() { - } - - public static TopicExtractor getInstance() throws IOException, ClassNotFoundException { - if (topicExtractor == null) { - Config config = Config.getConfig(); - topicExtractor = new TopicExtractor(); - setModel(config, topicExtractor, MODEL_KEY, ModelConstant.TOPIC_EXTRACTOR_PROMPT); - } - - return topicExtractor; - } - - @Override - public void execute(InteractionContext interactionContext) { - String primaryMessageResponse = singleChat(interactionContext.getInput()).getMessage(); - - } - -} diff --git a/src/test/java/memory/AITest.java b/src/test/java/memory/AITest.java new file mode 100644 index 00000000..6b8a1831 --- /dev/null +++ b/src/test/java/memory/AITest.java @@ -0,0 +1,89 @@ +package memory; + +import org.junit.jupiter.api.Test; +import work.slhaf.agent.common.chat.ChatClient; +import work.slhaf.agent.common.chat.constant.ChatConstant; +import work.slhaf.agent.common.chat.pojo.Message; + +import java.util.ArrayList; +import java.util.List; + +public class AITest { + @Test + public void test1(){ + ChatClient client = new ChatClient("https://open.bigmodel.cn/api/paas/v4/chat/completions","3db444552530b7742b0c53425fb93dcc.LcVwYjByht9AC3N9","glm-4-flash"); + List messages = new ArrayList<>(); + messages.add(new Message(ChatConstant.Character.SYSTEM, """ + # MemorySelectExtractor 提示词 + + ## 功能说明 + 你需要根据用户输入的JSON数据,分析其`text`字段内容,判断是否需要通过主题路径或日期进行记忆查询,并返回标准化格式的JSON响应。 + + ## 输入字段说明 + - `text`: 用户输入的文本内容 + - `topic_tree`: 当前可用的主题树结构(括号内数字表示子主题数量) + - `date`: 当前对话发生的日期(用于时间推理) + + ## 输出规则 + 1. 当文本涉及明确主题路径时: + - 使用`"type": "topic"` + - `text`字段格式为"根主题->子主题->子子主题"(必须**完全匹配**topic_tree中的层级,包括从[root]到目标主题的完整路径) + - 示例:{ + "type": "topic", + "text": "工作->项目A->需求文档" + } + + 2. 当文本包含明确可推算的日期时: + - 使用`"type": "date"` + - 日期格式必须为"YYYY-MM-DD" + - 仅接受具体日期(不接受"上周"等模糊表达) + - 示例:{ + "type": "date", + "text": "2024-04-15" + } + + 3. 当不需要查询或无法确定时: + - 使用`"type": "none"` + - 示例:{ + "type": "none" + } + + ## 完整示例 + 用户输入:{ + "text": "还记得我们讨论过游戏引擎的物理系统实现吗?", + "topic_tree": " + 技术 (3)[root] + ├── 游戏开发 (2) + │ ├── 图形渲染 (1) + │ └── 物理系统 (0) + └── 人工智能 (1)", + "date": "2024-04-20" + } + + 正确响应:{ + "type": "topic", + "text": "技术->游戏开发->物理系统" + } + """)); + + messages.add(new Message(ChatConstant.Character.USER, """ + { + "text": "上周似乎发生了什么重要的事??", + "topic_tree": " + 汽车工程 (4)[root] + ├── 动力系统 (3) + │ ├── 发动机 (1) + │ └── 新能源电池 (2) + │ ├── 测试标准 (1) + │ └── 安全规范 (1) + └── 车身设计 (1) + 软件开发 (3)[root] + 质量管理 (2)[root] + ├── ISO认证 (1) + └── 行业标准 (1)", + "date": "2024-04-20" + } + """)); + System.out.println(client.runChat(messages).getMessage()); + } +} diff --git a/src/test/java/memory/MemoryTest.java b/src/test/java/memory/MemoryTest.java new file mode 100644 index 00000000..4e75f2e8 --- /dev/null +++ b/src/test/java/memory/MemoryTest.java @@ -0,0 +1,62 @@ +package memory; + +import cn.hutool.core.date.LocalDateTimeUtil; +import org.junit.jupiter.api.Test; +import work.slhaf.agent.core.memory.MemoryGraph; +import work.slhaf.agent.core.memory.node.TopicNode; + +import java.time.LocalDate; +import java.time.LocalDateTime; +import java.util.HashMap; +import java.util.concurrent.ConcurrentHashMap; + +public class MemoryTest { + +@Test +public void test1() { + MemoryGraph graph = new MemoryGraph("test"); + HashMap topicMap = new HashMap<>(); + + TopicNode root1 = new TopicNode(); + root1.setTopicNodes(new ConcurrentHashMap<>()); + + TopicNode sub1 = new TopicNode(); + sub1.setTopicNodes(new ConcurrentHashMap<>()); + + TopicNode sub2 = new TopicNode(); + sub2.setTopicNodes(new ConcurrentHashMap<>()); + + TopicNode subsub1 = new TopicNode(); + subsub1.setTopicNodes(new ConcurrentHashMap<>()); + + // 构造结构:root -> sub1 -> subsub1, root -> sub2 + sub1.getTopicNodes().put("子子主题1", subsub1); + root1.getTopicNodes().put("子主题1", sub1); + root1.getTopicNodes().put("子主题2", sub2); + + topicMap.put("根主题1", root1); + + // 添加 root2 + TopicNode root2 = new TopicNode(); + root2.setTopicNodes(new ConcurrentHashMap<>()); + + TopicNode sub3 = new TopicNode(); + sub3.setTopicNodes(new ConcurrentHashMap<>()); + + // 构造结构:root2 -> sub3 + root2.getTopicNodes().put("子主题3", sub3); + + topicMap.put("根主题2", root2); + + // 输出 + graph.setTopicNodes(topicMap); + graph.printTopicTree(); +} + + + @Test + public void test2(){ + System.out.println(LocalDate.now()); + } + +} diff --git a/src/test/java/memory/SearchTest.java b/src/test/java/memory/SearchTest.java index 2b60d21e..4fcf8628 100644 --- a/src/test/java/memory/SearchTest.java +++ b/src/test/java/memory/SearchTest.java @@ -59,7 +59,7 @@ class SearchTest { List queryPath = new ArrayList<>(); queryPath.add("算法"); queryPath.add("排序"); - MemoryResult results = memoryGraph.selectMemory(queryPath); +// MemoryResult results = memoryGraph.selectMemory(queryPath); // 验证结果应包含: // 1. 目标节点所有记忆(java1) @@ -77,7 +77,7 @@ class SearchTest { invalidPath.add("不存在的主题"); assertThrows(UnExistedTopicException.class, () -> { - memoryGraph.selectMemory(invalidPath); +// memoryGraph.selectMemory(invalidPath); }); } @@ -94,7 +94,7 @@ class SearchTest { List queryPath = new ArrayList<>(); queryPath.add("编程"); queryPath.add("Java"); - MemoryResult results = memoryGraph.selectMemory(queryPath); +// MemoryResult results = memoryGraph.selectMemory(queryPath); // 应包含:Java记忆 + 父级最新记忆 // assertTrue(results.stream().anyMatch(m -> "java1".equals(m.getMemoryId()))); @@ -136,7 +136,7 @@ class SearchTest { // 执行查询 List queryPath = createTopicPath("编程", "Java"); - MemoryResult results = memoryGraph.selectMemory(queryPath); +// MemoryResult results = memoryGraph.selectMemory(queryPath); // 验证结果应包含最新关联记忆(dbNew) // assertTrue(results.stream().anyMatch(m -> "dbNew".equals(m.getMemoryId())),