mirror of
https://github.com/slhaf/Partner.git
synced 2026-05-12 08:43:02 +08:00
refactor(memory): manage state serialization via StateCenter in MemoryRuntime
This commit is contained in:
@@ -1,9 +1,10 @@
|
|||||||
package work.slhaf.partner.module.memory.runtime;
|
package work.slhaf.partner.module.memory.runtime;
|
||||||
|
|
||||||
|
import com.alibaba.fastjson2.JSONArray;
|
||||||
|
import com.alibaba.fastjson2.JSONObject;
|
||||||
import lombok.EqualsAndHashCode;
|
import lombok.EqualsAndHashCode;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.apache.commons.io.FileUtils;
|
import org.jetbrains.annotations.NotNull;
|
||||||
import work.slhaf.partner.common.config.PartnerAgentConfigLoader;
|
|
||||||
import work.slhaf.partner.core.cognition.CognitionCapability;
|
import work.slhaf.partner.core.cognition.CognitionCapability;
|
||||||
import work.slhaf.partner.core.memory.MemoryCapability;
|
import work.slhaf.partner.core.memory.MemoryCapability;
|
||||||
import work.slhaf.partner.core.memory.exception.UnExistedDateIndexException;
|
import work.slhaf.partner.core.memory.exception.UnExistedDateIndexException;
|
||||||
@@ -11,18 +12,16 @@ import work.slhaf.partner.core.memory.exception.UnExistedTopicException;
|
|||||||
import work.slhaf.partner.core.memory.pojo.MemorySlice;
|
import work.slhaf.partner.core.memory.pojo.MemorySlice;
|
||||||
import work.slhaf.partner.core.memory.pojo.MemoryUnit;
|
import work.slhaf.partner.core.memory.pojo.MemoryUnit;
|
||||||
import work.slhaf.partner.core.memory.pojo.SliceRef;
|
import work.slhaf.partner.core.memory.pojo.SliceRef;
|
||||||
import work.slhaf.partner.framework.agent.common.entity.PersistableObject;
|
|
||||||
import work.slhaf.partner.framework.agent.config.AgentConfigLoader;
|
|
||||||
import work.slhaf.partner.framework.agent.factory.capability.annotation.InjectCapability;
|
import work.slhaf.partner.framework.agent.factory.capability.annotation.InjectCapability;
|
||||||
import work.slhaf.partner.framework.agent.factory.component.abstracts.AbstractAgentModule;
|
import work.slhaf.partner.framework.agent.factory.component.abstracts.AbstractAgentModule;
|
||||||
import work.slhaf.partner.framework.agent.factory.component.annotation.Init;
|
import work.slhaf.partner.framework.agent.factory.component.annotation.Init;
|
||||||
import work.slhaf.partner.framework.agent.model.pojo.Message;
|
import work.slhaf.partner.framework.agent.model.pojo.Message;
|
||||||
|
import work.slhaf.partner.framework.agent.state.State;
|
||||||
|
import work.slhaf.partner.framework.agent.state.StateSerializable;
|
||||||
|
import work.slhaf.partner.framework.agent.state.StateValue;
|
||||||
import work.slhaf.partner.module.memory.selector.ActivatedMemorySlice;
|
import work.slhaf.partner.module.memory.selector.ActivatedMemorySlice;
|
||||||
|
|
||||||
import java.io.*;
|
|
||||||
import java.nio.file.Files;
|
|
||||||
import java.nio.file.Path;
|
import java.nio.file.Path;
|
||||||
import java.nio.file.Paths;
|
|
||||||
import java.time.Instant;
|
import java.time.Instant;
|
||||||
import java.time.LocalDate;
|
import java.time.LocalDate;
|
||||||
import java.time.ZoneId;
|
import java.time.ZoneId;
|
||||||
@@ -30,13 +29,9 @@ import java.util.*;
|
|||||||
import java.util.concurrent.CopyOnWriteArrayList;
|
import java.util.concurrent.CopyOnWriteArrayList;
|
||||||
import java.util.concurrent.locks.ReentrantLock;
|
import java.util.concurrent.locks.ReentrantLock;
|
||||||
|
|
||||||
import static work.slhaf.partner.common.Constant.Path.MEMORY_DATA;
|
|
||||||
|
|
||||||
@EqualsAndHashCode(callSuper = true)
|
@EqualsAndHashCode(callSuper = true)
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class MemoryRuntime extends AbstractAgentModule.Standalone {
|
public class MemoryRuntime extends AbstractAgentModule.Standalone implements StateSerializable {
|
||||||
|
|
||||||
private static final String RUNTIME_KEY = "memory-runtime";
|
|
||||||
|
|
||||||
@InjectCapability
|
@InjectCapability
|
||||||
private MemoryCapability memoryCapability;
|
private MemoryCapability memoryCapability;
|
||||||
@@ -49,9 +44,8 @@ public class MemoryRuntime extends AbstractAgentModule.Standalone {
|
|||||||
|
|
||||||
@Init
|
@Init
|
||||||
public void init() {
|
public void init() {
|
||||||
loadState();
|
register();
|
||||||
checkAndSetMemoryId();
|
checkAndSetMemoryId();
|
||||||
Runtime.getRuntime().addShutdownHook(new Thread(this::saveStateSafely));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private void checkAndSetMemoryId() {
|
private void checkAndSetMemoryId() {
|
||||||
@@ -71,7 +65,6 @@ public class MemoryRuntime extends AbstractAgentModule.Standalone {
|
|||||||
if (!exists) {
|
if (!exists) {
|
||||||
refs.add(sliceRef);
|
refs.add(sliceRef);
|
||||||
}
|
}
|
||||||
saveState();
|
|
||||||
} finally {
|
} finally {
|
||||||
runtimeLock.unlock();
|
runtimeLock.unlock();
|
||||||
}
|
}
|
||||||
@@ -104,7 +97,6 @@ public class MemoryRuntime extends AbstractAgentModule.Standalone {
|
|||||||
.addIfAbsent(new SliceRef(memoryUnit.getId(), slice.getId()));
|
.addIfAbsent(new SliceRef(memoryUnit.getId(), slice.getId()));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
saveState();
|
|
||||||
} finally {
|
} finally {
|
||||||
runtimeLock.unlock();
|
runtimeLock.unlock();
|
||||||
}
|
}
|
||||||
@@ -192,8 +184,8 @@ public class MemoryRuntime extends AbstractAgentModule.Standalone {
|
|||||||
return List.of();
|
return List.of();
|
||||||
}
|
}
|
||||||
int size = conversationMessages.size();
|
int size = conversationMessages.size();
|
||||||
int start = Math.max(0, Math.min(memorySlice.getStartIndex(), size));
|
int start = Math.clamp(memorySlice.getStartIndex(), 0, size);
|
||||||
int end = Math.max(start, Math.min(memorySlice.getEndIndex(), size));
|
int end = Math.clamp(memorySlice.getEndIndex(), start, size);
|
||||||
if (start >= end) {
|
if (start >= end) {
|
||||||
return List.of();
|
return List.of();
|
||||||
}
|
}
|
||||||
@@ -220,69 +212,118 @@ public class MemoryRuntime extends AbstractAgentModule.Standalone {
|
|||||||
return topicPath == null ? "" : topicPath.trim();
|
return topicPath == null ? "" : topicPath.trim();
|
||||||
}
|
}
|
||||||
|
|
||||||
private void loadState() {
|
@Override
|
||||||
Path filePath = getFilePath();
|
@NotNull
|
||||||
if (!Files.exists(filePath)) {
|
public Path statePath() {
|
||||||
return;
|
return Path.of("module", "memory", "topic_based_memory.json");
|
||||||
}
|
}
|
||||||
try (ObjectInputStream ois = new ObjectInputStream(new FileInputStream(filePath.toFile()))) {
|
|
||||||
RuntimeState state = (RuntimeState) ois.readObject();
|
@Override
|
||||||
topicSlices = state.topicSlices;
|
public void load(@NotNull JSONObject state) {
|
||||||
dateIndex = state.dateIndex;
|
runtimeLock.lock();
|
||||||
} catch (Exception e) {
|
try {
|
||||||
log.error("[MemoryRuntime] 加载运行态失败", e);
|
|
||||||
topicSlices = new HashMap<>();
|
topicSlices = new HashMap<>();
|
||||||
dateIndex = new HashMap<>();
|
dateIndex = new HashMap<>();
|
||||||
|
|
||||||
|
JSONArray topicSlicesArray = state.getJSONArray("topic_slices");
|
||||||
|
if (topicSlicesArray != null) {
|
||||||
|
for (int i = 0; i < topicSlicesArray.size(); i++) {
|
||||||
|
JSONObject topicObject = topicSlicesArray.getJSONObject(i);
|
||||||
|
if (topicObject == null) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
String topicPath = topicObject.getString("topic_path");
|
||||||
|
if (topicPath == null) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
topicSlices.put(normalizeTopicPath(topicPath), decodeSliceRefs(topicObject.getJSONArray("refs")));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private void saveStateSafely() {
|
JSONArray dateIndexArray = state.getJSONArray("date_index");
|
||||||
runtimeLock.lock();
|
if (dateIndexArray != null) {
|
||||||
|
for (int i = 0; i < dateIndexArray.size(); i++) {
|
||||||
|
JSONObject dateObject = dateIndexArray.getJSONObject(i);
|
||||||
|
if (dateObject == null) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
String date = dateObject.getString("date");
|
||||||
|
if (date == null) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
try {
|
try {
|
||||||
saveState();
|
dateIndex.put(LocalDate.parse(date), decodeSliceRefs(dateObject.getJSONArray("refs")));
|
||||||
|
} catch (Exception e) {
|
||||||
|
log.warn("[MemoryRuntime] 跳过非法日期索引: {}", date, e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
} finally {
|
} finally {
|
||||||
runtimeLock.unlock();
|
runtimeLock.unlock();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private void saveState() {
|
@Override
|
||||||
Path filePath = getFilePath();
|
public @NotNull State convert() {
|
||||||
Path tempPath = getTempFilePath();
|
runtimeLock.lock();
|
||||||
try {
|
try {
|
||||||
Files.createDirectories(Paths.get(MEMORY_DATA));
|
State state = new State();
|
||||||
FileUtils.createParentDirectories(filePath.toFile());
|
|
||||||
try (ObjectOutputStream oos = new ObjectOutputStream(new FileOutputStream(tempPath.toFile()))) {
|
List<StateValue.Obj> topicSliceStates = topicSlices.entrySet().stream()
|
||||||
RuntimeState state = new RuntimeState();
|
.sorted(Map.Entry.comparingByKey())
|
||||||
state.topicSlices = new HashMap<>(topicSlices);
|
.map(entry -> StateValue.obj(Map.of(
|
||||||
state.dateIndex = new HashMap<>(dateIndex);
|
"topic_path", StateValue.str(entry.getKey()),
|
||||||
oos.writeObject(state);
|
"refs", StateValue.arr(encodeSliceRefs(entry.getValue()))
|
||||||
}
|
)))
|
||||||
Files.move(tempPath, filePath, java.nio.file.StandardCopyOption.REPLACE_EXISTING);
|
.toList();
|
||||||
} catch (IOException e) {
|
state.append("topic_slices", StateValue.arr(topicSliceStates));
|
||||||
log.error("[MemoryRuntime] 保存运行态失败", e);
|
|
||||||
|
List<StateValue.Obj> dateIndexStates = dateIndex.entrySet().stream()
|
||||||
|
.sorted(Map.Entry.comparingByKey())
|
||||||
|
.map(entry -> StateValue.obj(Map.of(
|
||||||
|
"date", StateValue.str(entry.getKey().toString()),
|
||||||
|
"refs", StateValue.arr(encodeSliceRefs(entry.getValue()))
|
||||||
|
)))
|
||||||
|
.toList();
|
||||||
|
state.append("date_index", StateValue.arr(dateIndexStates));
|
||||||
|
|
||||||
|
return state;
|
||||||
|
} finally {
|
||||||
|
runtimeLock.unlock();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private Path getFilePath() {
|
private List<StateValue> encodeSliceRefs(List<SliceRef> refs) {
|
||||||
String id = ((PartnerAgentConfigLoader) AgentConfigLoader.INSTANCE).getConfig().getAgentId();
|
return refs.stream()
|
||||||
return Paths.get(MEMORY_DATA, id + "-" + RUNTIME_KEY + ".memory");
|
.map(ref -> (StateValue) StateValue.obj(Map.of(
|
||||||
|
"unit_id", StateValue.str(ref.getUnitId()),
|
||||||
|
"slice_id", StateValue.str(ref.getSliceId())
|
||||||
|
)))
|
||||||
|
.toList();
|
||||||
}
|
}
|
||||||
|
|
||||||
private Path getTempFilePath() {
|
private CopyOnWriteArrayList<SliceRef> decodeSliceRefs(JSONArray refsArray) {
|
||||||
String id = ((PartnerAgentConfigLoader) AgentConfigLoader.INSTANCE).getConfig().getAgentId();
|
CopyOnWriteArrayList<SliceRef> refs = new CopyOnWriteArrayList<>();
|
||||||
return Paths.get(MEMORY_DATA, id + "-" + RUNTIME_KEY + "-temp.memory");
|
if (refsArray == null) {
|
||||||
|
return refs;
|
||||||
|
}
|
||||||
|
for (int i = 0; i < refsArray.size(); i++) {
|
||||||
|
JSONObject refObject = refsArray.getJSONObject(i);
|
||||||
|
if (refObject == null) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
String unitId = refObject.getString("unit_id");
|
||||||
|
String sliceId = refObject.getString("slice_id");
|
||||||
|
if (unitId == null || sliceId == null) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
refs.addIfAbsent(new SliceRef(unitId, sliceId));
|
||||||
|
}
|
||||||
|
return refs;
|
||||||
}
|
}
|
||||||
|
|
||||||
private static final class TopicTreeNode {
|
private static final class TopicTreeNode {
|
||||||
private final Map<String, TopicTreeNode> children = new LinkedHashMap<>();
|
private final Map<String, TopicTreeNode> children = new LinkedHashMap<>();
|
||||||
private int count;
|
private int count;
|
||||||
}
|
}
|
||||||
|
|
||||||
private static final class RuntimeState extends PersistableObject {
|
|
||||||
@Serial
|
|
||||||
private static final long serialVersionUID = 1L;
|
|
||||||
|
|
||||||
private Map<String, CopyOnWriteArrayList<SliceRef>> topicSlices = new HashMap<>();
|
|
||||||
private Map<LocalDate, CopyOnWriteArrayList<SliceRef>> dateIndex = new HashMap<>();
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,33 +1,36 @@
|
|||||||
package work.slhaf.partner.module.memory.runtime;
|
package work.slhaf.partner.module.memory.runtime;
|
||||||
|
|
||||||
|
import com.alibaba.fastjson2.JSONArray;
|
||||||
|
import com.alibaba.fastjson2.JSONObject;
|
||||||
import org.junit.jupiter.api.AfterEach;
|
import org.junit.jupiter.api.AfterEach;
|
||||||
import org.junit.jupiter.api.BeforeAll;
|
import org.junit.jupiter.api.BeforeAll;
|
||||||
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.junit.jupiter.api.io.TempDir;
|
import org.junit.jupiter.api.io.TempDir;
|
||||||
import work.slhaf.partner.common.config.Config;
|
import work.slhaf.partner.core.cognition.CognitionCapability;
|
||||||
import work.slhaf.partner.common.config.PartnerAgentConfigLoader;
|
|
||||||
import work.slhaf.partner.core.memory.MemoryCapability;
|
import work.slhaf.partner.core.memory.MemoryCapability;
|
||||||
import work.slhaf.partner.core.memory.pojo.MemorySlice;
|
import work.slhaf.partner.core.memory.pojo.MemorySlice;
|
||||||
import work.slhaf.partner.core.memory.pojo.MemoryUnit;
|
import work.slhaf.partner.core.memory.pojo.MemoryUnit;
|
||||||
import work.slhaf.partner.core.memory.pojo.SliceRef;
|
import work.slhaf.partner.core.memory.pojo.SliceRef;
|
||||||
import work.slhaf.partner.framework.agent.config.AgentConfigLoader;
|
|
||||||
import work.slhaf.partner.framework.agent.model.pojo.Message;
|
import work.slhaf.partner.framework.agent.model.pojo.Message;
|
||||||
|
import work.slhaf.partner.module.memory.selector.ActivatedMemorySlice;
|
||||||
|
|
||||||
import java.lang.reflect.Field;
|
import java.lang.reflect.Field;
|
||||||
import java.lang.reflect.Method;
|
import java.lang.reflect.Method;
|
||||||
import java.nio.file.Files;
|
|
||||||
import java.nio.file.Path;
|
import java.nio.file.Path;
|
||||||
import java.util.*;
|
import java.time.LocalDate;
|
||||||
|
import java.util.Collection;
|
||||||
|
import java.util.HashMap;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
import java.util.concurrent.CopyOnWriteArrayList;
|
import java.util.concurrent.CopyOnWriteArrayList;
|
||||||
|
import java.util.concurrent.locks.Lock;
|
||||||
|
import java.util.concurrent.locks.ReentrantLock;
|
||||||
|
|
||||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
import static org.junit.jupiter.api.Assertions.assertTrue;
|
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||||
|
|
||||||
class MemoryRuntimeTest {
|
class MemoryRuntimeTest {
|
||||||
|
|
||||||
private AgentConfigLoader previousLoader;
|
|
||||||
private String runtimeAgentId;
|
|
||||||
|
|
||||||
@BeforeAll
|
@BeforeAll
|
||||||
public static void beforeAll(@TempDir Path tempDir) {
|
public static void beforeAll(@TempDir Path tempDir) {
|
||||||
System.setProperty("user.home", tempDir.toAbsolutePath().toString());
|
System.setProperty("user.home", tempDir.toAbsolutePath().toString());
|
||||||
@@ -53,29 +56,41 @@ class MemoryRuntimeTest {
|
|||||||
field.set(target, value);
|
field.set(target, value);
|
||||||
}
|
}
|
||||||
|
|
||||||
private static PartnerAgentConfigLoader testLoader(String agentId) {
|
|
||||||
PartnerAgentConfigLoader loader = new PartnerAgentConfigLoader();
|
|
||||||
Config config = new Config();
|
|
||||||
config.setAgentId(agentId);
|
|
||||||
Config.WebSocketConfig webSocketConfig = new Config.WebSocketConfig();
|
|
||||||
webSocketConfig.setPort(18080);
|
|
||||||
config.setWebSocketConfig(webSocketConfig);
|
|
||||||
loader.setConfig(config);
|
|
||||||
loader.setModelConfigMap(new HashMap<>());
|
|
||||||
return loader;
|
|
||||||
}
|
|
||||||
|
|
||||||
private static Message message(String content) {
|
private static Message message(String content) {
|
||||||
return new Message(Message.Character.USER, content);
|
return new Message(Message.Character.USER, content);
|
||||||
}
|
}
|
||||||
|
|
||||||
@AfterEach
|
private static CognitionCapability stubCognitionCapability(List<Message> chatMessages) {
|
||||||
void tearDown() throws Exception {
|
Lock lock = new ReentrantLock();
|
||||||
AgentConfigLoader.INSTANCE = previousLoader;
|
return new CognitionCapability() {
|
||||||
if (runtimeAgentId != null) {
|
@Override
|
||||||
Files.deleteIfExists(Path.of("data/memory", runtimeAgentId + "-memory-runtime.memory"));
|
public void initiateTurn(String input, String target, String... skippedModules) {
|
||||||
Files.deleteIfExists(Path.of("data/memory", runtimeAgentId + "-memory-runtime-temp.memory"));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public work.slhaf.partner.core.cognition.ContextWorkspace contextWorkspace() {
|
||||||
|
return new work.slhaf.partner.core.cognition.ContextWorkspace();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<Message> getChatMessages() {
|
||||||
|
return chatMessages;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<Message> snapshotChatMessages() {
|
||||||
|
return chatMessages;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void rollChatMessagesWithSnapshot(int snapshotSize, int retainDivisor) {
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Lock getMessageLock() {
|
||||||
|
return lock;
|
||||||
|
}
|
||||||
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
@@ -98,16 +113,16 @@ class MemoryRuntimeTest {
|
|||||||
assertTrue(invokeSliceMessages(runtime, unit, emptySlice).isEmpty());
|
assertTrue(invokeSliceMessages(runtime, unit, emptySlice).isEmpty());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@AfterEach
|
||||||
|
void tearDown() {
|
||||||
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
void shouldBindTopicToLatestMemorySliceInsteadOfFirstSlice() throws Exception {
|
void shouldBindTopicToLatestMemorySliceInsteadOfFirstSlice() throws Exception {
|
||||||
runtimeAgentId = "runtime-test-" + UUID.randomUUID();
|
|
||||||
previousLoader = AgentConfigLoader.INSTANCE;
|
|
||||||
AgentConfigLoader.INSTANCE = testLoader(runtimeAgentId);
|
|
||||||
|
|
||||||
StubMemoryCapability memoryCapability = new StubMemoryCapability("session-test");
|
StubMemoryCapability memoryCapability = new StubMemoryCapability("session-test");
|
||||||
|
|
||||||
MemoryRuntime runtime = new MemoryRuntime();
|
MemoryRuntime runtime = new MemoryRuntime();
|
||||||
setField(runtime, "memoryCapability", memoryCapability);
|
setField(runtime, "memoryCapability", memoryCapability);
|
||||||
|
setField(runtime, "cognitionCapability", stubCognitionCapability(List.of(message("seed"))));
|
||||||
|
|
||||||
MemoryUnit unit = new MemoryUnit("unit-99");
|
MemoryUnit unit = new MemoryUnit("unit-99");
|
||||||
unit.getConversationMessages().addAll(List.of(
|
unit.getConversationMessages().addAll(List.of(
|
||||||
@@ -122,6 +137,7 @@ class MemoryRuntimeTest {
|
|||||||
MemorySlice secondSlice = MemorySlice.restore("slice-2", 2, 4, "second", 2L);
|
MemorySlice secondSlice = MemorySlice.restore("slice-2", 2, 4, "second", 2L);
|
||||||
|
|
||||||
unit.getSlices().addAll(List.of(firstSlice, secondSlice));
|
unit.getSlices().addAll(List.of(firstSlice, secondSlice));
|
||||||
|
memoryCapability.remember(unit);
|
||||||
|
|
||||||
runtime.recordMemory(unit, "topic/main", List.of("topic/related"));
|
runtime.recordMemory(unit, "topic/main", List.of("topic/related"));
|
||||||
|
|
||||||
@@ -132,6 +148,62 @@ class MemoryRuntimeTest {
|
|||||||
topicSlices.get("topic/related").stream().map(SliceRef::getSliceId).toList());
|
topicSlices.get("topic/related").stream().map(SliceRef::getSliceId).toList());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void shouldRoundTripTopicAndDateIndexesViaState() throws Exception {
|
||||||
|
StubMemoryCapability memoryCapability = new StubMemoryCapability("session-test");
|
||||||
|
MemoryRuntime runtime = new MemoryRuntime();
|
||||||
|
setField(runtime, "memoryCapability", memoryCapability);
|
||||||
|
setField(runtime, "cognitionCapability", stubCognitionCapability(List.of(message("seed"))));
|
||||||
|
|
||||||
|
MemoryUnit unit = new MemoryUnit("unit-100");
|
||||||
|
unit.getConversationMessages().addAll(List.of(
|
||||||
|
message("m0"),
|
||||||
|
message("m1"),
|
||||||
|
message("m2"),
|
||||||
|
message("m3")
|
||||||
|
));
|
||||||
|
MemorySlice firstSlice = MemorySlice.restore("slice-1", 0, 2, "first", 86_400_000L);
|
||||||
|
MemorySlice secondSlice = MemorySlice.restore("slice-2", 2, 4, "second", 172_800_000L);
|
||||||
|
unit.getSlices().addAll(List.of(firstSlice, secondSlice));
|
||||||
|
memoryCapability.remember(unit);
|
||||||
|
|
||||||
|
runtime.recordMemory(unit, "topic/main", List.of("topic/related"));
|
||||||
|
|
||||||
|
JSONObject state = JSONObject.parseObject(runtime.convert().toString());
|
||||||
|
JSONArray topicSlices = state.getJSONArray("topic_slices");
|
||||||
|
assertEquals(2, topicSlices.size());
|
||||||
|
JSONObject mainTopic = topicSlices.stream()
|
||||||
|
.map(JSONObject.class::cast)
|
||||||
|
.filter(item -> "topic/main".equals(item.getString("topic_path")))
|
||||||
|
.findFirst()
|
||||||
|
.orElseThrow();
|
||||||
|
assertEquals("slice-2", mainTopic.getJSONArray("refs").getJSONObject(0).getString("slice_id"));
|
||||||
|
|
||||||
|
JSONArray dateIndex = state.getJSONArray("date_index");
|
||||||
|
assertEquals(2, dateIndex.size());
|
||||||
|
JSONObject secondDate = dateIndex.stream()
|
||||||
|
.map(JSONObject.class::cast)
|
||||||
|
.filter(item -> "1970-01-03".equals(item.getString("date")))
|
||||||
|
.findFirst()
|
||||||
|
.orElseThrow();
|
||||||
|
assertEquals("slice-2", secondDate.getJSONArray("refs").getJSONObject(0).getString("slice_id"));
|
||||||
|
|
||||||
|
MemoryRuntime restored = new MemoryRuntime();
|
||||||
|
setField(restored, "memoryCapability", memoryCapability);
|
||||||
|
setField(restored, "cognitionCapability", stubCognitionCapability(List.of(message("seed"))));
|
||||||
|
restored.load(state);
|
||||||
|
|
||||||
|
List<ActivatedMemorySlice> topicResult = restored.queryActivatedMemoryByTopicPath("topic/main");
|
||||||
|
assertEquals(1, topicResult.size());
|
||||||
|
assertEquals("slice-2", topicResult.getFirst().getSliceId());
|
||||||
|
assertEquals(List.of("m2", "m3"), topicResult.getFirst().getMessages().stream().map(Message::getContent).toList());
|
||||||
|
|
||||||
|
List<ActivatedMemorySlice> dateResult = restored.queryActivatedMemoryByDate(LocalDate.parse("1970-01-03"));
|
||||||
|
assertEquals(1, dateResult.size());
|
||||||
|
assertEquals("slice-2", dateResult.getFirst().getSliceId());
|
||||||
|
assertEquals("second", dateResult.getFirst().getSummary());
|
||||||
|
}
|
||||||
|
|
||||||
private static final class StubMemoryCapability implements MemoryCapability {
|
private static final class StubMemoryCapability implements MemoryCapability {
|
||||||
private final String sessionId;
|
private final String sessionId;
|
||||||
private final Map<String, MemoryUnit> units = new HashMap<>();
|
private final Map<String, MemoryUnit> units = new HashMap<>();
|
||||||
@@ -140,6 +212,10 @@ class MemoryRuntimeTest {
|
|||||||
this.sessionId = sessionId;
|
this.sessionId = sessionId;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private void remember(MemoryUnit unit) {
|
||||||
|
units.put(unit.getId(), unit);
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public MemoryUnit getMemoryUnit(String unitId) {
|
public MemoryUnit getMemoryUnit(String unitId) {
|
||||||
return units.get(unitId);
|
return units.get(unitId);
|
||||||
|
|||||||
Reference in New Issue
Block a user