mirror of
https://github.com/slhaf/Partner.git
synced 2026-05-12 16:53:04 +08:00
test(memory): test with new memory behavior
This commit is contained in:
@@ -29,7 +29,6 @@ import work.slhaf.partner.module.memory.updater.summarizer.entity.SummarizeResul
|
|||||||
import work.slhaf.partner.runtime.interaction.data.context.PartnerRunningFlowContext;
|
import work.slhaf.partner.runtime.interaction.data.context.PartnerRunningFlowContext;
|
||||||
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.UUID;
|
|
||||||
import java.util.concurrent.ExecutorService;
|
import java.util.concurrent.ExecutorService;
|
||||||
import java.util.concurrent.atomic.AtomicBoolean;
|
import java.util.concurrent.atomic.AtomicBoolean;
|
||||||
|
|
||||||
@@ -171,7 +170,7 @@ public class MemoryUpdater extends AbstractAgentModule.Running<PartnerRunningFlo
|
|||||||
log.debug("[MemoryUpdater] 记忆更新-总结流程-输入: {}", JSONObject.toJSONString(summarizeInput));
|
log.debug("[MemoryUpdater] 记忆更新-总结流程-输入: {}", JSONObject.toJSONString(summarizeInput));
|
||||||
SummarizeResult summarizeResult = summarize(summarizeInput);
|
SummarizeResult summarizeResult = summarize(summarizeInput);
|
||||||
log.debug("[MemoryUpdater] 记忆更新-总结流程-输出: {}", JSONObject.toJSONString(summarizeResult));
|
log.debug("[MemoryUpdater] 记忆更新-总结流程-输出: {}", JSONObject.toJSONString(summarizeResult));
|
||||||
MemoryUnit memoryUnit = buildMemoryUnit(chatSnapshot, summarizeResult);
|
MemoryUnit memoryUnit = memoryCapability.updateMemoryUnit(chatSnapshot, summarizeResult.getSummary());
|
||||||
memoryRuntime.recordMemory(
|
memoryRuntime.recordMemory(
|
||||||
memoryUnit,
|
memoryUnit,
|
||||||
summarizeResult.getTopicPath(),
|
summarizeResult.getTopicPath(),
|
||||||
@@ -187,28 +186,6 @@ public class MemoryUpdater extends AbstractAgentModule.Running<PartnerRunningFlo
|
|||||||
return multiSummarizer.execute(summarizeInput);
|
return multiSummarizer.execute(summarizeInput);
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO update memory unit via memory capability
|
|
||||||
private MemoryUnit buildMemoryUnit(List<Message> chatMessages, SummarizeResult summarizeResult) {
|
|
||||||
String memoryId = memoryCapability.getMemorySessionId();
|
|
||||||
String resolvedMemoryId = memoryId == null || memoryId.isBlank() ? UUID.randomUUID().toString() : memoryId;
|
|
||||||
MemoryUnit existingUnit = memoryCapability.getMemoryUnit(resolvedMemoryId);
|
|
||||||
List<Message> existingMessages = existingUnit.getConversationMessages();
|
|
||||||
int startIndex = existingMessages.size();
|
|
||||||
|
|
||||||
MemorySlice memorySlice = new MemorySlice(
|
|
||||||
startIndex,
|
|
||||||
startIndex + chatMessages.size(),
|
|
||||||
summarizeResult.getSummary()
|
|
||||||
);
|
|
||||||
|
|
||||||
MemoryUnit memoryUnit = new MemoryUnit(resolvedMemoryId);
|
|
||||||
memoryUnit.updateTimestamp();
|
|
||||||
memoryUnit.getConversationMessages().addAll(chatMessages);
|
|
||||||
|
|
||||||
memoryUnit.getSlices().add(memorySlice);
|
|
||||||
return memoryUnit;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public int order() {
|
public int order() {
|
||||||
return 7;
|
return 7;
|
||||||
|
|||||||
@@ -1,92 +1,86 @@
|
|||||||
package work.slhaf.partner.core.memory;
|
package work.slhaf.partner.core.memory;
|
||||||
|
|
||||||
import org.junit.jupiter.api.AfterEach;
|
import org.junit.jupiter.api.BeforeAll;
|
||||||
|
import org.junit.jupiter.api.BeforeEach;
|
||||||
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import work.slhaf.partner.common.config.Config;
|
import org.junit.jupiter.api.io.TempDir;
|
||||||
import work.slhaf.partner.common.config.PartnerAgentConfigLoader;
|
|
||||||
import work.slhaf.partner.core.memory.pojo.MemorySlice;
|
import work.slhaf.partner.core.memory.pojo.MemorySlice;
|
||||||
import work.slhaf.partner.core.memory.pojo.MemoryUnit;
|
import work.slhaf.partner.core.memory.pojo.MemoryUnit;
|
||||||
import work.slhaf.partner.framework.agent.config.AgentConfigLoader;
|
|
||||||
import work.slhaf.partner.framework.agent.model.pojo.Message;
|
import work.slhaf.partner.framework.agent.model.pojo.Message;
|
||||||
|
|
||||||
import java.nio.file.Files;
|
|
||||||
import java.nio.file.Path;
|
import java.nio.file.Path;
|
||||||
import java.util.HashMap;
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.UUID;
|
|
||||||
|
|
||||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
import static org.junit.jupiter.api.Assertions.*;
|
||||||
import static org.junit.jupiter.api.Assertions.assertTrue;
|
|
||||||
|
|
||||||
class MemoryCoreTest {
|
class MemoryCoreTest {
|
||||||
|
|
||||||
private AgentConfigLoader previousLoader;
|
private static MemoryCore memoryCore;
|
||||||
private String agentId;
|
|
||||||
|
|
||||||
private static PartnerAgentConfigLoader testLoader(String agentId) {
|
@BeforeAll
|
||||||
PartnerAgentConfigLoader loader = new PartnerAgentConfigLoader();
|
static void beforeAll(@TempDir Path tempDir) {
|
||||||
Config config = new Config();
|
System.setProperty("user.home", tempDir.toAbsolutePath().toString());
|
||||||
config.setAgentId(agentId);
|
memoryCore = new MemoryCore();
|
||||||
Config.WebSocketConfig webSocketConfig = new Config.WebSocketConfig();
|
|
||||||
webSocketConfig.setPort(18080);
|
|
||||||
config.setWebSocketConfig(webSocketConfig);
|
|
||||||
loader.setConfig(config);
|
|
||||||
loader.setModelConfigMap(new HashMap<>());
|
|
||||||
return loader;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@AfterEach
|
@BeforeEach
|
||||||
void tearDown() throws Exception {
|
void setUp() {
|
||||||
AgentConfigLoader.INSTANCE = previousLoader;
|
memoryCore.refreshMemorySession();
|
||||||
if (agentId != null) {
|
|
||||||
Files.deleteIfExists(Path.of("data/memory", agentId + "-memory-core.memory"));
|
|
||||||
Files.deleteIfExists(Path.of("data/memory", agentId + "-temp-memory-core.memory"));
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
void shouldNormalizeSliceEndIndexUsingExclusiveUpperBound() throws Exception {
|
void shouldCreateFirstSliceFromChatMessages() {
|
||||||
agentId = "memory-core-test-" + UUID.randomUUID();
|
String sessionId = memoryCore.getMemorySessionId();
|
||||||
previousLoader = AgentConfigLoader.INSTANCE;
|
|
||||||
AgentConfigLoader.INSTANCE = testLoader(agentId);
|
|
||||||
|
|
||||||
MemoryCore memoryCore = new MemoryCore();
|
MemoryUnit updatedUnit = memoryCore.updateMemoryUnit(List.of(
|
||||||
|
|
||||||
MemorySlice slice = MemorySlice.restore("slice-1", 1, 99, null, 1L);
|
|
||||||
|
|
||||||
MemoryUnit unit = new MemoryUnit("unit-1");
|
|
||||||
unit.getConversationMessages().addAll(List.of(
|
|
||||||
new Message(Message.Character.USER, "m0"),
|
new Message(Message.Character.USER, "m0"),
|
||||||
new Message(Message.Character.USER, "m1"),
|
new Message(Message.Character.USER, "m1"),
|
||||||
new Message(Message.Character.USER, "m2")
|
new Message(Message.Character.USER, "m2")
|
||||||
));
|
), "first-summary");
|
||||||
unit.getSlices().add(slice);
|
|
||||||
|
|
||||||
memoryCore.saveMemoryUnit(unit);
|
assertEquals(sessionId, updatedUnit.getId());
|
||||||
|
assertEquals(List.of("m0", "m1", "m2"),
|
||||||
|
updatedUnit.getConversationMessages().stream().map(Message::getContent).toList());
|
||||||
|
assertEquals(1, updatedUnit.getSlices().size());
|
||||||
|
|
||||||
MemorySlice savedSlice = memoryCore.getMemorySlice("unit-1", "slice-1");
|
MemorySlice firstSlice = updatedUnit.getSlices().getFirst();
|
||||||
assertEquals(1, savedSlice.getStartIndex());
|
assertNotNull(firstSlice.getId());
|
||||||
assertEquals(3, savedSlice.getEndIndex());
|
assertEquals(0, firstSlice.getStartIndex());
|
||||||
|
assertEquals(3, firstSlice.getEndIndex());
|
||||||
|
assertEquals("first-summary", firstSlice.getSummary());
|
||||||
|
assertTrue(updatedUnit.getTimestamp() > 0);
|
||||||
|
assertTrue(firstSlice.getTimestamp() > 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
void shouldFillMissingTimestampsWhenSavingMemoryUnit() throws Exception {
|
void shouldAppendMessagesAndCreateNextSlice() {
|
||||||
agentId = "memory-core-test-" + UUID.randomUUID();
|
String sessionId = memoryCore.getMemorySessionId();
|
||||||
previousLoader = AgentConfigLoader.INSTANCE;
|
|
||||||
AgentConfigLoader.INSTANCE = testLoader(agentId);
|
|
||||||
|
|
||||||
MemoryCore memoryCore = new MemoryCore();
|
memoryCore.updateMemoryUnit(List.of(
|
||||||
|
new Message(Message.Character.USER, "m0")
|
||||||
|
), "first-summary");
|
||||||
|
|
||||||
MemorySlice slice = MemorySlice.restore("slice-1", 0, 1, "summary", 0L);
|
MemoryUnit updatedUnit = memoryCore.updateMemoryUnit(List.of(
|
||||||
MemoryUnit unit = new MemoryUnit("unit-1");
|
new Message(Message.Character.ASSISTANT, "m1"),
|
||||||
unit.getConversationMessages().add(new Message(Message.Character.USER, "m0"));
|
new Message(Message.Character.USER, "m2")
|
||||||
unit.getSlices().add(slice);
|
), "second-summary");
|
||||||
|
|
||||||
memoryCore.saveMemoryUnit(unit);
|
assertEquals(sessionId, updatedUnit.getId());
|
||||||
|
assertEquals(List.of("m0", "m1", "m2"),
|
||||||
|
updatedUnit.getConversationMessages().stream().map(Message::getContent).toList());
|
||||||
|
assertEquals(2, updatedUnit.getSlices().size());
|
||||||
|
|
||||||
MemoryUnit savedUnit = memoryCore.getMemoryUnit("unit-1");
|
MemorySlice appendedSlice = updatedUnit.getSlices().getLast();
|
||||||
MemorySlice savedSlice = memoryCore.getMemorySlice("unit-1", "slice-1");
|
assertNotNull(appendedSlice.getId());
|
||||||
assertTrue(savedUnit.getTimestamp() > 0);
|
assertEquals(1, appendedSlice.getStartIndex());
|
||||||
assertEquals(savedUnit.getTimestamp(), savedSlice.getTimestamp());
|
assertEquals(3, appendedSlice.getEndIndex());
|
||||||
|
assertEquals("second-summary", appendedSlice.getSummary());
|
||||||
|
assertTrue(appendedSlice.getTimestamp() > 0);
|
||||||
|
|
||||||
|
MemorySlice loadedSlice = memoryCore.getMemorySlice(sessionId, appendedSlice.getId());
|
||||||
|
assertNotNull(loadedSlice);
|
||||||
|
assertEquals(1, loadedSlice.getStartIndex());
|
||||||
|
assertEquals(3, loadedSlice.getEndIndex());
|
||||||
|
assertEquals("second-summary", loadedSlice.getSummary());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,7 +1,9 @@
|
|||||||
package work.slhaf.partner.module.memory.runtime;
|
package work.slhaf.partner.module.memory.runtime;
|
||||||
|
|
||||||
import org.junit.jupiter.api.AfterEach;
|
import org.junit.jupiter.api.AfterEach;
|
||||||
|
import org.junit.jupiter.api.BeforeAll;
|
||||||
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
|
import org.junit.jupiter.api.io.TempDir;
|
||||||
import work.slhaf.partner.common.config.Config;
|
import work.slhaf.partner.common.config.Config;
|
||||||
import work.slhaf.partner.common.config.PartnerAgentConfigLoader;
|
import work.slhaf.partner.common.config.PartnerAgentConfigLoader;
|
||||||
import work.slhaf.partner.core.memory.MemoryCapability;
|
import work.slhaf.partner.core.memory.MemoryCapability;
|
||||||
@@ -26,6 +28,11 @@ class MemoryRuntimeTest {
|
|||||||
private AgentConfigLoader previousLoader;
|
private AgentConfigLoader previousLoader;
|
||||||
private String runtimeAgentId;
|
private String runtimeAgentId;
|
||||||
|
|
||||||
|
@BeforeAll
|
||||||
|
public static void beforeAll(@TempDir Path tempDir) {
|
||||||
|
System.setProperty("user.home", tempDir.toAbsolutePath().toString());
|
||||||
|
}
|
||||||
|
|
||||||
@SuppressWarnings("unchecked")
|
@SuppressWarnings("unchecked")
|
||||||
private static Map<String, CopyOnWriteArrayList<SliceRef>> topicSlices(MemoryRuntime runtime) throws Exception {
|
private static Map<String, CopyOnWriteArrayList<SliceRef>> topicSlices(MemoryRuntime runtime) throws Exception {
|
||||||
Field field = MemoryRuntime.class.getDeclaredField("topicSlices");
|
Field field = MemoryRuntime.class.getDeclaredField("topicSlices");
|
||||||
@@ -82,16 +89,13 @@ class MemoryRuntimeTest {
|
|||||||
message("m3")
|
message("m3")
|
||||||
));
|
));
|
||||||
|
|
||||||
MemorySlice slice = new MemorySlice();
|
MemorySlice slice = MemorySlice.restore("slice-1", 1, 3, null, 1L);
|
||||||
slice.setStartIndex(1);
|
|
||||||
slice.setEndIndex(3);
|
|
||||||
|
|
||||||
List<Message> messages = invokeSliceMessages(runtime, unit, slice);
|
List<Message> messages = invokeSliceMessages(runtime, unit, slice);
|
||||||
assertEquals(List.of("m1", "m2"), messages.stream().map(Message::getContent).toList());
|
assertEquals(List.of("m1", "m2"), messages.stream().map(Message::getContent).toList());
|
||||||
|
|
||||||
slice.setStartIndex(2);
|
MemorySlice emptySlice = MemorySlice.restore("slice-2", 2, 2, null, 2L);
|
||||||
slice.setEndIndex(2);
|
assertTrue(invokeSliceMessages(runtime, unit, emptySlice).isEmpty());
|
||||||
assertTrue(invokeSliceMessages(runtime, unit, slice).isEmpty());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
@@ -105,7 +109,7 @@ class MemoryRuntimeTest {
|
|||||||
MemoryRuntime runtime = new MemoryRuntime();
|
MemoryRuntime runtime = new MemoryRuntime();
|
||||||
setField(runtime, "memoryCapability", memoryCapability);
|
setField(runtime, "memoryCapability", memoryCapability);
|
||||||
|
|
||||||
MemoryUnit unit = new MemoryUnit("unit-1");
|
MemoryUnit unit = new MemoryUnit("unit-99");
|
||||||
unit.getConversationMessages().addAll(List.of(
|
unit.getConversationMessages().addAll(List.of(
|
||||||
message("m0"),
|
message("m0"),
|
||||||
message("m1"),
|
message("m1"),
|
||||||
@@ -113,19 +117,9 @@ class MemoryRuntimeTest {
|
|||||||
message("m3")
|
message("m3")
|
||||||
));
|
));
|
||||||
|
|
||||||
MemorySlice firstSlice = new MemorySlice();
|
MemorySlice firstSlice = MemorySlice.restore("slice-1", 0, 2, "first", 1L);
|
||||||
firstSlice.setId("slice-1");
|
|
||||||
firstSlice.setStartIndex(0);
|
|
||||||
firstSlice.setEndIndex(2);
|
|
||||||
firstSlice.setSummary("first");
|
|
||||||
firstSlice.setTimestamp(1L);
|
|
||||||
|
|
||||||
MemorySlice secondSlice = new MemorySlice();
|
MemorySlice secondSlice = MemorySlice.restore("slice-2", 2, 4, "second", 2L);
|
||||||
secondSlice.setId("slice-2");
|
|
||||||
secondSlice.setStartIndex(2);
|
|
||||||
secondSlice.setEndIndex(4);
|
|
||||||
secondSlice.setSummary("second");
|
|
||||||
secondSlice.setTimestamp(2L);
|
|
||||||
|
|
||||||
unit.getSlices().addAll(List.of(firstSlice, secondSlice));
|
unit.getSlices().addAll(List.of(firstSlice, secondSlice));
|
||||||
|
|
||||||
@@ -146,11 +140,6 @@ class MemoryRuntimeTest {
|
|||||||
this.sessionId = sessionId;
|
this.sessionId = sessionId;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public void saveMemoryUnit(MemoryUnit memoryUnit) {
|
|
||||||
units.put(memoryUnit.getId(), memoryUnit);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public MemoryUnit getMemoryUnit(String unitId) {
|
public MemoryUnit getMemoryUnit(String unitId) {
|
||||||
return units.get(unitId);
|
return units.get(unitId);
|
||||||
@@ -168,6 +157,11 @@ class MemoryRuntimeTest {
|
|||||||
.orElse(null);
|
.orElse(null);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public MemoryUnit updateMemoryUnit(List<Message> chatMessages, String summary) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Collection<MemoryUnit> listMemoryUnits() {
|
public Collection<MemoryUnit> listMemoryUnits() {
|
||||||
return units.values();
|
return units.values();
|
||||||
|
|||||||
@@ -1,30 +1,42 @@
|
|||||||
package work.slhaf.partner.module.memory.updater;
|
package work.slhaf.partner.module.memory.updater;
|
||||||
|
|
||||||
|
import org.junit.jupiter.api.BeforeAll;
|
||||||
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
|
import org.junit.jupiter.api.io.TempDir;
|
||||||
|
import org.mockito.Mockito;
|
||||||
import work.slhaf.partner.core.memory.MemoryCapability;
|
import work.slhaf.partner.core.memory.MemoryCapability;
|
||||||
import work.slhaf.partner.core.memory.pojo.MemorySlice;
|
import work.slhaf.partner.core.memory.pojo.MemorySlice;
|
||||||
import work.slhaf.partner.core.memory.pojo.MemoryUnit;
|
import work.slhaf.partner.core.memory.pojo.MemoryUnit;
|
||||||
import work.slhaf.partner.framework.agent.model.pojo.Message;
|
import work.slhaf.partner.framework.agent.model.pojo.Message;
|
||||||
|
import work.slhaf.partner.module.memory.runtime.MemoryRuntime;
|
||||||
|
import work.slhaf.partner.module.memory.updater.summarizer.MultiSummarizer;
|
||||||
|
import work.slhaf.partner.module.memory.updater.summarizer.SingleSummarizer;
|
||||||
import work.slhaf.partner.module.memory.updater.summarizer.entity.SummarizeResult;
|
import work.slhaf.partner.module.memory.updater.summarizer.entity.SummarizeResult;
|
||||||
|
|
||||||
import java.lang.reflect.Field;
|
import java.lang.reflect.Field;
|
||||||
import java.lang.reflect.Method;
|
import java.lang.reflect.Method;
|
||||||
|
import java.nio.file.Path;
|
||||||
import java.util.Collection;
|
import java.util.Collection;
|
||||||
import java.util.HashMap;
|
import java.util.HashMap;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
|
|
||||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
import static org.junit.jupiter.api.Assertions.*;
|
||||||
import static org.junit.jupiter.api.Assertions.assertNotNull;
|
import static org.mockito.ArgumentMatchers.eq;
|
||||||
|
import static org.mockito.Mockito.verify;
|
||||||
|
import static org.mockito.Mockito.when;
|
||||||
|
|
||||||
class MemoryUpdaterTest {
|
class MemoryUpdaterTest {
|
||||||
|
|
||||||
private static MemoryUnit invokeBuildMemoryUnit(MemoryUpdater updater,
|
@BeforeAll
|
||||||
List<Message> chatMessages,
|
static void beforeAll(@TempDir Path tempDir) {
|
||||||
SummarizeResult summarizeResult) throws Exception {
|
System.setProperty("user.home", tempDir.toAbsolutePath().toString());
|
||||||
Method method = MemoryUpdater.class.getDeclaredMethod("buildMemoryUnit", List.class, SummarizeResult.class);
|
}
|
||||||
|
|
||||||
|
private static Object invokeUpdateMemory(MemoryUpdater updater, List<Message> chatMessages) throws Exception {
|
||||||
|
Method method = MemoryUpdater.class.getDeclaredMethod("updateMemory", List.class);
|
||||||
method.setAccessible(true);
|
method.setAccessible(true);
|
||||||
return (MemoryUnit) method.invoke(updater, chatMessages, summarizeResult);
|
return method.invoke(updater, chatMessages);
|
||||||
}
|
}
|
||||||
|
|
||||||
@SuppressWarnings("unchecked")
|
@SuppressWarnings("unchecked")
|
||||||
@@ -35,15 +47,23 @@ class MemoryUpdaterTest {
|
|||||||
return (List<Message>) method.invoke(updater, chatMessages);
|
return (List<Message>) method.invoke(updater, chatMessages);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private static String recordField(Object target, String fieldName) throws Exception {
|
||||||
|
Field field = target.getClass().getDeclaredField(fieldName);
|
||||||
|
field.setAccessible(true);
|
||||||
|
return (String) field.get(target);
|
||||||
|
}
|
||||||
|
|
||||||
private static void setField(Object target, String fieldName, Object value) throws Exception {
|
private static void setField(Object target, String fieldName, Object value) throws Exception {
|
||||||
Field field = target.getClass().getDeclaredField(fieldName);
|
Field field = target.getClass().getDeclaredField(fieldName);
|
||||||
field.setAccessible(true);
|
field.setAccessible(true);
|
||||||
field.set(target, value);
|
field.set(target, value);
|
||||||
}
|
}
|
||||||
|
|
||||||
private static SummarizeResult summarizeResult(String summary) {
|
private static SummarizeResult summarizeResult(String summary, String topicPath, List<String> relatedTopicPath) {
|
||||||
SummarizeResult result = new SummarizeResult();
|
SummarizeResult result = new SummarizeResult();
|
||||||
result.setSummary(summary);
|
result.setSummary(summary);
|
||||||
|
result.setTopicPath(topicPath);
|
||||||
|
result.setRelatedTopicPath(relatedTopicPath);
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -52,69 +72,90 @@ class MemoryUpdaterTest {
|
|||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
void shouldAppendNewSliceToExistingMemoryUnitWithinSameSession() throws Exception {
|
void shouldDelegateMemoryUpdateToCapabilityAndRuntime() throws Exception {
|
||||||
StubMemoryCapability memoryCapability = new StubMemoryCapability("session-1");
|
StubMemoryCapability memoryCapability = new StubMemoryCapability("session-1");
|
||||||
MemoryUpdater updater = new MemoryUpdater();
|
MemoryUpdater updater = new MemoryUpdater();
|
||||||
|
MemoryRuntime memoryRuntime = Mockito.mock(MemoryRuntime.class);
|
||||||
|
MultiSummarizer multiSummarizer = Mockito.mock(MultiSummarizer.class);
|
||||||
|
SingleSummarizer singleSummarizer = Mockito.mock(SingleSummarizer.class);
|
||||||
setField(updater, "memoryCapability", memoryCapability);
|
setField(updater, "memoryCapability", memoryCapability);
|
||||||
|
setField(updater, "memoryRuntime", memoryRuntime);
|
||||||
|
setField(updater, "multiSummarizer", multiSummarizer);
|
||||||
|
setField(updater, "singleSummarizer", singleSummarizer);
|
||||||
|
|
||||||
String sessionId = memoryCapability.getMemorySessionId();
|
when(memoryRuntime.getTopicTree()).thenReturn("topic-tree");
|
||||||
MemoryUnit existingUnit = new MemoryUnit(sessionId);
|
when(multiSummarizer.execute(Mockito.any())).thenReturn(
|
||||||
|
summarizeResult("new-summary", "topic/main", List.of("topic/related"))
|
||||||
|
);
|
||||||
|
|
||||||
|
MemoryUnit existingUnit = new MemoryUnit("session-1");
|
||||||
existingUnit.getConversationMessages().addAll(List.of(
|
existingUnit.getConversationMessages().addAll(List.of(
|
||||||
message(Message.Character.USER, "old-user"),
|
message(Message.Character.USER, "old-user"),
|
||||||
message(Message.Character.ASSISTANT, "old-assistant")
|
message(Message.Character.ASSISTANT, "old-assistant")
|
||||||
));
|
));
|
||||||
MemorySlice existingSlice = new MemorySlice();
|
existingUnit.getSlices().add(MemorySlice.restore("slice-1", 0, 2, "old-summary", 1L));
|
||||||
existingSlice.setId("slice-1");
|
memoryCapability.putUnit(existingUnit);
|
||||||
existingSlice.setStartIndex(0);
|
|
||||||
existingSlice.setEndIndex(2);
|
|
||||||
existingSlice.setSummary("old-summary");
|
|
||||||
existingSlice.setTimestamp(1L);
|
|
||||||
existingUnit.getSlices().add(existingSlice);
|
|
||||||
memoryCapability.saveMemoryUnit(existingUnit);
|
|
||||||
|
|
||||||
MemoryUnit merged = invokeBuildMemoryUnit(
|
Object rollingRecord = invokeUpdateMemory(updater, List.of(
|
||||||
updater,
|
message(Message.Character.USER, "new-user"),
|
||||||
List.of(
|
message(Message.Character.ASSISTANT, "new-assistant")
|
||||||
message(Message.Character.USER, "new-user"),
|
));
|
||||||
message(Message.Character.ASSISTANT, "new-assistant")
|
|
||||||
),
|
|
||||||
summarizeResult("new-summary")
|
|
||||||
);
|
|
||||||
|
|
||||||
assertEquals(sessionId, merged.getId());
|
MemoryUnit merged = memoryCapability.getMemoryUnit("session-1");
|
||||||
assertEquals(4, merged.getConversationMessages().size());
|
|
||||||
assertEquals(List.of("old-user", "old-assistant", "new-user", "new-assistant"),
|
assertEquals(List.of("old-user", "old-assistant", "new-user", "new-assistant"),
|
||||||
merged.getConversationMessages().stream().map(Message::getContent).toList());
|
merged.getConversationMessages().stream().map(Message::getContent).toList());
|
||||||
assertEquals(2, merged.getSlices().size());
|
assertEquals(2, merged.getSlices().size());
|
||||||
|
|
||||||
MemorySlice appendedSlice = merged.getSlices().get(1);
|
MemorySlice appendedSlice = merged.getSlices().getLast();
|
||||||
assertNotNull(appendedSlice.getId());
|
assertNotNull(appendedSlice.getId());
|
||||||
assertEquals(2, appendedSlice.getStartIndex());
|
assertEquals(2, appendedSlice.getStartIndex());
|
||||||
assertEquals(4, appendedSlice.getEndIndex());
|
assertEquals(4, appendedSlice.getEndIndex());
|
||||||
assertEquals("new-summary", appendedSlice.getSummary());
|
assertEquals("new-summary", appendedSlice.getSummary());
|
||||||
|
|
||||||
|
assertEquals(List.of("new-user", "new-assistant"),
|
||||||
|
memoryCapability.lastChatMessages().stream().map(Message::getContent).toList());
|
||||||
|
assertEquals("new-summary", memoryCapability.lastSummary());
|
||||||
|
verify(memoryRuntime).recordMemory(eq(merged), eq("topic/main"), eq(List.of("topic/related")));
|
||||||
|
assertEquals("session-1", recordField(rollingRecord, "unitId"));
|
||||||
|
assertEquals(appendedSlice.getId(), recordField(rollingRecord, "sliceId"));
|
||||||
|
assertEquals("new-summary", recordField(rollingRecord, "summary"));
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
void shouldCreateNewMemoryUnitForNewSessionId() throws Exception {
|
void shouldCreateFirstSliceForFreshSessionThroughCapability() throws Exception {
|
||||||
StubMemoryCapability memoryCapability = new StubMemoryCapability("session-2");
|
StubMemoryCapability memoryCapability = new StubMemoryCapability("session-2");
|
||||||
MemoryUpdater updater = new MemoryUpdater();
|
MemoryUpdater updater = new MemoryUpdater();
|
||||||
|
MemoryRuntime memoryRuntime = Mockito.mock(MemoryRuntime.class);
|
||||||
|
MultiSummarizer multiSummarizer = Mockito.mock(MultiSummarizer.class);
|
||||||
|
SingleSummarizer singleSummarizer = Mockito.mock(SingleSummarizer.class);
|
||||||
setField(updater, "memoryCapability", memoryCapability);
|
setField(updater, "memoryCapability", memoryCapability);
|
||||||
|
setField(updater, "memoryRuntime", memoryRuntime);
|
||||||
|
setField(updater, "multiSummarizer", multiSummarizer);
|
||||||
|
setField(updater, "singleSummarizer", singleSummarizer);
|
||||||
|
|
||||||
MemoryUnit created = invokeBuildMemoryUnit(
|
when(memoryRuntime.getTopicTree()).thenReturn("topic-tree");
|
||||||
updater,
|
when(multiSummarizer.execute(Mockito.any())).thenReturn(
|
||||||
List.of(
|
summarizeResult("fresh-summary", "topic/root", List.of())
|
||||||
message(Message.Character.USER, "first"),
|
|
||||||
message(Message.Character.ASSISTANT, "second")
|
|
||||||
),
|
|
||||||
summarizeResult("fresh-summary")
|
|
||||||
);
|
);
|
||||||
|
|
||||||
|
Object rollingRecord = invokeUpdateMemory(updater, List.of(
|
||||||
|
message(Message.Character.USER, "first"),
|
||||||
|
message(Message.Character.ASSISTANT, "second")
|
||||||
|
));
|
||||||
|
|
||||||
|
MemoryUnit created = memoryCapability.getMemoryUnit("session-2");
|
||||||
|
assertNotNull(created);
|
||||||
assertEquals("session-2", created.getId());
|
assertEquals("session-2", created.getId());
|
||||||
assertEquals(2, created.getConversationMessages().size());
|
assertEquals(List.of("first", "second"),
|
||||||
|
created.getConversationMessages().stream().map(Message::getContent).toList());
|
||||||
assertEquals(1, created.getSlices().size());
|
assertEquals(1, created.getSlices().size());
|
||||||
assertEquals(0, created.getSlices().getFirst().getStartIndex());
|
assertEquals(0, created.getSlices().getFirst().getStartIndex());
|
||||||
assertEquals(2, created.getSlices().getFirst().getEndIndex());
|
assertEquals(2, created.getSlices().getFirst().getEndIndex());
|
||||||
assertEquals("fresh-summary", created.getSlices().getFirst().getSummary());
|
assertEquals("fresh-summary", created.getSlices().getFirst().getSummary());
|
||||||
|
verify(memoryRuntime).recordMemory(eq(created), eq("topic/root"), eq(List.of()));
|
||||||
|
assertEquals("session-2", recordField(rollingRecord, "unitId"));
|
||||||
|
assertEquals(created.getSlices().getFirst().getId(), recordField(rollingRecord, "sliceId"));
|
||||||
|
assertEquals("fresh-summary", recordField(rollingRecord, "summary"));
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
@@ -123,14 +164,14 @@ class MemoryUpdaterTest {
|
|||||||
MemoryUpdater updater = new MemoryUpdater();
|
MemoryUpdater updater = new MemoryUpdater();
|
||||||
setField(updater, "memoryCapability", memoryCapability);
|
setField(updater, "memoryCapability", memoryCapability);
|
||||||
|
|
||||||
MemoryUnit existingUnit = new MemoryUnit("session-3");
|
MemoryUnit existingUnit = Mockito.mock(MemoryUnit.class);
|
||||||
existingUnit.getConversationMessages().addAll(List.of(
|
when(existingUnit.getConversationMessages()).thenReturn(List.of(
|
||||||
message(Message.Character.USER, "m1"),
|
message(Message.Character.USER, "m1"),
|
||||||
message(Message.Character.ASSISTANT, "m2"),
|
message(Message.Character.ASSISTANT, "m2"),
|
||||||
message(Message.Character.USER, "m3"),
|
message(Message.Character.USER, "m3"),
|
||||||
message(Message.Character.ASSISTANT, "m4")
|
message(Message.Character.ASSISTANT, "m4")
|
||||||
));
|
));
|
||||||
memoryCapability.saveMemoryUnit(existingUnit);
|
memoryCapability.putUnit("session-3", existingUnit);
|
||||||
|
|
||||||
List<Message> increment = invokeResolveChatIncrement(
|
List<Message> increment = invokeResolveChatIncrement(
|
||||||
updater,
|
updater,
|
||||||
@@ -151,13 +192,13 @@ class MemoryUpdaterTest {
|
|||||||
MemoryUpdater updater = new MemoryUpdater();
|
MemoryUpdater updater = new MemoryUpdater();
|
||||||
setField(updater, "memoryCapability", memoryCapability);
|
setField(updater, "memoryCapability", memoryCapability);
|
||||||
|
|
||||||
MemoryUnit existingUnit = new MemoryUnit("session-4");
|
MemoryUnit existingUnit = Mockito.mock(MemoryUnit.class);
|
||||||
existingUnit.getConversationMessages().addAll(List.of(
|
when(existingUnit.getConversationMessages()).thenReturn(List.of(
|
||||||
message(Message.Character.USER, "m1"),
|
message(Message.Character.USER, "m1"),
|
||||||
message(Message.Character.ASSISTANT, "m2"),
|
message(Message.Character.ASSISTANT, "m2"),
|
||||||
message(Message.Character.USER, "m3")
|
message(Message.Character.USER, "m3")
|
||||||
));
|
));
|
||||||
memoryCapability.saveMemoryUnit(existingUnit);
|
memoryCapability.putUnit("session-4", existingUnit);
|
||||||
|
|
||||||
List<Message> increment = invokeResolveChatIncrement(
|
List<Message> increment = invokeResolveChatIncrement(
|
||||||
updater,
|
updater,
|
||||||
@@ -170,19 +211,47 @@ class MemoryUpdaterTest {
|
|||||||
assertEquals(List.of(), increment);
|
assertEquals(List.of(), increment);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void shouldReturnNullWhenUpdateMemoryReceivesEmptySnapshot() throws Exception {
|
||||||
|
StubMemoryCapability memoryCapability = new StubMemoryCapability("session-5");
|
||||||
|
MemoryUpdater updater = new MemoryUpdater();
|
||||||
|
MemoryRuntime memoryRuntime = Mockito.mock(MemoryRuntime.class);
|
||||||
|
setField(updater, "memoryCapability", memoryCapability);
|
||||||
|
setField(updater, "memoryRuntime", memoryRuntime);
|
||||||
|
|
||||||
|
Object rollingRecord = invokeUpdateMemory(updater, List.of());
|
||||||
|
|
||||||
|
assertNull(rollingRecord);
|
||||||
|
assertNull(memoryCapability.lastSummary());
|
||||||
|
Mockito.verifyNoInteractions(memoryRuntime);
|
||||||
|
}
|
||||||
|
|
||||||
private static final class StubMemoryCapability implements MemoryCapability {
|
private static final class StubMemoryCapability implements MemoryCapability {
|
||||||
private final String sessionId;
|
private final String sessionId;
|
||||||
private final Map<String, MemoryUnit> units = new HashMap<>();
|
private final Map<String, MemoryUnit> units = new HashMap<>();
|
||||||
|
private List<Message> lastChatMessages;
|
||||||
|
private String lastSummary;
|
||||||
|
|
||||||
private StubMemoryCapability(String sessionId) {
|
private StubMemoryCapability(String sessionId) {
|
||||||
this.sessionId = sessionId;
|
this.sessionId = sessionId;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
private void putUnit(String unitId, MemoryUnit memoryUnit) {
|
||||||
public void saveMemoryUnit(MemoryUnit memoryUnit) {
|
units.put(unitId, memoryUnit);
|
||||||
|
}
|
||||||
|
|
||||||
|
private void putUnit(MemoryUnit memoryUnit) {
|
||||||
units.put(memoryUnit.getId(), memoryUnit);
|
units.put(memoryUnit.getId(), memoryUnit);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private List<Message> lastChatMessages() {
|
||||||
|
return lastChatMessages;
|
||||||
|
}
|
||||||
|
|
||||||
|
private String lastSummary() {
|
||||||
|
return lastSummary;
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public MemoryUnit getMemoryUnit(String unitId) {
|
public MemoryUnit getMemoryUnit(String unitId) {
|
||||||
return units.get(unitId);
|
return units.get(unitId);
|
||||||
@@ -200,6 +269,18 @@ class MemoryUpdaterTest {
|
|||||||
.orElse(null);
|
.orElse(null);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public MemoryUnit updateMemoryUnit(List<Message> chatMessages, String summary) {
|
||||||
|
lastChatMessages = List.copyOf(chatMessages);
|
||||||
|
lastSummary = summary;
|
||||||
|
MemoryUnit unit = units.computeIfAbsent(sessionId, MemoryUnit::new);
|
||||||
|
unit.updateTimestamp();
|
||||||
|
int startIndex = unit.getConversationMessages().size();
|
||||||
|
unit.getConversationMessages().addAll(chatMessages);
|
||||||
|
unit.getSlices().add(new MemorySlice(startIndex, startIndex + chatMessages.size(), summary));
|
||||||
|
return unit;
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Collection<MemoryUnit> listMemoryUnits() {
|
public Collection<MemoryUnit> listMemoryUnits() {
|
||||||
return units.values();
|
return units.values();
|
||||||
|
|||||||
Reference in New Issue
Block a user