refactor(memory): enhance topic-based memory runtime on recalling and indexing

This commit is contained in:
2026-04-18 22:28:40 +08:00
parent 92c8e01000
commit a7ef9bff49
12 changed files with 1022 additions and 267 deletions

View File

@@ -11,9 +11,9 @@ import work.slhaf.partner.core.cognition.CognitionCapability;
import work.slhaf.partner.core.memory.MemoryCapability;
import work.slhaf.partner.core.memory.pojo.MemorySlice;
import work.slhaf.partner.core.memory.pojo.MemoryUnit;
import work.slhaf.partner.core.memory.pojo.SliceRef;
import work.slhaf.partner.framework.agent.model.pojo.Message;
import work.slhaf.partner.framework.agent.support.Result;
import work.slhaf.partner.module.memory.pojo.ActivationProfile;
import work.slhaf.partner.module.memory.runtime.exception.MemoryLookupException;
import work.slhaf.partner.module.memory.selector.ActivatedMemorySlice;
@@ -25,7 +25,6 @@ import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReentrantLock;
@@ -34,18 +33,13 @@ import static org.junit.jupiter.api.Assertions.assertTrue;
class MemoryRuntimeTest {
private static final ActivationProfile DEFAULT_PROFILE = new ActivationProfile(0.55f, 0.35f, 0.50f);
@BeforeAll
public static void beforeAll(@TempDir Path tempDir) {
System.setProperty("user.home", tempDir.toAbsolutePath().toString());
}
@SuppressWarnings("unchecked")
private static Map<String, CopyOnWriteArrayList<SliceRef>> topicSlices(MemoryRuntime runtime) throws Exception {
Field field = MemoryRuntime.class.getDeclaredField("topicSlices");
field.setAccessible(true);
return (Map<String, CopyOnWriteArrayList<SliceRef>>) field.get(runtime);
}
@SuppressWarnings("unchecked")
private static List<Message> invokeSliceMessages(MemoryRuntime runtime, MemoryUnit unit, MemorySlice slice) throws Exception {
Method method = MemoryRuntime.class.getDeclaredMethod("sliceMessages", MemoryUnit.class, MemorySlice.class);
@@ -152,13 +146,92 @@ class MemoryRuntimeTest {
unit.getSlices().addAll(List.of(firstSlice, secondSlice));
memoryCapability.remember(unit);
runtime.recordMemory(unit, "topic/main", List.of("topic/related"));
runtime.recordMemory(unit, "topic/main", List.of("topic/related"), DEFAULT_PROFILE);
Map<String, CopyOnWriteArrayList<SliceRef>> topicSlices = topicSlices(runtime);
List<ActivatedMemorySlice> topicResult = runtime.queryActivatedMemoryByTopicPath("topic/main");
assertEquals(List.of("slice-2"), topicResult.stream().map(ActivatedMemorySlice::getSliceId).toList());
assertTrue(runtime.getTopicTree().contains("topic/main [root]"));
assertTrue(runtime.getTopicTree().contains("topic/related [root]"));
assertTrue(JSONObject.parseObject(runtime.convert().toString())
.getJSONArray("topic_slices")
.stream()
.map(JSONObject.class::cast)
.anyMatch(item -> "topic/main".equals(item.getString("topic_path"))));
}
@Test
void shouldExpandTopicQueryToLatestRelatedTopicMemory() throws Exception {
StubMemoryCapability memoryCapability = new StubMemoryCapability("session-test");
MemoryRuntime runtime = new MemoryRuntime();
setField(runtime, "memoryCapability", memoryCapability);
setField(runtime, "cognitionCapability", stubCognitionCapability(List.of(message("seed"))));
MemoryUnit mainUnit = new MemoryUnit("unit-main");
mainUnit.getConversationMessages().addAll(List.of(
message("m0"),
message("m1"),
message("m2"),
message("m3")
));
MemorySlice mainSlice = MemorySlice.restore("slice-main", 0, 2, "main", 86_400_000L);
mainUnit.getSlices().add(mainSlice);
memoryCapability.remember(mainUnit);
MemoryUnit relatedUnit = new MemoryUnit("unit-related");
relatedUnit.getConversationMessages().addAll(List.of(
message("r0"),
message("r1")
));
MemorySlice relatedSlice = MemorySlice.restore("slice-related", 0, 2, "related", 172_800_000L);
relatedUnit.getSlices().add(relatedSlice);
memoryCapability.remember(relatedUnit);
runtime.recordMemory(mainUnit, "topic/main", List.of("topic/related"), DEFAULT_PROFILE);
runtime.recordMemory(relatedUnit, "topic/related", List.of(), DEFAULT_PROFILE);
List<ActivatedMemorySlice> topicResult = runtime.queryActivatedMemoryByTopicPath("topic/main");
assertEquals(List.of("slice-main", "slice-related"),
topicResult.stream().map(ActivatedMemorySlice::getSliceId).toList());
}
@Test
void shouldIndexDateIncrementallyWithoutRebuildingWholeUnit() throws Exception {
StubMemoryCapability memoryCapability = new StubMemoryCapability("session-test");
MemoryRuntime runtime = new MemoryRuntime();
setField(runtime, "memoryCapability", memoryCapability);
setField(runtime, "cognitionCapability", stubCognitionCapability(List.of(message("seed"))));
MemoryUnit firstUnitSnapshot = new MemoryUnit("unit-100");
firstUnitSnapshot.getConversationMessages().addAll(List.of(message("m0"), message("m1")));
MemorySlice firstSlice = MemorySlice.restore("slice-1", 0, 1, "first", 86_400_000L);
firstUnitSnapshot.getSlices().add(firstSlice);
memoryCapability.remember(firstUnitSnapshot);
runtime.recordMemory(firstUnitSnapshot, "topic/main", List.of(), DEFAULT_PROFILE);
firstUnitSnapshot.getConversationMessages().clear();
firstUnitSnapshot.getConversationMessages().addAll(List.of(message("m2"), message("m3")));
MemorySlice secondSlice = MemorySlice.restore("slice-2", 0, 1, "second", 172_800_000L);
firstUnitSnapshot.getSlices().clear();
firstUnitSnapshot.getSlices().add(secondSlice);
memoryCapability.remember(firstUnitSnapshot);
runtime.recordMemory(firstUnitSnapshot, "topic/main", List.of(), DEFAULT_PROFILE);
JSONObject state = JSONObject.parseObject(runtime.convert().toString());
JSONArray dateIndex = state.getJSONArray("date_index");
JSONObject firstDate = dateIndex.stream()
.map(JSONObject.class::cast)
.filter(item -> "1970-01-02".equals(item.getString("date")))
.findFirst()
.orElseThrow();
JSONObject secondDate = dateIndex.stream()
.map(JSONObject.class::cast)
.filter(item -> "1970-01-03".equals(item.getString("date")))
.findFirst()
.orElseThrow();
assertEquals(List.of("slice-1"),
firstDate.getJSONArray("refs").toJavaList(JSONObject.class).stream().map(obj -> obj.getString("slice_id")).toList());
assertEquals(List.of("slice-2"),
topicSlices.get("topic/main").stream().map(SliceRef::getSliceId).toList());
assertEquals(List.of("slice-2"),
topicSlices.get("topic/related").stream().map(SliceRef::getSliceId).toList());
secondDate.getJSONArray("refs").toJavaList(JSONObject.class).stream().map(obj -> obj.getString("slice_id")).toList());
}
@Test
@@ -168,8 +241,8 @@ class MemoryRuntimeTest {
setField(runtime, "memoryCapability", memoryCapability);
setField(runtime, "cognitionCapability", stubCognitionCapability(List.of(message("seed"))));
MemoryUnit unit = new MemoryUnit("unit-100");
unit.getConversationMessages().addAll(List.of(
MemoryUnit mainUnit = new MemoryUnit("unit-200");
mainUnit.getConversationMessages().addAll(List.of(
message("m0"),
message("m1"),
message("m2"),
@@ -177,10 +250,16 @@ class MemoryRuntimeTest {
));
MemorySlice firstSlice = MemorySlice.restore("slice-1", 0, 2, "first", 86_400_000L);
MemorySlice secondSlice = MemorySlice.restore("slice-2", 2, 4, "second", 172_800_000L);
unit.getSlices().addAll(List.of(firstSlice, secondSlice));
memoryCapability.remember(unit);
mainUnit.getSlices().addAll(List.of(firstSlice, secondSlice));
memoryCapability.remember(mainUnit);
runtime.recordMemory(mainUnit, "topic/main", List.of("topic/related"), DEFAULT_PROFILE);
runtime.recordMemory(unit, "topic/main", List.of("topic/related"));
MemoryUnit relatedUnit = new MemoryUnit("unit-201");
relatedUnit.getConversationMessages().addAll(List.of(message("r0"), message("r1")));
MemorySlice relatedSlice = MemorySlice.restore("slice-3", 0, 2, "related", 259_200_000L);
relatedUnit.getSlices().add(relatedSlice);
memoryCapability.remember(relatedUnit);
runtime.recordMemory(relatedUnit, "topic/related", List.of(), DEFAULT_PROFILE);
JSONObject state = JSONObject.parseObject(runtime.convert().toString());
JSONArray topicSlices = state.getJSONArray("topic_slices");
@@ -190,16 +269,23 @@ class MemoryRuntimeTest {
.filter(item -> "topic/main".equals(item.getString("topic_path")))
.findFirst()
.orElseThrow();
assertEquals("slice-2", mainTopic.getJSONArray("refs").getJSONObject(0).getString("slice_id"));
JSONObject binding = mainTopic.getJSONArray("bindings").getJSONObject(0);
assertEquals("slice-2", binding.getString("slice_id"));
assertEquals(172_800_000L, binding.getLongValue("timestamp"));
assertEquals(List.of("topic/related"), binding.getJSONArray("related_topic_paths").toJavaList(String.class));
JSONObject activationProfile = binding.getJSONObject("activation_profile");
assertEquals(0.55f, activationProfile.getFloatValue("activation_weight"));
assertEquals(0.35f, activationProfile.getFloatValue("diffusion_weight"));
assertEquals(0.50f, activationProfile.getFloatValue("context_independence_weight"));
JSONArray dateIndex = state.getJSONArray("date_index");
assertEquals(2, dateIndex.size());
JSONObject secondDate = dateIndex.stream()
JSONObject thirdDate = dateIndex.stream()
.map(JSONObject.class::cast)
.filter(item -> "1970-01-03".equals(item.getString("date")))
.filter(item -> "1970-01-04".equals(item.getString("date")))
.findFirst()
.orElseThrow();
assertEquals("slice-2", secondDate.getJSONArray("refs").getJSONObject(0).getString("slice_id"));
assertEquals("slice-3", thirdDate.getJSONArray("refs").getJSONObject(0).getString("slice_id"));
MemoryRuntime restored = new MemoryRuntime();
setField(restored, "memoryCapability", memoryCapability);
@@ -207,14 +293,14 @@ class MemoryRuntimeTest {
restored.load(state);
List<ActivatedMemorySlice> topicResult = restored.queryActivatedMemoryByTopicPath("topic/main");
assertEquals(1, topicResult.size());
assertEquals("slice-2", topicResult.getFirst().getSliceId());
assertEquals(List.of("slice-2", "slice-3"),
topicResult.stream().map(ActivatedMemorySlice::getSliceId).toList());
assertEquals(List.of("m2", "m3"), topicResult.getFirst().getMessages().stream().map(Message::getContent).toList());
List<ActivatedMemorySlice> dateResult = restored.queryActivatedMemoryByDate(LocalDate.parse("1970-01-03"));
List<ActivatedMemorySlice> dateResult = restored.queryActivatedMemoryByDate(LocalDate.parse("1970-01-04"));
assertEquals(1, dateResult.size());
assertEquals("slice-2", dateResult.getFirst().getSliceId());
assertEquals("second", dateResult.getFirst().getSummary());
assertEquals("slice-3", dateResult.getFirst().getSliceId());
assertEquals("related", dateResult.getFirst().getSummary());
}
@Test
@@ -228,6 +314,106 @@ class MemoryRuntimeTest {
assertTrue(runtime.queryActivatedMemoryByDate(LocalDate.parse("1970-01-01")).isEmpty());
}
@Test
void shouldRankTopicMatchesBySourceAndActivationProfile() throws Exception {
StubMemoryCapability memoryCapability = new StubMemoryCapability("session-test");
MemoryRuntime runtime = new MemoryRuntime();
setField(runtime, "memoryCapability", memoryCapability);
setField(runtime, "cognitionCapability", stubCognitionCapability(List.of(message("seed"))));
MemoryUnit primaryUnit = new MemoryUnit("unit-primary");
primaryUnit.getConversationMessages().addAll(List.of(message("p0"), message("p1")));
MemorySlice primarySlice = MemorySlice.restore("slice-primary", 0, 2, "primary", System.currentTimeMillis());
primaryUnit.getSlices().add(primarySlice);
memoryCapability.remember(primaryUnit);
runtime.recordMemory(primaryUnit, "topic->main", List.of("topic->related"), new ActivationProfile(0.9f, 0.1f, 0.9f));
MemoryUnit relatedUnit = new MemoryUnit("unit-related-rank");
relatedUnit.getConversationMessages().addAll(List.of(message("r0"), message("r1")));
MemorySlice relatedSlice = MemorySlice.restore("slice-related-rank", 0, 2, "related", System.currentTimeMillis());
relatedUnit.getSlices().add(relatedSlice);
memoryCapability.remember(relatedUnit);
runtime.recordMemory(relatedUnit, "topic->related", List.of(), new ActivationProfile(1.0f, 1.0f, 1.0f));
MemoryUnit parentUnit = new MemoryUnit("unit-parent");
parentUnit.getConversationMessages().addAll(List.of(message("x0"), message("x1")));
MemorySlice parentSlice = MemorySlice.restore("slice-parent", 0, 2, "parent", System.currentTimeMillis());
parentUnit.getSlices().add(parentSlice);
memoryCapability.remember(parentUnit);
runtime.recordMemory(parentUnit, "topic", List.of(), new ActivationProfile(1.0f, 1.0f, 1.0f));
List<ActivatedMemorySlice> topicResult = runtime.queryActivatedMemoryByTopicPath("topic->main");
assertEquals(List.of("slice-primary", "slice-related-rank", "slice-parent"),
topicResult.stream().map(ActivatedMemorySlice::getSliceId).toList());
}
@Test
void shouldNotExpandRelatedTopicWhenDiffusionWeightIsZero() throws Exception {
StubMemoryCapability memoryCapability = new StubMemoryCapability("session-test");
MemoryRuntime runtime = new MemoryRuntime();
setField(runtime, "memoryCapability", memoryCapability);
setField(runtime, "cognitionCapability", stubCognitionCapability(List.of(message("seed"))));
MemoryUnit primaryUnit = new MemoryUnit("unit-primary-zero");
primaryUnit.getConversationMessages().addAll(List.of(message("p0"), message("p1")));
MemorySlice primarySlice = MemorySlice.restore("slice-primary-zero", 0, 2, "primary", System.currentTimeMillis());
primaryUnit.getSlices().add(primarySlice);
memoryCapability.remember(primaryUnit);
runtime.recordMemory(
primaryUnit,
"topic->main",
List.of("topic->related"),
new ActivationProfile(0.8f, 0.0f, 0.8f)
);
MemoryUnit relatedUnit = new MemoryUnit("unit-related-zero");
relatedUnit.getConversationMessages().addAll(List.of(message("r0"), message("r1")));
MemorySlice relatedSlice = MemorySlice.restore("slice-related-zero", 0, 2, "related", System.currentTimeMillis());
relatedUnit.getSlices().add(relatedSlice);
memoryCapability.remember(relatedUnit);
runtime.recordMemory(relatedUnit, "topic->related", List.of(), new ActivationProfile(1.0f, 1.0f, 1.0f));
List<ActivatedMemorySlice> topicResult = runtime.queryActivatedMemoryByTopicPath("topic->main");
assertEquals(List.of("slice-primary-zero"), topicResult.stream().map(ActivatedMemorySlice::getSliceId).toList());
}
@Test
void shouldRefreshBindingTimestampAndActivationProfileWhenSameSliceRebound() throws Exception {
StubMemoryCapability memoryCapability = new StubMemoryCapability("session-test");
MemoryRuntime runtime = new MemoryRuntime();
setField(runtime, "memoryCapability", memoryCapability);
setField(runtime, "cognitionCapability", stubCognitionCapability(List.of(message("seed"))));
MemoryUnit unit = new MemoryUnit("unit-refresh");
unit.getConversationMessages().addAll(List.of(message("m0"), message("m1")));
MemorySlice slice = MemorySlice.restore("slice-refresh", 0, 2, "summary", 86_400_000L);
unit.getSlices().add(slice);
memoryCapability.remember(unit);
runtime.recordMemory(unit, "topic->main", List.of("topic->related"), new ActivationProfile(0.2f, 0.1f, 0.2f));
unit.getSlices().clear();
unit.getSlices().add(MemorySlice.restore("slice-refresh", 0, 2, "summary", 172_800_000L));
runtime.recordMemory(unit, "topic->main", List.of("topic->related-2"), new ActivationProfile(0.9f, 0.8f, 0.7f));
JSONObject state = JSONObject.parseObject(runtime.convert().toString());
JSONObject mainTopic = state.getJSONArray("topic_slices").stream()
.map(JSONObject.class::cast)
.filter(item -> "topic->main".equals(item.getString("topic_path")))
.findFirst()
.orElseThrow();
JSONObject binding = mainTopic.getJSONArray("bindings").getJSONObject(0);
JSONObject activationProfile = binding.getJSONObject("activation_profile");
assertEquals(172_800_000L, binding.getLongValue("timestamp"));
assertEquals(0.9f, activationProfile.getFloatValue("activation_weight"));
assertEquals(0.8f, activationProfile.getFloatValue("diffusion_weight"));
assertEquals(0.7f, activationProfile.getFloatValue("context_independence_weight"));
assertEquals(
List.of("topic->related", "topic->related-2"),
binding.getJSONArray("related_topic_paths").toJavaList(String.class)
);
}
private static final class StubMemoryCapability implements MemoryCapability {
private final String sessionId;
private final Map<String, MemoryUnit> units = new HashMap<>();

View File

@@ -4,7 +4,6 @@ import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir;
import org.mockito.Mockito;
import work.slhaf.partner.core.cognition.CognitionCapability;
import work.slhaf.partner.core.memory.pojo.MemorySlice;
import work.slhaf.partner.core.memory.pojo.MemoryUnit;
import work.slhaf.partner.framework.agent.exception.AgentRuntimeException;
@@ -12,6 +11,7 @@ import work.slhaf.partner.framework.agent.model.pojo.Message;
import work.slhaf.partner.framework.agent.support.Result;
import work.slhaf.partner.module.communication.AfterRollingRegistry;
import work.slhaf.partner.module.communication.RollingResult;
import work.slhaf.partner.module.memory.pojo.ActivationProfile;
import work.slhaf.partner.module.memory.runtime.MemoryRuntime;
import work.slhaf.partner.module.memory.updater.summarizer.entity.MemoryTopicResult;
@@ -19,6 +19,7 @@ import java.lang.reflect.Field;
import java.nio.file.Path;
import java.util.List;
import static org.mockito.ArgumentMatchers.argThat;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
@@ -55,11 +56,7 @@ class MemoryUpdaterTest {
void shouldExtractTopicAndRecordMemoryOnConsume() throws Exception {
MemoryUpdater updater = Mockito.spy(new MemoryUpdater());
MemoryRuntime memoryRuntime = Mockito.mock(MemoryRuntime.class);
CognitionCapability cognitionCapability = Mockito.mock(CognitionCapability.class);
setField(updater, "memoryRuntime", memoryRuntime);
setField(updater, "cognitionCapability", cognitionCapability);
when(cognitionCapability.contextWorkspace()).thenReturn(new work.slhaf.partner.core.cognition.ContextWorkspace());
when(memoryRuntime.getTopicTree()).thenReturn("topic-tree");
when(memoryRuntime.fixTopicPath("root[2]->branch[1]")).thenReturn("root->branch");
when(memoryRuntime.fixTopicPath("root[2]->related[1]")).thenReturn("root->related");
@@ -67,6 +64,7 @@ class MemoryUpdaterTest {
MemoryTopicResult topicResult = new MemoryTopicResult();
topicResult.setTopicPath("root[2]->branch[1]");
topicResult.setRelatedTopicPaths(List.of("root[2]->related[1]"));
topicResult.setActivationProfile(new ActivationProfile(0.8f, 0.9f, 0.7f));
Mockito.doReturn(Result.success(topicResult))
.when(updater)
.formattedChat(Mockito.anyList(), eq(MemoryTopicResult.class));
@@ -86,18 +84,22 @@ class MemoryUpdaterTest {
message(Message.Character.ASSISTANT, "new-reply")
), "slice-summary", 4, 6));
verify(memoryRuntime).recordMemory(eq(unit), eq("root->branch"), eq(List.of("root->related")));
verify(memoryRuntime).recordMemory(
eq(unit),
eq("root->branch"),
eq(List.of("root->related")),
argThat(profile -> profile != null
&& profile.getActivationWeight() == 0.8f
&& profile.getDiffusionWeight() == 0.9f
&& profile.getContextIndependenceWeight() == 0.7f)
);
}
@Test
void shouldFallbackToDateOnlyRecordWhenExtractionFails() throws Exception {
MemoryUpdater updater = Mockito.spy(new MemoryUpdater());
MemoryRuntime memoryRuntime = Mockito.mock(MemoryRuntime.class);
CognitionCapability cognitionCapability = Mockito.mock(CognitionCapability.class);
setField(updater, "memoryRuntime", memoryRuntime);
setField(updater, "cognitionCapability", cognitionCapability);
when(cognitionCapability.contextWorkspace()).thenReturn(new work.slhaf.partner.core.cognition.ContextWorkspace());
when(memoryRuntime.getTopicTree()).thenReturn("topic-tree");
Mockito.doReturn(Result.failure(new AgentRuntimeException("boom")))
.when(updater)
@@ -113,6 +115,48 @@ class MemoryUpdaterTest {
updater.consume(new RollingResult(unit, slice, unit.getConversationMessages(), "slice-summary", 2, 6));
verify(memoryRuntime).recordMemory(eq(unit), eq(null), eq(List.of()));
verify(memoryRuntime).recordMemory(
eq(unit),
eq(null),
eq(List.of()),
argThat(profile -> profile != null
&& profile.getActivationWeight() == 0.55f
&& profile.getDiffusionWeight() == 0.35f
&& profile.getContextIndependenceWeight() == 0.50f)
);
}
@Test
void shouldClampAndAdjustActivationProfileBeforeRecording() throws Exception {
MemoryUpdater updater = Mockito.spy(new MemoryUpdater());
MemoryRuntime memoryRuntime = Mockito.mock(MemoryRuntime.class);
setField(updater, "memoryRuntime", memoryRuntime);
when(memoryRuntime.getTopicTree()).thenReturn("topic-tree");
when(memoryRuntime.fixTopicPath("root[2]->branch[1]")).thenReturn("root->branch");
MemoryTopicResult topicResult = new MemoryTopicResult();
topicResult.setTopicPath("root[2]->branch[1]");
topicResult.setRelatedTopicPaths(List.of());
topicResult.setActivationProfile(new ActivationProfile(1.5f, 0.9f, -0.2f));
Mockito.doReturn(Result.success(topicResult))
.when(updater)
.formattedChat(Mockito.anyList(), eq(MemoryTopicResult.class));
MemoryUnit unit = new MemoryUnit("session-3");
unit.getConversationMessages().add(message(Message.Character.USER, "only"));
MemorySlice slice = new MemorySlice(0, 1, "slice-summary");
unit.getSlices().add(slice);
updater.consume(new RollingResult(unit, slice, unit.getConversationMessages(), "slice-summary", 1, 6));
verify(memoryRuntime).recordMemory(
eq(unit),
eq("root->branch"),
eq(List.of()),
argThat(profile -> profile != null
&& profile.getActivationWeight() == 0.95f
&& profile.getDiffusionWeight() == 0.45f
&& profile.getContextIndependenceWeight() == 0.0f)
);
}
}