diff --git a/Partner-Core/src/main/java/work/slhaf/partner/module/memory/pojo/ActivationProfile.java b/Partner-Core/src/main/java/work/slhaf/partner/module/memory/pojo/ActivationProfile.java new file mode 100644 index 00000000..0d182052 --- /dev/null +++ b/Partner-Core/src/main/java/work/slhaf/partner/module/memory/pojo/ActivationProfile.java @@ -0,0 +1,14 @@ +package work.slhaf.partner.module.memory.pojo; + +import lombok.AllArgsConstructor; +import lombok.Data; +import lombok.NoArgsConstructor; + +@Data +@NoArgsConstructor +@AllArgsConstructor +public class ActivationProfile { + private Float activationWeight; + private Float diffusionWeight; + private Float contextIndependenceWeight; +} diff --git a/Partner-Core/src/main/java/work/slhaf/partner/module/memory/runtime/DateMemoryIndex.java b/Partner-Core/src/main/java/work/slhaf/partner/module/memory/runtime/DateMemoryIndex.java new file mode 100644 index 00000000..425aae09 --- /dev/null +++ b/Partner-Core/src/main/java/work/slhaf/partner/module/memory/runtime/DateMemoryIndex.java @@ -0,0 +1,36 @@ +package work.slhaf.partner.module.memory.runtime; + +import work.slhaf.partner.core.memory.pojo.SliceRef; + +import java.time.LocalDate; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.CopyOnWriteArrayList; + +final class DateMemoryIndex { + + private final Map> dateIndex = new HashMap<>(); + + void record(SliceRef sliceRef, LocalDate date) { + dateIndex.computeIfAbsent(date, key -> new CopyOnWriteArrayList<>()).addIfAbsent(sliceRef); + } + + List find(LocalDate date) { + List refs = dateIndex.get(date); + return refs == null ? null : new ArrayList<>(refs); + } + + void reset() { + dateIndex.clear(); + } + + void restore(LocalDate date, CopyOnWriteArrayList refs) { + dateIndex.put(date, refs); + } + + Map> entries() { + return dateIndex; + } +} diff --git a/Partner-Core/src/main/java/work/slhaf/partner/module/memory/runtime/MemoryRuntime.java b/Partner-Core/src/main/java/work/slhaf/partner/module/memory/runtime/MemoryRuntime.java index 247f566b..b83e178f 100644 --- a/Partner-Core/src/main/java/work/slhaf/partner/module/memory/runtime/MemoryRuntime.java +++ b/Partner-Core/src/main/java/work/slhaf/partner/module/memory/runtime/MemoryRuntime.java @@ -1,8 +1,6 @@ package work.slhaf.partner.module.memory.runtime; -import com.alibaba.fastjson2.JSONArray; import com.alibaba.fastjson2.JSONObject; -import lombok.extern.slf4j.Slf4j; import org.jetbrains.annotations.NotNull; import work.slhaf.partner.core.cognition.CognitionCapability; import work.slhaf.partner.core.memory.MemoryCapability; @@ -16,8 +14,8 @@ import work.slhaf.partner.framework.agent.factory.component.annotation.Init; import work.slhaf.partner.framework.agent.model.pojo.Message; import work.slhaf.partner.framework.agent.state.State; import work.slhaf.partner.framework.agent.state.StateSerializable; -import work.slhaf.partner.framework.agent.state.StateValue; import work.slhaf.partner.framework.agent.support.Result; +import work.slhaf.partner.module.memory.pojo.ActivationProfile; import work.slhaf.partner.module.memory.runtime.exception.MemoryLookupException; import work.slhaf.partner.module.memory.selector.ActivatedMemorySlice; @@ -25,11 +23,10 @@ import java.nio.file.Path; import java.time.Instant; import java.time.LocalDate; import java.time.ZoneId; -import java.util.*; -import java.util.concurrent.CopyOnWriteArrayList; +import java.util.ArrayList; +import java.util.List; import java.util.concurrent.locks.ReentrantLock; -@Slf4j public class MemoryRuntime extends AbstractAgentModule.Standalone implements StateSerializable { @InjectCapability @@ -38,8 +35,10 @@ public class MemoryRuntime extends AbstractAgentModule.Standalone implements Sta private CognitionCapability cognitionCapability; private final ReentrantLock runtimeLock = new ReentrantLock(); - private Map> topicSlices = new HashMap<>(); - private Map> dateIndex = new HashMap<>(); + private final TopicMemoryIndex topicIndex = new TopicMemoryIndex(); + private final DateMemoryIndex dateIndex = new DateMemoryIndex(); + private final TopicRecallCollector topicRecallCollector = new TopicRecallCollector(new TopicRecallScorer()); + private final MemoryRuntimeStateCodec stateCodec = new MemoryRuntimeStateCodec(); @Init public void init() { @@ -53,80 +52,32 @@ public class MemoryRuntime extends AbstractAgentModule.Standalone implements Sta } } - private void bindTopic(String topicPath, SliceRef sliceRef) { - String normalizedPath = normalizeTopicPath(topicPath); - runtimeLock.lock(); - try { - CopyOnWriteArrayList refs = topicSlices.computeIfAbsent(normalizedPath, key -> new CopyOnWriteArrayList<>()); - boolean exists = refs.stream().anyMatch(ref -> Objects.equals(ref.getUnitId(), sliceRef.getUnitId()) - && Objects.equals(ref.getSliceId(), sliceRef.getSliceId())); - if (!exists) { - refs.add(sliceRef); - } - } finally { - runtimeLock.unlock(); - } - } - - public void recordMemory(MemoryUnit memoryUnit, String topicPath, List relatedTopicPaths) { + public void recordMemory(MemoryUnit memoryUnit, + String topicPath, + List relatedTopicPaths, + ActivationProfile activationProfile) { MemorySlice memorySlice = memoryUnit.getSlices().getLast(); SliceRef sliceRef = new SliceRef(memoryUnit.getId(), memorySlice.getId()); - indexMemoryUnit(memoryUnit); - if (topicPath != null && !topicPath.isBlank()) { - bindTopic(topicPath, sliceRef); - } - for (String relatedTopicPath : relatedTopicPaths) { - if (relatedTopicPath != null && !relatedTopicPath.isBlank()) { - bindTopic(relatedTopicPath, sliceRef); - } - } - } - - private void indexMemoryUnit(MemoryUnit memoryUnit) { + LocalDate date = toLocalDate(memorySlice.getTimestamp()); runtimeLock.lock(); try { - for (CopyOnWriteArrayList refs : dateIndex.values()) { - refs.removeIf(ref -> memoryUnit.getId().equals(ref.getUnitId())); - } - for (MemorySlice slice : memoryUnit.getSlices()) { - LocalDate date = Instant.ofEpochMilli(slice.getTimestamp()) - .atZone(ZoneId.systemDefault()) - .toLocalDate(); - dateIndex.computeIfAbsent(date, key -> new CopyOnWriteArrayList<>()) - .addIfAbsent(new SliceRef(memoryUnit.getId(), slice.getId())); + List normalizedRelatedTopicPaths = topicIndex.normalizeTopicPaths(relatedTopicPaths); + dateIndex.record(sliceRef, date); + if (topicPath != null && !topicPath.isBlank()) { + topicIndex.recordBinding( + topicPath, + sliceRef, + memorySlice.getTimestamp(), + normalizedRelatedTopicPaths, + activationProfile + ); } + topicIndex.ensureTopicPaths(normalizedRelatedTopicPaths); } finally { runtimeLock.unlock(); } } - private List findByTopicPath(String topicPath) { - String normalizedPath = normalizeTopicPath(topicPath); - List refs = topicSlices.get(normalizedPath); - if (refs == null) { - ExceptionReporterHandler.INSTANCE.report(new MemoryLookupException( - "Unexisted topic path: " + normalizedPath, - normalizedPath, - "TOPIC" - )); - return List.of(); - } - return new ArrayList<>(refs); - } - - private List findByDate(LocalDate date) { - List refs = dateIndex.get(date); - if (refs == null) { - ExceptionReporterHandler.INSTANCE.report(new MemoryLookupException( - "Unexisted date index: " + date, - date.toString(), - "DATE_INDEX" - )); - return List.of(); - } - return new ArrayList<>(refs); - } - public List queryActivatedMemoryByTopicPath(String topicPath) { return buildActivatedMemorySlices(findByTopicPath(topicPath)); } @@ -136,23 +87,61 @@ public class MemoryRuntime extends AbstractAgentModule.Standalone implements Sta } public String getTopicTree() { - TopicTreeNode root = new TopicTreeNode(); - for (Map.Entry> entry : topicSlices.entrySet()) { - String[] parts = entry.getKey().split("->"); - TopicTreeNode current = root; - for (String part : parts) { - current = current.children.computeIfAbsent(part, key -> new TopicTreeNode()); - } - current.count += entry.getValue().size(); + runtimeLock.lock(); + try { + return topicIndex.getTopicTree(); + } finally { + runtimeLock.unlock(); } + } - StringBuilder stringBuilder = new StringBuilder(); - List> roots = new ArrayList<>(root.children.entrySet()); - for (Map.Entry entry : roots) { - stringBuilder.append(entry.getKey()).append("[root]").append("\r\n"); - printSubTopicsTreeFormat(entry.getValue(), "", stringBuilder); + public String fixTopicPath(String topicPath) { + String[] parts = topicPath.split("->"); + List cleanedParts = new ArrayList<>(); + for (String part : parts) { + String cleaned = part.replaceAll("\\[[^]]*]", "").trim(); + if (!cleaned.isEmpty()) { + cleanedParts.add(cleaned); + } + } + return String.join("->", cleanedParts); + } + + private List findByTopicPath(String topicPath) { + runtimeLock.lock(); + try { + TopicMemoryIndex.TopicTreeNode topicNode = topicIndex.findTopicNode(topicPath); + if (topicNode == null) { + String normalizedPath = topicIndex.normalizeTopicPath(topicPath); + ExceptionReporterHandler.INSTANCE.report(new MemoryLookupException( + "Unexisted topic path: " + normalizedPath, + normalizedPath, + "TOPIC" + )); + return List.of(); + } + return topicRecallCollector.collect(topicIndex, topicNode); + } finally { + runtimeLock.unlock(); + } + } + + private List findByDate(LocalDate date) { + runtimeLock.lock(); + try { + List refs = dateIndex.find(date); + if (refs == null) { + ExceptionReporterHandler.INSTANCE.report(new MemoryLookupException( + "Unexisted date index: " + date, + date.toString(), + "DATE_INDEX" + )); + return List.of(); + } + return refs; + } finally { + runtimeLock.unlock(); } - return stringBuilder.toString(); } private List buildActivatedMemorySlices(List refs) { @@ -169,14 +158,12 @@ public class MemoryRuntime extends AbstractAgentModule.Standalone implements Sta private ActivatedMemorySlice buildActivatedMemorySlice(SliceRef ref) { MemoryUnit memoryUnit = memoryCapability.getMemoryUnit(ref.getUnitId()); Result memorySliceResult = memoryCapability.getMemorySlice(ref.getUnitId(), ref.getSliceId()); - if (memorySliceResult.exceptionOrNull() != null) { + if (memoryUnit == null || memorySliceResult.exceptionOrNull() != null) { return null; } MemorySlice memorySlice = memorySliceResult.getOrThrow(); List messages = sliceMessages(memoryUnit, memorySlice); - LocalDate date = Instant.ofEpochMilli(memorySlice.getTimestamp()) - .atZone(ZoneId.systemDefault()) - .toLocalDate(); + LocalDate date = toLocalDate(memorySlice.getTimestamp()); return ActivatedMemorySlice.builder() .unitId(ref.getUnitId()) .sliceId(ref.getSliceId()) @@ -201,29 +188,14 @@ public class MemoryRuntime extends AbstractAgentModule.Standalone implements Sta return new ArrayList<>(conversationMessages.subList(start, end)); } - private void printSubTopicsTreeFormat(TopicTreeNode node, String prefix, StringBuilder stringBuilder) { - List> entries = new ArrayList<>(node.children.entrySet()); - for (int i = 0; i < entries.size(); i++) { - boolean last = i == entries.size() - 1; - Map.Entry entry = entries.get(i); - stringBuilder.append(prefix) - .append(last ? "└── " : "├── ") - .append(entry.getKey()) - .append("[") - .append(entry.getValue().count) - .append("]") - .append("\r\n"); - printSubTopicsTreeFormat(entry.getValue(), prefix + (last ? " " : "│ "), stringBuilder); - } - } - - private String normalizeTopicPath(String topicPath) { - return topicPath == null ? "" : topicPath.trim(); + private LocalDate toLocalDate(Long timestamp) { + return Instant.ofEpochMilli(timestamp) + .atZone(ZoneId.systemDefault()) + .toLocalDate(); } @Override - @NotNull - public Path statePath() { + public @NotNull Path statePath() { return Path.of("module", "memory", "topic_based_memory.json"); } @@ -231,42 +203,7 @@ public class MemoryRuntime extends AbstractAgentModule.Standalone implements Sta public void load(@NotNull JSONObject state) { runtimeLock.lock(); try { - topicSlices = new HashMap<>(); - dateIndex = new HashMap<>(); - - JSONArray topicSlicesArray = state.getJSONArray("topic_slices"); - if (topicSlicesArray != null) { - for (int i = 0; i < topicSlicesArray.size(); i++) { - JSONObject topicObject = topicSlicesArray.getJSONObject(i); - if (topicObject == null) { - continue; - } - String topicPath = topicObject.getString("topic_path"); - if (topicPath == null) { - continue; - } - topicSlices.put(normalizeTopicPath(topicPath), decodeSliceRefs(topicObject.getJSONArray("refs"))); - } - } - - JSONArray dateIndexArray = state.getJSONArray("date_index"); - if (dateIndexArray != null) { - for (int i = 0; i < dateIndexArray.size(); i++) { - JSONObject dateObject = dateIndexArray.getJSONObject(i); - if (dateObject == null) { - continue; - } - String date = dateObject.getString("date"); - if (date == null) { - continue; - } - try { - dateIndex.put(LocalDate.parse(date), decodeSliceRefs(dateObject.getJSONArray("refs"))); - } catch (Exception e) { - log.warn("skip invalid date index: {}", date, e); - } - } - } + stateCodec.load(state, topicIndex, dateIndex); } finally { runtimeLock.unlock(); } @@ -276,78 +213,9 @@ public class MemoryRuntime extends AbstractAgentModule.Standalone implements Sta public @NotNull State convert() { runtimeLock.lock(); try { - State state = new State(); - - List topicSliceStates = topicSlices.entrySet().stream() - .sorted(Map.Entry.comparingByKey()) - .map(entry -> StateValue.obj(Map.of( - "topic_path", StateValue.str(entry.getKey()), - "refs", StateValue.arr(encodeSliceRefs(entry.getValue())) - ))) - .toList(); - state.append("topic_slices", StateValue.arr(topicSliceStates)); - - List dateIndexStates = dateIndex.entrySet().stream() - .sorted(Map.Entry.comparingByKey()) - .map(entry -> StateValue.obj(Map.of( - "date", StateValue.str(entry.getKey().toString()), - "refs", StateValue.arr(encodeSliceRefs(entry.getValue())) - ))) - .toList(); - state.append("date_index", StateValue.arr(dateIndexStates)); - - return state; + return stateCodec.convert(topicIndex, dateIndex); } finally { runtimeLock.unlock(); } } - - private List encodeSliceRefs(List refs) { - return refs.stream() - .map(ref -> (StateValue) StateValue.obj(Map.of( - "unit_id", StateValue.str(ref.getUnitId()), - "slice_id", StateValue.str(ref.getSliceId()) - ))) - .toList(); - } - - private CopyOnWriteArrayList decodeSliceRefs(JSONArray refsArray) { - CopyOnWriteArrayList refs = new CopyOnWriteArrayList<>(); - if (refsArray == null) { - return refs; - } - for (int i = 0; i < refsArray.size(); i++) { - JSONObject refObject = refsArray.getJSONObject(i); - if (refObject == null) { - continue; - } - String unitId = refObject.getString("unit_id"); - String sliceId = refObject.getString("slice_id"); - if (unitId == null || sliceId == null) { - continue; - } - refs.addIfAbsent(new SliceRef(unitId, sliceId)); - } - return refs; - } - - public String fixTopicPath(String topicPath) { - String[] parts = topicPath.split("->"); - List cleanedParts = new ArrayList<>(); - - for (String part : parts) { - // 修正正则表达式,正确移除 [xxx] 部分 - String cleaned = part.replaceAll("\\[[^\\]]*\\]", "").trim(); - if (!cleaned.isEmpty()) { // 忽略空字符串 - cleanedParts.add(cleaned); - } - } - - return String.join("->", cleanedParts); - } - - private static final class TopicTreeNode { - private final Map children = new LinkedHashMap<>(); - private int count; - } } diff --git a/Partner-Core/src/main/java/work/slhaf/partner/module/memory/runtime/MemoryRuntimeStateCodec.java b/Partner-Core/src/main/java/work/slhaf/partner/module/memory/runtime/MemoryRuntimeStateCodec.java new file mode 100644 index 00000000..4b4dde49 --- /dev/null +++ b/Partner-Core/src/main/java/work/slhaf/partner/module/memory/runtime/MemoryRuntimeStateCodec.java @@ -0,0 +1,184 @@ +package work.slhaf.partner.module.memory.runtime; + +import com.alibaba.fastjson2.JSONArray; +import com.alibaba.fastjson2.JSONObject; +import lombok.extern.slf4j.Slf4j; +import work.slhaf.partner.core.memory.pojo.SliceRef; +import work.slhaf.partner.framework.agent.state.State; +import work.slhaf.partner.framework.agent.state.StateValue; +import work.slhaf.partner.module.memory.pojo.ActivationProfile; + +import java.time.LocalDate; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.concurrent.CopyOnWriteArrayList; + +@Slf4j +final class MemoryRuntimeStateCodec { + + void load(JSONObject state, TopicMemoryIndex topicIndex, DateMemoryIndex dateIndex) { + topicIndex.reset(); + dateIndex.reset(); + + JSONArray topicSlicesArray = state.getJSONArray("topic_slices"); + if (topicSlicesArray != null) { + for (int i = 0; i < topicSlicesArray.size(); i++) { + JSONObject topicObject = topicSlicesArray.getJSONObject(i); + if (topicObject == null) { + continue; + } + String topicPath = topicObject.getString("topic_path"); + if (topicPath == null) { + continue; + } + topicIndex.ensureTopicPath(topicPath); + decodeTopicBindings(topicIndex, topicPath, topicObject.getJSONArray("bindings")); + } + } + + JSONArray dateIndexArray = state.getJSONArray("date_index"); + if (dateIndexArray != null) { + for (int i = 0; i < dateIndexArray.size(); i++) { + JSONObject dateObject = dateIndexArray.getJSONObject(i); + if (dateObject == null) { + continue; + } + String date = dateObject.getString("date"); + if (date == null) { + continue; + } + try { + dateIndex.restore(LocalDate.parse(date), decodeSliceRefs(dateObject.getJSONArray("refs"))); + } catch (Exception e) { + log.warn("skip invalid date index: {}", date, e); + } + } + } + } + + State convert(TopicMemoryIndex topicIndex, DateMemoryIndex dateIndex) { + State state = new State(); + + List topicSliceStates = new ArrayList<>(); + for (Map.Entry entry : topicIndex.roots().entrySet()) { + collectTopicStates(entry.getKey(), entry.getValue(), topicSliceStates); + } + state.append("topic_slices", StateValue.arr(topicSliceStates)); + + List dateIndexStates = dateIndex.entries().entrySet().stream() + .sorted(Map.Entry.comparingByKey()) + .map(entry -> StateValue.obj(Map.of( + "date", StateValue.str(entry.getKey().toString()), + "refs", StateValue.arr(encodeSliceRefs(entry.getValue())) + ))) + .toList(); + state.append("date_index", StateValue.arr(dateIndexStates)); + + return state; + } + + private void collectTopicStates(String path, + TopicMemoryIndex.TopicTreeNode topicNode, + List topicStates) { + topicStates.add(StateValue.obj(Map.of( + "topic_path", StateValue.str(path), + "bindings", StateValue.arr(encodeTopicBindings(topicNode.bindings())) + ))); + for (Map.Entry childEntry : topicNode.children().entrySet()) { + collectTopicStates(path + "->" + childEntry.getKey(), childEntry.getValue(), topicStates); + } + } + + private List encodeTopicBindings(List bindings) { + return bindings.stream() + .map(binding -> (StateValue) StateValue.obj(Map.of( + "unit_id", StateValue.str(binding.sliceRef().getUnitId()), + "slice_id", StateValue.str(binding.sliceRef().getSliceId()), + "timestamp", StateValue.num(binding.timestamp()), + "activation_profile", StateValue.obj(Map.of( + "activation_weight", StateValue.num(binding.activationProfile().getActivationWeight()), + "diffusion_weight", StateValue.num(binding.activationProfile().getDiffusionWeight()), + "context_independence_weight", + StateValue.num(binding.activationProfile().getContextIndependenceWeight()) + )), + "related_topic_paths", StateValue.arr(binding.relatedTopicPaths().stream() + .map(StateValue::str) + .toList()) + ))) + .toList(); + } + + private void decodeTopicBindings(TopicMemoryIndex topicIndex, String topicPath, JSONArray bindingsArray) { + if (bindingsArray == null) { + return; + } + for (int i = 0; i < bindingsArray.size(); i++) { + JSONObject bindingObject = bindingsArray.getJSONObject(i); + if (bindingObject == null) { + continue; + } + String unitId = bindingObject.getString("unit_id"); + String sliceId = bindingObject.getString("slice_id"); + if (unitId == null || sliceId == null) { + continue; + } + Long timestamp = bindingObject.getLong("timestamp"); + if (timestamp == null) { + log.warn("skip topic binding without timestamp: {}:{}", unitId, sliceId); + continue; + } + List relatedTopicPaths = topicIndex.normalizeTopicPaths( + bindingObject.getList("related_topic_paths", String.class) + ); + topicIndex.recordBinding( + topicPath, + new SliceRef(unitId, sliceId), + timestamp, + relatedTopicPaths, + decodeActivationProfile(bindingObject.getJSONObject("activation_profile")) + ); + topicIndex.ensureTopicPaths(relatedTopicPaths); + } + } + + private ActivationProfile decodeActivationProfile(JSONObject profileObject) { + if (profileObject == null) { + return null; + } + return new ActivationProfile( + profileObject.getFloat("activation_weight"), + profileObject.getFloat("diffusion_weight"), + profileObject.getFloat("context_independence_weight") + ); + } + + private List encodeSliceRefs(List refs) { + return refs.stream() + .map(ref -> (StateValue) StateValue.obj(Map.of( + "unit_id", StateValue.str(ref.getUnitId()), + "slice_id", StateValue.str(ref.getSliceId()) + ))) + .toList(); + } + + private CopyOnWriteArrayList decodeSliceRefs(JSONArray refsArray) { + CopyOnWriteArrayList refs = new CopyOnWriteArrayList<>(); + if (refsArray == null) { + return refs; + } + for (int i = 0; i < refsArray.size(); i++) { + JSONObject refObject = refsArray.getJSONObject(i); + if (refObject == null) { + continue; + } + String unitId = refObject.getString("unit_id"); + String sliceId = refObject.getString("slice_id"); + if (unitId == null || sliceId == null) { + continue; + } + refs.addIfAbsent(new SliceRef(unitId, sliceId)); + } + return refs; + } +} diff --git a/Partner-Core/src/main/java/work/slhaf/partner/module/memory/runtime/TopicMemoryIndex.java b/Partner-Core/src/main/java/work/slhaf/partner/module/memory/runtime/TopicMemoryIndex.java new file mode 100644 index 00000000..7a4799c5 --- /dev/null +++ b/Partner-Core/src/main/java/work/slhaf/partner/module/memory/runtime/TopicMemoryIndex.java @@ -0,0 +1,237 @@ +package work.slhaf.partner.module.memory.runtime; + +import work.slhaf.partner.core.memory.pojo.SliceRef; +import work.slhaf.partner.module.memory.pojo.ActivationProfile; + +import java.util.*; +import java.util.concurrent.CopyOnWriteArrayList; + +final class TopicMemoryIndex { + + private static final float DEFAULT_ACTIVATION_WEIGHT = 0.55f; + private static final float DEFAULT_DIFFUSION_WEIGHT = 0.35f; + private static final float DEFAULT_CONTEXT_INDEPENDENCE_WEIGHT = 0.50f; + + private final Map topicSlices = new LinkedHashMap<>(); + + void recordBinding(String topicPath, + SliceRef sliceRef, + long timestamp, + Collection relatedTopicPaths, + ActivationProfile activationProfile) { + String normalizedPath = normalizeTopicPath(topicPath); + if (normalizedPath.isBlank()) { + return; + } + ensureTopicNode(normalizedPath).addBinding( + sliceRef, + timestamp, + relatedTopicPaths, + normalizeActivationProfile(activationProfile) + ); + } + + void ensureTopicPaths(Collection topicPaths) { + if (topicPaths == null || topicPaths.isEmpty()) { + return; + } + for (String topicPath : topicPaths) { + ensureTopicNode(topicPath); + } + } + + void reset() { + topicSlices.clear(); + } + + void ensureTopicPath(String topicPath) { + String normalizedPath = normalizeTopicPath(topicPath); + if (normalizedPath.isBlank()) { + return; + } + ensureTopicNode(normalizedPath); + } + + TopicTreeNode findTopicNode(String topicPath) { + String normalizedPath = normalizeTopicPath(topicPath); + if (normalizedPath.isBlank()) { + return null; + } + String[] parts = normalizedPath.split("->"); + TopicTreeNode current = topicSlices.get(parts[0]); + for (int i = 1; current != null && i < parts.length; i++) { + current = current.children().get(parts[i]); + } + return current; + } + + String getTopicTree() { + List lines = new ArrayList<>(); + for (Map.Entry entry : topicSlices.entrySet()) { + collectTopicTreeLines(entry.getKey(), entry.getValue(), lines); + } + return String.join("\r\n", lines); + } + + List normalizeTopicPaths(Collection topicPaths) { + if (topicPaths == null || topicPaths.isEmpty()) { + return List.of(); + } + LinkedHashSet normalized = new LinkedHashSet<>(); + for (String topicPath : topicPaths) { + String normalizedPath = normalizeTopicPath(topicPath); + if (!normalizedPath.isBlank()) { + normalized.add(normalizedPath); + } + } + return List.copyOf(normalized); + } + + String normalizeTopicPath(String topicPath) { + return topicPath == null ? "" : topicPath.trim(); + } + + Map roots() { + return topicSlices; + } + + private TopicTreeNode ensureTopicNode(String topicPath) { + String[] parts = topicPath.split("->"); + TopicTreeNode current = topicSlices.computeIfAbsent(parts[0], ignored -> new TopicTreeNode(null)); + for (int i = 1; i < parts.length; i++) { + TopicTreeNode parent = current; + current = current.children.computeIfAbsent(parts[i], ignored -> new TopicTreeNode(parent)); + } + return current; + } + + private void collectTopicTreeLines(String path, TopicTreeNode node, List lines) { + if (node.parent() == null) { + lines.add(path + " [root]"); + } else { + lines.add(path + " {slices: " + node.bindings().size() + "}"); + } + for (Map.Entry childEntry : node.children().entrySet()) { + collectTopicTreeLines(path + "->" + childEntry.getKey(), childEntry.getValue(), lines); + } + } + + private ActivationProfile normalizeActivationProfile(ActivationProfile activationProfile) { + ActivationProfile profile = activationProfile == null ? defaultActivationProfile() : new ActivationProfile( + activationProfile.getActivationWeight(), + activationProfile.getDiffusionWeight(), + activationProfile.getContextIndependenceWeight() + ); + profile.setActivationWeight(clampOrDefault(profile.getActivationWeight(), DEFAULT_ACTIVATION_WEIGHT)); + profile.setDiffusionWeight(clampOrDefault(profile.getDiffusionWeight(), DEFAULT_DIFFUSION_WEIGHT)); + profile.setContextIndependenceWeight(clampOrDefault( + profile.getContextIndependenceWeight(), + DEFAULT_CONTEXT_INDEPENDENCE_WEIGHT + )); + return profile; + } + + private ActivationProfile defaultActivationProfile() { + return new ActivationProfile( + DEFAULT_ACTIVATION_WEIGHT, + DEFAULT_DIFFUSION_WEIGHT, + DEFAULT_CONTEXT_INDEPENDENCE_WEIGHT + ); + } + + private float clampOrDefault(Float value, float defaultValue) { + return value == null ? defaultValue : clamp(value); + } + + private float clamp(float value) { + return Math.clamp(value, 0.0f, 1.0f); + } + + static final class TopicTreeNode { + private final TopicTreeNode parent; + private final Map children = new LinkedHashMap<>(); + private final CopyOnWriteArrayList bindings = new CopyOnWriteArrayList<>(); + + private TopicTreeNode(TopicTreeNode parent) { + this.parent = parent; + } + + TopicTreeNode parent() { + return parent; + } + + Map children() { + return children; + } + + List bindings() { + return bindings; + } + + private void addBinding(SliceRef sliceRef, + long timestamp, + Collection relatedTopicPaths, + ActivationProfile activationProfile) { + for (TopicBinding binding : bindings) { + if (Objects.equals(binding.sliceRef().getUnitId(), sliceRef.getUnitId()) + && Objects.equals(binding.sliceRef().getSliceId(), sliceRef.getSliceId())) { + binding.refresh(timestamp, relatedTopicPaths, activationProfile); + return; + } + } + bindings.add(new TopicBinding(sliceRef, timestamp, relatedTopicPaths, activationProfile)); + } + } + + static final class TopicBinding { + private final SliceRef sliceRef; + private final CopyOnWriteArrayList relatedTopicPaths = new CopyOnWriteArrayList<>(); + private long timestamp; + private ActivationProfile activationProfile; + + private TopicBinding(SliceRef sliceRef, + long timestamp, + Collection relatedTopicPaths, + ActivationProfile activationProfile) { + this.sliceRef = sliceRef; + this.timestamp = timestamp; + this.activationProfile = activationProfile; + mergeRelatedTopicPaths(relatedTopicPaths); + } + + SliceRef sliceRef() { + return sliceRef; + } + + long timestamp() { + return timestamp; + } + + ActivationProfile activationProfile() { + return activationProfile; + } + + List relatedTopicPaths() { + return relatedTopicPaths; + } + + private void refresh(long timestamp, + Collection relatedTopicPaths, + ActivationProfile activationProfile) { + this.timestamp = timestamp; + this.activationProfile = activationProfile; + mergeRelatedTopicPaths(relatedTopicPaths); + } + + private void mergeRelatedTopicPaths(Collection relatedTopicPaths) { + if (relatedTopicPaths == null) { + return; + } + for (String relatedTopicPath : relatedTopicPaths) { + if (relatedTopicPath != null && !relatedTopicPath.isBlank()) { + this.relatedTopicPaths.addIfAbsent(relatedTopicPath); + } + } + } + } +} diff --git a/Partner-Core/src/main/java/work/slhaf/partner/module/memory/runtime/TopicRecallCollector.java b/Partner-Core/src/main/java/work/slhaf/partner/module/memory/runtime/TopicRecallCollector.java new file mode 100644 index 00000000..0fe12c8e --- /dev/null +++ b/Partner-Core/src/main/java/work/slhaf/partner/module/memory/runtime/TopicRecallCollector.java @@ -0,0 +1,90 @@ +package work.slhaf.partner.module.memory.runtime; + +import work.slhaf.partner.core.memory.pojo.SliceRef; + +import java.util.*; + +final class TopicRecallCollector { + + private static final int TOPIC_RESULT_LIMIT = 5; + private static final int PARENT_CANDIDATE_LIMIT = 2; + private static final int RELATED_CANDIDATE_LIMIT = 2; + + private final TopicRecallScorer scorer; + + TopicRecallCollector(TopicRecallScorer scorer) { + this.scorer = scorer; + } + + List collect(TopicMemoryIndex topicIndex, TopicMemoryIndex.TopicTreeNode topicNode) { + LinkedHashMap candidates = new LinkedHashMap<>(); + LinkedHashMap relatedTopicPaths = new LinkedHashMap<>(); + collectTopicCandidates( + topicNode, + TopicRecallScorer.CandidateSource.PRIMARY, + Integer.MAX_VALUE, + candidates, + relatedTopicPaths + ); + collectTopicCandidates( + topicNode.parent(), + TopicRecallScorer.CandidateSource.PARENT, + PARENT_CANDIDATE_LIMIT, + candidates, + null + ); + for (Map.Entry relatedTopicEntry : relatedTopicPaths.entrySet()) { + if (relatedTopicEntry.getValue() <= 0.0f) { + continue; + } + collectTopicCandidates( + topicIndex.findTopicNode(relatedTopicEntry.getKey()), + TopicRecallScorer.CandidateSource.RELATED, + RELATED_CANDIDATE_LIMIT, + candidates, + null + ); + } + return candidates.values().stream() + .sorted(Comparator.comparingDouble(ScoredSliceCandidate::score) + .reversed() + .thenComparing(Comparator.comparingLong(ScoredSliceCandidate::timestamp).reversed())) + .limit(TOPIC_RESULT_LIMIT) + .map(ScoredSliceCandidate::sliceRef) + .toList(); + } + + private void collectTopicCandidates(TopicMemoryIndex.TopicTreeNode topicNode, + TopicRecallScorer.CandidateSource source, + int limit, + LinkedHashMap candidates, + Map relatedTopicPaths) { + if (topicNode == null || topicNode.bindings().isEmpty()) { + return; + } + List bindings = new ArrayList<>(topicNode.bindings()); + bindings.sort(Comparator.comparingLong(TopicMemoryIndex.TopicBinding::timestamp).reversed()); + int actualLimit = limit == Integer.MAX_VALUE ? bindings.size() : Math.min(limit, bindings.size()); + for (int i = 0; i < actualLimit; i++) { + TopicMemoryIndex.TopicBinding binding = bindings.get(i); + if (relatedTopicPaths != null) { + for (String relatedTopicPath : binding.relatedTopicPaths()) { + relatedTopicPaths.merge( + relatedTopicPath, + binding.activationProfile().getDiffusionWeight(), + Math::max + ); + } + } + double score = scorer.score(binding, source); + String key = binding.sliceRef().getUnitId() + ":" + binding.sliceRef().getSliceId(); + ScoredSliceCandidate current = candidates.get(key); + if (current == null || score > current.score()) { + candidates.put(key, new ScoredSliceCandidate(binding.sliceRef(), binding.timestamp(), score)); + } + } + } + + private record ScoredSliceCandidate(SliceRef sliceRef, long timestamp, double score) { + } +} diff --git a/Partner-Core/src/main/java/work/slhaf/partner/module/memory/runtime/TopicRecallScorer.java b/Partner-Core/src/main/java/work/slhaf/partner/module/memory/runtime/TopicRecallScorer.java new file mode 100644 index 00000000..aae0cb81 --- /dev/null +++ b/Partner-Core/src/main/java/work/slhaf/partner/module/memory/runtime/TopicRecallScorer.java @@ -0,0 +1,47 @@ +package work.slhaf.partner.module.memory.runtime; + +import work.slhaf.partner.module.memory.pojo.ActivationProfile; + +final class TopicRecallScorer { + + double score(TopicMemoryIndex.TopicBinding binding, CandidateSource source) { + ActivationProfile profile = binding.activationProfile(); + return source.sourceScore + + recencyScore(binding.timestamp()) + + 0.50d * profile.getActivationWeight() + + 0.30d * profile.getContextIndependenceWeight() + + 0.20d * source.relationFactor * profile.getDiffusionWeight(); + } + + private double recencyScore(long timestamp) { + long ageMillis = Math.max(0L, System.currentTimeMillis() - timestamp); + long ageDays = ageMillis / 86_400_000L; + if (ageDays <= 1) { + return 0.30d; + } + if (ageDays <= 3) { + return 0.22d; + } + if (ageDays <= 7) { + return 0.15d; + } + if (ageDays <= 30) { + return 0.08d; + } + return 0.00d; + } + + enum CandidateSource { + PRIMARY(1.00f, 0.30f), + RELATED(0.65f, 1.00f), + PARENT(0.45f, 0.20f); + + private final float sourceScore; + private final float relationFactor; + + CandidateSource(float sourceScore, float relationFactor) { + this.sourceScore = sourceScore; + this.relationFactor = relationFactor; + } + } +} diff --git a/Partner-Core/src/main/java/work/slhaf/partner/module/memory/selector/extractor/MemoryRecallCueExtractor.java b/Partner-Core/src/main/java/work/slhaf/partner/module/memory/selector/extractor/MemoryRecallCueExtractor.java index d572553c..f2b9bc07 100644 --- a/Partner-Core/src/main/java/work/slhaf/partner/module/memory/selector/extractor/MemoryRecallCueExtractor.java +++ b/Partner-Core/src/main/java/work/slhaf/partner/module/memory/selector/extractor/MemoryRecallCueExtractor.java @@ -177,6 +177,6 @@ public class MemoryRecallCueExtractor extends AbstractAgentModule.Sub extractResult = formattedChat( List.of( - cognitionCapability.contextWorkspace().resolve(List.of( - ContextBlock.FocusedDomain.COGNITION, - ContextBlock.FocusedDomain.MEMORY - )).encodeToMessage(), resolveTopicTaskMessage(result, slicedMessages) ), MemoryTopicResult.class @@ -57,8 +54,18 @@ public class MemoryUpdater extends AbstractAgentModule.Standalone implements Aft List relatedTopicPaths = topicResult.getRelatedTopicPaths() == null ? List.of() : topicResult.getRelatedTopicPaths().stream().map(memoryRuntime::fixTopicPath).toList(); - memoryRuntime.recordMemory(result.memoryUnit(), topicPath, relatedTopicPaths); - }).onFailure(exp -> memoryRuntime.recordMemory(result.memoryUnit(), null, List.of())); + ActivationProfile activationProfile = stabilizeActivationProfile( + topicResult.getActivationProfile(), + relatedTopicPaths, + slicedMessages + ); + memoryRuntime.recordMemory(result.memoryUnit(), topicPath, relatedTopicPaths, activationProfile); + }).onFailure(exp -> memoryRuntime.recordMemory( + result.memoryUnit(), + null, + List.of(), + defaultActivationProfile() + )); } private List sliceMessages(RollingResult result) { @@ -91,4 +98,44 @@ public class MemoryUpdater extends AbstractAgentModule.Standalone implements Aft public String modelKey() { return "topic_extractor"; } + + private ActivationProfile stabilizeActivationProfile(ActivationProfile activationProfile, + List relatedTopicPaths, + List slicedMessages) { + ActivationProfile profile = activationProfile == null ? defaultActivationProfile() : new ActivationProfile( + activationProfile.getActivationWeight(), + activationProfile.getDiffusionWeight(), + activationProfile.getContextIndependenceWeight() + ); + profile.setActivationWeight(clampOrDefault(profile.getActivationWeight(), DEFAULT_ACTIVATION_WEIGHT)); + profile.setDiffusionWeight(clampOrDefault(profile.getDiffusionWeight(), DEFAULT_DIFFUSION_WEIGHT)); + profile.setContextIndependenceWeight(clampOrDefault( + profile.getContextIndependenceWeight(), + DEFAULT_CONTEXT_INDEPENDENCE_WEIGHT + )); + + if (relatedTopicPaths.isEmpty()) { + profile.setDiffusionWeight(Math.min(profile.getDiffusionWeight(), NO_RELATED_DIFFUSION_CAP)); + } + if (slicedMessages.size() <= 1) { + profile.setActivationWeight(clamp(profile.getActivationWeight() - SINGLE_MESSAGE_ACTIVATION_PENALTY)); + } + return profile; + } + + private ActivationProfile defaultActivationProfile() { + return new ActivationProfile( + DEFAULT_ACTIVATION_WEIGHT, + DEFAULT_DIFFUSION_WEIGHT, + DEFAULT_CONTEXT_INDEPENDENCE_WEIGHT + ); + } + + private float clampOrDefault(Float value, float defaultValue) { + return value == null ? defaultValue : clamp(value); + } + + private float clamp(float value) { + return Math.clamp(value, 0.0f, 1.0f); + } } diff --git a/Partner-Core/src/main/java/work/slhaf/partner/module/memory/updater/summarizer/entity/MemoryTopicResult.java b/Partner-Core/src/main/java/work/slhaf/partner/module/memory/updater/summarizer/entity/MemoryTopicResult.java index 15ec6523..478a94f9 100644 --- a/Partner-Core/src/main/java/work/slhaf/partner/module/memory/updater/summarizer/entity/MemoryTopicResult.java +++ b/Partner-Core/src/main/java/work/slhaf/partner/module/memory/updater/summarizer/entity/MemoryTopicResult.java @@ -1,6 +1,7 @@ package work.slhaf.partner.module.memory.updater.summarizer.entity; import lombok.Data; +import work.slhaf.partner.module.memory.pojo.ActivationProfile; import java.util.List; @@ -8,4 +9,5 @@ import java.util.List; public class MemoryTopicResult { private String topicPath; private List relatedTopicPaths; + private ActivationProfile activationProfile; } diff --git a/Partner-Core/src/test/java/work/slhaf/partner/module/memory/runtime/MemoryRuntimeTest.java b/Partner-Core/src/test/java/work/slhaf/partner/module/memory/runtime/MemoryRuntimeTest.java index a2b53305..da81c75f 100644 --- a/Partner-Core/src/test/java/work/slhaf/partner/module/memory/runtime/MemoryRuntimeTest.java +++ b/Partner-Core/src/test/java/work/slhaf/partner/module/memory/runtime/MemoryRuntimeTest.java @@ -11,9 +11,9 @@ import work.slhaf.partner.core.cognition.CognitionCapability; import work.slhaf.partner.core.memory.MemoryCapability; import work.slhaf.partner.core.memory.pojo.MemorySlice; import work.slhaf.partner.core.memory.pojo.MemoryUnit; -import work.slhaf.partner.core.memory.pojo.SliceRef; import work.slhaf.partner.framework.agent.model.pojo.Message; import work.slhaf.partner.framework.agent.support.Result; +import work.slhaf.partner.module.memory.pojo.ActivationProfile; import work.slhaf.partner.module.memory.runtime.exception.MemoryLookupException; import work.slhaf.partner.module.memory.selector.ActivatedMemorySlice; @@ -25,7 +25,6 @@ import java.util.Collection; import java.util.HashMap; import java.util.List; import java.util.Map; -import java.util.concurrent.CopyOnWriteArrayList; import java.util.concurrent.locks.Lock; import java.util.concurrent.locks.ReentrantLock; @@ -34,18 +33,13 @@ import static org.junit.jupiter.api.Assertions.assertTrue; class MemoryRuntimeTest { + private static final ActivationProfile DEFAULT_PROFILE = new ActivationProfile(0.55f, 0.35f, 0.50f); + @BeforeAll public static void beforeAll(@TempDir Path tempDir) { System.setProperty("user.home", tempDir.toAbsolutePath().toString()); } - @SuppressWarnings("unchecked") - private static Map> topicSlices(MemoryRuntime runtime) throws Exception { - Field field = MemoryRuntime.class.getDeclaredField("topicSlices"); - field.setAccessible(true); - return (Map>) field.get(runtime); - } - @SuppressWarnings("unchecked") private static List invokeSliceMessages(MemoryRuntime runtime, MemoryUnit unit, MemorySlice slice) throws Exception { Method method = MemoryRuntime.class.getDeclaredMethod("sliceMessages", MemoryUnit.class, MemorySlice.class); @@ -152,13 +146,92 @@ class MemoryRuntimeTest { unit.getSlices().addAll(List.of(firstSlice, secondSlice)); memoryCapability.remember(unit); - runtime.recordMemory(unit, "topic/main", List.of("topic/related")); + runtime.recordMemory(unit, "topic/main", List.of("topic/related"), DEFAULT_PROFILE); - Map> topicSlices = topicSlices(runtime); + List topicResult = runtime.queryActivatedMemoryByTopicPath("topic/main"); + assertEquals(List.of("slice-2"), topicResult.stream().map(ActivatedMemorySlice::getSliceId).toList()); + assertTrue(runtime.getTopicTree().contains("topic/main [root]")); + assertTrue(runtime.getTopicTree().contains("topic/related [root]")); + assertTrue(JSONObject.parseObject(runtime.convert().toString()) + .getJSONArray("topic_slices") + .stream() + .map(JSONObject.class::cast) + .anyMatch(item -> "topic/main".equals(item.getString("topic_path")))); + } + + @Test + void shouldExpandTopicQueryToLatestRelatedTopicMemory() throws Exception { + StubMemoryCapability memoryCapability = new StubMemoryCapability("session-test"); + MemoryRuntime runtime = new MemoryRuntime(); + setField(runtime, "memoryCapability", memoryCapability); + setField(runtime, "cognitionCapability", stubCognitionCapability(List.of(message("seed")))); + + MemoryUnit mainUnit = new MemoryUnit("unit-main"); + mainUnit.getConversationMessages().addAll(List.of( + message("m0"), + message("m1"), + message("m2"), + message("m3") + )); + MemorySlice mainSlice = MemorySlice.restore("slice-main", 0, 2, "main", 86_400_000L); + mainUnit.getSlices().add(mainSlice); + memoryCapability.remember(mainUnit); + + MemoryUnit relatedUnit = new MemoryUnit("unit-related"); + relatedUnit.getConversationMessages().addAll(List.of( + message("r0"), + message("r1") + )); + MemorySlice relatedSlice = MemorySlice.restore("slice-related", 0, 2, "related", 172_800_000L); + relatedUnit.getSlices().add(relatedSlice); + memoryCapability.remember(relatedUnit); + + runtime.recordMemory(mainUnit, "topic/main", List.of("topic/related"), DEFAULT_PROFILE); + runtime.recordMemory(relatedUnit, "topic/related", List.of(), DEFAULT_PROFILE); + + List topicResult = runtime.queryActivatedMemoryByTopicPath("topic/main"); + assertEquals(List.of("slice-main", "slice-related"), + topicResult.stream().map(ActivatedMemorySlice::getSliceId).toList()); + } + + @Test + void shouldIndexDateIncrementallyWithoutRebuildingWholeUnit() throws Exception { + StubMemoryCapability memoryCapability = new StubMemoryCapability("session-test"); + MemoryRuntime runtime = new MemoryRuntime(); + setField(runtime, "memoryCapability", memoryCapability); + setField(runtime, "cognitionCapability", stubCognitionCapability(List.of(message("seed")))); + + MemoryUnit firstUnitSnapshot = new MemoryUnit("unit-100"); + firstUnitSnapshot.getConversationMessages().addAll(List.of(message("m0"), message("m1"))); + MemorySlice firstSlice = MemorySlice.restore("slice-1", 0, 1, "first", 86_400_000L); + firstUnitSnapshot.getSlices().add(firstSlice); + memoryCapability.remember(firstUnitSnapshot); + runtime.recordMemory(firstUnitSnapshot, "topic/main", List.of(), DEFAULT_PROFILE); + + firstUnitSnapshot.getConversationMessages().clear(); + firstUnitSnapshot.getConversationMessages().addAll(List.of(message("m2"), message("m3"))); + MemorySlice secondSlice = MemorySlice.restore("slice-2", 0, 1, "second", 172_800_000L); + firstUnitSnapshot.getSlices().clear(); + firstUnitSnapshot.getSlices().add(secondSlice); + memoryCapability.remember(firstUnitSnapshot); + runtime.recordMemory(firstUnitSnapshot, "topic/main", List.of(), DEFAULT_PROFILE); + + JSONObject state = JSONObject.parseObject(runtime.convert().toString()); + JSONArray dateIndex = state.getJSONArray("date_index"); + JSONObject firstDate = dateIndex.stream() + .map(JSONObject.class::cast) + .filter(item -> "1970-01-02".equals(item.getString("date"))) + .findFirst() + .orElseThrow(); + JSONObject secondDate = dateIndex.stream() + .map(JSONObject.class::cast) + .filter(item -> "1970-01-03".equals(item.getString("date"))) + .findFirst() + .orElseThrow(); + assertEquals(List.of("slice-1"), + firstDate.getJSONArray("refs").toJavaList(JSONObject.class).stream().map(obj -> obj.getString("slice_id")).toList()); assertEquals(List.of("slice-2"), - topicSlices.get("topic/main").stream().map(SliceRef::getSliceId).toList()); - assertEquals(List.of("slice-2"), - topicSlices.get("topic/related").stream().map(SliceRef::getSliceId).toList()); + secondDate.getJSONArray("refs").toJavaList(JSONObject.class).stream().map(obj -> obj.getString("slice_id")).toList()); } @Test @@ -168,8 +241,8 @@ class MemoryRuntimeTest { setField(runtime, "memoryCapability", memoryCapability); setField(runtime, "cognitionCapability", stubCognitionCapability(List.of(message("seed")))); - MemoryUnit unit = new MemoryUnit("unit-100"); - unit.getConversationMessages().addAll(List.of( + MemoryUnit mainUnit = new MemoryUnit("unit-200"); + mainUnit.getConversationMessages().addAll(List.of( message("m0"), message("m1"), message("m2"), @@ -177,10 +250,16 @@ class MemoryRuntimeTest { )); MemorySlice firstSlice = MemorySlice.restore("slice-1", 0, 2, "first", 86_400_000L); MemorySlice secondSlice = MemorySlice.restore("slice-2", 2, 4, "second", 172_800_000L); - unit.getSlices().addAll(List.of(firstSlice, secondSlice)); - memoryCapability.remember(unit); + mainUnit.getSlices().addAll(List.of(firstSlice, secondSlice)); + memoryCapability.remember(mainUnit); + runtime.recordMemory(mainUnit, "topic/main", List.of("topic/related"), DEFAULT_PROFILE); - runtime.recordMemory(unit, "topic/main", List.of("topic/related")); + MemoryUnit relatedUnit = new MemoryUnit("unit-201"); + relatedUnit.getConversationMessages().addAll(List.of(message("r0"), message("r1"))); + MemorySlice relatedSlice = MemorySlice.restore("slice-3", 0, 2, "related", 259_200_000L); + relatedUnit.getSlices().add(relatedSlice); + memoryCapability.remember(relatedUnit); + runtime.recordMemory(relatedUnit, "topic/related", List.of(), DEFAULT_PROFILE); JSONObject state = JSONObject.parseObject(runtime.convert().toString()); JSONArray topicSlices = state.getJSONArray("topic_slices"); @@ -190,16 +269,23 @@ class MemoryRuntimeTest { .filter(item -> "topic/main".equals(item.getString("topic_path"))) .findFirst() .orElseThrow(); - assertEquals("slice-2", mainTopic.getJSONArray("refs").getJSONObject(0).getString("slice_id")); + JSONObject binding = mainTopic.getJSONArray("bindings").getJSONObject(0); + assertEquals("slice-2", binding.getString("slice_id")); + assertEquals(172_800_000L, binding.getLongValue("timestamp")); + assertEquals(List.of("topic/related"), binding.getJSONArray("related_topic_paths").toJavaList(String.class)); + JSONObject activationProfile = binding.getJSONObject("activation_profile"); + assertEquals(0.55f, activationProfile.getFloatValue("activation_weight")); + assertEquals(0.35f, activationProfile.getFloatValue("diffusion_weight")); + assertEquals(0.50f, activationProfile.getFloatValue("context_independence_weight")); JSONArray dateIndex = state.getJSONArray("date_index"); assertEquals(2, dateIndex.size()); - JSONObject secondDate = dateIndex.stream() + JSONObject thirdDate = dateIndex.stream() .map(JSONObject.class::cast) - .filter(item -> "1970-01-03".equals(item.getString("date"))) + .filter(item -> "1970-01-04".equals(item.getString("date"))) .findFirst() .orElseThrow(); - assertEquals("slice-2", secondDate.getJSONArray("refs").getJSONObject(0).getString("slice_id")); + assertEquals("slice-3", thirdDate.getJSONArray("refs").getJSONObject(0).getString("slice_id")); MemoryRuntime restored = new MemoryRuntime(); setField(restored, "memoryCapability", memoryCapability); @@ -207,14 +293,14 @@ class MemoryRuntimeTest { restored.load(state); List topicResult = restored.queryActivatedMemoryByTopicPath("topic/main"); - assertEquals(1, topicResult.size()); - assertEquals("slice-2", topicResult.getFirst().getSliceId()); + assertEquals(List.of("slice-2", "slice-3"), + topicResult.stream().map(ActivatedMemorySlice::getSliceId).toList()); assertEquals(List.of("m2", "m3"), topicResult.getFirst().getMessages().stream().map(Message::getContent).toList()); - List dateResult = restored.queryActivatedMemoryByDate(LocalDate.parse("1970-01-03")); + List dateResult = restored.queryActivatedMemoryByDate(LocalDate.parse("1970-01-04")); assertEquals(1, dateResult.size()); - assertEquals("slice-2", dateResult.getFirst().getSliceId()); - assertEquals("second", dateResult.getFirst().getSummary()); + assertEquals("slice-3", dateResult.getFirst().getSliceId()); + assertEquals("related", dateResult.getFirst().getSummary()); } @Test @@ -228,6 +314,106 @@ class MemoryRuntimeTest { assertTrue(runtime.queryActivatedMemoryByDate(LocalDate.parse("1970-01-01")).isEmpty()); } + @Test + void shouldRankTopicMatchesBySourceAndActivationProfile() throws Exception { + StubMemoryCapability memoryCapability = new StubMemoryCapability("session-test"); + MemoryRuntime runtime = new MemoryRuntime(); + setField(runtime, "memoryCapability", memoryCapability); + setField(runtime, "cognitionCapability", stubCognitionCapability(List.of(message("seed")))); + + MemoryUnit primaryUnit = new MemoryUnit("unit-primary"); + primaryUnit.getConversationMessages().addAll(List.of(message("p0"), message("p1"))); + MemorySlice primarySlice = MemorySlice.restore("slice-primary", 0, 2, "primary", System.currentTimeMillis()); + primaryUnit.getSlices().add(primarySlice); + memoryCapability.remember(primaryUnit); + runtime.recordMemory(primaryUnit, "topic->main", List.of("topic->related"), new ActivationProfile(0.9f, 0.1f, 0.9f)); + + MemoryUnit relatedUnit = new MemoryUnit("unit-related-rank"); + relatedUnit.getConversationMessages().addAll(List.of(message("r0"), message("r1"))); + MemorySlice relatedSlice = MemorySlice.restore("slice-related-rank", 0, 2, "related", System.currentTimeMillis()); + relatedUnit.getSlices().add(relatedSlice); + memoryCapability.remember(relatedUnit); + runtime.recordMemory(relatedUnit, "topic->related", List.of(), new ActivationProfile(1.0f, 1.0f, 1.0f)); + + MemoryUnit parentUnit = new MemoryUnit("unit-parent"); + parentUnit.getConversationMessages().addAll(List.of(message("x0"), message("x1"))); + MemorySlice parentSlice = MemorySlice.restore("slice-parent", 0, 2, "parent", System.currentTimeMillis()); + parentUnit.getSlices().add(parentSlice); + memoryCapability.remember(parentUnit); + runtime.recordMemory(parentUnit, "topic", List.of(), new ActivationProfile(1.0f, 1.0f, 1.0f)); + + List topicResult = runtime.queryActivatedMemoryByTopicPath("topic->main"); + assertEquals(List.of("slice-primary", "slice-related-rank", "slice-parent"), + topicResult.stream().map(ActivatedMemorySlice::getSliceId).toList()); + } + + @Test + void shouldNotExpandRelatedTopicWhenDiffusionWeightIsZero() throws Exception { + StubMemoryCapability memoryCapability = new StubMemoryCapability("session-test"); + MemoryRuntime runtime = new MemoryRuntime(); + setField(runtime, "memoryCapability", memoryCapability); + setField(runtime, "cognitionCapability", stubCognitionCapability(List.of(message("seed")))); + + MemoryUnit primaryUnit = new MemoryUnit("unit-primary-zero"); + primaryUnit.getConversationMessages().addAll(List.of(message("p0"), message("p1"))); + MemorySlice primarySlice = MemorySlice.restore("slice-primary-zero", 0, 2, "primary", System.currentTimeMillis()); + primaryUnit.getSlices().add(primarySlice); + memoryCapability.remember(primaryUnit); + runtime.recordMemory( + primaryUnit, + "topic->main", + List.of("topic->related"), + new ActivationProfile(0.8f, 0.0f, 0.8f) + ); + + MemoryUnit relatedUnit = new MemoryUnit("unit-related-zero"); + relatedUnit.getConversationMessages().addAll(List.of(message("r0"), message("r1"))); + MemorySlice relatedSlice = MemorySlice.restore("slice-related-zero", 0, 2, "related", System.currentTimeMillis()); + relatedUnit.getSlices().add(relatedSlice); + memoryCapability.remember(relatedUnit); + runtime.recordMemory(relatedUnit, "topic->related", List.of(), new ActivationProfile(1.0f, 1.0f, 1.0f)); + + List topicResult = runtime.queryActivatedMemoryByTopicPath("topic->main"); + assertEquals(List.of("slice-primary-zero"), topicResult.stream().map(ActivatedMemorySlice::getSliceId).toList()); + } + + @Test + void shouldRefreshBindingTimestampAndActivationProfileWhenSameSliceRebound() throws Exception { + StubMemoryCapability memoryCapability = new StubMemoryCapability("session-test"); + MemoryRuntime runtime = new MemoryRuntime(); + setField(runtime, "memoryCapability", memoryCapability); + setField(runtime, "cognitionCapability", stubCognitionCapability(List.of(message("seed")))); + + MemoryUnit unit = new MemoryUnit("unit-refresh"); + unit.getConversationMessages().addAll(List.of(message("m0"), message("m1"))); + MemorySlice slice = MemorySlice.restore("slice-refresh", 0, 2, "summary", 86_400_000L); + unit.getSlices().add(slice); + memoryCapability.remember(unit); + + runtime.recordMemory(unit, "topic->main", List.of("topic->related"), new ActivationProfile(0.2f, 0.1f, 0.2f)); + unit.getSlices().clear(); + unit.getSlices().add(MemorySlice.restore("slice-refresh", 0, 2, "summary", 172_800_000L)); + runtime.recordMemory(unit, "topic->main", List.of("topic->related-2"), new ActivationProfile(0.9f, 0.8f, 0.7f)); + + JSONObject state = JSONObject.parseObject(runtime.convert().toString()); + JSONObject mainTopic = state.getJSONArray("topic_slices").stream() + .map(JSONObject.class::cast) + .filter(item -> "topic->main".equals(item.getString("topic_path"))) + .findFirst() + .orElseThrow(); + JSONObject binding = mainTopic.getJSONArray("bindings").getJSONObject(0); + JSONObject activationProfile = binding.getJSONObject("activation_profile"); + + assertEquals(172_800_000L, binding.getLongValue("timestamp")); + assertEquals(0.9f, activationProfile.getFloatValue("activation_weight")); + assertEquals(0.8f, activationProfile.getFloatValue("diffusion_weight")); + assertEquals(0.7f, activationProfile.getFloatValue("context_independence_weight")); + assertEquals( + List.of("topic->related", "topic->related-2"), + binding.getJSONArray("related_topic_paths").toJavaList(String.class) + ); + } + private static final class StubMemoryCapability implements MemoryCapability { private final String sessionId; private final Map units = new HashMap<>(); diff --git a/Partner-Core/src/test/java/work/slhaf/partner/module/memory/updater/MemoryUpdaterTest.java b/Partner-Core/src/test/java/work/slhaf/partner/module/memory/updater/MemoryUpdaterTest.java index 620c5a9e..48568e66 100644 --- a/Partner-Core/src/test/java/work/slhaf/partner/module/memory/updater/MemoryUpdaterTest.java +++ b/Partner-Core/src/test/java/work/slhaf/partner/module/memory/updater/MemoryUpdaterTest.java @@ -4,7 +4,6 @@ import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.io.TempDir; import org.mockito.Mockito; -import work.slhaf.partner.core.cognition.CognitionCapability; import work.slhaf.partner.core.memory.pojo.MemorySlice; import work.slhaf.partner.core.memory.pojo.MemoryUnit; import work.slhaf.partner.framework.agent.exception.AgentRuntimeException; @@ -12,6 +11,7 @@ import work.slhaf.partner.framework.agent.model.pojo.Message; import work.slhaf.partner.framework.agent.support.Result; import work.slhaf.partner.module.communication.AfterRollingRegistry; import work.slhaf.partner.module.communication.RollingResult; +import work.slhaf.partner.module.memory.pojo.ActivationProfile; import work.slhaf.partner.module.memory.runtime.MemoryRuntime; import work.slhaf.partner.module.memory.updater.summarizer.entity.MemoryTopicResult; @@ -19,6 +19,7 @@ import java.lang.reflect.Field; import java.nio.file.Path; import java.util.List; +import static org.mockito.ArgumentMatchers.argThat; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; @@ -55,11 +56,7 @@ class MemoryUpdaterTest { void shouldExtractTopicAndRecordMemoryOnConsume() throws Exception { MemoryUpdater updater = Mockito.spy(new MemoryUpdater()); MemoryRuntime memoryRuntime = Mockito.mock(MemoryRuntime.class); - CognitionCapability cognitionCapability = Mockito.mock(CognitionCapability.class); setField(updater, "memoryRuntime", memoryRuntime); - setField(updater, "cognitionCapability", cognitionCapability); - - when(cognitionCapability.contextWorkspace()).thenReturn(new work.slhaf.partner.core.cognition.ContextWorkspace()); when(memoryRuntime.getTopicTree()).thenReturn("topic-tree"); when(memoryRuntime.fixTopicPath("root[2]->branch[1]")).thenReturn("root->branch"); when(memoryRuntime.fixTopicPath("root[2]->related[1]")).thenReturn("root->related"); @@ -67,6 +64,7 @@ class MemoryUpdaterTest { MemoryTopicResult topicResult = new MemoryTopicResult(); topicResult.setTopicPath("root[2]->branch[1]"); topicResult.setRelatedTopicPaths(List.of("root[2]->related[1]")); + topicResult.setActivationProfile(new ActivationProfile(0.8f, 0.9f, 0.7f)); Mockito.doReturn(Result.success(topicResult)) .when(updater) .formattedChat(Mockito.anyList(), eq(MemoryTopicResult.class)); @@ -86,18 +84,22 @@ class MemoryUpdaterTest { message(Message.Character.ASSISTANT, "new-reply") ), "slice-summary", 4, 6)); - verify(memoryRuntime).recordMemory(eq(unit), eq("root->branch"), eq(List.of("root->related"))); + verify(memoryRuntime).recordMemory( + eq(unit), + eq("root->branch"), + eq(List.of("root->related")), + argThat(profile -> profile != null + && profile.getActivationWeight() == 0.8f + && profile.getDiffusionWeight() == 0.9f + && profile.getContextIndependenceWeight() == 0.7f) + ); } @Test void shouldFallbackToDateOnlyRecordWhenExtractionFails() throws Exception { MemoryUpdater updater = Mockito.spy(new MemoryUpdater()); MemoryRuntime memoryRuntime = Mockito.mock(MemoryRuntime.class); - CognitionCapability cognitionCapability = Mockito.mock(CognitionCapability.class); setField(updater, "memoryRuntime", memoryRuntime); - setField(updater, "cognitionCapability", cognitionCapability); - - when(cognitionCapability.contextWorkspace()).thenReturn(new work.slhaf.partner.core.cognition.ContextWorkspace()); when(memoryRuntime.getTopicTree()).thenReturn("topic-tree"); Mockito.doReturn(Result.failure(new AgentRuntimeException("boom"))) .when(updater) @@ -113,6 +115,48 @@ class MemoryUpdaterTest { updater.consume(new RollingResult(unit, slice, unit.getConversationMessages(), "slice-summary", 2, 6)); - verify(memoryRuntime).recordMemory(eq(unit), eq(null), eq(List.of())); + verify(memoryRuntime).recordMemory( + eq(unit), + eq(null), + eq(List.of()), + argThat(profile -> profile != null + && profile.getActivationWeight() == 0.55f + && profile.getDiffusionWeight() == 0.35f + && profile.getContextIndependenceWeight() == 0.50f) + ); + } + + @Test + void shouldClampAndAdjustActivationProfileBeforeRecording() throws Exception { + MemoryUpdater updater = Mockito.spy(new MemoryUpdater()); + MemoryRuntime memoryRuntime = Mockito.mock(MemoryRuntime.class); + setField(updater, "memoryRuntime", memoryRuntime); + when(memoryRuntime.getTopicTree()).thenReturn("topic-tree"); + when(memoryRuntime.fixTopicPath("root[2]->branch[1]")).thenReturn("root->branch"); + + MemoryTopicResult topicResult = new MemoryTopicResult(); + topicResult.setTopicPath("root[2]->branch[1]"); + topicResult.setRelatedTopicPaths(List.of()); + topicResult.setActivationProfile(new ActivationProfile(1.5f, 0.9f, -0.2f)); + Mockito.doReturn(Result.success(topicResult)) + .when(updater) + .formattedChat(Mockito.anyList(), eq(MemoryTopicResult.class)); + + MemoryUnit unit = new MemoryUnit("session-3"); + unit.getConversationMessages().add(message(Message.Character.USER, "only")); + MemorySlice slice = new MemorySlice(0, 1, "slice-summary"); + unit.getSlices().add(slice); + + updater.consume(new RollingResult(unit, slice, unit.getConversationMessages(), "slice-summary", 1, 6)); + + verify(memoryRuntime).recordMemory( + eq(unit), + eq("root->branch"), + eq(List.of()), + argThat(profile -> profile != null + && profile.getActivationWeight() == 0.95f + && profile.getDiffusionWeight() == 0.45f + && profile.getContextIndependenceWeight() == 0.0f) + ); } }