记忆模块初步完成

- 为 MemorySlice 添加字段 embedding, embedded 为后续向量匹配提供基础
- 在 PreprocessExecutor 中添加了针对上下文的预填充,现经过记忆模块后,可直接交由主模型进行处理
- 实现了 MemorySelectExtractor, SliceEvaluator, MemorySelector 为主的记忆模块, 并新增了必要的实体类
- 为 MemorySelectExtractor, SliceEvaluator 设计了提示词
This commit is contained in:
2025-04-23 19:27:11 +08:00
parent f31176336d
commit 4e28adbc52
29 changed files with 513 additions and 274 deletions

View File

@@ -6,7 +6,7 @@ import work.slhaf.agent.core.interaction.data.InteractionInputData;
import java.io.IOException;
public class Main {
public static void main(String[] args) throws IOException, ClassNotFoundException {
public static void main(String[] args) throws IOException, ClassNotFoundException, InterruptedException {
Agent agent = Agent.initialize();
InteractionInputData inputData = new InteractionInputData();

View File

@@ -38,7 +38,7 @@ public class Agent implements TaskCallback {
* 接收用户输入,包装为标准输入数据类
* @param inputData
*/
public void receiveUserInput(InteractionInputData inputData) throws IOException, ClassNotFoundException {
public void receiveUserInput(InteractionInputData inputData) throws IOException, ClassNotFoundException, InterruptedException {
inputData.setLocalDateTime(LocalDateTime.now());
interactionHub.call(inputData);
}

View File

@@ -1,8 +1,10 @@
package work.slhaf.agent.common.chat.pojo;
import com.alibaba.fastjson2.JSONObject;
import lombok.*;
import java.util.List;
import java.util.Map;
@Builder
@Data

View File

@@ -4,7 +4,7 @@ import cn.hutool.json.JSONUtil;
import lombok.Data;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.io.FileUtils;
import work.slhaf.agent.core.model.CoreModel;
import work.slhaf.agent.core.module.CoreModel;
import work.slhaf.agent.modules.memory.MemorySelectExtractor;
import work.slhaf.agent.modules.memory.MemorySelector;
import work.slhaf.agent.modules.memory.MemoryUpdater;

View File

@@ -4,18 +4,103 @@ public class ModelConstant {
public static final String CORE_MODEL_PROMPT = """
""";
public static final String SLICE_EVALUATOR_PROMPT = """
记忆切片选择器提示词(最终版)
功能说明
你需要根据用户输入的JSON数据分析其中的`text`(当前输入内容)、`history`(对话历史)和`memory_slices`(可用记忆切片)选出相关记忆切片。当text内容与history明显不相关时应以text为主要判断依据。
输入字段说明
• `text`: 用户当前输入的文本内容(首要分析对象)
• `history`: 用户与助手的对话历史记录(辅助参考)
• `memory_slices`: 可用的记忆切片列表,每个切片包含:
- `summary`: 切片内容摘要
- `id`: 切片唯一标识(时间戳)
- `date`: 切片所属日期
核心判断逻辑
1. 主题连续性检测:
IF 满足以下任一条件:
• text包含明显的新主题关键词"另外问下""突然想到"等转折词)
• text与history最后3轮对话的语义相关性<30%
• history为空
THEN 进入「独立分析模式」:
• 仅基于text内容匹配memory_slices
• 忽略history上下文
2. 常规模式:
ELSE:
• 综合text和history最近2-3轮内容进行联合判断
输出规则
{
"results": [id1, id2...]
}
完整示例
示例1(独立分析模式)
输入:{
"text": "突然想到,之前讨论的量子计算进展现在怎么样了?",
"history": [/* 10轮关于新冠疫苗的讨论 */],
"memory_slices": [
{"summary": "量子计算机近期突破IBM发布433量子位处理器", "id": 1672537000},
{"summary": "新冠疫苗加强针接种指南", "id": 1672623400}
]
}
输出:{
"results": [1672537000]
}
示例2(强相关延续)
输入:{
"text": "React 18的新特性具体有哪些",
"history": [
{"role": "user", "content": "现在前端框架怎么选?"},
{"role": "assistant", "content": "建议考虑React、Vue..."}
],
"memory_slices": [
{"summary": "React 18更新详解并发渲染、自动批处理等", "id": 1672709800},
{"summary": "Vue3组合式API教程", "id": 1672796200}
]
}
输出:{
"results": [1672709800]
}
示例3(模糊关联)
输入:{
"text": "这个方案的设计思路",
"history": [/* 5轮关于A项目的技术方案讨论 */],
"memory_slices": [
{"summary": "A项目架构设计V3.2", "id": 1672882600},
{"summary": "B项目风险评估报告", "id": 1672969000}
]
}
输出:{
"results": [1672882600]
}
最终注意事项
1. 匹配优先级:
独立分析模式text关键词 > 语义相似度 > 日期
常规模式:上下文关联度 > text关键词 > 历史延续性
2. 结果排序规则:
• 匹配度高的在前
• 同等匹配度时,时间近的在前
• 完全匹配优先于部分匹配
3. 直接输出JSON字符串
""";
public static final String TOPIC_EXTRACTOR_PROMPT = """
MemorySelectExtractor 提示词
功能说明
你需要根据用户输入的JSON数据分析其`text`和`history`字段内容判断是否需要通过主题路径或日期进行记忆查询并返回标准化格式的JSON响应。
注意:你只需要返回对应的JSON文本
注意:你只需要直接输出对应的JSON字符串
输入字段说明
• `text`: 用户当前输入的文本内容
• `topic_tree`: 当前可用的主题树结构(多层级结构,需返回从根节点到目标节点的完整路径)
• `topic_tree`: 当前可用的主题树结构(多层级结构,需返回从根节点([root])到目标节点的完整路径)
• `date`: 当前对话发生的日期(用于时间推理)
@@ -25,7 +110,7 @@ public class ModelConstant {
输出规则
1. 基本响应格式:
{
"recall": boolean,
"recall": boolean, //不存在匹配项则为false, 存在则为true
"matches": [
// 匹配项列表
]
@@ -73,7 +158,7 @@ public class ModelConstant {
输入:{
"text": "关于NodeJS的并发处理还有哪些要注意的",
"topic_tree": "
编程
编程[root]
├── JavaScript
│ ├── NodeJS
│ │ ├── 并发处理
@@ -96,7 +181,7 @@ public class ModelConstant {
输入:{
"text": "现在我想了解Express中间件的原理",
"topic_tree": "
编程
编程[root]
├── JavaScript
│ ├── NodeJS
│ │ ├── 并发处理
@@ -121,7 +206,7 @@ public class ModelConstant {
输入:{
"text": "2024-04-15讨论的Python内容和现在的Express需求",
"topic_tree": "
编程
编程[root]
├── JavaScript
│ ├── NodeJS
│ │ ├── 并发处理
@@ -147,7 +232,7 @@ public class ModelConstant {
输入:{
"text": "上周说的那个JavaScript特性",
"topic_tree": "
编程
编程[root]
├── JavaScript
│ ├── NodeJS
│ │ ├── 并发处理

View File

@@ -8,7 +8,7 @@ import work.slhaf.agent.core.interaction.TaskCallback;
import work.slhaf.agent.core.interaction.data.InteractionContext;
import work.slhaf.agent.core.interaction.data.InteractionInputData;
import work.slhaf.agent.core.memory.MemoryManager;
import work.slhaf.agent.core.model.CoreModel;
import work.slhaf.agent.core.module.CoreModel;
import work.slhaf.agent.modules.preprocess.PreprocessExecutor;
import work.slhaf.agent.modules.task.TaskScheduler;
@@ -35,7 +35,7 @@ public class InteractionHub {
return interactionHub;
}
public void call(InteractionInputData inputData) throws IOException, ClassNotFoundException {
public void call(InteractionInputData inputData) throws IOException, ClassNotFoundException, InterruptedException {
//预处理
InteractionContext interactionContext = PreprocessExecutor.getInstance().execute(inputData);
//加载模块

View File

@@ -5,5 +5,5 @@ import work.slhaf.agent.core.interaction.data.InteractionContext;
import java.io.IOException;
public interface InteractionModule {
void execute(InteractionContext context) throws IOException, ClassNotFoundException;
void execute(InteractionContext context) throws IOException, ClassNotFoundException, InterruptedException;
}

View File

@@ -102,6 +102,11 @@ public class MemoryGraph extends PersistableObject {
*/
private List<User> users;
/**
* 已被选中的切片时间戳集合,需要及时清理
*/
private Set<Long> selectedSlices;
public MemoryGraph(String id) {
this.id = id;
this.topicNodes = new HashMap<>();
@@ -111,6 +116,7 @@ public class MemoryGraph extends PersistableObject {
this.memoryNodeCacheCounter = new ConcurrentHashMap<>();
this.memorySliceCache = new ConcurrentHashMap<>();
this.modelPrompt = new HashMap<>();
this.selectedSlices = new HashSet<>();
}
public static MemoryGraph getInstance(String id) throws IOException, ClassNotFoundException {
@@ -188,7 +194,7 @@ public class MemoryGraph extends PersistableObject {
node = new MemoryNode();
node.setLocalDate(now);
node.setMemoryNodeId(UUID.randomUUID().toString());
node.setMemorySliceList(new ArrayList<>());
node.setMemorySliceList(new CopyOnWriteArrayList<>());
lastTopicNode.getMemoryNodes().add(node);
lastTopicNode.getMemoryNodes().sort(null);
}
@@ -312,7 +318,7 @@ public class MemoryGraph extends PersistableObject {
if (memorySliceCache.containsKey(topicPath)) {
return memorySliceCache.get(topicPath);
}
List<MemorySliceResult> targetSliceList = new ArrayList<>();
CopyOnWriteArrayList<MemorySliceResult> targetSliceList = new CopyOnWriteArrayList<>();
topicPath = new ArrayList<>(topicPath);
String targetTopic = topicPath.getLast();
TopicNode targetParentNode = getTargetParentNode(topicPath, targetTopic);
@@ -323,10 +329,14 @@ public class MemoryGraph extends PersistableObject {
for (MemoryNode memoryNode : targetParentNode.getTopicNodes().get(targetTopic).getMemoryNodes()) {
List<MemorySlice> endpointMemorySliceList = memoryNode.loadMemorySliceList();
for (MemorySlice memorySlice : endpointMemorySliceList) {
if (selectedSlices.contains(memorySlice.getTimestamp())){
continue;
}
sliceResult.setSliceBefore(memorySlice.getSliceBefore());
sliceResult.setMemorySlice(memorySlice);
sliceResult.setSliceAfter(memorySlice.getSliceAfter());
targetSliceList.add(sliceResult);
selectedSlices.add(memorySlice.getTimestamp());
}
for (MemorySlice memorySlice : endpointMemorySliceList) {
if (memorySlice.getRelatedTopics() != null) {
@@ -345,17 +355,11 @@ public class MemoryGraph extends PersistableObject {
TopicNode tempTargetParentNode = getTargetParentNode(tempTopicPath, tempTargetTopic);
//获取终点节点及其最新记忆节点
TopicNode tempTargetNode = tempTargetParentNode.getTopicNodes().get(tempTopicPath.getLast());
List<MemoryNode> tempMemoryNodes = tempTargetNode.getMemoryNodes();
if (!tempMemoryNodes.isEmpty()) {
relatedMemorySlice.addAll(tempMemoryNodes.getFirst().loadMemorySliceList());
}
setRelatedMemorySlices(tempTargetNode, relatedMemorySlice);
}
//邻近记忆节点 父级
List<MemoryNode> targetParentMemoryNodes = targetParentNode.getMemoryNodes();
if (!targetParentMemoryNodes.isEmpty()) {
relatedMemorySlice.addAll(targetParentMemoryNodes.getFirst().loadMemorySliceList());
}
setRelatedMemorySlices(targetParentNode, relatedMemorySlice);
//将上述结果包装为MemoryResult
memoryResult.setRelatedMemorySliceResult(relatedMemorySlice);
@@ -365,6 +369,19 @@ public class MemoryGraph extends PersistableObject {
return memoryResult;
}
private void setRelatedMemorySlices(TopicNode targetParentNode, List<MemorySlice> relatedMemorySlice) throws IOException, ClassNotFoundException {
List<MemoryNode> targetParentMemoryNodes = targetParentNode.getMemoryNodes();
if (!targetParentMemoryNodes.isEmpty()) {
for (MemorySlice memorySlice : targetParentMemoryNodes.getFirst().loadMemorySliceList()) {
if (selectedSlices.contains(memorySlice.getTimestamp())) {
continue;
}
relatedMemorySlice.add(memorySlice);
selectedSlices.add(memorySlice.getTimestamp());
}
}
}
private void updateCache(List<String> topicPath, MemoryResult memoryResult) {
Integer tempCount = memoryNodeCacheCounter.get(topicPath);
if (tempCount >= 5) {
@@ -390,12 +407,16 @@ public class MemoryGraph extends PersistableObject {
public MemoryResult selectMemory(LocalDate date) {
MemoryResult memoryResult = new MemoryResult();
List<MemorySliceResult> targetSliceList = new ArrayList<>();
CopyOnWriteArrayList<MemorySliceResult> targetSliceList = new CopyOnWriteArrayList<>();
for (List<MemorySlice> value : dateIndex.get(date).values()) {
for (MemorySlice memorySlice : value) {
if (selectedSlices.contains(memorySlice.getTimestamp())){
continue;
}
MemorySliceResult memorySliceResult = new MemorySliceResult();
memorySliceResult.setMemorySlice(memorySlice);
targetSliceList.add(memorySliceResult);
selectedSlices.add(memorySlice.getTimestamp());
}
}
memoryResult.setMemorySliceResult(targetSliceList);

View File

@@ -2,6 +2,7 @@ package work.slhaf.agent.core.memory;
import lombok.Data;
import lombok.extern.slf4j.Slf4j;
import work.slhaf.agent.common.chat.pojo.Message;
import work.slhaf.agent.common.config.Config;
import work.slhaf.agent.core.interaction.InteractionModule;
import work.slhaf.agent.core.interaction.data.InteractionContext;
@@ -50,6 +51,10 @@ public class MemoryManager implements InteractionModule {
return memoryGraph.selectMemory(date);
}
public void cleanSelectedSliceFilter(){
memoryGraph.getSelectedSlices().clear();
}
public String getUserId(String userInfo,String nickName) {
String userId = null;
for (User user : memoryGraph.getUsers()) {
@@ -65,6 +70,10 @@ public class MemoryManager implements InteractionModule {
return userId;
}
public List<Message> getChatMessages(){
return memoryGraph.getChatMessages();
}
private static User setNewUser(String userInfo, String nickName) {
User newUser = new User();
newUser.setUuid(UUID.randomUUID().toString());
@@ -75,4 +84,7 @@ public class MemoryManager implements InteractionModule {
return newUser;
}
public String getTopicTree() {
return memoryManager.getTopicTree();
}
}

View File

@@ -13,6 +13,7 @@ import java.nio.file.Path;
import java.time.LocalDate;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.CopyOnWriteArrayList;
@EqualsAndHashCode(callSuper = true)
@Data
@@ -37,7 +38,7 @@ public class MemoryNode extends PersistableObject implements Comparable<MemoryNo
/**
* 该日期对应的全部记忆切片
*/
private List<MemorySlice> memorySliceList;
private CopyOnWriteArrayList<MemorySlice> memorySliceList;
@Override
public int compareTo(MemoryNode memoryNode) {
@@ -56,7 +57,7 @@ public class MemoryNode extends PersistableObject implements Comparable<MemoryNo
this.memorySliceList = deserialize(file);
}else {
//逻辑正常的话这部分应该不会出现除非在insertMemory中进行save操作之前出现异常中断了方法但程序却没有结束
this.memorySliceList = new ArrayList<>();
this.memorySliceList = new CopyOnWriteArrayList<>();
}
return this.memorySliceList;
}
@@ -74,9 +75,9 @@ public class MemoryNode extends PersistableObject implements Comparable<MemoryNo
this.memorySliceList = null;
}
private List<MemorySlice> deserialize(File file) throws IOException, ClassNotFoundException {
private CopyOnWriteArrayList<MemorySlice> deserialize(File file) throws IOException, ClassNotFoundException {
try(ObjectInputStream ois = new ObjectInputStream(new FileInputStream(file))) {
List<MemorySlice> sliceList = (List<MemorySlice>) ois.readObject();
CopyOnWriteArrayList<MemorySlice> sliceList = (CopyOnWriteArrayList<MemorySlice>) ois.readObject();
log.info("读取记忆切片成功");
return sliceList;
}

View File

@@ -3,9 +3,10 @@ package work.slhaf.agent.core.memory.pojo;
import lombok.Data;
import java.util.List;
import java.util.concurrent.CopyOnWriteArrayList;
@Data
public class MemoryResult {
private List<MemorySliceResult> memorySliceResult;
private CopyOnWriteArrayList<MemorySliceResult> memorySliceResult;
private List<MemorySlice> relatedMemorySliceResult;
}

View File

@@ -57,6 +57,16 @@ public class MemorySlice extends PersistableObject implements Comparable<MemoryS
*/
private boolean isPrivate;
/**
* 摘要向量化结果
*/
private float[] summaryEmbedding;
/**
* 是否向量化
*/
private boolean embedded;
@Override
public int compareTo(MemorySlice memorySlice) {
if (memorySlice.getTimestamp() > this.getTimestamp()) {

View File

@@ -1,4 +1,4 @@
package work.slhaf.agent.core.model;
package work.slhaf.agent.core.module;
import lombok.Data;
import lombok.EqualsAndHashCode;

View File

@@ -42,7 +42,7 @@ public class AgentWebSocketServer extends WebSocketServer implements MessageSend
userSessions.put(inputData.getUserInfo(), webSocket); // 注册连接
try {
agent.receiveUserInput(inputData);
} catch (IOException | ClassNotFoundException e) {
} catch (IOException | ClassNotFoundException | InterruptedException e) {
throw new RuntimeException(e);
}
}

View File

@@ -1,20 +1,30 @@
package work.slhaf.agent.modules.memory;
import cn.hutool.json.JSONUtil;
import com.alibaba.fastjson2.JSONObject;
import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.extern.slf4j.Slf4j;
import work.slhaf.agent.common.config.Config;
import work.slhaf.agent.common.model.Model;
import work.slhaf.agent.common.model.ModelConstant;
import work.slhaf.agent.core.interaction.data.InteractionContext;
import work.slhaf.agent.core.memory.MemoryManager;
import work.slhaf.agent.modules.memory.data.extractor.ExtractorInput;
import work.slhaf.agent.modules.memory.data.extractor.ExtractorResult;
import java.io.IOException;
import java.util.List;
@EqualsAndHashCode(callSuper = true)
@Data
@Slf4j
public class MemorySelectExtractor extends Model {
public static final String MODEL_KEY = "topic_extractor";
private static MemorySelectExtractor memorySelectExtractor;
private MemoryManager memoryManager;
private MemorySelectExtractor() {
}
@@ -22,14 +32,33 @@ public class MemorySelectExtractor extends Model {
if (memorySelectExtractor == null) {
Config config = Config.getConfig();
memorySelectExtractor = new MemorySelectExtractor();
memorySelectExtractor.setMemoryManager(MemoryManager.getInstance());
setModel(config, memorySelectExtractor, MODEL_KEY, ModelConstant.TOPIC_EXTRACTOR_PROMPT);
}
return memorySelectExtractor;
}
public JSONObject execute(String input) {
return JSONObject.parseObject(singleChat(input).getMessage());
public ExtractorResult execute(InteractionContext context) {
//结构化为指定格式
ExtractorInput extractorInput = ExtractorInput.builder()
.text(context.getInput())
.date(context.getDateTime().toLocalDate())
.history(memoryManager.getChatMessages())
.topic_tree(memoryManager.getTopicTree())
.build();
String responseStr = singleChat(JSONUtil.toJsonPrettyStr(extractorInput)).getMessage();
ExtractorResult extractorResult;
try {
extractorResult = JSONObject.parseObject(responseStr, ExtractorResult.class);
} catch (Exception e) {
log.error("主题提取出错: {}", e.getLocalizedMessage());
extractorResult = new ExtractorResult();
extractorResult.setRecall(false);
extractorResult.setMatches(List.of());
}
return extractorResult;
}
public static class Constant {

View File

@@ -1,14 +1,21 @@
package work.slhaf.agent.modules.memory;
import com.alibaba.fastjson2.JSONObject;
import lombok.Data;
import work.slhaf.agent.core.interaction.InteractionModule;
import work.slhaf.agent.core.interaction.data.InteractionContext;
import work.slhaf.agent.core.memory.MemoryManager;
import work.slhaf.agent.core.memory.pojo.MemoryResult;
import work.slhaf.agent.core.memory.pojo.MemorySlice;
import work.slhaf.agent.modules.memory.data.evaluator.EvaluatedSlice;
import work.slhaf.agent.modules.memory.data.evaluator.EvaluatorInput;
import work.slhaf.agent.modules.memory.data.evaluator.EvaluatorResult;
import work.slhaf.agent.modules.memory.data.extractor.ExtractorMatchData;
import work.slhaf.agent.modules.memory.data.extractor.ExtractorResult;
import java.io.IOException;
import java.time.LocalDate;
import java.util.ArrayList;
import java.util.List;
@Data
public class MemorySelector implements InteractionModule {
@@ -33,24 +40,53 @@ public class MemorySelector implements InteractionModule {
}
@Override
public void execute(InteractionContext interactionContext) throws IOException, ClassNotFoundException {
public void execute(InteractionContext interactionContext) throws IOException, ClassNotFoundException, InterruptedException {
//获取主题路径
JSONObject extractorResult = memorySelectExtractor.execute(interactionContext.getInput());
String selectType = extractorResult.getString("type");
//根据主结果进行操作查找切片
MemoryResult memoryResult = switch (selectType) {
case MemorySelectExtractor.Constant.DATE ->
memoryManager.selectMemory(LocalDate.parse(extractorResult.getString(MemorySelectExtractor.Constant.DATE)));
case MemorySelectExtractor.Constant.TOPIC ->
memoryManager.selectMemory(MemorySelectExtractor.Constant.TOPIC);
ExtractorResult extractorResult = memorySelectExtractor.execute(interactionContext);
if (extractorResult.isRecall()) {
//查找切片
List<MemoryResult> memoryResultList = new ArrayList<>();
setMemoryResultList(memoryResultList, extractorResult.getMatches(), interactionContext.getUserInfo(), interactionContext.getUserNickname());
//评估切片
EvaluatorInput evaluatorInput = EvaluatorInput.builder()
.input(interactionContext.getInput())
.memoryResults(memoryResultList)
.messages(memoryManager.getChatMessages())
.build();
List<EvaluatedSlice> memorySlices = sliceEvaluator.execute(evaluatorInput);
//设置上下文
interactionContext.getModuleContext().put("memory_slices",memorySlices);
}
}
private void setMemoryResultList(List<MemoryResult> memoryResultList, List<ExtractorMatchData> matches, String userInfo, String nickName) throws IOException, ClassNotFoundException {
for (ExtractorMatchData match : matches) {
MemoryResult memoryResult = switch (match.getType()) {
case ExtractorMatchData.Constant.DATE -> memoryManager.selectMemory(match.getText());
case ExtractorMatchData.Constant.TOPIC -> memoryManager.selectMemory(LocalDate.parse(match.getText()));
default -> null;
};
//评估切片
if (memoryResult == null) {
memoryResult = sliceEvaluator.execute(memoryResult,interactionContext);
if (memoryResult == null) continue;
memoryResultList.add(memoryResult);
}
//清理切片记录
memoryManager.cleanSelectedSliceFilter();
//根据userInfo过滤是否为私人记忆
for (MemoryResult memoryResult : memoryResultList) {
//过滤终点记忆
memoryResult.getMemorySliceResult().removeIf(m -> removeOrNot(m.getMemorySlice(), userInfo, nickName));
//过滤邻近记忆
memoryResult.getRelatedMemorySliceResult().removeIf(m -> removeOrNot(m, userInfo, nickName));
}
}
//设置上下文
private boolean removeOrNot(MemorySlice memorySlice, String userInfo, String nickName) {
if (memorySlice.isPrivate()) {
String userId = memoryManager.getUserId(userInfo, nickName);
return memorySlice.getStartUserId().equals(userId);
}
return true;
}
}

View File

@@ -1,22 +1,26 @@
package work.slhaf.agent.modules.memory;
import cn.hutool.core.date.DateUtil;
import cn.hutool.json.JSONUtil;
import com.alibaba.fastjson2.JSONObject;
import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.extern.slf4j.Slf4j;
import work.slhaf.agent.common.config.Config;
import work.slhaf.agent.common.model.Model;
import work.slhaf.agent.common.model.ModelConstant;
import work.slhaf.agent.core.interaction.data.InteractionContext;
import work.slhaf.agent.core.interaction.InteractionThreadPoolExecutor;
import work.slhaf.agent.core.memory.MemoryManager;
import work.slhaf.agent.core.memory.pojo.MemoryResult;
import work.slhaf.agent.core.memory.pojo.MemorySlice;
import work.slhaf.agent.core.memory.pojo.MemorySliceResult;
import work.slhaf.agent.modules.memory.data.SliceSummary;
import work.slhaf.agent.modules.memory.data.evaluator.*;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.*;
import java.util.concurrent.Callable;
import java.util.concurrent.ConcurrentLinkedDeque;
import java.util.concurrent.TimeUnit;
@EqualsAndHashCode(callSuper = true)
@Data
@@ -26,6 +30,7 @@ public class SliceEvaluator extends Model {
private static SliceEvaluator sliceEvaluator;
private MemoryManager memoryManager;
private InteractionThreadPoolExecutor executor;
private SliceEvaluator() {
}
@@ -42,19 +47,46 @@ public class SliceEvaluator extends Model {
return sliceEvaluator;
}
public MemoryResult execute(MemoryResult memoryResult, InteractionContext context) {
public List<EvaluatedSlice> execute(EvaluatorInput evaluatorInput) throws InterruptedException {
List<MemoryResult> memoryResultList = evaluatorInput.getMemoryResults();
List<Callable<Void>> tasks = new ArrayList<>();
Queue<EvaluatedSlice> queue = new ConcurrentLinkedDeque<>();
for (MemoryResult memoryResult : memoryResultList) {
tasks.add(() -> {
List<SliceSummary> sliceSummaryList = new ArrayList<>();
setSliceSummaryList(memoryResult, context, sliceSummaryList);
String primaryJsonStr = singleChat(JSONUtil.toJsonStr(sliceSummaryList)).getMessage();
//TODO 解析并转换为过滤后的MemoryResult
//映射查找键值
Map<Long, SliceSummary> map = new HashMap<>();
setSliceSummaryList(memoryResult, sliceSummaryList, map);
try {
EvaluatorBatchInput batchInput = EvaluatorBatchInput.builder()
.text(evaluatorInput.getInput())
.memory_slices(sliceSummaryList)
.history(evaluatorInput.getMessages())
.build();
EvaluatorResult evaluatorResult = JSONObject.parseObject(singleChat(JSONUtil.toJsonStr(batchInput)).getMessage(), EvaluatorResult.class);
for (Long result : evaluatorResult.getResults()) {
SliceSummary sliceSummary = map.get(result);
EvaluatedSlice evaluatedSlice = EvaluatedSlice.builder()
.summary(sliceSummary.getSummary())
.date(sliceSummary.getDate())
.build();
queue.offer(evaluatedSlice);
}
} catch (Exception e) {
log.error("切片评估: {}", e.getLocalizedMessage());
}
return null;
});
}
private void setSliceSummaryList(MemoryResult memoryResult, InteractionContext context, List<SliceSummary> sliceSummaryList) {
executor.invokeAll(tasks, 30, TimeUnit.SECONDS);
return queue.stream().toList();
}
private void setSliceSummaryList(MemoryResult memoryResult, List<SliceSummary> sliceSummaryList, Map<Long, SliceSummary> map) {
for (MemorySliceResult memorySliceResult : memoryResult.getMemorySliceResult()) {
//判断是否为发起用户
if (accessible(memorySliceResult.getMemorySlice(), context)) {
SliceSummary sliceSummary = new SliceSummary();
sliceSummary.setId(memorySliceResult.getMemorySlice().getTimestamp());
String stringBuilder = memorySliceResult.getSliceBefore().getSummary() +
@@ -63,9 +95,11 @@ public class SliceEvaluator extends Model {
"\r\n" +
memorySliceResult.getSliceAfter().getSummary();
sliceSummary.setSummary(stringBuilder);
Long timestamp = memorySliceResult.getMemorySlice().getTimestamp();
sliceSummary.setDate(DateUtil.date(timestamp).toLocalDateTime().toLocalDate());
sliceSummaryList.add(sliceSummary);
}
map.put(timestamp, sliceSummary);
}
for (MemorySlice memorySlice : memoryResult.getRelatedMemorySliceResult()) {
@@ -74,23 +108,9 @@ public class SliceEvaluator extends Model {
sliceSummary.setSummary(memorySlice.getSummary());
sliceSummaryList.add(sliceSummary);
map.put(memorySlice.getTimestamp(), sliceSummary);
}
}
private boolean accessible(MemorySlice slice, InteractionContext context) {
boolean ok;
String startUserId = slice.getStartUserId();
String userInfo = context.getUserInfo();
String nickName = context.getUserNickname();
if (memoryManager.getUserId(userInfo, nickName).equals(startUserId)) {
ok = true;
} else {
ok = !slice.isPrivate();
}
return ok;
}
}

View File

@@ -1,9 +0,0 @@
package work.slhaf.agent.modules.memory.data;
import lombok.Data;
@Data
public class SliceSummary {
private String summary;
private Long id;
}

View File

@@ -0,0 +1,16 @@
package work.slhaf.agent.modules.memory.data.evaluator;
import lombok.Builder;
import lombok.Data;
import work.slhaf.agent.common.chat.pojo.Message;
import java.time.LocalDate;
import java.util.List;
@Data
@Builder
public class EvaluatedSlice {
// private List<Message> chatMessages;
private LocalDate date;
private String summary;
}

View File

@@ -0,0 +1,15 @@
package work.slhaf.agent.modules.memory.data.evaluator;
import lombok.Builder;
import lombok.Data;
import work.slhaf.agent.common.chat.pojo.Message;
import java.util.List;
@Data
@Builder
public class EvaluatorBatchInput {
private String text;
private List<Message> history;
private List<SliceSummary> memory_slices;
}

View File

@@ -0,0 +1,16 @@
package work.slhaf.agent.modules.memory.data.evaluator;
import lombok.Builder;
import lombok.Data;
import work.slhaf.agent.common.chat.pojo.Message;
import work.slhaf.agent.core.memory.pojo.MemoryResult;
import java.util.List;
@Data
@Builder
public class EvaluatorInput {
private String input;
private List<Message> messages;
private List<MemoryResult> memoryResults;
}

View File

@@ -0,0 +1,10 @@
package work.slhaf.agent.modules.memory.data.evaluator;
import lombok.Data;
import java.util.List;
@Data
public class EvaluatorResult {
private List<Long> results;
}

View File

@@ -0,0 +1,12 @@
package work.slhaf.agent.modules.memory.data.evaluator;
import lombok.Data;
import java.time.LocalDate;
@Data
public class SliceSummary {
private String summary;
private Long id;
private LocalDate date;
}

View File

@@ -0,0 +1,17 @@
package work.slhaf.agent.modules.memory.data.extractor;
import lombok.Builder;
import lombok.Data;
import work.slhaf.agent.common.chat.pojo.Message;
import java.time.LocalDate;
import java.util.List;
@Data
@Builder
public class ExtractorInput {
private String text;
private String topic_tree;
private LocalDate date;
private List<Message> history;
}

View File

@@ -0,0 +1,14 @@
package work.slhaf.agent.modules.memory.data.extractor;
import lombok.Data;
@Data
public class ExtractorMatchData {
private String type;
private String text;
public static class Constant {
public static final String DATE = "date";
public static final String TOPIC = "topic";
}
}

View File

@@ -0,0 +1,11 @@
package work.slhaf.agent.modules.memory.data.extractor;
import lombok.Data;
import java.util.List;
@Data
public class ExtractorResult {
private boolean recall;
private List<ExtractorMatchData> matches;
}

View File

@@ -1,18 +1,29 @@
package work.slhaf.agent.modules.preprocess;
import com.alibaba.fastjson2.JSONObject;
import lombok.Data;
import work.slhaf.agent.core.interaction.data.InteractionContext;
import work.slhaf.agent.core.interaction.data.InteractionInputData;
import work.slhaf.agent.core.memory.MemoryManager;
import java.io.IOException;
import java.time.LocalDateTime;
import java.time.format.DateTimeFormatter;
@Data
public class PreprocessExecutor {
private static PreprocessExecutor preprocessExecutor;
private PreprocessExecutor(){}
private MemoryManager memoryManager;
public static PreprocessExecutor getInstance() {
private PreprocessExecutor() {
}
public static PreprocessExecutor getInstance() throws IOException, ClassNotFoundException {
if (preprocessExecutor == null) {
preprocessExecutor = new PreprocessExecutor();
preprocessExecutor.setMemoryManager(MemoryManager.getInstance());
}
return preprocessExecutor;
}
@@ -28,6 +39,9 @@ public class PreprocessExecutor {
context.setInput(inputData.getContent());
context.setModuleContext(new JSONObject());
context.getModuleContext().put("text", inputData.getContent());
context.getModuleContext().put("datetime", LocalDateTime.now().format(DateTimeFormatter.ISO_LOCAL_DATE_TIME));
context.getModuleContext().put("character",memoryManager.getMemoryGraph().getModelPrompt());
return context;
}

View File

@@ -4,194 +4,100 @@ import org.junit.jupiter.api.Test;
import work.slhaf.agent.common.chat.ChatClient;
import work.slhaf.agent.common.chat.constant.ChatConstant;
import work.slhaf.agent.common.chat.pojo.Message;
import work.slhaf.agent.common.model.ModelConstant;
import java.util.ArrayList;
import java.util.List;
import java.util.Scanner;
public class AITest {
@Test
public void test1() {
public void topicExtractorTest() {
String input = """
{
"text": "之前处理过Node.js的并发问题还有Express中间件开发",
"topic_tree": "
编程 (3)[root]
├── JavaScript (3)
├── NodeJS (2)
├── 并发处理 (0)
└── 事件循环 (0)
└── Express (1)
└── 中间件 (0)
└── Python (2)",
"date": "2024-04-10"
"text": "回到昨天讨论的TypeScript装饰器实现AOP结合现在需要的微服务熔断机制",
"topic_tree": "编程\\n├── JavaScript\\n│ ├── NodeJS\\n│ └── TypeScript\\n│ ├── 类型系统\\n│ ├── 装饰器\\n│ │ ├── 类装饰器\\n│ │ └── 方法装饰器\\n│ └── 编译配置\\n└── 系统设计\\n ├── 微服务\\n │ ├── 服务发现\\n │ └── 熔断机制\\n └── 消息队列",
"date": "2024-05-20",
"history": [
{
"role": "user",
"content": "我们的TS项目需要实现日志切面装饰器方案和中间件方案哪个更合适"
},
{
"role": "assistant",
"content": "两种方案各有优劣装饰器的优势在于1) 声明式编程 2) 精确到方法级别 3) 元编程能力。中间件方案则更适合请求级别的处理。具体到实现细节方法装饰器可以通过Reflect Metadata..."
},
{
"role": "user",
"content": "在微服务架构中如何设计跨服务的统一日志收集特别是Kubernetes环境下的实现"
},
{
"role": "assistant",
"content": "K8s环境下的日志方案需要考虑1) DaemonSet部署Fluentd 2) 应用层的日志规范 3) EFK栈的索引策略。对于NodeJS应用建议使用Winston配合..."\s
},
{
"role": "user",
"content": "现在遇到服务雪崩问题,需要实现熔断降级"
}
]
}
""";
run(input);
run(input, ModelConstant.TOPIC_EXTRACTOR_PROMPT);
}
private void run(String input) {
@Test
public void sliceEvaluatorTest(){
String input = """
{
"text": "请结合我们之前讨论的美联储加息影响分析下当前黄金ETF和国债逆回购的组合策略",
"history": [
{
"role": "user",
"content": "美联储连续加息对A股有什么影响"
},
{
"role": "assistant",
"content": "历史数据分析显示:\\n1. 北上资金流动:利空流动性敏感板块\\n2. 汇率压力:增加出口企业汇兑收益\\n3. 行业分化:利好银行/出口,利空地产/消费\\n具体机制..."
},
{
"role": "user",
"content": "黄金作为避险资产现在可以配置吗?"
},
{
"role": "assistant",
"content": "黄金配置建议:\\n• 实际利率是核心影响因素\\n• 短期受美元指数压制\\n• 长期抗通胀属性仍在\\n建议比例..."
}
],
"memory_slices": [
{
"summary": "美联储加息周期的大类资产配置策略研究包含历史回测数据2004-2006、2015-2018两次加息周期中黄金、美债、新兴市场股市的表现以及当前特殊环境高通胀+地缘冲突)下的策略调整建议。",
"id": 1685587200,
"date": "2025-06-20"
},
{
"summary": "黄金ETF与国债逆回购的组合优化模型通过波动率分析给出了不同风险偏好下的最优配置比例特别讨论了在流动性紧张时期如何利用逆回购对冲黄金波动。",
"id": 1685673600,
"date": "2025-06-21"
},
{
"summary": "A股行业轮动与美联储政策的相关性研究建立了包含利率敏感度、外资持仓比例、出口依赖度等因子的分析框架并给出了当前环境下的行业配置建议。",
"id": 1685760000,
"date": "2025-06-22"
},
{
"summary": "跨境资本流动监测指标体系详解包含利差模型、风险偏好指标VIX指数、以及中国特有的资本管制有效性分析。",
"id": 1685846400,
"date": "2025-06-23"
}
]
}
""";
run(input,ModelConstant.SLICE_EVALUATOR_PROMPT);
}
private void run(String input, String prompt) {
ChatClient client = new ChatClient("https://open.bigmodel.cn/api/paas/v4/chat/completions", "3db444552530b7742b0c53425fb93dcc.LcVwYjByht9AC3N9", "glm-4-flash-250414");
List<Message> messages = new ArrayList<>();
messages.add(new Message(ChatConstant.Character.SYSTEM, """
MemorySelectExtractor 提示词
功能说明
你需要根据用户输入的JSON数据分析其`text`和`history`字段内容判断是否需要通过主题路径或日期进行记忆查询并返回标准化格式的JSON响应。
注意你只需要返回对应的JSON文本
输入字段说明
• `text`: 用户当前输入的文本内容
• `topic_tree`: 当前可用的主题树结构(多层级结构,需返回从根节点到目标节点的完整路径)
• `date`: 当前对话发生的日期(用于时间推理)
• `history`: 用户与LLM的完整对话历史用于主题连续性判断
输出规则
1. 基本响应格式:
{
"recall": boolean,
"matches": [
// 匹配项列表
]
}
2. 主题提取规则:
• 当当前`text`涉及新主题(与`history`最后N轮对话主题明显不同
◦ 必须进行主题提取
◦ 匹配`topic_tree`中最接近的完整路径(从根节点到目标节点,如"编程->JavaScript->NodeJS->并发处理"
• 当主题与历史对话连续时:
◦ 除非包含明确的新子主题,否则不重复提取相同主题路径
3. 日期提取规则(保持不变):
• 仅接受具体日期YYYY-MM-DD格式
• 拒绝所有模糊日期表达
4. 特殊处理:
• 当检测到主题切换但无法匹配`topic_tree`时:
{
"recall": false,
"matches": []
}
• 当历史对话为空时:
◦ 视为新主题,按常规规则处理
决策流程
1. 首先分析`history`判断当前对话主题上下文
2. 然后分析`text`
a. 检测是否包含具体日期→添加date类型
b. 检测是否包含新主题→添加topic类型
3. 最终综合判断`recall`值
完整示例
示例1主题延续
输入:{
"text": "关于NodeJS的并发处理还有哪些要注意的",
"topic_tree": "
编程
├── JavaScript
│ ├── NodeJS
│ │ ├── 并发处理
│ │ └── 事件循环
│ └── Express
│ └── 中间件
└── Python",
"date": "2024-04-20",
"history": [
{"role": "user", "content": "说说NodeJS的并发处理机制"},
{"role": "assistant", "content": "NodeJS的并发处理主要通过..."}
]
}
输出:{
"recall": false,
"matches": []
}
示例2主题切换
输入:{
"text": "现在我想了解Express中间件的原理",
"topic_tree": "
编程
├── JavaScript
│ ├── NodeJS
│ │ ├── 并发处理
│ │ └── 事件循环
│ └── Express
│ └── 中间件
└── Python",
"date": "2024-04-20",
"history": [
{"role": "user", "content": "NodeJS的并发处理怎么实现"},
{"role": "assistant", "content": "需要..."}
]
}
输出:{
"recall": true,
"matches": [
{"type": "topic", "text": "编程->JavaScript->Express->中间件"}
]
}
示例3混合情况
输入:{
"text": "2024-04-15讨论的Python内容和现在的Express需求",
"topic_tree": "
编程
├── JavaScript
│ ├── NodeJS
│ │ ├── 并发处理
│ │ └── 事件循环
│ └── Express
│ └── 中间件
└── Python",
"date": "2024-04-20",
"history": [
{"role": "user", "content": "需要了解Express框架"},
{"role": "assistant", "content": "Express是..."}
]
}
输出:{
"recall": true,
"matches": [
{"type": "date", "text": "2024-04-15"},
{"type": "topic", "text": "编程->Python"}
]
}
示例4模糊日期
输入:{
"text": "上周说的那个JavaScript特性",
"topic_tree": "
编程
├── JavaScript
│ ├── NodeJS
│ │ ├── 并发处理
│ │ └── 事件循环
│ └── Express
│ └── 中间件
└── Python",
"date": "2024-04-20",
"history": [...]
}
输出:{
"recall": false,
"matches": []
}
"""));
messages.add(new Message(ChatConstant.Character.SYSTEM, prompt));
messages.add(new Message(ChatConstant.Character.USER, input));
System.out.println(client.runChat(messages).getMessage());

View File

@@ -39,7 +39,7 @@ class SearchTest {
MemorySlice oldJavaMemory = createMemorySlice("javaOld");
MemoryNode oldNode = new MemoryNode();
oldNode.setLocalDate(yesterday);
oldNode.setMemorySliceList(List.of(oldJavaMemory));
// oldNode.setMemorySliceList(List.of(oldJavaMemory));
}
// 场景1查询存在的完整主题路径含相关主题
@@ -121,7 +121,7 @@ class SearchTest {
// 创建昨日记忆节点并添加到主题节点
MemoryNode oldMemoryNode = new MemoryNode();
oldMemoryNode.setLocalDate(yesterday);
oldMemoryNode.setMemorySliceList(new ArrayList<>(List.of(oldDbMemory)));
// oldMemoryNode.setMemorySliceList(new ArrayList<>(List.of(oldDbMemory)));
dbTopicNode.getMemoryNodes().add(oldMemoryNode);
// 对记忆节点进行日期排序根据compareTo方法