refactor(memory): correct memorySlice-topicPath binding behavior and adjust slice index into [startIndex,endIndex)

This commit is contained in:
2026-03-29 18:11:02 +08:00
parent 1c995923a1
commit c7df35beb4
6 changed files with 455 additions and 20 deletions

View File

@@ -0,0 +1,75 @@
package work.slhaf.partner.core.memory;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.Test;
import work.slhaf.partner.api.agent.runtime.config.AgentConfigLoader;
import work.slhaf.partner.api.chat.pojo.Message;
import work.slhaf.partner.common.config.Config;
import work.slhaf.partner.common.config.PartnerAgentConfigLoader;
import work.slhaf.partner.core.memory.pojo.MemorySlice;
import work.slhaf.partner.core.memory.pojo.MemoryUnit;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.UUID;
import static org.junit.jupiter.api.Assertions.assertEquals;
class MemoryCoreTest {
private AgentConfigLoader previousLoader;
private String agentId;
private static PartnerAgentConfigLoader testLoader(String agentId) {
PartnerAgentConfigLoader loader = new PartnerAgentConfigLoader();
Config config = new Config();
config.setAgentId(agentId);
Config.WebSocketConfig webSocketConfig = new Config.WebSocketConfig();
webSocketConfig.setPort(18080);
config.setWebSocketConfig(webSocketConfig);
loader.setConfig(config);
loader.setModelConfigMap(new HashMap<>());
return loader;
}
@AfterEach
void tearDown() throws Exception {
AgentConfigLoader.INSTANCE = previousLoader;
if (agentId != null) {
Files.deleteIfExists(Path.of("data/memory", agentId + "-memory-core.memory"));
Files.deleteIfExists(Path.of("data/memory", agentId + "-temp-memory-core.memory"));
}
}
@Test
void shouldNormalizeSliceEndIndexUsingExclusiveUpperBound() throws Exception {
agentId = "memory-core-test-" + UUID.randomUUID();
previousLoader = AgentConfigLoader.INSTANCE;
AgentConfigLoader.INSTANCE = testLoader(agentId);
MemoryCore memoryCore = new MemoryCore();
MemorySlice slice = new MemorySlice();
slice.setId("slice-1");
slice.setStartIndex(1);
slice.setEndIndex(99);
MemoryUnit unit = new MemoryUnit();
unit.setId("unit-1");
unit.setConversationMessages(new ArrayList<>(List.of(
new Message(Message.Character.USER, "m0"),
new Message(Message.Character.USER, "m1"),
new Message(Message.Character.USER, "m2")
)));
unit.setSlices(new ArrayList<>(List.of(slice)));
memoryCore.saveMemoryUnit(unit);
MemorySlice savedSlice = memoryCore.getMemorySlice("unit-1", "slice-1");
assertEquals(1, savedSlice.getStartIndex());
assertEquals(3, savedSlice.getEndIndex());
}
}

View File

@@ -0,0 +1,186 @@
package work.slhaf.partner.module.memory.runtime;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.Test;
import work.slhaf.partner.api.agent.runtime.config.AgentConfigLoader;
import work.slhaf.partner.api.chat.pojo.Message;
import work.slhaf.partner.common.config.Config;
import work.slhaf.partner.common.config.PartnerAgentConfigLoader;
import work.slhaf.partner.core.memory.MemoryCapability;
import work.slhaf.partner.core.memory.pojo.MemorySlice;
import work.slhaf.partner.core.memory.pojo.MemoryUnit;
import work.slhaf.partner.core.memory.pojo.SliceRef;
import java.lang.reflect.Field;
import java.lang.reflect.Method;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.*;
import java.util.concurrent.CopyOnWriteArrayList;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
class MemoryRuntimeTest {
private AgentConfigLoader previousLoader;
private String runtimeAgentId;
@SuppressWarnings("unchecked")
private static Map<String, CopyOnWriteArrayList<SliceRef>> topicSlices(MemoryRuntime runtime) throws Exception {
Field field = MemoryRuntime.class.getDeclaredField("topicSlices");
field.setAccessible(true);
return (Map<String, CopyOnWriteArrayList<SliceRef>>) field.get(runtime);
}
@SuppressWarnings("unchecked")
private static List<Message> invokeSliceMessages(MemoryRuntime runtime, MemoryUnit unit, MemorySlice slice) throws Exception {
Method method = MemoryRuntime.class.getDeclaredMethod("sliceMessages", MemoryUnit.class, MemorySlice.class);
method.setAccessible(true);
return (List<Message>) method.invoke(runtime, unit, slice);
}
private static void setField(Object target, String fieldName, Object value) throws Exception {
Field field = target.getClass().getDeclaredField(fieldName);
field.setAccessible(true);
field.set(target, value);
}
private static PartnerAgentConfigLoader testLoader(String agentId) {
PartnerAgentConfigLoader loader = new PartnerAgentConfigLoader();
Config config = new Config();
config.setAgentId(agentId);
Config.WebSocketConfig webSocketConfig = new Config.WebSocketConfig();
webSocketConfig.setPort(18080);
config.setWebSocketConfig(webSocketConfig);
loader.setConfig(config);
loader.setModelConfigMap(new HashMap<>());
return loader;
}
private static Message message(String content) {
return new Message(Message.Character.USER, content);
}
@AfterEach
void tearDown() throws Exception {
AgentConfigLoader.INSTANCE = previousLoader;
if (runtimeAgentId != null) {
Files.deleteIfExists(Path.of("data/memory", runtimeAgentId + "-memory-runtime.memory"));
Files.deleteIfExists(Path.of("data/memory", runtimeAgentId + "-memory-runtime-temp.memory"));
}
}
@Test
void shouldSliceMessagesUsingLeftClosedRightOpenRange() throws Exception {
MemoryRuntime runtime = new MemoryRuntime();
MemoryUnit unit = new MemoryUnit();
unit.setConversationMessages(new ArrayList<>(List.of(
message("m0"),
message("m1"),
message("m2"),
message("m3")
)));
MemorySlice slice = new MemorySlice();
slice.setStartIndex(1);
slice.setEndIndex(3);
List<Message> messages = invokeSliceMessages(runtime, unit, slice);
assertEquals(List.of("m1", "m2"), messages.stream().map(Message::getContent).toList());
slice.setStartIndex(2);
slice.setEndIndex(2);
assertTrue(invokeSliceMessages(runtime, unit, slice).isEmpty());
}
@Test
void shouldBindTopicToLatestMemorySliceInsteadOfFirstSlice() throws Exception {
runtimeAgentId = "runtime-test-" + UUID.randomUUID();
previousLoader = AgentConfigLoader.INSTANCE;
AgentConfigLoader.INSTANCE = testLoader(runtimeAgentId);
StubMemoryCapability memoryCapability = new StubMemoryCapability("session-test");
MemoryRuntime runtime = new MemoryRuntime();
setField(runtime, "memoryCapability", memoryCapability);
MemoryUnit unit = new MemoryUnit();
unit.setId("unit-1");
unit.setConversationMessages(new ArrayList<>(List.of(
message("m0"),
message("m1"),
message("m2"),
message("m3")
)));
MemorySlice firstSlice = new MemorySlice();
firstSlice.setId("slice-1");
firstSlice.setStartIndex(0);
firstSlice.setEndIndex(2);
firstSlice.setSummary("first");
firstSlice.setTimestamp(1L);
MemorySlice secondSlice = new MemorySlice();
secondSlice.setId("slice-2");
secondSlice.setStartIndex(2);
secondSlice.setEndIndex(4);
secondSlice.setSummary("second");
secondSlice.setTimestamp(2L);
unit.setSlices(new ArrayList<>(List.of(firstSlice, secondSlice)));
runtime.recordMemory(unit, "topic/main", List.of("topic/related"), "dialog-summary");
Map<String, CopyOnWriteArrayList<SliceRef>> topicSlices = topicSlices(runtime);
assertEquals(List.of("slice-2"),
topicSlices.get("topic/main").stream().map(SliceRef::getSliceId).toList());
assertEquals(List.of("slice-2"),
topicSlices.get("topic/related").stream().map(SliceRef::getSliceId).toList());
}
private static final class StubMemoryCapability implements MemoryCapability {
private final String sessionId;
private final Map<String, MemoryUnit> units = new HashMap<>();
private StubMemoryCapability(String sessionId) {
this.sessionId = sessionId;
}
@Override
public void saveMemoryUnit(MemoryUnit memoryUnit) {
units.put(memoryUnit.getId(), memoryUnit);
}
@Override
public MemoryUnit getMemoryUnit(String unitId) {
return units.get(unitId);
}
@Override
public MemorySlice getMemorySlice(String unitId, String sliceId) {
MemoryUnit unit = units.get(unitId);
if (unit == null || unit.getSlices() == null) {
return null;
}
return unit.getSlices().stream()
.filter(slice -> sliceId.equals(slice.getId()))
.findFirst()
.orElse(null);
}
@Override
public Collection<MemoryUnit> listMemoryUnits() {
return units.values();
}
@Override
public void refreshMemorySession() {
}
@Override
public String getMemorySessionId() {
return sessionId;
}
}
}

View File

@@ -0,0 +1,154 @@
package work.slhaf.partner.module.memory.updater;
import org.junit.jupiter.api.Test;
import work.slhaf.partner.api.chat.pojo.Message;
import work.slhaf.partner.core.memory.MemoryCapability;
import work.slhaf.partner.core.memory.pojo.MemorySlice;
import work.slhaf.partner.core.memory.pojo.MemoryUnit;
import work.slhaf.partner.module.memory.updater.summarizer.entity.SummarizeResult;
import java.lang.reflect.Field;
import java.lang.reflect.Method;
import java.util.*;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull;
class MemoryUpdaterTest {
private static MemoryUnit invokeBuildMemoryUnit(MemoryUpdater updater,
List<Message> chatMessages,
SummarizeResult summarizeResult) throws Exception {
Method method = MemoryUpdater.class.getDeclaredMethod("buildMemoryUnit", List.class, SummarizeResult.class);
method.setAccessible(true);
return (MemoryUnit) method.invoke(updater, chatMessages, summarizeResult);
}
private static void setField(Object target, String fieldName, Object value) throws Exception {
Field field = target.getClass().getDeclaredField(fieldName);
field.setAccessible(true);
field.set(target, value);
}
private static SummarizeResult summarizeResult(String summary) {
SummarizeResult result = new SummarizeResult();
result.setSummary(summary);
return result;
}
private static Message message(Message.Character role, String content) {
return new Message(role, content);
}
@Test
void shouldAppendNewSliceToExistingMemoryUnitWithinSameSession() throws Exception {
StubMemoryCapability memoryCapability = new StubMemoryCapability("session-1");
MemoryUpdater updater = new MemoryUpdater();
setField(updater, "memoryCapability", memoryCapability);
String sessionId = memoryCapability.getMemorySessionId();
MemoryUnit existingUnit = new MemoryUnit();
existingUnit.setId(sessionId);
existingUnit.setConversationMessages(new ArrayList<>(List.of(
message(Message.Character.USER, "old-user"),
message(Message.Character.ASSISTANT, "old-assistant")
)));
MemorySlice existingSlice = new MemorySlice();
existingSlice.setId("slice-1");
existingSlice.setStartIndex(0);
existingSlice.setEndIndex(2);
existingSlice.setSummary("old-summary");
existingSlice.setTimestamp(1L);
existingUnit.setSlices(new ArrayList<>(List.of(existingSlice)));
memoryCapability.saveMemoryUnit(existingUnit);
MemoryUnit merged = invokeBuildMemoryUnit(
updater,
List.of(
message(Message.Character.USER, "new-user"),
message(Message.Character.ASSISTANT, "new-assistant")
),
summarizeResult("new-summary")
);
assertEquals(sessionId, merged.getId());
assertEquals(4, merged.getConversationMessages().size());
assertEquals(List.of("old-user", "old-assistant", "new-user", "new-assistant"),
merged.getConversationMessages().stream().map(Message::getContent).toList());
assertEquals(2, merged.getSlices().size());
MemorySlice appendedSlice = merged.getSlices().get(1);
assertNotNull(appendedSlice.getId());
assertEquals(2, appendedSlice.getStartIndex());
assertEquals(4, appendedSlice.getEndIndex());
assertEquals("new-summary", appendedSlice.getSummary());
}
@Test
void shouldCreateNewMemoryUnitForNewSessionId() throws Exception {
StubMemoryCapability memoryCapability = new StubMemoryCapability("session-2");
MemoryUpdater updater = new MemoryUpdater();
setField(updater, "memoryCapability", memoryCapability);
MemoryUnit created = invokeBuildMemoryUnit(
updater,
List.of(
message(Message.Character.USER, "first"),
message(Message.Character.ASSISTANT, "second")
),
summarizeResult("fresh-summary")
);
assertEquals("session-2", created.getId());
assertEquals(2, created.getConversationMessages().size());
assertEquals(1, created.getSlices().size());
assertEquals(0, created.getSlices().getFirst().getStartIndex());
assertEquals(2, created.getSlices().getFirst().getEndIndex());
assertEquals("fresh-summary", created.getSlices().getFirst().getSummary());
}
private static final class StubMemoryCapability implements MemoryCapability {
private final String sessionId;
private final Map<String, MemoryUnit> units = new HashMap<>();
private StubMemoryCapability(String sessionId) {
this.sessionId = sessionId;
}
@Override
public void saveMemoryUnit(MemoryUnit memoryUnit) {
units.put(memoryUnit.getId(), memoryUnit);
}
@Override
public MemoryUnit getMemoryUnit(String unitId) {
return units.get(unitId);
}
@Override
public MemorySlice getMemorySlice(String unitId, String sliceId) {
MemoryUnit unit = units.get(unitId);
if (unit == null || unit.getSlices() == null) {
return null;
}
return unit.getSlices().stream()
.filter(slice -> sliceId.equals(slice.getId()))
.findFirst()
.orElse(null);
}
@Override
public Collection<MemoryUnit> listMemoryUnits() {
return units.values();
}
@Override
public void refreshMemorySession() {
}
@Override
public String getMemorySessionId() {
return sessionId;
}
}
}