diff --git a/Partner-Core/src/main/java/work/slhaf/partner/core/memory/MemoryCapability.java b/Partner-Core/src/main/java/work/slhaf/partner/core/memory/MemoryCapability.java index 95d4605c..7f6520c3 100644 --- a/Partner-Core/src/main/java/work/slhaf/partner/core/memory/MemoryCapability.java +++ b/Partner-Core/src/main/java/work/slhaf/partner/core/memory/MemoryCapability.java @@ -1,7 +1,7 @@ package work.slhaf.partner.core.memory; -import work.slhaf.partner.core.memory.pojo.MemorySlice; -import work.slhaf.partner.core.memory.pojo.MemoryUnit; +import work.slhaf.partner.core.memory.pojo.MemorySliceSnapshot; +import work.slhaf.partner.core.memory.pojo.MemoryUnitSnapshot; import work.slhaf.partner.framework.agent.factory.capability.annotation.Capability; import work.slhaf.partner.framework.agent.model.pojo.Message; import work.slhaf.partner.framework.agent.support.Result; @@ -12,13 +12,13 @@ import java.util.List; @Capability(value = "memory") public interface MemoryCapability { - MemoryUnit getMemoryUnit(String unitId); + MemoryUnitSnapshot getMemoryUnit(String unitId); - Result getMemorySlice(String unitId, String sliceId); + Result getMemorySlice(String unitId, String sliceId); - MemoryUnit updateMemoryUnit(List chatMessages, String summary); + MemoryUnitSnapshot updateMemoryUnit(List chatMessages, String summary); - Collection listMemoryUnits(); + Collection listMemoryUnits(); void refreshMemorySession(); diff --git a/Partner-Core/src/main/java/work/slhaf/partner/core/memory/MemoryCore.java b/Partner-Core/src/main/java/work/slhaf/partner/core/memory/MemoryCore.java index 2b68bc4c..7b6ded82 100644 --- a/Partner-Core/src/main/java/work/slhaf/partner/core/memory/MemoryCore.java +++ b/Partner-Core/src/main/java/work/slhaf/partner/core/memory/MemoryCore.java @@ -5,7 +5,9 @@ import com.alibaba.fastjson2.JSONObject; import lombok.extern.slf4j.Slf4j; import org.jetbrains.annotations.NotNull; import work.slhaf.partner.core.memory.pojo.MemorySlice; +import work.slhaf.partner.core.memory.pojo.MemorySliceSnapshot; import work.slhaf.partner.core.memory.pojo.MemoryUnit; +import work.slhaf.partner.core.memory.pojo.MemoryUnitSnapshot; import work.slhaf.partner.framework.agent.factory.capability.annotation.CapabilityCore; import work.slhaf.partner.framework.agent.factory.capability.annotation.CapabilityMethod; import work.slhaf.partner.framework.agent.model.pojo.Message; @@ -36,10 +38,10 @@ public class MemoryCore implements StateSerializable { } @CapabilityMethod - public MemoryUnit updateMemoryUnit(List chatMessages, String summary) { + public MemoryUnitSnapshot updateMemoryUnit(List chatMessages, String summary) { memoryLock.lock(); try { - MemoryUnit unit = getMemoryUnit(memorySessionId); + MemoryUnit unit = getOrLoadMemoryUnit(memorySessionId); unit.updateTimestamp(); List conversationMessages = unit.getConversationMessages(); @@ -55,14 +57,60 @@ public class MemoryCore implements StateSerializable { unit.getSlices().add(memorySlice); normalizeMemoryUnit(unit); - return unit; + return unit.snapshot(); } finally { memoryLock.unlock(); } } @CapabilityMethod - public MemoryUnit getMemoryUnit(String unitId) { + public MemoryUnitSnapshot getMemoryUnit(String unitId) { + memoryLock.lock(); + try { + MemoryUnit unit = getOrLoadMemoryUnit(unitId); + normalizeMemoryUnit(unit); + return unit.snapshot(); + } finally { + memoryLock.unlock(); + } + } + + @CapabilityMethod + public Result getMemorySlice(String unitId, String sliceId) { + memoryLock.lock(); + try { + MemoryUnit memoryUnit = memoryUnits.get(unitId); + if (memoryUnit == null) { + return memorySliceNotFound(unitId, sliceId); + } + memoryUnit.load(); + normalizeMemoryUnit(memoryUnit); + for (MemorySlice slice : memoryUnit.getSlices()) { + if (sliceId.equals(slice.getId())) { + return Result.success(slice.snapshot()); + } + } + return memorySliceNotFound(unitId, sliceId); + } finally { + memoryLock.unlock(); + } + } + + @CapabilityMethod + public Collection listMemoryUnits() { + memoryLock.lock(); + try { + return memoryUnits.values().stream() + .peek(MemoryUnit::load) + .peek(this::normalizeMemoryUnit) + .map(MemoryUnit::snapshot) + .toList(); + } finally { + memoryLock.unlock(); + } + } + + private MemoryUnit getOrLoadMemoryUnit(String unitId) { MemoryUnit unit = memoryUnits.computeIfAbsent(unitId, id -> { MemoryUnit newUnit = new MemoryUnit(id); newUnit.register(); @@ -72,21 +120,7 @@ public class MemoryCore implements StateSerializable { return unit; } - @CapabilityMethod - public Result getMemorySlice(String unitId, String sliceId) { - MemoryUnit memoryUnit = memoryUnits.get(unitId); - if (memoryUnit == null) { - return Result.failure(new MemoryLookupException( - "Memory slice not found: " + unitId + ":" + sliceId, - unitId + ":" + sliceId, - "MEMORY_SLICE" - )); - } - for (MemorySlice slice : memoryUnit.getSlices()) { - if (sliceId.equals(slice.getId())) { - return Result.success(slice); - } - } + private Result memorySliceNotFound(String unitId, String sliceId) { return Result.failure(new MemoryLookupException( "Memory slice not found: " + unitId + ":" + sliceId, unitId + ":" + sliceId, @@ -94,11 +128,6 @@ public class MemoryCore implements StateSerializable { )); } - @CapabilityMethod - public Collection listMemoryUnits() { - return new ArrayList<>(memoryUnits.values()); - } - @CapabilityMethod public void refreshMemorySession() { memorySessionId = UUID.randomUUID().toString(); diff --git a/Partner-Core/src/main/java/work/slhaf/partner/core/memory/pojo/MemorySlice.java b/Partner-Core/src/main/java/work/slhaf/partner/core/memory/pojo/MemorySlice.java index 5cdc1cda..99a29f1b 100644 --- a/Partner-Core/src/main/java/work/slhaf/partner/core/memory/pojo/MemorySlice.java +++ b/Partner-Core/src/main/java/work/slhaf/partner/core/memory/pojo/MemorySlice.java @@ -33,6 +33,16 @@ public class MemorySlice implements Comparable { return new MemorySlice(id, startIndex, endIndex, summary, timestamp); } + public MemorySliceSnapshot snapshot() { + return new MemorySliceSnapshot( + id, + startIndex == null ? 0 : startIndex, + endIndex == null ? 0 : endIndex, + summary, + timestamp == null ? 0L : timestamp + ); + } + @Override public int compareTo(MemorySlice memorySlice) { if (memorySlice.getTimestamp() > this.getTimestamp()) { diff --git a/Partner-Core/src/main/java/work/slhaf/partner/core/memory/pojo/MemorySliceSnapshot.kt b/Partner-Core/src/main/java/work/slhaf/partner/core/memory/pojo/MemorySliceSnapshot.kt new file mode 100644 index 00000000..2251e714 --- /dev/null +++ b/Partner-Core/src/main/java/work/slhaf/partner/core/memory/pojo/MemorySliceSnapshot.kt @@ -0,0 +1,9 @@ +package work.slhaf.partner.core.memory.pojo + +data class MemorySliceSnapshot( + val id: String, + val startIndex: Int, + val endIndex: Int, + val summary: String?, + val timestamp: Long, +) diff --git a/Partner-Core/src/main/java/work/slhaf/partner/core/memory/pojo/MemoryUnit.java b/Partner-Core/src/main/java/work/slhaf/partner/core/memory/pojo/MemoryUnit.java index c4054a14..e87f3d20 100644 --- a/Partner-Core/src/main/java/work/slhaf/partner/core/memory/pojo/MemoryUnit.java +++ b/Partner-Core/src/main/java/work/slhaf/partner/core/memory/pojo/MemoryUnit.java @@ -31,6 +31,15 @@ public class MemoryUnit implements StateSerializable { timestamp = System.currentTimeMillis(); } + public MemoryUnitSnapshot snapshot() { + return new MemoryUnitSnapshot( + id, + List.copyOf(conversationMessages), + timestamp == null ? 0L : timestamp, + slices.stream().map(MemorySlice::snapshot).toList() + ); + } + @Override public @NotNull Path statePath() { return Path.of("core", "memory", "memory-unit" + id + ".json"); diff --git a/Partner-Core/src/main/java/work/slhaf/partner/core/memory/pojo/MemoryUnitSnapshot.kt b/Partner-Core/src/main/java/work/slhaf/partner/core/memory/pojo/MemoryUnitSnapshot.kt new file mode 100644 index 00000000..0eab9c40 --- /dev/null +++ b/Partner-Core/src/main/java/work/slhaf/partner/core/memory/pojo/MemoryUnitSnapshot.kt @@ -0,0 +1,23 @@ +package work.slhaf.partner.core.memory.pojo + +import work.slhaf.partner.framework.agent.model.pojo.Message + +data class MemoryUnitSnapshot( + val id: String, + val conversationMessages: List, + val timestamp: Long, + val slices: List, +) { + + fun messagesOf(slice: MemorySliceSnapshot): List { + if (conversationMessages.isEmpty()) { + return emptyList() + } + val start = slice.startIndex.coerceIn(0, conversationMessages.size) + val end = slice.endIndex.coerceIn(start, conversationMessages.size) + if (start >= end) { + return emptyList() + } + return conversationMessages.subList(start, end).toList() + } +} diff --git a/Partner-Core/src/main/java/work/slhaf/partner/module/action/builtin/BuiltinCapabilityActionProvider.java b/Partner-Core/src/main/java/work/slhaf/partner/module/action/builtin/BuiltinCapabilityActionProvider.java index d1071263..92daf57d 100644 --- a/Partner-Core/src/main/java/work/slhaf/partner/module/action/builtin/BuiltinCapabilityActionProvider.java +++ b/Partner-Core/src/main/java/work/slhaf/partner/module/action/builtin/BuiltinCapabilityActionProvider.java @@ -10,8 +10,8 @@ import work.slhaf.partner.core.cognition.CognitionCapability; import work.slhaf.partner.core.cognition.context.BlockContent; import work.slhaf.partner.core.cognition.context.ContextBlock; import work.slhaf.partner.core.memory.MemoryCapability; -import work.slhaf.partner.core.memory.pojo.MemorySlice; -import work.slhaf.partner.core.memory.pojo.MemoryUnit; +import work.slhaf.partner.core.memory.pojo.MemorySliceSnapshot; +import work.slhaf.partner.core.memory.pojo.MemoryUnitSnapshot; import work.slhaf.partner.framework.agent.factory.capability.annotation.InjectCapability; import work.slhaf.partner.framework.agent.factory.component.annotation.AgentComponent; import work.slhaf.partner.framework.agent.factory.component.annotation.Init; @@ -75,16 +75,16 @@ class BuiltinCapabilityActionProvider implements BuiltinActionProvider { Function, String> invoker = params -> { String unitId = BuiltinActionRegistry.BuiltinActionDefinition.requireString(params, "unit_id"); String sliceId = BuiltinActionRegistry.BuiltinActionDefinition.requireString(params, "slice_id"); - Result sliceResult = memoryCapability.getMemorySlice(unitId, sliceId); + Result sliceResult = memoryCapability.getMemorySlice(unitId, sliceId); if (sliceResult.exceptionOrNull() != null) { return JSONObject.of( "ok", false, "message", sliceResult.exceptionOrNull().getLocalizedMessage() ).toJSONString(); } - MemorySlice slice = sliceResult.getOrThrow(); + MemorySliceSnapshot slice = sliceResult.getOrThrow(); - MemoryUnit unit = memoryCapability.getMemoryUnit(unitId); + MemoryUnitSnapshot unit = memoryCapability.getMemoryUnit(unitId); cognitionCapability.contextWorkspace().register(new ContextBlock( buildMemoryRecallFullBlock(unit, slice), Set.of(ContextBlock.FocusedDomain.MEMORY), @@ -105,13 +105,13 @@ class BuiltinCapabilityActionProvider implements BuiltinActionProvider { ); } - private @NotNull BlockContent buildMemoryRecallFullBlock(MemoryUnit unit, MemorySlice slice) { + private @NotNull BlockContent buildMemoryRecallFullBlock(MemoryUnitSnapshot unit, MemorySliceSnapshot slice) { return new BlockContent("memory_recall", "memory_capability") { @Override protected void fillXml(@NotNull Document document, @NotNull Element root) { root.setAttribute("unit_id", unit.getId()); root.setAttribute("slice_id", slice.getId()); - appendRepeatedElements(document, root, "message", unit.getConversationMessages().subList(slice.getStartIndex(), slice.getEndIndex()), (messageElement, message) -> { + appendRepeatedElements(document, root, "message", unit.messagesOf(slice), (messageElement, message) -> { messageElement.setAttribute("role", message.getRole().name().toLowerCase(Locale.ROOT)); messageElement.setTextContent(message.getContent()); return Unit.INSTANCE; diff --git a/Partner-Core/src/main/java/work/slhaf/partner/module/communication/DialogRolling.java b/Partner-Core/src/main/java/work/slhaf/partner/module/communication/DialogRolling.java index 2866afa4..6bfedb97 100644 --- a/Partner-Core/src/main/java/work/slhaf/partner/module/communication/DialogRolling.java +++ b/Partner-Core/src/main/java/work/slhaf/partner/module/communication/DialogRolling.java @@ -13,8 +13,8 @@ import work.slhaf.partner.core.cognition.CognitionCapability; import work.slhaf.partner.core.cognition.context.BlockContent; import work.slhaf.partner.core.cognition.context.ContextBlock; import work.slhaf.partner.core.memory.MemoryCapability; -import work.slhaf.partner.core.memory.pojo.MemorySlice; -import work.slhaf.partner.core.memory.pojo.MemoryUnit; +import work.slhaf.partner.core.memory.pojo.MemorySliceSnapshot; +import work.slhaf.partner.core.memory.pojo.MemoryUnitSnapshot; import work.slhaf.partner.core.perceive.PerceiveCapability; import work.slhaf.partner.framework.agent.exception.AgentRuntimeException; import work.slhaf.partner.framework.agent.exception.ExceptionReporterHandler; @@ -31,6 +31,7 @@ import work.slhaf.partner.runtime.PartnerRunningFlowContext; import java.time.ZonedDateTime; import java.time.format.DateTimeFormatter; +import java.util.ArrayList; import java.util.List; import java.util.Set; import java.util.concurrent.atomic.AtomicBoolean; @@ -140,7 +141,7 @@ public class DialogRolling extends AbstractAgentModule.Running chatSnapshot, int rollingSize, int retainDivisor) { - messageCompressor.execute(chatSnapshot); - Result summaryResult = messageSummarizer.execute(chatSnapshot); + List rollingMessages = new ArrayList<>(chatSnapshot); + messageCompressor.execute(rollingMessages); + Result summaryResult = messageSummarizer.execute(rollingMessages); String summary = summaryResult.fold( value -> value, exp -> "no summary, due to exception" @@ -167,20 +169,20 @@ public class DialogRolling extends AbstractAgentModule.Running incrementMessages, - String summary, - int rollingSize, - int retainDivisor -) { -} diff --git a/Partner-Core/src/main/java/work/slhaf/partner/module/communication/RollingResult.kt b/Partner-Core/src/main/java/work/slhaf/partner/module/communication/RollingResult.kt new file mode 100644 index 00000000..0e43e423 --- /dev/null +++ b/Partner-Core/src/main/java/work/slhaf/partner/module/communication/RollingResult.kt @@ -0,0 +1,17 @@ +package work.slhaf.partner.module.communication + +import work.slhaf.partner.core.memory.pojo.MemorySliceSnapshot +import work.slhaf.partner.core.memory.pojo.MemoryUnitSnapshot +import work.slhaf.partner.framework.agent.model.pojo.Message + +data class RollingResult( + val memoryUnit: MemoryUnitSnapshot, + val memorySlice: MemorySliceSnapshot, + val rollingSize: Int, + val retainDivisor: Int, +) { + val summary: String + get() = memorySlice.summary ?: "" + + fun incrementMessages(): List = memoryUnit.messagesOf(memorySlice) +} diff --git a/Partner-Core/src/main/java/work/slhaf/partner/module/memory/runtime/MemoryRuntime.java b/Partner-Core/src/main/java/work/slhaf/partner/module/memory/runtime/MemoryRuntime.java index 8324960f..3386ec23 100644 --- a/Partner-Core/src/main/java/work/slhaf/partner/module/memory/runtime/MemoryRuntime.java +++ b/Partner-Core/src/main/java/work/slhaf/partner/module/memory/runtime/MemoryRuntime.java @@ -4,8 +4,8 @@ import com.alibaba.fastjson2.JSONObject; import org.jetbrains.annotations.NotNull; import work.slhaf.partner.core.cognition.CognitionCapability; import work.slhaf.partner.core.memory.MemoryCapability; -import work.slhaf.partner.core.memory.pojo.MemorySlice; -import work.slhaf.partner.core.memory.pojo.MemoryUnit; +import work.slhaf.partner.core.memory.pojo.MemorySliceSnapshot; +import work.slhaf.partner.core.memory.pojo.MemoryUnitSnapshot; import work.slhaf.partner.core.memory.pojo.SliceRef; import work.slhaf.partner.framework.agent.exception.ExceptionReporterHandler; import work.slhaf.partner.framework.agent.factory.capability.annotation.InjectCapability; @@ -52,11 +52,11 @@ public class MemoryRuntime extends AbstractAgentModule.Standalone implements Sta } } - public void recordMemory(MemoryUnit memoryUnit, + public void recordMemory(MemoryUnitSnapshot memoryUnit, String topicPath, List relatedTopicPaths, ActivationProfile activationProfile) { - MemorySlice memorySlice = memoryUnit.getSlices().getLast(); + MemorySliceSnapshot memorySlice = memoryUnit.getSlices().getLast(); SliceRef sliceRef = new SliceRef(memoryUnit.getId(), memorySlice.getId()); LocalDate date = toLocalDate(memorySlice.getTimestamp()); runtimeLock.lock(); @@ -159,13 +159,13 @@ public class MemoryRuntime extends AbstractAgentModule.Standalone implements Sta } private ActivatedMemorySlice buildActivatedMemorySlice(SliceRef ref) { - MemoryUnit memoryUnit = memoryCapability.getMemoryUnit(ref.getUnitId()); - Result memorySliceResult = memoryCapability.getMemorySlice(ref.getUnitId(), ref.getSliceId()); + MemoryUnitSnapshot memoryUnit = memoryCapability.getMemoryUnit(ref.getUnitId()); + Result memorySliceResult = memoryCapability.getMemorySlice(ref.getUnitId(), ref.getSliceId()); if (memoryUnit == null || memorySliceResult.exceptionOrNull() != null) { return null; } - MemorySlice memorySlice = memorySliceResult.getOrThrow(); - List messages = sliceMessages(memoryUnit, memorySlice); + MemorySliceSnapshot memorySlice = memorySliceResult.getOrThrow(); + List messages = memoryUnit.messagesOf(memorySlice); LocalDate date = toLocalDate(memorySlice.getTimestamp()); return ActivatedMemorySlice.builder() .unitId(ref.getUnitId()) @@ -177,19 +177,6 @@ public class MemoryRuntime extends AbstractAgentModule.Standalone implements Sta .build(); } - private List sliceMessages(MemoryUnit memoryUnit, MemorySlice memorySlice) { - List conversationMessages = memoryUnit.getConversationMessages(); - if (conversationMessages.isEmpty()) { - return List.of(); - } - int size = conversationMessages.size(); - int start = Math.clamp(memorySlice.getStartIndex(), 0, size); - int end = Math.clamp(memorySlice.getEndIndex(), start, size); - if (start >= end) { - return List.of(); - } - return new ArrayList<>(conversationMessages.subList(start, end)); - } private LocalDate toLocalDate(Long timestamp) { return Instant.ofEpochMilli(timestamp) diff --git a/Partner-Core/src/main/java/work/slhaf/partner/module/memory/updater/MemoryRecallProfileExtractor.java b/Partner-Core/src/main/java/work/slhaf/partner/module/memory/updater/MemoryRecallProfileExtractor.java index d573bd5c..4ec57f59 100644 --- a/Partner-Core/src/main/java/work/slhaf/partner/module/memory/updater/MemoryRecallProfileExtractor.java +++ b/Partner-Core/src/main/java/work/slhaf/partner/module/memory/updater/MemoryRecallProfileExtractor.java @@ -149,7 +149,7 @@ public class MemoryRecallProfileExtractor extends AbstractAgentModule.Standalone @Override public void consume(RollingResult result) { - List slicedMessages = sliceMessages(result); + List slicedMessages = result.incrementMessages(); if (slicedMessages.isEmpty()) { return; } @@ -169,31 +169,21 @@ public class MemoryRecallProfileExtractor extends AbstractAgentModule.Standalone relatedTopicPaths, slicedMessages ); - memoryRuntime.recordMemory(result.memoryUnit(), topicPath, relatedTopicPaths, activationProfile); + memoryRuntime.recordMemory(result.getMemoryUnit(), topicPath, relatedTopicPaths, activationProfile); }).onFailure(exp -> memoryRuntime.recordMemory( - result.memoryUnit(), + result.getMemoryUnit(), null, List.of(), defaultActivationProfile() )); } - private List sliceMessages(RollingResult result) { - int size = result.memoryUnit().getConversationMessages().size(); - int start = Math.clamp(result.memorySlice().getStartIndex(), 0, size); - int end = Math.clamp(result.memorySlice().getEndIndex(), start, size); - if (start >= end) { - return List.of(); - } - return result.memoryUnit().getConversationMessages().subList(start, end); - } - private Message resolveTopicTaskMessage(RollingResult result, List slicedMessages) { return new TaskBlock() { @Override protected void fillXml(@NotNull Document document, @NotNull Element root) { appendTextElement(document, root, "current_topic_tree", memoryRuntime.getTopicTree()); - appendTextElement(document, root, "slice_summary", result.summary()); + appendTextElement(document, root, "slice_summary", result.getSummary()); appendRepeatedElements(document, root, "message", slicedMessages, (messageElement, message) -> { messageElement.setAttribute("role", message.roleValue()); messageElement.setTextContent(message.getContent()); diff --git a/Partner-Core/src/test/java/work/slhaf/partner/core/memory/MemoryCoreTest.java b/Partner-Core/src/test/java/work/slhaf/partner/core/memory/MemoryCoreTest.java index 2f11c838..1b59d9b7 100644 --- a/Partner-Core/src/test/java/work/slhaf/partner/core/memory/MemoryCoreTest.java +++ b/Partner-Core/src/test/java/work/slhaf/partner/core/memory/MemoryCoreTest.java @@ -4,8 +4,8 @@ import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.io.TempDir; -import work.slhaf.partner.core.memory.pojo.MemorySlice; -import work.slhaf.partner.core.memory.pojo.MemoryUnit; +import work.slhaf.partner.core.memory.pojo.MemorySliceSnapshot; +import work.slhaf.partner.core.memory.pojo.MemoryUnitSnapshot; import work.slhaf.partner.framework.agent.model.pojo.Message; import java.nio.file.Path; @@ -32,7 +32,7 @@ class MemoryCoreTest { void shouldCreateFirstSliceFromChatMessages() { String sessionId = memoryCore.getMemorySessionId(); - MemoryUnit updatedUnit = memoryCore.updateMemoryUnit(List.of( + MemoryUnitSnapshot updatedUnit = memoryCore.updateMemoryUnit(List.of( new Message(Message.Character.USER, "m0"), new Message(Message.Character.USER, "m1"), new Message(Message.Character.USER, "m2") @@ -43,7 +43,7 @@ class MemoryCoreTest { updatedUnit.getConversationMessages().stream().map(Message::getContent).toList()); assertEquals(1, updatedUnit.getSlices().size()); - MemorySlice firstSlice = updatedUnit.getSlices().getFirst(); + MemorySliceSnapshot firstSlice = updatedUnit.getSlices().getFirst(); assertNotNull(firstSlice.getId()); assertEquals(0, firstSlice.getStartIndex()); assertEquals(3, firstSlice.getEndIndex()); @@ -60,7 +60,7 @@ class MemoryCoreTest { new Message(Message.Character.USER, "m0") ), "first-summary"); - MemoryUnit updatedUnit = memoryCore.updateMemoryUnit(List.of( + MemoryUnitSnapshot updatedUnit = memoryCore.updateMemoryUnit(List.of( new Message(Message.Character.ASSISTANT, "m1"), new Message(Message.Character.USER, "m2") ), "second-summary"); @@ -70,14 +70,14 @@ class MemoryCoreTest { updatedUnit.getConversationMessages().stream().map(Message::getContent).toList()); assertEquals(2, updatedUnit.getSlices().size()); - MemorySlice appendedSlice = updatedUnit.getSlices().getLast(); + MemorySliceSnapshot appendedSlice = updatedUnit.getSlices().getLast(); assertNotNull(appendedSlice.getId()); assertEquals(1, appendedSlice.getStartIndex()); assertEquals(3, appendedSlice.getEndIndex()); assertEquals("second-summary", appendedSlice.getSummary()); assertTrue(appendedSlice.getTimestamp() > 0); - MemorySlice loadedSlice = memoryCore.getMemorySlice(sessionId, appendedSlice.getId()).getOrThrow(); + MemorySliceSnapshot loadedSlice = memoryCore.getMemorySlice(sessionId, appendedSlice.getId()).getOrThrow(); assertNotNull(loadedSlice); assertEquals(1, loadedSlice.getStartIndex()); assertEquals(3, loadedSlice.getEndIndex()); diff --git a/Partner-Core/src/test/java/work/slhaf/partner/module/communication/DialogRollingTest.java b/Partner-Core/src/test/java/work/slhaf/partner/module/communication/DialogRollingTest.java index 53d5542b..134c6d25 100644 --- a/Partner-Core/src/test/java/work/slhaf/partner/module/communication/DialogRollingTest.java +++ b/Partner-Core/src/test/java/work/slhaf/partner/module/communication/DialogRollingTest.java @@ -6,7 +6,9 @@ import org.junit.jupiter.api.io.TempDir; import org.mockito.Mockito; import work.slhaf.partner.core.memory.MemoryCapability; import work.slhaf.partner.core.memory.pojo.MemorySlice; +import work.slhaf.partner.core.memory.pojo.MemorySliceSnapshot; import work.slhaf.partner.core.memory.pojo.MemoryUnit; +import work.slhaf.partner.core.memory.pojo.MemoryUnitSnapshot; import work.slhaf.partner.framework.agent.model.pojo.Message; import work.slhaf.partner.framework.agent.support.Result; import work.slhaf.partner.module.communication.summarizer.MessageCompressor; @@ -63,19 +65,19 @@ class DialogRollingTest { message(Message.Character.ASSISTANT, "new-assistant") ), 4, 6); - MemoryUnit merged = memoryCapability.getMemoryUnit(sessionId); + MemoryUnitSnapshot merged = memoryCapability.getMemoryUnit(sessionId); assertEquals(List.of("old-user", "old-assistant", "new-user", "new-assistant"), merged.getConversationMessages().stream().map(Message::getContent).toList()); assertEquals(2, merged.getSlices().size()); - MemorySlice appendedSlice = merged.getSlices().getLast(); + MemorySliceSnapshot appendedSlice = merged.getSlices().getLast(); assertNotNull(appendedSlice.getId()); assertEquals(2, appendedSlice.getStartIndex()); assertEquals(4, appendedSlice.getEndIndex()); assertEquals("new-summary", appendedSlice.getSummary()); - assertEquals(sessionId, rollingResult.memoryUnit().getId()); - assertEquals(appendedSlice.getId(), rollingResult.memorySlice().getId()); - assertEquals("new-summary", rollingResult.summary()); + assertEquals(sessionId, rollingResult.getMemoryUnit().getId()); + assertEquals(appendedSlice.getId(), rollingResult.getMemorySlice().getId()); + assertEquals("new-summary", rollingResult.getSummary()); } @Test @@ -96,7 +98,7 @@ class DialogRollingTest { message(Message.Character.ASSISTANT, "second") ), 2, 6); - MemoryUnit created = memoryCapability.getMemoryUnit(sessionId); + MemoryUnitSnapshot created = memoryCapability.getMemoryUnit(sessionId); assertNotNull(created); assertEquals(List.of("first", "second"), created.getConversationMessages().stream().map(Message::getContent).toList()); @@ -104,7 +106,7 @@ class DialogRollingTest { assertEquals(0, created.getSlices().getFirst().getStartIndex()); assertEquals(2, created.getSlices().getFirst().getEndIndex()); assertEquals("fresh-summary", created.getSlices().getFirst().getSummary()); - assertEquals(created, rollingResult.memoryUnit()); + assertEquals(created, rollingResult.getMemoryUnit()); } @Test @@ -151,8 +153,8 @@ class DialogRollingTest { message(Message.Character.ASSISTANT, "a1") ), 2, 6); - assertEquals(sessionId, rollingResult.memoryUnit().getId()); - assertEquals("no summary, due to empty summarize result", rollingResult.summary()); + assertEquals(sessionId, rollingResult.getMemoryUnit().getId()); + assertEquals("no summary, due to empty summarize result", rollingResult.getSummary()); } private static final class StubMemoryCapability implements MemoryCapability { @@ -172,28 +174,29 @@ class DialogRollingTest { } @Override - public MemoryUnit getMemoryUnit(String unitId) { - return units.get(unitId); + public MemoryUnitSnapshot getMemoryUnit(String unitId) { + MemoryUnit unit = units.get(unitId); + return unit == null ? null : unit.snapshot(); } @Override - public work.slhaf.partner.framework.agent.support.Result getMemorySlice(String unitId, String sliceId) { + public work.slhaf.partner.framework.agent.support.Result getMemorySlice(String unitId, String sliceId) { return null; } @Override - public MemoryUnit updateMemoryUnit(List chatMessages, String summary) { + public MemoryUnitSnapshot updateMemoryUnit(List chatMessages, String summary) { MemoryUnit unit = units.computeIfAbsent(sessionId, MemoryUnit::new); unit.updateTimestamp(); int startIndex = unit.getConversationMessages().size(); unit.getConversationMessages().addAll(chatMessages); unit.getSlices().add(new MemorySlice(startIndex, startIndex + chatMessages.size(), summary)); - return unit; + return unit.snapshot(); } @Override - public Collection listMemoryUnits() { - return units.values(); + public Collection listMemoryUnits() { + return units.values().stream().map(MemoryUnit::snapshot).toList(); } @Override diff --git a/Partner-Core/src/test/java/work/slhaf/partner/module/memory/runtime/MemoryRuntimeTest.java b/Partner-Core/src/test/java/work/slhaf/partner/module/memory/runtime/MemoryRuntimeTest.java index f8868bc8..0f3de692 100644 --- a/Partner-Core/src/test/java/work/slhaf/partner/module/memory/runtime/MemoryRuntimeTest.java +++ b/Partner-Core/src/test/java/work/slhaf/partner/module/memory/runtime/MemoryRuntimeTest.java @@ -11,7 +11,9 @@ import work.slhaf.partner.core.cognition.CognitionCapability; import work.slhaf.partner.core.cognition.context.ContextWorkspace; import work.slhaf.partner.core.memory.MemoryCapability; import work.slhaf.partner.core.memory.pojo.MemorySlice; +import work.slhaf.partner.core.memory.pojo.MemorySliceSnapshot; import work.slhaf.partner.core.memory.pojo.MemoryUnit; +import work.slhaf.partner.core.memory.pojo.MemoryUnitSnapshot; import work.slhaf.partner.framework.agent.model.pojo.Message; import work.slhaf.partner.framework.agent.support.Result; import work.slhaf.partner.module.memory.pojo.ActivationProfile; @@ -19,7 +21,6 @@ import work.slhaf.partner.module.memory.runtime.exception.MemoryLookupException; import work.slhaf.partner.module.memory.selector.ActivatedMemorySlice; import java.lang.reflect.Field; -import java.lang.reflect.Method; import java.nio.file.Path; import java.time.LocalDate; import java.util.Collection; @@ -41,11 +42,8 @@ class MemoryRuntimeTest { System.setProperty("user.home", tempDir.toAbsolutePath().toString()); } - @SuppressWarnings("unchecked") - private static List invokeSliceMessages(MemoryRuntime runtime, MemoryUnit unit, MemorySlice slice) throws Exception { - Method method = MemoryRuntime.class.getDeclaredMethod("sliceMessages", MemoryUnit.class, MemorySlice.class); - method.setAccessible(true); - return (List) method.invoke(runtime, unit, slice); + private static List invokeSliceMessages(MemoryRuntime runtime, MemoryUnit unit, MemorySlice slice) { + return unit.snapshot().messagesOf(slice.snapshot()); } private static void setField(Object target, String fieldName, Object value) throws Exception { @@ -200,7 +198,7 @@ class MemoryRuntimeTest { unit.getSlices().addAll(List.of(firstSlice, secondSlice)); memoryCapability.remember(unit); - runtime.recordMemory(unit, "topic/main", List.of("topic/related"), DEFAULT_PROFILE); + runtime.recordMemory(unit.snapshot(), "topic/main", List.of("topic/related"), DEFAULT_PROFILE); List topicResult = runtime.queryActivatedMemoryByTopicPath("topic/main"); assertEquals(List.of("slice-2"), topicResult.stream().map(ActivatedMemorySlice::getSliceId).toList()); @@ -240,8 +238,8 @@ class MemoryRuntimeTest { relatedUnit.getSlices().add(relatedSlice); memoryCapability.remember(relatedUnit); - runtime.recordMemory(mainUnit, "topic/main", List.of("topic/related"), DEFAULT_PROFILE); - runtime.recordMemory(relatedUnit, "topic/related", List.of(), DEFAULT_PROFILE); + runtime.recordMemory(mainUnit.snapshot(), "topic/main", List.of("topic/related"), DEFAULT_PROFILE); + runtime.recordMemory(relatedUnit.snapshot(), "topic/related", List.of(), DEFAULT_PROFILE); List topicResult = runtime.queryActivatedMemoryByTopicPath("topic/main"); assertEquals(List.of("slice-main", "slice-related"), @@ -260,7 +258,7 @@ class MemoryRuntimeTest { MemorySlice firstSlice = MemorySlice.restore("slice-1", 0, 1, "first", 86_400_000L); firstUnitSnapshot.getSlices().add(firstSlice); memoryCapability.remember(firstUnitSnapshot); - runtime.recordMemory(firstUnitSnapshot, "topic/main", List.of(), DEFAULT_PROFILE); + runtime.recordMemory(firstUnitSnapshot.snapshot(), "topic/main", List.of(), DEFAULT_PROFILE); firstUnitSnapshot.getConversationMessages().clear(); firstUnitSnapshot.getConversationMessages().addAll(List.of(message("m2"), message("m3"))); @@ -268,7 +266,7 @@ class MemoryRuntimeTest { firstUnitSnapshot.getSlices().clear(); firstUnitSnapshot.getSlices().add(secondSlice); memoryCapability.remember(firstUnitSnapshot); - runtime.recordMemory(firstUnitSnapshot, "topic/main", List.of(), DEFAULT_PROFILE); + runtime.recordMemory(firstUnitSnapshot.snapshot(), "topic/main", List.of(), DEFAULT_PROFILE); JSONObject state = JSONObject.parseObject(runtime.convert().toString()); JSONArray dateIndex = state.getJSONArray("date_index"); @@ -306,14 +304,14 @@ class MemoryRuntimeTest { MemorySlice secondSlice = MemorySlice.restore("slice-2", 2, 4, "second", 172_800_000L); mainUnit.getSlices().addAll(List.of(firstSlice, secondSlice)); memoryCapability.remember(mainUnit); - runtime.recordMemory(mainUnit, "topic/main", List.of("topic/related"), DEFAULT_PROFILE); + runtime.recordMemory(mainUnit.snapshot(), "topic/main", List.of("topic/related"), DEFAULT_PROFILE); MemoryUnit relatedUnit = new MemoryUnit("unit-201"); relatedUnit.getConversationMessages().addAll(List.of(message("r0"), message("r1"))); MemorySlice relatedSlice = MemorySlice.restore("slice-3", 0, 2, "related", 259_200_000L); relatedUnit.getSlices().add(relatedSlice); memoryCapability.remember(relatedUnit); - runtime.recordMemory(relatedUnit, "topic/related", List.of(), DEFAULT_PROFILE); + runtime.recordMemory(relatedUnit.snapshot(), "topic/related", List.of(), DEFAULT_PROFILE); JSONObject state = JSONObject.parseObject(runtime.convert().toString()); JSONArray topicSlices = state.getJSONArray("topic_slices"); @@ -380,21 +378,21 @@ class MemoryRuntimeTest { MemorySlice primarySlice = MemorySlice.restore("slice-primary", 0, 2, "primary", System.currentTimeMillis()); primaryUnit.getSlices().add(primarySlice); memoryCapability.remember(primaryUnit); - runtime.recordMemory(primaryUnit, "topic->main", List.of("topic->related"), new ActivationProfile(0.9f, 0.1f, 0.9f)); + runtime.recordMemory(primaryUnit.snapshot(), "topic->main", List.of("topic->related"), new ActivationProfile(0.9f, 0.1f, 0.9f)); MemoryUnit relatedUnit = new MemoryUnit("unit-related-rank"); relatedUnit.getConversationMessages().addAll(List.of(message("r0"), message("r1"))); MemorySlice relatedSlice = MemorySlice.restore("slice-related-rank", 0, 2, "related", System.currentTimeMillis()); relatedUnit.getSlices().add(relatedSlice); memoryCapability.remember(relatedUnit); - runtime.recordMemory(relatedUnit, "topic->related", List.of(), new ActivationProfile(1.0f, 1.0f, 1.0f)); + runtime.recordMemory(relatedUnit.snapshot(), "topic->related", List.of(), new ActivationProfile(1.0f, 1.0f, 1.0f)); MemoryUnit parentUnit = new MemoryUnit("unit-parent"); parentUnit.getConversationMessages().addAll(List.of(message("x0"), message("x1"))); MemorySlice parentSlice = MemorySlice.restore("slice-parent", 0, 2, "parent", System.currentTimeMillis()); parentUnit.getSlices().add(parentSlice); memoryCapability.remember(parentUnit); - runtime.recordMemory(parentUnit, "topic", List.of(), new ActivationProfile(1.0f, 1.0f, 1.0f)); + runtime.recordMemory(parentUnit.snapshot(), "topic", List.of(), new ActivationProfile(1.0f, 1.0f, 1.0f)); List topicResult = runtime.queryActivatedMemoryByTopicPath("topic->main"); assertEquals(List.of("slice-primary", "slice-related-rank", "slice-parent"), @@ -414,7 +412,7 @@ class MemoryRuntimeTest { primaryUnit.getSlices().add(primarySlice); memoryCapability.remember(primaryUnit); runtime.recordMemory( - primaryUnit, + primaryUnit.snapshot(), "topic->main", List.of("topic->related"), new ActivationProfile(0.8f, 0.0f, 0.8f) @@ -425,7 +423,7 @@ class MemoryRuntimeTest { MemorySlice relatedSlice = MemorySlice.restore("slice-related-zero", 0, 2, "related", System.currentTimeMillis()); relatedUnit.getSlices().add(relatedSlice); memoryCapability.remember(relatedUnit); - runtime.recordMemory(relatedUnit, "topic->related", List.of(), new ActivationProfile(1.0f, 1.0f, 1.0f)); + runtime.recordMemory(relatedUnit.snapshot(), "topic->related", List.of(), new ActivationProfile(1.0f, 1.0f, 1.0f)); List topicResult = runtime.queryActivatedMemoryByTopicPath("topic->main"); assertEquals(List.of("slice-primary-zero"), topicResult.stream().map(ActivatedMemorySlice::getSliceId).toList()); @@ -444,10 +442,10 @@ class MemoryRuntimeTest { unit.getSlices().add(slice); memoryCapability.remember(unit); - runtime.recordMemory(unit, "topic->main", List.of("topic->related"), new ActivationProfile(0.2f, 0.1f, 0.2f)); + runtime.recordMemory(unit.snapshot(), "topic->main", List.of("topic->related"), new ActivationProfile(0.2f, 0.1f, 0.2f)); unit.getSlices().clear(); unit.getSlices().add(MemorySlice.restore("slice-refresh", 0, 2, "summary", 172_800_000L)); - runtime.recordMemory(unit, "topic->main", List.of("topic->related-2"), new ActivationProfile(0.9f, 0.8f, 0.7f)); + runtime.recordMemory(unit.snapshot(), "topic->main", List.of("topic->related-2"), new ActivationProfile(0.9f, 0.8f, 0.7f)); JSONObject state = JSONObject.parseObject(runtime.convert().toString()); JSONObject mainTopic = state.getJSONArray("topic_slices").stream() @@ -481,12 +479,13 @@ class MemoryRuntimeTest { } @Override - public MemoryUnit getMemoryUnit(String unitId) { - return units.get(unitId); + public MemoryUnitSnapshot getMemoryUnit(String unitId) { + MemoryUnit unit = units.get(unitId); + return unit == null ? null : unit.snapshot(); } @Override - public Result getMemorySlice(String unitId, String sliceId) { + public Result getMemorySlice(String unitId, String sliceId) { MemoryUnit unit = units.get(unitId); if (unit == null || unit.getSlices() == null) { return Result.failure(new MemoryLookupException( @@ -498,7 +497,7 @@ class MemoryRuntimeTest { return unit.getSlices().stream() .filter(slice -> sliceId.equals(slice.getId())) .findFirst() - .map(Result::success) + .map(slice -> Result.success(slice.snapshot())) .orElseGet(() -> Result.failure(new MemoryLookupException( "Memory slice not found: " + unitId + ":" + sliceId, unitId + ":" + sliceId, @@ -507,13 +506,13 @@ class MemoryRuntimeTest { } @Override - public MemoryUnit updateMemoryUnit(List chatMessages, String summary) { + public MemoryUnitSnapshot updateMemoryUnit(List chatMessages, String summary) { return null; } @Override - public Collection listMemoryUnits() { - return units.values(); + public Collection listMemoryUnits() { + return units.values().stream().map(MemoryUnit::snapshot).toList(); } @Override diff --git a/Partner-Core/src/test/java/work/slhaf/partner/module/memory/updater/MemoryRecallProfileExtractorTest.java b/Partner-Core/src/test/java/work/slhaf/partner/module/memory/updater/MemoryRecallProfileExtractorTest.java index e484d127..794a8c13 100644 --- a/Partner-Core/src/test/java/work/slhaf/partner/module/memory/updater/MemoryRecallProfileExtractorTest.java +++ b/Partner-Core/src/test/java/work/slhaf/partner/module/memory/updater/MemoryRecallProfileExtractorTest.java @@ -79,13 +79,10 @@ class MemoryRecallProfileExtractorTest { MemorySlice slice = new MemorySlice(2, 4, "slice-summary"); unit.getSlices().add(slice); - updater.consume(new RollingResult(unit, slice, List.of( - message(Message.Character.USER, "new"), - message(Message.Character.ASSISTANT, "new-reply") - ), "slice-summary", 4, 6)); + updater.consume(new RollingResult(unit.snapshot(), slice.snapshot(), 4, 6)); verify(memoryRuntime).recordMemory( - eq(unit), + eq(unit.snapshot()), eq("root->branch"), eq(List.of("root->related")), argThat(profile -> profile != null @@ -113,10 +110,10 @@ class MemoryRecallProfileExtractorTest { MemorySlice slice = new MemorySlice(0, 2, "slice-summary"); unit.getSlices().add(slice); - updater.consume(new RollingResult(unit, slice, unit.getConversationMessages(), "slice-summary", 2, 6)); + updater.consume(new RollingResult(unit.snapshot(), slice.snapshot(), 2, 6)); verify(memoryRuntime).recordMemory( - eq(unit), + eq(unit.snapshot()), eq(null), eq(List.of()), argThat(profile -> profile != null @@ -147,10 +144,10 @@ class MemoryRecallProfileExtractorTest { MemorySlice slice = new MemorySlice(0, 1, "slice-summary"); unit.getSlices().add(slice); - updater.consume(new RollingResult(unit, slice, unit.getConversationMessages(), "slice-summary", 1, 6)); + updater.consume(new RollingResult(unit.snapshot(), slice.snapshot(), 1, 6)); verify(memoryRuntime).recordMemory( - eq(unit), + eq(unit.snapshot()), eq("root->branch"), eq(List.of()), argThat(profile -> profile != null