feat(memory): 实现记忆切片持久化并优化记忆存储结构- 新增 ChatClient 类实现与大模型的交互

- 添加了chat包,用于后续大模型对接
- 更新 MemoryGraph 类,增加用户对话缓存和当前对话压缩上下文
- 修改 MemoryNode 类,实现记忆切片的序列化和反序列化
- 更新 MemorySlice 类,增加多用户相关字段和方法,将切片内容从SliceData移动至MemorySlice
- 删除未使用的 SliceData 类
- 添加日志依赖和异常处理,新的异常类NullSliceListException
This commit is contained in:
2025-04-11 21:50:11 +08:00
parent 24d4510270
commit c28979b495
14 changed files with 413 additions and 45 deletions

15
pom.xml
View File

@@ -49,6 +49,21 @@
<version>RELEASE</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.slf4j</groupId>
<artifactId>slf4j-api</artifactId>
<version>2.0.17</version>
</dependency>
<dependency>
<groupId>ch.qos.logback</groupId>
<artifactId>logback-classic</artifactId>
<version>1.5.17</version>
</dependency>
<dependency>
<groupId>cn.hutool</groupId>
<artifactId>hutool-all</artifactId>
<version>5.8.36</version>
</dependency>
</dependencies>
</project>

View File

@@ -0,0 +1,70 @@
package work.slhaf.chat;
import cn.hutool.http.HttpRequest;
import cn.hutool.http.HttpResponse;
import cn.hutool.json.JSONUtil;
import lombok.Data;
import lombok.NoArgsConstructor;
import work.slhaf.chat.constant.Constant;
import work.slhaf.chat.pojo.ChatBody;
import work.slhaf.chat.pojo.ChatResponse;
import work.slhaf.chat.pojo.Message;
import work.slhaf.chat.pojo.PrimaryChatResponse;
import java.util.List;
@Data
@NoArgsConstructor
public class ChatClient {
private String clientId;
private String url;
private String apikey;
private String model;
private int top_p;
private int temperature;
private int max_tokens;
public ChatClient(String url, String apikey, String model) {
this.url = url;
this.apikey = apikey;
this.model = model;
}
public ChatResponse runChat(List<Message> messages) {
HttpRequest request = HttpRequest.post(url);
request.header("Content-Type", "application/json");
request.header("Authorization", "Bearer " + apikey);
ChatBody body;
if (top_p > 0) {
body = ChatBody.builder()
.model(model)
.messages(messages)
.top_p(top_p)
.temperature(temperature)
.max_tokens(max_tokens)
.build();
} else {
body = ChatBody.builder()
.model(model)
.messages(messages)
.build();
}
HttpResponse response = request.body(JSONUtil.toJsonStr(body)).execute();
ChatResponse finalResponse;
PrimaryChatResponse primaryChatResponse = JSONUtil.toBean(response.body(), PrimaryChatResponse.class);
finalResponse = ChatResponse.builder()
.type(Constant.Response.SUCCESS)
.message(primaryChatResponse.getChoices().get(0).getMessage().getContent())
.usageBean(primaryChatResponse.getUsage())
.build();
response.close();
return finalResponse;
}
}

View File

@@ -0,0 +1,22 @@
package work.slhaf.chat.constant;
public class Constant {
public static class Character {
public static final String USER = "user";
public static final String SYSTEM = "system";
public static final String ASSISTANT = "assistant";
}
public static class Model {
public static final String DEEP_SEEK_CHAT = "deepseek-chat";
public static final String GLM_4_FLASH = "glm-4_flash";
public static final String GLM_4_PLUS = "glm-4_plus";
public static final String GLM_4_0520 = "glm-4_0520";
}
public static class Response {
public static final String SUCCESS = "success";
public static final String ERROR = "error";
}
}

View File

@@ -0,0 +1,25 @@
package work.slhaf.chat.pojo;
import lombok.*;
import java.util.List;
@Builder
@Data
@AllArgsConstructor
@NoArgsConstructor
public class ChatBody {
@NonNull
private String model;
@NonNull
private List<Message> messages;
@Builder.Default
private int temperature = 1;
@Builder.Default
private int top_p = 1;
private boolean stream;
@Builder.Default
private int max_tokens = 1024;
private int presence_penalty;
private int frequency_penalty;
}

View File

@@ -0,0 +1,16 @@
package work.slhaf.chat.pojo;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;
@Data
@Builder
@AllArgsConstructor
@NoArgsConstructor
public class ChatResponse {
private String type;
private String message;
private PrimaryChatResponse.UsageBean usageBean;
}

View File

@@ -0,0 +1,14 @@
package work.slhaf.chat.pojo;
import lombok.*;
@Builder
@Data
@AllArgsConstructor
@NoArgsConstructor
public class Message {
@NonNull
private String role;
@NonNull
private String content;
}

View File

@@ -0,0 +1,111 @@
package work.slhaf.chat.pojo;
import lombok.Getter;
import lombok.Setter;
import java.util.List;
@Getter
@Setter
public class PrimaryChatResponse {
/**
* id
*/
private String id;
/**
* object
*/
private String object;
/**
* created
*/
private int created;
/**
* model
*/
private String model;
/**
* choices
*/
private List<ChoicesBean> choices;
/**
* usage
*/
private UsageBean usage;
/**
* system_fingerprint
*/
private String system_fingerprint;
@Setter
@Getter
public static class UsageBean {
/**
* prompt_tokens
*/
private int prompt_tokens;
/**
* completion_tokens
*/
private int completion_tokens;
/**
* total_tokens
*/
private int total_tokens;
/**
* prompt_cache_hit_tokens
*/
private int prompt_cache_hit_tokens;
/**
* prompt_cache_miss_tokens
*/
private int prompt_cache_miss_tokens;
@Override
public String toString() {
return "UsageBean{" +
"prompt_tokens=" + prompt_tokens +
", completion_tokens=" + completion_tokens +
", total_tokens=" + total_tokens +
", prompt_cache_hit_tokens=" + prompt_cache_hit_tokens +
", prompt_cache_miss_tokens=" + prompt_cache_miss_tokens +
'}';
}
}
@Setter
@Getter
public static class ChoicesBean {
/**
* index
*/
private int index;
/**
* message
*/
private MessageBean message;
/**
* logprobs
*/
private Object logprobs;
/**
* finish_reason
*/
private String finish_reason;
@Setter
@Getter
public static class MessageBean {
/**
* role
*/
private String role;
/**
* content
*/
private String content;
}
}
}

View File

@@ -1,6 +1,7 @@
package work.slhaf.memory;
import lombok.Data;
import lombok.extern.slf4j.Slf4j;
import work.slhaf.memory.content.MemorySlice;
import work.slhaf.memory.exception.UnExistedTopicException;
import work.slhaf.memory.node.MemoryNode;
@@ -15,6 +16,7 @@ import java.time.LocalDateTime;
import java.util.*;
@Data
@Slf4j
public class MemoryGraph implements Serializable {
@Serial
@@ -43,14 +45,25 @@ public class MemoryGraph implements Serializable {
/**
* 近两日的对话总结缓存, 用于为大模型提供必要的记忆补充, hashmap以切片的存储时间为键总结为值
* 该部分作为'主LLM'system prompt常驻
* 该部分作为近两日的整体对话缓存, 不区分用户
*/
private HashMap<LocalDateTime, String> dialogMap;
/**
* 近两日的区分用户的对话总结缓存在prompt结构上比dialogMap层级深一层, dialogMap更具近两日整体对话的摘要性质
*/
private HashMap<LocalDateTime,HashMap<String/*userId*/,String>> userDialogMap;
/**
* 当前对话的活动性总结, 拥有比dialogMap更丰富的全文细节, 作为当前对话token超限时的必要上下文压缩存储
*/
private List<String> currentCompressedSessionContext;
/**
* 存储确定性记忆, 如'用户爱好'等确定性信息
* 该部分作为'主LLM'system prompt常驻
*/
private HashMap<String, LinkedHashMap<LocalDate, String>> staticMemory;
private HashMap<String /*userId*/, HashMap<String /*memoryKey*/,String /*memoryValue*/>> staticMemory;
public MemoryGraph(String id) {
this.id = id;
@@ -98,7 +111,7 @@ public class MemoryGraph implements Serializable {
try (ObjectInputStream ois = new ObjectInputStream(
new FileInputStream(filePath.toFile()))) {
MemoryGraph graph = (MemoryGraph) ois.readObject();
System.out.println("MemoryGraph 已从文件加载: " + filePath);
log.info("MemoryGraph 已从文件加载: " + filePath);
return graph;
}
}
@@ -115,7 +128,7 @@ public class MemoryGraph implements Serializable {
}
}
public void insertMemory(List<String> topicPath, MemorySlice slice) {
public void insertMemory(List<String> topicPath, MemorySlice slice) throws IOException, ClassNotFoundException {
topicPath = new ArrayList<>(topicPath);
//查看是否存在根主题节点
String rootTopic = topicPath.getFirst();
@@ -165,12 +178,14 @@ public class MemoryGraph implements Serializable {
updateDateIndex(now, slice);
updateDialogMap(slice);
node.saveMemorySliceList();
}
private void updateDialogMap(MemorySlice slice) {
String summary = slice.getSliceData().getSummary();
String summary = slice.getSummary();
LocalDateTime now = LocalDateTime.now();
//移除两天前的上下文补充(切片总结)
//更新dialogMap
//移除两天前的上下文缓存(切片总结)
List<LocalDateTime> keysToRemove = new ArrayList<>();
dialogMap.forEach((k, v) -> {
if (now.minusDays(2).isAfter(k)){
@@ -180,7 +195,21 @@ public class MemoryGraph implements Serializable {
for (LocalDateTime dateTime : keysToRemove) {
dialogMap.remove(dateTime);
}
keysToRemove.clear();
//放入新缓存
dialogMap.put(now,summary);
//更新userDialogMap
//移除两天前上下文缓存(切片总结)
userDialogMap.forEach((k,v) -> {
if (now.minusDays(2).isAfter(k)){
keysToRemove.add(k);
}
});
for (LocalDateTime dateTime : keysToRemove) {
userDialogMap.remove(dateTime);
}
//放入新缓存
userDialogMap.get(now).put(slice.getStartUser(),slice.getSummary());
}
private void updateDateIndex(LocalDate now, MemorySlice slice) {
@@ -211,7 +240,7 @@ public class MemoryGraph implements Serializable {
}
public List<MemorySlice> selectMemoryByPath(List<String> topicPath) {
public List<MemorySlice> selectMemoryByPath(List<String> topicPath) throws IOException, ClassNotFoundException {
List<MemorySlice> targetSliceList = new ArrayList<>();
topicPath = new ArrayList<>(topicPath);
String targetTopic = topicPath.getLast();

View File

@@ -1,21 +1,56 @@
package work.slhaf.memory.content;
import lombok.Data;
import work.slhaf.chat.pojo.Message;
import java.io.Serializable;
import java.util.List;
@Data
public class MemorySlice implements Serializable, Comparable<MemorySlice> {
//关联的完整对话的id
/**
* 关联的完整对话的id
*/
private String memoryId;
//该切片在关联的完整对话中的顺序, 由时间戳确定
/**
* 该切片在关联的完整对话中的顺序, 由时间戳确定
*/
private Long timestamp;
private String slicePath;
/**
* 格式为"<日期>.slice", 如2025-04-11.slice
*/
private String summary;
private List<Message> chatMessages;
/**
* 关联的其他主题, 即"邻近节点(联系)"
*/
private List<List<String>> relatedTopics;
//关联完整对话中的前序切片, 排序为键,完整路径为值
private MemorySlice sliceBefore;
private MemorySlice sliceAfter;
/**
* 关联完整对话中的前序切片, 排序为键,完整路径为值
*/
private MemorySlice sliceBefore, sliceAfter;
/**
* 多用户设定
* 发起该切片对话的用户
*/
private String startUser;
/**
* 该切片涉及到的用户
*/
private List<String> involvedUsers;
/**
* 是否仅供发起用户作为记忆参考
*/
private boolean isPrivate;
@Override
public int compareTo(MemorySlice memorySlice) {
@@ -27,12 +62,4 @@ public class MemorySlice implements Serializable, Comparable<MemorySlice> {
return 0;
}
public SliceData getSliceData(){
//todo: 待实现获取逻辑
return new SliceData();
}
public void saveSlice(SliceData sliceData){
//todo: 待实现存储逻辑, 该逻辑内将设置`slicePath`
}
}

View File

@@ -1,10 +0,0 @@
package work.slhaf.memory.content;
import com.alibaba.fastjson2.JSONArray;
import lombok.Data;
@Data
public class SliceData {
private String summary;
private JSONArray content;
}

View File

@@ -0,0 +1,7 @@
package work.slhaf.memory.exception;
public class NullSliceListException extends RuntimeException {
public NullSliceListException(String message) {
super(message);
}
}

View File

@@ -1,17 +1,29 @@
package work.slhaf.memory.node;
import lombok.Data;
import lombok.extern.slf4j.Slf4j;
import work.slhaf.memory.content.MemorySlice;
import work.slhaf.memory.exception.NullSliceListException;
import java.io.Serializable;
import java.io.*;
import java.time.LocalDate;
import java.util.ArrayList;
import java.util.List;
@Data
@Slf4j
public class MemoryNode implements Serializable, Comparable<MemoryNode> {
//记忆节点所属日期
private static String SLICE_DATA_DIR = "./data/slice/";
/**
* 记忆节点所属日期, 以日期为文件名在硬盘存储记忆数据(如 2025-04-11.slice)
*/
private LocalDate localDate;
//该日期对应的全部记忆切片
/**
* 该日期对应的全部记忆切片
*/
private List<MemorySlice> memorySliceList;
@Override
@@ -23,4 +35,35 @@ public class MemoryNode implements Serializable, Comparable<MemoryNode> {
}
return 0;
}
public List<MemorySlice> getMemorySliceList() throws IOException, ClassNotFoundException {
//检查是否存在对应文件
File file = new File(SLICE_DATA_DIR+this.getLocalDate()+".slice");
if (file.exists()){
this.memorySliceList = deserialize(file);
}else {
this.memorySliceList = new ArrayList<>();
}
return this.memorySliceList;
}
public void saveMemorySliceList() throws IOException {
if (memorySliceList == null){
throw new NullSliceListException("memorySliceList为NULL! 检查实现逻辑!");
}
File file = new File(SLICE_DATA_DIR+this.getLocalDate()+".slice");
try(ObjectOutputStream oos = new ObjectOutputStream(new FileOutputStream(file))){
oos.writeObject(this.memorySliceList);
}
//取消切片挂载, 释放内存
this.memorySliceList = null;
}
private List<MemorySlice> deserialize(File file) throws IOException, ClassNotFoundException {
try(ObjectInputStream ois = new ObjectInputStream(new FileInputStream(file))) {
List<MemorySlice> sliceList = (List<MemorySlice>) ois.readObject();
log.info("读取记忆切片成功");
return sliceList;
}
}
}

View File

@@ -7,8 +7,8 @@ import work.slhaf.memory.content.MemorySlice;
import work.slhaf.memory.node.MemoryNode;
import work.slhaf.memory.node.TopicNode;
import java.io.IOException;
import java.time.LocalDate;
import java.time.LocalDateTime;
import java.util.Arrays;
import java.util.HashMap;
import java.util.LinkedList;
@@ -28,7 +28,7 @@ public class InsertTest {
}
@Test
public void testInsertMemory_NewRootTopic() {
public void testInsertMemory_NewRootTopic() throws IOException, ClassNotFoundException {
// 准备测试数据
List<String> topicPath = new LinkedList<>(Arrays.asList("Programming", "Java", "Collections"));
MemorySlice slice = createTestMemorySlice("slice1");
@@ -54,7 +54,7 @@ public class InsertTest {
}
@Test
public void testInsertMemory_ExistingTopicPath() {
public void testInsertMemory_ExistingTopicPath() throws IOException, ClassNotFoundException {
// 准备初始数据
List<String> topicPath1 = new LinkedList<>(Arrays.asList("Programming", "Java", "Collections"));
MemorySlice slice1 = createTestMemorySlice("slice1");
@@ -75,7 +75,7 @@ public class InsertTest {
}
@Test
public void testInsertMemory_DifferentDays() {
public void testInsertMemory_DifferentDays() throws IOException, ClassNotFoundException {
// 准备测试数据
List<String> topicPath = new LinkedList<>(Arrays.asList("Math", "Algebra"));
MemorySlice slice1 = createTestMemorySlice("slice1");
@@ -101,7 +101,7 @@ public class InsertTest {
}
@Test
public void testInsertMemory_PartialExistingPath() {
public void testInsertMemory_PartialExistingPath() throws IOException, ClassNotFoundException {
// 准备初始数据 - 创建部分路径
List<String> topicPath1 = new LinkedList<>(Arrays.asList("Science", "Physics"));
MemorySlice slice1 = createTestMemorySlice("slice1");
@@ -129,11 +129,10 @@ public class InsertTest {
}
@Test
public void testSerializationConsistency() {
public void testSerializationConsistency() throws IOException, ClassNotFoundException {
// 构造 MemorySlice
MemorySlice slice = new MemorySlice();
slice.setMemoryId("001");
slice.setSlicePath("/demo/path");
List<String> topicPath = Arrays.asList("生活", "学习", "Java");
@@ -160,7 +159,6 @@ public class InsertTest {
// 校验MemorySlice 内容一致
MemorySlice deserializedSlice = javaNode.getMemoryNodes().get(0).getMemorySliceList().get(0);
assertEquals("001", deserializedSlice.getMemoryId());
assertEquals("/demo/path", deserializedSlice.getSlicePath());
}
}

View File

@@ -8,6 +8,7 @@ import work.slhaf.memory.exception.UnExistedTopicException;
import work.slhaf.memory.node.MemoryNode;
import work.slhaf.memory.node.TopicNode;
import java.io.IOException;
import java.time.LocalDate;
import java.util.ArrayList;
import java.util.List;
@@ -21,7 +22,7 @@ class SearchTest {
// 初始化测试环境,模拟插入基础数据
@BeforeEach
void setUp() {
void setUp() throws IOException, ClassNotFoundException {
memoryGraph = new MemoryGraph("testGraph");
// 构建基础主题路径:根主题 -> 编程 -> Java
@@ -42,7 +43,7 @@ class SearchTest {
// 场景1查询存在的完整主题路径含相关主题
@Test
void selectMemory_shouldReturnTargetAndRelatedAndParentMemories() {
void selectMemory_shouldReturnTargetAndRelatedAndParentMemories() throws IOException, ClassNotFoundException {
// 准备相关主题数据:根主题 -> 算法 -> 排序
List<String> sortPath = new ArrayList<>();
sortPath.add("算法");
@@ -81,7 +82,7 @@ class SearchTest {
// 场景3无相关主题时仅返回目标节点和父节点记忆
@Test
void selectMemory_withoutRelatedTopics_shouldReturnTargetAndParent() {
void selectMemory_withoutRelatedTopics_shouldReturnTargetAndParent() throws IOException, ClassNotFoundException {
// 插入父级记忆:根主题 -> 编程
List<String> parentPath = new ArrayList<>();
parentPath.add("编程");
@@ -102,7 +103,7 @@ class SearchTest {
// 场景4验证日期排序应优先取最新日期的邻近记忆
@Test
void selectMemory_shouldGetLatestRelatedMemory() {
void selectMemory_shouldGetLatestRelatedMemory() throws IOException, ClassNotFoundException {
// 准备相关主题路径:根主题 -> 数据库
List<String> dbPath = new ArrayList<>();
dbPath.add("数据库");