From 34c6b861c8e4376f397a4e47c484a5e5b56ab821 Mon Sep 17 00:00:00 2001 From: slhaf Date: Thu, 17 Apr 2025 23:12:13 +0800 Subject: [PATCH] =?UTF-8?q?refactor(agent):=20=E6=98=8E=E7=A1=AE=E6=A8=A1?= =?UTF-8?q?=E5=9D=97=E5=8C=96=E8=AE=BE=E8=AE=A1=E6=B5=81=E7=A8=8B=EF=BC=8C?= =?UTF-8?q?=E5=85=B7=E4=BD=93=E9=80=BB=E8=BE=91=E5=BE=85=E5=AE=9E=E7=8E=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 调整配置文件路径 - 新增 InteractionModulesLoader 用于动态加载交互模块,加载扩展模块待实现 - 修复 MemoryGraph 和 MemoryNode 的部分逻辑 - 改进 ModelConfig 类,支持单独配置文件, 用于动态加载模块 - 新增 PreprocessExecutor 和 TaskEvaluator模块, 待后续实现 --- .gitignore | 1 + .idea/misc.xml | 4 + pom.xml | 5 + src/main/java/work/slhaf/Main.java | 1 + src/main/java/work/slhaf/agent/Agent.java | 8 +- .../slhaf/agent/common/config/Config.java | 112 +++++++++++------- .../agent/common/config/ModelConfig.java | 30 +++++ .../agent/common/config/ModuleConfig.java | 17 +++ .../work/slhaf/agent/common/model/Model.java | 7 +- .../work/slhaf/agent/core/InteractionHub.java | 13 +- .../interation/InteractionModulesLoader.java | 34 ++++++ .../slhaf/agent/core/model/CoreModel.java | 7 +- .../agent/gateway/AgentWebSocketServer.java | 7 +- .../agent/modules/memory/MemoryGraph.java | 37 +++--- .../agent/modules/memory/MemoryManager.java | 18 ++- .../agent/modules/memory/SliceEvaluator.java | 6 +- .../agent/modules/memory/node/MemoryNode.java | 7 +- .../preprocess/PreprocessExecutor.java | 29 +++++ .../agent/modules/task/TaskEvaluator.java | 4 + .../agent/modules/task/TaskScheduler.java | 15 ++- .../agent/modules/topic/TopicExtractor.java | 17 ++- src/test/java/memory/InsertTest.java | 10 +- 22 files changed, 293 insertions(+), 96 deletions(-) create mode 100644 src/main/java/work/slhaf/agent/common/config/ModuleConfig.java create mode 100644 src/main/java/work/slhaf/agent/core/interation/InteractionModulesLoader.java create mode 100644 src/main/java/work/slhaf/agent/modules/preprocess/PreprocessExecutor.java create mode 100644 src/main/java/work/slhaf/agent/modules/task/TaskEvaluator.java diff --git a/.gitignore b/.gitignore index c2f92312..223d58ec 100644 --- a/.gitignore +++ b/.gitignore @@ -37,3 +37,4 @@ build/ ### Mac OS ### .DS_Store /data/ +/config/ diff --git a/.idea/misc.xml b/.idea/misc.xml index fdc35ea8..dddff606 100644 --- a/.idea/misc.xml +++ b/.idea/misc.xml @@ -8,6 +8,10 @@ + + diff --git a/pom.xml b/pom.xml index 151ea062..e8a47aa2 100644 --- a/pom.xml +++ b/pom.xml @@ -74,6 +74,11 @@ hutool-all 5.8.36 + + work.slhaf + Partner-Modules-Api + 1.0-SNAPSHOT + \ No newline at end of file diff --git a/src/main/java/work/slhaf/Main.java b/src/main/java/work/slhaf/Main.java index 50ac87b3..50d3195b 100644 --- a/src/main/java/work/slhaf/Main.java +++ b/src/main/java/work/slhaf/Main.java @@ -8,5 +8,6 @@ import java.io.IOException; public class Main { public static void main(String[] args) throws IOException { Agent agent = Agent.initialize(); + agent.receiveUserInput("111","222","hello"); } } \ No newline at end of file diff --git a/src/main/java/work/slhaf/agent/Agent.java b/src/main/java/work/slhaf/agent/Agent.java index f5dc45b8..50cbc9e1 100644 --- a/src/main/java/work/slhaf/agent/Agent.java +++ b/src/main/java/work/slhaf/agent/Agent.java @@ -23,9 +23,9 @@ public class Agent implements TaskCallback { public static Agent initialize() throws IOException { if (agent == null) { //加载配置 - Config config = Config.load(); + Config config = Config.getConfig(); agent = new Agent(); - agent.setInteractionHub(InteractionHub.initialize(config)); + agent.setInteractionHub(InteractionHub.initialize()); agent.registerTaskCallback(); agent.setMessageSender(new AgentWebSocketServer(config.getWebSocketConfig().getPort(),agent)); log.info("Agent 加载完毕.."); @@ -37,7 +37,7 @@ public class Agent implements TaskCallback { * 接收用户输入,包装为标准输入数据类 * @param input */ - public void receiveUserInput(String userNickName,String userInfo,String input){ + public void receiveUserInput(String userNickName,String userInfo,String input) throws IOException { InteractionInputData inputData = new InteractionInputData(); inputData.setContent(input); inputData.setUserInfo(userInfo); @@ -53,7 +53,7 @@ public class Agent implements TaskCallback { */ public void sendToUser(String userInfo,String output){ System.out.println(output); - messageSender.sendMessage(userInfo,output); +// messageSender.sendMessage(userInfo,output); } @Override diff --git a/src/main/java/work/slhaf/agent/common/config/Config.java b/src/main/java/work/slhaf/agent/common/config/Config.java index 9350f30e..1ac890b8 100644 --- a/src/main/java/work/slhaf/agent/common/config/Config.java +++ b/src/main/java/work/slhaf/agent/common/config/Config.java @@ -5,6 +5,7 @@ import lombok.Data; import lombok.extern.slf4j.Slf4j; import org.apache.commons.io.FileUtils; import work.slhaf.agent.core.model.CoreModel; +import work.slhaf.agent.modules.memory.MemoryManager; import work.slhaf.agent.modules.memory.SliceEvaluator; import work.slhaf.agent.modules.task.TaskScheduler; import work.slhaf.agent.modules.topic.TopicExtractor; @@ -12,81 +13,100 @@ import work.slhaf.agent.modules.topic.TopicExtractor; import java.io.File; import java.io.IOException; import java.nio.charset.StandardCharsets; -import java.util.HashMap; +import java.util.List; import java.util.Scanner; @Data @Slf4j public class Config { - private static final String CONFIG_FILE_PATH = "./data/config/config.json"; + private static final String CONFIG_FILE_PATH = "./config/config.json"; private static Config config; private String agentId; - private HashMap modelConfig; - private WebSocketConfig webSocketConfig; - public static Config load() throws IOException { + private List moduleConfigList; + + private Config() { + } + + public static Config getConfig() throws IOException { if (config == null) { File file = new File(CONFIG_FILE_PATH); if (file.exists()) { config = JSONUtil.readJSONObject(file, StandardCharsets.UTF_8).toBean(Config.class); } else { - Config tempConfig = new Config(); + config = new Config(); Scanner scanner = new Scanner(System.in); System.out.print("输入智能体名称: "); - tempConfig.setAgentId(scanner.nextLine()); + config.setAgentId(scanner.nextLine()); System.out.println("\r\n--------模型配置--------\r\n"); - HashMap modelConfig = new HashMap<>(); - for (int i = 0; i < 4; i++) { - String modelKey = switch (i) { - case 0 -> { - System.out.println("CoreModel:"); - yield CoreModel.MODEL_KEY; - } - case 1 -> { - System.out.println("SliceEvaluator:"); - yield SliceEvaluator.MODEL_KEY; - } - case 2 -> { - System.out.println("TaskTrigger:"); - yield TaskScheduler.MODEL_KEY; - } - case 3 -> { - System.out.println("TopicExtractor:"); - yield TopicExtractor.MODEL_KEY; - } - default -> throw new RuntimeException(); - }; - System.out.println(modelKey); - ModelConfig temp = new ModelConfig(); - System.out.print("apikey: "); - temp.setApikey(scanner.nextLine()); - System.out.print("baseUrl: "); - temp.setBaseUrl(scanner.nextLine()); - System.out.print("model: "); - temp.setModel(scanner.nextLine()); - - modelConfig.put(modelKey, temp); - } - tempConfig.setModelConfig(modelConfig); + generateModelConfig(scanner); System.out.println("\r\n--------服务配置--------\r\n"); - System.out.print("WebSocket port: "); - WebSocketConfig wsConfig = new WebSocketConfig(); - wsConfig.setPort(scanner.nextInt()); + generateWsSocketConfig(scanner); + + System.out.println("\r\n--------模块链配置--------\r\n"); + generatePipelineConfig(); //保存配置文件 - String str = JSONUtil.toJsonPrettyStr(tempConfig); - FileUtils.writeStringToFile(file,str,StandardCharsets.UTF_8); + String str = JSONUtil.toJsonPrettyStr(config); + FileUtils.writeStringToFile(file, str, StandardCharsets.UTF_8); log.info("配置已保存"); - config = tempConfig; } } return config; } + + private static void generatePipelineConfig() { + List moduleConfigList = List.of( + new ModuleConfig(TopicExtractor.class.getName(), ModuleConfig.Constant.INTERNAL, null), + new ModuleConfig(MemoryManager.class.getName(), ModuleConfig.Constant.INTERNAL, null), + new ModuleConfig(TaskScheduler.class.getName(), ModuleConfig.Constant.INTERNAL, null) + ); + config.setModuleConfigList(moduleConfigList); + } + + private static void generateWsSocketConfig(Scanner scanner) { + System.out.print("WebSocket port: "); + WebSocketConfig wsConfig = new WebSocketConfig(); + wsConfig.setPort(scanner.nextInt()); + config.setWebSocketConfig(wsConfig); + } + + private static void generateModelConfig(Scanner scanner) throws IOException { + for (int i = 0; i < 4; i++) { + String modelKey = switch (i) { + case 0 -> { + System.out.println("CoreModel:"); + yield CoreModel.MODEL_KEY; + } + case 1 -> { + System.out.println("SliceEvaluator:"); + yield SliceEvaluator.MODEL_KEY; + } + case 2 -> { + System.out.println("TaskTrigger:"); + yield TaskScheduler.MODEL_KEY; + } + case 3 -> { + System.out.println("TopicExtractor:"); + yield TopicExtractor.MODEL_KEY; + } + default -> throw new RuntimeException(); + }; + ModelConfig modelConfig = new ModelConfig(); + System.out.print("apikey: "); + modelConfig.setApikey(scanner.nextLine()); + System.out.print("baseUrl: "); + modelConfig.setBaseUrl(scanner.nextLine()); + System.out.print("model: "); + modelConfig.setModel(scanner.nextLine()); + modelConfig.generateConfig(modelKey); + } + } } diff --git a/src/main/java/work/slhaf/agent/common/config/ModelConfig.java b/src/main/java/work/slhaf/agent/common/config/ModelConfig.java index d6b77358..be2c3f07 100644 --- a/src/main/java/work/slhaf/agent/common/config/ModelConfig.java +++ b/src/main/java/work/slhaf/agent/common/config/ModelConfig.java @@ -1,10 +1,40 @@ package work.slhaf.agent.common.config; +import cn.hutool.json.JSONUtil; import lombok.Data; +import org.apache.commons.io.FileUtils; + +import java.io.File; +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.util.HashMap; @Data public class ModelConfig { + + private static final String MODEL_CONFIG_DIR_PATH = "./config/model/"; + private static final HashMap modelConfigMap = new HashMap<>(); + private String apikey; private String baseUrl; private String model; + + public void generateConfig(String filename) throws IOException { + String str = JSONUtil.toJsonPrettyStr(this); + File file = new File(MODEL_CONFIG_DIR_PATH + filename + ".json"); + FileUtils.writeStringToFile(file, str, StandardCharsets.UTF_8); + } + + public static ModelConfig load(String modelKey) { + if (!modelConfigMap.containsKey(modelKey)) { + modelConfigMap.put(modelKey,loadConfig(modelKey)); + } + + return modelConfigMap.get(modelKey); + } + + private static ModelConfig loadConfig(String modelKey) { + File file = new File(MODEL_CONFIG_DIR_PATH+modelKey+".json"); + return JSONUtil.readJSONObject(file,StandardCharsets.UTF_8).toBean(ModelConfig.class); + } } diff --git a/src/main/java/work/slhaf/agent/common/config/ModuleConfig.java b/src/main/java/work/slhaf/agent/common/config/ModuleConfig.java new file mode 100644 index 00000000..8006c15c --- /dev/null +++ b/src/main/java/work/slhaf/agent/common/config/ModuleConfig.java @@ -0,0 +1,17 @@ +package work.slhaf.agent.common.config; + +import lombok.AllArgsConstructor; +import lombok.Data; + +@Data +@AllArgsConstructor +public class ModuleConfig { + private String className; + private String type; + private String path; + + public static class Constant { + public static final String INTERNAL = "internal"; + public static final String EXTERNAL = "external"; + } +} diff --git a/src/main/java/work/slhaf/agent/common/model/Model.java b/src/main/java/work/slhaf/agent/common/model/Model.java index 4ca6c983..3256b954 100644 --- a/src/main/java/work/slhaf/agent/common/model/Model.java +++ b/src/main/java/work/slhaf/agent/common/model/Model.java @@ -8,6 +8,7 @@ import work.slhaf.agent.common.config.Config; import work.slhaf.agent.common.config.ModelConfig; import work.slhaf.agent.modules.memory.MemoryGraph; +import java.io.IOException; import java.util.ArrayList; import java.util.List; @@ -17,9 +18,9 @@ public class Model { protected String prompt; protected List messages; - protected static void setModel(Config config, Model model, String model_key, String prompt) { - MemoryGraph memoryGraph = MemoryGraph.initialize(config.getAgentId()); - ModelConfig modelConfig = config.getModelConfig().get(model_key); + protected static void setModel(Config config, Model model, String model_key, String prompt) throws IOException, ClassNotFoundException { + MemoryGraph memoryGraph = MemoryGraph.getInstance(config.getAgentId()); + ModelConfig modelConfig = ModelConfig.load(model_key); if (memoryGraph.getModelPrompt().containsKey(model_key)) { model.setPrompt(memoryGraph.getModelPrompt().get(model_key)); } else { diff --git a/src/main/java/work/slhaf/agent/core/InteractionHub.java b/src/main/java/work/slhaf/agent/core/InteractionHub.java index 042c3f30..b0e46048 100644 --- a/src/main/java/work/slhaf/agent/core/InteractionHub.java +++ b/src/main/java/work/slhaf/agent/core/InteractionHub.java @@ -3,11 +3,16 @@ package work.slhaf.agent.core; import lombok.Data; import lombok.extern.slf4j.Slf4j; import work.slhaf.agent.common.config.Config; +import work.slhaf.agent.core.interation.InteractionModulesLoader; import work.slhaf.agent.core.interation.TaskCallback; import work.slhaf.agent.core.interation.data.InteractionInputData; import work.slhaf.agent.core.model.CoreModel; import work.slhaf.agent.modules.memory.MemoryManager; import work.slhaf.agent.modules.task.TaskScheduler; +import work.slhaf.module.InteractionModule; + +import java.io.IOException; +import java.util.List; @Data @Slf4j @@ -21,18 +26,16 @@ public class InteractionHub { private MemoryManager memoryManager; private TaskScheduler taskScheduler; - public static InteractionHub initialize(Config config) { + public static InteractionHub initialize() throws IOException { if (interactionHub == null) { interactionHub = new InteractionHub(); - interactionHub.setCoreModel(CoreModel.initialize(config)); - interactionHub.setMemoryManager(MemoryManager.initialize(config)); - interactionHub.setTaskScheduler(TaskScheduler.initialize(config)); log.info("InteractionHub注册完毕..."); } return interactionHub; } - public void call(InteractionInputData inputData) { + public void call(InteractionInputData inputData) throws IOException { + List interactionModules = InteractionModulesLoader.registerInteractionModules(); callback.onTaskFinished(null, null); } diff --git a/src/main/java/work/slhaf/agent/core/interation/InteractionModulesLoader.java b/src/main/java/work/slhaf/agent/core/interation/InteractionModulesLoader.java new file mode 100644 index 00000000..71103391 --- /dev/null +++ b/src/main/java/work/slhaf/agent/core/interation/InteractionModulesLoader.java @@ -0,0 +1,34 @@ +package work.slhaf.agent.core.interation; + +import work.slhaf.agent.common.config.Config; +import work.slhaf.agent.common.config.ModuleConfig; +import work.slhaf.module.InteractionModule; + +import java.io.IOException; +import java.lang.reflect.InvocationTargetException; +import java.util.ArrayList; +import java.util.List; + +public class InteractionModulesLoader { + public static List registerInteractionModules() throws IOException { + List moduleList = new ArrayList<>(); + List moduleConfigList = Config.getConfig().getModuleConfigList(); + for (ModuleConfig moduleConfig : moduleConfigList) { + if (ModuleConfig.Constant.INTERNAL.equals(moduleConfig.getType())) { + moduleList.add(loadInternalModule(moduleConfig.getClassName())); + } + } + return moduleList; + } + + private static InteractionModule loadInternalModule(String moduleName) { + try { + Class clazz = Class.forName(moduleName); + + //TODO 后续需要规范`getInstance`方法的实现 + return (InteractionModule) clazz.getMethod("getInstance").invoke(null); + } catch (ClassNotFoundException | InvocationTargetException | IllegalAccessException | NoSuchMethodException e) { + throw new RuntimeException("Fail to load internal module: " + moduleName,e); + } + } +} diff --git a/src/main/java/work/slhaf/agent/core/model/CoreModel.java b/src/main/java/work/slhaf/agent/core/model/CoreModel.java index 03a30ebc..b5667b49 100644 --- a/src/main/java/work/slhaf/agent/core/model/CoreModel.java +++ b/src/main/java/work/slhaf/agent/core/model/CoreModel.java @@ -7,6 +7,8 @@ import work.slhaf.agent.common.config.Config; import work.slhaf.agent.common.model.Model; import work.slhaf.agent.common.model.ModelConstant; +import java.io.IOException; + @EqualsAndHashCode(callSuper = true) @Data @Slf4j @@ -15,8 +17,11 @@ public class CoreModel extends Model { public static final String MODEL_KEY = "core_model"; private static CoreModel coreModel; - public static CoreModel initialize(Config config) { + private CoreModel(){} + + public static CoreModel getInstance() throws IOException, ClassNotFoundException { if (coreModel == null) { + Config config = Config.getConfig(); coreModel = new CoreModel(); coreModel.setPrompt(ModelConstant.CORE_MODEL_PROMPT); setModel(config, coreModel, MODEL_KEY, coreModel.getPrompt()); diff --git a/src/main/java/work/slhaf/agent/gateway/AgentWebSocketServer.java b/src/main/java/work/slhaf/agent/gateway/AgentWebSocketServer.java index 2561a460..15c58448 100644 --- a/src/main/java/work/slhaf/agent/gateway/AgentWebSocketServer.java +++ b/src/main/java/work/slhaf/agent/gateway/AgentWebSocketServer.java @@ -10,6 +10,7 @@ import work.slhaf.agent.Agent; import work.slhaf.agent.core.interation.data.InteractionInputData; import work.slhaf.agent.core.interation.data.InteractionOutputData; +import java.io.IOException; import java.net.InetSocketAddress; import java.util.concurrent.ConcurrentHashMap; @@ -39,7 +40,11 @@ public class AgentWebSocketServer extends WebSocketServer implements MessageSend public void onMessage(WebSocket webSocket, String s) { InteractionInputData inputData = JSONObject.parseObject(s, InteractionInputData.class); userSessions.put(inputData.getUserInfo(), webSocket); // 注册连接 - agent.receiveUserInput(inputData.getUserNickName(), inputData.getUserInfo(), inputData.getContent()); + try { + agent.receiveUserInput(inputData.getUserNickName(), inputData.getUserInfo(), inputData.getContent()); + } catch (IOException e) { + throw new RuntimeException(e); + } } @Override diff --git a/src/main/java/work/slhaf/agent/modules/memory/MemoryGraph.java b/src/main/java/work/slhaf/agent/modules/memory/MemoryGraph.java index e676020d..4c9dd5e3 100644 --- a/src/main/java/work/slhaf/agent/modules/memory/MemoryGraph.java +++ b/src/main/java/work/slhaf/agent/modules/memory/MemoryGraph.java @@ -3,6 +3,7 @@ package work.slhaf.agent.modules.memory; import lombok.Data; import lombok.EqualsAndHashCode; import lombok.extern.slf4j.Slf4j; +import org.apache.commons.io.FileUtils; import work.slhaf.agent.common.chat.pojo.Message; import work.slhaf.agent.modules.memory.exception.UnExistedTopicException; import work.slhaf.agent.modules.memory.node.MemoryNode; @@ -36,7 +37,7 @@ public class MemoryGraph extends PersistableObject { * key: 根主题名称 value: 根主题节点 */ private HashMap topicNodes; - public static MemoryGraph memoryGraph; + private static MemoryGraph memoryGraph; /** * 用于存储已存在的主题列表,便于记忆查找, 使用根主题名称作为键, 子主题名称集合为值 @@ -110,31 +111,27 @@ public class MemoryGraph extends PersistableObject { this.modelPrompt = new HashMap<>(); } - public static MemoryGraph initialize(String id) { + public static MemoryGraph getInstance(String id) throws IOException, ClassNotFoundException { // 检查存储目录是否存在,不存在则创建 createStorageDirectory(); - - Path filePath = getFilePath(id); - if (memoryGraph == null && Files.exists(filePath)) { - try { - // 从文件加载 + if (memoryGraph == null) { + Path filePath = getFilePath(id); + if (Files.exists(filePath)) { memoryGraph = deserialize(id); - } catch (Exception e) { - log.error("加载序列化文件失败,创建新实例"); - System.exit(1); + }else { + FileUtils.createParentDirectories(filePath.toFile().getParentFile()); + memoryGraph = new MemoryGraph(id); + memoryGraph.serialize(); } - } else { - // 创建新实例 - memoryGraph = new MemoryGraph(id); + log.info("MemoryGraph注册完毕..."); } - log.info("MemoryGraph注册完毕..."); return memoryGraph; } - public void serialize() { + public void serialize() throws IOException { Path filePath = getFilePath(this.id); - + Files.createDirectories(Path.of(STORAGE_DIR)); try (ObjectOutputStream oos = new ObjectOutputStream( new FileOutputStream(filePath.toFile()))) { oos.writeObject(this); @@ -193,7 +190,7 @@ public class MemoryGraph extends PersistableObject { lastTopicNode.getMemoryNodes().add(node); lastTopicNode.getMemoryNodes().sort(null); } - node.getMemorySliceList().add(slice); + node.loadMemorySliceList().add(slice); //生成relatedTopicPath for (List relatedTopic : slice.getRelatedTopics()) { @@ -321,7 +318,7 @@ public class MemoryGraph extends PersistableObject { //终点记忆节点 MemorySliceResult sliceResult = new MemorySliceResult(); for (MemoryNode memoryNode : targetParentNode.getTopicNodes().get(targetTopic).getMemoryNodes()) { - List endpointMemorySliceList = memoryNode.getMemorySliceList(); + List endpointMemorySliceList = memoryNode.loadMemorySliceList(); // targetSliceList.addAll(endpointMemorySliceList); for (MemorySlice memorySlice : endpointMemorySliceList) { sliceResult.setSliceBefore(memorySlice.getSliceBefore()); @@ -348,14 +345,14 @@ public class MemoryGraph extends PersistableObject { TopicNode tempTargetNode = tempTargetParentNode.getTopicNodes().get(tempTopicPath.getLast()); List tempMemoryNodes = tempTargetNode.getMemoryNodes(); if (!tempMemoryNodes.isEmpty()) { - relatedMemorySlice.addAll(tempMemoryNodes.getFirst().getMemorySliceList()); + relatedMemorySlice.addAll(tempMemoryNodes.getFirst().loadMemorySliceList()); } } //邻近记忆节点 父级 List targetParentMemoryNodes = targetParentNode.getMemoryNodes(); if (!targetParentMemoryNodes.isEmpty()) { - relatedMemorySlice.addAll(targetParentMemoryNodes.getFirst().getMemorySliceList()); + relatedMemorySlice.addAll(targetParentMemoryNodes.getFirst().loadMemorySliceList()); } //将上述结果包装为MemoryResult diff --git a/src/main/java/work/slhaf/agent/modules/memory/MemoryManager.java b/src/main/java/work/slhaf/agent/modules/memory/MemoryManager.java index bd8cf11d..f91aab71 100644 --- a/src/main/java/work/slhaf/agent/modules/memory/MemoryManager.java +++ b/src/main/java/work/slhaf/agent/modules/memory/MemoryManager.java @@ -3,20 +3,32 @@ package work.slhaf.agent.modules.memory; import lombok.Data; import lombok.extern.slf4j.Slf4j; import work.slhaf.agent.common.config.Config; +import work.slhaf.module.InteractionContext; +import work.slhaf.module.InteractionModule; + +import java.io.IOException; @Data @Slf4j -public class MemoryManager { +public class MemoryManager implements InteractionModule { private static MemoryManager memoryManager; private MemoryGraph memoryGraph; private SliceEvaluator sliceEvaluator; - public static MemoryManager initialize(Config config){ + private MemoryManager(){} + + @Override + public void execute(InteractionContext interactionContext) { + + } + + public static MemoryManager getInstance() throws IOException, ClassNotFoundException { if (memoryManager == null) { + Config config = Config.getConfig(); memoryManager = new MemoryManager(); - memoryManager.setMemoryGraph(MemoryGraph.initialize(config.getAgentId())); + memoryManager.setMemoryGraph(MemoryGraph.getInstance(config.getAgentId())); memoryManager.setSliceEvaluator(SliceEvaluator.initialize(config)); log.info("MemoryManager注册完毕..."); } diff --git a/src/main/java/work/slhaf/agent/modules/memory/SliceEvaluator.java b/src/main/java/work/slhaf/agent/modules/memory/SliceEvaluator.java index 1c8b815d..6c6a7044 100644 --- a/src/main/java/work/slhaf/agent/modules/memory/SliceEvaluator.java +++ b/src/main/java/work/slhaf/agent/modules/memory/SliceEvaluator.java @@ -7,6 +7,8 @@ import work.slhaf.agent.common.config.Config; import work.slhaf.agent.common.model.Model; import work.slhaf.agent.common.model.ModelConstant; +import java.io.IOException; + @EqualsAndHashCode(callSuper = true) @Data @Slf4j @@ -15,7 +17,9 @@ public class SliceEvaluator extends Model { private static SliceEvaluator sliceEvaluator; - public static SliceEvaluator initialize(Config config) { + private SliceEvaluator(){} + + public static SliceEvaluator initialize(Config config) throws IOException, ClassNotFoundException { if (sliceEvaluator == null) { sliceEvaluator = new SliceEvaluator(); diff --git a/src/main/java/work/slhaf/agent/modules/memory/node/MemoryNode.java b/src/main/java/work/slhaf/agent/modules/memory/node/MemoryNode.java index 8d3782ff..62932ca4 100644 --- a/src/main/java/work/slhaf/agent/modules/memory/node/MemoryNode.java +++ b/src/main/java/work/slhaf/agent/modules/memory/node/MemoryNode.java @@ -8,6 +8,8 @@ import work.slhaf.agent.modules.memory.pojo.MemorySlice; import work.slhaf.agent.modules.memory.pojo.PersistableObject; import java.io.*; +import java.nio.file.Files; +import java.nio.file.Path; import java.time.LocalDate; import java.util.ArrayList; import java.util.List; @@ -20,7 +22,7 @@ public class MemoryNode extends PersistableObject implements Comparable getMemorySliceList() throws IOException, ClassNotFoundException { + public List loadMemorySliceList() throws IOException, ClassNotFoundException { //检查是否存在对应文件 File file = new File(SLICE_DATA_DIR+this.getMemoryNodeId()+".slice"); if (file.exists()){ @@ -64,6 +66,7 @@ public class MemoryNode extends PersistableObject implements Comparable