mirror of
https://github.com/slhaf/Partner.git
synced 2026-05-12 16:53:04 +08:00
- MemoryGraph 新增输出主题树功能
- 将 TopicExtractor 重命名为 MemorySelectExtractor ,并添加了提示词 - 记忆模块开发工作进行中 - 新增 SliceSummary 类,服务于记忆模块
This commit is contained in:
@@ -1,12 +1,20 @@
|
|||||||
package work.slhaf;
|
package work.slhaf;
|
||||||
|
|
||||||
import work.slhaf.agent.Agent;
|
import work.slhaf.agent.Agent;
|
||||||
|
import work.slhaf.agent.core.interaction.data.InteractionInputData;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
|
|
||||||
public class Main {
|
public class Main {
|
||||||
public static void main(String[] args) throws IOException {
|
public static void main(String[] args) throws IOException, ClassNotFoundException {
|
||||||
Agent agent = Agent.initialize();
|
Agent agent = Agent.initialize();
|
||||||
agent.receiveUserInput("111","222","hello");
|
|
||||||
|
InteractionInputData inputData = new InteractionInputData();
|
||||||
|
inputData.setContent("hello");
|
||||||
|
inputData.setPlatform("cli");
|
||||||
|
inputData.setUserInfo("owner");
|
||||||
|
inputData.setUserNickName("master");
|
||||||
|
|
||||||
|
agent.receiveUserInput(inputData);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -36,13 +36,9 @@ public class Agent implements TaskCallback {
|
|||||||
|
|
||||||
/**
|
/**
|
||||||
* 接收用户输入,包装为标准输入数据类
|
* 接收用户输入,包装为标准输入数据类
|
||||||
* @param input
|
* @param inputData
|
||||||
*/
|
*/
|
||||||
public void receiveUserInput(String userNickName,String userInfo,String input) throws IOException {
|
public void receiveUserInput(InteractionInputData inputData) throws IOException, ClassNotFoundException {
|
||||||
InteractionInputData inputData = new InteractionInputData();
|
|
||||||
inputData.setContent(input);
|
|
||||||
inputData.setUserInfo(userInfo);
|
|
||||||
inputData.setUserNickName(userNickName);
|
|
||||||
inputData.setLocalDateTime(LocalDateTime.now());
|
inputData.setLocalDateTime(LocalDateTime.now());
|
||||||
interactionHub.call(inputData);
|
interactionHub.call(inputData);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -5,13 +5,12 @@ import lombok.Data;
|
|||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.apache.commons.io.FileUtils;
|
import org.apache.commons.io.FileUtils;
|
||||||
import work.slhaf.agent.core.model.CoreModel;
|
import work.slhaf.agent.core.model.CoreModel;
|
||||||
import work.slhaf.agent.core.memory.MemoryManager;
|
import work.slhaf.agent.modules.memory.MemorySelectExtractor;
|
||||||
import work.slhaf.agent.modules.memory.MemorySelector;
|
import work.slhaf.agent.modules.memory.MemorySelector;
|
||||||
import work.slhaf.agent.modules.memory.MemoryUpdater;
|
import work.slhaf.agent.modules.memory.MemoryUpdater;
|
||||||
import work.slhaf.agent.modules.memory.SliceEvaluator;
|
import work.slhaf.agent.modules.memory.SliceEvaluator;
|
||||||
import work.slhaf.agent.modules.task.TaskEvaluator;
|
import work.slhaf.agent.modules.task.TaskEvaluator;
|
||||||
import work.slhaf.agent.modules.task.TaskScheduler;
|
import work.slhaf.agent.modules.task.TaskScheduler;
|
||||||
import work.slhaf.agent.modules.topic.TopicExtractor;
|
|
||||||
|
|
||||||
import java.io.File;
|
import java.io.File;
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
@@ -67,7 +66,7 @@ public class Config {
|
|||||||
|
|
||||||
private static void generatePipelineConfig() {
|
private static void generatePipelineConfig() {
|
||||||
List<ModuleConfig> moduleConfigList = List.of(
|
List<ModuleConfig> moduleConfigList = List.of(
|
||||||
new ModuleConfig(TopicExtractor.class.getName(), ModuleConfig.Constant.INTERNAL, null),
|
new ModuleConfig(MemorySelectExtractor.class.getName(), ModuleConfig.Constant.INTERNAL, null),
|
||||||
new ModuleConfig(MemorySelector.class.getName(), ModuleConfig.Constant.INTERNAL, null),
|
new ModuleConfig(MemorySelector.class.getName(), ModuleConfig.Constant.INTERNAL, null),
|
||||||
new ModuleConfig(CoreModel.class.getName(),ModuleConfig.Constant.INTERNAL,null),
|
new ModuleConfig(CoreModel.class.getName(),ModuleConfig.Constant.INTERNAL,null),
|
||||||
new ModuleConfig(MemoryUpdater.class.getName(),ModuleConfig.Constant.INTERNAL,null),
|
new ModuleConfig(MemoryUpdater.class.getName(),ModuleConfig.Constant.INTERNAL,null),
|
||||||
@@ -100,7 +99,7 @@ public class Config {
|
|||||||
}
|
}
|
||||||
case 3 -> {
|
case 3 -> {
|
||||||
System.out.println("TopicExtractor:");
|
System.out.println("TopicExtractor:");
|
||||||
yield TopicExtractor.MODEL_KEY;
|
yield MemorySelectExtractor.MODEL_KEY;
|
||||||
}
|
}
|
||||||
default -> throw new RuntimeException();
|
default -> throw new RuntimeException();
|
||||||
};
|
};
|
||||||
|
|||||||
@@ -6,6 +6,56 @@ public class ModelConstant {
|
|||||||
public static final String SLICE_EVALUATOR_PROMPT = """
|
public static final String SLICE_EVALUATOR_PROMPT = """
|
||||||
""";
|
""";
|
||||||
public static final String TOPIC_EXTRACTOR_PROMPT = """
|
public static final String TOPIC_EXTRACTOR_PROMPT = """
|
||||||
|
# MemorySelectExtractor 提示词
|
||||||
|
|
||||||
|
## 功能说明
|
||||||
|
你需要根据用户输入的JSON数据,分析其`text`字段内容,判断是否需要通过主题路径或日期进行记忆查询,并返回标准化格式的JSON响应。
|
||||||
|
|
||||||
|
## 输入字段说明
|
||||||
|
- `text`: 用户输入的文本内容
|
||||||
|
- `topic_tree`: 当前可用的主题树结构(括号内数字表示子主题数量)
|
||||||
|
- `date`: 当前对话发生的日期(用于时间推理)
|
||||||
|
|
||||||
|
## 输出规则
|
||||||
|
1. 当文本涉及明确主题路径时:
|
||||||
|
- 使用`"type": "topic"`
|
||||||
|
- `text`字段格式为"根主题->子主题->子子主题"(必须**完全匹配**topic_tree中的层级,包括从[root]到目标主题的完整路径)
|
||||||
|
- 示例:{
|
||||||
|
"type": "topic",
|
||||||
|
"text": "工作->项目A->需求文档"
|
||||||
|
}
|
||||||
|
|
||||||
|
2. 当文本包含明确可推算的日期时:
|
||||||
|
- 使用`"type": "date"`
|
||||||
|
- 日期格式必须为"YYYY-MM-DD"
|
||||||
|
- 仅接受具体日期(不接受"上周"等模糊表达)
|
||||||
|
- 示例:{
|
||||||
|
"type": "date",
|
||||||
|
"text": "2024-04-15"
|
||||||
|
}
|
||||||
|
|
||||||
|
3. 当不需要查询或无法确定时:
|
||||||
|
- 使用`"type": "none"`
|
||||||
|
- 示例:{
|
||||||
|
"type": "none"
|
||||||
|
}
|
||||||
|
|
||||||
|
## 完整示例
|
||||||
|
用户输入:{
|
||||||
|
"text": "还记得我们讨论过游戏引擎的物理系统实现吗?",
|
||||||
|
"topic_tree": "
|
||||||
|
技术 (3)[root]
|
||||||
|
├── 游戏开发 (2)
|
||||||
|
│ ├── 图形渲染 (1)
|
||||||
|
│ └── 物理系统 (0)
|
||||||
|
└── 人工智能 (1)",
|
||||||
|
"date": "2024-04-20"
|
||||||
|
}
|
||||||
|
|
||||||
|
正确响应:{
|
||||||
|
"type": "topic",
|
||||||
|
"text": "技术->游戏开发->物理系统"
|
||||||
|
}
|
||||||
""";
|
""";
|
||||||
public static final String TASK_EVALUATOR_PROMPT = """
|
public static final String TASK_EVALUATOR_PROMPT = """
|
||||||
""";
|
""";
|
||||||
|
|||||||
@@ -7,8 +7,8 @@ import work.slhaf.agent.core.interaction.InteractionModulesLoader;
|
|||||||
import work.slhaf.agent.core.interaction.TaskCallback;
|
import work.slhaf.agent.core.interaction.TaskCallback;
|
||||||
import work.slhaf.agent.core.interaction.data.InteractionContext;
|
import work.slhaf.agent.core.interaction.data.InteractionContext;
|
||||||
import work.slhaf.agent.core.interaction.data.InteractionInputData;
|
import work.slhaf.agent.core.interaction.data.InteractionInputData;
|
||||||
import work.slhaf.agent.core.model.CoreModel;
|
|
||||||
import work.slhaf.agent.core.memory.MemoryManager;
|
import work.slhaf.agent.core.memory.MemoryManager;
|
||||||
|
import work.slhaf.agent.core.model.CoreModel;
|
||||||
import work.slhaf.agent.modules.preprocess.PreprocessExecutor;
|
import work.slhaf.agent.modules.preprocess.PreprocessExecutor;
|
||||||
import work.slhaf.agent.modules.task.TaskScheduler;
|
import work.slhaf.agent.modules.task.TaskScheduler;
|
||||||
|
|
||||||
@@ -35,7 +35,7 @@ public class InteractionHub {
|
|||||||
return interactionHub;
|
return interactionHub;
|
||||||
}
|
}
|
||||||
|
|
||||||
public void call(InteractionInputData inputData) throws IOException {
|
public void call(InteractionInputData inputData) throws IOException, ClassNotFoundException {
|
||||||
//预处理
|
//预处理
|
||||||
InteractionContext interactionContext = PreprocessExecutor.getInstance().execute(inputData);
|
InteractionContext interactionContext = PreprocessExecutor.getInstance().execute(inputData);
|
||||||
//加载模块
|
//加载模块
|
||||||
@@ -43,6 +43,6 @@ public class InteractionHub {
|
|||||||
for (InteractionModule interactionModule : interactionModules) {
|
for (InteractionModule interactionModule : interactionModules) {
|
||||||
interactionModule.execute(interactionContext);
|
interactionModule.execute(interactionContext);
|
||||||
}
|
}
|
||||||
callback.onTaskFinished(interactionContext.getUserInfo(), interactionContext.getCoreResponse().getMessage());
|
callback.onTaskFinished(interactionContext.getUserInfo(), interactionContext.getCoreResponse().getString("message"));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,6 +2,8 @@ package work.slhaf.agent.core.interaction;
|
|||||||
|
|
||||||
import work.slhaf.agent.core.interaction.data.InteractionContext;
|
import work.slhaf.agent.core.interaction.data.InteractionContext;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
|
||||||
public interface InteractionModule {
|
public interface InteractionModule {
|
||||||
void execute(InteractionContext context);
|
void execute(InteractionContext context) throws IOException, ClassNotFoundException;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,12 +2,8 @@ package work.slhaf.agent.core.interaction.data;
|
|||||||
|
|
||||||
import com.alibaba.fastjson2.JSONObject;
|
import com.alibaba.fastjson2.JSONObject;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
import work.slhaf.agent.common.chat.pojo.ChatResponse;
|
|
||||||
import work.slhaf.agent.core.memory.pojo.MemorySlice;
|
|
||||||
import work.slhaf.agent.modules.task.data.TaskData;
|
|
||||||
|
|
||||||
import java.time.LocalDateTime;
|
import java.time.LocalDateTime;
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
public class InteractionContext {
|
public class InteractionContext {
|
||||||
@@ -17,10 +13,7 @@ public class InteractionContext {
|
|||||||
|
|
||||||
protected boolean finished;
|
protected boolean finished;
|
||||||
protected String input;
|
protected String input;
|
||||||
protected JSONObject tempResult;
|
|
||||||
protected ChatResponse coreResponse;
|
|
||||||
|
|
||||||
protected List<MemorySlice> memorySlices;
|
protected JSONObject moduleContext;
|
||||||
protected List<String> topicPath;
|
protected JSONObject coreResponse;
|
||||||
protected List<TaskData> taskDataList;
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -10,4 +10,5 @@ public class InteractionInputData {
|
|||||||
private String userNickName;
|
private String userNickName;
|
||||||
private String content;
|
private String content;
|
||||||
private LocalDateTime localDateTime;
|
private LocalDateTime localDateTime;
|
||||||
|
private String platform;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -8,10 +8,7 @@ import work.slhaf.agent.common.chat.pojo.Message;
|
|||||||
import work.slhaf.agent.core.memory.exception.UnExistedTopicException;
|
import work.slhaf.agent.core.memory.exception.UnExistedTopicException;
|
||||||
import work.slhaf.agent.core.memory.node.MemoryNode;
|
import work.slhaf.agent.core.memory.node.MemoryNode;
|
||||||
import work.slhaf.agent.core.memory.node.TopicNode;
|
import work.slhaf.agent.core.memory.node.TopicNode;
|
||||||
import work.slhaf.agent.core.memory.pojo.MemoryResult;
|
import work.slhaf.agent.core.memory.pojo.*;
|
||||||
import work.slhaf.agent.core.memory.pojo.MemorySlice;
|
|
||||||
import work.slhaf.agent.core.memory.pojo.MemorySliceResult;
|
|
||||||
import work.slhaf.agent.core.memory.pojo.PersistableObject;
|
|
||||||
|
|
||||||
import java.io.*;
|
import java.io.*;
|
||||||
import java.nio.file.Files;
|
import java.nio.file.Files;
|
||||||
@@ -100,6 +97,11 @@ public class MemoryGraph extends PersistableObject {
|
|||||||
*/
|
*/
|
||||||
private List<Message> chatMessages;
|
private List<Message> chatMessages;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 用户列表
|
||||||
|
*/
|
||||||
|
private List<User> users;
|
||||||
|
|
||||||
public MemoryGraph(String id) {
|
public MemoryGraph(String id) {
|
||||||
this.id = id;
|
this.id = id;
|
||||||
this.topicNodes = new HashMap<>();
|
this.topicNodes = new HashMap<>();
|
||||||
@@ -266,7 +268,7 @@ public class MemoryGraph extends PersistableObject {
|
|||||||
//放入新缓存
|
//放入新缓存
|
||||||
userDialogMap
|
userDialogMap
|
||||||
.computeIfAbsent(now, k -> new ConcurrentHashMap<>())
|
.computeIfAbsent(now, k -> new ConcurrentHashMap<>())
|
||||||
.merge(slice.getStartUser(), slice.getSummary(), (oldVal, newVal) -> oldVal + " " + newVal);
|
.merge(slice.getStartUserId(), slice.getSummary(), (oldVal, newVal) -> oldVal + " " + newVal);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -298,7 +300,8 @@ public class MemoryGraph extends PersistableObject {
|
|||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public MemoryResult selectMemory(List<String> topicPath) throws IOException, ClassNotFoundException {
|
public MemoryResult selectMemory(String topicPathStr) throws IOException, ClassNotFoundException {
|
||||||
|
List<String> topicPath = List.of(topicPathStr.split("->"));
|
||||||
MemoryResult memoryResult = new MemoryResult();
|
MemoryResult memoryResult = new MemoryResult();
|
||||||
|
|
||||||
//每日刷新缓存
|
//每日刷新缓存
|
||||||
@@ -319,7 +322,6 @@ public class MemoryGraph extends PersistableObject {
|
|||||||
MemorySliceResult sliceResult = new MemorySliceResult();
|
MemorySliceResult sliceResult = new MemorySliceResult();
|
||||||
for (MemoryNode memoryNode : targetParentNode.getTopicNodes().get(targetTopic).getMemoryNodes()) {
|
for (MemoryNode memoryNode : targetParentNode.getTopicNodes().get(targetTopic).getMemoryNodes()) {
|
||||||
List<MemorySlice> endpointMemorySliceList = memoryNode.loadMemorySliceList();
|
List<MemorySlice> endpointMemorySliceList = memoryNode.loadMemorySliceList();
|
||||||
// targetSliceList.addAll(endpointMemorySliceList);
|
|
||||||
for (MemorySlice memorySlice : endpointMemorySliceList) {
|
for (MemorySlice memorySlice : endpointMemorySliceList) {
|
||||||
sliceResult.setSliceBefore(memorySlice.getSliceBefore());
|
sliceResult.setSliceBefore(memorySlice.getSliceBefore());
|
||||||
sliceResult.setMemorySlice(memorySlice);
|
sliceResult.setMemorySlice(memorySlice);
|
||||||
@@ -420,5 +422,28 @@ public class MemoryGraph extends PersistableObject {
|
|||||||
}
|
}
|
||||||
return targetParentNode;
|
return targetParentNode;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public void printTopicTree() {
|
||||||
|
for (Map.Entry<String, TopicNode> entry : topicNodes.entrySet()) {
|
||||||
|
String rootName = entry.getKey();
|
||||||
|
TopicNode rootNode = entry.getValue();
|
||||||
|
System.out.println(rootName+"[root]");
|
||||||
|
printSubTopicsTreeFormat(rootNode, "", true);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private void printSubTopicsTreeFormat(TopicNode node, String prefix, boolean isLast) {
|
||||||
|
if (node.getTopicNodes() == null || node.getTopicNodes().isEmpty()) return;
|
||||||
|
|
||||||
|
List<Map.Entry<String, TopicNode>> entries = new ArrayList<>(node.getTopicNodes().entrySet());
|
||||||
|
for (int i = 0; i < entries.size(); i++) {
|
||||||
|
boolean last = (i == entries.size() - 1);
|
||||||
|
Map.Entry<String, TopicNode> entry = entries.get(i);
|
||||||
|
System.out.println(prefix + (last ? "└── " : "├── ") + entry.getKey());
|
||||||
|
printSubTopicsTreeFormat(entry.getValue(), prefix + (last ? " " : "│ "), last);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -5,9 +5,15 @@ import lombok.extern.slf4j.Slf4j;
|
|||||||
import work.slhaf.agent.common.config.Config;
|
import work.slhaf.agent.common.config.Config;
|
||||||
import work.slhaf.agent.core.interaction.InteractionModule;
|
import work.slhaf.agent.core.interaction.InteractionModule;
|
||||||
import work.slhaf.agent.core.interaction.data.InteractionContext;
|
import work.slhaf.agent.core.interaction.data.InteractionContext;
|
||||||
|
import work.slhaf.agent.core.memory.pojo.MemoryResult;
|
||||||
|
import work.slhaf.agent.core.memory.pojo.User;
|
||||||
import work.slhaf.agent.modules.memory.SliceEvaluator;
|
import work.slhaf.agent.modules.memory.SliceEvaluator;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
|
import java.time.LocalDate;
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.UUID;
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
@Slf4j
|
@Slf4j
|
||||||
@@ -36,4 +42,37 @@ public class MemoryManager implements InteractionModule {
|
|||||||
return memoryManager;
|
return memoryManager;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public MemoryResult selectMemory(String path) throws IOException, ClassNotFoundException {
|
||||||
|
return memoryGraph.selectMemory(path);
|
||||||
|
}
|
||||||
|
|
||||||
|
public MemoryResult selectMemory(LocalDate date) {
|
||||||
|
return memoryGraph.selectMemory(date);
|
||||||
|
}
|
||||||
|
|
||||||
|
public String getUserId(String userInfo,String nickName) {
|
||||||
|
String userId = null;
|
||||||
|
for (User user : memoryGraph.getUsers()) {
|
||||||
|
if (user.getInfo().contains(userInfo)){
|
||||||
|
userId = user.getUuid();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (userId == null) {
|
||||||
|
User newUser = setNewUser(userInfo, nickName);
|
||||||
|
memoryGraph.getUsers().add(newUser);
|
||||||
|
userId = newUser.getUuid();
|
||||||
|
}
|
||||||
|
return userId;
|
||||||
|
}
|
||||||
|
|
||||||
|
private static User setNewUser(String userInfo, String nickName) {
|
||||||
|
User newUser = new User();
|
||||||
|
newUser.setUuid(UUID.randomUUID().toString());
|
||||||
|
List<String> infoList = new ArrayList<>();
|
||||||
|
infoList.add(userInfo);
|
||||||
|
newUser.setInfo(infoList);
|
||||||
|
newUser.setNickName(nickName);
|
||||||
|
return newUser;
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -45,12 +45,12 @@ public class MemorySlice extends PersistableObject implements Comparable<MemoryS
|
|||||||
* 多用户设定
|
* 多用户设定
|
||||||
* 发起该切片对话的用户
|
* 发起该切片对话的用户
|
||||||
*/
|
*/
|
||||||
private String startUser;
|
private String startUserId;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 该切片涉及到的用户uuid
|
* 该切片涉及到的用户uuid
|
||||||
*/
|
*/
|
||||||
private List<String> involvedUsers;
|
private List<String> involvedUserIds;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 是否仅供发起用户作为记忆参考
|
* 是否仅供发起用户作为记忆参考
|
||||||
|
|||||||
@@ -38,6 +38,6 @@ public class CoreModel extends Model implements InteractionModule {
|
|||||||
//TODO 需要拼接上下文之后再发送给主模型
|
//TODO 需要拼接上下文之后再发送给主模型
|
||||||
|
|
||||||
ChatResponse res = runChat(interactionContext.getInput());
|
ChatResponse res = runChat(interactionContext.getInput());
|
||||||
interactionContext.setCoreResponse(res);
|
// interactionContext.setCoreResponse();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -41,8 +41,8 @@ public class AgentWebSocketServer extends WebSocketServer implements MessageSend
|
|||||||
InteractionInputData inputData = JSONObject.parseObject(s, InteractionInputData.class);
|
InteractionInputData inputData = JSONObject.parseObject(s, InteractionInputData.class);
|
||||||
userSessions.put(inputData.getUserInfo(), webSocket); // 注册连接
|
userSessions.put(inputData.getUserInfo(), webSocket); // 注册连接
|
||||||
try {
|
try {
|
||||||
agent.receiveUserInput(inputData.getUserNickName(), inputData.getUserInfo(), inputData.getContent());
|
agent.receiveUserInput(inputData);
|
||||||
} catch (IOException e) {
|
} catch (IOException | ClassNotFoundException e) {
|
||||||
throw new RuntimeException(e);
|
throw new RuntimeException(e);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,41 @@
|
|||||||
|
package work.slhaf.agent.modules.memory;
|
||||||
|
|
||||||
|
import com.alibaba.fastjson2.JSONObject;
|
||||||
|
import lombok.Data;
|
||||||
|
import lombok.EqualsAndHashCode;
|
||||||
|
import work.slhaf.agent.common.config.Config;
|
||||||
|
import work.slhaf.agent.common.model.Model;
|
||||||
|
import work.slhaf.agent.common.model.ModelConstant;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
|
||||||
|
@EqualsAndHashCode(callSuper = true)
|
||||||
|
@Data
|
||||||
|
public class MemorySelectExtractor extends Model {
|
||||||
|
public static final String MODEL_KEY = "topic_extractor";
|
||||||
|
private static MemorySelectExtractor memorySelectExtractor;
|
||||||
|
|
||||||
|
private MemorySelectExtractor() {
|
||||||
|
}
|
||||||
|
|
||||||
|
public static MemorySelectExtractor getInstance() throws IOException, ClassNotFoundException {
|
||||||
|
if (memorySelectExtractor == null) {
|
||||||
|
Config config = Config.getConfig();
|
||||||
|
memorySelectExtractor = new MemorySelectExtractor();
|
||||||
|
setModel(config, memorySelectExtractor, MODEL_KEY, ModelConstant.TOPIC_EXTRACTOR_PROMPT);
|
||||||
|
}
|
||||||
|
|
||||||
|
return memorySelectExtractor;
|
||||||
|
}
|
||||||
|
|
||||||
|
public JSONObject execute(String input) {
|
||||||
|
return JSONObject.parseObject(singleChat(input).getMessage());
|
||||||
|
}
|
||||||
|
|
||||||
|
public static class Constant {
|
||||||
|
public static final String NONE = "none";
|
||||||
|
public static final String DATE = "date";
|
||||||
|
public static final String TOPIC = "topic";
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
@@ -1,11 +1,14 @@
|
|||||||
package work.slhaf.agent.modules.memory;
|
package work.slhaf.agent.modules.memory;
|
||||||
|
|
||||||
|
import com.alibaba.fastjson2.JSONObject;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
import work.slhaf.agent.core.interaction.InteractionModule;
|
import work.slhaf.agent.core.interaction.InteractionModule;
|
||||||
import work.slhaf.agent.core.interaction.data.InteractionContext;
|
import work.slhaf.agent.core.interaction.data.InteractionContext;
|
||||||
import work.slhaf.agent.core.memory.MemoryManager;
|
import work.slhaf.agent.core.memory.MemoryManager;
|
||||||
|
import work.slhaf.agent.core.memory.pojo.MemoryResult;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
|
import java.time.LocalDate;
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
public class MemorySelector implements InteractionModule {
|
public class MemorySelector implements InteractionModule {
|
||||||
@@ -14,20 +17,40 @@ public class MemorySelector implements InteractionModule {
|
|||||||
|
|
||||||
private MemoryManager memoryManager;
|
private MemoryManager memoryManager;
|
||||||
private SliceEvaluator sliceEvaluator;
|
private SliceEvaluator sliceEvaluator;
|
||||||
|
private MemorySelectExtractor memorySelectExtractor;
|
||||||
|
|
||||||
private MemorySelector(){}
|
private MemorySelector() {
|
||||||
|
}
|
||||||
|
|
||||||
public static MemorySelector getInstance() throws IOException, ClassNotFoundException {
|
public static MemorySelector getInstance() throws IOException, ClassNotFoundException {
|
||||||
if (memorySelector == null) {
|
if (memorySelector == null) {
|
||||||
memorySelector = new MemorySelector();
|
memorySelector = new MemorySelector();
|
||||||
memorySelector.setMemoryManager(MemoryManager.getInstance());
|
memorySelector.setMemoryManager(MemoryManager.getInstance());
|
||||||
memorySelector.setSliceEvaluator(SliceEvaluator.getInstance());
|
memorySelector.setSliceEvaluator(SliceEvaluator.getInstance());
|
||||||
|
memorySelector.setMemorySelectExtractor(MemorySelectExtractor.getInstance());
|
||||||
}
|
}
|
||||||
return memorySelector;
|
return memorySelector;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void execute(InteractionContext interactionContext) {
|
public void execute(InteractionContext interactionContext) throws IOException, ClassNotFoundException {
|
||||||
|
//获取主题路径
|
||||||
|
JSONObject extractorResult = memorySelectExtractor.execute(interactionContext.getInput());
|
||||||
|
String selectType = extractorResult.getString("type");
|
||||||
|
//根据主结果进行操作查找切片
|
||||||
|
MemoryResult memoryResult = switch (selectType) {
|
||||||
|
case MemorySelectExtractor.Constant.DATE ->
|
||||||
|
memoryManager.selectMemory(LocalDate.parse(extractorResult.getString(MemorySelectExtractor.Constant.DATE)));
|
||||||
|
case MemorySelectExtractor.Constant.TOPIC ->
|
||||||
|
memoryManager.selectMemory(MemorySelectExtractor.Constant.TOPIC);
|
||||||
|
default -> null;
|
||||||
|
};
|
||||||
|
//评估切片
|
||||||
|
if (memoryResult == null) {
|
||||||
|
memoryResult = sliceEvaluator.execute(memoryResult,interactionContext);
|
||||||
|
}
|
||||||
|
|
||||||
|
//设置上下文
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ public class MemoryUpdater implements InteractionModule {
|
|||||||
|
|
||||||
private MemoryManager memoryManager;
|
private MemoryManager memoryManager;
|
||||||
private InteractionThreadPoolExecutor executor;
|
private InteractionThreadPoolExecutor executor;
|
||||||
|
private MemorySelectExtractor memorySelectExtractor;
|
||||||
|
|
||||||
private MemoryUpdater(){}
|
private MemoryUpdater(){}
|
||||||
|
|
||||||
@@ -22,6 +23,7 @@ public class MemoryUpdater implements InteractionModule {
|
|||||||
if (memoryUpdater == null) {
|
if (memoryUpdater == null) {
|
||||||
memoryUpdater = new MemoryUpdater();
|
memoryUpdater = new MemoryUpdater();
|
||||||
memoryUpdater.setMemoryManager(MemoryManager.getInstance());
|
memoryUpdater.setMemoryManager(MemoryManager.getInstance());
|
||||||
|
memoryUpdater.setMemorySelectExtractor(MemorySelectExtractor.getInstance());
|
||||||
}
|
}
|
||||||
return memoryUpdater;
|
return memoryUpdater;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,13 +1,22 @@
|
|||||||
package work.slhaf.agent.modules.memory;
|
package work.slhaf.agent.modules.memory;
|
||||||
|
|
||||||
|
import cn.hutool.json.JSONUtil;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
import lombok.EqualsAndHashCode;
|
import lombok.EqualsAndHashCode;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import work.slhaf.agent.common.config.Config;
|
import work.slhaf.agent.common.config.Config;
|
||||||
import work.slhaf.agent.common.model.Model;
|
import work.slhaf.agent.common.model.Model;
|
||||||
import work.slhaf.agent.common.model.ModelConstant;
|
import work.slhaf.agent.common.model.ModelConstant;
|
||||||
|
import work.slhaf.agent.core.interaction.data.InteractionContext;
|
||||||
|
import work.slhaf.agent.core.memory.MemoryManager;
|
||||||
|
import work.slhaf.agent.core.memory.pojo.MemoryResult;
|
||||||
|
import work.slhaf.agent.core.memory.pojo.MemorySlice;
|
||||||
|
import work.slhaf.agent.core.memory.pojo.MemorySliceResult;
|
||||||
|
import work.slhaf.agent.modules.memory.data.SliceSummary;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
@EqualsAndHashCode(callSuper = true)
|
@EqualsAndHashCode(callSuper = true)
|
||||||
@Data
|
@Data
|
||||||
@@ -16,19 +25,72 @@ public class SliceEvaluator extends Model {
|
|||||||
public static final String MODEL_KEY = "slice_evaluator";
|
public static final String MODEL_KEY = "slice_evaluator";
|
||||||
|
|
||||||
private static SliceEvaluator sliceEvaluator;
|
private static SliceEvaluator sliceEvaluator;
|
||||||
|
private MemoryManager memoryManager;
|
||||||
|
|
||||||
private SliceEvaluator(){}
|
private SliceEvaluator() {
|
||||||
|
}
|
||||||
|
|
||||||
public static SliceEvaluator getInstance() throws IOException, ClassNotFoundException {
|
public static SliceEvaluator getInstance() throws IOException, ClassNotFoundException {
|
||||||
if (sliceEvaluator == null) {
|
if (sliceEvaluator == null) {
|
||||||
Config config = Config.getConfig();
|
Config config = Config.getConfig();
|
||||||
sliceEvaluator = new SliceEvaluator();
|
sliceEvaluator = new SliceEvaluator();
|
||||||
setModel(config,sliceEvaluator, MODEL_KEY, ModelConstant.SLICE_EVALUATOR_PROMPT);
|
sliceEvaluator.setMemoryManager(MemoryManager.getInstance());
|
||||||
|
setModel(config, sliceEvaluator, MODEL_KEY, ModelConstant.SLICE_EVALUATOR_PROMPT);
|
||||||
log.info("SliceEvaluator注册完毕...");
|
log.info("SliceEvaluator注册完毕...");
|
||||||
}
|
}
|
||||||
|
|
||||||
return sliceEvaluator;
|
return sliceEvaluator;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public MemoryResult execute(MemoryResult memoryResult, InteractionContext context) {
|
||||||
|
List<SliceSummary> sliceSummaryList = new ArrayList<>();
|
||||||
|
setSliceSummaryList(memoryResult, context, sliceSummaryList);
|
||||||
|
String primaryJsonStr = singleChat(JSONUtil.toJsonStr(sliceSummaryList)).getMessage();
|
||||||
|
//TODO 解析并转换为过滤后的MemoryResult
|
||||||
|
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
private void setSliceSummaryList(MemoryResult memoryResult, InteractionContext context, List<SliceSummary> sliceSummaryList) {
|
||||||
|
for (MemorySliceResult memorySliceResult : memoryResult.getMemorySliceResult()) {
|
||||||
|
//判断是否为发起用户
|
||||||
|
if (accessible(memorySliceResult.getMemorySlice(), context)) {
|
||||||
|
SliceSummary sliceSummary = new SliceSummary();
|
||||||
|
sliceSummary.setId(memorySliceResult.getMemorySlice().getTimestamp());
|
||||||
|
String stringBuilder = memorySliceResult.getSliceBefore().getSummary() +
|
||||||
|
"\r\n" +
|
||||||
|
memorySliceResult.getMemorySlice().getSummary() +
|
||||||
|
"\r\n" +
|
||||||
|
memorySliceResult.getSliceAfter().getSummary();
|
||||||
|
sliceSummary.setSummary(stringBuilder);
|
||||||
|
|
||||||
|
sliceSummaryList.add(sliceSummary);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (MemorySlice memorySlice : memoryResult.getRelatedMemorySliceResult()) {
|
||||||
|
SliceSummary sliceSummary = new SliceSummary();
|
||||||
|
sliceSummary.setId(memorySlice.getTimestamp());
|
||||||
|
sliceSummary.setSummary(memorySlice.getSummary());
|
||||||
|
|
||||||
|
sliceSummaryList.add(sliceSummary);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
private boolean accessible(MemorySlice slice, InteractionContext context) {
|
||||||
|
boolean ok;
|
||||||
|
String startUserId = slice.getStartUserId();
|
||||||
|
String userInfo = context.getUserInfo();
|
||||||
|
String nickName = context.getUserNickname();
|
||||||
|
|
||||||
|
if (memoryManager.getUserId(userInfo, nickName).equals(startUserId)) {
|
||||||
|
ok = true;
|
||||||
|
} else {
|
||||||
|
ok = !slice.isPrivate();
|
||||||
|
}
|
||||||
|
|
||||||
|
return ok;
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,9 @@
|
|||||||
|
package work.slhaf.agent.modules.memory.data;
|
||||||
|
|
||||||
|
import lombok.Data;
|
||||||
|
|
||||||
|
@Data
|
||||||
|
public class SliceSummary {
|
||||||
|
private String summary;
|
||||||
|
private Long id;
|
||||||
|
}
|
||||||
@@ -1,5 +1,6 @@
|
|||||||
package work.slhaf.agent.modules.preprocess;
|
package work.slhaf.agent.modules.preprocess;
|
||||||
|
|
||||||
|
import com.alibaba.fastjson2.JSONObject;
|
||||||
import work.slhaf.agent.core.interaction.data.InteractionContext;
|
import work.slhaf.agent.core.interaction.data.InteractionContext;
|
||||||
import work.slhaf.agent.core.interaction.data.InteractionInputData;
|
import work.slhaf.agent.core.interaction.data.InteractionInputData;
|
||||||
|
|
||||||
@@ -26,6 +27,8 @@ public class PreprocessExecutor {
|
|||||||
context.setFinished(false);
|
context.setFinished(false);
|
||||||
context.setInput(inputData.getContent());
|
context.setInput(inputData.getContent());
|
||||||
|
|
||||||
|
context.setModuleContext(new JSONObject());
|
||||||
|
|
||||||
return context;
|
return context;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,7 +3,6 @@ package work.slhaf.agent.modules.task;
|
|||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
import lombok.EqualsAndHashCode;
|
import lombok.EqualsAndHashCode;
|
||||||
import work.slhaf.agent.common.config.Config;
|
import work.slhaf.agent.common.config.Config;
|
||||||
import work.slhaf.agent.common.config.ModelConfig;
|
|
||||||
import work.slhaf.agent.common.model.Model;
|
import work.slhaf.agent.common.model.Model;
|
||||||
import work.slhaf.agent.common.model.ModelConstant;
|
import work.slhaf.agent.common.model.ModelConstant;
|
||||||
|
|
||||||
|
|||||||
@@ -1,41 +0,0 @@
|
|||||||
package work.slhaf.agent.modules.topic;
|
|
||||||
|
|
||||||
import lombok.Data;
|
|
||||||
import lombok.EqualsAndHashCode;
|
|
||||||
import work.slhaf.agent.common.chat.constant.ChatConstant;
|
|
||||||
import work.slhaf.agent.common.chat.pojo.ChatResponse;
|
|
||||||
import work.slhaf.agent.common.chat.pojo.Message;
|
|
||||||
import work.slhaf.agent.common.config.Config;
|
|
||||||
import work.slhaf.agent.common.model.Model;
|
|
||||||
import work.slhaf.agent.common.model.ModelConstant;
|
|
||||||
import work.slhaf.agent.core.interaction.InteractionModule;
|
|
||||||
import work.slhaf.agent.core.interaction.data.InteractionContext;
|
|
||||||
|
|
||||||
import java.io.IOException;
|
|
||||||
|
|
||||||
@EqualsAndHashCode(callSuper = true)
|
|
||||||
@Data
|
|
||||||
public class TopicExtractor extends Model implements InteractionModule {
|
|
||||||
public static final String MODEL_KEY = "topic_extractor";
|
|
||||||
private static TopicExtractor topicExtractor;
|
|
||||||
|
|
||||||
private TopicExtractor() {
|
|
||||||
}
|
|
||||||
|
|
||||||
public static TopicExtractor getInstance() throws IOException, ClassNotFoundException {
|
|
||||||
if (topicExtractor == null) {
|
|
||||||
Config config = Config.getConfig();
|
|
||||||
topicExtractor = new TopicExtractor();
|
|
||||||
setModel(config, topicExtractor, MODEL_KEY, ModelConstant.TOPIC_EXTRACTOR_PROMPT);
|
|
||||||
}
|
|
||||||
|
|
||||||
return topicExtractor;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void execute(InteractionContext interactionContext) {
|
|
||||||
String primaryMessageResponse = singleChat(interactionContext.getInput()).getMessage();
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
89
src/test/java/memory/AITest.java
Normal file
89
src/test/java/memory/AITest.java
Normal file
@@ -0,0 +1,89 @@
|
|||||||
|
package memory;
|
||||||
|
|
||||||
|
import org.junit.jupiter.api.Test;
|
||||||
|
import work.slhaf.agent.common.chat.ChatClient;
|
||||||
|
import work.slhaf.agent.common.chat.constant.ChatConstant;
|
||||||
|
import work.slhaf.agent.common.chat.pojo.Message;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
public class AITest {
|
||||||
|
@Test
|
||||||
|
public void test1(){
|
||||||
|
ChatClient client = new ChatClient("https://open.bigmodel.cn/api/paas/v4/chat/completions","3db444552530b7742b0c53425fb93dcc.LcVwYjByht9AC3N9","glm-4-flash");
|
||||||
|
List<Message> messages = new ArrayList<>();
|
||||||
|
messages.add(new Message(ChatConstant.Character.SYSTEM, """
|
||||||
|
# MemorySelectExtractor 提示词
|
||||||
|
|
||||||
|
## 功能说明
|
||||||
|
你需要根据用户输入的JSON数据,分析其`text`字段内容,判断是否需要通过主题路径或日期进行记忆查询,并返回标准化格式的JSON响应。
|
||||||
|
|
||||||
|
## 输入字段说明
|
||||||
|
- `text`: 用户输入的文本内容
|
||||||
|
- `topic_tree`: 当前可用的主题树结构(括号内数字表示子主题数量)
|
||||||
|
- `date`: 当前对话发生的日期(用于时间推理)
|
||||||
|
|
||||||
|
## 输出规则
|
||||||
|
1. 当文本涉及明确主题路径时:
|
||||||
|
- 使用`"type": "topic"`
|
||||||
|
- `text`字段格式为"根主题->子主题->子子主题"(必须**完全匹配**topic_tree中的层级,包括从[root]到目标主题的完整路径)
|
||||||
|
- 示例:{
|
||||||
|
"type": "topic",
|
||||||
|
"text": "工作->项目A->需求文档"
|
||||||
|
}
|
||||||
|
|
||||||
|
2. 当文本包含明确可推算的日期时:
|
||||||
|
- 使用`"type": "date"`
|
||||||
|
- 日期格式必须为"YYYY-MM-DD"
|
||||||
|
- 仅接受具体日期(不接受"上周"等模糊表达)
|
||||||
|
- 示例:{
|
||||||
|
"type": "date",
|
||||||
|
"text": "2024-04-15"
|
||||||
|
}
|
||||||
|
|
||||||
|
3. 当不需要查询或无法确定时:
|
||||||
|
- 使用`"type": "none"`
|
||||||
|
- 示例:{
|
||||||
|
"type": "none"
|
||||||
|
}
|
||||||
|
|
||||||
|
## 完整示例
|
||||||
|
用户输入:{
|
||||||
|
"text": "还记得我们讨论过游戏引擎的物理系统实现吗?",
|
||||||
|
"topic_tree": "
|
||||||
|
技术 (3)[root]
|
||||||
|
├── 游戏开发 (2)
|
||||||
|
│ ├── 图形渲染 (1)
|
||||||
|
│ └── 物理系统 (0)
|
||||||
|
└── 人工智能 (1)",
|
||||||
|
"date": "2024-04-20"
|
||||||
|
}
|
||||||
|
|
||||||
|
正确响应:{
|
||||||
|
"type": "topic",
|
||||||
|
"text": "技术->游戏开发->物理系统"
|
||||||
|
}
|
||||||
|
"""));
|
||||||
|
|
||||||
|
messages.add(new Message(ChatConstant.Character.USER, """
|
||||||
|
{
|
||||||
|
"text": "上周似乎发生了什么重要的事??",
|
||||||
|
"topic_tree": "
|
||||||
|
汽车工程 (4)[root]
|
||||||
|
├── 动力系统 (3)
|
||||||
|
│ ├── 发动机 (1)
|
||||||
|
│ └── 新能源电池 (2)
|
||||||
|
│ ├── 测试标准 (1)
|
||||||
|
│ └── 安全规范 (1)
|
||||||
|
└── 车身设计 (1)
|
||||||
|
软件开发 (3)[root]
|
||||||
|
质量管理 (2)[root]
|
||||||
|
├── ISO认证 (1)
|
||||||
|
└── 行业标准 (1)",
|
||||||
|
"date": "2024-04-20"
|
||||||
|
}
|
||||||
|
"""));
|
||||||
|
System.out.println(client.runChat(messages).getMessage());
|
||||||
|
}
|
||||||
|
}
|
||||||
62
src/test/java/memory/MemoryTest.java
Normal file
62
src/test/java/memory/MemoryTest.java
Normal file
@@ -0,0 +1,62 @@
|
|||||||
|
package memory;
|
||||||
|
|
||||||
|
import cn.hutool.core.date.LocalDateTimeUtil;
|
||||||
|
import org.junit.jupiter.api.Test;
|
||||||
|
import work.slhaf.agent.core.memory.MemoryGraph;
|
||||||
|
import work.slhaf.agent.core.memory.node.TopicNode;
|
||||||
|
|
||||||
|
import java.time.LocalDate;
|
||||||
|
import java.time.LocalDateTime;
|
||||||
|
import java.util.HashMap;
|
||||||
|
import java.util.concurrent.ConcurrentHashMap;
|
||||||
|
|
||||||
|
public class MemoryTest {
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void test1() {
|
||||||
|
MemoryGraph graph = new MemoryGraph("test");
|
||||||
|
HashMap<String, TopicNode> topicMap = new HashMap<>();
|
||||||
|
|
||||||
|
TopicNode root1 = new TopicNode();
|
||||||
|
root1.setTopicNodes(new ConcurrentHashMap<>());
|
||||||
|
|
||||||
|
TopicNode sub1 = new TopicNode();
|
||||||
|
sub1.setTopicNodes(new ConcurrentHashMap<>());
|
||||||
|
|
||||||
|
TopicNode sub2 = new TopicNode();
|
||||||
|
sub2.setTopicNodes(new ConcurrentHashMap<>());
|
||||||
|
|
||||||
|
TopicNode subsub1 = new TopicNode();
|
||||||
|
subsub1.setTopicNodes(new ConcurrentHashMap<>());
|
||||||
|
|
||||||
|
// 构造结构:root -> sub1 -> subsub1, root -> sub2
|
||||||
|
sub1.getTopicNodes().put("子子主题1", subsub1);
|
||||||
|
root1.getTopicNodes().put("子主题1", sub1);
|
||||||
|
root1.getTopicNodes().put("子主题2", sub2);
|
||||||
|
|
||||||
|
topicMap.put("根主题1", root1);
|
||||||
|
|
||||||
|
// 添加 root2
|
||||||
|
TopicNode root2 = new TopicNode();
|
||||||
|
root2.setTopicNodes(new ConcurrentHashMap<>());
|
||||||
|
|
||||||
|
TopicNode sub3 = new TopicNode();
|
||||||
|
sub3.setTopicNodes(new ConcurrentHashMap<>());
|
||||||
|
|
||||||
|
// 构造结构:root2 -> sub3
|
||||||
|
root2.getTopicNodes().put("子主题3", sub3);
|
||||||
|
|
||||||
|
topicMap.put("根主题2", root2);
|
||||||
|
|
||||||
|
// 输出
|
||||||
|
graph.setTopicNodes(topicMap);
|
||||||
|
graph.printTopicTree();
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void test2(){
|
||||||
|
System.out.println(LocalDate.now());
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
@@ -59,7 +59,7 @@ class SearchTest {
|
|||||||
List<String> queryPath = new ArrayList<>();
|
List<String> queryPath = new ArrayList<>();
|
||||||
queryPath.add("算法");
|
queryPath.add("算法");
|
||||||
queryPath.add("排序");
|
queryPath.add("排序");
|
||||||
MemoryResult results = memoryGraph.selectMemory(queryPath);
|
// MemoryResult results = memoryGraph.selectMemory(queryPath);
|
||||||
|
|
||||||
// 验证结果应包含:
|
// 验证结果应包含:
|
||||||
// 1. 目标节点所有记忆(java1)
|
// 1. 目标节点所有记忆(java1)
|
||||||
@@ -77,7 +77,7 @@ class SearchTest {
|
|||||||
invalidPath.add("不存在的主题");
|
invalidPath.add("不存在的主题");
|
||||||
|
|
||||||
assertThrows(UnExistedTopicException.class, () -> {
|
assertThrows(UnExistedTopicException.class, () -> {
|
||||||
memoryGraph.selectMemory(invalidPath);
|
// memoryGraph.selectMemory(invalidPath);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -94,7 +94,7 @@ class SearchTest {
|
|||||||
List<String> queryPath = new ArrayList<>();
|
List<String> queryPath = new ArrayList<>();
|
||||||
queryPath.add("编程");
|
queryPath.add("编程");
|
||||||
queryPath.add("Java");
|
queryPath.add("Java");
|
||||||
MemoryResult results = memoryGraph.selectMemory(queryPath);
|
// MemoryResult results = memoryGraph.selectMemory(queryPath);
|
||||||
|
|
||||||
// 应包含:Java记忆 + 父级最新记忆
|
// 应包含:Java记忆 + 父级最新记忆
|
||||||
// assertTrue(results.stream().anyMatch(m -> "java1".equals(m.getMemoryId())));
|
// assertTrue(results.stream().anyMatch(m -> "java1".equals(m.getMemoryId())));
|
||||||
@@ -136,7 +136,7 @@ class SearchTest {
|
|||||||
|
|
||||||
// 执行查询
|
// 执行查询
|
||||||
List<String> queryPath = createTopicPath("编程", "Java");
|
List<String> queryPath = createTopicPath("编程", "Java");
|
||||||
MemoryResult results = memoryGraph.selectMemory(queryPath);
|
// MemoryResult results = memoryGraph.selectMemory(queryPath);
|
||||||
|
|
||||||
// 验证结果应包含最新关联记忆(dbNew)
|
// 验证结果应包含最新关联记忆(dbNew)
|
||||||
// assertTrue(results.stream().anyMatch(m -> "dbNew".equals(m.getMemoryId())),
|
// assertTrue(results.stream().anyMatch(m -> "dbNew".equals(m.getMemoryId())),
|
||||||
|
|||||||
Reference in New Issue
Block a user