fix(memory): restore normalizing logic in memory core, and fix errors in MemoryCoreTest

This commit is contained in:
2026-04-07 15:02:55 +08:00
parent 2cbaccedba
commit 6fd12cd19f
2 changed files with 55 additions and 4 deletions

View File

@@ -84,6 +84,38 @@ public class MemoryCore implements StateSerializable {
}
private void normalizeMemoryUnit(MemoryUnit memoryUnit) {
if (memoryUnit.getTimestamp() == null || memoryUnit.getTimestamp() <= 0) {
memoryUnit.updateTimestamp();
}
int maxEndExclusive = memoryUnit.getConversationMessages().size();
List<MemorySlice> normalizedSlices = new ArrayList<>(memoryUnit.getSlices().size());
for (MemorySlice slice : memoryUnit.getSlices()) {
if (slice == null) {
continue;
}
String sliceId = slice.getId();
if (sliceId == null || sliceId.isBlank()) {
sliceId = UUID.randomUUID().toString();
}
long sliceTimestamp = slice.getTimestamp() == null || slice.getTimestamp() <= 0
? memoryUnit.getTimestamp()
: slice.getTimestamp();
int startIndex = slice.getStartIndex() == null || slice.getStartIndex() < 0
? 0
: Math.min(slice.getStartIndex(), maxEndExclusive);
int endIndex = slice.getEndIndex() == null || slice.getEndIndex() < startIndex
? maxEndExclusive
: Math.min(slice.getEndIndex(), maxEndExclusive);
normalizedSlices.add(MemorySlice.restore(
sliceId,
startIndex,
endIndex,
slice.getSummary(),
sliceTimestamp
));
}
memoryUnit.getSlices().clear();
memoryUnit.getSlices().addAll(normalizedSlices);
memoryUnit.getSlices().sort(Comparator.naturalOrder());
}

View File

@@ -16,6 +16,7 @@ import java.util.List;
import java.util.UUID;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
class MemoryCoreTest {
@@ -51,10 +52,7 @@ class MemoryCoreTest {
MemoryCore memoryCore = new MemoryCore();
MemorySlice slice = new MemorySlice();
slice.setId("slice-1");
slice.setStartIndex(1);
slice.setEndIndex(99);
MemorySlice slice = MemorySlice.restore("slice-1", 1, 99, null, 1L);
MemoryUnit unit = new MemoryUnit("unit-1");
unit.getConversationMessages().addAll(List.of(
@@ -70,4 +68,25 @@ class MemoryCoreTest {
assertEquals(1, savedSlice.getStartIndex());
assertEquals(3, savedSlice.getEndIndex());
}
@Test
void shouldFillMissingTimestampsWhenSavingMemoryUnit() throws Exception {
agentId = "memory-core-test-" + UUID.randomUUID();
previousLoader = AgentConfigLoader.INSTANCE;
AgentConfigLoader.INSTANCE = testLoader(agentId);
MemoryCore memoryCore = new MemoryCore();
MemorySlice slice = MemorySlice.restore("slice-1", 0, 1, "summary", 0L);
MemoryUnit unit = new MemoryUnit("unit-1");
unit.getConversationMessages().add(new Message(Message.Character.USER, "m0"));
unit.getSlices().add(slice);
memoryCore.saveMemoryUnit(unit);
MemoryUnit savedUnit = memoryCore.getMemoryUnit("unit-1");
MemorySlice savedSlice = memoryCore.getMemorySlice("unit-1", "slice-1");
assertTrue(savedUnit.getTimestamp() > 0);
assertEquals(savedUnit.getTimestamp(), savedSlice.getTimestamp());
}
}