mirror of
https://github.com/slhaf/Partner.git
synced 2026-05-12 16:53:04 +08:00
refactor(agent): 明确模块化设计流程,具体逻辑待实现
- 调整配置文件路径 - 新增 InteractionModulesLoader 用于动态加载交互模块,加载扩展模块待实现 - 修复 MemoryGraph 和 MemoryNode 的部分逻辑 - 改进 ModelConfig 类,支持单独配置文件, 用于动态加载模块 - 新增 PreprocessExecutor 和 TaskEvaluator模块, 待后续实现
This commit is contained in:
@@ -8,5 +8,6 @@ import java.io.IOException;
|
||||
public class Main {
|
||||
public static void main(String[] args) throws IOException {
|
||||
Agent agent = Agent.initialize();
|
||||
agent.receiveUserInput("111","222","hello");
|
||||
}
|
||||
}
|
||||
@@ -23,9 +23,9 @@ public class Agent implements TaskCallback {
|
||||
public static Agent initialize() throws IOException {
|
||||
if (agent == null) {
|
||||
//加载配置
|
||||
Config config = Config.load();
|
||||
Config config = Config.getConfig();
|
||||
agent = new Agent();
|
||||
agent.setInteractionHub(InteractionHub.initialize(config));
|
||||
agent.setInteractionHub(InteractionHub.initialize());
|
||||
agent.registerTaskCallback();
|
||||
agent.setMessageSender(new AgentWebSocketServer(config.getWebSocketConfig().getPort(),agent));
|
||||
log.info("Agent 加载完毕..");
|
||||
@@ -37,7 +37,7 @@ public class Agent implements TaskCallback {
|
||||
* 接收用户输入,包装为标准输入数据类
|
||||
* @param input
|
||||
*/
|
||||
public void receiveUserInput(String userNickName,String userInfo,String input){
|
||||
public void receiveUserInput(String userNickName,String userInfo,String input) throws IOException {
|
||||
InteractionInputData inputData = new InteractionInputData();
|
||||
inputData.setContent(input);
|
||||
inputData.setUserInfo(userInfo);
|
||||
@@ -53,7 +53,7 @@ public class Agent implements TaskCallback {
|
||||
*/
|
||||
public void sendToUser(String userInfo,String output){
|
||||
System.out.println(output);
|
||||
messageSender.sendMessage(userInfo,output);
|
||||
// messageSender.sendMessage(userInfo,output);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
||||
@@ -5,6 +5,7 @@ import lombok.Data;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.io.FileUtils;
|
||||
import work.slhaf.agent.core.model.CoreModel;
|
||||
import work.slhaf.agent.modules.memory.MemoryManager;
|
||||
import work.slhaf.agent.modules.memory.SliceEvaluator;
|
||||
import work.slhaf.agent.modules.task.TaskScheduler;
|
||||
import work.slhaf.agent.modules.topic.TopicExtractor;
|
||||
@@ -12,81 +13,100 @@ import work.slhaf.agent.modules.topic.TopicExtractor;
|
||||
import java.io.File;
|
||||
import java.io.IOException;
|
||||
import java.nio.charset.StandardCharsets;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Scanner;
|
||||
|
||||
@Data
|
||||
@Slf4j
|
||||
public class Config {
|
||||
|
||||
private static final String CONFIG_FILE_PATH = "./data/config/config.json";
|
||||
private static final String CONFIG_FILE_PATH = "./config/config.json";
|
||||
private static Config config;
|
||||
|
||||
private String agentId;
|
||||
|
||||
private HashMap<String, ModelConfig> modelConfig;
|
||||
|
||||
private WebSocketConfig webSocketConfig;
|
||||
|
||||
public static Config load() throws IOException {
|
||||
private List<ModuleConfig> moduleConfigList;
|
||||
|
||||
private Config() {
|
||||
}
|
||||
|
||||
public static Config getConfig() throws IOException {
|
||||
if (config == null) {
|
||||
File file = new File(CONFIG_FILE_PATH);
|
||||
if (file.exists()) {
|
||||
config = JSONUtil.readJSONObject(file, StandardCharsets.UTF_8).toBean(Config.class);
|
||||
} else {
|
||||
Config tempConfig = new Config();
|
||||
config = new Config();
|
||||
Scanner scanner = new Scanner(System.in);
|
||||
|
||||
System.out.print("输入智能体名称: ");
|
||||
tempConfig.setAgentId(scanner.nextLine());
|
||||
config.setAgentId(scanner.nextLine());
|
||||
|
||||
System.out.println("\r\n--------模型配置--------\r\n");
|
||||
HashMap<String, ModelConfig> modelConfig = new HashMap<>();
|
||||
for (int i = 0; i < 4; i++) {
|
||||
String modelKey = switch (i) {
|
||||
case 0 -> {
|
||||
System.out.println("CoreModel:");
|
||||
yield CoreModel.MODEL_KEY;
|
||||
}
|
||||
case 1 -> {
|
||||
System.out.println("SliceEvaluator:");
|
||||
yield SliceEvaluator.MODEL_KEY;
|
||||
}
|
||||
case 2 -> {
|
||||
System.out.println("TaskTrigger:");
|
||||
yield TaskScheduler.MODEL_KEY;
|
||||
}
|
||||
case 3 -> {
|
||||
System.out.println("TopicExtractor:");
|
||||
yield TopicExtractor.MODEL_KEY;
|
||||
}
|
||||
default -> throw new RuntimeException();
|
||||
};
|
||||
System.out.println(modelKey);
|
||||
ModelConfig temp = new ModelConfig();
|
||||
System.out.print("apikey: ");
|
||||
temp.setApikey(scanner.nextLine());
|
||||
System.out.print("baseUrl: ");
|
||||
temp.setBaseUrl(scanner.nextLine());
|
||||
System.out.print("model: ");
|
||||
temp.setModel(scanner.nextLine());
|
||||
|
||||
modelConfig.put(modelKey, temp);
|
||||
}
|
||||
tempConfig.setModelConfig(modelConfig);
|
||||
generateModelConfig(scanner);
|
||||
|
||||
System.out.println("\r\n--------服务配置--------\r\n");
|
||||
System.out.print("WebSocket port: ");
|
||||
WebSocketConfig wsConfig = new WebSocketConfig();
|
||||
wsConfig.setPort(scanner.nextInt());
|
||||
generateWsSocketConfig(scanner);
|
||||
|
||||
System.out.println("\r\n--------模块链配置--------\r\n");
|
||||
generatePipelineConfig();
|
||||
|
||||
//保存配置文件
|
||||
String str = JSONUtil.toJsonPrettyStr(tempConfig);
|
||||
FileUtils.writeStringToFile(file,str,StandardCharsets.UTF_8);
|
||||
String str = JSONUtil.toJsonPrettyStr(config);
|
||||
FileUtils.writeStringToFile(file, str, StandardCharsets.UTF_8);
|
||||
log.info("配置已保存");
|
||||
config = tempConfig;
|
||||
}
|
||||
}
|
||||
return config;
|
||||
}
|
||||
|
||||
private static void generatePipelineConfig() {
|
||||
List<ModuleConfig> moduleConfigList = List.of(
|
||||
new ModuleConfig(TopicExtractor.class.getName(), ModuleConfig.Constant.INTERNAL, null),
|
||||
new ModuleConfig(MemoryManager.class.getName(), ModuleConfig.Constant.INTERNAL, null),
|
||||
new ModuleConfig(TaskScheduler.class.getName(), ModuleConfig.Constant.INTERNAL, null)
|
||||
);
|
||||
config.setModuleConfigList(moduleConfigList);
|
||||
}
|
||||
|
||||
private static void generateWsSocketConfig(Scanner scanner) {
|
||||
System.out.print("WebSocket port: ");
|
||||
WebSocketConfig wsConfig = new WebSocketConfig();
|
||||
wsConfig.setPort(scanner.nextInt());
|
||||
config.setWebSocketConfig(wsConfig);
|
||||
}
|
||||
|
||||
private static void generateModelConfig(Scanner scanner) throws IOException {
|
||||
for (int i = 0; i < 4; i++) {
|
||||
String modelKey = switch (i) {
|
||||
case 0 -> {
|
||||
System.out.println("CoreModel:");
|
||||
yield CoreModel.MODEL_KEY;
|
||||
}
|
||||
case 1 -> {
|
||||
System.out.println("SliceEvaluator:");
|
||||
yield SliceEvaluator.MODEL_KEY;
|
||||
}
|
||||
case 2 -> {
|
||||
System.out.println("TaskTrigger:");
|
||||
yield TaskScheduler.MODEL_KEY;
|
||||
}
|
||||
case 3 -> {
|
||||
System.out.println("TopicExtractor:");
|
||||
yield TopicExtractor.MODEL_KEY;
|
||||
}
|
||||
default -> throw new RuntimeException();
|
||||
};
|
||||
ModelConfig modelConfig = new ModelConfig();
|
||||
System.out.print("apikey: ");
|
||||
modelConfig.setApikey(scanner.nextLine());
|
||||
System.out.print("baseUrl: ");
|
||||
modelConfig.setBaseUrl(scanner.nextLine());
|
||||
System.out.print("model: ");
|
||||
modelConfig.setModel(scanner.nextLine());
|
||||
modelConfig.generateConfig(modelKey);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,10 +1,40 @@
|
||||
package work.slhaf.agent.common.config;
|
||||
|
||||
import cn.hutool.json.JSONUtil;
|
||||
import lombok.Data;
|
||||
import org.apache.commons.io.FileUtils;
|
||||
|
||||
import java.io.File;
|
||||
import java.io.IOException;
|
||||
import java.nio.charset.StandardCharsets;
|
||||
import java.util.HashMap;
|
||||
|
||||
@Data
|
||||
public class ModelConfig {
|
||||
|
||||
private static final String MODEL_CONFIG_DIR_PATH = "./config/model/";
|
||||
private static final HashMap<String, ModelConfig> modelConfigMap = new HashMap<>();
|
||||
|
||||
private String apikey;
|
||||
private String baseUrl;
|
||||
private String model;
|
||||
|
||||
public void generateConfig(String filename) throws IOException {
|
||||
String str = JSONUtil.toJsonPrettyStr(this);
|
||||
File file = new File(MODEL_CONFIG_DIR_PATH + filename + ".json");
|
||||
FileUtils.writeStringToFile(file, str, StandardCharsets.UTF_8);
|
||||
}
|
||||
|
||||
public static ModelConfig load(String modelKey) {
|
||||
if (!modelConfigMap.containsKey(modelKey)) {
|
||||
modelConfigMap.put(modelKey,loadConfig(modelKey));
|
||||
}
|
||||
|
||||
return modelConfigMap.get(modelKey);
|
||||
}
|
||||
|
||||
private static ModelConfig loadConfig(String modelKey) {
|
||||
File file = new File(MODEL_CONFIG_DIR_PATH+modelKey+".json");
|
||||
return JSONUtil.readJSONObject(file,StandardCharsets.UTF_8).toBean(ModelConfig.class);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,17 @@
|
||||
package work.slhaf.agent.common.config;
|
||||
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Data;
|
||||
|
||||
@Data
|
||||
@AllArgsConstructor
|
||||
public class ModuleConfig {
|
||||
private String className;
|
||||
private String type;
|
||||
private String path;
|
||||
|
||||
public static class Constant {
|
||||
public static final String INTERNAL = "internal";
|
||||
public static final String EXTERNAL = "external";
|
||||
}
|
||||
}
|
||||
@@ -8,6 +8,7 @@ import work.slhaf.agent.common.config.Config;
|
||||
import work.slhaf.agent.common.config.ModelConfig;
|
||||
import work.slhaf.agent.modules.memory.MemoryGraph;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
@@ -17,9 +18,9 @@ public class Model {
|
||||
protected String prompt;
|
||||
protected List<Message> messages;
|
||||
|
||||
protected static void setModel(Config config, Model model, String model_key, String prompt) {
|
||||
MemoryGraph memoryGraph = MemoryGraph.initialize(config.getAgentId());
|
||||
ModelConfig modelConfig = config.getModelConfig().get(model_key);
|
||||
protected static void setModel(Config config, Model model, String model_key, String prompt) throws IOException, ClassNotFoundException {
|
||||
MemoryGraph memoryGraph = MemoryGraph.getInstance(config.getAgentId());
|
||||
ModelConfig modelConfig = ModelConfig.load(model_key);
|
||||
if (memoryGraph.getModelPrompt().containsKey(model_key)) {
|
||||
model.setPrompt(memoryGraph.getModelPrompt().get(model_key));
|
||||
} else {
|
||||
|
||||
@@ -3,11 +3,16 @@ package work.slhaf.agent.core;
|
||||
import lombok.Data;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import work.slhaf.agent.common.config.Config;
|
||||
import work.slhaf.agent.core.interation.InteractionModulesLoader;
|
||||
import work.slhaf.agent.core.interation.TaskCallback;
|
||||
import work.slhaf.agent.core.interation.data.InteractionInputData;
|
||||
import work.slhaf.agent.core.model.CoreModel;
|
||||
import work.slhaf.agent.modules.memory.MemoryManager;
|
||||
import work.slhaf.agent.modules.task.TaskScheduler;
|
||||
import work.slhaf.module.InteractionModule;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.List;
|
||||
|
||||
@Data
|
||||
@Slf4j
|
||||
@@ -21,18 +26,16 @@ public class InteractionHub {
|
||||
private MemoryManager memoryManager;
|
||||
private TaskScheduler taskScheduler;
|
||||
|
||||
public static InteractionHub initialize(Config config) {
|
||||
public static InteractionHub initialize() throws IOException {
|
||||
if (interactionHub == null) {
|
||||
interactionHub = new InteractionHub();
|
||||
interactionHub.setCoreModel(CoreModel.initialize(config));
|
||||
interactionHub.setMemoryManager(MemoryManager.initialize(config));
|
||||
interactionHub.setTaskScheduler(TaskScheduler.initialize(config));
|
||||
log.info("InteractionHub注册完毕...");
|
||||
}
|
||||
return interactionHub;
|
||||
}
|
||||
|
||||
public void call(InteractionInputData inputData) {
|
||||
public void call(InteractionInputData inputData) throws IOException {
|
||||
List<InteractionModule> interactionModules = InteractionModulesLoader.registerInteractionModules();
|
||||
|
||||
callback.onTaskFinished(null, null);
|
||||
}
|
||||
|
||||
@@ -0,0 +1,34 @@
|
||||
package work.slhaf.agent.core.interation;
|
||||
|
||||
import work.slhaf.agent.common.config.Config;
|
||||
import work.slhaf.agent.common.config.ModuleConfig;
|
||||
import work.slhaf.module.InteractionModule;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.lang.reflect.InvocationTargetException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
public class InteractionModulesLoader {
|
||||
public static List<InteractionModule> registerInteractionModules() throws IOException {
|
||||
List<InteractionModule> moduleList = new ArrayList<>();
|
||||
List<ModuleConfig> moduleConfigList = Config.getConfig().getModuleConfigList();
|
||||
for (ModuleConfig moduleConfig : moduleConfigList) {
|
||||
if (ModuleConfig.Constant.INTERNAL.equals(moduleConfig.getType())) {
|
||||
moduleList.add(loadInternalModule(moduleConfig.getClassName()));
|
||||
}
|
||||
}
|
||||
return moduleList;
|
||||
}
|
||||
|
||||
private static InteractionModule loadInternalModule(String moduleName) {
|
||||
try {
|
||||
Class<?> clazz = Class.forName(moduleName);
|
||||
|
||||
//TODO 后续需要规范`getInstance`方法的实现
|
||||
return (InteractionModule) clazz.getMethod("getInstance").invoke(null);
|
||||
} catch (ClassNotFoundException | InvocationTargetException | IllegalAccessException | NoSuchMethodException e) {
|
||||
throw new RuntimeException("Fail to load internal module: " + moduleName,e);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -7,6 +7,8 @@ import work.slhaf.agent.common.config.Config;
|
||||
import work.slhaf.agent.common.model.Model;
|
||||
import work.slhaf.agent.common.model.ModelConstant;
|
||||
|
||||
import java.io.IOException;
|
||||
|
||||
@EqualsAndHashCode(callSuper = true)
|
||||
@Data
|
||||
@Slf4j
|
||||
@@ -15,8 +17,11 @@ public class CoreModel extends Model {
|
||||
public static final String MODEL_KEY = "core_model";
|
||||
private static CoreModel coreModel;
|
||||
|
||||
public static CoreModel initialize(Config config) {
|
||||
private CoreModel(){}
|
||||
|
||||
public static CoreModel getInstance() throws IOException, ClassNotFoundException {
|
||||
if (coreModel == null) {
|
||||
Config config = Config.getConfig();
|
||||
coreModel = new CoreModel();
|
||||
coreModel.setPrompt(ModelConstant.CORE_MODEL_PROMPT);
|
||||
setModel(config, coreModel, MODEL_KEY, coreModel.getPrompt());
|
||||
|
||||
@@ -10,6 +10,7 @@ import work.slhaf.agent.Agent;
|
||||
import work.slhaf.agent.core.interation.data.InteractionInputData;
|
||||
import work.slhaf.agent.core.interation.data.InteractionOutputData;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.net.InetSocketAddress;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
|
||||
@@ -39,7 +40,11 @@ public class AgentWebSocketServer extends WebSocketServer implements MessageSend
|
||||
public void onMessage(WebSocket webSocket, String s) {
|
||||
InteractionInputData inputData = JSONObject.parseObject(s, InteractionInputData.class);
|
||||
userSessions.put(inputData.getUserInfo(), webSocket); // 注册连接
|
||||
agent.receiveUserInput(inputData.getUserNickName(), inputData.getUserInfo(), inputData.getContent());
|
||||
try {
|
||||
agent.receiveUserInput(inputData.getUserNickName(), inputData.getUserInfo(), inputData.getContent());
|
||||
} catch (IOException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
|
||||
@@ -3,6 +3,7 @@ package work.slhaf.agent.modules.memory;
|
||||
import lombok.Data;
|
||||
import lombok.EqualsAndHashCode;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.io.FileUtils;
|
||||
import work.slhaf.agent.common.chat.pojo.Message;
|
||||
import work.slhaf.agent.modules.memory.exception.UnExistedTopicException;
|
||||
import work.slhaf.agent.modules.memory.node.MemoryNode;
|
||||
@@ -36,7 +37,7 @@ public class MemoryGraph extends PersistableObject {
|
||||
* key: 根主题名称 value: 根主题节点
|
||||
*/
|
||||
private HashMap<String, TopicNode> topicNodes;
|
||||
public static MemoryGraph memoryGraph;
|
||||
private static MemoryGraph memoryGraph;
|
||||
|
||||
/**
|
||||
* 用于存储已存在的主题列表,便于记忆查找, 使用根主题名称作为键, 子主题名称集合为值
|
||||
@@ -110,31 +111,27 @@ public class MemoryGraph extends PersistableObject {
|
||||
this.modelPrompt = new HashMap<>();
|
||||
}
|
||||
|
||||
public static MemoryGraph initialize(String id) {
|
||||
public static MemoryGraph getInstance(String id) throws IOException, ClassNotFoundException {
|
||||
// 检查存储目录是否存在,不存在则创建
|
||||
createStorageDirectory();
|
||||
|
||||
Path filePath = getFilePath(id);
|
||||
if (memoryGraph == null && Files.exists(filePath)) {
|
||||
try {
|
||||
// 从文件加载
|
||||
if (memoryGraph == null) {
|
||||
Path filePath = getFilePath(id);
|
||||
if (Files.exists(filePath)) {
|
||||
memoryGraph = deserialize(id);
|
||||
} catch (Exception e) {
|
||||
log.error("加载序列化文件失败,创建新实例");
|
||||
System.exit(1);
|
||||
}else {
|
||||
FileUtils.createParentDirectories(filePath.toFile().getParentFile());
|
||||
memoryGraph = new MemoryGraph(id);
|
||||
memoryGraph.serialize();
|
||||
}
|
||||
} else {
|
||||
// 创建新实例
|
||||
memoryGraph = new MemoryGraph(id);
|
||||
log.info("MemoryGraph注册完毕...");
|
||||
}
|
||||
log.info("MemoryGraph注册完毕...");
|
||||
|
||||
return memoryGraph;
|
||||
}
|
||||
|
||||
public void serialize() {
|
||||
public void serialize() throws IOException {
|
||||
Path filePath = getFilePath(this.id);
|
||||
|
||||
Files.createDirectories(Path.of(STORAGE_DIR));
|
||||
try (ObjectOutputStream oos = new ObjectOutputStream(
|
||||
new FileOutputStream(filePath.toFile()))) {
|
||||
oos.writeObject(this);
|
||||
@@ -193,7 +190,7 @@ public class MemoryGraph extends PersistableObject {
|
||||
lastTopicNode.getMemoryNodes().add(node);
|
||||
lastTopicNode.getMemoryNodes().sort(null);
|
||||
}
|
||||
node.getMemorySliceList().add(slice);
|
||||
node.loadMemorySliceList().add(slice);
|
||||
|
||||
//生成relatedTopicPath
|
||||
for (List<String> relatedTopic : slice.getRelatedTopics()) {
|
||||
@@ -321,7 +318,7 @@ public class MemoryGraph extends PersistableObject {
|
||||
//终点记忆节点
|
||||
MemorySliceResult sliceResult = new MemorySliceResult();
|
||||
for (MemoryNode memoryNode : targetParentNode.getTopicNodes().get(targetTopic).getMemoryNodes()) {
|
||||
List<MemorySlice> endpointMemorySliceList = memoryNode.getMemorySliceList();
|
||||
List<MemorySlice> endpointMemorySliceList = memoryNode.loadMemorySliceList();
|
||||
// targetSliceList.addAll(endpointMemorySliceList);
|
||||
for (MemorySlice memorySlice : endpointMemorySliceList) {
|
||||
sliceResult.setSliceBefore(memorySlice.getSliceBefore());
|
||||
@@ -348,14 +345,14 @@ public class MemoryGraph extends PersistableObject {
|
||||
TopicNode tempTargetNode = tempTargetParentNode.getTopicNodes().get(tempTopicPath.getLast());
|
||||
List<MemoryNode> tempMemoryNodes = tempTargetNode.getMemoryNodes();
|
||||
if (!tempMemoryNodes.isEmpty()) {
|
||||
relatedMemorySlice.addAll(tempMemoryNodes.getFirst().getMemorySliceList());
|
||||
relatedMemorySlice.addAll(tempMemoryNodes.getFirst().loadMemorySliceList());
|
||||
}
|
||||
}
|
||||
|
||||
//邻近记忆节点 父级
|
||||
List<MemoryNode> targetParentMemoryNodes = targetParentNode.getMemoryNodes();
|
||||
if (!targetParentMemoryNodes.isEmpty()) {
|
||||
relatedMemorySlice.addAll(targetParentMemoryNodes.getFirst().getMemorySliceList());
|
||||
relatedMemorySlice.addAll(targetParentMemoryNodes.getFirst().loadMemorySliceList());
|
||||
}
|
||||
|
||||
//将上述结果包装为MemoryResult
|
||||
|
||||
@@ -3,20 +3,32 @@ package work.slhaf.agent.modules.memory;
|
||||
import lombok.Data;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import work.slhaf.agent.common.config.Config;
|
||||
import work.slhaf.module.InteractionContext;
|
||||
import work.slhaf.module.InteractionModule;
|
||||
|
||||
import java.io.IOException;
|
||||
|
||||
@Data
|
||||
@Slf4j
|
||||
public class MemoryManager {
|
||||
public class MemoryManager implements InteractionModule {
|
||||
|
||||
private static MemoryManager memoryManager;
|
||||
|
||||
private MemoryGraph memoryGraph;
|
||||
private SliceEvaluator sliceEvaluator;
|
||||
|
||||
public static MemoryManager initialize(Config config){
|
||||
private MemoryManager(){}
|
||||
|
||||
@Override
|
||||
public void execute(InteractionContext interactionContext) {
|
||||
|
||||
}
|
||||
|
||||
public static MemoryManager getInstance() throws IOException, ClassNotFoundException {
|
||||
if (memoryManager == null) {
|
||||
Config config = Config.getConfig();
|
||||
memoryManager = new MemoryManager();
|
||||
memoryManager.setMemoryGraph(MemoryGraph.initialize(config.getAgentId()));
|
||||
memoryManager.setMemoryGraph(MemoryGraph.getInstance(config.getAgentId()));
|
||||
memoryManager.setSliceEvaluator(SliceEvaluator.initialize(config));
|
||||
log.info("MemoryManager注册完毕...");
|
||||
}
|
||||
|
||||
@@ -7,6 +7,8 @@ import work.slhaf.agent.common.config.Config;
|
||||
import work.slhaf.agent.common.model.Model;
|
||||
import work.slhaf.agent.common.model.ModelConstant;
|
||||
|
||||
import java.io.IOException;
|
||||
|
||||
@EqualsAndHashCode(callSuper = true)
|
||||
@Data
|
||||
@Slf4j
|
||||
@@ -15,7 +17,9 @@ public class SliceEvaluator extends Model {
|
||||
|
||||
private static SliceEvaluator sliceEvaluator;
|
||||
|
||||
public static SliceEvaluator initialize(Config config) {
|
||||
private SliceEvaluator(){}
|
||||
|
||||
public static SliceEvaluator initialize(Config config) throws IOException, ClassNotFoundException {
|
||||
|
||||
if (sliceEvaluator == null) {
|
||||
sliceEvaluator = new SliceEvaluator();
|
||||
|
||||
@@ -8,6 +8,8 @@ import work.slhaf.agent.modules.memory.pojo.MemorySlice;
|
||||
import work.slhaf.agent.modules.memory.pojo.PersistableObject;
|
||||
|
||||
import java.io.*;
|
||||
import java.nio.file.Files;
|
||||
import java.nio.file.Path;
|
||||
import java.time.LocalDate;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
@@ -20,7 +22,7 @@ public class MemoryNode extends PersistableObject implements Comparable<MemoryNo
|
||||
@Serial
|
||||
private static final long serialVersionUID = 1L;
|
||||
|
||||
private static String SLICE_DATA_DIR = "./data/slice/";
|
||||
private static String SLICE_DATA_DIR = "./data/memory/slice/";
|
||||
|
||||
/**
|
||||
* 记忆节点唯一标识, 用于作为实际文件名, 如(xxxx-xxxxx-xxxxx.slice)
|
||||
@@ -47,7 +49,7 @@ public class MemoryNode extends PersistableObject implements Comparable<MemoryNo
|
||||
return 0;
|
||||
}
|
||||
|
||||
public List<MemorySlice> getMemorySliceList() throws IOException, ClassNotFoundException {
|
||||
public List<MemorySlice> loadMemorySliceList() throws IOException, ClassNotFoundException {
|
||||
//检查是否存在对应文件
|
||||
File file = new File(SLICE_DATA_DIR+this.getMemoryNodeId()+".slice");
|
||||
if (file.exists()){
|
||||
@@ -64,6 +66,7 @@ public class MemoryNode extends PersistableObject implements Comparable<MemoryNo
|
||||
throw new NullSliceListException("memorySliceList为NULL! 检查实现逻辑!");
|
||||
}
|
||||
File file = new File(SLICE_DATA_DIR+this.getMemoryNodeId()+".slice");
|
||||
Files.createDirectories(Path.of(SLICE_DATA_DIR));
|
||||
try(ObjectOutputStream oos = new ObjectOutputStream(new FileOutputStream(file))){
|
||||
oos.writeObject(this.memorySliceList);
|
||||
}
|
||||
|
||||
@@ -0,0 +1,29 @@
|
||||
package work.slhaf.agent.modules.preprocess;
|
||||
|
||||
import work.slhaf.agent.core.interation.data.InteractionInputData;
|
||||
import work.slhaf.module.InteractionContext;
|
||||
|
||||
public class PreprocessExecutor {
|
||||
|
||||
private static PreprocessExecutor preprocessExecutor;
|
||||
|
||||
private PreprocessExecutor(){}
|
||||
|
||||
public static PreprocessExecutor getInstance() {
|
||||
if (preprocessExecutor == null) {
|
||||
preprocessExecutor = new PreprocessExecutor();
|
||||
}
|
||||
return preprocessExecutor;
|
||||
}
|
||||
|
||||
public InteractionContext execute(InteractionInputData inputData) {
|
||||
InteractionContext context = new InteractionContext();
|
||||
context.setDateTime(inputData.getLocalDateTime());
|
||||
context.setFinished(false);
|
||||
context.setInput(inputData.getContent());
|
||||
context.setUserInfo(inputData.getUserInfo());
|
||||
context.setUserNickname(inputData.getUserNickName());
|
||||
|
||||
return context;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,4 @@
|
||||
package work.slhaf.agent.modules.task;
|
||||
|
||||
public class TaskEvaluator {
|
||||
}
|
||||
@@ -6,16 +6,23 @@ import lombok.extern.slf4j.Slf4j;
|
||||
import work.slhaf.agent.common.config.Config;
|
||||
import work.slhaf.agent.common.model.Model;
|
||||
import work.slhaf.agent.common.model.ModelConstant;
|
||||
import work.slhaf.module.InteractionContext;
|
||||
import work.slhaf.module.InteractionModule;
|
||||
|
||||
import java.io.IOException;
|
||||
|
||||
@EqualsAndHashCode(callSuper = true)
|
||||
@Data
|
||||
@Slf4j
|
||||
public class TaskScheduler extends Model {
|
||||
public class TaskScheduler extends Model implements InteractionModule {
|
||||
public static final String MODEL_KEY = "task_trigger";
|
||||
private static TaskScheduler taskScheduler;
|
||||
public static TaskScheduler initialize(Config config) {
|
||||
|
||||
private TaskScheduler(){}
|
||||
|
||||
public static TaskScheduler getInstance() throws IOException, ClassNotFoundException {
|
||||
if (taskScheduler == null) {
|
||||
Config config = Config.getConfig();
|
||||
taskScheduler = new TaskScheduler();
|
||||
taskScheduler.setPrompt(ModelConstant.SLICE_EVALUATOR_PROMPT);
|
||||
setModel(config, taskScheduler, MODEL_KEY, taskScheduler.getPrompt());
|
||||
@@ -25,4 +32,8 @@ public class TaskScheduler extends Model {
|
||||
return taskScheduler;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void execute(InteractionContext interactionContext) {
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,22 +5,33 @@ import lombok.EqualsAndHashCode;
|
||||
import work.slhaf.agent.common.config.Config;
|
||||
import work.slhaf.agent.common.model.Model;
|
||||
import work.slhaf.agent.common.model.ModelConstant;
|
||||
import work.slhaf.module.InteractionContext;
|
||||
import work.slhaf.module.InteractionModule;
|
||||
|
||||
import java.io.IOException;
|
||||
|
||||
@EqualsAndHashCode(callSuper = true)
|
||||
@Data
|
||||
public class TopicExtractor extends Model {
|
||||
public class TopicExtractor extends Model implements InteractionModule {
|
||||
public static final String MODEL_KEY = "topic_extractor";
|
||||
private static TopicExtractor topicExtractor;
|
||||
|
||||
public static TopicExtractor initialize(Config config) {
|
||||
private TopicExtractor() {
|
||||
}
|
||||
|
||||
public static TopicExtractor getInstance() throws IOException, ClassNotFoundException {
|
||||
if (topicExtractor == null) {
|
||||
Config config = Config.getConfig();
|
||||
topicExtractor = new TopicExtractor();
|
||||
topicExtractor.setPrompt(ModelConstant.SLICE_EVALUATOR_PROMPT);
|
||||
setModel(config,topicExtractor, MODEL_KEY, topicExtractor.getPrompt());
|
||||
setModel(config, topicExtractor, MODEL_KEY, topicExtractor.getPrompt());
|
||||
}
|
||||
|
||||
return topicExtractor;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void execute(InteractionContext interactionContext) {
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
@@ -49,8 +49,8 @@ public class InsertTest {
|
||||
assertEquals(1, collectionsNode.getMemoryNodes().size());
|
||||
MemoryNode memoryNode = collectionsNode.getMemoryNodes().get(0);
|
||||
assertEquals(LocalDate.now(), memoryNode.getLocalDate());
|
||||
assertEquals(1, memoryNode.getMemorySliceList().size());
|
||||
assertEquals(slice, memoryNode.getMemorySliceList().get(0));
|
||||
assertEquals(1, memoryNode.loadMemorySliceList().size());
|
||||
assertEquals(slice, memoryNode.loadMemorySliceList().get(0));
|
||||
}
|
||||
|
||||
@Test
|
||||
@@ -71,7 +71,7 @@ public class InsertTest {
|
||||
.getTopicNodes().get("Collections");
|
||||
|
||||
assertEquals(1, collectionsNode.getMemoryNodes().size()); // 同一天应该只有一个MemoryNode
|
||||
assertEquals(2, collectionsNode.getMemoryNodes().get(0).getMemorySliceList().size()); // 但有两个MemorySlice
|
||||
assertEquals(2, collectionsNode.getMemoryNodes().get(0).loadMemorySliceList().size()); // 但有两个MemorySlice
|
||||
}
|
||||
|
||||
@Test
|
||||
@@ -141,7 +141,7 @@ public class InsertTest {
|
||||
memoryGraph.serialize();
|
||||
|
||||
// 反序列化
|
||||
MemoryGraph loadedGraph = MemoryGraph.initialize(testId);
|
||||
MemoryGraph loadedGraph = MemoryGraph.getInstance(testId);
|
||||
|
||||
// 校验:topic 是否存在
|
||||
assertNotNull(loadedGraph.getTopicNodes().get("生活"));
|
||||
@@ -157,7 +157,7 @@ public class InsertTest {
|
||||
assertFalse(javaNode.getMemoryNodes().isEmpty());
|
||||
|
||||
// 校验:MemorySlice 内容一致
|
||||
MemorySlice deserializedSlice = javaNode.getMemoryNodes().get(0).getMemorySliceList().get(0);
|
||||
MemorySlice deserializedSlice = javaNode.getMemoryNodes().get(0).loadMemorySliceList().get(0);
|
||||
assertEquals("001", deserializedSlice.getMemoryId());
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user