mirror of
https://github.com/slhaf/Partner.git
synced 2026-05-12 16:53:04 +08:00
refactor(memory): manage state serialization via StateCenter in MemoryRuntime
This commit is contained in:
@@ -1,33 +1,36 @@
|
||||
package work.slhaf.partner.module.memory.runtime;
|
||||
|
||||
import com.alibaba.fastjson2.JSONArray;
|
||||
import com.alibaba.fastjson2.JSONObject;
|
||||
import org.junit.jupiter.api.AfterEach;
|
||||
import org.junit.jupiter.api.BeforeAll;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.io.TempDir;
|
||||
import work.slhaf.partner.common.config.Config;
|
||||
import work.slhaf.partner.common.config.PartnerAgentConfigLoader;
|
||||
import work.slhaf.partner.core.cognition.CognitionCapability;
|
||||
import work.slhaf.partner.core.memory.MemoryCapability;
|
||||
import work.slhaf.partner.core.memory.pojo.MemorySlice;
|
||||
import work.slhaf.partner.core.memory.pojo.MemoryUnit;
|
||||
import work.slhaf.partner.core.memory.pojo.SliceRef;
|
||||
import work.slhaf.partner.framework.agent.config.AgentConfigLoader;
|
||||
import work.slhaf.partner.framework.agent.model.pojo.Message;
|
||||
import work.slhaf.partner.module.memory.selector.ActivatedMemorySlice;
|
||||
|
||||
import java.lang.reflect.Field;
|
||||
import java.lang.reflect.Method;
|
||||
import java.nio.file.Files;
|
||||
import java.nio.file.Path;
|
||||
import java.util.*;
|
||||
import java.time.LocalDate;
|
||||
import java.util.Collection;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.concurrent.CopyOnWriteArrayList;
|
||||
import java.util.concurrent.locks.Lock;
|
||||
import java.util.concurrent.locks.ReentrantLock;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||
|
||||
class MemoryRuntimeTest {
|
||||
|
||||
private AgentConfigLoader previousLoader;
|
||||
private String runtimeAgentId;
|
||||
|
||||
@BeforeAll
|
||||
public static void beforeAll(@TempDir Path tempDir) {
|
||||
System.setProperty("user.home", tempDir.toAbsolutePath().toString());
|
||||
@@ -53,29 +56,41 @@ class MemoryRuntimeTest {
|
||||
field.set(target, value);
|
||||
}
|
||||
|
||||
private static PartnerAgentConfigLoader testLoader(String agentId) {
|
||||
PartnerAgentConfigLoader loader = new PartnerAgentConfigLoader();
|
||||
Config config = new Config();
|
||||
config.setAgentId(agentId);
|
||||
Config.WebSocketConfig webSocketConfig = new Config.WebSocketConfig();
|
||||
webSocketConfig.setPort(18080);
|
||||
config.setWebSocketConfig(webSocketConfig);
|
||||
loader.setConfig(config);
|
||||
loader.setModelConfigMap(new HashMap<>());
|
||||
return loader;
|
||||
}
|
||||
|
||||
private static Message message(String content) {
|
||||
return new Message(Message.Character.USER, content);
|
||||
}
|
||||
|
||||
@AfterEach
|
||||
void tearDown() throws Exception {
|
||||
AgentConfigLoader.INSTANCE = previousLoader;
|
||||
if (runtimeAgentId != null) {
|
||||
Files.deleteIfExists(Path.of("data/memory", runtimeAgentId + "-memory-runtime.memory"));
|
||||
Files.deleteIfExists(Path.of("data/memory", runtimeAgentId + "-memory-runtime-temp.memory"));
|
||||
}
|
||||
private static CognitionCapability stubCognitionCapability(List<Message> chatMessages) {
|
||||
Lock lock = new ReentrantLock();
|
||||
return new CognitionCapability() {
|
||||
@Override
|
||||
public void initiateTurn(String input, String target, String... skippedModules) {
|
||||
}
|
||||
|
||||
@Override
|
||||
public work.slhaf.partner.core.cognition.ContextWorkspace contextWorkspace() {
|
||||
return new work.slhaf.partner.core.cognition.ContextWorkspace();
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<Message> getChatMessages() {
|
||||
return chatMessages;
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<Message> snapshotChatMessages() {
|
||||
return chatMessages;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void rollChatMessagesWithSnapshot(int snapshotSize, int retainDivisor) {
|
||||
}
|
||||
|
||||
@Override
|
||||
public Lock getMessageLock() {
|
||||
return lock;
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
@Test
|
||||
@@ -98,16 +113,16 @@ class MemoryRuntimeTest {
|
||||
assertTrue(invokeSliceMessages(runtime, unit, emptySlice).isEmpty());
|
||||
}
|
||||
|
||||
@AfterEach
|
||||
void tearDown() {
|
||||
}
|
||||
|
||||
@Test
|
||||
void shouldBindTopicToLatestMemorySliceInsteadOfFirstSlice() throws Exception {
|
||||
runtimeAgentId = "runtime-test-" + UUID.randomUUID();
|
||||
previousLoader = AgentConfigLoader.INSTANCE;
|
||||
AgentConfigLoader.INSTANCE = testLoader(runtimeAgentId);
|
||||
|
||||
StubMemoryCapability memoryCapability = new StubMemoryCapability("session-test");
|
||||
|
||||
MemoryRuntime runtime = new MemoryRuntime();
|
||||
setField(runtime, "memoryCapability", memoryCapability);
|
||||
setField(runtime, "cognitionCapability", stubCognitionCapability(List.of(message("seed"))));
|
||||
|
||||
MemoryUnit unit = new MemoryUnit("unit-99");
|
||||
unit.getConversationMessages().addAll(List.of(
|
||||
@@ -122,6 +137,7 @@ class MemoryRuntimeTest {
|
||||
MemorySlice secondSlice = MemorySlice.restore("slice-2", 2, 4, "second", 2L);
|
||||
|
||||
unit.getSlices().addAll(List.of(firstSlice, secondSlice));
|
||||
memoryCapability.remember(unit);
|
||||
|
||||
runtime.recordMemory(unit, "topic/main", List.of("topic/related"));
|
||||
|
||||
@@ -132,6 +148,62 @@ class MemoryRuntimeTest {
|
||||
topicSlices.get("topic/related").stream().map(SliceRef::getSliceId).toList());
|
||||
}
|
||||
|
||||
@Test
|
||||
void shouldRoundTripTopicAndDateIndexesViaState() throws Exception {
|
||||
StubMemoryCapability memoryCapability = new StubMemoryCapability("session-test");
|
||||
MemoryRuntime runtime = new MemoryRuntime();
|
||||
setField(runtime, "memoryCapability", memoryCapability);
|
||||
setField(runtime, "cognitionCapability", stubCognitionCapability(List.of(message("seed"))));
|
||||
|
||||
MemoryUnit unit = new MemoryUnit("unit-100");
|
||||
unit.getConversationMessages().addAll(List.of(
|
||||
message("m0"),
|
||||
message("m1"),
|
||||
message("m2"),
|
||||
message("m3")
|
||||
));
|
||||
MemorySlice firstSlice = MemorySlice.restore("slice-1", 0, 2, "first", 86_400_000L);
|
||||
MemorySlice secondSlice = MemorySlice.restore("slice-2", 2, 4, "second", 172_800_000L);
|
||||
unit.getSlices().addAll(List.of(firstSlice, secondSlice));
|
||||
memoryCapability.remember(unit);
|
||||
|
||||
runtime.recordMemory(unit, "topic/main", List.of("topic/related"));
|
||||
|
||||
JSONObject state = JSONObject.parseObject(runtime.convert().toString());
|
||||
JSONArray topicSlices = state.getJSONArray("topic_slices");
|
||||
assertEquals(2, topicSlices.size());
|
||||
JSONObject mainTopic = topicSlices.stream()
|
||||
.map(JSONObject.class::cast)
|
||||
.filter(item -> "topic/main".equals(item.getString("topic_path")))
|
||||
.findFirst()
|
||||
.orElseThrow();
|
||||
assertEquals("slice-2", mainTopic.getJSONArray("refs").getJSONObject(0).getString("slice_id"));
|
||||
|
||||
JSONArray dateIndex = state.getJSONArray("date_index");
|
||||
assertEquals(2, dateIndex.size());
|
||||
JSONObject secondDate = dateIndex.stream()
|
||||
.map(JSONObject.class::cast)
|
||||
.filter(item -> "1970-01-03".equals(item.getString("date")))
|
||||
.findFirst()
|
||||
.orElseThrow();
|
||||
assertEquals("slice-2", secondDate.getJSONArray("refs").getJSONObject(0).getString("slice_id"));
|
||||
|
||||
MemoryRuntime restored = new MemoryRuntime();
|
||||
setField(restored, "memoryCapability", memoryCapability);
|
||||
setField(restored, "cognitionCapability", stubCognitionCapability(List.of(message("seed"))));
|
||||
restored.load(state);
|
||||
|
||||
List<ActivatedMemorySlice> topicResult = restored.queryActivatedMemoryByTopicPath("topic/main");
|
||||
assertEquals(1, topicResult.size());
|
||||
assertEquals("slice-2", topicResult.getFirst().getSliceId());
|
||||
assertEquals(List.of("m2", "m3"), topicResult.getFirst().getMessages().stream().map(Message::getContent).toList());
|
||||
|
||||
List<ActivatedMemorySlice> dateResult = restored.queryActivatedMemoryByDate(LocalDate.parse("1970-01-03"));
|
||||
assertEquals(1, dateResult.size());
|
||||
assertEquals("slice-2", dateResult.getFirst().getSliceId());
|
||||
assertEquals("second", dateResult.getFirst().getSummary());
|
||||
}
|
||||
|
||||
private static final class StubMemoryCapability implements MemoryCapability {
|
||||
private final String sessionId;
|
||||
private final Map<String, MemoryUnit> units = new HashMap<>();
|
||||
@@ -140,6 +212,10 @@ class MemoryRuntimeTest {
|
||||
this.sessionId = sessionId;
|
||||
}
|
||||
|
||||
private void remember(MemoryUnit unit) {
|
||||
units.put(unit.getId(), unit);
|
||||
}
|
||||
|
||||
@Override
|
||||
public MemoryUnit getMemoryUnit(String unitId) {
|
||||
return units.get(unitId);
|
||||
|
||||
Reference in New Issue
Block a user