mirror of
https://github.com/slhaf/Partner.git
synced 2026-05-12 08:43:02 +08:00
- MemoryGraph 新增输出主题树功能
- 将 TopicExtractor 重命名为 MemorySelectExtractor ,并添加了提示词 - 记忆模块开发工作进行中 - 新增 SliceSummary 类,服务于记忆模块
This commit is contained in:
@@ -1,12 +1,20 @@
|
||||
package work.slhaf;
|
||||
|
||||
import work.slhaf.agent.Agent;
|
||||
import work.slhaf.agent.core.interaction.data.InteractionInputData;
|
||||
|
||||
import java.io.IOException;
|
||||
|
||||
public class Main {
|
||||
public static void main(String[] args) throws IOException {
|
||||
public static void main(String[] args) throws IOException, ClassNotFoundException {
|
||||
Agent agent = Agent.initialize();
|
||||
agent.receiveUserInput("111","222","hello");
|
||||
|
||||
InteractionInputData inputData = new InteractionInputData();
|
||||
inputData.setContent("hello");
|
||||
inputData.setPlatform("cli");
|
||||
inputData.setUserInfo("owner");
|
||||
inputData.setUserNickName("master");
|
||||
|
||||
agent.receiveUserInput(inputData);
|
||||
}
|
||||
}
|
||||
@@ -36,13 +36,9 @@ public class Agent implements TaskCallback {
|
||||
|
||||
/**
|
||||
* 接收用户输入,包装为标准输入数据类
|
||||
* @param input
|
||||
* @param inputData
|
||||
*/
|
||||
public void receiveUserInput(String userNickName,String userInfo,String input) throws IOException {
|
||||
InteractionInputData inputData = new InteractionInputData();
|
||||
inputData.setContent(input);
|
||||
inputData.setUserInfo(userInfo);
|
||||
inputData.setUserNickName(userNickName);
|
||||
public void receiveUserInput(InteractionInputData inputData) throws IOException, ClassNotFoundException {
|
||||
inputData.setLocalDateTime(LocalDateTime.now());
|
||||
interactionHub.call(inputData);
|
||||
}
|
||||
|
||||
@@ -5,13 +5,12 @@ import lombok.Data;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.io.FileUtils;
|
||||
import work.slhaf.agent.core.model.CoreModel;
|
||||
import work.slhaf.agent.core.memory.MemoryManager;
|
||||
import work.slhaf.agent.modules.memory.MemorySelectExtractor;
|
||||
import work.slhaf.agent.modules.memory.MemorySelector;
|
||||
import work.slhaf.agent.modules.memory.MemoryUpdater;
|
||||
import work.slhaf.agent.modules.memory.SliceEvaluator;
|
||||
import work.slhaf.agent.modules.task.TaskEvaluator;
|
||||
import work.slhaf.agent.modules.task.TaskScheduler;
|
||||
import work.slhaf.agent.modules.topic.TopicExtractor;
|
||||
|
||||
import java.io.File;
|
||||
import java.io.IOException;
|
||||
@@ -67,7 +66,7 @@ public class Config {
|
||||
|
||||
private static void generatePipelineConfig() {
|
||||
List<ModuleConfig> moduleConfigList = List.of(
|
||||
new ModuleConfig(TopicExtractor.class.getName(), ModuleConfig.Constant.INTERNAL, null),
|
||||
new ModuleConfig(MemorySelectExtractor.class.getName(), ModuleConfig.Constant.INTERNAL, null),
|
||||
new ModuleConfig(MemorySelector.class.getName(), ModuleConfig.Constant.INTERNAL, null),
|
||||
new ModuleConfig(CoreModel.class.getName(),ModuleConfig.Constant.INTERNAL,null),
|
||||
new ModuleConfig(MemoryUpdater.class.getName(),ModuleConfig.Constant.INTERNAL,null),
|
||||
@@ -100,7 +99,7 @@ public class Config {
|
||||
}
|
||||
case 3 -> {
|
||||
System.out.println("TopicExtractor:");
|
||||
yield TopicExtractor.MODEL_KEY;
|
||||
yield MemorySelectExtractor.MODEL_KEY;
|
||||
}
|
||||
default -> throw new RuntimeException();
|
||||
};
|
||||
|
||||
@@ -6,6 +6,56 @@ public class ModelConstant {
|
||||
public static final String SLICE_EVALUATOR_PROMPT = """
|
||||
""";
|
||||
public static final String TOPIC_EXTRACTOR_PROMPT = """
|
||||
# MemorySelectExtractor 提示词
|
||||
|
||||
## 功能说明
|
||||
你需要根据用户输入的JSON数据,分析其`text`字段内容,判断是否需要通过主题路径或日期进行记忆查询,并返回标准化格式的JSON响应。
|
||||
|
||||
## 输入字段说明
|
||||
- `text`: 用户输入的文本内容
|
||||
- `topic_tree`: 当前可用的主题树结构(括号内数字表示子主题数量)
|
||||
- `date`: 当前对话发生的日期(用于时间推理)
|
||||
|
||||
## 输出规则
|
||||
1. 当文本涉及明确主题路径时:
|
||||
- 使用`"type": "topic"`
|
||||
- `text`字段格式为"根主题->子主题->子子主题"(必须**完全匹配**topic_tree中的层级,包括从[root]到目标主题的完整路径)
|
||||
- 示例:{
|
||||
"type": "topic",
|
||||
"text": "工作->项目A->需求文档"
|
||||
}
|
||||
|
||||
2. 当文本包含明确可推算的日期时:
|
||||
- 使用`"type": "date"`
|
||||
- 日期格式必须为"YYYY-MM-DD"
|
||||
- 仅接受具体日期(不接受"上周"等模糊表达)
|
||||
- 示例:{
|
||||
"type": "date",
|
||||
"text": "2024-04-15"
|
||||
}
|
||||
|
||||
3. 当不需要查询或无法确定时:
|
||||
- 使用`"type": "none"`
|
||||
- 示例:{
|
||||
"type": "none"
|
||||
}
|
||||
|
||||
## 完整示例
|
||||
用户输入:{
|
||||
"text": "还记得我们讨论过游戏引擎的物理系统实现吗?",
|
||||
"topic_tree": "
|
||||
技术 (3)[root]
|
||||
├── 游戏开发 (2)
|
||||
│ ├── 图形渲染 (1)
|
||||
│ └── 物理系统 (0)
|
||||
└── 人工智能 (1)",
|
||||
"date": "2024-04-20"
|
||||
}
|
||||
|
||||
正确响应:{
|
||||
"type": "topic",
|
||||
"text": "技术->游戏开发->物理系统"
|
||||
}
|
||||
""";
|
||||
public static final String TASK_EVALUATOR_PROMPT = """
|
||||
""";
|
||||
|
||||
@@ -7,8 +7,8 @@ import work.slhaf.agent.core.interaction.InteractionModulesLoader;
|
||||
import work.slhaf.agent.core.interaction.TaskCallback;
|
||||
import work.slhaf.agent.core.interaction.data.InteractionContext;
|
||||
import work.slhaf.agent.core.interaction.data.InteractionInputData;
|
||||
import work.slhaf.agent.core.model.CoreModel;
|
||||
import work.slhaf.agent.core.memory.MemoryManager;
|
||||
import work.slhaf.agent.core.model.CoreModel;
|
||||
import work.slhaf.agent.modules.preprocess.PreprocessExecutor;
|
||||
import work.slhaf.agent.modules.task.TaskScheduler;
|
||||
|
||||
@@ -35,7 +35,7 @@ public class InteractionHub {
|
||||
return interactionHub;
|
||||
}
|
||||
|
||||
public void call(InteractionInputData inputData) throws IOException {
|
||||
public void call(InteractionInputData inputData) throws IOException, ClassNotFoundException {
|
||||
//预处理
|
||||
InteractionContext interactionContext = PreprocessExecutor.getInstance().execute(inputData);
|
||||
//加载模块
|
||||
@@ -43,6 +43,6 @@ public class InteractionHub {
|
||||
for (InteractionModule interactionModule : interactionModules) {
|
||||
interactionModule.execute(interactionContext);
|
||||
}
|
||||
callback.onTaskFinished(interactionContext.getUserInfo(), interactionContext.getCoreResponse().getMessage());
|
||||
callback.onTaskFinished(interactionContext.getUserInfo(), interactionContext.getCoreResponse().getString("message"));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,6 +2,8 @@ package work.slhaf.agent.core.interaction;
|
||||
|
||||
import work.slhaf.agent.core.interaction.data.InteractionContext;
|
||||
|
||||
import java.io.IOException;
|
||||
|
||||
public interface InteractionModule {
|
||||
void execute(InteractionContext context);
|
||||
void execute(InteractionContext context) throws IOException, ClassNotFoundException;
|
||||
}
|
||||
|
||||
@@ -2,12 +2,8 @@ package work.slhaf.agent.core.interaction.data;
|
||||
|
||||
import com.alibaba.fastjson2.JSONObject;
|
||||
import lombok.Data;
|
||||
import work.slhaf.agent.common.chat.pojo.ChatResponse;
|
||||
import work.slhaf.agent.core.memory.pojo.MemorySlice;
|
||||
import work.slhaf.agent.modules.task.data.TaskData;
|
||||
|
||||
import java.time.LocalDateTime;
|
||||
import java.util.List;
|
||||
|
||||
@Data
|
||||
public class InteractionContext {
|
||||
@@ -17,10 +13,7 @@ public class InteractionContext {
|
||||
|
||||
protected boolean finished;
|
||||
protected String input;
|
||||
protected JSONObject tempResult;
|
||||
protected ChatResponse coreResponse;
|
||||
|
||||
protected List<MemorySlice> memorySlices;
|
||||
protected List<String> topicPath;
|
||||
protected List<TaskData> taskDataList;
|
||||
protected JSONObject moduleContext;
|
||||
protected JSONObject coreResponse;
|
||||
}
|
||||
|
||||
@@ -10,4 +10,5 @@ public class InteractionInputData {
|
||||
private String userNickName;
|
||||
private String content;
|
||||
private LocalDateTime localDateTime;
|
||||
private String platform;
|
||||
}
|
||||
|
||||
@@ -8,10 +8,7 @@ import work.slhaf.agent.common.chat.pojo.Message;
|
||||
import work.slhaf.agent.core.memory.exception.UnExistedTopicException;
|
||||
import work.slhaf.agent.core.memory.node.MemoryNode;
|
||||
import work.slhaf.agent.core.memory.node.TopicNode;
|
||||
import work.slhaf.agent.core.memory.pojo.MemoryResult;
|
||||
import work.slhaf.agent.core.memory.pojo.MemorySlice;
|
||||
import work.slhaf.agent.core.memory.pojo.MemorySliceResult;
|
||||
import work.slhaf.agent.core.memory.pojo.PersistableObject;
|
||||
import work.slhaf.agent.core.memory.pojo.*;
|
||||
|
||||
import java.io.*;
|
||||
import java.nio.file.Files;
|
||||
@@ -100,6 +97,11 @@ public class MemoryGraph extends PersistableObject {
|
||||
*/
|
||||
private List<Message> chatMessages;
|
||||
|
||||
/**
|
||||
* 用户列表
|
||||
*/
|
||||
private List<User> users;
|
||||
|
||||
public MemoryGraph(String id) {
|
||||
this.id = id;
|
||||
this.topicNodes = new HashMap<>();
|
||||
@@ -266,7 +268,7 @@ public class MemoryGraph extends PersistableObject {
|
||||
//放入新缓存
|
||||
userDialogMap
|
||||
.computeIfAbsent(now, k -> new ConcurrentHashMap<>())
|
||||
.merge(slice.getStartUser(), slice.getSummary(), (oldVal, newVal) -> oldVal + " " + newVal);
|
||||
.merge(slice.getStartUserId(), slice.getSummary(), (oldVal, newVal) -> oldVal + " " + newVal);
|
||||
|
||||
}
|
||||
|
||||
@@ -298,7 +300,8 @@ public class MemoryGraph extends PersistableObject {
|
||||
|
||||
}
|
||||
|
||||
public MemoryResult selectMemory(List<String> topicPath) throws IOException, ClassNotFoundException {
|
||||
public MemoryResult selectMemory(String topicPathStr) throws IOException, ClassNotFoundException {
|
||||
List<String> topicPath = List.of(topicPathStr.split("->"));
|
||||
MemoryResult memoryResult = new MemoryResult();
|
||||
|
||||
//每日刷新缓存
|
||||
@@ -319,7 +322,6 @@ public class MemoryGraph extends PersistableObject {
|
||||
MemorySliceResult sliceResult = new MemorySliceResult();
|
||||
for (MemoryNode memoryNode : targetParentNode.getTopicNodes().get(targetTopic).getMemoryNodes()) {
|
||||
List<MemorySlice> endpointMemorySliceList = memoryNode.loadMemorySliceList();
|
||||
// targetSliceList.addAll(endpointMemorySliceList);
|
||||
for (MemorySlice memorySlice : endpointMemorySliceList) {
|
||||
sliceResult.setSliceBefore(memorySlice.getSliceBefore());
|
||||
sliceResult.setMemorySlice(memorySlice);
|
||||
@@ -420,5 +422,28 @@ public class MemoryGraph extends PersistableObject {
|
||||
}
|
||||
return targetParentNode;
|
||||
}
|
||||
|
||||
public void printTopicTree() {
|
||||
for (Map.Entry<String, TopicNode> entry : topicNodes.entrySet()) {
|
||||
String rootName = entry.getKey();
|
||||
TopicNode rootNode = entry.getValue();
|
||||
System.out.println(rootName+"[root]");
|
||||
printSubTopicsTreeFormat(rootNode, "", true);
|
||||
}
|
||||
}
|
||||
|
||||
private void printSubTopicsTreeFormat(TopicNode node, String prefix, boolean isLast) {
|
||||
if (node.getTopicNodes() == null || node.getTopicNodes().isEmpty()) return;
|
||||
|
||||
List<Map.Entry<String, TopicNode>> entries = new ArrayList<>(node.getTopicNodes().entrySet());
|
||||
for (int i = 0; i < entries.size(); i++) {
|
||||
boolean last = (i == entries.size() - 1);
|
||||
Map.Entry<String, TopicNode> entry = entries.get(i);
|
||||
System.out.println(prefix + (last ? "└── " : "├── ") + entry.getKey());
|
||||
printSubTopicsTreeFormat(entry.getValue(), prefix + (last ? " " : "│ "), last);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
|
||||
|
||||
@@ -5,9 +5,15 @@ import lombok.extern.slf4j.Slf4j;
|
||||
import work.slhaf.agent.common.config.Config;
|
||||
import work.slhaf.agent.core.interaction.InteractionModule;
|
||||
import work.slhaf.agent.core.interaction.data.InteractionContext;
|
||||
import work.slhaf.agent.core.memory.pojo.MemoryResult;
|
||||
import work.slhaf.agent.core.memory.pojo.User;
|
||||
import work.slhaf.agent.modules.memory.SliceEvaluator;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.time.LocalDate;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.UUID;
|
||||
|
||||
@Data
|
||||
@Slf4j
|
||||
@@ -36,4 +42,37 @@ public class MemoryManager implements InteractionModule {
|
||||
return memoryManager;
|
||||
}
|
||||
|
||||
public MemoryResult selectMemory(String path) throws IOException, ClassNotFoundException {
|
||||
return memoryGraph.selectMemory(path);
|
||||
}
|
||||
|
||||
public MemoryResult selectMemory(LocalDate date) {
|
||||
return memoryGraph.selectMemory(date);
|
||||
}
|
||||
|
||||
public String getUserId(String userInfo,String nickName) {
|
||||
String userId = null;
|
||||
for (User user : memoryGraph.getUsers()) {
|
||||
if (user.getInfo().contains(userInfo)){
|
||||
userId = user.getUuid();
|
||||
}
|
||||
}
|
||||
if (userId == null) {
|
||||
User newUser = setNewUser(userInfo, nickName);
|
||||
memoryGraph.getUsers().add(newUser);
|
||||
userId = newUser.getUuid();
|
||||
}
|
||||
return userId;
|
||||
}
|
||||
|
||||
private static User setNewUser(String userInfo, String nickName) {
|
||||
User newUser = new User();
|
||||
newUser.setUuid(UUID.randomUUID().toString());
|
||||
List<String> infoList = new ArrayList<>();
|
||||
infoList.add(userInfo);
|
||||
newUser.setInfo(infoList);
|
||||
newUser.setNickName(nickName);
|
||||
return newUser;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -45,12 +45,12 @@ public class MemorySlice extends PersistableObject implements Comparable<MemoryS
|
||||
* 多用户设定
|
||||
* 发起该切片对话的用户
|
||||
*/
|
||||
private String startUser;
|
||||
private String startUserId;
|
||||
|
||||
/**
|
||||
* 该切片涉及到的用户uuid
|
||||
*/
|
||||
private List<String> involvedUsers;
|
||||
private List<String> involvedUserIds;
|
||||
|
||||
/**
|
||||
* 是否仅供发起用户作为记忆参考
|
||||
|
||||
@@ -38,6 +38,6 @@ public class CoreModel extends Model implements InteractionModule {
|
||||
//TODO 需要拼接上下文之后再发送给主模型
|
||||
|
||||
ChatResponse res = runChat(interactionContext.getInput());
|
||||
interactionContext.setCoreResponse(res);
|
||||
// interactionContext.setCoreResponse();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -41,8 +41,8 @@ public class AgentWebSocketServer extends WebSocketServer implements MessageSend
|
||||
InteractionInputData inputData = JSONObject.parseObject(s, InteractionInputData.class);
|
||||
userSessions.put(inputData.getUserInfo(), webSocket); // 注册连接
|
||||
try {
|
||||
agent.receiveUserInput(inputData.getUserNickName(), inputData.getUserInfo(), inputData.getContent());
|
||||
} catch (IOException e) {
|
||||
agent.receiveUserInput(inputData);
|
||||
} catch (IOException | ClassNotFoundException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,41 @@
|
||||
package work.slhaf.agent.modules.memory;
|
||||
|
||||
import com.alibaba.fastjson2.JSONObject;
|
||||
import lombok.Data;
|
||||
import lombok.EqualsAndHashCode;
|
||||
import work.slhaf.agent.common.config.Config;
|
||||
import work.slhaf.agent.common.model.Model;
|
||||
import work.slhaf.agent.common.model.ModelConstant;
|
||||
|
||||
import java.io.IOException;
|
||||
|
||||
@EqualsAndHashCode(callSuper = true)
|
||||
@Data
|
||||
public class MemorySelectExtractor extends Model {
|
||||
public static final String MODEL_KEY = "topic_extractor";
|
||||
private static MemorySelectExtractor memorySelectExtractor;
|
||||
|
||||
private MemorySelectExtractor() {
|
||||
}
|
||||
|
||||
public static MemorySelectExtractor getInstance() throws IOException, ClassNotFoundException {
|
||||
if (memorySelectExtractor == null) {
|
||||
Config config = Config.getConfig();
|
||||
memorySelectExtractor = new MemorySelectExtractor();
|
||||
setModel(config, memorySelectExtractor, MODEL_KEY, ModelConstant.TOPIC_EXTRACTOR_PROMPT);
|
||||
}
|
||||
|
||||
return memorySelectExtractor;
|
||||
}
|
||||
|
||||
public JSONObject execute(String input) {
|
||||
return JSONObject.parseObject(singleChat(input).getMessage());
|
||||
}
|
||||
|
||||
public static class Constant {
|
||||
public static final String NONE = "none";
|
||||
public static final String DATE = "date";
|
||||
public static final String TOPIC = "topic";
|
||||
}
|
||||
|
||||
}
|
||||
@@ -1,11 +1,14 @@
|
||||
package work.slhaf.agent.modules.memory;
|
||||
|
||||
import com.alibaba.fastjson2.JSONObject;
|
||||
import lombok.Data;
|
||||
import work.slhaf.agent.core.interaction.InteractionModule;
|
||||
import work.slhaf.agent.core.interaction.data.InteractionContext;
|
||||
import work.slhaf.agent.core.memory.MemoryManager;
|
||||
import work.slhaf.agent.core.memory.pojo.MemoryResult;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.time.LocalDate;
|
||||
|
||||
@Data
|
||||
public class MemorySelector implements InteractionModule {
|
||||
@@ -14,20 +17,40 @@ public class MemorySelector implements InteractionModule {
|
||||
|
||||
private MemoryManager memoryManager;
|
||||
private SliceEvaluator sliceEvaluator;
|
||||
private MemorySelectExtractor memorySelectExtractor;
|
||||
|
||||
private MemorySelector(){}
|
||||
private MemorySelector() {
|
||||
}
|
||||
|
||||
public static MemorySelector getInstance() throws IOException, ClassNotFoundException {
|
||||
if (memorySelector == null) {
|
||||
memorySelector = new MemorySelector();
|
||||
memorySelector.setMemoryManager(MemoryManager.getInstance());
|
||||
memorySelector.setSliceEvaluator(SliceEvaluator.getInstance());
|
||||
memorySelector.setMemorySelectExtractor(MemorySelectExtractor.getInstance());
|
||||
}
|
||||
return memorySelector;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void execute(InteractionContext interactionContext) {
|
||||
public void execute(InteractionContext interactionContext) throws IOException, ClassNotFoundException {
|
||||
//获取主题路径
|
||||
JSONObject extractorResult = memorySelectExtractor.execute(interactionContext.getInput());
|
||||
String selectType = extractorResult.getString("type");
|
||||
//根据主结果进行操作查找切片
|
||||
MemoryResult memoryResult = switch (selectType) {
|
||||
case MemorySelectExtractor.Constant.DATE ->
|
||||
memoryManager.selectMemory(LocalDate.parse(extractorResult.getString(MemorySelectExtractor.Constant.DATE)));
|
||||
case MemorySelectExtractor.Constant.TOPIC ->
|
||||
memoryManager.selectMemory(MemorySelectExtractor.Constant.TOPIC);
|
||||
default -> null;
|
||||
};
|
||||
//评估切片
|
||||
if (memoryResult == null) {
|
||||
memoryResult = sliceEvaluator.execute(memoryResult,interactionContext);
|
||||
}
|
||||
|
||||
//设置上下文
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
@@ -15,6 +15,7 @@ public class MemoryUpdater implements InteractionModule {
|
||||
|
||||
private MemoryManager memoryManager;
|
||||
private InteractionThreadPoolExecutor executor;
|
||||
private MemorySelectExtractor memorySelectExtractor;
|
||||
|
||||
private MemoryUpdater(){}
|
||||
|
||||
@@ -22,6 +23,7 @@ public class MemoryUpdater implements InteractionModule {
|
||||
if (memoryUpdater == null) {
|
||||
memoryUpdater = new MemoryUpdater();
|
||||
memoryUpdater.setMemoryManager(MemoryManager.getInstance());
|
||||
memoryUpdater.setMemorySelectExtractor(MemorySelectExtractor.getInstance());
|
||||
}
|
||||
return memoryUpdater;
|
||||
}
|
||||
|
||||
@@ -1,13 +1,22 @@
|
||||
package work.slhaf.agent.modules.memory;
|
||||
|
||||
import cn.hutool.json.JSONUtil;
|
||||
import lombok.Data;
|
||||
import lombok.EqualsAndHashCode;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import work.slhaf.agent.common.config.Config;
|
||||
import work.slhaf.agent.common.model.Model;
|
||||
import work.slhaf.agent.common.model.ModelConstant;
|
||||
import work.slhaf.agent.core.interaction.data.InteractionContext;
|
||||
import work.slhaf.agent.core.memory.MemoryManager;
|
||||
import work.slhaf.agent.core.memory.pojo.MemoryResult;
|
||||
import work.slhaf.agent.core.memory.pojo.MemorySlice;
|
||||
import work.slhaf.agent.core.memory.pojo.MemorySliceResult;
|
||||
import work.slhaf.agent.modules.memory.data.SliceSummary;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
@EqualsAndHashCode(callSuper = true)
|
||||
@Data
|
||||
@@ -16,13 +25,16 @@ public class SliceEvaluator extends Model {
|
||||
public static final String MODEL_KEY = "slice_evaluator";
|
||||
|
||||
private static SliceEvaluator sliceEvaluator;
|
||||
private MemoryManager memoryManager;
|
||||
|
||||
private SliceEvaluator(){}
|
||||
private SliceEvaluator() {
|
||||
}
|
||||
|
||||
public static SliceEvaluator getInstance() throws IOException, ClassNotFoundException {
|
||||
if (sliceEvaluator == null) {
|
||||
Config config = Config.getConfig();
|
||||
sliceEvaluator = new SliceEvaluator();
|
||||
sliceEvaluator.setMemoryManager(MemoryManager.getInstance());
|
||||
setModel(config, sliceEvaluator, MODEL_KEY, ModelConstant.SLICE_EVALUATOR_PROMPT);
|
||||
log.info("SliceEvaluator注册完毕...");
|
||||
}
|
||||
@@ -30,5 +42,55 @@ public class SliceEvaluator extends Model {
|
||||
return sliceEvaluator;
|
||||
}
|
||||
|
||||
public MemoryResult execute(MemoryResult memoryResult, InteractionContext context) {
|
||||
List<SliceSummary> sliceSummaryList = new ArrayList<>();
|
||||
setSliceSummaryList(memoryResult, context, sliceSummaryList);
|
||||
String primaryJsonStr = singleChat(JSONUtil.toJsonStr(sliceSummaryList)).getMessage();
|
||||
//TODO 解析并转换为过滤后的MemoryResult
|
||||
|
||||
return null;
|
||||
}
|
||||
|
||||
private void setSliceSummaryList(MemoryResult memoryResult, InteractionContext context, List<SliceSummary> sliceSummaryList) {
|
||||
for (MemorySliceResult memorySliceResult : memoryResult.getMemorySliceResult()) {
|
||||
//判断是否为发起用户
|
||||
if (accessible(memorySliceResult.getMemorySlice(), context)) {
|
||||
SliceSummary sliceSummary = new SliceSummary();
|
||||
sliceSummary.setId(memorySliceResult.getMemorySlice().getTimestamp());
|
||||
String stringBuilder = memorySliceResult.getSliceBefore().getSummary() +
|
||||
"\r\n" +
|
||||
memorySliceResult.getMemorySlice().getSummary() +
|
||||
"\r\n" +
|
||||
memorySliceResult.getSliceAfter().getSummary();
|
||||
sliceSummary.setSummary(stringBuilder);
|
||||
|
||||
sliceSummaryList.add(sliceSummary);
|
||||
}
|
||||
}
|
||||
|
||||
for (MemorySlice memorySlice : memoryResult.getRelatedMemorySliceResult()) {
|
||||
SliceSummary sliceSummary = new SliceSummary();
|
||||
sliceSummary.setId(memorySlice.getTimestamp());
|
||||
sliceSummary.setSummary(memorySlice.getSummary());
|
||||
|
||||
sliceSummaryList.add(sliceSummary);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
private boolean accessible(MemorySlice slice, InteractionContext context) {
|
||||
boolean ok;
|
||||
String startUserId = slice.getStartUserId();
|
||||
String userInfo = context.getUserInfo();
|
||||
String nickName = context.getUserNickname();
|
||||
|
||||
if (memoryManager.getUserId(userInfo, nickName).equals(startUserId)) {
|
||||
ok = true;
|
||||
} else {
|
||||
ok = !slice.isPrivate();
|
||||
}
|
||||
|
||||
return ok;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -0,0 +1,9 @@
|
||||
package work.slhaf.agent.modules.memory.data;
|
||||
|
||||
import lombok.Data;
|
||||
|
||||
@Data
|
||||
public class SliceSummary {
|
||||
private String summary;
|
||||
private Long id;
|
||||
}
|
||||
@@ -1,5 +1,6 @@
|
||||
package work.slhaf.agent.modules.preprocess;
|
||||
|
||||
import com.alibaba.fastjson2.JSONObject;
|
||||
import work.slhaf.agent.core.interaction.data.InteractionContext;
|
||||
import work.slhaf.agent.core.interaction.data.InteractionInputData;
|
||||
|
||||
@@ -26,6 +27,8 @@ public class PreprocessExecutor {
|
||||
context.setFinished(false);
|
||||
context.setInput(inputData.getContent());
|
||||
|
||||
context.setModuleContext(new JSONObject());
|
||||
|
||||
return context;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,7 +3,6 @@ package work.slhaf.agent.modules.task;
|
||||
import lombok.Data;
|
||||
import lombok.EqualsAndHashCode;
|
||||
import work.slhaf.agent.common.config.Config;
|
||||
import work.slhaf.agent.common.config.ModelConfig;
|
||||
import work.slhaf.agent.common.model.Model;
|
||||
import work.slhaf.agent.common.model.ModelConstant;
|
||||
|
||||
|
||||
@@ -1,41 +0,0 @@
|
||||
package work.slhaf.agent.modules.topic;
|
||||
|
||||
import lombok.Data;
|
||||
import lombok.EqualsAndHashCode;
|
||||
import work.slhaf.agent.common.chat.constant.ChatConstant;
|
||||
import work.slhaf.agent.common.chat.pojo.ChatResponse;
|
||||
import work.slhaf.agent.common.chat.pojo.Message;
|
||||
import work.slhaf.agent.common.config.Config;
|
||||
import work.slhaf.agent.common.model.Model;
|
||||
import work.slhaf.agent.common.model.ModelConstant;
|
||||
import work.slhaf.agent.core.interaction.InteractionModule;
|
||||
import work.slhaf.agent.core.interaction.data.InteractionContext;
|
||||
|
||||
import java.io.IOException;
|
||||
|
||||
@EqualsAndHashCode(callSuper = true)
|
||||
@Data
|
||||
public class TopicExtractor extends Model implements InteractionModule {
|
||||
public static final String MODEL_KEY = "topic_extractor";
|
||||
private static TopicExtractor topicExtractor;
|
||||
|
||||
private TopicExtractor() {
|
||||
}
|
||||
|
||||
public static TopicExtractor getInstance() throws IOException, ClassNotFoundException {
|
||||
if (topicExtractor == null) {
|
||||
Config config = Config.getConfig();
|
||||
topicExtractor = new TopicExtractor();
|
||||
setModel(config, topicExtractor, MODEL_KEY, ModelConstant.TOPIC_EXTRACTOR_PROMPT);
|
||||
}
|
||||
|
||||
return topicExtractor;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void execute(InteractionContext interactionContext) {
|
||||
String primaryMessageResponse = singleChat(interactionContext.getInput()).getMessage();
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
89
src/test/java/memory/AITest.java
Normal file
89
src/test/java/memory/AITest.java
Normal file
@@ -0,0 +1,89 @@
|
||||
package memory;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
import work.slhaf.agent.common.chat.ChatClient;
|
||||
import work.slhaf.agent.common.chat.constant.ChatConstant;
|
||||
import work.slhaf.agent.common.chat.pojo.Message;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
public class AITest {
|
||||
@Test
|
||||
public void test1(){
|
||||
ChatClient client = new ChatClient("https://open.bigmodel.cn/api/paas/v4/chat/completions","3db444552530b7742b0c53425fb93dcc.LcVwYjByht9AC3N9","glm-4-flash");
|
||||
List<Message> messages = new ArrayList<>();
|
||||
messages.add(new Message(ChatConstant.Character.SYSTEM, """
|
||||
# MemorySelectExtractor 提示词
|
||||
|
||||
## 功能说明
|
||||
你需要根据用户输入的JSON数据,分析其`text`字段内容,判断是否需要通过主题路径或日期进行记忆查询,并返回标准化格式的JSON响应。
|
||||
|
||||
## 输入字段说明
|
||||
- `text`: 用户输入的文本内容
|
||||
- `topic_tree`: 当前可用的主题树结构(括号内数字表示子主题数量)
|
||||
- `date`: 当前对话发生的日期(用于时间推理)
|
||||
|
||||
## 输出规则
|
||||
1. 当文本涉及明确主题路径时:
|
||||
- 使用`"type": "topic"`
|
||||
- `text`字段格式为"根主题->子主题->子子主题"(必须**完全匹配**topic_tree中的层级,包括从[root]到目标主题的完整路径)
|
||||
- 示例:{
|
||||
"type": "topic",
|
||||
"text": "工作->项目A->需求文档"
|
||||
}
|
||||
|
||||
2. 当文本包含明确可推算的日期时:
|
||||
- 使用`"type": "date"`
|
||||
- 日期格式必须为"YYYY-MM-DD"
|
||||
- 仅接受具体日期(不接受"上周"等模糊表达)
|
||||
- 示例:{
|
||||
"type": "date",
|
||||
"text": "2024-04-15"
|
||||
}
|
||||
|
||||
3. 当不需要查询或无法确定时:
|
||||
- 使用`"type": "none"`
|
||||
- 示例:{
|
||||
"type": "none"
|
||||
}
|
||||
|
||||
## 完整示例
|
||||
用户输入:{
|
||||
"text": "还记得我们讨论过游戏引擎的物理系统实现吗?",
|
||||
"topic_tree": "
|
||||
技术 (3)[root]
|
||||
├── 游戏开发 (2)
|
||||
│ ├── 图形渲染 (1)
|
||||
│ └── 物理系统 (0)
|
||||
└── 人工智能 (1)",
|
||||
"date": "2024-04-20"
|
||||
}
|
||||
|
||||
正确响应:{
|
||||
"type": "topic",
|
||||
"text": "技术->游戏开发->物理系统"
|
||||
}
|
||||
"""));
|
||||
|
||||
messages.add(new Message(ChatConstant.Character.USER, """
|
||||
{
|
||||
"text": "上周似乎发生了什么重要的事??",
|
||||
"topic_tree": "
|
||||
汽车工程 (4)[root]
|
||||
├── 动力系统 (3)
|
||||
│ ├── 发动机 (1)
|
||||
│ └── 新能源电池 (2)
|
||||
│ ├── 测试标准 (1)
|
||||
│ └── 安全规范 (1)
|
||||
└── 车身设计 (1)
|
||||
软件开发 (3)[root]
|
||||
质量管理 (2)[root]
|
||||
├── ISO认证 (1)
|
||||
└── 行业标准 (1)",
|
||||
"date": "2024-04-20"
|
||||
}
|
||||
"""));
|
||||
System.out.println(client.runChat(messages).getMessage());
|
||||
}
|
||||
}
|
||||
62
src/test/java/memory/MemoryTest.java
Normal file
62
src/test/java/memory/MemoryTest.java
Normal file
@@ -0,0 +1,62 @@
|
||||
package memory;
|
||||
|
||||
import cn.hutool.core.date.LocalDateTimeUtil;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import work.slhaf.agent.core.memory.MemoryGraph;
|
||||
import work.slhaf.agent.core.memory.node.TopicNode;
|
||||
|
||||
import java.time.LocalDate;
|
||||
import java.time.LocalDateTime;
|
||||
import java.util.HashMap;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
|
||||
public class MemoryTest {
|
||||
|
||||
@Test
|
||||
public void test1() {
|
||||
MemoryGraph graph = new MemoryGraph("test");
|
||||
HashMap<String, TopicNode> topicMap = new HashMap<>();
|
||||
|
||||
TopicNode root1 = new TopicNode();
|
||||
root1.setTopicNodes(new ConcurrentHashMap<>());
|
||||
|
||||
TopicNode sub1 = new TopicNode();
|
||||
sub1.setTopicNodes(new ConcurrentHashMap<>());
|
||||
|
||||
TopicNode sub2 = new TopicNode();
|
||||
sub2.setTopicNodes(new ConcurrentHashMap<>());
|
||||
|
||||
TopicNode subsub1 = new TopicNode();
|
||||
subsub1.setTopicNodes(new ConcurrentHashMap<>());
|
||||
|
||||
// 构造结构:root -> sub1 -> subsub1, root -> sub2
|
||||
sub1.getTopicNodes().put("子子主题1", subsub1);
|
||||
root1.getTopicNodes().put("子主题1", sub1);
|
||||
root1.getTopicNodes().put("子主题2", sub2);
|
||||
|
||||
topicMap.put("根主题1", root1);
|
||||
|
||||
// 添加 root2
|
||||
TopicNode root2 = new TopicNode();
|
||||
root2.setTopicNodes(new ConcurrentHashMap<>());
|
||||
|
||||
TopicNode sub3 = new TopicNode();
|
||||
sub3.setTopicNodes(new ConcurrentHashMap<>());
|
||||
|
||||
// 构造结构:root2 -> sub3
|
||||
root2.getTopicNodes().put("子主题3", sub3);
|
||||
|
||||
topicMap.put("根主题2", root2);
|
||||
|
||||
// 输出
|
||||
graph.setTopicNodes(topicMap);
|
||||
graph.printTopicTree();
|
||||
}
|
||||
|
||||
|
||||
@Test
|
||||
public void test2(){
|
||||
System.out.println(LocalDate.now());
|
||||
}
|
||||
|
||||
}
|
||||
@@ -59,7 +59,7 @@ class SearchTest {
|
||||
List<String> queryPath = new ArrayList<>();
|
||||
queryPath.add("算法");
|
||||
queryPath.add("排序");
|
||||
MemoryResult results = memoryGraph.selectMemory(queryPath);
|
||||
// MemoryResult results = memoryGraph.selectMemory(queryPath);
|
||||
|
||||
// 验证结果应包含:
|
||||
// 1. 目标节点所有记忆(java1)
|
||||
@@ -77,7 +77,7 @@ class SearchTest {
|
||||
invalidPath.add("不存在的主题");
|
||||
|
||||
assertThrows(UnExistedTopicException.class, () -> {
|
||||
memoryGraph.selectMemory(invalidPath);
|
||||
// memoryGraph.selectMemory(invalidPath);
|
||||
});
|
||||
}
|
||||
|
||||
@@ -94,7 +94,7 @@ class SearchTest {
|
||||
List<String> queryPath = new ArrayList<>();
|
||||
queryPath.add("编程");
|
||||
queryPath.add("Java");
|
||||
MemoryResult results = memoryGraph.selectMemory(queryPath);
|
||||
// MemoryResult results = memoryGraph.selectMemory(queryPath);
|
||||
|
||||
// 应包含:Java记忆 + 父级最新记忆
|
||||
// assertTrue(results.stream().anyMatch(m -> "java1".equals(m.getMemoryId())));
|
||||
@@ -136,7 +136,7 @@ class SearchTest {
|
||||
|
||||
// 执行查询
|
||||
List<String> queryPath = createTopicPath("编程", "Java");
|
||||
MemoryResult results = memoryGraph.selectMemory(queryPath);
|
||||
// MemoryResult results = memoryGraph.selectMemory(queryPath);
|
||||
|
||||
// 验证结果应包含最新关联记忆(dbNew)
|
||||
// assertTrue(results.stream().anyMatch(m -> "dbNew".equals(m.getMemoryId())),
|
||||
|
||||
Reference in New Issue
Block a user