refactor(memory): correct memorySlice-topicPath binding behavior and adjust slice index into [startIndex,endIndex)

This commit is contained in:
2026-03-29 18:11:02 +08:00
parent 1c995923a1
commit c7df35beb4
6 changed files with 455 additions and 20 deletions

View File

@@ -100,7 +100,7 @@ public class MemoryCore extends PartnerCore<MemoryCore> {
if (memoryUnit.getSlices() == null) { if (memoryUnit.getSlices() == null) {
memoryUnit.setSlices(new ArrayList<>()); memoryUnit.setSlices(new ArrayList<>());
} }
int maxIndex = Math.max(memoryUnit.getConversationMessages().size() - 1, 0); int maxEndExclusive = Math.max(memoryUnit.getConversationMessages().size(), 0);
for (MemorySlice slice : memoryUnit.getSlices()) { for (MemorySlice slice : memoryUnit.getSlices()) {
if (slice.getId() == null || slice.getId().isBlank()) { if (slice.getId() == null || slice.getId().isBlank()) {
slice.setId(UUID.randomUUID().toString()); slice.setId(UUID.randomUUID().toString());
@@ -111,11 +111,14 @@ public class MemoryCore extends PartnerCore<MemoryCore> {
if (slice.getStartIndex() == null || slice.getStartIndex() < 0) { if (slice.getStartIndex() == null || slice.getStartIndex() < 0) {
slice.setStartIndex(0); slice.setStartIndex(0);
} }
if (slice.getEndIndex() == null || slice.getEndIndex() < slice.getStartIndex()) { if (slice.getStartIndex() > maxEndExclusive) {
slice.setEndIndex(maxIndex); slice.setStartIndex(maxEndExclusive);
} }
if (slice.getEndIndex() > maxIndex) { if (slice.getEndIndex() == null || slice.getEndIndex() < slice.getStartIndex()) {
slice.setEndIndex(maxIndex); slice.setEndIndex(maxEndExclusive);
}
if (slice.getEndIndex() > maxEndExclusive) {
slice.setEndIndex(maxEndExclusive);
} }
} }
memoryUnit.getSlices().sort(Comparator.naturalOrder()); memoryUnit.getSlices().sort(Comparator.naturalOrder());

View File

@@ -7,6 +7,7 @@ import work.slhaf.partner.api.agent.factory.capability.annotation.InjectCapabili
import work.slhaf.partner.api.agent.factory.component.abstracts.AbstractAgentModule; import work.slhaf.partner.api.agent.factory.component.abstracts.AbstractAgentModule;
import work.slhaf.partner.api.agent.factory.component.annotation.Init; import work.slhaf.partner.api.agent.factory.component.annotation.Init;
import work.slhaf.partner.api.agent.runtime.config.AgentConfigLoader; import work.slhaf.partner.api.agent.runtime.config.AgentConfigLoader;
import work.slhaf.partner.api.chat.pojo.Message;
import work.slhaf.partner.api.common.entity.PersistableObject; import work.slhaf.partner.api.common.entity.PersistableObject;
import work.slhaf.partner.common.config.PartnerAgentConfigLoader; import work.slhaf.partner.common.config.PartnerAgentConfigLoader;
import work.slhaf.partner.core.cognition.CognitionCapability; import work.slhaf.partner.core.cognition.CognitionCapability;
@@ -80,7 +81,7 @@ public class MemoryRuntime extends AbstractAgentModule.Standalone {
public void recordMemory(MemoryUnit memoryUnit, String topicPath, List<String> relatedTopicPaths, String dialogSummary) { public void recordMemory(MemoryUnit memoryUnit, String topicPath, List<String> relatedTopicPaths, String dialogSummary) {
memoryCapability.saveMemoryUnit(memoryUnit); memoryCapability.saveMemoryUnit(memoryUnit);
MemorySlice memorySlice = memoryUnit.getSlices().getFirst(); MemorySlice memorySlice = memoryUnit.getSlices().getLast();
SliceRef sliceRef = new SliceRef(memoryUnit.getId(), memorySlice.getId()); SliceRef sliceRef = new SliceRef(memoryUnit.getId(), memorySlice.getId());
indexMemoryUnit(memoryUnit); indexMemoryUnit(memoryUnit);
bindTopic(topicPath, sliceRef); bindTopic(topicPath, sliceRef);
@@ -209,7 +210,7 @@ public class MemoryRuntime extends AbstractAgentModule.Standalone {
if (memoryUnit == null || memorySlice == null) { if (memoryUnit == null || memorySlice == null) {
return null; return null;
} }
List<work.slhaf.partner.api.chat.pojo.Message> messages = sliceMessages(memoryUnit, memorySlice); List<Message> messages = sliceMessages(memoryUnit, memorySlice);
LocalDate date = Instant.ofEpochMilli(memorySlice.getTimestamp()) LocalDate date = Instant.ofEpochMilli(memorySlice.getTimestamp())
.atZone(ZoneId.systemDefault()) .atZone(ZoneId.systemDefault())
.toLocalDate(); .toLocalDate();
@@ -223,17 +224,18 @@ public class MemoryRuntime extends AbstractAgentModule.Standalone {
.build(); .build();
} }
private List<work.slhaf.partner.api.chat.pojo.Message> sliceMessages(MemoryUnit memoryUnit, MemorySlice memorySlice) { private List<Message> sliceMessages(MemoryUnit memoryUnit, MemorySlice memorySlice) {
List<work.slhaf.partner.api.chat.pojo.Message> conversationMessages = memoryUnit.getConversationMessages(); List<Message> conversationMessages = memoryUnit.getConversationMessages();
if (conversationMessages == null || conversationMessages.isEmpty()) { if (conversationMessages == null || conversationMessages.isEmpty()) {
return List.of(); return List.of();
} }
int start = Math.max(0, memorySlice.getStartIndex()); int size = conversationMessages.size();
int end = Math.min(conversationMessages.size() - 1, memorySlice.getEndIndex()); int start = Math.max(0, Math.min(memorySlice.getStartIndex(), size));
if (start > end) { int end = Math.max(start, Math.min(memorySlice.getEndIndex(), size));
if (start >= end) {
return List.of(); return List.of();
} }
return new ArrayList<>(conversationMessages.subList(start, end + 1)); return new ArrayList<>(conversationMessages.subList(start, end));
} }
private void printSubTopicsTreeFormat(TopicTreeNode node, String prefix, StringBuilder stringBuilder) { private void printSubTopicsTreeFormat(TopicTreeNode node, String prefix, StringBuilder stringBuilder) {

View File

@@ -4,6 +4,7 @@ import com.alibaba.fastjson2.JSONObject;
import kotlin.Unit; import kotlin.Unit;
import lombok.Data; import lombok.Data;
import lombok.EqualsAndHashCode; import lombok.EqualsAndHashCode;
import org.jetbrains.annotations.NotNull;
import work.slhaf.partner.api.agent.factory.capability.annotation.InjectCapability; import work.slhaf.partner.api.agent.factory.capability.annotation.InjectCapability;
import work.slhaf.partner.api.agent.factory.component.abstracts.AbstractAgentModule; import work.slhaf.partner.api.agent.factory.component.abstracts.AbstractAgentModule;
import work.slhaf.partner.api.agent.factory.component.annotation.Init; import work.slhaf.partner.api.agent.factory.component.annotation.Init;
@@ -81,7 +82,7 @@ public class MemoryUpdater extends AbstractAgentModule.Running<PartnerRunningFlo
} }
@Override @Override
public void execute(PartnerRunningFlowContext context) { public void execute(@NotNull PartnerRunningFlowContext context) {
boolean trigger = cognitionCapability.getChatMessages().size() >= MEMORY_UPDATE_TRIGGER_ROLL_LIMIT; boolean trigger = cognitionCapability.getChatMessages().size() >= MEMORY_UPDATE_TRIGGER_ROLL_LIMIT;
if (!trigger) { if (!trigger) {
return; return;
@@ -149,19 +150,33 @@ public class MemoryUpdater extends AbstractAgentModule.Running<PartnerRunningFlo
private MemoryUnit buildMemoryUnit(List<Message> chatMessages, SummarizeResult summarizeResult) { private MemoryUnit buildMemoryUnit(List<Message> chatMessages, SummarizeResult summarizeResult) {
long now = System.currentTimeMillis(); long now = System.currentTimeMillis();
String memoryId = memoryCapability.getMemorySessionId();
String resolvedMemoryId = memoryId == null || memoryId.isBlank() ? UUID.randomUUID().toString() : memoryId;
MemoryUnit existingUnit = memoryCapability.getMemoryUnit(resolvedMemoryId);
List<Message> existingMessages = existingUnit != null && existingUnit.getConversationMessages() != null
? existingUnit.getConversationMessages()
: List.of();
int startIndex = existingMessages.size();
MemorySlice memorySlice = new MemorySlice(); MemorySlice memorySlice = new MemorySlice();
memorySlice.setId(UUID.randomUUID().toString()); memorySlice.setId(UUID.randomUUID().toString());
memorySlice.setStartIndex(0); memorySlice.setStartIndex(startIndex);
memorySlice.setEndIndex(chatMessages.size()); memorySlice.setEndIndex(startIndex + chatMessages.size());
memorySlice.setSummary(summarizeResult.getSummary()); memorySlice.setSummary(summarizeResult.getSummary());
memorySlice.setTimestamp(now); memorySlice.setTimestamp(now);
MemoryUnit memoryUnit = new MemoryUnit(); MemoryUnit memoryUnit = new MemoryUnit();
String memoryId = memoryCapability.getMemorySessionId(); memoryUnit.setId(resolvedMemoryId);
memoryUnit.setId(memoryId == null || memoryId.isBlank() ? UUID.randomUUID().toString() : memoryId);
memoryUnit.setTimestamp(now); memoryUnit.setTimestamp(now);
memoryUnit.setConversationMessages(new ArrayList<>(chatMessages)); List<Message> conversationMessages = new ArrayList<>(existingMessages);
memoryUnit.setSlices(new ArrayList<>(List.of(memorySlice))); conversationMessages.addAll(chatMessages);
memoryUnit.setConversationMessages(conversationMessages);
List<MemorySlice> slices = existingUnit != null && existingUnit.getSlices() != null
? new ArrayList<>(existingUnit.getSlices())
: new ArrayList<>();
slices.add(memorySlice);
memoryUnit.setSlices(slices);
return memoryUnit; return memoryUnit;
} }

View File

@@ -0,0 +1,75 @@
package work.slhaf.partner.core.memory;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.Test;
import work.slhaf.partner.api.agent.runtime.config.AgentConfigLoader;
import work.slhaf.partner.api.chat.pojo.Message;
import work.slhaf.partner.common.config.Config;
import work.slhaf.partner.common.config.PartnerAgentConfigLoader;
import work.slhaf.partner.core.memory.pojo.MemorySlice;
import work.slhaf.partner.core.memory.pojo.MemoryUnit;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.UUID;
import static org.junit.jupiter.api.Assertions.assertEquals;
class MemoryCoreTest {
private AgentConfigLoader previousLoader;
private String agentId;
private static PartnerAgentConfigLoader testLoader(String agentId) {
PartnerAgentConfigLoader loader = new PartnerAgentConfigLoader();
Config config = new Config();
config.setAgentId(agentId);
Config.WebSocketConfig webSocketConfig = new Config.WebSocketConfig();
webSocketConfig.setPort(18080);
config.setWebSocketConfig(webSocketConfig);
loader.setConfig(config);
loader.setModelConfigMap(new HashMap<>());
return loader;
}
@AfterEach
void tearDown() throws Exception {
AgentConfigLoader.INSTANCE = previousLoader;
if (agentId != null) {
Files.deleteIfExists(Path.of("data/memory", agentId + "-memory-core.memory"));
Files.deleteIfExists(Path.of("data/memory", agentId + "-temp-memory-core.memory"));
}
}
@Test
void shouldNormalizeSliceEndIndexUsingExclusiveUpperBound() throws Exception {
agentId = "memory-core-test-" + UUID.randomUUID();
previousLoader = AgentConfigLoader.INSTANCE;
AgentConfigLoader.INSTANCE = testLoader(agentId);
MemoryCore memoryCore = new MemoryCore();
MemorySlice slice = new MemorySlice();
slice.setId("slice-1");
slice.setStartIndex(1);
slice.setEndIndex(99);
MemoryUnit unit = new MemoryUnit();
unit.setId("unit-1");
unit.setConversationMessages(new ArrayList<>(List.of(
new Message(Message.Character.USER, "m0"),
new Message(Message.Character.USER, "m1"),
new Message(Message.Character.USER, "m2")
)));
unit.setSlices(new ArrayList<>(List.of(slice)));
memoryCore.saveMemoryUnit(unit);
MemorySlice savedSlice = memoryCore.getMemorySlice("unit-1", "slice-1");
assertEquals(1, savedSlice.getStartIndex());
assertEquals(3, savedSlice.getEndIndex());
}
}

View File

@@ -0,0 +1,186 @@
package work.slhaf.partner.module.memory.runtime;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.Test;
import work.slhaf.partner.api.agent.runtime.config.AgentConfigLoader;
import work.slhaf.partner.api.chat.pojo.Message;
import work.slhaf.partner.common.config.Config;
import work.slhaf.partner.common.config.PartnerAgentConfigLoader;
import work.slhaf.partner.core.memory.MemoryCapability;
import work.slhaf.partner.core.memory.pojo.MemorySlice;
import work.slhaf.partner.core.memory.pojo.MemoryUnit;
import work.slhaf.partner.core.memory.pojo.SliceRef;
import java.lang.reflect.Field;
import java.lang.reflect.Method;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.*;
import java.util.concurrent.CopyOnWriteArrayList;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
class MemoryRuntimeTest {
private AgentConfigLoader previousLoader;
private String runtimeAgentId;
@SuppressWarnings("unchecked")
private static Map<String, CopyOnWriteArrayList<SliceRef>> topicSlices(MemoryRuntime runtime) throws Exception {
Field field = MemoryRuntime.class.getDeclaredField("topicSlices");
field.setAccessible(true);
return (Map<String, CopyOnWriteArrayList<SliceRef>>) field.get(runtime);
}
@SuppressWarnings("unchecked")
private static List<Message> invokeSliceMessages(MemoryRuntime runtime, MemoryUnit unit, MemorySlice slice) throws Exception {
Method method = MemoryRuntime.class.getDeclaredMethod("sliceMessages", MemoryUnit.class, MemorySlice.class);
method.setAccessible(true);
return (List<Message>) method.invoke(runtime, unit, slice);
}
private static void setField(Object target, String fieldName, Object value) throws Exception {
Field field = target.getClass().getDeclaredField(fieldName);
field.setAccessible(true);
field.set(target, value);
}
private static PartnerAgentConfigLoader testLoader(String agentId) {
PartnerAgentConfigLoader loader = new PartnerAgentConfigLoader();
Config config = new Config();
config.setAgentId(agentId);
Config.WebSocketConfig webSocketConfig = new Config.WebSocketConfig();
webSocketConfig.setPort(18080);
config.setWebSocketConfig(webSocketConfig);
loader.setConfig(config);
loader.setModelConfigMap(new HashMap<>());
return loader;
}
private static Message message(String content) {
return new Message(Message.Character.USER, content);
}
@AfterEach
void tearDown() throws Exception {
AgentConfigLoader.INSTANCE = previousLoader;
if (runtimeAgentId != null) {
Files.deleteIfExists(Path.of("data/memory", runtimeAgentId + "-memory-runtime.memory"));
Files.deleteIfExists(Path.of("data/memory", runtimeAgentId + "-memory-runtime-temp.memory"));
}
}
@Test
void shouldSliceMessagesUsingLeftClosedRightOpenRange() throws Exception {
MemoryRuntime runtime = new MemoryRuntime();
MemoryUnit unit = new MemoryUnit();
unit.setConversationMessages(new ArrayList<>(List.of(
message("m0"),
message("m1"),
message("m2"),
message("m3")
)));
MemorySlice slice = new MemorySlice();
slice.setStartIndex(1);
slice.setEndIndex(3);
List<Message> messages = invokeSliceMessages(runtime, unit, slice);
assertEquals(List.of("m1", "m2"), messages.stream().map(Message::getContent).toList());
slice.setStartIndex(2);
slice.setEndIndex(2);
assertTrue(invokeSliceMessages(runtime, unit, slice).isEmpty());
}
@Test
void shouldBindTopicToLatestMemorySliceInsteadOfFirstSlice() throws Exception {
runtimeAgentId = "runtime-test-" + UUID.randomUUID();
previousLoader = AgentConfigLoader.INSTANCE;
AgentConfigLoader.INSTANCE = testLoader(runtimeAgentId);
StubMemoryCapability memoryCapability = new StubMemoryCapability("session-test");
MemoryRuntime runtime = new MemoryRuntime();
setField(runtime, "memoryCapability", memoryCapability);
MemoryUnit unit = new MemoryUnit();
unit.setId("unit-1");
unit.setConversationMessages(new ArrayList<>(List.of(
message("m0"),
message("m1"),
message("m2"),
message("m3")
)));
MemorySlice firstSlice = new MemorySlice();
firstSlice.setId("slice-1");
firstSlice.setStartIndex(0);
firstSlice.setEndIndex(2);
firstSlice.setSummary("first");
firstSlice.setTimestamp(1L);
MemorySlice secondSlice = new MemorySlice();
secondSlice.setId("slice-2");
secondSlice.setStartIndex(2);
secondSlice.setEndIndex(4);
secondSlice.setSummary("second");
secondSlice.setTimestamp(2L);
unit.setSlices(new ArrayList<>(List.of(firstSlice, secondSlice)));
runtime.recordMemory(unit, "topic/main", List.of("topic/related"), "dialog-summary");
Map<String, CopyOnWriteArrayList<SliceRef>> topicSlices = topicSlices(runtime);
assertEquals(List.of("slice-2"),
topicSlices.get("topic/main").stream().map(SliceRef::getSliceId).toList());
assertEquals(List.of("slice-2"),
topicSlices.get("topic/related").stream().map(SliceRef::getSliceId).toList());
}
private static final class StubMemoryCapability implements MemoryCapability {
private final String sessionId;
private final Map<String, MemoryUnit> units = new HashMap<>();
private StubMemoryCapability(String sessionId) {
this.sessionId = sessionId;
}
@Override
public void saveMemoryUnit(MemoryUnit memoryUnit) {
units.put(memoryUnit.getId(), memoryUnit);
}
@Override
public MemoryUnit getMemoryUnit(String unitId) {
return units.get(unitId);
}
@Override
public MemorySlice getMemorySlice(String unitId, String sliceId) {
MemoryUnit unit = units.get(unitId);
if (unit == null || unit.getSlices() == null) {
return null;
}
return unit.getSlices().stream()
.filter(slice -> sliceId.equals(slice.getId()))
.findFirst()
.orElse(null);
}
@Override
public Collection<MemoryUnit> listMemoryUnits() {
return units.values();
}
@Override
public void refreshMemorySession() {
}
@Override
public String getMemorySessionId() {
return sessionId;
}
}
}

View File

@@ -0,0 +1,154 @@
package work.slhaf.partner.module.memory.updater;
import org.junit.jupiter.api.Test;
import work.slhaf.partner.api.chat.pojo.Message;
import work.slhaf.partner.core.memory.MemoryCapability;
import work.slhaf.partner.core.memory.pojo.MemorySlice;
import work.slhaf.partner.core.memory.pojo.MemoryUnit;
import work.slhaf.partner.module.memory.updater.summarizer.entity.SummarizeResult;
import java.lang.reflect.Field;
import java.lang.reflect.Method;
import java.util.*;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull;
class MemoryUpdaterTest {
private static MemoryUnit invokeBuildMemoryUnit(MemoryUpdater updater,
List<Message> chatMessages,
SummarizeResult summarizeResult) throws Exception {
Method method = MemoryUpdater.class.getDeclaredMethod("buildMemoryUnit", List.class, SummarizeResult.class);
method.setAccessible(true);
return (MemoryUnit) method.invoke(updater, chatMessages, summarizeResult);
}
private static void setField(Object target, String fieldName, Object value) throws Exception {
Field field = target.getClass().getDeclaredField(fieldName);
field.setAccessible(true);
field.set(target, value);
}
private static SummarizeResult summarizeResult(String summary) {
SummarizeResult result = new SummarizeResult();
result.setSummary(summary);
return result;
}
private static Message message(Message.Character role, String content) {
return new Message(role, content);
}
@Test
void shouldAppendNewSliceToExistingMemoryUnitWithinSameSession() throws Exception {
StubMemoryCapability memoryCapability = new StubMemoryCapability("session-1");
MemoryUpdater updater = new MemoryUpdater();
setField(updater, "memoryCapability", memoryCapability);
String sessionId = memoryCapability.getMemorySessionId();
MemoryUnit existingUnit = new MemoryUnit();
existingUnit.setId(sessionId);
existingUnit.setConversationMessages(new ArrayList<>(List.of(
message(Message.Character.USER, "old-user"),
message(Message.Character.ASSISTANT, "old-assistant")
)));
MemorySlice existingSlice = new MemorySlice();
existingSlice.setId("slice-1");
existingSlice.setStartIndex(0);
existingSlice.setEndIndex(2);
existingSlice.setSummary("old-summary");
existingSlice.setTimestamp(1L);
existingUnit.setSlices(new ArrayList<>(List.of(existingSlice)));
memoryCapability.saveMemoryUnit(existingUnit);
MemoryUnit merged = invokeBuildMemoryUnit(
updater,
List.of(
message(Message.Character.USER, "new-user"),
message(Message.Character.ASSISTANT, "new-assistant")
),
summarizeResult("new-summary")
);
assertEquals(sessionId, merged.getId());
assertEquals(4, merged.getConversationMessages().size());
assertEquals(List.of("old-user", "old-assistant", "new-user", "new-assistant"),
merged.getConversationMessages().stream().map(Message::getContent).toList());
assertEquals(2, merged.getSlices().size());
MemorySlice appendedSlice = merged.getSlices().get(1);
assertNotNull(appendedSlice.getId());
assertEquals(2, appendedSlice.getStartIndex());
assertEquals(4, appendedSlice.getEndIndex());
assertEquals("new-summary", appendedSlice.getSummary());
}
@Test
void shouldCreateNewMemoryUnitForNewSessionId() throws Exception {
StubMemoryCapability memoryCapability = new StubMemoryCapability("session-2");
MemoryUpdater updater = new MemoryUpdater();
setField(updater, "memoryCapability", memoryCapability);
MemoryUnit created = invokeBuildMemoryUnit(
updater,
List.of(
message(Message.Character.USER, "first"),
message(Message.Character.ASSISTANT, "second")
),
summarizeResult("fresh-summary")
);
assertEquals("session-2", created.getId());
assertEquals(2, created.getConversationMessages().size());
assertEquals(1, created.getSlices().size());
assertEquals(0, created.getSlices().getFirst().getStartIndex());
assertEquals(2, created.getSlices().getFirst().getEndIndex());
assertEquals("fresh-summary", created.getSlices().getFirst().getSummary());
}
private static final class StubMemoryCapability implements MemoryCapability {
private final String sessionId;
private final Map<String, MemoryUnit> units = new HashMap<>();
private StubMemoryCapability(String sessionId) {
this.sessionId = sessionId;
}
@Override
public void saveMemoryUnit(MemoryUnit memoryUnit) {
units.put(memoryUnit.getId(), memoryUnit);
}
@Override
public MemoryUnit getMemoryUnit(String unitId) {
return units.get(unitId);
}
@Override
public MemorySlice getMemorySlice(String unitId, String sliceId) {
MemoryUnit unit = units.get(unitId);
if (unit == null || unit.getSlices() == null) {
return null;
}
return unit.getSlices().stream()
.filter(slice -> sliceId.equals(slice.getId()))
.findFirst()
.orElse(null);
}
@Override
public Collection<MemoryUnit> listMemoryUnits() {
return units.values();
}
@Override
public void refreshMemorySession() {
}
@Override
public String getMemorySessionId() {
return sessionId;
}
}
}