diff --git a/Partner-Core/src/main/java/work/slhaf/partner/module/memory/updater/MemoryUpdater.java b/Partner-Core/src/main/java/work/slhaf/partner/module/memory/updater/MemoryUpdater.java index 62f8d10e..a8e813e8 100644 --- a/Partner-Core/src/main/java/work/slhaf/partner/module/memory/updater/MemoryUpdater.java +++ b/Partner-Core/src/main/java/work/slhaf/partner/module/memory/updater/MemoryUpdater.java @@ -29,7 +29,6 @@ import work.slhaf.partner.module.memory.updater.summarizer.entity.SummarizeResul import work.slhaf.partner.runtime.interaction.data.context.PartnerRunningFlowContext; import java.util.List; -import java.util.UUID; import java.util.concurrent.ExecutorService; import java.util.concurrent.atomic.AtomicBoolean; @@ -171,7 +170,7 @@ public class MemoryUpdater extends AbstractAgentModule.Running chatMessages, SummarizeResult summarizeResult) { - String memoryId = memoryCapability.getMemorySessionId(); - String resolvedMemoryId = memoryId == null || memoryId.isBlank() ? UUID.randomUUID().toString() : memoryId; - MemoryUnit existingUnit = memoryCapability.getMemoryUnit(resolvedMemoryId); - List existingMessages = existingUnit.getConversationMessages(); - int startIndex = existingMessages.size(); - - MemorySlice memorySlice = new MemorySlice( - startIndex, - startIndex + chatMessages.size(), - summarizeResult.getSummary() - ); - - MemoryUnit memoryUnit = new MemoryUnit(resolvedMemoryId); - memoryUnit.updateTimestamp(); - memoryUnit.getConversationMessages().addAll(chatMessages); - - memoryUnit.getSlices().add(memorySlice); - return memoryUnit; - } - @Override public int order() { return 7; diff --git a/Partner-Core/src/test/java/work/slhaf/partner/core/memory/MemoryCoreTest.java b/Partner-Core/src/test/java/work/slhaf/partner/core/memory/MemoryCoreTest.java index 4712afdb..86e2f1a8 100644 --- a/Partner-Core/src/test/java/work/slhaf/partner/core/memory/MemoryCoreTest.java +++ b/Partner-Core/src/test/java/work/slhaf/partner/core/memory/MemoryCoreTest.java @@ -1,92 +1,86 @@ package work.slhaf.partner.core.memory; -import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; -import work.slhaf.partner.common.config.Config; -import work.slhaf.partner.common.config.PartnerAgentConfigLoader; +import org.junit.jupiter.api.io.TempDir; import work.slhaf.partner.core.memory.pojo.MemorySlice; import work.slhaf.partner.core.memory.pojo.MemoryUnit; -import work.slhaf.partner.framework.agent.config.AgentConfigLoader; import work.slhaf.partner.framework.agent.model.pojo.Message; -import java.nio.file.Files; import java.nio.file.Path; -import java.util.HashMap; import java.util.List; -import java.util.UUID; -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.*; class MemoryCoreTest { - private AgentConfigLoader previousLoader; - private String agentId; + private static MemoryCore memoryCore; - private static PartnerAgentConfigLoader testLoader(String agentId) { - PartnerAgentConfigLoader loader = new PartnerAgentConfigLoader(); - Config config = new Config(); - config.setAgentId(agentId); - Config.WebSocketConfig webSocketConfig = new Config.WebSocketConfig(); - webSocketConfig.setPort(18080); - config.setWebSocketConfig(webSocketConfig); - loader.setConfig(config); - loader.setModelConfigMap(new HashMap<>()); - return loader; + @BeforeAll + static void beforeAll(@TempDir Path tempDir) { + System.setProperty("user.home", tempDir.toAbsolutePath().toString()); + memoryCore = new MemoryCore(); } - @AfterEach - void tearDown() throws Exception { - AgentConfigLoader.INSTANCE = previousLoader; - if (agentId != null) { - Files.deleteIfExists(Path.of("data/memory", agentId + "-memory-core.memory")); - Files.deleteIfExists(Path.of("data/memory", agentId + "-temp-memory-core.memory")); - } + @BeforeEach + void setUp() { + memoryCore.refreshMemorySession(); } @Test - void shouldNormalizeSliceEndIndexUsingExclusiveUpperBound() throws Exception { - agentId = "memory-core-test-" + UUID.randomUUID(); - previousLoader = AgentConfigLoader.INSTANCE; - AgentConfigLoader.INSTANCE = testLoader(agentId); + void shouldCreateFirstSliceFromChatMessages() { + String sessionId = memoryCore.getMemorySessionId(); - MemoryCore memoryCore = new MemoryCore(); - - MemorySlice slice = MemorySlice.restore("slice-1", 1, 99, null, 1L); - - MemoryUnit unit = new MemoryUnit("unit-1"); - unit.getConversationMessages().addAll(List.of( + MemoryUnit updatedUnit = memoryCore.updateMemoryUnit(List.of( new Message(Message.Character.USER, "m0"), new Message(Message.Character.USER, "m1"), new Message(Message.Character.USER, "m2") - )); - unit.getSlices().add(slice); + ), "first-summary"); - memoryCore.saveMemoryUnit(unit); + assertEquals(sessionId, updatedUnit.getId()); + assertEquals(List.of("m0", "m1", "m2"), + updatedUnit.getConversationMessages().stream().map(Message::getContent).toList()); + assertEquals(1, updatedUnit.getSlices().size()); - MemorySlice savedSlice = memoryCore.getMemorySlice("unit-1", "slice-1"); - assertEquals(1, savedSlice.getStartIndex()); - assertEquals(3, savedSlice.getEndIndex()); + MemorySlice firstSlice = updatedUnit.getSlices().getFirst(); + assertNotNull(firstSlice.getId()); + assertEquals(0, firstSlice.getStartIndex()); + assertEquals(3, firstSlice.getEndIndex()); + assertEquals("first-summary", firstSlice.getSummary()); + assertTrue(updatedUnit.getTimestamp() > 0); + assertTrue(firstSlice.getTimestamp() > 0); } @Test - void shouldFillMissingTimestampsWhenSavingMemoryUnit() throws Exception { - agentId = "memory-core-test-" + UUID.randomUUID(); - previousLoader = AgentConfigLoader.INSTANCE; - AgentConfigLoader.INSTANCE = testLoader(agentId); + void shouldAppendMessagesAndCreateNextSlice() { + String sessionId = memoryCore.getMemorySessionId(); - MemoryCore memoryCore = new MemoryCore(); + memoryCore.updateMemoryUnit(List.of( + new Message(Message.Character.USER, "m0") + ), "first-summary"); - MemorySlice slice = MemorySlice.restore("slice-1", 0, 1, "summary", 0L); - MemoryUnit unit = new MemoryUnit("unit-1"); - unit.getConversationMessages().add(new Message(Message.Character.USER, "m0")); - unit.getSlices().add(slice); + MemoryUnit updatedUnit = memoryCore.updateMemoryUnit(List.of( + new Message(Message.Character.ASSISTANT, "m1"), + new Message(Message.Character.USER, "m2") + ), "second-summary"); - memoryCore.saveMemoryUnit(unit); + assertEquals(sessionId, updatedUnit.getId()); + assertEquals(List.of("m0", "m1", "m2"), + updatedUnit.getConversationMessages().stream().map(Message::getContent).toList()); + assertEquals(2, updatedUnit.getSlices().size()); - MemoryUnit savedUnit = memoryCore.getMemoryUnit("unit-1"); - MemorySlice savedSlice = memoryCore.getMemorySlice("unit-1", "slice-1"); - assertTrue(savedUnit.getTimestamp() > 0); - assertEquals(savedUnit.getTimestamp(), savedSlice.getTimestamp()); + MemorySlice appendedSlice = updatedUnit.getSlices().getLast(); + assertNotNull(appendedSlice.getId()); + assertEquals(1, appendedSlice.getStartIndex()); + assertEquals(3, appendedSlice.getEndIndex()); + assertEquals("second-summary", appendedSlice.getSummary()); + assertTrue(appendedSlice.getTimestamp() > 0); + + MemorySlice loadedSlice = memoryCore.getMemorySlice(sessionId, appendedSlice.getId()); + assertNotNull(loadedSlice); + assertEquals(1, loadedSlice.getStartIndex()); + assertEquals(3, loadedSlice.getEndIndex()); + assertEquals("second-summary", loadedSlice.getSummary()); } } diff --git a/Partner-Core/src/test/java/work/slhaf/partner/module/memory/runtime/MemoryRuntimeTest.java b/Partner-Core/src/test/java/work/slhaf/partner/module/memory/runtime/MemoryRuntimeTest.java index 008736cb..64b8de73 100644 --- a/Partner-Core/src/test/java/work/slhaf/partner/module/memory/runtime/MemoryRuntimeTest.java +++ b/Partner-Core/src/test/java/work/slhaf/partner/module/memory/runtime/MemoryRuntimeTest.java @@ -1,7 +1,9 @@ package work.slhaf.partner.module.memory.runtime; import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.io.TempDir; import work.slhaf.partner.common.config.Config; import work.slhaf.partner.common.config.PartnerAgentConfigLoader; import work.slhaf.partner.core.memory.MemoryCapability; @@ -26,6 +28,11 @@ class MemoryRuntimeTest { private AgentConfigLoader previousLoader; private String runtimeAgentId; + @BeforeAll + public static void beforeAll(@TempDir Path tempDir) { + System.setProperty("user.home", tempDir.toAbsolutePath().toString()); + } + @SuppressWarnings("unchecked") private static Map> topicSlices(MemoryRuntime runtime) throws Exception { Field field = MemoryRuntime.class.getDeclaredField("topicSlices"); @@ -82,16 +89,13 @@ class MemoryRuntimeTest { message("m3") )); - MemorySlice slice = new MemorySlice(); - slice.setStartIndex(1); - slice.setEndIndex(3); + MemorySlice slice = MemorySlice.restore("slice-1", 1, 3, null, 1L); List messages = invokeSliceMessages(runtime, unit, slice); assertEquals(List.of("m1", "m2"), messages.stream().map(Message::getContent).toList()); - slice.setStartIndex(2); - slice.setEndIndex(2); - assertTrue(invokeSliceMessages(runtime, unit, slice).isEmpty()); + MemorySlice emptySlice = MemorySlice.restore("slice-2", 2, 2, null, 2L); + assertTrue(invokeSliceMessages(runtime, unit, emptySlice).isEmpty()); } @Test @@ -105,7 +109,7 @@ class MemoryRuntimeTest { MemoryRuntime runtime = new MemoryRuntime(); setField(runtime, "memoryCapability", memoryCapability); - MemoryUnit unit = new MemoryUnit("unit-1"); + MemoryUnit unit = new MemoryUnit("unit-99"); unit.getConversationMessages().addAll(List.of( message("m0"), message("m1"), @@ -113,19 +117,9 @@ class MemoryRuntimeTest { message("m3") )); - MemorySlice firstSlice = new MemorySlice(); - firstSlice.setId("slice-1"); - firstSlice.setStartIndex(0); - firstSlice.setEndIndex(2); - firstSlice.setSummary("first"); - firstSlice.setTimestamp(1L); + MemorySlice firstSlice = MemorySlice.restore("slice-1", 0, 2, "first", 1L); - MemorySlice secondSlice = new MemorySlice(); - secondSlice.setId("slice-2"); - secondSlice.setStartIndex(2); - secondSlice.setEndIndex(4); - secondSlice.setSummary("second"); - secondSlice.setTimestamp(2L); + MemorySlice secondSlice = MemorySlice.restore("slice-2", 2, 4, "second", 2L); unit.getSlices().addAll(List.of(firstSlice, secondSlice)); @@ -146,11 +140,6 @@ class MemoryRuntimeTest { this.sessionId = sessionId; } - @Override - public void saveMemoryUnit(MemoryUnit memoryUnit) { - units.put(memoryUnit.getId(), memoryUnit); - } - @Override public MemoryUnit getMemoryUnit(String unitId) { return units.get(unitId); @@ -168,6 +157,11 @@ class MemoryRuntimeTest { .orElse(null); } + @Override + public MemoryUnit updateMemoryUnit(List chatMessages, String summary) { + return null; + } + @Override public Collection listMemoryUnits() { return units.values(); diff --git a/Partner-Core/src/test/java/work/slhaf/partner/module/memory/updater/MemoryUpdaterTest.java b/Partner-Core/src/test/java/work/slhaf/partner/module/memory/updater/MemoryUpdaterTest.java index f3114fab..6591bed1 100644 --- a/Partner-Core/src/test/java/work/slhaf/partner/module/memory/updater/MemoryUpdaterTest.java +++ b/Partner-Core/src/test/java/work/slhaf/partner/module/memory/updater/MemoryUpdaterTest.java @@ -1,30 +1,42 @@ package work.slhaf.partner.module.memory.updater; +import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.io.TempDir; +import org.mockito.Mockito; import work.slhaf.partner.core.memory.MemoryCapability; import work.slhaf.partner.core.memory.pojo.MemorySlice; import work.slhaf.partner.core.memory.pojo.MemoryUnit; import work.slhaf.partner.framework.agent.model.pojo.Message; +import work.slhaf.partner.module.memory.runtime.MemoryRuntime; +import work.slhaf.partner.module.memory.updater.summarizer.MultiSummarizer; +import work.slhaf.partner.module.memory.updater.summarizer.SingleSummarizer; import work.slhaf.partner.module.memory.updater.summarizer.entity.SummarizeResult; import java.lang.reflect.Field; import java.lang.reflect.Method; +import java.nio.file.Path; import java.util.Collection; import java.util.HashMap; import java.util.List; import java.util.Map; -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.*; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; class MemoryUpdaterTest { - private static MemoryUnit invokeBuildMemoryUnit(MemoryUpdater updater, - List chatMessages, - SummarizeResult summarizeResult) throws Exception { - Method method = MemoryUpdater.class.getDeclaredMethod("buildMemoryUnit", List.class, SummarizeResult.class); + @BeforeAll + static void beforeAll(@TempDir Path tempDir) { + System.setProperty("user.home", tempDir.toAbsolutePath().toString()); + } + + private static Object invokeUpdateMemory(MemoryUpdater updater, List chatMessages) throws Exception { + Method method = MemoryUpdater.class.getDeclaredMethod("updateMemory", List.class); method.setAccessible(true); - return (MemoryUnit) method.invoke(updater, chatMessages, summarizeResult); + return method.invoke(updater, chatMessages); } @SuppressWarnings("unchecked") @@ -35,15 +47,23 @@ class MemoryUpdaterTest { return (List) method.invoke(updater, chatMessages); } + private static String recordField(Object target, String fieldName) throws Exception { + Field field = target.getClass().getDeclaredField(fieldName); + field.setAccessible(true); + return (String) field.get(target); + } + private static void setField(Object target, String fieldName, Object value) throws Exception { Field field = target.getClass().getDeclaredField(fieldName); field.setAccessible(true); field.set(target, value); } - private static SummarizeResult summarizeResult(String summary) { + private static SummarizeResult summarizeResult(String summary, String topicPath, List relatedTopicPath) { SummarizeResult result = new SummarizeResult(); result.setSummary(summary); + result.setTopicPath(topicPath); + result.setRelatedTopicPath(relatedTopicPath); return result; } @@ -52,69 +72,90 @@ class MemoryUpdaterTest { } @Test - void shouldAppendNewSliceToExistingMemoryUnitWithinSameSession() throws Exception { + void shouldDelegateMemoryUpdateToCapabilityAndRuntime() throws Exception { StubMemoryCapability memoryCapability = new StubMemoryCapability("session-1"); MemoryUpdater updater = new MemoryUpdater(); + MemoryRuntime memoryRuntime = Mockito.mock(MemoryRuntime.class); + MultiSummarizer multiSummarizer = Mockito.mock(MultiSummarizer.class); + SingleSummarizer singleSummarizer = Mockito.mock(SingleSummarizer.class); setField(updater, "memoryCapability", memoryCapability); + setField(updater, "memoryRuntime", memoryRuntime); + setField(updater, "multiSummarizer", multiSummarizer); + setField(updater, "singleSummarizer", singleSummarizer); - String sessionId = memoryCapability.getMemorySessionId(); - MemoryUnit existingUnit = new MemoryUnit(sessionId); + when(memoryRuntime.getTopicTree()).thenReturn("topic-tree"); + when(multiSummarizer.execute(Mockito.any())).thenReturn( + summarizeResult("new-summary", "topic/main", List.of("topic/related")) + ); + + MemoryUnit existingUnit = new MemoryUnit("session-1"); existingUnit.getConversationMessages().addAll(List.of( message(Message.Character.USER, "old-user"), message(Message.Character.ASSISTANT, "old-assistant") )); - MemorySlice existingSlice = new MemorySlice(); - existingSlice.setId("slice-1"); - existingSlice.setStartIndex(0); - existingSlice.setEndIndex(2); - existingSlice.setSummary("old-summary"); - existingSlice.setTimestamp(1L); - existingUnit.getSlices().add(existingSlice); - memoryCapability.saveMemoryUnit(existingUnit); + existingUnit.getSlices().add(MemorySlice.restore("slice-1", 0, 2, "old-summary", 1L)); + memoryCapability.putUnit(existingUnit); - MemoryUnit merged = invokeBuildMemoryUnit( - updater, - List.of( - message(Message.Character.USER, "new-user"), - message(Message.Character.ASSISTANT, "new-assistant") - ), - summarizeResult("new-summary") - ); + Object rollingRecord = invokeUpdateMemory(updater, List.of( + message(Message.Character.USER, "new-user"), + message(Message.Character.ASSISTANT, "new-assistant") + )); - assertEquals(sessionId, merged.getId()); - assertEquals(4, merged.getConversationMessages().size()); + MemoryUnit merged = memoryCapability.getMemoryUnit("session-1"); assertEquals(List.of("old-user", "old-assistant", "new-user", "new-assistant"), merged.getConversationMessages().stream().map(Message::getContent).toList()); assertEquals(2, merged.getSlices().size()); - MemorySlice appendedSlice = merged.getSlices().get(1); + MemorySlice appendedSlice = merged.getSlices().getLast(); assertNotNull(appendedSlice.getId()); assertEquals(2, appendedSlice.getStartIndex()); assertEquals(4, appendedSlice.getEndIndex()); assertEquals("new-summary", appendedSlice.getSummary()); + + assertEquals(List.of("new-user", "new-assistant"), + memoryCapability.lastChatMessages().stream().map(Message::getContent).toList()); + assertEquals("new-summary", memoryCapability.lastSummary()); + verify(memoryRuntime).recordMemory(eq(merged), eq("topic/main"), eq(List.of("topic/related"))); + assertEquals("session-1", recordField(rollingRecord, "unitId")); + assertEquals(appendedSlice.getId(), recordField(rollingRecord, "sliceId")); + assertEquals("new-summary", recordField(rollingRecord, "summary")); } @Test - void shouldCreateNewMemoryUnitForNewSessionId() throws Exception { + void shouldCreateFirstSliceForFreshSessionThroughCapability() throws Exception { StubMemoryCapability memoryCapability = new StubMemoryCapability("session-2"); MemoryUpdater updater = new MemoryUpdater(); + MemoryRuntime memoryRuntime = Mockito.mock(MemoryRuntime.class); + MultiSummarizer multiSummarizer = Mockito.mock(MultiSummarizer.class); + SingleSummarizer singleSummarizer = Mockito.mock(SingleSummarizer.class); setField(updater, "memoryCapability", memoryCapability); + setField(updater, "memoryRuntime", memoryRuntime); + setField(updater, "multiSummarizer", multiSummarizer); + setField(updater, "singleSummarizer", singleSummarizer); - MemoryUnit created = invokeBuildMemoryUnit( - updater, - List.of( - message(Message.Character.USER, "first"), - message(Message.Character.ASSISTANT, "second") - ), - summarizeResult("fresh-summary") + when(memoryRuntime.getTopicTree()).thenReturn("topic-tree"); + when(multiSummarizer.execute(Mockito.any())).thenReturn( + summarizeResult("fresh-summary", "topic/root", List.of()) ); + Object rollingRecord = invokeUpdateMemory(updater, List.of( + message(Message.Character.USER, "first"), + message(Message.Character.ASSISTANT, "second") + )); + + MemoryUnit created = memoryCapability.getMemoryUnit("session-2"); + assertNotNull(created); assertEquals("session-2", created.getId()); - assertEquals(2, created.getConversationMessages().size()); + assertEquals(List.of("first", "second"), + created.getConversationMessages().stream().map(Message::getContent).toList()); assertEquals(1, created.getSlices().size()); assertEquals(0, created.getSlices().getFirst().getStartIndex()); assertEquals(2, created.getSlices().getFirst().getEndIndex()); assertEquals("fresh-summary", created.getSlices().getFirst().getSummary()); + verify(memoryRuntime).recordMemory(eq(created), eq("topic/root"), eq(List.of())); + assertEquals("session-2", recordField(rollingRecord, "unitId")); + assertEquals(created.getSlices().getFirst().getId(), recordField(rollingRecord, "sliceId")); + assertEquals("fresh-summary", recordField(rollingRecord, "summary")); } @Test @@ -123,14 +164,14 @@ class MemoryUpdaterTest { MemoryUpdater updater = new MemoryUpdater(); setField(updater, "memoryCapability", memoryCapability); - MemoryUnit existingUnit = new MemoryUnit("session-3"); - existingUnit.getConversationMessages().addAll(List.of( + MemoryUnit existingUnit = Mockito.mock(MemoryUnit.class); + when(existingUnit.getConversationMessages()).thenReturn(List.of( message(Message.Character.USER, "m1"), message(Message.Character.ASSISTANT, "m2"), message(Message.Character.USER, "m3"), message(Message.Character.ASSISTANT, "m4") )); - memoryCapability.saveMemoryUnit(existingUnit); + memoryCapability.putUnit("session-3", existingUnit); List increment = invokeResolveChatIncrement( updater, @@ -151,13 +192,13 @@ class MemoryUpdaterTest { MemoryUpdater updater = new MemoryUpdater(); setField(updater, "memoryCapability", memoryCapability); - MemoryUnit existingUnit = new MemoryUnit("session-4"); - existingUnit.getConversationMessages().addAll(List.of( + MemoryUnit existingUnit = Mockito.mock(MemoryUnit.class); + when(existingUnit.getConversationMessages()).thenReturn(List.of( message(Message.Character.USER, "m1"), message(Message.Character.ASSISTANT, "m2"), message(Message.Character.USER, "m3") )); - memoryCapability.saveMemoryUnit(existingUnit); + memoryCapability.putUnit("session-4", existingUnit); List increment = invokeResolveChatIncrement( updater, @@ -170,19 +211,47 @@ class MemoryUpdaterTest { assertEquals(List.of(), increment); } + @Test + void shouldReturnNullWhenUpdateMemoryReceivesEmptySnapshot() throws Exception { + StubMemoryCapability memoryCapability = new StubMemoryCapability("session-5"); + MemoryUpdater updater = new MemoryUpdater(); + MemoryRuntime memoryRuntime = Mockito.mock(MemoryRuntime.class); + setField(updater, "memoryCapability", memoryCapability); + setField(updater, "memoryRuntime", memoryRuntime); + + Object rollingRecord = invokeUpdateMemory(updater, List.of()); + + assertNull(rollingRecord); + assertNull(memoryCapability.lastSummary()); + Mockito.verifyNoInteractions(memoryRuntime); + } + private static final class StubMemoryCapability implements MemoryCapability { private final String sessionId; private final Map units = new HashMap<>(); + private List lastChatMessages; + private String lastSummary; private StubMemoryCapability(String sessionId) { this.sessionId = sessionId; } - @Override - public void saveMemoryUnit(MemoryUnit memoryUnit) { + private void putUnit(String unitId, MemoryUnit memoryUnit) { + units.put(unitId, memoryUnit); + } + + private void putUnit(MemoryUnit memoryUnit) { units.put(memoryUnit.getId(), memoryUnit); } + private List lastChatMessages() { + return lastChatMessages; + } + + private String lastSummary() { + return lastSummary; + } + @Override public MemoryUnit getMemoryUnit(String unitId) { return units.get(unitId); @@ -200,6 +269,18 @@ class MemoryUpdaterTest { .orElse(null); } + @Override + public MemoryUnit updateMemoryUnit(List chatMessages, String summary) { + lastChatMessages = List.copyOf(chatMessages); + lastSummary = summary; + MemoryUnit unit = units.computeIfAbsent(sessionId, MemoryUnit::new); + unit.updateTimestamp(); + int startIndex = unit.getConversationMessages().size(); + unit.getConversationMessages().addAll(chatMessages); + unit.getSlices().add(new MemorySlice(startIndex, startIndex + chatMessages.size(), summary)); + return unit; + } + @Override public Collection listMemoryUnits() { return units.values();