- MemoryGraph 新增输出主题树功能

- 将 TopicExtractor 重命名为 MemorySelectExtractor ,并添加了提示词
- 记忆模块开发工作进行中
- 新增 SliceSummary 类,服务于记忆模块
This commit is contained in:
2025-04-20 23:07:22 +08:00
parent 7594a1c43b
commit cb85192c50
24 changed files with 449 additions and 87 deletions

View File

@@ -1,12 +1,20 @@
package work.slhaf;
import work.slhaf.agent.Agent;
import work.slhaf.agent.core.interaction.data.InteractionInputData;
import java.io.IOException;
public class Main {
public static void main(String[] args) throws IOException {
public static void main(String[] args) throws IOException, ClassNotFoundException {
Agent agent = Agent.initialize();
agent.receiveUserInput("111","222","hello");
InteractionInputData inputData = new InteractionInputData();
inputData.setContent("hello");
inputData.setPlatform("cli");
inputData.setUserInfo("owner");
inputData.setUserNickName("master");
agent.receiveUserInput(inputData);
}
}

View File

@@ -36,13 +36,9 @@ public class Agent implements TaskCallback {
/**
* 接收用户输入,包装为标准输入数据类
* @param input
* @param inputData
*/
public void receiveUserInput(String userNickName,String userInfo,String input) throws IOException {
InteractionInputData inputData = new InteractionInputData();
inputData.setContent(input);
inputData.setUserInfo(userInfo);
inputData.setUserNickName(userNickName);
public void receiveUserInput(InteractionInputData inputData) throws IOException, ClassNotFoundException {
inputData.setLocalDateTime(LocalDateTime.now());
interactionHub.call(inputData);
}

View File

@@ -5,13 +5,12 @@ import lombok.Data;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.io.FileUtils;
import work.slhaf.agent.core.model.CoreModel;
import work.slhaf.agent.core.memory.MemoryManager;
import work.slhaf.agent.modules.memory.MemorySelectExtractor;
import work.slhaf.agent.modules.memory.MemorySelector;
import work.slhaf.agent.modules.memory.MemoryUpdater;
import work.slhaf.agent.modules.memory.SliceEvaluator;
import work.slhaf.agent.modules.task.TaskEvaluator;
import work.slhaf.agent.modules.task.TaskScheduler;
import work.slhaf.agent.modules.topic.TopicExtractor;
import java.io.File;
import java.io.IOException;
@@ -67,7 +66,7 @@ public class Config {
private static void generatePipelineConfig() {
List<ModuleConfig> moduleConfigList = List.of(
new ModuleConfig(TopicExtractor.class.getName(), ModuleConfig.Constant.INTERNAL, null),
new ModuleConfig(MemorySelectExtractor.class.getName(), ModuleConfig.Constant.INTERNAL, null),
new ModuleConfig(MemorySelector.class.getName(), ModuleConfig.Constant.INTERNAL, null),
new ModuleConfig(CoreModel.class.getName(),ModuleConfig.Constant.INTERNAL,null),
new ModuleConfig(MemoryUpdater.class.getName(),ModuleConfig.Constant.INTERNAL,null),
@@ -100,7 +99,7 @@ public class Config {
}
case 3 -> {
System.out.println("TopicExtractor:");
yield TopicExtractor.MODEL_KEY;
yield MemorySelectExtractor.MODEL_KEY;
}
default -> throw new RuntimeException();
};

View File

@@ -6,6 +6,56 @@ public class ModelConstant {
public static final String SLICE_EVALUATOR_PROMPT = """
""";
public static final String TOPIC_EXTRACTOR_PROMPT = """
# MemorySelectExtractor 提示词
## 功能说明
你需要根据用户输入的JSON数据分析其`text`字段内容判断是否需要通过主题路径或日期进行记忆查询并返回标准化格式的JSON响应。
## 输入字段说明
- `text`: 用户输入的文本内容
- `topic_tree`: 当前可用的主题树结构(括号内数字表示子主题数量)
- `date`: 当前对话发生的日期(用于时间推理)
## 输出规则
1. 当文本涉及明确主题路径时:
- 使用`"type": "topic"`
- `text`字段格式为"根主题->子主题->子子主题"(必须**完全匹配**topic_tree中的层级包括从[root]到目标主题的完整路径)
- 示例:{
"type": "topic",
"text": "工作->项目A->需求文档"
}
2. 当文本包含明确可推算的日期时:
- 使用`"type": "date"`
- 日期格式必须为"YYYY-MM-DD"
- 仅接受具体日期(不接受"上周"等模糊表达)
- 示例:{
"type": "date",
"text": "2024-04-15"
}
3. 当不需要查询或无法确定时:
- 使用`"type": "none"`
- 示例:{
"type": "none"
}
## 完整示例
用户输入:{
"text": "还记得我们讨论过游戏引擎的物理系统实现吗?",
"topic_tree": "
技术 (3)[root]
├── 游戏开发 (2)
│ ├── 图形渲染 (1)
│ └── 物理系统 (0)
└── 人工智能 (1)",
"date": "2024-04-20"
}
正确响应:{
"type": "topic",
"text": "技术->游戏开发->物理系统"
}
""";
public static final String TASK_EVALUATOR_PROMPT = """
""";

View File

@@ -7,8 +7,8 @@ import work.slhaf.agent.core.interaction.InteractionModulesLoader;
import work.slhaf.agent.core.interaction.TaskCallback;
import work.slhaf.agent.core.interaction.data.InteractionContext;
import work.slhaf.agent.core.interaction.data.InteractionInputData;
import work.slhaf.agent.core.model.CoreModel;
import work.slhaf.agent.core.memory.MemoryManager;
import work.slhaf.agent.core.model.CoreModel;
import work.slhaf.agent.modules.preprocess.PreprocessExecutor;
import work.slhaf.agent.modules.task.TaskScheduler;
@@ -35,7 +35,7 @@ public class InteractionHub {
return interactionHub;
}
public void call(InteractionInputData inputData) throws IOException {
public void call(InteractionInputData inputData) throws IOException, ClassNotFoundException {
//预处理
InteractionContext interactionContext = PreprocessExecutor.getInstance().execute(inputData);
//加载模块
@@ -43,6 +43,6 @@ public class InteractionHub {
for (InteractionModule interactionModule : interactionModules) {
interactionModule.execute(interactionContext);
}
callback.onTaskFinished(interactionContext.getUserInfo(), interactionContext.getCoreResponse().getMessage());
callback.onTaskFinished(interactionContext.getUserInfo(), interactionContext.getCoreResponse().getString("message"));
}
}

View File

@@ -2,6 +2,8 @@ package work.slhaf.agent.core.interaction;
import work.slhaf.agent.core.interaction.data.InteractionContext;
import java.io.IOException;
public interface InteractionModule {
void execute(InteractionContext context);
void execute(InteractionContext context) throws IOException, ClassNotFoundException;
}

View File

@@ -2,12 +2,8 @@ package work.slhaf.agent.core.interaction.data;
import com.alibaba.fastjson2.JSONObject;
import lombok.Data;
import work.slhaf.agent.common.chat.pojo.ChatResponse;
import work.slhaf.agent.core.memory.pojo.MemorySlice;
import work.slhaf.agent.modules.task.data.TaskData;
import java.time.LocalDateTime;
import java.util.List;
@Data
public class InteractionContext {
@@ -17,10 +13,7 @@ public class InteractionContext {
protected boolean finished;
protected String input;
protected JSONObject tempResult;
protected ChatResponse coreResponse;
protected List<MemorySlice> memorySlices;
protected List<String> topicPath;
protected List<TaskData> taskDataList;
protected JSONObject moduleContext;
protected JSONObject coreResponse;
}

View File

@@ -10,4 +10,5 @@ public class InteractionInputData {
private String userNickName;
private String content;
private LocalDateTime localDateTime;
private String platform;
}

View File

@@ -8,10 +8,7 @@ import work.slhaf.agent.common.chat.pojo.Message;
import work.slhaf.agent.core.memory.exception.UnExistedTopicException;
import work.slhaf.agent.core.memory.node.MemoryNode;
import work.slhaf.agent.core.memory.node.TopicNode;
import work.slhaf.agent.core.memory.pojo.MemoryResult;
import work.slhaf.agent.core.memory.pojo.MemorySlice;
import work.slhaf.agent.core.memory.pojo.MemorySliceResult;
import work.slhaf.agent.core.memory.pojo.PersistableObject;
import work.slhaf.agent.core.memory.pojo.*;
import java.io.*;
import java.nio.file.Files;
@@ -100,6 +97,11 @@ public class MemoryGraph extends PersistableObject {
*/
private List<Message> chatMessages;
/**
* 用户列表
*/
private List<User> users;
public MemoryGraph(String id) {
this.id = id;
this.topicNodes = new HashMap<>();
@@ -266,7 +268,7 @@ public class MemoryGraph extends PersistableObject {
//放入新缓存
userDialogMap
.computeIfAbsent(now, k -> new ConcurrentHashMap<>())
.merge(slice.getStartUser(), slice.getSummary(), (oldVal, newVal) -> oldVal + " " + newVal);
.merge(slice.getStartUserId(), slice.getSummary(), (oldVal, newVal) -> oldVal + " " + newVal);
}
@@ -298,7 +300,8 @@ public class MemoryGraph extends PersistableObject {
}
public MemoryResult selectMemory(List<String> topicPath) throws IOException, ClassNotFoundException {
public MemoryResult selectMemory(String topicPathStr) throws IOException, ClassNotFoundException {
List<String> topicPath = List.of(topicPathStr.split("->"));
MemoryResult memoryResult = new MemoryResult();
//每日刷新缓存
@@ -319,7 +322,6 @@ public class MemoryGraph extends PersistableObject {
MemorySliceResult sliceResult = new MemorySliceResult();
for (MemoryNode memoryNode : targetParentNode.getTopicNodes().get(targetTopic).getMemoryNodes()) {
List<MemorySlice> endpointMemorySliceList = memoryNode.loadMemorySliceList();
// targetSliceList.addAll(endpointMemorySliceList);
for (MemorySlice memorySlice : endpointMemorySliceList) {
sliceResult.setSliceBefore(memorySlice.getSliceBefore());
sliceResult.setMemorySlice(memorySlice);
@@ -420,5 +422,28 @@ public class MemoryGraph extends PersistableObject {
}
return targetParentNode;
}
public void printTopicTree() {
for (Map.Entry<String, TopicNode> entry : topicNodes.entrySet()) {
String rootName = entry.getKey();
TopicNode rootNode = entry.getValue();
System.out.println(rootName+"[root]");
printSubTopicsTreeFormat(rootNode, "", true);
}
}
private void printSubTopicsTreeFormat(TopicNode node, String prefix, boolean isLast) {
if (node.getTopicNodes() == null || node.getTopicNodes().isEmpty()) return;
List<Map.Entry<String, TopicNode>> entries = new ArrayList<>(node.getTopicNodes().entrySet());
for (int i = 0; i < entries.size(); i++) {
boolean last = (i == entries.size() - 1);
Map.Entry<String, TopicNode> entry = entries.get(i);
System.out.println(prefix + (last ? "└── " : "├── ") + entry.getKey());
printSubTopicsTreeFormat(entry.getValue(), prefix + (last ? " " : ""), last);
}
}
}

View File

@@ -5,9 +5,15 @@ import lombok.extern.slf4j.Slf4j;
import work.slhaf.agent.common.config.Config;
import work.slhaf.agent.core.interaction.InteractionModule;
import work.slhaf.agent.core.interaction.data.InteractionContext;
import work.slhaf.agent.core.memory.pojo.MemoryResult;
import work.slhaf.agent.core.memory.pojo.User;
import work.slhaf.agent.modules.memory.SliceEvaluator;
import java.io.IOException;
import java.time.LocalDate;
import java.util.ArrayList;
import java.util.List;
import java.util.UUID;
@Data
@Slf4j
@@ -36,4 +42,37 @@ public class MemoryManager implements InteractionModule {
return memoryManager;
}
public MemoryResult selectMemory(String path) throws IOException, ClassNotFoundException {
return memoryGraph.selectMemory(path);
}
public MemoryResult selectMemory(LocalDate date) {
return memoryGraph.selectMemory(date);
}
public String getUserId(String userInfo,String nickName) {
String userId = null;
for (User user : memoryGraph.getUsers()) {
if (user.getInfo().contains(userInfo)){
userId = user.getUuid();
}
}
if (userId == null) {
User newUser = setNewUser(userInfo, nickName);
memoryGraph.getUsers().add(newUser);
userId = newUser.getUuid();
}
return userId;
}
private static User setNewUser(String userInfo, String nickName) {
User newUser = new User();
newUser.setUuid(UUID.randomUUID().toString());
List<String> infoList = new ArrayList<>();
infoList.add(userInfo);
newUser.setInfo(infoList);
newUser.setNickName(nickName);
return newUser;
}
}

View File

@@ -45,12 +45,12 @@ public class MemorySlice extends PersistableObject implements Comparable<MemoryS
* 多用户设定
* 发起该切片对话的用户
*/
private String startUser;
private String startUserId;
/**
* 该切片涉及到的用户uuid
*/
private List<String> involvedUsers;
private List<String> involvedUserIds;
/**
* 是否仅供发起用户作为记忆参考

View File

@@ -38,6 +38,6 @@ public class CoreModel extends Model implements InteractionModule {
//TODO 需要拼接上下文之后再发送给主模型
ChatResponse res = runChat(interactionContext.getInput());
interactionContext.setCoreResponse(res);
// interactionContext.setCoreResponse();
}
}

View File

@@ -41,8 +41,8 @@ public class AgentWebSocketServer extends WebSocketServer implements MessageSend
InteractionInputData inputData = JSONObject.parseObject(s, InteractionInputData.class);
userSessions.put(inputData.getUserInfo(), webSocket); // 注册连接
try {
agent.receiveUserInput(inputData.getUserNickName(), inputData.getUserInfo(), inputData.getContent());
} catch (IOException e) {
agent.receiveUserInput(inputData);
} catch (IOException | ClassNotFoundException e) {
throw new RuntimeException(e);
}
}

View File

@@ -0,0 +1,41 @@
package work.slhaf.agent.modules.memory;
import com.alibaba.fastjson2.JSONObject;
import lombok.Data;
import lombok.EqualsAndHashCode;
import work.slhaf.agent.common.config.Config;
import work.slhaf.agent.common.model.Model;
import work.slhaf.agent.common.model.ModelConstant;
import java.io.IOException;
@EqualsAndHashCode(callSuper = true)
@Data
public class MemorySelectExtractor extends Model {
public static final String MODEL_KEY = "topic_extractor";
private static MemorySelectExtractor memorySelectExtractor;
private MemorySelectExtractor() {
}
public static MemorySelectExtractor getInstance() throws IOException, ClassNotFoundException {
if (memorySelectExtractor == null) {
Config config = Config.getConfig();
memorySelectExtractor = new MemorySelectExtractor();
setModel(config, memorySelectExtractor, MODEL_KEY, ModelConstant.TOPIC_EXTRACTOR_PROMPT);
}
return memorySelectExtractor;
}
public JSONObject execute(String input) {
return JSONObject.parseObject(singleChat(input).getMessage());
}
public static class Constant {
public static final String NONE = "none";
public static final String DATE = "date";
public static final String TOPIC = "topic";
}
}

View File

@@ -1,11 +1,14 @@
package work.slhaf.agent.modules.memory;
import com.alibaba.fastjson2.JSONObject;
import lombok.Data;
import work.slhaf.agent.core.interaction.InteractionModule;
import work.slhaf.agent.core.interaction.data.InteractionContext;
import work.slhaf.agent.core.memory.MemoryManager;
import work.slhaf.agent.core.memory.pojo.MemoryResult;
import java.io.IOException;
import java.time.LocalDate;
@Data
public class MemorySelector implements InteractionModule {
@@ -14,20 +17,40 @@ public class MemorySelector implements InteractionModule {
private MemoryManager memoryManager;
private SliceEvaluator sliceEvaluator;
private MemorySelectExtractor memorySelectExtractor;
private MemorySelector(){}
private MemorySelector() {
}
public static MemorySelector getInstance() throws IOException, ClassNotFoundException {
if (memorySelector == null) {
memorySelector = new MemorySelector();
memorySelector.setMemoryManager(MemoryManager.getInstance());
memorySelector.setSliceEvaluator(SliceEvaluator.getInstance());
memorySelector.setMemorySelectExtractor(MemorySelectExtractor.getInstance());
}
return memorySelector;
}
@Override
public void execute(InteractionContext interactionContext) {
public void execute(InteractionContext interactionContext) throws IOException, ClassNotFoundException {
//获取主题路径
JSONObject extractorResult = memorySelectExtractor.execute(interactionContext.getInput());
String selectType = extractorResult.getString("type");
//根据主结果进行操作查找切片
MemoryResult memoryResult = switch (selectType) {
case MemorySelectExtractor.Constant.DATE ->
memoryManager.selectMemory(LocalDate.parse(extractorResult.getString(MemorySelectExtractor.Constant.DATE)));
case MemorySelectExtractor.Constant.TOPIC ->
memoryManager.selectMemory(MemorySelectExtractor.Constant.TOPIC);
default -> null;
};
//评估切片
if (memoryResult == null) {
memoryResult = sliceEvaluator.execute(memoryResult,interactionContext);
}
//设置上下文
}
}

View File

@@ -15,6 +15,7 @@ public class MemoryUpdater implements InteractionModule {
private MemoryManager memoryManager;
private InteractionThreadPoolExecutor executor;
private MemorySelectExtractor memorySelectExtractor;
private MemoryUpdater(){}
@@ -22,6 +23,7 @@ public class MemoryUpdater implements InteractionModule {
if (memoryUpdater == null) {
memoryUpdater = new MemoryUpdater();
memoryUpdater.setMemoryManager(MemoryManager.getInstance());
memoryUpdater.setMemorySelectExtractor(MemorySelectExtractor.getInstance());
}
return memoryUpdater;
}

View File

@@ -1,13 +1,22 @@
package work.slhaf.agent.modules.memory;
import cn.hutool.json.JSONUtil;
import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.extern.slf4j.Slf4j;
import work.slhaf.agent.common.config.Config;
import work.slhaf.agent.common.model.Model;
import work.slhaf.agent.common.model.ModelConstant;
import work.slhaf.agent.core.interaction.data.InteractionContext;
import work.slhaf.agent.core.memory.MemoryManager;
import work.slhaf.agent.core.memory.pojo.MemoryResult;
import work.slhaf.agent.core.memory.pojo.MemorySlice;
import work.slhaf.agent.core.memory.pojo.MemorySliceResult;
import work.slhaf.agent.modules.memory.data.SliceSummary;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
@EqualsAndHashCode(callSuper = true)
@Data
@@ -16,19 +25,72 @@ public class SliceEvaluator extends Model {
public static final String MODEL_KEY = "slice_evaluator";
private static SliceEvaluator sliceEvaluator;
private MemoryManager memoryManager;
private SliceEvaluator(){}
private SliceEvaluator() {
}
public static SliceEvaluator getInstance() throws IOException, ClassNotFoundException {
if (sliceEvaluator == null) {
Config config = Config.getConfig();
sliceEvaluator = new SliceEvaluator();
setModel(config,sliceEvaluator, MODEL_KEY, ModelConstant.SLICE_EVALUATOR_PROMPT);
sliceEvaluator.setMemoryManager(MemoryManager.getInstance());
setModel(config, sliceEvaluator, MODEL_KEY, ModelConstant.SLICE_EVALUATOR_PROMPT);
log.info("SliceEvaluator注册完毕...");
}
return sliceEvaluator;
}
public MemoryResult execute(MemoryResult memoryResult, InteractionContext context) {
List<SliceSummary> sliceSummaryList = new ArrayList<>();
setSliceSummaryList(memoryResult, context, sliceSummaryList);
String primaryJsonStr = singleChat(JSONUtil.toJsonStr(sliceSummaryList)).getMessage();
//TODO 解析并转换为过滤后的MemoryResult
return null;
}
private void setSliceSummaryList(MemoryResult memoryResult, InteractionContext context, List<SliceSummary> sliceSummaryList) {
for (MemorySliceResult memorySliceResult : memoryResult.getMemorySliceResult()) {
//判断是否为发起用户
if (accessible(memorySliceResult.getMemorySlice(), context)) {
SliceSummary sliceSummary = new SliceSummary();
sliceSummary.setId(memorySliceResult.getMemorySlice().getTimestamp());
String stringBuilder = memorySliceResult.getSliceBefore().getSummary() +
"\r\n" +
memorySliceResult.getMemorySlice().getSummary() +
"\r\n" +
memorySliceResult.getSliceAfter().getSummary();
sliceSummary.setSummary(stringBuilder);
sliceSummaryList.add(sliceSummary);
}
}
for (MemorySlice memorySlice : memoryResult.getRelatedMemorySliceResult()) {
SliceSummary sliceSummary = new SliceSummary();
sliceSummary.setId(memorySlice.getTimestamp());
sliceSummary.setSummary(memorySlice.getSummary());
sliceSummaryList.add(sliceSummary);
}
}
private boolean accessible(MemorySlice slice, InteractionContext context) {
boolean ok;
String startUserId = slice.getStartUserId();
String userInfo = context.getUserInfo();
String nickName = context.getUserNickname();
if (memoryManager.getUserId(userInfo, nickName).equals(startUserId)) {
ok = true;
} else {
ok = !slice.isPrivate();
}
return ok;
}
}

View File

@@ -0,0 +1,9 @@
package work.slhaf.agent.modules.memory.data;
import lombok.Data;
@Data
public class SliceSummary {
private String summary;
private Long id;
}

View File

@@ -1,5 +1,6 @@
package work.slhaf.agent.modules.preprocess;
import com.alibaba.fastjson2.JSONObject;
import work.slhaf.agent.core.interaction.data.InteractionContext;
import work.slhaf.agent.core.interaction.data.InteractionInputData;
@@ -26,6 +27,8 @@ public class PreprocessExecutor {
context.setFinished(false);
context.setInput(inputData.getContent());
context.setModuleContext(new JSONObject());
return context;
}
}

View File

@@ -3,7 +3,6 @@ package work.slhaf.agent.modules.task;
import lombok.Data;
import lombok.EqualsAndHashCode;
import work.slhaf.agent.common.config.Config;
import work.slhaf.agent.common.config.ModelConfig;
import work.slhaf.agent.common.model.Model;
import work.slhaf.agent.common.model.ModelConstant;

View File

@@ -1,41 +0,0 @@
package work.slhaf.agent.modules.topic;
import lombok.Data;
import lombok.EqualsAndHashCode;
import work.slhaf.agent.common.chat.constant.ChatConstant;
import work.slhaf.agent.common.chat.pojo.ChatResponse;
import work.slhaf.agent.common.chat.pojo.Message;
import work.slhaf.agent.common.config.Config;
import work.slhaf.agent.common.model.Model;
import work.slhaf.agent.common.model.ModelConstant;
import work.slhaf.agent.core.interaction.InteractionModule;
import work.slhaf.agent.core.interaction.data.InteractionContext;
import java.io.IOException;
@EqualsAndHashCode(callSuper = true)
@Data
public class TopicExtractor extends Model implements InteractionModule {
public static final String MODEL_KEY = "topic_extractor";
private static TopicExtractor topicExtractor;
private TopicExtractor() {
}
public static TopicExtractor getInstance() throws IOException, ClassNotFoundException {
if (topicExtractor == null) {
Config config = Config.getConfig();
topicExtractor = new TopicExtractor();
setModel(config, topicExtractor, MODEL_KEY, ModelConstant.TOPIC_EXTRACTOR_PROMPT);
}
return topicExtractor;
}
@Override
public void execute(InteractionContext interactionContext) {
String primaryMessageResponse = singleChat(interactionContext.getInput()).getMessage();
}
}

View File

@@ -0,0 +1,89 @@
package memory;
import org.junit.jupiter.api.Test;
import work.slhaf.agent.common.chat.ChatClient;
import work.slhaf.agent.common.chat.constant.ChatConstant;
import work.slhaf.agent.common.chat.pojo.Message;
import java.util.ArrayList;
import java.util.List;
public class AITest {
@Test
public void test1(){
ChatClient client = new ChatClient("https://open.bigmodel.cn/api/paas/v4/chat/completions","3db444552530b7742b0c53425fb93dcc.LcVwYjByht9AC3N9","glm-4-flash");
List<Message> messages = new ArrayList<>();
messages.add(new Message(ChatConstant.Character.SYSTEM, """
# MemorySelectExtractor 提示词
## 功能说明
你需要根据用户输入的JSON数据分析其`text`字段内容判断是否需要通过主题路径或日期进行记忆查询并返回标准化格式的JSON响应。
## 输入字段说明
- `text`: 用户输入的文本内容
- `topic_tree`: 当前可用的主题树结构(括号内数字表示子主题数量)
- `date`: 当前对话发生的日期(用于时间推理)
## 输出规则
1. 当文本涉及明确主题路径时:
- 使用`"type": "topic"`
- `text`字段格式为"根主题->子主题->子子主题"(必须**完全匹配**topic_tree中的层级包括从[root]到目标主题的完整路径)
- 示例:{
"type": "topic",
"text": "工作->项目A->需求文档"
}
2. 当文本包含明确可推算的日期时:
- 使用`"type": "date"`
- 日期格式必须为"YYYY-MM-DD"
- 仅接受具体日期(不接受"上周"等模糊表达)
- 示例:{
"type": "date",
"text": "2024-04-15"
}
3. 当不需要查询或无法确定时:
- 使用`"type": "none"`
- 示例:{
"type": "none"
}
## 完整示例
用户输入:{
"text": "还记得我们讨论过游戏引擎的物理系统实现吗?",
"topic_tree": "
技术 (3)[root]
├── 游戏开发 (2)
│ ├── 图形渲染 (1)
│ └── 物理系统 (0)
└── 人工智能 (1)",
"date": "2024-04-20"
}
正确响应:{
"type": "topic",
"text": "技术->游戏开发->物理系统"
}
"""));
messages.add(new Message(ChatConstant.Character.USER, """
{
"text": "上周似乎发生了什么重要的事??",
"topic_tree": "
汽车工程 (4)[root]
├── 动力系统 (3)
│ ├── 发动机 (1)
│ └── 新能源电池 (2)
│ ├── 测试标准 (1)
│ └── 安全规范 (1)
└── 车身设计 (1)
软件开发 (3)[root]
质量管理 (2)[root]
├── ISO认证 (1)
└── 行业标准 (1)",
"date": "2024-04-20"
}
"""));
System.out.println(client.runChat(messages).getMessage());
}
}

View File

@@ -0,0 +1,62 @@
package memory;
import cn.hutool.core.date.LocalDateTimeUtil;
import org.junit.jupiter.api.Test;
import work.slhaf.agent.core.memory.MemoryGraph;
import work.slhaf.agent.core.memory.node.TopicNode;
import java.time.LocalDate;
import java.time.LocalDateTime;
import java.util.HashMap;
import java.util.concurrent.ConcurrentHashMap;
public class MemoryTest {
@Test
public void test1() {
MemoryGraph graph = new MemoryGraph("test");
HashMap<String, TopicNode> topicMap = new HashMap<>();
TopicNode root1 = new TopicNode();
root1.setTopicNodes(new ConcurrentHashMap<>());
TopicNode sub1 = new TopicNode();
sub1.setTopicNodes(new ConcurrentHashMap<>());
TopicNode sub2 = new TopicNode();
sub2.setTopicNodes(new ConcurrentHashMap<>());
TopicNode subsub1 = new TopicNode();
subsub1.setTopicNodes(new ConcurrentHashMap<>());
// 构造结构root -> sub1 -> subsub1, root -> sub2
sub1.getTopicNodes().put("子子主题1", subsub1);
root1.getTopicNodes().put("子主题1", sub1);
root1.getTopicNodes().put("子主题2", sub2);
topicMap.put("根主题1", root1);
// 添加 root2
TopicNode root2 = new TopicNode();
root2.setTopicNodes(new ConcurrentHashMap<>());
TopicNode sub3 = new TopicNode();
sub3.setTopicNodes(new ConcurrentHashMap<>());
// 构造结构root2 -> sub3
root2.getTopicNodes().put("子主题3", sub3);
topicMap.put("根主题2", root2);
// 输出
graph.setTopicNodes(topicMap);
graph.printTopicTree();
}
@Test
public void test2(){
System.out.println(LocalDate.now());
}
}

View File

@@ -59,7 +59,7 @@ class SearchTest {
List<String> queryPath = new ArrayList<>();
queryPath.add("算法");
queryPath.add("排序");
MemoryResult results = memoryGraph.selectMemory(queryPath);
// MemoryResult results = memoryGraph.selectMemory(queryPath);
// 验证结果应包含:
// 1. 目标节点所有记忆java1
@@ -77,7 +77,7 @@ class SearchTest {
invalidPath.add("不存在的主题");
assertThrows(UnExistedTopicException.class, () -> {
memoryGraph.selectMemory(invalidPath);
// memoryGraph.selectMemory(invalidPath);
});
}
@@ -94,7 +94,7 @@ class SearchTest {
List<String> queryPath = new ArrayList<>();
queryPath.add("编程");
queryPath.add("Java");
MemoryResult results = memoryGraph.selectMemory(queryPath);
// MemoryResult results = memoryGraph.selectMemory(queryPath);
// 应包含Java记忆 + 父级最新记忆
// assertTrue(results.stream().anyMatch(m -> "java1".equals(m.getMemoryId())));
@@ -136,7 +136,7 @@ class SearchTest {
// 执行查询
List<String> queryPath = createTopicPath("编程", "Java");
MemoryResult results = memoryGraph.selectMemory(queryPath);
// MemoryResult results = memoryGraph.selectMemory(queryPath);
// 验证结果应包含最新关联记忆dbNew
// assertTrue(results.stream().anyMatch(m -> "dbNew".equals(m.getMemoryId())),