test(memory): test with new memory behavior

This commit is contained in:
2026-04-07 17:23:14 +08:00
parent b80ff8400c
commit 874488ea79
4 changed files with 199 additions and 153 deletions

View File

@@ -29,7 +29,6 @@ import work.slhaf.partner.module.memory.updater.summarizer.entity.SummarizeResul
import work.slhaf.partner.runtime.interaction.data.context.PartnerRunningFlowContext; import work.slhaf.partner.runtime.interaction.data.context.PartnerRunningFlowContext;
import java.util.List; import java.util.List;
import java.util.UUID;
import java.util.concurrent.ExecutorService; import java.util.concurrent.ExecutorService;
import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicBoolean;
@@ -171,7 +170,7 @@ public class MemoryUpdater extends AbstractAgentModule.Running<PartnerRunningFlo
log.debug("[MemoryUpdater] 记忆更新-总结流程-输入: {}", JSONObject.toJSONString(summarizeInput)); log.debug("[MemoryUpdater] 记忆更新-总结流程-输入: {}", JSONObject.toJSONString(summarizeInput));
SummarizeResult summarizeResult = summarize(summarizeInput); SummarizeResult summarizeResult = summarize(summarizeInput);
log.debug("[MemoryUpdater] 记忆更新-总结流程-输出: {}", JSONObject.toJSONString(summarizeResult)); log.debug("[MemoryUpdater] 记忆更新-总结流程-输出: {}", JSONObject.toJSONString(summarizeResult));
MemoryUnit memoryUnit = buildMemoryUnit(chatSnapshot, summarizeResult); MemoryUnit memoryUnit = memoryCapability.updateMemoryUnit(chatSnapshot, summarizeResult.getSummary());
memoryRuntime.recordMemory( memoryRuntime.recordMemory(
memoryUnit, memoryUnit,
summarizeResult.getTopicPath(), summarizeResult.getTopicPath(),
@@ -187,28 +186,6 @@ public class MemoryUpdater extends AbstractAgentModule.Running<PartnerRunningFlo
return multiSummarizer.execute(summarizeInput); return multiSummarizer.execute(summarizeInput);
} }
// TODO update memory unit via memory capability
private MemoryUnit buildMemoryUnit(List<Message> chatMessages, SummarizeResult summarizeResult) {
String memoryId = memoryCapability.getMemorySessionId();
String resolvedMemoryId = memoryId == null || memoryId.isBlank() ? UUID.randomUUID().toString() : memoryId;
MemoryUnit existingUnit = memoryCapability.getMemoryUnit(resolvedMemoryId);
List<Message> existingMessages = existingUnit.getConversationMessages();
int startIndex = existingMessages.size();
MemorySlice memorySlice = new MemorySlice(
startIndex,
startIndex + chatMessages.size(),
summarizeResult.getSummary()
);
MemoryUnit memoryUnit = new MemoryUnit(resolvedMemoryId);
memoryUnit.updateTimestamp();
memoryUnit.getConversationMessages().addAll(chatMessages);
memoryUnit.getSlices().add(memorySlice);
return memoryUnit;
}
@Override @Override
public int order() { public int order() {
return 7; return 7;

View File

@@ -1,92 +1,86 @@
package work.slhaf.partner.core.memory; package work.slhaf.partner.core.memory;
import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import work.slhaf.partner.common.config.Config; import org.junit.jupiter.api.io.TempDir;
import work.slhaf.partner.common.config.PartnerAgentConfigLoader;
import work.slhaf.partner.core.memory.pojo.MemorySlice; import work.slhaf.partner.core.memory.pojo.MemorySlice;
import work.slhaf.partner.core.memory.pojo.MemoryUnit; import work.slhaf.partner.core.memory.pojo.MemoryUnit;
import work.slhaf.partner.framework.agent.config.AgentConfigLoader;
import work.slhaf.partner.framework.agent.model.pojo.Message; import work.slhaf.partner.framework.agent.model.pojo.Message;
import java.nio.file.Files;
import java.nio.file.Path; import java.nio.file.Path;
import java.util.HashMap;
import java.util.List; import java.util.List;
import java.util.UUID;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.*;
import static org.junit.jupiter.api.Assertions.assertTrue;
class MemoryCoreTest { class MemoryCoreTest {
private AgentConfigLoader previousLoader; private static MemoryCore memoryCore;
private String agentId;
private static PartnerAgentConfigLoader testLoader(String agentId) { @BeforeAll
PartnerAgentConfigLoader loader = new PartnerAgentConfigLoader(); static void beforeAll(@TempDir Path tempDir) {
Config config = new Config(); System.setProperty("user.home", tempDir.toAbsolutePath().toString());
config.setAgentId(agentId); memoryCore = new MemoryCore();
Config.WebSocketConfig webSocketConfig = new Config.WebSocketConfig();
webSocketConfig.setPort(18080);
config.setWebSocketConfig(webSocketConfig);
loader.setConfig(config);
loader.setModelConfigMap(new HashMap<>());
return loader;
} }
@AfterEach @BeforeEach
void tearDown() throws Exception { void setUp() {
AgentConfigLoader.INSTANCE = previousLoader; memoryCore.refreshMemorySession();
if (agentId != null) {
Files.deleteIfExists(Path.of("data/memory", agentId + "-memory-core.memory"));
Files.deleteIfExists(Path.of("data/memory", agentId + "-temp-memory-core.memory"));
}
} }
@Test @Test
void shouldNormalizeSliceEndIndexUsingExclusiveUpperBound() throws Exception { void shouldCreateFirstSliceFromChatMessages() {
agentId = "memory-core-test-" + UUID.randomUUID(); String sessionId = memoryCore.getMemorySessionId();
previousLoader = AgentConfigLoader.INSTANCE;
AgentConfigLoader.INSTANCE = testLoader(agentId);
MemoryCore memoryCore = new MemoryCore(); MemoryUnit updatedUnit = memoryCore.updateMemoryUnit(List.of(
MemorySlice slice = MemorySlice.restore("slice-1", 1, 99, null, 1L);
MemoryUnit unit = new MemoryUnit("unit-1");
unit.getConversationMessages().addAll(List.of(
new Message(Message.Character.USER, "m0"), new Message(Message.Character.USER, "m0"),
new Message(Message.Character.USER, "m1"), new Message(Message.Character.USER, "m1"),
new Message(Message.Character.USER, "m2") new Message(Message.Character.USER, "m2")
)); ), "first-summary");
unit.getSlices().add(slice);
memoryCore.saveMemoryUnit(unit); assertEquals(sessionId, updatedUnit.getId());
assertEquals(List.of("m0", "m1", "m2"),
updatedUnit.getConversationMessages().stream().map(Message::getContent).toList());
assertEquals(1, updatedUnit.getSlices().size());
MemorySlice savedSlice = memoryCore.getMemorySlice("unit-1", "slice-1"); MemorySlice firstSlice = updatedUnit.getSlices().getFirst();
assertEquals(1, savedSlice.getStartIndex()); assertNotNull(firstSlice.getId());
assertEquals(3, savedSlice.getEndIndex()); assertEquals(0, firstSlice.getStartIndex());
assertEquals(3, firstSlice.getEndIndex());
assertEquals("first-summary", firstSlice.getSummary());
assertTrue(updatedUnit.getTimestamp() > 0);
assertTrue(firstSlice.getTimestamp() > 0);
} }
@Test @Test
void shouldFillMissingTimestampsWhenSavingMemoryUnit() throws Exception { void shouldAppendMessagesAndCreateNextSlice() {
agentId = "memory-core-test-" + UUID.randomUUID(); String sessionId = memoryCore.getMemorySessionId();
previousLoader = AgentConfigLoader.INSTANCE;
AgentConfigLoader.INSTANCE = testLoader(agentId);
MemoryCore memoryCore = new MemoryCore(); memoryCore.updateMemoryUnit(List.of(
new Message(Message.Character.USER, "m0")
), "first-summary");
MemorySlice slice = MemorySlice.restore("slice-1", 0, 1, "summary", 0L); MemoryUnit updatedUnit = memoryCore.updateMemoryUnit(List.of(
MemoryUnit unit = new MemoryUnit("unit-1"); new Message(Message.Character.ASSISTANT, "m1"),
unit.getConversationMessages().add(new Message(Message.Character.USER, "m0")); new Message(Message.Character.USER, "m2")
unit.getSlices().add(slice); ), "second-summary");
memoryCore.saveMemoryUnit(unit); assertEquals(sessionId, updatedUnit.getId());
assertEquals(List.of("m0", "m1", "m2"),
updatedUnit.getConversationMessages().stream().map(Message::getContent).toList());
assertEquals(2, updatedUnit.getSlices().size());
MemoryUnit savedUnit = memoryCore.getMemoryUnit("unit-1"); MemorySlice appendedSlice = updatedUnit.getSlices().getLast();
MemorySlice savedSlice = memoryCore.getMemorySlice("unit-1", "slice-1"); assertNotNull(appendedSlice.getId());
assertTrue(savedUnit.getTimestamp() > 0); assertEquals(1, appendedSlice.getStartIndex());
assertEquals(savedUnit.getTimestamp(), savedSlice.getTimestamp()); assertEquals(3, appendedSlice.getEndIndex());
assertEquals("second-summary", appendedSlice.getSummary());
assertTrue(appendedSlice.getTimestamp() > 0);
MemorySlice loadedSlice = memoryCore.getMemorySlice(sessionId, appendedSlice.getId());
assertNotNull(loadedSlice);
assertEquals(1, loadedSlice.getStartIndex());
assertEquals(3, loadedSlice.getEndIndex());
assertEquals("second-summary", loadedSlice.getSummary());
} }
} }

View File

@@ -1,7 +1,9 @@
package work.slhaf.partner.module.memory.runtime; package work.slhaf.partner.module.memory.runtime;
import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir;
import work.slhaf.partner.common.config.Config; import work.slhaf.partner.common.config.Config;
import work.slhaf.partner.common.config.PartnerAgentConfigLoader; import work.slhaf.partner.common.config.PartnerAgentConfigLoader;
import work.slhaf.partner.core.memory.MemoryCapability; import work.slhaf.partner.core.memory.MemoryCapability;
@@ -26,6 +28,11 @@ class MemoryRuntimeTest {
private AgentConfigLoader previousLoader; private AgentConfigLoader previousLoader;
private String runtimeAgentId; private String runtimeAgentId;
@BeforeAll
public static void beforeAll(@TempDir Path tempDir) {
System.setProperty("user.home", tempDir.toAbsolutePath().toString());
}
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
private static Map<String, CopyOnWriteArrayList<SliceRef>> topicSlices(MemoryRuntime runtime) throws Exception { private static Map<String, CopyOnWriteArrayList<SliceRef>> topicSlices(MemoryRuntime runtime) throws Exception {
Field field = MemoryRuntime.class.getDeclaredField("topicSlices"); Field field = MemoryRuntime.class.getDeclaredField("topicSlices");
@@ -82,16 +89,13 @@ class MemoryRuntimeTest {
message("m3") message("m3")
)); ));
MemorySlice slice = new MemorySlice(); MemorySlice slice = MemorySlice.restore("slice-1", 1, 3, null, 1L);
slice.setStartIndex(1);
slice.setEndIndex(3);
List<Message> messages = invokeSliceMessages(runtime, unit, slice); List<Message> messages = invokeSliceMessages(runtime, unit, slice);
assertEquals(List.of("m1", "m2"), messages.stream().map(Message::getContent).toList()); assertEquals(List.of("m1", "m2"), messages.stream().map(Message::getContent).toList());
slice.setStartIndex(2); MemorySlice emptySlice = MemorySlice.restore("slice-2", 2, 2, null, 2L);
slice.setEndIndex(2); assertTrue(invokeSliceMessages(runtime, unit, emptySlice).isEmpty());
assertTrue(invokeSliceMessages(runtime, unit, slice).isEmpty());
} }
@Test @Test
@@ -105,7 +109,7 @@ class MemoryRuntimeTest {
MemoryRuntime runtime = new MemoryRuntime(); MemoryRuntime runtime = new MemoryRuntime();
setField(runtime, "memoryCapability", memoryCapability); setField(runtime, "memoryCapability", memoryCapability);
MemoryUnit unit = new MemoryUnit("unit-1"); MemoryUnit unit = new MemoryUnit("unit-99");
unit.getConversationMessages().addAll(List.of( unit.getConversationMessages().addAll(List.of(
message("m0"), message("m0"),
message("m1"), message("m1"),
@@ -113,19 +117,9 @@ class MemoryRuntimeTest {
message("m3") message("m3")
)); ));
MemorySlice firstSlice = new MemorySlice(); MemorySlice firstSlice = MemorySlice.restore("slice-1", 0, 2, "first", 1L);
firstSlice.setId("slice-1");
firstSlice.setStartIndex(0);
firstSlice.setEndIndex(2);
firstSlice.setSummary("first");
firstSlice.setTimestamp(1L);
MemorySlice secondSlice = new MemorySlice(); MemorySlice secondSlice = MemorySlice.restore("slice-2", 2, 4, "second", 2L);
secondSlice.setId("slice-2");
secondSlice.setStartIndex(2);
secondSlice.setEndIndex(4);
secondSlice.setSummary("second");
secondSlice.setTimestamp(2L);
unit.getSlices().addAll(List.of(firstSlice, secondSlice)); unit.getSlices().addAll(List.of(firstSlice, secondSlice));
@@ -146,11 +140,6 @@ class MemoryRuntimeTest {
this.sessionId = sessionId; this.sessionId = sessionId;
} }
@Override
public void saveMemoryUnit(MemoryUnit memoryUnit) {
units.put(memoryUnit.getId(), memoryUnit);
}
@Override @Override
public MemoryUnit getMemoryUnit(String unitId) { public MemoryUnit getMemoryUnit(String unitId) {
return units.get(unitId); return units.get(unitId);
@@ -168,6 +157,11 @@ class MemoryRuntimeTest {
.orElse(null); .orElse(null);
} }
@Override
public MemoryUnit updateMemoryUnit(List<Message> chatMessages, String summary) {
return null;
}
@Override @Override
public Collection<MemoryUnit> listMemoryUnits() { public Collection<MemoryUnit> listMemoryUnits() {
return units.values(); return units.values();

View File

@@ -1,30 +1,42 @@
package work.slhaf.partner.module.memory.updater; package work.slhaf.partner.module.memory.updater;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir;
import org.mockito.Mockito;
import work.slhaf.partner.core.memory.MemoryCapability; import work.slhaf.partner.core.memory.MemoryCapability;
import work.slhaf.partner.core.memory.pojo.MemorySlice; import work.slhaf.partner.core.memory.pojo.MemorySlice;
import work.slhaf.partner.core.memory.pojo.MemoryUnit; import work.slhaf.partner.core.memory.pojo.MemoryUnit;
import work.slhaf.partner.framework.agent.model.pojo.Message; import work.slhaf.partner.framework.agent.model.pojo.Message;
import work.slhaf.partner.module.memory.runtime.MemoryRuntime;
import work.slhaf.partner.module.memory.updater.summarizer.MultiSummarizer;
import work.slhaf.partner.module.memory.updater.summarizer.SingleSummarizer;
import work.slhaf.partner.module.memory.updater.summarizer.entity.SummarizeResult; import work.slhaf.partner.module.memory.updater.summarizer.entity.SummarizeResult;
import java.lang.reflect.Field; import java.lang.reflect.Field;
import java.lang.reflect.Method; import java.lang.reflect.Method;
import java.nio.file.Path;
import java.util.Collection; import java.util.Collection;
import java.util.HashMap; import java.util.HashMap;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.*;
import static org.junit.jupiter.api.Assertions.assertNotNull; import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
class MemoryUpdaterTest { class MemoryUpdaterTest {
private static MemoryUnit invokeBuildMemoryUnit(MemoryUpdater updater, @BeforeAll
List<Message> chatMessages, static void beforeAll(@TempDir Path tempDir) {
SummarizeResult summarizeResult) throws Exception { System.setProperty("user.home", tempDir.toAbsolutePath().toString());
Method method = MemoryUpdater.class.getDeclaredMethod("buildMemoryUnit", List.class, SummarizeResult.class); }
private static Object invokeUpdateMemory(MemoryUpdater updater, List<Message> chatMessages) throws Exception {
Method method = MemoryUpdater.class.getDeclaredMethod("updateMemory", List.class);
method.setAccessible(true); method.setAccessible(true);
return (MemoryUnit) method.invoke(updater, chatMessages, summarizeResult); return method.invoke(updater, chatMessages);
} }
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
@@ -35,15 +47,23 @@ class MemoryUpdaterTest {
return (List<Message>) method.invoke(updater, chatMessages); return (List<Message>) method.invoke(updater, chatMessages);
} }
private static String recordField(Object target, String fieldName) throws Exception {
Field field = target.getClass().getDeclaredField(fieldName);
field.setAccessible(true);
return (String) field.get(target);
}
private static void setField(Object target, String fieldName, Object value) throws Exception { private static void setField(Object target, String fieldName, Object value) throws Exception {
Field field = target.getClass().getDeclaredField(fieldName); Field field = target.getClass().getDeclaredField(fieldName);
field.setAccessible(true); field.setAccessible(true);
field.set(target, value); field.set(target, value);
} }
private static SummarizeResult summarizeResult(String summary) { private static SummarizeResult summarizeResult(String summary, String topicPath, List<String> relatedTopicPath) {
SummarizeResult result = new SummarizeResult(); SummarizeResult result = new SummarizeResult();
result.setSummary(summary); result.setSummary(summary);
result.setTopicPath(topicPath);
result.setRelatedTopicPath(relatedTopicPath);
return result; return result;
} }
@@ -52,69 +72,90 @@ class MemoryUpdaterTest {
} }
@Test @Test
void shouldAppendNewSliceToExistingMemoryUnitWithinSameSession() throws Exception { void shouldDelegateMemoryUpdateToCapabilityAndRuntime() throws Exception {
StubMemoryCapability memoryCapability = new StubMemoryCapability("session-1"); StubMemoryCapability memoryCapability = new StubMemoryCapability("session-1");
MemoryUpdater updater = new MemoryUpdater(); MemoryUpdater updater = new MemoryUpdater();
MemoryRuntime memoryRuntime = Mockito.mock(MemoryRuntime.class);
MultiSummarizer multiSummarizer = Mockito.mock(MultiSummarizer.class);
SingleSummarizer singleSummarizer = Mockito.mock(SingleSummarizer.class);
setField(updater, "memoryCapability", memoryCapability); setField(updater, "memoryCapability", memoryCapability);
setField(updater, "memoryRuntime", memoryRuntime);
setField(updater, "multiSummarizer", multiSummarizer);
setField(updater, "singleSummarizer", singleSummarizer);
String sessionId = memoryCapability.getMemorySessionId(); when(memoryRuntime.getTopicTree()).thenReturn("topic-tree");
MemoryUnit existingUnit = new MemoryUnit(sessionId); when(multiSummarizer.execute(Mockito.any())).thenReturn(
summarizeResult("new-summary", "topic/main", List.of("topic/related"))
);
MemoryUnit existingUnit = new MemoryUnit("session-1");
existingUnit.getConversationMessages().addAll(List.of( existingUnit.getConversationMessages().addAll(List.of(
message(Message.Character.USER, "old-user"), message(Message.Character.USER, "old-user"),
message(Message.Character.ASSISTANT, "old-assistant") message(Message.Character.ASSISTANT, "old-assistant")
)); ));
MemorySlice existingSlice = new MemorySlice(); existingUnit.getSlices().add(MemorySlice.restore("slice-1", 0, 2, "old-summary", 1L));
existingSlice.setId("slice-1"); memoryCapability.putUnit(existingUnit);
existingSlice.setStartIndex(0);
existingSlice.setEndIndex(2);
existingSlice.setSummary("old-summary");
existingSlice.setTimestamp(1L);
existingUnit.getSlices().add(existingSlice);
memoryCapability.saveMemoryUnit(existingUnit);
MemoryUnit merged = invokeBuildMemoryUnit( Object rollingRecord = invokeUpdateMemory(updater, List.of(
updater,
List.of(
message(Message.Character.USER, "new-user"), message(Message.Character.USER, "new-user"),
message(Message.Character.ASSISTANT, "new-assistant") message(Message.Character.ASSISTANT, "new-assistant")
), ));
summarizeResult("new-summary")
);
assertEquals(sessionId, merged.getId()); MemoryUnit merged = memoryCapability.getMemoryUnit("session-1");
assertEquals(4, merged.getConversationMessages().size());
assertEquals(List.of("old-user", "old-assistant", "new-user", "new-assistant"), assertEquals(List.of("old-user", "old-assistant", "new-user", "new-assistant"),
merged.getConversationMessages().stream().map(Message::getContent).toList()); merged.getConversationMessages().stream().map(Message::getContent).toList());
assertEquals(2, merged.getSlices().size()); assertEquals(2, merged.getSlices().size());
MemorySlice appendedSlice = merged.getSlices().get(1); MemorySlice appendedSlice = merged.getSlices().getLast();
assertNotNull(appendedSlice.getId()); assertNotNull(appendedSlice.getId());
assertEquals(2, appendedSlice.getStartIndex()); assertEquals(2, appendedSlice.getStartIndex());
assertEquals(4, appendedSlice.getEndIndex()); assertEquals(4, appendedSlice.getEndIndex());
assertEquals("new-summary", appendedSlice.getSummary()); assertEquals("new-summary", appendedSlice.getSummary());
assertEquals(List.of("new-user", "new-assistant"),
memoryCapability.lastChatMessages().stream().map(Message::getContent).toList());
assertEquals("new-summary", memoryCapability.lastSummary());
verify(memoryRuntime).recordMemory(eq(merged), eq("topic/main"), eq(List.of("topic/related")));
assertEquals("session-1", recordField(rollingRecord, "unitId"));
assertEquals(appendedSlice.getId(), recordField(rollingRecord, "sliceId"));
assertEquals("new-summary", recordField(rollingRecord, "summary"));
} }
@Test @Test
void shouldCreateNewMemoryUnitForNewSessionId() throws Exception { void shouldCreateFirstSliceForFreshSessionThroughCapability() throws Exception {
StubMemoryCapability memoryCapability = new StubMemoryCapability("session-2"); StubMemoryCapability memoryCapability = new StubMemoryCapability("session-2");
MemoryUpdater updater = new MemoryUpdater(); MemoryUpdater updater = new MemoryUpdater();
MemoryRuntime memoryRuntime = Mockito.mock(MemoryRuntime.class);
MultiSummarizer multiSummarizer = Mockito.mock(MultiSummarizer.class);
SingleSummarizer singleSummarizer = Mockito.mock(SingleSummarizer.class);
setField(updater, "memoryCapability", memoryCapability); setField(updater, "memoryCapability", memoryCapability);
setField(updater, "memoryRuntime", memoryRuntime);
setField(updater, "multiSummarizer", multiSummarizer);
setField(updater, "singleSummarizer", singleSummarizer);
MemoryUnit created = invokeBuildMemoryUnit( when(memoryRuntime.getTopicTree()).thenReturn("topic-tree");
updater, when(multiSummarizer.execute(Mockito.any())).thenReturn(
List.of( summarizeResult("fresh-summary", "topic/root", List.of())
message(Message.Character.USER, "first"),
message(Message.Character.ASSISTANT, "second")
),
summarizeResult("fresh-summary")
); );
Object rollingRecord = invokeUpdateMemory(updater, List.of(
message(Message.Character.USER, "first"),
message(Message.Character.ASSISTANT, "second")
));
MemoryUnit created = memoryCapability.getMemoryUnit("session-2");
assertNotNull(created);
assertEquals("session-2", created.getId()); assertEquals("session-2", created.getId());
assertEquals(2, created.getConversationMessages().size()); assertEquals(List.of("first", "second"),
created.getConversationMessages().stream().map(Message::getContent).toList());
assertEquals(1, created.getSlices().size()); assertEquals(1, created.getSlices().size());
assertEquals(0, created.getSlices().getFirst().getStartIndex()); assertEquals(0, created.getSlices().getFirst().getStartIndex());
assertEquals(2, created.getSlices().getFirst().getEndIndex()); assertEquals(2, created.getSlices().getFirst().getEndIndex());
assertEquals("fresh-summary", created.getSlices().getFirst().getSummary()); assertEquals("fresh-summary", created.getSlices().getFirst().getSummary());
verify(memoryRuntime).recordMemory(eq(created), eq("topic/root"), eq(List.of()));
assertEquals("session-2", recordField(rollingRecord, "unitId"));
assertEquals(created.getSlices().getFirst().getId(), recordField(rollingRecord, "sliceId"));
assertEquals("fresh-summary", recordField(rollingRecord, "summary"));
} }
@Test @Test
@@ -123,14 +164,14 @@ class MemoryUpdaterTest {
MemoryUpdater updater = new MemoryUpdater(); MemoryUpdater updater = new MemoryUpdater();
setField(updater, "memoryCapability", memoryCapability); setField(updater, "memoryCapability", memoryCapability);
MemoryUnit existingUnit = new MemoryUnit("session-3"); MemoryUnit existingUnit = Mockito.mock(MemoryUnit.class);
existingUnit.getConversationMessages().addAll(List.of( when(existingUnit.getConversationMessages()).thenReturn(List.of(
message(Message.Character.USER, "m1"), message(Message.Character.USER, "m1"),
message(Message.Character.ASSISTANT, "m2"), message(Message.Character.ASSISTANT, "m2"),
message(Message.Character.USER, "m3"), message(Message.Character.USER, "m3"),
message(Message.Character.ASSISTANT, "m4") message(Message.Character.ASSISTANT, "m4")
)); ));
memoryCapability.saveMemoryUnit(existingUnit); memoryCapability.putUnit("session-3", existingUnit);
List<Message> increment = invokeResolveChatIncrement( List<Message> increment = invokeResolveChatIncrement(
updater, updater,
@@ -151,13 +192,13 @@ class MemoryUpdaterTest {
MemoryUpdater updater = new MemoryUpdater(); MemoryUpdater updater = new MemoryUpdater();
setField(updater, "memoryCapability", memoryCapability); setField(updater, "memoryCapability", memoryCapability);
MemoryUnit existingUnit = new MemoryUnit("session-4"); MemoryUnit existingUnit = Mockito.mock(MemoryUnit.class);
existingUnit.getConversationMessages().addAll(List.of( when(existingUnit.getConversationMessages()).thenReturn(List.of(
message(Message.Character.USER, "m1"), message(Message.Character.USER, "m1"),
message(Message.Character.ASSISTANT, "m2"), message(Message.Character.ASSISTANT, "m2"),
message(Message.Character.USER, "m3") message(Message.Character.USER, "m3")
)); ));
memoryCapability.saveMemoryUnit(existingUnit); memoryCapability.putUnit("session-4", existingUnit);
List<Message> increment = invokeResolveChatIncrement( List<Message> increment = invokeResolveChatIncrement(
updater, updater,
@@ -170,19 +211,47 @@ class MemoryUpdaterTest {
assertEquals(List.of(), increment); assertEquals(List.of(), increment);
} }
@Test
void shouldReturnNullWhenUpdateMemoryReceivesEmptySnapshot() throws Exception {
StubMemoryCapability memoryCapability = new StubMemoryCapability("session-5");
MemoryUpdater updater = new MemoryUpdater();
MemoryRuntime memoryRuntime = Mockito.mock(MemoryRuntime.class);
setField(updater, "memoryCapability", memoryCapability);
setField(updater, "memoryRuntime", memoryRuntime);
Object rollingRecord = invokeUpdateMemory(updater, List.of());
assertNull(rollingRecord);
assertNull(memoryCapability.lastSummary());
Mockito.verifyNoInteractions(memoryRuntime);
}
private static final class StubMemoryCapability implements MemoryCapability { private static final class StubMemoryCapability implements MemoryCapability {
private final String sessionId; private final String sessionId;
private final Map<String, MemoryUnit> units = new HashMap<>(); private final Map<String, MemoryUnit> units = new HashMap<>();
private List<Message> lastChatMessages;
private String lastSummary;
private StubMemoryCapability(String sessionId) { private StubMemoryCapability(String sessionId) {
this.sessionId = sessionId; this.sessionId = sessionId;
} }
@Override private void putUnit(String unitId, MemoryUnit memoryUnit) {
public void saveMemoryUnit(MemoryUnit memoryUnit) { units.put(unitId, memoryUnit);
}
private void putUnit(MemoryUnit memoryUnit) {
units.put(memoryUnit.getId(), memoryUnit); units.put(memoryUnit.getId(), memoryUnit);
} }
private List<Message> lastChatMessages() {
return lastChatMessages;
}
private String lastSummary() {
return lastSummary;
}
@Override @Override
public MemoryUnit getMemoryUnit(String unitId) { public MemoryUnit getMemoryUnit(String unitId) {
return units.get(unitId); return units.get(unitId);
@@ -200,6 +269,18 @@ class MemoryUpdaterTest {
.orElse(null); .orElse(null);
} }
@Override
public MemoryUnit updateMemoryUnit(List<Message> chatMessages, String summary) {
lastChatMessages = List.copyOf(chatMessages);
lastSummary = summary;
MemoryUnit unit = units.computeIfAbsent(sessionId, MemoryUnit::new);
unit.updateTimestamp();
int startIndex = unit.getConversationMessages().size();
unit.getConversationMessages().addAll(chatMessages);
unit.getSlices().add(new MemorySlice(startIndex, startIndex + chatMessages.size(), summary));
return unit;
}
@Override @Override
public Collection<MemoryUnit> listMemoryUnits() { public Collection<MemoryUnit> listMemoryUnits() {
return units.values(); return units.values();