refactor(agent): 明确模块化设计流程,具体逻辑待实现

- 调整配置文件路径
- 新增 InteractionModulesLoader 用于动态加载交互模块,加载扩展模块待实现
- 修复 MemoryGraph 和 MemoryNode 的部分逻辑
- 改进 ModelConfig 类,支持单独配置文件, 用于动态加载模块
- 新增 PreprocessExecutor 和 TaskEvaluator模块, 待后续实现
This commit is contained in:
2025-04-17 23:12:13 +08:00
parent 27719b7c11
commit 34c6b861c8
22 changed files with 293 additions and 96 deletions

1
.gitignore vendored
View File

@@ -37,3 +37,4 @@ build/
### Mac OS ###
.DS_Store
/data/
/config/

4
.idea/misc.xml generated
View File

@@ -8,6 +8,10 @@
</list>
</option>
</component>
<component name="PWA">
<option name="enabled" value="true" />
<option name="wasEnabledAtLeastOnce" value="true" />
</component>
<component name="ProjectRootManager" version="2" languageLevel="JDK_21" default="true" project-jdk-name="21" project-jdk-type="JavaSDK">
<output url="file://$PROJECT_DIR$/out" />
</component>

View File

@@ -74,6 +74,11 @@
<artifactId>hutool-all</artifactId>
<version>5.8.36</version>
</dependency>
<dependency>
<groupId>work.slhaf</groupId>
<artifactId>Partner-Modules-Api</artifactId>
<version>1.0-SNAPSHOT</version>
</dependency>
</dependencies>
</project>

View File

@@ -8,5 +8,6 @@ import java.io.IOException;
public class Main {
public static void main(String[] args) throws IOException {
Agent agent = Agent.initialize();
agent.receiveUserInput("111","222","hello");
}
}

View File

@@ -23,9 +23,9 @@ public class Agent implements TaskCallback {
public static Agent initialize() throws IOException {
if (agent == null) {
//加载配置
Config config = Config.load();
Config config = Config.getConfig();
agent = new Agent();
agent.setInteractionHub(InteractionHub.initialize(config));
agent.setInteractionHub(InteractionHub.initialize());
agent.registerTaskCallback();
agent.setMessageSender(new AgentWebSocketServer(config.getWebSocketConfig().getPort(),agent));
log.info("Agent 加载完毕..");
@@ -37,7 +37,7 @@ public class Agent implements TaskCallback {
* 接收用户输入,包装为标准输入数据类
* @param input
*/
public void receiveUserInput(String userNickName,String userInfo,String input){
public void receiveUserInput(String userNickName,String userInfo,String input) throws IOException {
InteractionInputData inputData = new InteractionInputData();
inputData.setContent(input);
inputData.setUserInfo(userInfo);
@@ -53,7 +53,7 @@ public class Agent implements TaskCallback {
*/
public void sendToUser(String userInfo,String output){
System.out.println(output);
messageSender.sendMessage(userInfo,output);
// messageSender.sendMessage(userInfo,output);
}
@Override

View File

@@ -5,6 +5,7 @@ import lombok.Data;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.io.FileUtils;
import work.slhaf.agent.core.model.CoreModel;
import work.slhaf.agent.modules.memory.MemoryManager;
import work.slhaf.agent.modules.memory.SliceEvaluator;
import work.slhaf.agent.modules.task.TaskScheduler;
import work.slhaf.agent.modules.topic.TopicExtractor;
@@ -12,36 +13,72 @@ import work.slhaf.agent.modules.topic.TopicExtractor;
import java.io.File;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.util.HashMap;
import java.util.List;
import java.util.Scanner;
@Data
@Slf4j
public class Config {
private static final String CONFIG_FILE_PATH = "./data/config/config.json";
private static final String CONFIG_FILE_PATH = "./config/config.json";
private static Config config;
private String agentId;
private HashMap<String, ModelConfig> modelConfig;
private WebSocketConfig webSocketConfig;
public static Config load() throws IOException {
private List<ModuleConfig> moduleConfigList;
private Config() {
}
public static Config getConfig() throws IOException {
if (config == null) {
File file = new File(CONFIG_FILE_PATH);
if (file.exists()) {
config = JSONUtil.readJSONObject(file, StandardCharsets.UTF_8).toBean(Config.class);
} else {
Config tempConfig = new Config();
config = new Config();
Scanner scanner = new Scanner(System.in);
System.out.print("输入智能体名称: ");
tempConfig.setAgentId(scanner.nextLine());
config.setAgentId(scanner.nextLine());
System.out.println("\r\n--------模型配置--------\r\n");
HashMap<String, ModelConfig> modelConfig = new HashMap<>();
generateModelConfig(scanner);
System.out.println("\r\n--------服务配置--------\r\n");
generateWsSocketConfig(scanner);
System.out.println("\r\n--------模块链配置--------\r\n");
generatePipelineConfig();
//保存配置文件
String str = JSONUtil.toJsonPrettyStr(config);
FileUtils.writeStringToFile(file, str, StandardCharsets.UTF_8);
log.info("配置已保存");
}
}
return config;
}
private static void generatePipelineConfig() {
List<ModuleConfig> moduleConfigList = List.of(
new ModuleConfig(TopicExtractor.class.getName(), ModuleConfig.Constant.INTERNAL, null),
new ModuleConfig(MemoryManager.class.getName(), ModuleConfig.Constant.INTERNAL, null),
new ModuleConfig(TaskScheduler.class.getName(), ModuleConfig.Constant.INTERNAL, null)
);
config.setModuleConfigList(moduleConfigList);
}
private static void generateWsSocketConfig(Scanner scanner) {
System.out.print("WebSocket port: ");
WebSocketConfig wsConfig = new WebSocketConfig();
wsConfig.setPort(scanner.nextInt());
config.setWebSocketConfig(wsConfig);
}
private static void generateModelConfig(Scanner scanner) throws IOException {
for (int i = 0; i < 4; i++) {
String modelKey = switch (i) {
case 0 -> {
@@ -62,31 +99,14 @@ public class Config {
}
default -> throw new RuntimeException();
};
System.out.println(modelKey);
ModelConfig temp = new ModelConfig();
ModelConfig modelConfig = new ModelConfig();
System.out.print("apikey: ");
temp.setApikey(scanner.nextLine());
modelConfig.setApikey(scanner.nextLine());
System.out.print("baseUrl: ");
temp.setBaseUrl(scanner.nextLine());
modelConfig.setBaseUrl(scanner.nextLine());
System.out.print("model: ");
temp.setModel(scanner.nextLine());
modelConfig.put(modelKey, temp);
modelConfig.setModel(scanner.nextLine());
modelConfig.generateConfig(modelKey);
}
tempConfig.setModelConfig(modelConfig);
System.out.println("\r\n--------服务配置--------\r\n");
System.out.print("WebSocket port: ");
WebSocketConfig wsConfig = new WebSocketConfig();
wsConfig.setPort(scanner.nextInt());
//保存配置文件
String str = JSONUtil.toJsonPrettyStr(tempConfig);
FileUtils.writeStringToFile(file,str,StandardCharsets.UTF_8);
log.info("配置已保存");
config = tempConfig;
}
}
return config;
}
}

View File

@@ -1,10 +1,40 @@
package work.slhaf.agent.common.config;
import cn.hutool.json.JSONUtil;
import lombok.Data;
import org.apache.commons.io.FileUtils;
import java.io.File;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.util.HashMap;
@Data
public class ModelConfig {
private static final String MODEL_CONFIG_DIR_PATH = "./config/model/";
private static final HashMap<String, ModelConfig> modelConfigMap = new HashMap<>();
private String apikey;
private String baseUrl;
private String model;
public void generateConfig(String filename) throws IOException {
String str = JSONUtil.toJsonPrettyStr(this);
File file = new File(MODEL_CONFIG_DIR_PATH + filename + ".json");
FileUtils.writeStringToFile(file, str, StandardCharsets.UTF_8);
}
public static ModelConfig load(String modelKey) {
if (!modelConfigMap.containsKey(modelKey)) {
modelConfigMap.put(modelKey,loadConfig(modelKey));
}
return modelConfigMap.get(modelKey);
}
private static ModelConfig loadConfig(String modelKey) {
File file = new File(MODEL_CONFIG_DIR_PATH+modelKey+".json");
return JSONUtil.readJSONObject(file,StandardCharsets.UTF_8).toBean(ModelConfig.class);
}
}

View File

@@ -0,0 +1,17 @@
package work.slhaf.agent.common.config;
import lombok.AllArgsConstructor;
import lombok.Data;
@Data
@AllArgsConstructor
public class ModuleConfig {
private String className;
private String type;
private String path;
public static class Constant {
public static final String INTERNAL = "internal";
public static final String EXTERNAL = "external";
}
}

View File

@@ -8,6 +8,7 @@ import work.slhaf.agent.common.config.Config;
import work.slhaf.agent.common.config.ModelConfig;
import work.slhaf.agent.modules.memory.MemoryGraph;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
@@ -17,9 +18,9 @@ public class Model {
protected String prompt;
protected List<Message> messages;
protected static void setModel(Config config, Model model, String model_key, String prompt) {
MemoryGraph memoryGraph = MemoryGraph.initialize(config.getAgentId());
ModelConfig modelConfig = config.getModelConfig().get(model_key);
protected static void setModel(Config config, Model model, String model_key, String prompt) throws IOException, ClassNotFoundException {
MemoryGraph memoryGraph = MemoryGraph.getInstance(config.getAgentId());
ModelConfig modelConfig = ModelConfig.load(model_key);
if (memoryGraph.getModelPrompt().containsKey(model_key)) {
model.setPrompt(memoryGraph.getModelPrompt().get(model_key));
} else {

View File

@@ -3,11 +3,16 @@ package work.slhaf.agent.core;
import lombok.Data;
import lombok.extern.slf4j.Slf4j;
import work.slhaf.agent.common.config.Config;
import work.slhaf.agent.core.interation.InteractionModulesLoader;
import work.slhaf.agent.core.interation.TaskCallback;
import work.slhaf.agent.core.interation.data.InteractionInputData;
import work.slhaf.agent.core.model.CoreModel;
import work.slhaf.agent.modules.memory.MemoryManager;
import work.slhaf.agent.modules.task.TaskScheduler;
import work.slhaf.module.InteractionModule;
import java.io.IOException;
import java.util.List;
@Data
@Slf4j
@@ -21,18 +26,16 @@ public class InteractionHub {
private MemoryManager memoryManager;
private TaskScheduler taskScheduler;
public static InteractionHub initialize(Config config) {
public static InteractionHub initialize() throws IOException {
if (interactionHub == null) {
interactionHub = new InteractionHub();
interactionHub.setCoreModel(CoreModel.initialize(config));
interactionHub.setMemoryManager(MemoryManager.initialize(config));
interactionHub.setTaskScheduler(TaskScheduler.initialize(config));
log.info("InteractionHub注册完毕...");
}
return interactionHub;
}
public void call(InteractionInputData inputData) {
public void call(InteractionInputData inputData) throws IOException {
List<InteractionModule> interactionModules = InteractionModulesLoader.registerInteractionModules();
callback.onTaskFinished(null, null);
}

View File

@@ -0,0 +1,34 @@
package work.slhaf.agent.core.interation;
import work.slhaf.agent.common.config.Config;
import work.slhaf.agent.common.config.ModuleConfig;
import work.slhaf.module.InteractionModule;
import java.io.IOException;
import java.lang.reflect.InvocationTargetException;
import java.util.ArrayList;
import java.util.List;
public class InteractionModulesLoader {
public static List<InteractionModule> registerInteractionModules() throws IOException {
List<InteractionModule> moduleList = new ArrayList<>();
List<ModuleConfig> moduleConfigList = Config.getConfig().getModuleConfigList();
for (ModuleConfig moduleConfig : moduleConfigList) {
if (ModuleConfig.Constant.INTERNAL.equals(moduleConfig.getType())) {
moduleList.add(loadInternalModule(moduleConfig.getClassName()));
}
}
return moduleList;
}
private static InteractionModule loadInternalModule(String moduleName) {
try {
Class<?> clazz = Class.forName(moduleName);
//TODO 后续需要规范`getInstance`方法的实现
return (InteractionModule) clazz.getMethod("getInstance").invoke(null);
} catch (ClassNotFoundException | InvocationTargetException | IllegalAccessException | NoSuchMethodException e) {
throw new RuntimeException("Fail to load internal module: " + moduleName,e);
}
}
}

View File

@@ -7,6 +7,8 @@ import work.slhaf.agent.common.config.Config;
import work.slhaf.agent.common.model.Model;
import work.slhaf.agent.common.model.ModelConstant;
import java.io.IOException;
@EqualsAndHashCode(callSuper = true)
@Data
@Slf4j
@@ -15,8 +17,11 @@ public class CoreModel extends Model {
public static final String MODEL_KEY = "core_model";
private static CoreModel coreModel;
public static CoreModel initialize(Config config) {
private CoreModel(){}
public static CoreModel getInstance() throws IOException, ClassNotFoundException {
if (coreModel == null) {
Config config = Config.getConfig();
coreModel = new CoreModel();
coreModel.setPrompt(ModelConstant.CORE_MODEL_PROMPT);
setModel(config, coreModel, MODEL_KEY, coreModel.getPrompt());

View File

@@ -10,6 +10,7 @@ import work.slhaf.agent.Agent;
import work.slhaf.agent.core.interation.data.InteractionInputData;
import work.slhaf.agent.core.interation.data.InteractionOutputData;
import java.io.IOException;
import java.net.InetSocketAddress;
import java.util.concurrent.ConcurrentHashMap;
@@ -39,7 +40,11 @@ public class AgentWebSocketServer extends WebSocketServer implements MessageSend
public void onMessage(WebSocket webSocket, String s) {
InteractionInputData inputData = JSONObject.parseObject(s, InteractionInputData.class);
userSessions.put(inputData.getUserInfo(), webSocket); // 注册连接
try {
agent.receiveUserInput(inputData.getUserNickName(), inputData.getUserInfo(), inputData.getContent());
} catch (IOException e) {
throw new RuntimeException(e);
}
}
@Override

View File

@@ -3,6 +3,7 @@ package work.slhaf.agent.modules.memory;
import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.io.FileUtils;
import work.slhaf.agent.common.chat.pojo.Message;
import work.slhaf.agent.modules.memory.exception.UnExistedTopicException;
import work.slhaf.agent.modules.memory.node.MemoryNode;
@@ -36,7 +37,7 @@ public class MemoryGraph extends PersistableObject {
* key: 根主题名称 value: 根主题节点
*/
private HashMap<String, TopicNode> topicNodes;
public static MemoryGraph memoryGraph;
private static MemoryGraph memoryGraph;
/**
* 用于存储已存在的主题列表,便于记忆查找, 使用根主题名称作为键, 子主题名称集合为值
@@ -110,31 +111,27 @@ public class MemoryGraph extends PersistableObject {
this.modelPrompt = new HashMap<>();
}
public static MemoryGraph initialize(String id) {
public static MemoryGraph getInstance(String id) throws IOException, ClassNotFoundException {
// 检查存储目录是否存在,不存在则创建
createStorageDirectory();
if (memoryGraph == null) {
Path filePath = getFilePath(id);
if (memoryGraph == null && Files.exists(filePath)) {
try {
// 从文件加载
if (Files.exists(filePath)) {
memoryGraph = deserialize(id);
} catch (Exception e) {
log.error("加载序列化文件失败,创建新实例");
System.exit(1);
}
} else {
// 创建新实例
}else {
FileUtils.createParentDirectories(filePath.toFile().getParentFile());
memoryGraph = new MemoryGraph(id);
memoryGraph.serialize();
}
log.info("MemoryGraph注册完毕...");
}
return memoryGraph;
}
public void serialize() {
public void serialize() throws IOException {
Path filePath = getFilePath(this.id);
Files.createDirectories(Path.of(STORAGE_DIR));
try (ObjectOutputStream oos = new ObjectOutputStream(
new FileOutputStream(filePath.toFile()))) {
oos.writeObject(this);
@@ -193,7 +190,7 @@ public class MemoryGraph extends PersistableObject {
lastTopicNode.getMemoryNodes().add(node);
lastTopicNode.getMemoryNodes().sort(null);
}
node.getMemorySliceList().add(slice);
node.loadMemorySliceList().add(slice);
//生成relatedTopicPath
for (List<String> relatedTopic : slice.getRelatedTopics()) {
@@ -321,7 +318,7 @@ public class MemoryGraph extends PersistableObject {
//终点记忆节点
MemorySliceResult sliceResult = new MemorySliceResult();
for (MemoryNode memoryNode : targetParentNode.getTopicNodes().get(targetTopic).getMemoryNodes()) {
List<MemorySlice> endpointMemorySliceList = memoryNode.getMemorySliceList();
List<MemorySlice> endpointMemorySliceList = memoryNode.loadMemorySliceList();
// targetSliceList.addAll(endpointMemorySliceList);
for (MemorySlice memorySlice : endpointMemorySliceList) {
sliceResult.setSliceBefore(memorySlice.getSliceBefore());
@@ -348,14 +345,14 @@ public class MemoryGraph extends PersistableObject {
TopicNode tempTargetNode = tempTargetParentNode.getTopicNodes().get(tempTopicPath.getLast());
List<MemoryNode> tempMemoryNodes = tempTargetNode.getMemoryNodes();
if (!tempMemoryNodes.isEmpty()) {
relatedMemorySlice.addAll(tempMemoryNodes.getFirst().getMemorySliceList());
relatedMemorySlice.addAll(tempMemoryNodes.getFirst().loadMemorySliceList());
}
}
//邻近记忆节点 父级
List<MemoryNode> targetParentMemoryNodes = targetParentNode.getMemoryNodes();
if (!targetParentMemoryNodes.isEmpty()) {
relatedMemorySlice.addAll(targetParentMemoryNodes.getFirst().getMemorySliceList());
relatedMemorySlice.addAll(targetParentMemoryNodes.getFirst().loadMemorySliceList());
}
//将上述结果包装为MemoryResult

View File

@@ -3,20 +3,32 @@ package work.slhaf.agent.modules.memory;
import lombok.Data;
import lombok.extern.slf4j.Slf4j;
import work.slhaf.agent.common.config.Config;
import work.slhaf.module.InteractionContext;
import work.slhaf.module.InteractionModule;
import java.io.IOException;
@Data
@Slf4j
public class MemoryManager {
public class MemoryManager implements InteractionModule {
private static MemoryManager memoryManager;
private MemoryGraph memoryGraph;
private SliceEvaluator sliceEvaluator;
public static MemoryManager initialize(Config config){
private MemoryManager(){}
@Override
public void execute(InteractionContext interactionContext) {
}
public static MemoryManager getInstance() throws IOException, ClassNotFoundException {
if (memoryManager == null) {
Config config = Config.getConfig();
memoryManager = new MemoryManager();
memoryManager.setMemoryGraph(MemoryGraph.initialize(config.getAgentId()));
memoryManager.setMemoryGraph(MemoryGraph.getInstance(config.getAgentId()));
memoryManager.setSliceEvaluator(SliceEvaluator.initialize(config));
log.info("MemoryManager注册完毕...");
}

View File

@@ -7,6 +7,8 @@ import work.slhaf.agent.common.config.Config;
import work.slhaf.agent.common.model.Model;
import work.slhaf.agent.common.model.ModelConstant;
import java.io.IOException;
@EqualsAndHashCode(callSuper = true)
@Data
@Slf4j
@@ -15,7 +17,9 @@ public class SliceEvaluator extends Model {
private static SliceEvaluator sliceEvaluator;
public static SliceEvaluator initialize(Config config) {
private SliceEvaluator(){}
public static SliceEvaluator initialize(Config config) throws IOException, ClassNotFoundException {
if (sliceEvaluator == null) {
sliceEvaluator = new SliceEvaluator();

View File

@@ -8,6 +8,8 @@ import work.slhaf.agent.modules.memory.pojo.MemorySlice;
import work.slhaf.agent.modules.memory.pojo.PersistableObject;
import java.io.*;
import java.nio.file.Files;
import java.nio.file.Path;
import java.time.LocalDate;
import java.util.ArrayList;
import java.util.List;
@@ -20,7 +22,7 @@ public class MemoryNode extends PersistableObject implements Comparable<MemoryNo
@Serial
private static final long serialVersionUID = 1L;
private static String SLICE_DATA_DIR = "./data/slice/";
private static String SLICE_DATA_DIR = "./data/memory/slice/";
/**
* 记忆节点唯一标识, 用于作为实际文件名, 如(xxxx-xxxxx-xxxxx.slice)
@@ -47,7 +49,7 @@ public class MemoryNode extends PersistableObject implements Comparable<MemoryNo
return 0;
}
public List<MemorySlice> getMemorySliceList() throws IOException, ClassNotFoundException {
public List<MemorySlice> loadMemorySliceList() throws IOException, ClassNotFoundException {
//检查是否存在对应文件
File file = new File(SLICE_DATA_DIR+this.getMemoryNodeId()+".slice");
if (file.exists()){
@@ -64,6 +66,7 @@ public class MemoryNode extends PersistableObject implements Comparable<MemoryNo
throw new NullSliceListException("memorySliceList为NULL! 检查实现逻辑!");
}
File file = new File(SLICE_DATA_DIR+this.getMemoryNodeId()+".slice");
Files.createDirectories(Path.of(SLICE_DATA_DIR));
try(ObjectOutputStream oos = new ObjectOutputStream(new FileOutputStream(file))){
oos.writeObject(this.memorySliceList);
}

View File

@@ -0,0 +1,29 @@
package work.slhaf.agent.modules.preprocess;
import work.slhaf.agent.core.interation.data.InteractionInputData;
import work.slhaf.module.InteractionContext;
public class PreprocessExecutor {
private static PreprocessExecutor preprocessExecutor;
private PreprocessExecutor(){}
public static PreprocessExecutor getInstance() {
if (preprocessExecutor == null) {
preprocessExecutor = new PreprocessExecutor();
}
return preprocessExecutor;
}
public InteractionContext execute(InteractionInputData inputData) {
InteractionContext context = new InteractionContext();
context.setDateTime(inputData.getLocalDateTime());
context.setFinished(false);
context.setInput(inputData.getContent());
context.setUserInfo(inputData.getUserInfo());
context.setUserNickname(inputData.getUserNickName());
return context;
}
}

View File

@@ -0,0 +1,4 @@
package work.slhaf.agent.modules.task;
public class TaskEvaluator {
}

View File

@@ -6,16 +6,23 @@ import lombok.extern.slf4j.Slf4j;
import work.slhaf.agent.common.config.Config;
import work.slhaf.agent.common.model.Model;
import work.slhaf.agent.common.model.ModelConstant;
import work.slhaf.module.InteractionContext;
import work.slhaf.module.InteractionModule;
import java.io.IOException;
@EqualsAndHashCode(callSuper = true)
@Data
@Slf4j
public class TaskScheduler extends Model {
public class TaskScheduler extends Model implements InteractionModule {
public static final String MODEL_KEY = "task_trigger";
private static TaskScheduler taskScheduler;
public static TaskScheduler initialize(Config config) {
private TaskScheduler(){}
public static TaskScheduler getInstance() throws IOException, ClassNotFoundException {
if (taskScheduler == null) {
Config config = Config.getConfig();
taskScheduler = new TaskScheduler();
taskScheduler.setPrompt(ModelConstant.SLICE_EVALUATOR_PROMPT);
setModel(config, taskScheduler, MODEL_KEY, taskScheduler.getPrompt());
@@ -25,4 +32,8 @@ public class TaskScheduler extends Model {
return taskScheduler;
}
@Override
public void execute(InteractionContext interactionContext) {
}
}

View File

@@ -5,22 +5,33 @@ import lombok.EqualsAndHashCode;
import work.slhaf.agent.common.config.Config;
import work.slhaf.agent.common.model.Model;
import work.slhaf.agent.common.model.ModelConstant;
import work.slhaf.module.InteractionContext;
import work.slhaf.module.InteractionModule;
import java.io.IOException;
@EqualsAndHashCode(callSuper = true)
@Data
public class TopicExtractor extends Model {
public class TopicExtractor extends Model implements InteractionModule {
public static final String MODEL_KEY = "topic_extractor";
private static TopicExtractor topicExtractor;
public static TopicExtractor initialize(Config config) {
private TopicExtractor() {
}
public static TopicExtractor getInstance() throws IOException, ClassNotFoundException {
if (topicExtractor == null) {
Config config = Config.getConfig();
topicExtractor = new TopicExtractor();
topicExtractor.setPrompt(ModelConstant.SLICE_EVALUATOR_PROMPT);
setModel(config,topicExtractor, MODEL_KEY, topicExtractor.getPrompt());
setModel(config, topicExtractor, MODEL_KEY, topicExtractor.getPrompt());
}
return topicExtractor;
}
@Override
public void execute(InteractionContext interactionContext) {
}
}

View File

@@ -49,8 +49,8 @@ public class InsertTest {
assertEquals(1, collectionsNode.getMemoryNodes().size());
MemoryNode memoryNode = collectionsNode.getMemoryNodes().get(0);
assertEquals(LocalDate.now(), memoryNode.getLocalDate());
assertEquals(1, memoryNode.getMemorySliceList().size());
assertEquals(slice, memoryNode.getMemorySliceList().get(0));
assertEquals(1, memoryNode.loadMemorySliceList().size());
assertEquals(slice, memoryNode.loadMemorySliceList().get(0));
}
@Test
@@ -71,7 +71,7 @@ public class InsertTest {
.getTopicNodes().get("Collections");
assertEquals(1, collectionsNode.getMemoryNodes().size()); // 同一天应该只有一个MemoryNode
assertEquals(2, collectionsNode.getMemoryNodes().get(0).getMemorySliceList().size()); // 但有两个MemorySlice
assertEquals(2, collectionsNode.getMemoryNodes().get(0).loadMemorySliceList().size()); // 但有两个MemorySlice
}
@Test
@@ -141,7 +141,7 @@ public class InsertTest {
memoryGraph.serialize();
// 反序列化
MemoryGraph loadedGraph = MemoryGraph.initialize(testId);
MemoryGraph loadedGraph = MemoryGraph.getInstance(testId);
// 校验topic 是否存在
assertNotNull(loadedGraph.getTopicNodes().get("生活"));
@@ -157,7 +157,7 @@ public class InsertTest {
assertFalse(javaNode.getMemoryNodes().isEmpty());
// 校验MemorySlice 内容一致
MemorySlice deserializedSlice = javaNode.getMemoryNodes().get(0).getMemorySliceList().get(0);
MemorySlice deserializedSlice = javaNode.getMemoryNodes().get(0).loadMemorySliceList().get(0);
assertEquals("001", deserializedSlice.getMemoryId());
}