refactor(memory): enhance topic-based memory runtime on recalling and indexing

This commit is contained in:
2026-04-18 22:28:40 +08:00
parent 92c8e01000
commit a7ef9bff49
12 changed files with 1022 additions and 267 deletions

View File

@@ -0,0 +1,14 @@
package work.slhaf.partner.module.memory.pojo;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;
@Data
@NoArgsConstructor
@AllArgsConstructor
public class ActivationProfile {
private Float activationWeight;
private Float diffusionWeight;
private Float contextIndependenceWeight;
}

View File

@@ -0,0 +1,36 @@
package work.slhaf.partner.module.memory.runtime;
import work.slhaf.partner.core.memory.pojo.SliceRef;
import java.time.LocalDate;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CopyOnWriteArrayList;
final class DateMemoryIndex {
private final Map<LocalDate, CopyOnWriteArrayList<SliceRef>> dateIndex = new HashMap<>();
void record(SliceRef sliceRef, LocalDate date) {
dateIndex.computeIfAbsent(date, key -> new CopyOnWriteArrayList<>()).addIfAbsent(sliceRef);
}
List<SliceRef> find(LocalDate date) {
List<SliceRef> refs = dateIndex.get(date);
return refs == null ? null : new ArrayList<>(refs);
}
void reset() {
dateIndex.clear();
}
void restore(LocalDate date, CopyOnWriteArrayList<SliceRef> refs) {
dateIndex.put(date, refs);
}
Map<LocalDate, CopyOnWriteArrayList<SliceRef>> entries() {
return dateIndex;
}
}

View File

@@ -1,8 +1,6 @@
package work.slhaf.partner.module.memory.runtime; package work.slhaf.partner.module.memory.runtime;
import com.alibaba.fastjson2.JSONArray;
import com.alibaba.fastjson2.JSONObject; import com.alibaba.fastjson2.JSONObject;
import lombok.extern.slf4j.Slf4j;
import org.jetbrains.annotations.NotNull; import org.jetbrains.annotations.NotNull;
import work.slhaf.partner.core.cognition.CognitionCapability; import work.slhaf.partner.core.cognition.CognitionCapability;
import work.slhaf.partner.core.memory.MemoryCapability; import work.slhaf.partner.core.memory.MemoryCapability;
@@ -16,8 +14,8 @@ import work.slhaf.partner.framework.agent.factory.component.annotation.Init;
import work.slhaf.partner.framework.agent.model.pojo.Message; import work.slhaf.partner.framework.agent.model.pojo.Message;
import work.slhaf.partner.framework.agent.state.State; import work.slhaf.partner.framework.agent.state.State;
import work.slhaf.partner.framework.agent.state.StateSerializable; import work.slhaf.partner.framework.agent.state.StateSerializable;
import work.slhaf.partner.framework.agent.state.StateValue;
import work.slhaf.partner.framework.agent.support.Result; import work.slhaf.partner.framework.agent.support.Result;
import work.slhaf.partner.module.memory.pojo.ActivationProfile;
import work.slhaf.partner.module.memory.runtime.exception.MemoryLookupException; import work.slhaf.partner.module.memory.runtime.exception.MemoryLookupException;
import work.slhaf.partner.module.memory.selector.ActivatedMemorySlice; import work.slhaf.partner.module.memory.selector.ActivatedMemorySlice;
@@ -25,11 +23,10 @@ import java.nio.file.Path;
import java.time.Instant; import java.time.Instant;
import java.time.LocalDate; import java.time.LocalDate;
import java.time.ZoneId; import java.time.ZoneId;
import java.util.*; import java.util.ArrayList;
import java.util.concurrent.CopyOnWriteArrayList; import java.util.List;
import java.util.concurrent.locks.ReentrantLock; import java.util.concurrent.locks.ReentrantLock;
@Slf4j
public class MemoryRuntime extends AbstractAgentModule.Standalone implements StateSerializable { public class MemoryRuntime extends AbstractAgentModule.Standalone implements StateSerializable {
@InjectCapability @InjectCapability
@@ -38,8 +35,10 @@ public class MemoryRuntime extends AbstractAgentModule.Standalone implements Sta
private CognitionCapability cognitionCapability; private CognitionCapability cognitionCapability;
private final ReentrantLock runtimeLock = new ReentrantLock(); private final ReentrantLock runtimeLock = new ReentrantLock();
private Map<String, CopyOnWriteArrayList<SliceRef>> topicSlices = new HashMap<>(); private final TopicMemoryIndex topicIndex = new TopicMemoryIndex();
private Map<LocalDate, CopyOnWriteArrayList<SliceRef>> dateIndex = new HashMap<>(); private final DateMemoryIndex dateIndex = new DateMemoryIndex();
private final TopicRecallCollector topicRecallCollector = new TopicRecallCollector(new TopicRecallScorer());
private final MemoryRuntimeStateCodec stateCodec = new MemoryRuntimeStateCodec();
@Init @Init
public void init() { public void init() {
@@ -53,80 +52,32 @@ public class MemoryRuntime extends AbstractAgentModule.Standalone implements Sta
} }
} }
private void bindTopic(String topicPath, SliceRef sliceRef) { public void recordMemory(MemoryUnit memoryUnit,
String normalizedPath = normalizeTopicPath(topicPath); String topicPath,
runtimeLock.lock(); List<String> relatedTopicPaths,
try { ActivationProfile activationProfile) {
CopyOnWriteArrayList<SliceRef> refs = topicSlices.computeIfAbsent(normalizedPath, key -> new CopyOnWriteArrayList<>());
boolean exists = refs.stream().anyMatch(ref -> Objects.equals(ref.getUnitId(), sliceRef.getUnitId())
&& Objects.equals(ref.getSliceId(), sliceRef.getSliceId()));
if (!exists) {
refs.add(sliceRef);
}
} finally {
runtimeLock.unlock();
}
}
public void recordMemory(MemoryUnit memoryUnit, String topicPath, List<String> relatedTopicPaths) {
MemorySlice memorySlice = memoryUnit.getSlices().getLast(); MemorySlice memorySlice = memoryUnit.getSlices().getLast();
SliceRef sliceRef = new SliceRef(memoryUnit.getId(), memorySlice.getId()); SliceRef sliceRef = new SliceRef(memoryUnit.getId(), memorySlice.getId());
indexMemoryUnit(memoryUnit); LocalDate date = toLocalDate(memorySlice.getTimestamp());
if (topicPath != null && !topicPath.isBlank()) {
bindTopic(topicPath, sliceRef);
}
for (String relatedTopicPath : relatedTopicPaths) {
if (relatedTopicPath != null && !relatedTopicPath.isBlank()) {
bindTopic(relatedTopicPath, sliceRef);
}
}
}
private void indexMemoryUnit(MemoryUnit memoryUnit) {
runtimeLock.lock(); runtimeLock.lock();
try { try {
for (CopyOnWriteArrayList<SliceRef> refs : dateIndex.values()) { List<String> normalizedRelatedTopicPaths = topicIndex.normalizeTopicPaths(relatedTopicPaths);
refs.removeIf(ref -> memoryUnit.getId().equals(ref.getUnitId())); dateIndex.record(sliceRef, date);
} if (topicPath != null && !topicPath.isBlank()) {
for (MemorySlice slice : memoryUnit.getSlices()) { topicIndex.recordBinding(
LocalDate date = Instant.ofEpochMilli(slice.getTimestamp()) topicPath,
.atZone(ZoneId.systemDefault()) sliceRef,
.toLocalDate(); memorySlice.getTimestamp(),
dateIndex.computeIfAbsent(date, key -> new CopyOnWriteArrayList<>()) normalizedRelatedTopicPaths,
.addIfAbsent(new SliceRef(memoryUnit.getId(), slice.getId())); activationProfile
);
} }
topicIndex.ensureTopicPaths(normalizedRelatedTopicPaths);
} finally { } finally {
runtimeLock.unlock(); runtimeLock.unlock();
} }
} }
private List<SliceRef> findByTopicPath(String topicPath) {
String normalizedPath = normalizeTopicPath(topicPath);
List<SliceRef> refs = topicSlices.get(normalizedPath);
if (refs == null) {
ExceptionReporterHandler.INSTANCE.report(new MemoryLookupException(
"Unexisted topic path: " + normalizedPath,
normalizedPath,
"TOPIC"
));
return List.of();
}
return new ArrayList<>(refs);
}
private List<SliceRef> findByDate(LocalDate date) {
List<SliceRef> refs = dateIndex.get(date);
if (refs == null) {
ExceptionReporterHandler.INSTANCE.report(new MemoryLookupException(
"Unexisted date index: " + date,
date.toString(),
"DATE_INDEX"
));
return List.of();
}
return new ArrayList<>(refs);
}
public List<ActivatedMemorySlice> queryActivatedMemoryByTopicPath(String topicPath) { public List<ActivatedMemorySlice> queryActivatedMemoryByTopicPath(String topicPath) {
return buildActivatedMemorySlices(findByTopicPath(topicPath)); return buildActivatedMemorySlices(findByTopicPath(topicPath));
} }
@@ -136,23 +87,61 @@ public class MemoryRuntime extends AbstractAgentModule.Standalone implements Sta
} }
public String getTopicTree() { public String getTopicTree() {
TopicTreeNode root = new TopicTreeNode(); runtimeLock.lock();
for (Map.Entry<String, CopyOnWriteArrayList<SliceRef>> entry : topicSlices.entrySet()) { try {
String[] parts = entry.getKey().split("->"); return topicIndex.getTopicTree();
TopicTreeNode current = root; } finally {
for (String part : parts) { runtimeLock.unlock();
current = current.children.computeIfAbsent(part, key -> new TopicTreeNode());
}
current.count += entry.getValue().size();
} }
}
StringBuilder stringBuilder = new StringBuilder(); public String fixTopicPath(String topicPath) {
List<Map.Entry<String, TopicTreeNode>> roots = new ArrayList<>(root.children.entrySet()); String[] parts = topicPath.split("->");
for (Map.Entry<String, TopicTreeNode> entry : roots) { List<String> cleanedParts = new ArrayList<>();
stringBuilder.append(entry.getKey()).append("[root]").append("\r\n"); for (String part : parts) {
printSubTopicsTreeFormat(entry.getValue(), "", stringBuilder); String cleaned = part.replaceAll("\\[[^]]*]", "").trim();
if (!cleaned.isEmpty()) {
cleanedParts.add(cleaned);
}
}
return String.join("->", cleanedParts);
}
private List<SliceRef> findByTopicPath(String topicPath) {
runtimeLock.lock();
try {
TopicMemoryIndex.TopicTreeNode topicNode = topicIndex.findTopicNode(topicPath);
if (topicNode == null) {
String normalizedPath = topicIndex.normalizeTopicPath(topicPath);
ExceptionReporterHandler.INSTANCE.report(new MemoryLookupException(
"Unexisted topic path: " + normalizedPath,
normalizedPath,
"TOPIC"
));
return List.of();
}
return topicRecallCollector.collect(topicIndex, topicNode);
} finally {
runtimeLock.unlock();
}
}
private List<SliceRef> findByDate(LocalDate date) {
runtimeLock.lock();
try {
List<SliceRef> refs = dateIndex.find(date);
if (refs == null) {
ExceptionReporterHandler.INSTANCE.report(new MemoryLookupException(
"Unexisted date index: " + date,
date.toString(),
"DATE_INDEX"
));
return List.of();
}
return refs;
} finally {
runtimeLock.unlock();
} }
return stringBuilder.toString();
} }
private List<ActivatedMemorySlice> buildActivatedMemorySlices(List<SliceRef> refs) { private List<ActivatedMemorySlice> buildActivatedMemorySlices(List<SliceRef> refs) {
@@ -169,14 +158,12 @@ public class MemoryRuntime extends AbstractAgentModule.Standalone implements Sta
private ActivatedMemorySlice buildActivatedMemorySlice(SliceRef ref) { private ActivatedMemorySlice buildActivatedMemorySlice(SliceRef ref) {
MemoryUnit memoryUnit = memoryCapability.getMemoryUnit(ref.getUnitId()); MemoryUnit memoryUnit = memoryCapability.getMemoryUnit(ref.getUnitId());
Result<MemorySlice> memorySliceResult = memoryCapability.getMemorySlice(ref.getUnitId(), ref.getSliceId()); Result<MemorySlice> memorySliceResult = memoryCapability.getMemorySlice(ref.getUnitId(), ref.getSliceId());
if (memorySliceResult.exceptionOrNull() != null) { if (memoryUnit == null || memorySliceResult.exceptionOrNull() != null) {
return null; return null;
} }
MemorySlice memorySlice = memorySliceResult.getOrThrow(); MemorySlice memorySlice = memorySliceResult.getOrThrow();
List<Message> messages = sliceMessages(memoryUnit, memorySlice); List<Message> messages = sliceMessages(memoryUnit, memorySlice);
LocalDate date = Instant.ofEpochMilli(memorySlice.getTimestamp()) LocalDate date = toLocalDate(memorySlice.getTimestamp());
.atZone(ZoneId.systemDefault())
.toLocalDate();
return ActivatedMemorySlice.builder() return ActivatedMemorySlice.builder()
.unitId(ref.getUnitId()) .unitId(ref.getUnitId())
.sliceId(ref.getSliceId()) .sliceId(ref.getSliceId())
@@ -201,29 +188,14 @@ public class MemoryRuntime extends AbstractAgentModule.Standalone implements Sta
return new ArrayList<>(conversationMessages.subList(start, end)); return new ArrayList<>(conversationMessages.subList(start, end));
} }
private void printSubTopicsTreeFormat(TopicTreeNode node, String prefix, StringBuilder stringBuilder) { private LocalDate toLocalDate(Long timestamp) {
List<Map.Entry<String, TopicTreeNode>> entries = new ArrayList<>(node.children.entrySet()); return Instant.ofEpochMilli(timestamp)
for (int i = 0; i < entries.size(); i++) { .atZone(ZoneId.systemDefault())
boolean last = i == entries.size() - 1; .toLocalDate();
Map.Entry<String, TopicTreeNode> entry = entries.get(i);
stringBuilder.append(prefix)
.append(last ? "└── " : "├── ")
.append(entry.getKey())
.append("[")
.append(entry.getValue().count)
.append("]")
.append("\r\n");
printSubTopicsTreeFormat(entry.getValue(), prefix + (last ? " " : ""), stringBuilder);
}
}
private String normalizeTopicPath(String topicPath) {
return topicPath == null ? "" : topicPath.trim();
} }
@Override @Override
@NotNull public @NotNull Path statePath() {
public Path statePath() {
return Path.of("module", "memory", "topic_based_memory.json"); return Path.of("module", "memory", "topic_based_memory.json");
} }
@@ -231,42 +203,7 @@ public class MemoryRuntime extends AbstractAgentModule.Standalone implements Sta
public void load(@NotNull JSONObject state) { public void load(@NotNull JSONObject state) {
runtimeLock.lock(); runtimeLock.lock();
try { try {
topicSlices = new HashMap<>(); stateCodec.load(state, topicIndex, dateIndex);
dateIndex = new HashMap<>();
JSONArray topicSlicesArray = state.getJSONArray("topic_slices");
if (topicSlicesArray != null) {
for (int i = 0; i < topicSlicesArray.size(); i++) {
JSONObject topicObject = topicSlicesArray.getJSONObject(i);
if (topicObject == null) {
continue;
}
String topicPath = topicObject.getString("topic_path");
if (topicPath == null) {
continue;
}
topicSlices.put(normalizeTopicPath(topicPath), decodeSliceRefs(topicObject.getJSONArray("refs")));
}
}
JSONArray dateIndexArray = state.getJSONArray("date_index");
if (dateIndexArray != null) {
for (int i = 0; i < dateIndexArray.size(); i++) {
JSONObject dateObject = dateIndexArray.getJSONObject(i);
if (dateObject == null) {
continue;
}
String date = dateObject.getString("date");
if (date == null) {
continue;
}
try {
dateIndex.put(LocalDate.parse(date), decodeSliceRefs(dateObject.getJSONArray("refs")));
} catch (Exception e) {
log.warn("skip invalid date index: {}", date, e);
}
}
}
} finally { } finally {
runtimeLock.unlock(); runtimeLock.unlock();
} }
@@ -276,78 +213,9 @@ public class MemoryRuntime extends AbstractAgentModule.Standalone implements Sta
public @NotNull State convert() { public @NotNull State convert() {
runtimeLock.lock(); runtimeLock.lock();
try { try {
State state = new State(); return stateCodec.convert(topicIndex, dateIndex);
List<StateValue.Obj> topicSliceStates = topicSlices.entrySet().stream()
.sorted(Map.Entry.comparingByKey())
.map(entry -> StateValue.obj(Map.of(
"topic_path", StateValue.str(entry.getKey()),
"refs", StateValue.arr(encodeSliceRefs(entry.getValue()))
)))
.toList();
state.append("topic_slices", StateValue.arr(topicSliceStates));
List<StateValue.Obj> dateIndexStates = dateIndex.entrySet().stream()
.sorted(Map.Entry.comparingByKey())
.map(entry -> StateValue.obj(Map.of(
"date", StateValue.str(entry.getKey().toString()),
"refs", StateValue.arr(encodeSliceRefs(entry.getValue()))
)))
.toList();
state.append("date_index", StateValue.arr(dateIndexStates));
return state;
} finally { } finally {
runtimeLock.unlock(); runtimeLock.unlock();
} }
} }
private List<StateValue> encodeSliceRefs(List<SliceRef> refs) {
return refs.stream()
.map(ref -> (StateValue) StateValue.obj(Map.of(
"unit_id", StateValue.str(ref.getUnitId()),
"slice_id", StateValue.str(ref.getSliceId())
)))
.toList();
}
private CopyOnWriteArrayList<SliceRef> decodeSliceRefs(JSONArray refsArray) {
CopyOnWriteArrayList<SliceRef> refs = new CopyOnWriteArrayList<>();
if (refsArray == null) {
return refs;
}
for (int i = 0; i < refsArray.size(); i++) {
JSONObject refObject = refsArray.getJSONObject(i);
if (refObject == null) {
continue;
}
String unitId = refObject.getString("unit_id");
String sliceId = refObject.getString("slice_id");
if (unitId == null || sliceId == null) {
continue;
}
refs.addIfAbsent(new SliceRef(unitId, sliceId));
}
return refs;
}
public String fixTopicPath(String topicPath) {
String[] parts = topicPath.split("->");
List<String> cleanedParts = new ArrayList<>();
for (String part : parts) {
// 修正正则表达式,正确移除 [xxx] 部分
String cleaned = part.replaceAll("\\[[^\\]]*\\]", "").trim();
if (!cleaned.isEmpty()) { // 忽略空字符串
cleanedParts.add(cleaned);
}
}
return String.join("->", cleanedParts);
}
private static final class TopicTreeNode {
private final Map<String, TopicTreeNode> children = new LinkedHashMap<>();
private int count;
}
} }

View File

@@ -0,0 +1,184 @@
package work.slhaf.partner.module.memory.runtime;
import com.alibaba.fastjson2.JSONArray;
import com.alibaba.fastjson2.JSONObject;
import lombok.extern.slf4j.Slf4j;
import work.slhaf.partner.core.memory.pojo.SliceRef;
import work.slhaf.partner.framework.agent.state.State;
import work.slhaf.partner.framework.agent.state.StateValue;
import work.slhaf.partner.module.memory.pojo.ActivationProfile;
import java.time.LocalDate;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CopyOnWriteArrayList;
@Slf4j
final class MemoryRuntimeStateCodec {
void load(JSONObject state, TopicMemoryIndex topicIndex, DateMemoryIndex dateIndex) {
topicIndex.reset();
dateIndex.reset();
JSONArray topicSlicesArray = state.getJSONArray("topic_slices");
if (topicSlicesArray != null) {
for (int i = 0; i < topicSlicesArray.size(); i++) {
JSONObject topicObject = topicSlicesArray.getJSONObject(i);
if (topicObject == null) {
continue;
}
String topicPath = topicObject.getString("topic_path");
if (topicPath == null) {
continue;
}
topicIndex.ensureTopicPath(topicPath);
decodeTopicBindings(topicIndex, topicPath, topicObject.getJSONArray("bindings"));
}
}
JSONArray dateIndexArray = state.getJSONArray("date_index");
if (dateIndexArray != null) {
for (int i = 0; i < dateIndexArray.size(); i++) {
JSONObject dateObject = dateIndexArray.getJSONObject(i);
if (dateObject == null) {
continue;
}
String date = dateObject.getString("date");
if (date == null) {
continue;
}
try {
dateIndex.restore(LocalDate.parse(date), decodeSliceRefs(dateObject.getJSONArray("refs")));
} catch (Exception e) {
log.warn("skip invalid date index: {}", date, e);
}
}
}
}
State convert(TopicMemoryIndex topicIndex, DateMemoryIndex dateIndex) {
State state = new State();
List<StateValue.Obj> topicSliceStates = new ArrayList<>();
for (Map.Entry<String, TopicMemoryIndex.TopicTreeNode> entry : topicIndex.roots().entrySet()) {
collectTopicStates(entry.getKey(), entry.getValue(), topicSliceStates);
}
state.append("topic_slices", StateValue.arr(topicSliceStates));
List<StateValue.Obj> dateIndexStates = dateIndex.entries().entrySet().stream()
.sorted(Map.Entry.comparingByKey())
.map(entry -> StateValue.obj(Map.of(
"date", StateValue.str(entry.getKey().toString()),
"refs", StateValue.arr(encodeSliceRefs(entry.getValue()))
)))
.toList();
state.append("date_index", StateValue.arr(dateIndexStates));
return state;
}
private void collectTopicStates(String path,
TopicMemoryIndex.TopicTreeNode topicNode,
List<StateValue.Obj> topicStates) {
topicStates.add(StateValue.obj(Map.of(
"topic_path", StateValue.str(path),
"bindings", StateValue.arr(encodeTopicBindings(topicNode.bindings()))
)));
for (Map.Entry<String, TopicMemoryIndex.TopicTreeNode> childEntry : topicNode.children().entrySet()) {
collectTopicStates(path + "->" + childEntry.getKey(), childEntry.getValue(), topicStates);
}
}
private List<StateValue> encodeTopicBindings(List<TopicMemoryIndex.TopicBinding> bindings) {
return bindings.stream()
.map(binding -> (StateValue) StateValue.obj(Map.of(
"unit_id", StateValue.str(binding.sliceRef().getUnitId()),
"slice_id", StateValue.str(binding.sliceRef().getSliceId()),
"timestamp", StateValue.num(binding.timestamp()),
"activation_profile", StateValue.obj(Map.of(
"activation_weight", StateValue.num(binding.activationProfile().getActivationWeight()),
"diffusion_weight", StateValue.num(binding.activationProfile().getDiffusionWeight()),
"context_independence_weight",
StateValue.num(binding.activationProfile().getContextIndependenceWeight())
)),
"related_topic_paths", StateValue.arr(binding.relatedTopicPaths().stream()
.map(StateValue::str)
.toList())
)))
.toList();
}
private void decodeTopicBindings(TopicMemoryIndex topicIndex, String topicPath, JSONArray bindingsArray) {
if (bindingsArray == null) {
return;
}
for (int i = 0; i < bindingsArray.size(); i++) {
JSONObject bindingObject = bindingsArray.getJSONObject(i);
if (bindingObject == null) {
continue;
}
String unitId = bindingObject.getString("unit_id");
String sliceId = bindingObject.getString("slice_id");
if (unitId == null || sliceId == null) {
continue;
}
Long timestamp = bindingObject.getLong("timestamp");
if (timestamp == null) {
log.warn("skip topic binding without timestamp: {}:{}", unitId, sliceId);
continue;
}
List<String> relatedTopicPaths = topicIndex.normalizeTopicPaths(
bindingObject.getList("related_topic_paths", String.class)
);
topicIndex.recordBinding(
topicPath,
new SliceRef(unitId, sliceId),
timestamp,
relatedTopicPaths,
decodeActivationProfile(bindingObject.getJSONObject("activation_profile"))
);
topicIndex.ensureTopicPaths(relatedTopicPaths);
}
}
private ActivationProfile decodeActivationProfile(JSONObject profileObject) {
if (profileObject == null) {
return null;
}
return new ActivationProfile(
profileObject.getFloat("activation_weight"),
profileObject.getFloat("diffusion_weight"),
profileObject.getFloat("context_independence_weight")
);
}
private List<StateValue> encodeSliceRefs(List<SliceRef> refs) {
return refs.stream()
.map(ref -> (StateValue) StateValue.obj(Map.of(
"unit_id", StateValue.str(ref.getUnitId()),
"slice_id", StateValue.str(ref.getSliceId())
)))
.toList();
}
private CopyOnWriteArrayList<SliceRef> decodeSliceRefs(JSONArray refsArray) {
CopyOnWriteArrayList<SliceRef> refs = new CopyOnWriteArrayList<>();
if (refsArray == null) {
return refs;
}
for (int i = 0; i < refsArray.size(); i++) {
JSONObject refObject = refsArray.getJSONObject(i);
if (refObject == null) {
continue;
}
String unitId = refObject.getString("unit_id");
String sliceId = refObject.getString("slice_id");
if (unitId == null || sliceId == null) {
continue;
}
refs.addIfAbsent(new SliceRef(unitId, sliceId));
}
return refs;
}
}

View File

@@ -0,0 +1,237 @@
package work.slhaf.partner.module.memory.runtime;
import work.slhaf.partner.core.memory.pojo.SliceRef;
import work.slhaf.partner.module.memory.pojo.ActivationProfile;
import java.util.*;
import java.util.concurrent.CopyOnWriteArrayList;
final class TopicMemoryIndex {
private static final float DEFAULT_ACTIVATION_WEIGHT = 0.55f;
private static final float DEFAULT_DIFFUSION_WEIGHT = 0.35f;
private static final float DEFAULT_CONTEXT_INDEPENDENCE_WEIGHT = 0.50f;
private final Map<String, TopicTreeNode> topicSlices = new LinkedHashMap<>();
void recordBinding(String topicPath,
SliceRef sliceRef,
long timestamp,
Collection<String> relatedTopicPaths,
ActivationProfile activationProfile) {
String normalizedPath = normalizeTopicPath(topicPath);
if (normalizedPath.isBlank()) {
return;
}
ensureTopicNode(normalizedPath).addBinding(
sliceRef,
timestamp,
relatedTopicPaths,
normalizeActivationProfile(activationProfile)
);
}
void ensureTopicPaths(Collection<String> topicPaths) {
if (topicPaths == null || topicPaths.isEmpty()) {
return;
}
for (String topicPath : topicPaths) {
ensureTopicNode(topicPath);
}
}
void reset() {
topicSlices.clear();
}
void ensureTopicPath(String topicPath) {
String normalizedPath = normalizeTopicPath(topicPath);
if (normalizedPath.isBlank()) {
return;
}
ensureTopicNode(normalizedPath);
}
TopicTreeNode findTopicNode(String topicPath) {
String normalizedPath = normalizeTopicPath(topicPath);
if (normalizedPath.isBlank()) {
return null;
}
String[] parts = normalizedPath.split("->");
TopicTreeNode current = topicSlices.get(parts[0]);
for (int i = 1; current != null && i < parts.length; i++) {
current = current.children().get(parts[i]);
}
return current;
}
String getTopicTree() {
List<String> lines = new ArrayList<>();
for (Map.Entry<String, TopicTreeNode> entry : topicSlices.entrySet()) {
collectTopicTreeLines(entry.getKey(), entry.getValue(), lines);
}
return String.join("\r\n", lines);
}
List<String> normalizeTopicPaths(Collection<String> topicPaths) {
if (topicPaths == null || topicPaths.isEmpty()) {
return List.of();
}
LinkedHashSet<String> normalized = new LinkedHashSet<>();
for (String topicPath : topicPaths) {
String normalizedPath = normalizeTopicPath(topicPath);
if (!normalizedPath.isBlank()) {
normalized.add(normalizedPath);
}
}
return List.copyOf(normalized);
}
String normalizeTopicPath(String topicPath) {
return topicPath == null ? "" : topicPath.trim();
}
Map<String, TopicTreeNode> roots() {
return topicSlices;
}
private TopicTreeNode ensureTopicNode(String topicPath) {
String[] parts = topicPath.split("->");
TopicTreeNode current = topicSlices.computeIfAbsent(parts[0], ignored -> new TopicTreeNode(null));
for (int i = 1; i < parts.length; i++) {
TopicTreeNode parent = current;
current = current.children.computeIfAbsent(parts[i], ignored -> new TopicTreeNode(parent));
}
return current;
}
private void collectTopicTreeLines(String path, TopicTreeNode node, List<String> lines) {
if (node.parent() == null) {
lines.add(path + " [root]");
} else {
lines.add(path + " {slices: " + node.bindings().size() + "}");
}
for (Map.Entry<String, TopicTreeNode> childEntry : node.children().entrySet()) {
collectTopicTreeLines(path + "->" + childEntry.getKey(), childEntry.getValue(), lines);
}
}
private ActivationProfile normalizeActivationProfile(ActivationProfile activationProfile) {
ActivationProfile profile = activationProfile == null ? defaultActivationProfile() : new ActivationProfile(
activationProfile.getActivationWeight(),
activationProfile.getDiffusionWeight(),
activationProfile.getContextIndependenceWeight()
);
profile.setActivationWeight(clampOrDefault(profile.getActivationWeight(), DEFAULT_ACTIVATION_WEIGHT));
profile.setDiffusionWeight(clampOrDefault(profile.getDiffusionWeight(), DEFAULT_DIFFUSION_WEIGHT));
profile.setContextIndependenceWeight(clampOrDefault(
profile.getContextIndependenceWeight(),
DEFAULT_CONTEXT_INDEPENDENCE_WEIGHT
));
return profile;
}
private ActivationProfile defaultActivationProfile() {
return new ActivationProfile(
DEFAULT_ACTIVATION_WEIGHT,
DEFAULT_DIFFUSION_WEIGHT,
DEFAULT_CONTEXT_INDEPENDENCE_WEIGHT
);
}
private float clampOrDefault(Float value, float defaultValue) {
return value == null ? defaultValue : clamp(value);
}
private float clamp(float value) {
return Math.clamp(value, 0.0f, 1.0f);
}
static final class TopicTreeNode {
private final TopicTreeNode parent;
private final Map<String, TopicTreeNode> children = new LinkedHashMap<>();
private final CopyOnWriteArrayList<TopicBinding> bindings = new CopyOnWriteArrayList<>();
private TopicTreeNode(TopicTreeNode parent) {
this.parent = parent;
}
TopicTreeNode parent() {
return parent;
}
Map<String, TopicTreeNode> children() {
return children;
}
List<TopicBinding> bindings() {
return bindings;
}
private void addBinding(SliceRef sliceRef,
long timestamp,
Collection<String> relatedTopicPaths,
ActivationProfile activationProfile) {
for (TopicBinding binding : bindings) {
if (Objects.equals(binding.sliceRef().getUnitId(), sliceRef.getUnitId())
&& Objects.equals(binding.sliceRef().getSliceId(), sliceRef.getSliceId())) {
binding.refresh(timestamp, relatedTopicPaths, activationProfile);
return;
}
}
bindings.add(new TopicBinding(sliceRef, timestamp, relatedTopicPaths, activationProfile));
}
}
static final class TopicBinding {
private final SliceRef sliceRef;
private final CopyOnWriteArrayList<String> relatedTopicPaths = new CopyOnWriteArrayList<>();
private long timestamp;
private ActivationProfile activationProfile;
private TopicBinding(SliceRef sliceRef,
long timestamp,
Collection<String> relatedTopicPaths,
ActivationProfile activationProfile) {
this.sliceRef = sliceRef;
this.timestamp = timestamp;
this.activationProfile = activationProfile;
mergeRelatedTopicPaths(relatedTopicPaths);
}
SliceRef sliceRef() {
return sliceRef;
}
long timestamp() {
return timestamp;
}
ActivationProfile activationProfile() {
return activationProfile;
}
List<String> relatedTopicPaths() {
return relatedTopicPaths;
}
private void refresh(long timestamp,
Collection<String> relatedTopicPaths,
ActivationProfile activationProfile) {
this.timestamp = timestamp;
this.activationProfile = activationProfile;
mergeRelatedTopicPaths(relatedTopicPaths);
}
private void mergeRelatedTopicPaths(Collection<String> relatedTopicPaths) {
if (relatedTopicPaths == null) {
return;
}
for (String relatedTopicPath : relatedTopicPaths) {
if (relatedTopicPath != null && !relatedTopicPath.isBlank()) {
this.relatedTopicPaths.addIfAbsent(relatedTopicPath);
}
}
}
}
}

View File

@@ -0,0 +1,90 @@
package work.slhaf.partner.module.memory.runtime;
import work.slhaf.partner.core.memory.pojo.SliceRef;
import java.util.*;
final class TopicRecallCollector {
private static final int TOPIC_RESULT_LIMIT = 5;
private static final int PARENT_CANDIDATE_LIMIT = 2;
private static final int RELATED_CANDIDATE_LIMIT = 2;
private final TopicRecallScorer scorer;
TopicRecallCollector(TopicRecallScorer scorer) {
this.scorer = scorer;
}
List<SliceRef> collect(TopicMemoryIndex topicIndex, TopicMemoryIndex.TopicTreeNode topicNode) {
LinkedHashMap<String, ScoredSliceCandidate> candidates = new LinkedHashMap<>();
LinkedHashMap<String, Float> relatedTopicPaths = new LinkedHashMap<>();
collectTopicCandidates(
topicNode,
TopicRecallScorer.CandidateSource.PRIMARY,
Integer.MAX_VALUE,
candidates,
relatedTopicPaths
);
collectTopicCandidates(
topicNode.parent(),
TopicRecallScorer.CandidateSource.PARENT,
PARENT_CANDIDATE_LIMIT,
candidates,
null
);
for (Map.Entry<String, Float> relatedTopicEntry : relatedTopicPaths.entrySet()) {
if (relatedTopicEntry.getValue() <= 0.0f) {
continue;
}
collectTopicCandidates(
topicIndex.findTopicNode(relatedTopicEntry.getKey()),
TopicRecallScorer.CandidateSource.RELATED,
RELATED_CANDIDATE_LIMIT,
candidates,
null
);
}
return candidates.values().stream()
.sorted(Comparator.comparingDouble(ScoredSliceCandidate::score)
.reversed()
.thenComparing(Comparator.comparingLong(ScoredSliceCandidate::timestamp).reversed()))
.limit(TOPIC_RESULT_LIMIT)
.map(ScoredSliceCandidate::sliceRef)
.toList();
}
private void collectTopicCandidates(TopicMemoryIndex.TopicTreeNode topicNode,
TopicRecallScorer.CandidateSource source,
int limit,
LinkedHashMap<String, ScoredSliceCandidate> candidates,
Map<String, Float> relatedTopicPaths) {
if (topicNode == null || topicNode.bindings().isEmpty()) {
return;
}
List<TopicMemoryIndex.TopicBinding> bindings = new ArrayList<>(topicNode.bindings());
bindings.sort(Comparator.comparingLong(TopicMemoryIndex.TopicBinding::timestamp).reversed());
int actualLimit = limit == Integer.MAX_VALUE ? bindings.size() : Math.min(limit, bindings.size());
for (int i = 0; i < actualLimit; i++) {
TopicMemoryIndex.TopicBinding binding = bindings.get(i);
if (relatedTopicPaths != null) {
for (String relatedTopicPath : binding.relatedTopicPaths()) {
relatedTopicPaths.merge(
relatedTopicPath,
binding.activationProfile().getDiffusionWeight(),
Math::max
);
}
}
double score = scorer.score(binding, source);
String key = binding.sliceRef().getUnitId() + ":" + binding.sliceRef().getSliceId();
ScoredSliceCandidate current = candidates.get(key);
if (current == null || score > current.score()) {
candidates.put(key, new ScoredSliceCandidate(binding.sliceRef(), binding.timestamp(), score));
}
}
}
private record ScoredSliceCandidate(SliceRef sliceRef, long timestamp, double score) {
}
}

View File

@@ -0,0 +1,47 @@
package work.slhaf.partner.module.memory.runtime;
import work.slhaf.partner.module.memory.pojo.ActivationProfile;
final class TopicRecallScorer {
double score(TopicMemoryIndex.TopicBinding binding, CandidateSource source) {
ActivationProfile profile = binding.activationProfile();
return source.sourceScore
+ recencyScore(binding.timestamp())
+ 0.50d * profile.getActivationWeight()
+ 0.30d * profile.getContextIndependenceWeight()
+ 0.20d * source.relationFactor * profile.getDiffusionWeight();
}
private double recencyScore(long timestamp) {
long ageMillis = Math.max(0L, System.currentTimeMillis() - timestamp);
long ageDays = ageMillis / 86_400_000L;
if (ageDays <= 1) {
return 0.30d;
}
if (ageDays <= 3) {
return 0.22d;
}
if (ageDays <= 7) {
return 0.15d;
}
if (ageDays <= 30) {
return 0.08d;
}
return 0.00d;
}
enum CandidateSource {
PRIMARY(1.00f, 0.30f),
RELATED(0.65f, 1.00f),
PARENT(0.45f, 0.20f);
private final float sourceScore;
private final float relationFactor;
CandidateSource(float sourceScore, float relationFactor) {
this.sourceScore = sourceScore;
this.relationFactor = relationFactor;
}
}
}

View File

@@ -177,6 +177,6 @@ public class MemoryRecallCueExtractor extends AbstractAgentModule.Sub<ExtractorI
@NotNull @NotNull
@Override @Override
public String modelKey() { public String modelKey() {
return "topic_extractor"; return "memory_recall_cue_extractor";
} }
} }

View File

@@ -3,9 +3,6 @@ package work.slhaf.partner.module.memory.updater;
import org.jetbrains.annotations.NotNull; import org.jetbrains.annotations.NotNull;
import org.w3c.dom.Document; import org.w3c.dom.Document;
import org.w3c.dom.Element; import org.w3c.dom.Element;
import work.slhaf.partner.core.cognition.CognitionCapability;
import work.slhaf.partner.core.cognition.ContextBlock;
import work.slhaf.partner.framework.agent.factory.capability.annotation.InjectCapability;
import work.slhaf.partner.framework.agent.factory.component.abstracts.AbstractAgentModule; import work.slhaf.partner.framework.agent.factory.component.abstracts.AbstractAgentModule;
import work.slhaf.partner.framework.agent.factory.component.annotation.Init; import work.slhaf.partner.framework.agent.factory.component.annotation.Init;
import work.slhaf.partner.framework.agent.factory.component.annotation.InjectModule; import work.slhaf.partner.framework.agent.factory.component.annotation.InjectModule;
@@ -16,6 +13,7 @@ import work.slhaf.partner.module.TaskBlock;
import work.slhaf.partner.module.communication.AfterRolling; import work.slhaf.partner.module.communication.AfterRolling;
import work.slhaf.partner.module.communication.AfterRollingRegistry; import work.slhaf.partner.module.communication.AfterRollingRegistry;
import work.slhaf.partner.module.communication.RollingResult; import work.slhaf.partner.module.communication.RollingResult;
import work.slhaf.partner.module.memory.pojo.ActivationProfile;
import work.slhaf.partner.module.memory.runtime.MemoryRuntime; import work.slhaf.partner.module.memory.runtime.MemoryRuntime;
import work.slhaf.partner.module.memory.updater.summarizer.entity.MemoryTopicResult; import work.slhaf.partner.module.memory.updater.summarizer.entity.MemoryTopicResult;
@@ -23,8 +21,11 @@ import java.util.List;
public class MemoryUpdater extends AbstractAgentModule.Standalone implements AfterRolling, ActivateModel { public class MemoryUpdater extends AbstractAgentModule.Standalone implements AfterRolling, ActivateModel {
@InjectCapability private static final float DEFAULT_ACTIVATION_WEIGHT = 0.55f;
private CognitionCapability cognitionCapability; private static final float DEFAULT_DIFFUSION_WEIGHT = 0.35f;
private static final float DEFAULT_CONTEXT_INDEPENDENCE_WEIGHT = 0.50f;
private static final float NO_RELATED_DIFFUSION_CAP = 0.45f;
private static final float SINGLE_MESSAGE_ACTIVATION_PENALTY = 0.05f;
@InjectModule @InjectModule
private MemoryRuntime memoryRuntime; private MemoryRuntime memoryRuntime;
@@ -44,10 +45,6 @@ public class MemoryUpdater extends AbstractAgentModule.Standalone implements Aft
} }
Result<MemoryTopicResult> extractResult = formattedChat( Result<MemoryTopicResult> extractResult = formattedChat(
List.of( List.of(
cognitionCapability.contextWorkspace().resolve(List.of(
ContextBlock.FocusedDomain.COGNITION,
ContextBlock.FocusedDomain.MEMORY
)).encodeToMessage(),
resolveTopicTaskMessage(result, slicedMessages) resolveTopicTaskMessage(result, slicedMessages)
), ),
MemoryTopicResult.class MemoryTopicResult.class
@@ -57,8 +54,18 @@ public class MemoryUpdater extends AbstractAgentModule.Standalone implements Aft
List<String> relatedTopicPaths = topicResult.getRelatedTopicPaths() == null List<String> relatedTopicPaths = topicResult.getRelatedTopicPaths() == null
? List.of() ? List.of()
: topicResult.getRelatedTopicPaths().stream().map(memoryRuntime::fixTopicPath).toList(); : topicResult.getRelatedTopicPaths().stream().map(memoryRuntime::fixTopicPath).toList();
memoryRuntime.recordMemory(result.memoryUnit(), topicPath, relatedTopicPaths); ActivationProfile activationProfile = stabilizeActivationProfile(
}).onFailure(exp -> memoryRuntime.recordMemory(result.memoryUnit(), null, List.of())); topicResult.getActivationProfile(),
relatedTopicPaths,
slicedMessages
);
memoryRuntime.recordMemory(result.memoryUnit(), topicPath, relatedTopicPaths, activationProfile);
}).onFailure(exp -> memoryRuntime.recordMemory(
result.memoryUnit(),
null,
List.of(),
defaultActivationProfile()
));
} }
private List<Message> sliceMessages(RollingResult result) { private List<Message> sliceMessages(RollingResult result) {
@@ -91,4 +98,44 @@ public class MemoryUpdater extends AbstractAgentModule.Standalone implements Aft
public String modelKey() { public String modelKey() {
return "topic_extractor"; return "topic_extractor";
} }
private ActivationProfile stabilizeActivationProfile(ActivationProfile activationProfile,
List<String> relatedTopicPaths,
List<Message> slicedMessages) {
ActivationProfile profile = activationProfile == null ? defaultActivationProfile() : new ActivationProfile(
activationProfile.getActivationWeight(),
activationProfile.getDiffusionWeight(),
activationProfile.getContextIndependenceWeight()
);
profile.setActivationWeight(clampOrDefault(profile.getActivationWeight(), DEFAULT_ACTIVATION_WEIGHT));
profile.setDiffusionWeight(clampOrDefault(profile.getDiffusionWeight(), DEFAULT_DIFFUSION_WEIGHT));
profile.setContextIndependenceWeight(clampOrDefault(
profile.getContextIndependenceWeight(),
DEFAULT_CONTEXT_INDEPENDENCE_WEIGHT
));
if (relatedTopicPaths.isEmpty()) {
profile.setDiffusionWeight(Math.min(profile.getDiffusionWeight(), NO_RELATED_DIFFUSION_CAP));
}
if (slicedMessages.size() <= 1) {
profile.setActivationWeight(clamp(profile.getActivationWeight() - SINGLE_MESSAGE_ACTIVATION_PENALTY));
}
return profile;
}
private ActivationProfile defaultActivationProfile() {
return new ActivationProfile(
DEFAULT_ACTIVATION_WEIGHT,
DEFAULT_DIFFUSION_WEIGHT,
DEFAULT_CONTEXT_INDEPENDENCE_WEIGHT
);
}
private float clampOrDefault(Float value, float defaultValue) {
return value == null ? defaultValue : clamp(value);
}
private float clamp(float value) {
return Math.clamp(value, 0.0f, 1.0f);
}
} }

View File

@@ -1,6 +1,7 @@
package work.slhaf.partner.module.memory.updater.summarizer.entity; package work.slhaf.partner.module.memory.updater.summarizer.entity;
import lombok.Data; import lombok.Data;
import work.slhaf.partner.module.memory.pojo.ActivationProfile;
import java.util.List; import java.util.List;
@@ -8,4 +9,5 @@ import java.util.List;
public class MemoryTopicResult { public class MemoryTopicResult {
private String topicPath; private String topicPath;
private List<String> relatedTopicPaths; private List<String> relatedTopicPaths;
private ActivationProfile activationProfile;
} }

View File

@@ -11,9 +11,9 @@ import work.slhaf.partner.core.cognition.CognitionCapability;
import work.slhaf.partner.core.memory.MemoryCapability; import work.slhaf.partner.core.memory.MemoryCapability;
import work.slhaf.partner.core.memory.pojo.MemorySlice; import work.slhaf.partner.core.memory.pojo.MemorySlice;
import work.slhaf.partner.core.memory.pojo.MemoryUnit; import work.slhaf.partner.core.memory.pojo.MemoryUnit;
import work.slhaf.partner.core.memory.pojo.SliceRef;
import work.slhaf.partner.framework.agent.model.pojo.Message; import work.slhaf.partner.framework.agent.model.pojo.Message;
import work.slhaf.partner.framework.agent.support.Result; import work.slhaf.partner.framework.agent.support.Result;
import work.slhaf.partner.module.memory.pojo.ActivationProfile;
import work.slhaf.partner.module.memory.runtime.exception.MemoryLookupException; import work.slhaf.partner.module.memory.runtime.exception.MemoryLookupException;
import work.slhaf.partner.module.memory.selector.ActivatedMemorySlice; import work.slhaf.partner.module.memory.selector.ActivatedMemorySlice;
@@ -25,7 +25,6 @@ import java.util.Collection;
import java.util.HashMap; import java.util.HashMap;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.concurrent.locks.Lock; import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReentrantLock; import java.util.concurrent.locks.ReentrantLock;
@@ -34,18 +33,13 @@ import static org.junit.jupiter.api.Assertions.assertTrue;
class MemoryRuntimeTest { class MemoryRuntimeTest {
private static final ActivationProfile DEFAULT_PROFILE = new ActivationProfile(0.55f, 0.35f, 0.50f);
@BeforeAll @BeforeAll
public static void beforeAll(@TempDir Path tempDir) { public static void beforeAll(@TempDir Path tempDir) {
System.setProperty("user.home", tempDir.toAbsolutePath().toString()); System.setProperty("user.home", tempDir.toAbsolutePath().toString());
} }
@SuppressWarnings("unchecked")
private static Map<String, CopyOnWriteArrayList<SliceRef>> topicSlices(MemoryRuntime runtime) throws Exception {
Field field = MemoryRuntime.class.getDeclaredField("topicSlices");
field.setAccessible(true);
return (Map<String, CopyOnWriteArrayList<SliceRef>>) field.get(runtime);
}
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
private static List<Message> invokeSliceMessages(MemoryRuntime runtime, MemoryUnit unit, MemorySlice slice) throws Exception { private static List<Message> invokeSliceMessages(MemoryRuntime runtime, MemoryUnit unit, MemorySlice slice) throws Exception {
Method method = MemoryRuntime.class.getDeclaredMethod("sliceMessages", MemoryUnit.class, MemorySlice.class); Method method = MemoryRuntime.class.getDeclaredMethod("sliceMessages", MemoryUnit.class, MemorySlice.class);
@@ -152,13 +146,92 @@ class MemoryRuntimeTest {
unit.getSlices().addAll(List.of(firstSlice, secondSlice)); unit.getSlices().addAll(List.of(firstSlice, secondSlice));
memoryCapability.remember(unit); memoryCapability.remember(unit);
runtime.recordMemory(unit, "topic/main", List.of("topic/related")); runtime.recordMemory(unit, "topic/main", List.of("topic/related"), DEFAULT_PROFILE);
Map<String, CopyOnWriteArrayList<SliceRef>> topicSlices = topicSlices(runtime); List<ActivatedMemorySlice> topicResult = runtime.queryActivatedMemoryByTopicPath("topic/main");
assertEquals(List.of("slice-2"), topicResult.stream().map(ActivatedMemorySlice::getSliceId).toList());
assertTrue(runtime.getTopicTree().contains("topic/main [root]"));
assertTrue(runtime.getTopicTree().contains("topic/related [root]"));
assertTrue(JSONObject.parseObject(runtime.convert().toString())
.getJSONArray("topic_slices")
.stream()
.map(JSONObject.class::cast)
.anyMatch(item -> "topic/main".equals(item.getString("topic_path"))));
}
@Test
void shouldExpandTopicQueryToLatestRelatedTopicMemory() throws Exception {
StubMemoryCapability memoryCapability = new StubMemoryCapability("session-test");
MemoryRuntime runtime = new MemoryRuntime();
setField(runtime, "memoryCapability", memoryCapability);
setField(runtime, "cognitionCapability", stubCognitionCapability(List.of(message("seed"))));
MemoryUnit mainUnit = new MemoryUnit("unit-main");
mainUnit.getConversationMessages().addAll(List.of(
message("m0"),
message("m1"),
message("m2"),
message("m3")
));
MemorySlice mainSlice = MemorySlice.restore("slice-main", 0, 2, "main", 86_400_000L);
mainUnit.getSlices().add(mainSlice);
memoryCapability.remember(mainUnit);
MemoryUnit relatedUnit = new MemoryUnit("unit-related");
relatedUnit.getConversationMessages().addAll(List.of(
message("r0"),
message("r1")
));
MemorySlice relatedSlice = MemorySlice.restore("slice-related", 0, 2, "related", 172_800_000L);
relatedUnit.getSlices().add(relatedSlice);
memoryCapability.remember(relatedUnit);
runtime.recordMemory(mainUnit, "topic/main", List.of("topic/related"), DEFAULT_PROFILE);
runtime.recordMemory(relatedUnit, "topic/related", List.of(), DEFAULT_PROFILE);
List<ActivatedMemorySlice> topicResult = runtime.queryActivatedMemoryByTopicPath("topic/main");
assertEquals(List.of("slice-main", "slice-related"),
topicResult.stream().map(ActivatedMemorySlice::getSliceId).toList());
}
@Test
void shouldIndexDateIncrementallyWithoutRebuildingWholeUnit() throws Exception {
StubMemoryCapability memoryCapability = new StubMemoryCapability("session-test");
MemoryRuntime runtime = new MemoryRuntime();
setField(runtime, "memoryCapability", memoryCapability);
setField(runtime, "cognitionCapability", stubCognitionCapability(List.of(message("seed"))));
MemoryUnit firstUnitSnapshot = new MemoryUnit("unit-100");
firstUnitSnapshot.getConversationMessages().addAll(List.of(message("m0"), message("m1")));
MemorySlice firstSlice = MemorySlice.restore("slice-1", 0, 1, "first", 86_400_000L);
firstUnitSnapshot.getSlices().add(firstSlice);
memoryCapability.remember(firstUnitSnapshot);
runtime.recordMemory(firstUnitSnapshot, "topic/main", List.of(), DEFAULT_PROFILE);
firstUnitSnapshot.getConversationMessages().clear();
firstUnitSnapshot.getConversationMessages().addAll(List.of(message("m2"), message("m3")));
MemorySlice secondSlice = MemorySlice.restore("slice-2", 0, 1, "second", 172_800_000L);
firstUnitSnapshot.getSlices().clear();
firstUnitSnapshot.getSlices().add(secondSlice);
memoryCapability.remember(firstUnitSnapshot);
runtime.recordMemory(firstUnitSnapshot, "topic/main", List.of(), DEFAULT_PROFILE);
JSONObject state = JSONObject.parseObject(runtime.convert().toString());
JSONArray dateIndex = state.getJSONArray("date_index");
JSONObject firstDate = dateIndex.stream()
.map(JSONObject.class::cast)
.filter(item -> "1970-01-02".equals(item.getString("date")))
.findFirst()
.orElseThrow();
JSONObject secondDate = dateIndex.stream()
.map(JSONObject.class::cast)
.filter(item -> "1970-01-03".equals(item.getString("date")))
.findFirst()
.orElseThrow();
assertEquals(List.of("slice-1"),
firstDate.getJSONArray("refs").toJavaList(JSONObject.class).stream().map(obj -> obj.getString("slice_id")).toList());
assertEquals(List.of("slice-2"), assertEquals(List.of("slice-2"),
topicSlices.get("topic/main").stream().map(SliceRef::getSliceId).toList()); secondDate.getJSONArray("refs").toJavaList(JSONObject.class).stream().map(obj -> obj.getString("slice_id")).toList());
assertEquals(List.of("slice-2"),
topicSlices.get("topic/related").stream().map(SliceRef::getSliceId).toList());
} }
@Test @Test
@@ -168,8 +241,8 @@ class MemoryRuntimeTest {
setField(runtime, "memoryCapability", memoryCapability); setField(runtime, "memoryCapability", memoryCapability);
setField(runtime, "cognitionCapability", stubCognitionCapability(List.of(message("seed")))); setField(runtime, "cognitionCapability", stubCognitionCapability(List.of(message("seed"))));
MemoryUnit unit = new MemoryUnit("unit-100"); MemoryUnit mainUnit = new MemoryUnit("unit-200");
unit.getConversationMessages().addAll(List.of( mainUnit.getConversationMessages().addAll(List.of(
message("m0"), message("m0"),
message("m1"), message("m1"),
message("m2"), message("m2"),
@@ -177,10 +250,16 @@ class MemoryRuntimeTest {
)); ));
MemorySlice firstSlice = MemorySlice.restore("slice-1", 0, 2, "first", 86_400_000L); MemorySlice firstSlice = MemorySlice.restore("slice-1", 0, 2, "first", 86_400_000L);
MemorySlice secondSlice = MemorySlice.restore("slice-2", 2, 4, "second", 172_800_000L); MemorySlice secondSlice = MemorySlice.restore("slice-2", 2, 4, "second", 172_800_000L);
unit.getSlices().addAll(List.of(firstSlice, secondSlice)); mainUnit.getSlices().addAll(List.of(firstSlice, secondSlice));
memoryCapability.remember(unit); memoryCapability.remember(mainUnit);
runtime.recordMemory(mainUnit, "topic/main", List.of("topic/related"), DEFAULT_PROFILE);
runtime.recordMemory(unit, "topic/main", List.of("topic/related")); MemoryUnit relatedUnit = new MemoryUnit("unit-201");
relatedUnit.getConversationMessages().addAll(List.of(message("r0"), message("r1")));
MemorySlice relatedSlice = MemorySlice.restore("slice-3", 0, 2, "related", 259_200_000L);
relatedUnit.getSlices().add(relatedSlice);
memoryCapability.remember(relatedUnit);
runtime.recordMemory(relatedUnit, "topic/related", List.of(), DEFAULT_PROFILE);
JSONObject state = JSONObject.parseObject(runtime.convert().toString()); JSONObject state = JSONObject.parseObject(runtime.convert().toString());
JSONArray topicSlices = state.getJSONArray("topic_slices"); JSONArray topicSlices = state.getJSONArray("topic_slices");
@@ -190,16 +269,23 @@ class MemoryRuntimeTest {
.filter(item -> "topic/main".equals(item.getString("topic_path"))) .filter(item -> "topic/main".equals(item.getString("topic_path")))
.findFirst() .findFirst()
.orElseThrow(); .orElseThrow();
assertEquals("slice-2", mainTopic.getJSONArray("refs").getJSONObject(0).getString("slice_id")); JSONObject binding = mainTopic.getJSONArray("bindings").getJSONObject(0);
assertEquals("slice-2", binding.getString("slice_id"));
assertEquals(172_800_000L, binding.getLongValue("timestamp"));
assertEquals(List.of("topic/related"), binding.getJSONArray("related_topic_paths").toJavaList(String.class));
JSONObject activationProfile = binding.getJSONObject("activation_profile");
assertEquals(0.55f, activationProfile.getFloatValue("activation_weight"));
assertEquals(0.35f, activationProfile.getFloatValue("diffusion_weight"));
assertEquals(0.50f, activationProfile.getFloatValue("context_independence_weight"));
JSONArray dateIndex = state.getJSONArray("date_index"); JSONArray dateIndex = state.getJSONArray("date_index");
assertEquals(2, dateIndex.size()); assertEquals(2, dateIndex.size());
JSONObject secondDate = dateIndex.stream() JSONObject thirdDate = dateIndex.stream()
.map(JSONObject.class::cast) .map(JSONObject.class::cast)
.filter(item -> "1970-01-03".equals(item.getString("date"))) .filter(item -> "1970-01-04".equals(item.getString("date")))
.findFirst() .findFirst()
.orElseThrow(); .orElseThrow();
assertEquals("slice-2", secondDate.getJSONArray("refs").getJSONObject(0).getString("slice_id")); assertEquals("slice-3", thirdDate.getJSONArray("refs").getJSONObject(0).getString("slice_id"));
MemoryRuntime restored = new MemoryRuntime(); MemoryRuntime restored = new MemoryRuntime();
setField(restored, "memoryCapability", memoryCapability); setField(restored, "memoryCapability", memoryCapability);
@@ -207,14 +293,14 @@ class MemoryRuntimeTest {
restored.load(state); restored.load(state);
List<ActivatedMemorySlice> topicResult = restored.queryActivatedMemoryByTopicPath("topic/main"); List<ActivatedMemorySlice> topicResult = restored.queryActivatedMemoryByTopicPath("topic/main");
assertEquals(1, topicResult.size()); assertEquals(List.of("slice-2", "slice-3"),
assertEquals("slice-2", topicResult.getFirst().getSliceId()); topicResult.stream().map(ActivatedMemorySlice::getSliceId).toList());
assertEquals(List.of("m2", "m3"), topicResult.getFirst().getMessages().stream().map(Message::getContent).toList()); assertEquals(List.of("m2", "m3"), topicResult.getFirst().getMessages().stream().map(Message::getContent).toList());
List<ActivatedMemorySlice> dateResult = restored.queryActivatedMemoryByDate(LocalDate.parse("1970-01-03")); List<ActivatedMemorySlice> dateResult = restored.queryActivatedMemoryByDate(LocalDate.parse("1970-01-04"));
assertEquals(1, dateResult.size()); assertEquals(1, dateResult.size());
assertEquals("slice-2", dateResult.getFirst().getSliceId()); assertEquals("slice-3", dateResult.getFirst().getSliceId());
assertEquals("second", dateResult.getFirst().getSummary()); assertEquals("related", dateResult.getFirst().getSummary());
} }
@Test @Test
@@ -228,6 +314,106 @@ class MemoryRuntimeTest {
assertTrue(runtime.queryActivatedMemoryByDate(LocalDate.parse("1970-01-01")).isEmpty()); assertTrue(runtime.queryActivatedMemoryByDate(LocalDate.parse("1970-01-01")).isEmpty());
} }
@Test
void shouldRankTopicMatchesBySourceAndActivationProfile() throws Exception {
StubMemoryCapability memoryCapability = new StubMemoryCapability("session-test");
MemoryRuntime runtime = new MemoryRuntime();
setField(runtime, "memoryCapability", memoryCapability);
setField(runtime, "cognitionCapability", stubCognitionCapability(List.of(message("seed"))));
MemoryUnit primaryUnit = new MemoryUnit("unit-primary");
primaryUnit.getConversationMessages().addAll(List.of(message("p0"), message("p1")));
MemorySlice primarySlice = MemorySlice.restore("slice-primary", 0, 2, "primary", System.currentTimeMillis());
primaryUnit.getSlices().add(primarySlice);
memoryCapability.remember(primaryUnit);
runtime.recordMemory(primaryUnit, "topic->main", List.of("topic->related"), new ActivationProfile(0.9f, 0.1f, 0.9f));
MemoryUnit relatedUnit = new MemoryUnit("unit-related-rank");
relatedUnit.getConversationMessages().addAll(List.of(message("r0"), message("r1")));
MemorySlice relatedSlice = MemorySlice.restore("slice-related-rank", 0, 2, "related", System.currentTimeMillis());
relatedUnit.getSlices().add(relatedSlice);
memoryCapability.remember(relatedUnit);
runtime.recordMemory(relatedUnit, "topic->related", List.of(), new ActivationProfile(1.0f, 1.0f, 1.0f));
MemoryUnit parentUnit = new MemoryUnit("unit-parent");
parentUnit.getConversationMessages().addAll(List.of(message("x0"), message("x1")));
MemorySlice parentSlice = MemorySlice.restore("slice-parent", 0, 2, "parent", System.currentTimeMillis());
parentUnit.getSlices().add(parentSlice);
memoryCapability.remember(parentUnit);
runtime.recordMemory(parentUnit, "topic", List.of(), new ActivationProfile(1.0f, 1.0f, 1.0f));
List<ActivatedMemorySlice> topicResult = runtime.queryActivatedMemoryByTopicPath("topic->main");
assertEquals(List.of("slice-primary", "slice-related-rank", "slice-parent"),
topicResult.stream().map(ActivatedMemorySlice::getSliceId).toList());
}
@Test
void shouldNotExpandRelatedTopicWhenDiffusionWeightIsZero() throws Exception {
StubMemoryCapability memoryCapability = new StubMemoryCapability("session-test");
MemoryRuntime runtime = new MemoryRuntime();
setField(runtime, "memoryCapability", memoryCapability);
setField(runtime, "cognitionCapability", stubCognitionCapability(List.of(message("seed"))));
MemoryUnit primaryUnit = new MemoryUnit("unit-primary-zero");
primaryUnit.getConversationMessages().addAll(List.of(message("p0"), message("p1")));
MemorySlice primarySlice = MemorySlice.restore("slice-primary-zero", 0, 2, "primary", System.currentTimeMillis());
primaryUnit.getSlices().add(primarySlice);
memoryCapability.remember(primaryUnit);
runtime.recordMemory(
primaryUnit,
"topic->main",
List.of("topic->related"),
new ActivationProfile(0.8f, 0.0f, 0.8f)
);
MemoryUnit relatedUnit = new MemoryUnit("unit-related-zero");
relatedUnit.getConversationMessages().addAll(List.of(message("r0"), message("r1")));
MemorySlice relatedSlice = MemorySlice.restore("slice-related-zero", 0, 2, "related", System.currentTimeMillis());
relatedUnit.getSlices().add(relatedSlice);
memoryCapability.remember(relatedUnit);
runtime.recordMemory(relatedUnit, "topic->related", List.of(), new ActivationProfile(1.0f, 1.0f, 1.0f));
List<ActivatedMemorySlice> topicResult = runtime.queryActivatedMemoryByTopicPath("topic->main");
assertEquals(List.of("slice-primary-zero"), topicResult.stream().map(ActivatedMemorySlice::getSliceId).toList());
}
@Test
void shouldRefreshBindingTimestampAndActivationProfileWhenSameSliceRebound() throws Exception {
StubMemoryCapability memoryCapability = new StubMemoryCapability("session-test");
MemoryRuntime runtime = new MemoryRuntime();
setField(runtime, "memoryCapability", memoryCapability);
setField(runtime, "cognitionCapability", stubCognitionCapability(List.of(message("seed"))));
MemoryUnit unit = new MemoryUnit("unit-refresh");
unit.getConversationMessages().addAll(List.of(message("m0"), message("m1")));
MemorySlice slice = MemorySlice.restore("slice-refresh", 0, 2, "summary", 86_400_000L);
unit.getSlices().add(slice);
memoryCapability.remember(unit);
runtime.recordMemory(unit, "topic->main", List.of("topic->related"), new ActivationProfile(0.2f, 0.1f, 0.2f));
unit.getSlices().clear();
unit.getSlices().add(MemorySlice.restore("slice-refresh", 0, 2, "summary", 172_800_000L));
runtime.recordMemory(unit, "topic->main", List.of("topic->related-2"), new ActivationProfile(0.9f, 0.8f, 0.7f));
JSONObject state = JSONObject.parseObject(runtime.convert().toString());
JSONObject mainTopic = state.getJSONArray("topic_slices").stream()
.map(JSONObject.class::cast)
.filter(item -> "topic->main".equals(item.getString("topic_path")))
.findFirst()
.orElseThrow();
JSONObject binding = mainTopic.getJSONArray("bindings").getJSONObject(0);
JSONObject activationProfile = binding.getJSONObject("activation_profile");
assertEquals(172_800_000L, binding.getLongValue("timestamp"));
assertEquals(0.9f, activationProfile.getFloatValue("activation_weight"));
assertEquals(0.8f, activationProfile.getFloatValue("diffusion_weight"));
assertEquals(0.7f, activationProfile.getFloatValue("context_independence_weight"));
assertEquals(
List.of("topic->related", "topic->related-2"),
binding.getJSONArray("related_topic_paths").toJavaList(String.class)
);
}
private static final class StubMemoryCapability implements MemoryCapability { private static final class StubMemoryCapability implements MemoryCapability {
private final String sessionId; private final String sessionId;
private final Map<String, MemoryUnit> units = new HashMap<>(); private final Map<String, MemoryUnit> units = new HashMap<>();

View File

@@ -4,7 +4,6 @@ import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir; import org.junit.jupiter.api.io.TempDir;
import org.mockito.Mockito; import org.mockito.Mockito;
import work.slhaf.partner.core.cognition.CognitionCapability;
import work.slhaf.partner.core.memory.pojo.MemorySlice; import work.slhaf.partner.core.memory.pojo.MemorySlice;
import work.slhaf.partner.core.memory.pojo.MemoryUnit; import work.slhaf.partner.core.memory.pojo.MemoryUnit;
import work.slhaf.partner.framework.agent.exception.AgentRuntimeException; import work.slhaf.partner.framework.agent.exception.AgentRuntimeException;
@@ -12,6 +11,7 @@ import work.slhaf.partner.framework.agent.model.pojo.Message;
import work.slhaf.partner.framework.agent.support.Result; import work.slhaf.partner.framework.agent.support.Result;
import work.slhaf.partner.module.communication.AfterRollingRegistry; import work.slhaf.partner.module.communication.AfterRollingRegistry;
import work.slhaf.partner.module.communication.RollingResult; import work.slhaf.partner.module.communication.RollingResult;
import work.slhaf.partner.module.memory.pojo.ActivationProfile;
import work.slhaf.partner.module.memory.runtime.MemoryRuntime; import work.slhaf.partner.module.memory.runtime.MemoryRuntime;
import work.slhaf.partner.module.memory.updater.summarizer.entity.MemoryTopicResult; import work.slhaf.partner.module.memory.updater.summarizer.entity.MemoryTopicResult;
@@ -19,6 +19,7 @@ import java.lang.reflect.Field;
import java.nio.file.Path; import java.nio.file.Path;
import java.util.List; import java.util.List;
import static org.mockito.ArgumentMatchers.argThat;
import static org.mockito.ArgumentMatchers.eq; import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when; import static org.mockito.Mockito.when;
@@ -55,11 +56,7 @@ class MemoryUpdaterTest {
void shouldExtractTopicAndRecordMemoryOnConsume() throws Exception { void shouldExtractTopicAndRecordMemoryOnConsume() throws Exception {
MemoryUpdater updater = Mockito.spy(new MemoryUpdater()); MemoryUpdater updater = Mockito.spy(new MemoryUpdater());
MemoryRuntime memoryRuntime = Mockito.mock(MemoryRuntime.class); MemoryRuntime memoryRuntime = Mockito.mock(MemoryRuntime.class);
CognitionCapability cognitionCapability = Mockito.mock(CognitionCapability.class);
setField(updater, "memoryRuntime", memoryRuntime); setField(updater, "memoryRuntime", memoryRuntime);
setField(updater, "cognitionCapability", cognitionCapability);
when(cognitionCapability.contextWorkspace()).thenReturn(new work.slhaf.partner.core.cognition.ContextWorkspace());
when(memoryRuntime.getTopicTree()).thenReturn("topic-tree"); when(memoryRuntime.getTopicTree()).thenReturn("topic-tree");
when(memoryRuntime.fixTopicPath("root[2]->branch[1]")).thenReturn("root->branch"); when(memoryRuntime.fixTopicPath("root[2]->branch[1]")).thenReturn("root->branch");
when(memoryRuntime.fixTopicPath("root[2]->related[1]")).thenReturn("root->related"); when(memoryRuntime.fixTopicPath("root[2]->related[1]")).thenReturn("root->related");
@@ -67,6 +64,7 @@ class MemoryUpdaterTest {
MemoryTopicResult topicResult = new MemoryTopicResult(); MemoryTopicResult topicResult = new MemoryTopicResult();
topicResult.setTopicPath("root[2]->branch[1]"); topicResult.setTopicPath("root[2]->branch[1]");
topicResult.setRelatedTopicPaths(List.of("root[2]->related[1]")); topicResult.setRelatedTopicPaths(List.of("root[2]->related[1]"));
topicResult.setActivationProfile(new ActivationProfile(0.8f, 0.9f, 0.7f));
Mockito.doReturn(Result.success(topicResult)) Mockito.doReturn(Result.success(topicResult))
.when(updater) .when(updater)
.formattedChat(Mockito.anyList(), eq(MemoryTopicResult.class)); .formattedChat(Mockito.anyList(), eq(MemoryTopicResult.class));
@@ -86,18 +84,22 @@ class MemoryUpdaterTest {
message(Message.Character.ASSISTANT, "new-reply") message(Message.Character.ASSISTANT, "new-reply")
), "slice-summary", 4, 6)); ), "slice-summary", 4, 6));
verify(memoryRuntime).recordMemory(eq(unit), eq("root->branch"), eq(List.of("root->related"))); verify(memoryRuntime).recordMemory(
eq(unit),
eq("root->branch"),
eq(List.of("root->related")),
argThat(profile -> profile != null
&& profile.getActivationWeight() == 0.8f
&& profile.getDiffusionWeight() == 0.9f
&& profile.getContextIndependenceWeight() == 0.7f)
);
} }
@Test @Test
void shouldFallbackToDateOnlyRecordWhenExtractionFails() throws Exception { void shouldFallbackToDateOnlyRecordWhenExtractionFails() throws Exception {
MemoryUpdater updater = Mockito.spy(new MemoryUpdater()); MemoryUpdater updater = Mockito.spy(new MemoryUpdater());
MemoryRuntime memoryRuntime = Mockito.mock(MemoryRuntime.class); MemoryRuntime memoryRuntime = Mockito.mock(MemoryRuntime.class);
CognitionCapability cognitionCapability = Mockito.mock(CognitionCapability.class);
setField(updater, "memoryRuntime", memoryRuntime); setField(updater, "memoryRuntime", memoryRuntime);
setField(updater, "cognitionCapability", cognitionCapability);
when(cognitionCapability.contextWorkspace()).thenReturn(new work.slhaf.partner.core.cognition.ContextWorkspace());
when(memoryRuntime.getTopicTree()).thenReturn("topic-tree"); when(memoryRuntime.getTopicTree()).thenReturn("topic-tree");
Mockito.doReturn(Result.failure(new AgentRuntimeException("boom"))) Mockito.doReturn(Result.failure(new AgentRuntimeException("boom")))
.when(updater) .when(updater)
@@ -113,6 +115,48 @@ class MemoryUpdaterTest {
updater.consume(new RollingResult(unit, slice, unit.getConversationMessages(), "slice-summary", 2, 6)); updater.consume(new RollingResult(unit, slice, unit.getConversationMessages(), "slice-summary", 2, 6));
verify(memoryRuntime).recordMemory(eq(unit), eq(null), eq(List.of())); verify(memoryRuntime).recordMemory(
eq(unit),
eq(null),
eq(List.of()),
argThat(profile -> profile != null
&& profile.getActivationWeight() == 0.55f
&& profile.getDiffusionWeight() == 0.35f
&& profile.getContextIndependenceWeight() == 0.50f)
);
}
@Test
void shouldClampAndAdjustActivationProfileBeforeRecording() throws Exception {
MemoryUpdater updater = Mockito.spy(new MemoryUpdater());
MemoryRuntime memoryRuntime = Mockito.mock(MemoryRuntime.class);
setField(updater, "memoryRuntime", memoryRuntime);
when(memoryRuntime.getTopicTree()).thenReturn("topic-tree");
when(memoryRuntime.fixTopicPath("root[2]->branch[1]")).thenReturn("root->branch");
MemoryTopicResult topicResult = new MemoryTopicResult();
topicResult.setTopicPath("root[2]->branch[1]");
topicResult.setRelatedTopicPaths(List.of());
topicResult.setActivationProfile(new ActivationProfile(1.5f, 0.9f, -0.2f));
Mockito.doReturn(Result.success(topicResult))
.when(updater)
.formattedChat(Mockito.anyList(), eq(MemoryTopicResult.class));
MemoryUnit unit = new MemoryUnit("session-3");
unit.getConversationMessages().add(message(Message.Character.USER, "only"));
MemorySlice slice = new MemorySlice(0, 1, "slice-summary");
unit.getSlices().add(slice);
updater.consume(new RollingResult(unit, slice, unit.getConversationMessages(), "slice-summary", 1, 6));
verify(memoryRuntime).recordMemory(
eq(unit),
eq("root->branch"),
eq(List.of()),
argThat(profile -> profile != null
&& profile.getActivationWeight() == 0.95f
&& profile.getDiffusionWeight() == 0.45f
&& profile.getContextIndependenceWeight() == 0.0f)
);
} }
} }