diff --git a/Partner-Core/src/main/java/work/slhaf/partner/core/memory/MemoryCore.java b/Partner-Core/src/main/java/work/slhaf/partner/core/memory/MemoryCore.java index d74c84aa..4d15af3e 100644 --- a/Partner-Core/src/main/java/work/slhaf/partner/core/memory/MemoryCore.java +++ b/Partner-Core/src/main/java/work/slhaf/partner/core/memory/MemoryCore.java @@ -75,7 +75,7 @@ public class MemoryCore implements StateSerializable { @CapabilityMethod public Result getMemorySlice(String unitId, String sliceId) { MemoryUnit memoryUnit = memoryUnits.get(unitId); - if (memoryUnit == null || memoryUnit.getSlices() == null) { + if (memoryUnit == null) { return Result.failure(new MemoryLookupException( "Memory slice not found: " + unitId + ":" + sliceId, unitId + ":" + sliceId, diff --git a/Partner-Core/src/main/java/work/slhaf/partner/module/communication/AfterRolling.java b/Partner-Core/src/main/java/work/slhaf/partner/module/communication/AfterRolling.java new file mode 100644 index 00000000..62fdbd23 --- /dev/null +++ b/Partner-Core/src/main/java/work/slhaf/partner/module/communication/AfterRolling.java @@ -0,0 +1,5 @@ +package work.slhaf.partner.module.communication; + +public interface AfterRolling { + void consume(RollingResult result); +} diff --git a/Partner-Core/src/main/java/work/slhaf/partner/module/communication/AfterRollingRegistry.java b/Partner-Core/src/main/java/work/slhaf/partner/module/communication/AfterRollingRegistry.java new file mode 100644 index 00000000..b0caed07 --- /dev/null +++ b/Partner-Core/src/main/java/work/slhaf/partner/module/communication/AfterRollingRegistry.java @@ -0,0 +1,52 @@ +package work.slhaf.partner.module.communication; + +import lombok.EqualsAndHashCode; +import lombok.extern.slf4j.Slf4j; +import work.slhaf.partner.core.action.ActionCapability; +import work.slhaf.partner.core.action.ActionCore; +import work.slhaf.partner.framework.agent.exception.AgentRuntimeException; +import work.slhaf.partner.framework.agent.exception.ExceptionReporterHandler; +import work.slhaf.partner.framework.agent.factory.capability.annotation.InjectCapability; +import work.slhaf.partner.framework.agent.factory.component.abstracts.AbstractAgentModule; +import work.slhaf.partner.framework.agent.factory.component.annotation.Init; + +import java.util.List; +import java.util.concurrent.CopyOnWriteArrayList; +import java.util.concurrent.ExecutorService; + +@EqualsAndHashCode(callSuper = true) +@Slf4j +public class AfterRollingRegistry extends AbstractAgentModule.Standalone { + + private final CopyOnWriteArrayList consumers = new CopyOnWriteArrayList<>(); + @InjectCapability + private ActionCapability actionCapability; + private ExecutorService executor; + + @Init + public void init() { + executor = actionCapability.getExecutor(ActionCore.ExecutorType.VIRTUAL); + } + + public void register(AfterRolling consumer) { + if (consumers.contains(consumer)) { + return; + } + consumers.add(consumer); + } + + public void trigger(RollingResult result) { + if (consumers.isEmpty()) { + return; + } + executor.execute(() -> { + for (AfterRolling consumer : List.copyOf(consumers)) { + try { + consumer.consume(result); + } catch (Exception e) { + ExceptionReporterHandler.INSTANCE.report(new AgentRuntimeException("after-rolling consumer occurred exception: " + consumer.getClass().getSimpleName(), e)); + } + } + }); + } +} diff --git a/Partner-Core/src/main/java/work/slhaf/partner/module/communication/DialogRolling.java b/Partner-Core/src/main/java/work/slhaf/partner/module/communication/DialogRolling.java new file mode 100644 index 00000000..1bb187c5 --- /dev/null +++ b/Partner-Core/src/main/java/work/slhaf/partner/module/communication/DialogRolling.java @@ -0,0 +1,208 @@ +package work.slhaf.partner.module.communication; + +import kotlin.Unit; +import lombok.Data; +import lombok.EqualsAndHashCode; +import org.jetbrains.annotations.NotNull; +import org.w3c.dom.Document; +import org.w3c.dom.Element; +import work.slhaf.partner.core.action.ActionCapability; +import work.slhaf.partner.core.action.ActionCore; +import work.slhaf.partner.core.action.entity.Schedulable; +import work.slhaf.partner.core.action.entity.StateAction; +import work.slhaf.partner.core.cognition.BlockContent; +import work.slhaf.partner.core.cognition.CognitionCapability; +import work.slhaf.partner.core.cognition.ContextBlock; +import work.slhaf.partner.core.memory.MemoryCapability; +import work.slhaf.partner.core.memory.pojo.MemorySlice; +import work.slhaf.partner.core.memory.pojo.MemoryUnit; +import work.slhaf.partner.core.perceive.PerceiveCapability; +import work.slhaf.partner.framework.agent.exception.AgentRuntimeException; +import work.slhaf.partner.framework.agent.exception.ExceptionReporterHandler; +import work.slhaf.partner.framework.agent.factory.capability.annotation.InjectCapability; +import work.slhaf.partner.framework.agent.factory.component.abstracts.AbstractAgentModule; +import work.slhaf.partner.framework.agent.factory.component.annotation.Init; +import work.slhaf.partner.framework.agent.factory.component.annotation.InjectModule; +import work.slhaf.partner.framework.agent.model.pojo.Message; +import work.slhaf.partner.framework.agent.support.Result; +import work.slhaf.partner.module.action.scheduler.ActionScheduler; +import work.slhaf.partner.module.memory.runtime.MemoryRuntime; +import work.slhaf.partner.module.memory.updater.summarizer.MultiSummarizer; +import work.slhaf.partner.module.memory.updater.summarizer.SingleSummarizer; +import work.slhaf.partner.module.memory.updater.summarizer.entity.SummarizeInput; +import work.slhaf.partner.module.memory.updater.summarizer.entity.SummarizeResult; +import work.slhaf.partner.runtime.PartnerRunningFlowContext; + +import java.time.ZonedDateTime; +import java.time.format.DateTimeFormatter; +import java.util.List; +import java.util.Set; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.atomic.AtomicBoolean; + +@EqualsAndHashCode(callSuper = true) +@Data +public class DialogRolling extends AbstractAgentModule.Running { + + private static final String AUTO_UPDATE_CRON = "0/10 * * * * ?"; + private static final long UPDATE_TRIGGER_INTERVAL = 60 * 60 * 1000; + private static final int CONTEXT_RETAIN_DIVISOR = 6; + private static final int MEMORY_UPDATE_TRIGGER_ROLL_LIMIT = 36; + private final AtomicBoolean rolling = new AtomicBoolean(false); + @InjectCapability + private CognitionCapability cognitionCapability; + @InjectCapability + private MemoryCapability memoryCapability; + @InjectCapability + private PerceiveCapability perceiveCapability; + @InjectCapability + private ActionCapability actionCapability; + @InjectModule + private MemoryRuntime memoryRuntime; + @InjectModule + private MultiSummarizer multiSummarizer; + @InjectModule + private SingleSummarizer singleSummarizer; + @InjectModule + private ActionScheduler actionScheduler; + @InjectModule + private AfterRollingRegistry afterRollingRegistry; + private ExecutorService executor; + + @Init + public void init() { + executor = actionCapability.getExecutor(ActionCore.ExecutorType.VIRTUAL); + registerScheduledUpdater(); + } + + private void registerScheduledUpdater() { + StateAction stateAction = new StateAction( + "system", + "dialog-rolling-auto-update", + "定时检查并触发对话滚动", + Schedulable.ScheduleType.CYCLE, + AUTO_UPDATE_CRON, + new StateAction.Trigger.Call(() -> { + tryAutoRolling(); + return Unit.INSTANCE; + }) + ); + actionScheduler.schedule(stateAction); + log.info("Dialog rolling has been registered into ActionScheduler, cron={}", AUTO_UPDATE_CRON); + } + + @Override + protected void doExecute(@NotNull PartnerRunningFlowContext context) { + if (cognitionCapability.getChatMessages().size() < MEMORY_UPDATE_TRIGGER_ROLL_LIMIT) { + return; + } + executor.execute(() -> triggerRolling(false)); + } + + private void tryAutoRolling() { + long currentTime = System.currentTimeMillis(); + int chatCount = cognitionCapability.snapshotChatMessages().size(); + if (currentTime - perceiveCapability.showLastInteract().toEpochMilli() > UPDATE_TRIGGER_INTERVAL && chatCount > 1) { + triggerRolling(true); + log.debug("Dialog rolling: auto triggered"); + } + } + + private void triggerRolling(boolean refreshMemoryId) { + if (!rolling.compareAndSet(false, true)) { + log.debug("Dialog rolling: rolling is already executing"); + return; + } + try { + List fullChatSnapshot = cognitionCapability.snapshotChatMessages(); + if (fullChatSnapshot.size() <= 1) { + return; + } + List chatIncrement = resolveChatIncrement(fullChatSnapshot); + if (chatIncrement.isEmpty()) { + if (refreshMemoryId) { + memoryCapability.refreshMemorySession(); + } + return; + } + + RollingResult result = buildRollingResult(chatIncrement, fullChatSnapshot.size(), CONTEXT_RETAIN_DIVISOR); + applyRolling(result); + afterRollingRegistry.trigger(result); + + if (refreshMemoryId) { + memoryCapability.refreshMemorySession(); + } + } catch (Exception e) { + ExceptionReporterHandler.INSTANCE.report(new AgentRuntimeException("Dialog rolling failed", e)); + } finally { + rolling.set(false); + } + } + + List resolveChatIncrement(List fullChatSnapshot) { + String memoryId = memoryCapability.getMemorySessionId(); + if (memoryId.isBlank()) { + return fullChatSnapshot; + } + MemoryUnit existingUnit = memoryCapability.getMemoryUnit(memoryId); + if (existingUnit.getConversationMessages().isEmpty()) { + return fullChatSnapshot; + } + List existingMessages = existingUnit.getConversationMessages(); + int maxOverlap = Math.min(existingMessages.size(), fullChatSnapshot.size()); + for (int overlap = maxOverlap; overlap > 0; overlap--) { + List existingSuffix = existingMessages.subList(existingMessages.size() - overlap, existingMessages.size()); + List snapshotPrefix = fullChatSnapshot.subList(0, overlap); + if (existingSuffix.equals(snapshotPrefix)) { + return fullChatSnapshot.subList(overlap, fullChatSnapshot.size()); + } + } + return fullChatSnapshot; + } + + @NotNull + RollingResult buildRollingResult(List chatSnapshot, int rollingSize, int retainDivisor) { + SummarizeInput summarizeInput = new SummarizeInput(chatSnapshot, memoryRuntime.getTopicTree()); + singleSummarizer.execute(summarizeInput.getChatMessages()); + Result summarizeResult = multiSummarizer.execute(summarizeInput); + String summary = summarizeResult.fold( + SummarizeResult::getSummary, + exp -> "no summary, due to exception" + ); + if (summary.isBlank()) { + summary = "no summary, due to empty summarize result"; + } + MemoryUnit memoryUnit = memoryCapability.updateMemoryUnit(chatSnapshot, summary); + MemorySlice newSlice = memoryUnit.getSlices().getLast(); + return new RollingResult(memoryUnit, newSlice, List.copyOf(chatSnapshot), newSlice.getSummary(), rollingSize, retainDivisor); + } + + private void applyRolling(RollingResult result) { + cognitionCapability.contextWorkspace().register(new ContextBlock( + buildDialogAbstractBlock(result.summary(), result.memoryUnit().getId(), result.memorySlice().getId()), + Set.of(ContextBlock.VisibleDomain.MEMORY, ContextBlock.VisibleDomain.COMMUNICATION), + 20, + 5, + 10 + )); + cognitionCapability.rollChatMessagesWithSnapshot(result.rollingSize(), result.retainDivisor()); + } + + private @NotNull BlockContent buildDialogAbstractBlock(String summary, String unitId, String sliceId) { + return new BlockContent("dialog_history", "dialog_rolling") { + @Override + protected void fillXml(@NotNull Document document, @NotNull Element root) { + root.setAttribute("related_memory_unit_id", unitId); + root.setAttribute("related_memory_slice_id", sliceId); + root.setAttribute("datetime", ZonedDateTime.now().format(DateTimeFormatter.ofPattern("yyyy-MM-dd HH:mm:ss"))); + appendTextElement(document, root, "summary", summary); + } + }; + } + + @Override + public int order() { + return 7; + } +} diff --git a/Partner-Core/src/main/java/work/slhaf/partner/module/communication/DialogRollingService.java b/Partner-Core/src/main/java/work/slhaf/partner/module/communication/DialogRollingService.java deleted file mode 100644 index bd7e94c2..00000000 --- a/Partner-Core/src/main/java/work/slhaf/partner/module/communication/DialogRollingService.java +++ /dev/null @@ -1,77 +0,0 @@ -package work.slhaf.partner.module.communication; - -import kotlin.Unit; -import org.jetbrains.annotations.NotNull; -import org.jetbrains.annotations.Nullable; -import org.w3c.dom.Document; -import org.w3c.dom.Element; -import work.slhaf.partner.core.cognition.BlockContent; -import work.slhaf.partner.core.cognition.CognitionCapability; -import work.slhaf.partner.core.cognition.ContextBlock; -import work.slhaf.partner.framework.agent.factory.capability.annotation.InjectCapability; -import work.slhaf.partner.framework.agent.factory.component.abstracts.AbstractAgentModule; -import work.slhaf.partner.framework.agent.model.ActivateModel; -import work.slhaf.partner.framework.agent.model.pojo.Message; -import work.slhaf.partner.module.TaskBlock; - -import java.time.ZonedDateTime; -import java.time.format.DateTimeFormatter; -import java.util.List; -import java.util.Locale; -import java.util.Set; - -public class DialogRollingService extends AbstractAgentModule.Standalone implements ActivateModel { - - @InjectCapability - private CognitionCapability cognitionCapability; - - public void rollMessages(List snapshotMessages, int rollingSize, int retainSize) { - rollMessages(snapshotMessages, rollingSize, retainSize, null, null, null); - } - - public void rollMessages(List snapshotMessages, int rollingSize, int retainSize, @Nullable String unitId, @Nullable String sliceId, @Nullable String summary) { - summary = summary == null ? summarize(snapshotMessages) : summary; - cognitionCapability.contextWorkspace().register(new ContextBlock( - buildDialogAbstractBlock(summary, unitId, sliceId), - Set.of(ContextBlock.VisibleDomain.MEMORY, ContextBlock.VisibleDomain.COMMUNICATION), - 35, - 8, - 10 - )); - cognitionCapability.rollChatMessagesWithSnapshot(rollingSize, retainSize); - } - - private String summarize(List snapshotMessages) { - List messages = List.of( - resolveTaskBlock(snapshotMessages) - ); - return chat(messages).getOrThrow(); - } - - private @NotNull BlockContent buildDialogAbstractBlock(String summary, @Nullable String unitId, @Nullable String sliceId) { - return new BlockContent("dialog_history", "dialog_rolling_service") { - @Override - protected void fillXml(@NotNull Document document, @NotNull Element root) { - if (unitId != null) root.setAttribute("related_memory_unit_id", unitId); - if (sliceId != null) root.setAttribute("related_memory_slice_id", sliceId); - root.setAttribute("datetime", ZonedDateTime.now().format(DateTimeFormatter.ofPattern("yyyy-MM-dd HH:mm:ss"))); - - appendTextElement(document, root, "summary", summary); - } - }; - } - - private Message resolveTaskBlock(List snapshotMessages) { - return new TaskBlock() { - @Override - protected void fillXml(@NotNull Document document, @NotNull Element root) { - appendRepeatedElements(document, root, "message", snapshotMessages, (messageElement, message) -> { - messageElement.setAttribute("role", message.getRole().name().toLowerCase(Locale.ROOT)); - messageElement.setTextContent(message.getContent()); - return Unit.INSTANCE; - }); - } - }.encodeToMessage(); - } - -} diff --git a/Partner-Core/src/main/java/work/slhaf/partner/module/communication/RollingResult.java b/Partner-Core/src/main/java/work/slhaf/partner/module/communication/RollingResult.java new file mode 100644 index 00000000..8540586b --- /dev/null +++ b/Partner-Core/src/main/java/work/slhaf/partner/module/communication/RollingResult.java @@ -0,0 +1,17 @@ +package work.slhaf.partner.module.communication; + +import work.slhaf.partner.core.memory.pojo.MemorySlice; +import work.slhaf.partner.core.memory.pojo.MemoryUnit; +import work.slhaf.partner.framework.agent.model.pojo.Message; + +import java.util.List; + +public record RollingResult( + MemoryUnit memoryUnit, + MemorySlice memorySlice, + List incrementMessages, + String summary, + int rollingSize, + int retainDivisor +) { +} diff --git a/Partner-Core/src/main/java/work/slhaf/partner/module/memory/runtime/MemoryRuntime.java b/Partner-Core/src/main/java/work/slhaf/partner/module/memory/runtime/MemoryRuntime.java index fc2890c2..c34bb7e3 100644 --- a/Partner-Core/src/main/java/work/slhaf/partner/module/memory/runtime/MemoryRuntime.java +++ b/Partner-Core/src/main/java/work/slhaf/partner/module/memory/runtime/MemoryRuntime.java @@ -50,8 +50,7 @@ public class MemoryRuntime extends AbstractAgentModule.Standalone implements Sta } private void checkAndSetMemoryId() { - String currentMemoryId = memoryCapability.getMemorySessionId(); - if (currentMemoryId == null || cognitionCapability.getChatMessages().isEmpty()) { + if (cognitionCapability.getChatMessages().isEmpty()) { memoryCapability.refreshMemorySession(); } } @@ -75,9 +74,11 @@ public class MemoryRuntime extends AbstractAgentModule.Standalone implements Sta MemorySlice memorySlice = memoryUnit.getSlices().getLast(); SliceRef sliceRef = new SliceRef(memoryUnit.getId(), memorySlice.getId()); indexMemoryUnit(memoryUnit); - bindTopic(topicPath, sliceRef); - if (relatedTopicPaths != null) { - for (String relatedTopicPath : relatedTopicPaths) { + if (topicPath != null && !topicPath.isBlank()) { + bindTopic(topicPath, sliceRef); + } + for (String relatedTopicPath : relatedTopicPaths) { + if (relatedTopicPath != null && !relatedTopicPath.isBlank()) { bindTopic(relatedTopicPath, sliceRef); } } @@ -89,14 +90,12 @@ public class MemoryRuntime extends AbstractAgentModule.Standalone implements Sta for (CopyOnWriteArrayList refs : dateIndex.values()) { refs.removeIf(ref -> memoryUnit.getId().equals(ref.getUnitId())); } - if (memoryUnit.getSlices() != null) { - for (MemorySlice slice : memoryUnit.getSlices()) { - LocalDate date = Instant.ofEpochMilli(slice.getTimestamp()) - .atZone(ZoneId.systemDefault()) - .toLocalDate(); - dateIndex.computeIfAbsent(date, key -> new CopyOnWriteArrayList<>()) - .addIfAbsent(new SliceRef(memoryUnit.getId(), slice.getId())); - } + for (MemorySlice slice : memoryUnit.getSlices()) { + LocalDate date = Instant.ofEpochMilli(slice.getTimestamp()) + .atZone(ZoneId.systemDefault()) + .toLocalDate(); + dateIndex.computeIfAbsent(date, key -> new CopyOnWriteArrayList<>()) + .addIfAbsent(new SliceRef(memoryUnit.getId(), slice.getId())); } } finally { runtimeLock.unlock(); @@ -192,7 +191,7 @@ public class MemoryRuntime extends AbstractAgentModule.Standalone implements Sta private List sliceMessages(MemoryUnit memoryUnit, MemorySlice memorySlice) { List conversationMessages = memoryUnit.getConversationMessages(); - if (conversationMessages == null || conversationMessages.isEmpty()) { + if (conversationMessages.isEmpty()) { return List.of(); } int size = conversationMessages.size(); diff --git a/Partner-Core/src/main/java/work/slhaf/partner/module/memory/updater/MemoryUpdater.java b/Partner-Core/src/main/java/work/slhaf/partner/module/memory/updater/MemoryUpdater.java index 9519f369..a6f95be8 100644 --- a/Partner-Core/src/main/java/work/slhaf/partner/module/memory/updater/MemoryUpdater.java +++ b/Partner-Core/src/main/java/work/slhaf/partner/module/memory/updater/MemoryUpdater.java @@ -1,192 +1,94 @@ package work.slhaf.partner.module.memory.updater; -import kotlin.Unit; -import lombok.Data; -import lombok.EqualsAndHashCode; import org.jetbrains.annotations.NotNull; -import work.slhaf.partner.core.action.ActionCapability; -import work.slhaf.partner.core.action.ActionCore; -import work.slhaf.partner.core.action.entity.Schedulable; -import work.slhaf.partner.core.action.entity.StateAction; +import org.w3c.dom.Document; +import org.w3c.dom.Element; import work.slhaf.partner.core.cognition.CognitionCapability; -import work.slhaf.partner.core.memory.MemoryCapability; -import work.slhaf.partner.core.memory.pojo.MemorySlice; -import work.slhaf.partner.core.memory.pojo.MemoryUnit; -import work.slhaf.partner.core.perceive.PerceiveCapability; +import work.slhaf.partner.core.cognition.ContextBlock; import work.slhaf.partner.framework.agent.factory.capability.annotation.InjectCapability; import work.slhaf.partner.framework.agent.factory.component.abstracts.AbstractAgentModule; import work.slhaf.partner.framework.agent.factory.component.annotation.Init; import work.slhaf.partner.framework.agent.factory.component.annotation.InjectModule; +import work.slhaf.partner.framework.agent.model.ActivateModel; import work.slhaf.partner.framework.agent.model.pojo.Message; -import work.slhaf.partner.module.action.scheduler.ActionScheduler; -import work.slhaf.partner.module.communication.DialogRollingService; +import work.slhaf.partner.framework.agent.support.Result; +import work.slhaf.partner.module.TaskBlock; +import work.slhaf.partner.module.communication.AfterRolling; +import work.slhaf.partner.module.communication.AfterRollingRegistry; +import work.slhaf.partner.module.communication.RollingResult; import work.slhaf.partner.module.memory.runtime.MemoryRuntime; -import work.slhaf.partner.module.memory.updater.summarizer.MultiSummarizer; -import work.slhaf.partner.module.memory.updater.summarizer.SingleSummarizer; -import work.slhaf.partner.module.memory.updater.summarizer.entity.SummarizeInput; -import work.slhaf.partner.runtime.PartnerRunningFlowContext; +import work.slhaf.partner.module.memory.updater.summarizer.entity.MemoryTopicResult; import java.util.List; -import java.util.concurrent.ExecutorService; -import java.util.concurrent.atomic.AtomicBoolean; -@EqualsAndHashCode(callSuper = true) -@Data -public class MemoryUpdater extends AbstractAgentModule.Running { - - private static final String AUTO_UPDATE_CRON = "0/10 * * * * ?"; - private static final long UPDATE_TRIGGER_INTERVAL = 60 * 60 * 1000; - private static final int CONTEXT_RETAIN_DIVISOR = 6; - private static final int MEMORY_UPDATE_TRIGGER_ROLL_LIMIT = 36; +public class MemoryUpdater extends AbstractAgentModule.Standalone implements AfterRolling, ActivateModel { @InjectCapability private CognitionCapability cognitionCapability; - @InjectCapability - private MemoryCapability memoryCapability; - @InjectCapability - private PerceiveCapability perceiveCapability; - @InjectCapability - private ActionCapability actionCapability; @InjectModule private MemoryRuntime memoryRuntime; @InjectModule - private MultiSummarizer multiSummarizer; - @InjectModule - private SingleSummarizer singleSummarizer; - @InjectModule - private ActionScheduler actionScheduler; - @InjectModule - private DialogRollingService dialogRollingService; - - private final AtomicBoolean updating = new AtomicBoolean(false); - private ExecutorService executor; + private AfterRollingRegistry afterRollingRegistry; @Init public void init() { - executor = actionCapability.getExecutor(ActionCore.ExecutorType.VIRTUAL); - registerScheduledUpdater(); + afterRollingRegistry.register(this); } - private void registerScheduledUpdater() { - StateAction stateAction = new StateAction( - "system", - "memory-auto-update", - "定时检查并触发记忆更新", - Schedulable.ScheduleType.CYCLE, - AUTO_UPDATE_CRON, - new StateAction.Trigger.Call(() -> { - tryAutoUpdate(); - return Unit.INSTANCE; - }) + @Override + public void consume(RollingResult result) { + List slicedMessages = sliceMessages(result); + if (slicedMessages.isEmpty()) { + return; + } + Result extractResult = formattedChat( + List.of( + cognitionCapability.contextWorkspace().resolve(List.of( + ContextBlock.VisibleDomain.COGNITION, + ContextBlock.VisibleDomain.MEMORY + )).encodeToMessage(), + resolveTopicTaskMessage(result, slicedMessages) + ), + MemoryTopicResult.class ); - actionScheduler.schedule(stateAction); - log.info("[MemoryUpdater] 记忆自动更新已注册到 ActionScheduler, cron={}", AUTO_UPDATE_CRON); + extractResult.onSuccess(topicResult -> { + String topicPath = topicResult.getTopicPath() == null ? null : memoryRuntime.fixTopicPath(topicResult.getTopicPath()); + List relatedTopicPaths = topicResult.getRelatedTopicPaths() == null + ? List.of() + : topicResult.getRelatedTopicPaths().stream().map(memoryRuntime::fixTopicPath).toList(); + memoryRuntime.recordMemory(result.memoryUnit(), topicPath, relatedTopicPaths); + }).onFailure(exp -> memoryRuntime.recordMemory(result.memoryUnit(), null, List.of())); } - @Override - protected void doExecute(@NotNull PartnerRunningFlowContext context) { - boolean trigger = cognitionCapability.getChatMessages().size() >= MEMORY_UPDATE_TRIGGER_ROLL_LIMIT; - if (!trigger) { - return; + private List sliceMessages(RollingResult result) { + int size = result.memoryUnit().getConversationMessages().size(); + int start = Math.clamp(result.memorySlice().getStartIndex(), 0, size); + int end = Math.clamp(result.memorySlice().getEndIndex(), start, size); + if (start >= end) { + return List.of(); } - executor.execute(() -> { - log.debug("[MemoryUpdater] 记忆更新触发"); - triggerMemoryUpdate(false); - }); + return result.memoryUnit().getConversationMessages().subList(start, end); } - private void tryAutoUpdate() { - long currentTime = System.currentTimeMillis(); - int chatCount = cognitionCapability.snapshotChatMessages().size(); - if (currentTime - perceiveCapability.showLastInteract().toEpochMilli() > UPDATE_TRIGGER_INTERVAL && chatCount > 1) { - triggerMemoryUpdate(true); - log.info("[MemoryUpdater] 记忆更新: 自动触发"); - } - } - - private void triggerMemoryUpdate(boolean refreshMemoryId) { - if (!updating.compareAndSet(false, true)) { - log.debug("[MemoryUpdater] 更新任务已在执行中,本次触发跳过"); - return; - } - try { - List fullChatSnapshot = cognitionCapability.snapshotChatMessages(); - if (fullChatSnapshot.size() <= 1) { - return; - } - List chatIncrement = resolveChatIncrement(fullChatSnapshot); - if (chatIncrement.isEmpty()) { - if (refreshMemoryId) { - memoryCapability.refreshMemorySession(); - } - return; - } - - RollingRecord record = updateMemory(chatIncrement); - dialogRollingService.rollMessages(chatIncrement, fullChatSnapshot.size(), CONTEXT_RETAIN_DIVISOR, record.unitId, record.sliceId, record.summary); - - if (refreshMemoryId) { - memoryCapability.refreshMemorySession(); - } - } catch (Exception e) { - log.error("[MemoryUpdater] 记忆更新线程出错: ", e); - } finally { - updating.set(false); - } - } - - private List resolveChatIncrement(List fullChatSnapshot) { - String memoryId = memoryCapability.getMemorySessionId(); - if (memoryId == null || memoryId.isBlank()) { - return fullChatSnapshot; - } - MemoryUnit existingUnit = memoryCapability.getMemoryUnit(memoryId); - if (existingUnit == null || existingUnit.getConversationMessages() == null || existingUnit.getConversationMessages().isEmpty()) { - return fullChatSnapshot; - } - List existingMessages = existingUnit.getConversationMessages(); - int maxOverlap = Math.min(existingMessages.size(), fullChatSnapshot.size()); - for (int overlap = maxOverlap; overlap > 0; overlap--) { - List existingSuffix = existingMessages.subList(existingMessages.size() - overlap, existingMessages.size()); - List snapshotPrefix = fullChatSnapshot.subList(0, overlap); - if (existingSuffix.equals(snapshotPrefix)) { - return fullChatSnapshot.subList(overlap, fullChatSnapshot.size()); - } - } - return fullChatSnapshot; - } - - private RollingRecord updateMemory(List chatSnapshot) { - log.debug("[MemoryUpdater] 记忆更新流程开始..."); - if (chatSnapshot.isEmpty()) { - return null; - } - SummarizeInput summarizeInput = new SummarizeInput(chatSnapshot, memoryRuntime.getTopicTree()); - singleSummarizer.execute(summarizeInput.getChatMessages()); - return multiSummarizer.execute(summarizeInput).fold( - summarizeResult -> { - MemoryUnit memoryUnit = memoryCapability.updateMemoryUnit(chatSnapshot, summarizeResult.getSummary()); - memoryRuntime.recordMemory( - memoryUnit, - summarizeResult.getTopicPath(), - summarizeResult.getRelatedTopicPath() - ); - MemorySlice newSlice = memoryUnit.getSlices().getLast(); - return new RollingRecord(memoryUnit.getId(), newSlice.getId(), newSlice.getSummary()); - }, - exp -> { - MemoryUnit memoryUnit = memoryCapability.updateMemoryUnit(chatSnapshot, "no summary, due to exception"); - MemorySlice newSlice = memoryUnit.getSlices().getLast(); - return new RollingRecord(memoryUnit.getId(), newSlice.getId(), newSlice.getSummary()); + private Message resolveTopicTaskMessage(RollingResult result, List slicedMessages) { + return new TaskBlock() { + @Override + protected void fillXml(@NotNull Document document, @NotNull Element root) { + appendTextElement(document, root, "current_topic_tree", memoryRuntime.getTopicTree()); + appendTextElement(document, root, "slice_summary", result.summary()); + appendRepeatedElements(document, root, "message", slicedMessages, (messageElement, message) -> { + messageElement.setAttribute("role", message.roleValue()); + messageElement.setTextContent(message.getContent()); + return kotlin.Unit.INSTANCE; }); + } + }.encodeToMessage(); } @Override - public int order() { - return 7; - } - - private record RollingRecord(String unitId, String sliceId, String summary) { + @NotNull + public String modelKey() { + return "topic_extractor"; } } diff --git a/Partner-Core/src/main/java/work/slhaf/partner/module/memory/updater/summarizer/entity/MemoryTopicResult.java b/Partner-Core/src/main/java/work/slhaf/partner/module/memory/updater/summarizer/entity/MemoryTopicResult.java new file mode 100644 index 00000000..15ec6523 --- /dev/null +++ b/Partner-Core/src/main/java/work/slhaf/partner/module/memory/updater/summarizer/entity/MemoryTopicResult.java @@ -0,0 +1,11 @@ +package work.slhaf.partner.module.memory.updater.summarizer.entity; + +import lombok.Data; + +import java.util.List; + +@Data +public class MemoryTopicResult { + private String topicPath; + private List relatedTopicPaths; +} diff --git a/Partner-Core/src/test/java/work/slhaf/partner/module/communication/DialogRollingTest.java b/Partner-Core/src/test/java/work/slhaf/partner/module/communication/DialogRollingTest.java new file mode 100644 index 00000000..039522b3 --- /dev/null +++ b/Partner-Core/src/test/java/work/slhaf/partner/module/communication/DialogRollingTest.java @@ -0,0 +1,231 @@ +package work.slhaf.partner.module.communication; + +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.io.TempDir; +import org.mockito.Mockito; +import work.slhaf.partner.core.memory.MemoryCapability; +import work.slhaf.partner.core.memory.pojo.MemorySlice; +import work.slhaf.partner.core.memory.pojo.MemoryUnit; +import work.slhaf.partner.framework.agent.model.pojo.Message; +import work.slhaf.partner.framework.agent.support.Result; +import work.slhaf.partner.module.memory.runtime.MemoryRuntime; +import work.slhaf.partner.module.memory.updater.summarizer.MultiSummarizer; +import work.slhaf.partner.module.memory.updater.summarizer.SingleSummarizer; +import work.slhaf.partner.module.memory.updater.summarizer.entity.SummarizeResult; + +import java.lang.reflect.Field; +import java.nio.file.Path; +import java.util.*; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.mockito.Mockito.when; + +class DialogRollingTest { + + @BeforeAll + static void beforeAll(@TempDir Path tempDir) { + System.setProperty("user.home", tempDir.toAbsolutePath().toString()); + } + + private static void setField(Object target, String fieldName, Object value) throws Exception { + Field field = target.getClass().getDeclaredField(fieldName); + field.setAccessible(true); + field.set(target, value); + } + + private static Message message(Message.Character role, String content) { + return new Message(role, content); + } + + private static SummarizeResult summarizeResult(String summary, String topicPath, List relatedTopicPath) { + SummarizeResult result = new SummarizeResult(); + result.setSummary(summary); + result.setTopicPath(topicPath); + result.setRelatedTopicPath(relatedTopicPath); + return result; + } + + @Test + void shouldDelegateMemoryUpdateToCapability() throws Exception { + String sessionId = "dialog-rolling-" + UUID.randomUUID(); + StubMemoryCapability memoryCapability = new StubMemoryCapability(sessionId); + DialogRolling dialogRolling = new DialogRolling(); + MemoryRuntime memoryRuntime = Mockito.mock(MemoryRuntime.class); + MultiSummarizer multiSummarizer = Mockito.mock(MultiSummarizer.class); + SingleSummarizer singleSummarizer = Mockito.mock(SingleSummarizer.class); + setField(dialogRolling, "memoryCapability", memoryCapability); + setField(dialogRolling, "memoryRuntime", memoryRuntime); + setField(dialogRolling, "multiSummarizer", multiSummarizer); + setField(dialogRolling, "singleSummarizer", singleSummarizer); + + when(memoryRuntime.getTopicTree()).thenReturn("topic-tree"); + when(multiSummarizer.execute(Mockito.any())).thenReturn(Result.success( + summarizeResult("new-summary", "topic/main", List.of("topic/related")) + )); + + MemoryUnit existingUnit = new MemoryUnit(sessionId); + existingUnit.getConversationMessages().addAll(List.of( + message(Message.Character.USER, "old-user"), + message(Message.Character.ASSISTANT, "old-assistant") + )); + existingUnit.getSlices().add(MemorySlice.restore("slice-1", 0, 2, "old-summary", 1L)); + memoryCapability.putUnit(existingUnit); + + RollingResult rollingResult = dialogRolling.buildRollingResult(List.of( + message(Message.Character.USER, "new-user"), + message(Message.Character.ASSISTANT, "new-assistant") + ), 4, 6); + + MemoryUnit merged = memoryCapability.getMemoryUnit(sessionId); + assertEquals(List.of("old-user", "old-assistant", "new-user", "new-assistant"), + merged.getConversationMessages().stream().map(Message::getContent).toList()); + assertEquals(2, merged.getSlices().size()); + + MemorySlice appendedSlice = merged.getSlices().getLast(); + assertNotNull(appendedSlice.getId()); + assertEquals(2, appendedSlice.getStartIndex()); + assertEquals(4, appendedSlice.getEndIndex()); + assertEquals("new-summary", appendedSlice.getSummary()); + assertEquals(sessionId, rollingResult.memoryUnit().getId()); + assertEquals(appendedSlice.getId(), rollingResult.memorySlice().getId()); + assertEquals("new-summary", rollingResult.summary()); + } + + @Test + void shouldCreateFirstSliceForFreshSessionThroughCapability() throws Exception { + String sessionId = "dialog-rolling-" + UUID.randomUUID(); + StubMemoryCapability memoryCapability = new StubMemoryCapability(sessionId); + DialogRolling dialogRolling = new DialogRolling(); + MemoryRuntime memoryRuntime = Mockito.mock(MemoryRuntime.class); + MultiSummarizer multiSummarizer = Mockito.mock(MultiSummarizer.class); + SingleSummarizer singleSummarizer = Mockito.mock(SingleSummarizer.class); + setField(dialogRolling, "memoryCapability", memoryCapability); + setField(dialogRolling, "memoryRuntime", memoryRuntime); + setField(dialogRolling, "multiSummarizer", multiSummarizer); + setField(dialogRolling, "singleSummarizer", singleSummarizer); + + when(memoryRuntime.getTopicTree()).thenReturn("topic-tree"); + when(multiSummarizer.execute(Mockito.any())).thenReturn(Result.success( + summarizeResult("fresh-summary", "topic/root", List.of()) + )); + + RollingResult rollingResult = dialogRolling.buildRollingResult(List.of( + message(Message.Character.USER, "first"), + message(Message.Character.ASSISTANT, "second") + ), 2, 6); + + MemoryUnit created = memoryCapability.getMemoryUnit(sessionId); + assertNotNull(created); + assertEquals(List.of("first", "second"), + created.getConversationMessages().stream().map(Message::getContent).toList()); + assertEquals(1, created.getSlices().size()); + assertEquals(0, created.getSlices().getFirst().getStartIndex()); + assertEquals(2, created.getSlices().getFirst().getEndIndex()); + assertEquals("fresh-summary", created.getSlices().getFirst().getSummary()); + assertEquals(created, rollingResult.memoryUnit()); + } + + @Test + void shouldTrimPersistedOverlapFromCurrentSnapshot() throws Exception { + String sessionId = "dialog-rolling-" + UUID.randomUUID(); + StubMemoryCapability memoryCapability = new StubMemoryCapability(sessionId); + DialogRolling dialogRolling = new DialogRolling(); + setField(dialogRolling, "memoryCapability", memoryCapability); + + MemoryUnit existingUnit = Mockito.mock(MemoryUnit.class); + when(existingUnit.getConversationMessages()).thenReturn(List.of( + message(Message.Character.USER, "m1"), + message(Message.Character.ASSISTANT, "m2"), + message(Message.Character.USER, "m3"), + message(Message.Character.ASSISTANT, "m4") + )); + memoryCapability.putUnit(sessionId, existingUnit); + + List increment = dialogRolling.resolveChatIncrement(List.of( + message(Message.Character.USER, "m3"), + message(Message.Character.ASSISTANT, "m4"), + message(Message.Character.USER, "m5"), + message(Message.Character.ASSISTANT, "m6") + )); + + assertEquals(List.of("m5", "m6"), increment.stream().map(Message::getContent).toList()); + } + + @Test + void shouldFallbackWhenSummarizeResultIsBlank() throws Exception { + String sessionId = "dialog-rolling-" + UUID.randomUUID(); + StubMemoryCapability memoryCapability = new StubMemoryCapability(sessionId); + DialogRolling dialogRolling = new DialogRolling(); + MemoryRuntime memoryRuntime = Mockito.mock(MemoryRuntime.class); + MultiSummarizer multiSummarizer = Mockito.mock(MultiSummarizer.class); + SingleSummarizer singleSummarizer = Mockito.mock(SingleSummarizer.class); + setField(dialogRolling, "memoryCapability", memoryCapability); + setField(dialogRolling, "memoryRuntime", memoryRuntime); + setField(dialogRolling, "multiSummarizer", multiSummarizer); + setField(dialogRolling, "singleSummarizer", singleSummarizer); + + when(memoryRuntime.getTopicTree()).thenReturn("topic-tree"); + when(multiSummarizer.execute(Mockito.any())).thenReturn(Result.success(summarizeResult(" ", "topic/root", List.of()))); + + RollingResult rollingResult = dialogRolling.buildRollingResult(List.of( + message(Message.Character.USER, "u1"), + message(Message.Character.ASSISTANT, "a1") + ), 2, 6); + + assertEquals(sessionId, rollingResult.memoryUnit().getId()); + assertEquals("no summary, due to empty summarize result", rollingResult.summary()); + } + + private static final class StubMemoryCapability implements MemoryCapability { + private final String sessionId; + private final Map units = new HashMap<>(); + + private StubMemoryCapability(String sessionId) { + this.sessionId = sessionId; + } + + private void putUnit(String unitId, MemoryUnit memoryUnit) { + units.put(unitId == null ? memoryUnit.getId() : unitId, memoryUnit); + } + + private void putUnit(MemoryUnit memoryUnit) { + units.put(memoryUnit.getId(), memoryUnit); + } + + @Override + public MemoryUnit getMemoryUnit(String unitId) { + return units.get(unitId); + } + + @Override + public work.slhaf.partner.framework.agent.support.Result getMemorySlice(String unitId, String sliceId) { + return null; + } + + @Override + public MemoryUnit updateMemoryUnit(List chatMessages, String summary) { + MemoryUnit unit = units.computeIfAbsent(sessionId, MemoryUnit::new); + unit.updateTimestamp(); + int startIndex = unit.getConversationMessages().size(); + unit.getConversationMessages().addAll(chatMessages); + unit.getSlices().add(new MemorySlice(startIndex, startIndex + chatMessages.size(), summary)); + return unit; + } + + @Override + public Collection listMemoryUnits() { + return units.values(); + } + + @Override + public void refreshMemorySession() { + } + + @Override + public String getMemorySessionId() { + return sessionId; + } + } +} diff --git a/Partner-Core/src/test/java/work/slhaf/partner/module/memory/updater/MemoryUpdaterTest.java b/Partner-Core/src/test/java/work/slhaf/partner/module/memory/updater/MemoryUpdaterTest.java index bd651818..620c5a9e 100644 --- a/Partner-Core/src/test/java/work/slhaf/partner/module/memory/updater/MemoryUpdaterTest.java +++ b/Partner-Core/src/test/java/work/slhaf/partner/module/memory/updater/MemoryUpdaterTest.java @@ -4,26 +4,21 @@ import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.io.TempDir; import org.mockito.Mockito; -import work.slhaf.partner.core.memory.MemoryCapability; +import work.slhaf.partner.core.cognition.CognitionCapability; import work.slhaf.partner.core.memory.pojo.MemorySlice; import work.slhaf.partner.core.memory.pojo.MemoryUnit; +import work.slhaf.partner.framework.agent.exception.AgentRuntimeException; import work.slhaf.partner.framework.agent.model.pojo.Message; import work.slhaf.partner.framework.agent.support.Result; +import work.slhaf.partner.module.communication.AfterRollingRegistry; +import work.slhaf.partner.module.communication.RollingResult; import work.slhaf.partner.module.memory.runtime.MemoryRuntime; -import work.slhaf.partner.module.memory.runtime.exception.MemoryLookupException; -import work.slhaf.partner.module.memory.updater.summarizer.MultiSummarizer; -import work.slhaf.partner.module.memory.updater.summarizer.SingleSummarizer; -import work.slhaf.partner.module.memory.updater.summarizer.entity.SummarizeResult; +import work.slhaf.partner.module.memory.updater.summarizer.entity.MemoryTopicResult; import java.lang.reflect.Field; -import java.lang.reflect.Method; import java.nio.file.Path; -import java.util.Collection; -import java.util.HashMap; import java.util.List; -import java.util.Map; -import static org.junit.jupiter.api.Assertions.*; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; @@ -35,275 +30,89 @@ class MemoryUpdaterTest { System.setProperty("user.home", tempDir.toAbsolutePath().toString()); } - private static Object invokeUpdateMemory(MemoryUpdater updater, List chatMessages) throws Exception { - Method method = MemoryUpdater.class.getDeclaredMethod("updateMemory", List.class); - method.setAccessible(true); - return method.invoke(updater, chatMessages); - } - - @SuppressWarnings("unchecked") - private static List invokeResolveChatIncrement(MemoryUpdater updater, - List chatMessages) throws Exception { - Method method = MemoryUpdater.class.getDeclaredMethod("resolveChatIncrement", List.class); - method.setAccessible(true); - return (List) method.invoke(updater, chatMessages); - } - - private static String recordField(Object target, String fieldName) throws Exception { - Field field = target.getClass().getDeclaredField(fieldName); - field.setAccessible(true); - return (String) field.get(target); - } - private static void setField(Object target, String fieldName, Object value) throws Exception { Field field = target.getClass().getDeclaredField(fieldName); field.setAccessible(true); field.set(target, value); } - private static SummarizeResult summarizeResult(String summary, String topicPath, List relatedTopicPath) { - SummarizeResult result = new SummarizeResult(); - result.setSummary(summary); - result.setTopicPath(topicPath); - result.setRelatedTopicPath(relatedTopicPath); - return result; - } - private static Message message(Message.Character role, String content) { return new Message(role, content); } @Test - void shouldDelegateMemoryUpdateToCapabilityAndRuntime() throws Exception { - StubMemoryCapability memoryCapability = new StubMemoryCapability("session-1"); - MemoryUpdater updater = new MemoryUpdater(); - MemoryRuntime memoryRuntime = Mockito.mock(MemoryRuntime.class); - MultiSummarizer multiSummarizer = Mockito.mock(MultiSummarizer.class); - SingleSummarizer singleSummarizer = Mockito.mock(SingleSummarizer.class); - setField(updater, "memoryCapability", memoryCapability); - setField(updater, "memoryRuntime", memoryRuntime); - setField(updater, "multiSummarizer", multiSummarizer); - setField(updater, "singleSummarizer", singleSummarizer); + void shouldRegisterItselfToAfterRollingRegistryOnInit() throws Exception { + MemoryUpdater updater = Mockito.spy(new MemoryUpdater()); + AfterRollingRegistry registry = Mockito.mock(AfterRollingRegistry.class); + setField(updater, "afterRollingRegistry", registry); + updater.init(); + + verify(registry).register(updater); + } + + @Test + void shouldExtractTopicAndRecordMemoryOnConsume() throws Exception { + MemoryUpdater updater = Mockito.spy(new MemoryUpdater()); + MemoryRuntime memoryRuntime = Mockito.mock(MemoryRuntime.class); + CognitionCapability cognitionCapability = Mockito.mock(CognitionCapability.class); + setField(updater, "memoryRuntime", memoryRuntime); + setField(updater, "cognitionCapability", cognitionCapability); + + when(cognitionCapability.contextWorkspace()).thenReturn(new work.slhaf.partner.core.cognition.ContextWorkspace()); when(memoryRuntime.getTopicTree()).thenReturn("topic-tree"); - when(multiSummarizer.execute(Mockito.any())).thenReturn(Result.success( - summarizeResult("new-summary", "topic/main", List.of("topic/related")) + when(memoryRuntime.fixTopicPath("root[2]->branch[1]")).thenReturn("root->branch"); + when(memoryRuntime.fixTopicPath("root[2]->related[1]")).thenReturn("root->related"); + + MemoryTopicResult topicResult = new MemoryTopicResult(); + topicResult.setTopicPath("root[2]->branch[1]"); + topicResult.setRelatedTopicPaths(List.of("root[2]->related[1]")); + Mockito.doReturn(Result.success(topicResult)) + .when(updater) + .formattedChat(Mockito.anyList(), eq(MemoryTopicResult.class)); + + MemoryUnit unit = new MemoryUnit("session-1"); + unit.getConversationMessages().addAll(List.of( + message(Message.Character.USER, "old"), + message(Message.Character.ASSISTANT, "old-reply"), + message(Message.Character.USER, "new"), + message(Message.Character.ASSISTANT, "new-reply") )); + MemorySlice slice = new MemorySlice(2, 4, "slice-summary"); + unit.getSlices().add(slice); - MemoryUnit existingUnit = new MemoryUnit("session-1"); - existingUnit.getConversationMessages().addAll(List.of( - message(Message.Character.USER, "old-user"), - message(Message.Character.ASSISTANT, "old-assistant") - )); - existingUnit.getSlices().add(MemorySlice.restore("slice-1", 0, 2, "old-summary", 1L)); - memoryCapability.putUnit(existingUnit); + updater.consume(new RollingResult(unit, slice, List.of( + message(Message.Character.USER, "new"), + message(Message.Character.ASSISTANT, "new-reply") + ), "slice-summary", 4, 6)); - Object rollingRecord = invokeUpdateMemory(updater, List.of( - message(Message.Character.USER, "new-user"), - message(Message.Character.ASSISTANT, "new-assistant") - )); - - MemoryUnit merged = memoryCapability.getMemoryUnit("session-1"); - assertEquals(List.of("old-user", "old-assistant", "new-user", "new-assistant"), - merged.getConversationMessages().stream().map(Message::getContent).toList()); - assertEquals(2, merged.getSlices().size()); - - MemorySlice appendedSlice = merged.getSlices().getLast(); - assertNotNull(appendedSlice.getId()); - assertEquals(2, appendedSlice.getStartIndex()); - assertEquals(4, appendedSlice.getEndIndex()); - assertEquals("new-summary", appendedSlice.getSummary()); - - assertEquals(List.of("new-user", "new-assistant"), - memoryCapability.lastChatMessages().stream().map(Message::getContent).toList()); - assertEquals("new-summary", memoryCapability.lastSummary()); - verify(memoryRuntime).recordMemory(eq(merged), eq("topic/main"), eq(List.of("topic/related"))); - assertEquals("session-1", recordField(rollingRecord, "unitId")); - assertEquals(appendedSlice.getId(), recordField(rollingRecord, "sliceId")); - assertEquals("new-summary", recordField(rollingRecord, "summary")); + verify(memoryRuntime).recordMemory(eq(unit), eq("root->branch"), eq(List.of("root->related"))); } @Test - void shouldCreateFirstSliceForFreshSessionThroughCapability() throws Exception { - StubMemoryCapability memoryCapability = new StubMemoryCapability("session-2"); - MemoryUpdater updater = new MemoryUpdater(); + void shouldFallbackToDateOnlyRecordWhenExtractionFails() throws Exception { + MemoryUpdater updater = Mockito.spy(new MemoryUpdater()); MemoryRuntime memoryRuntime = Mockito.mock(MemoryRuntime.class); - MultiSummarizer multiSummarizer = Mockito.mock(MultiSummarizer.class); - SingleSummarizer singleSummarizer = Mockito.mock(SingleSummarizer.class); - setField(updater, "memoryCapability", memoryCapability); + CognitionCapability cognitionCapability = Mockito.mock(CognitionCapability.class); setField(updater, "memoryRuntime", memoryRuntime); - setField(updater, "multiSummarizer", multiSummarizer); - setField(updater, "singleSummarizer", singleSummarizer); + setField(updater, "cognitionCapability", cognitionCapability); + when(cognitionCapability.contextWorkspace()).thenReturn(new work.slhaf.partner.core.cognition.ContextWorkspace()); when(memoryRuntime.getTopicTree()).thenReturn("topic-tree"); - when(multiSummarizer.execute(Mockito.any())).thenReturn(Result.success( - summarizeResult("fresh-summary", "topic/root", List.of()) + Mockito.doReturn(Result.failure(new AgentRuntimeException("boom"))) + .when(updater) + .formattedChat(Mockito.anyList(), eq(MemoryTopicResult.class)); + + MemoryUnit unit = new MemoryUnit("session-2"); + unit.getConversationMessages().addAll(List.of( + message(Message.Character.USER, "u1"), + message(Message.Character.ASSISTANT, "a1") )); + MemorySlice slice = new MemorySlice(0, 2, "slice-summary"); + unit.getSlices().add(slice); - Object rollingRecord = invokeUpdateMemory(updater, List.of( - message(Message.Character.USER, "first"), - message(Message.Character.ASSISTANT, "second") - )); + updater.consume(new RollingResult(unit, slice, unit.getConversationMessages(), "slice-summary", 2, 6)); - MemoryUnit created = memoryCapability.getMemoryUnit("session-2"); - assertNotNull(created); - assertEquals("session-2", created.getId()); - assertEquals(List.of("first", "second"), - created.getConversationMessages().stream().map(Message::getContent).toList()); - assertEquals(1, created.getSlices().size()); - assertEquals(0, created.getSlices().getFirst().getStartIndex()); - assertEquals(2, created.getSlices().getFirst().getEndIndex()); - assertEquals("fresh-summary", created.getSlices().getFirst().getSummary()); - verify(memoryRuntime).recordMemory(eq(created), eq("topic/root"), eq(List.of())); - assertEquals("session-2", recordField(rollingRecord, "unitId")); - assertEquals(created.getSlices().getFirst().getId(), recordField(rollingRecord, "sliceId")); - assertEquals("fresh-summary", recordField(rollingRecord, "summary")); - } - - @Test - void shouldTrimPersistedOverlapFromCurrentSnapshot() throws Exception { - StubMemoryCapability memoryCapability = new StubMemoryCapability("session-3"); - MemoryUpdater updater = new MemoryUpdater(); - setField(updater, "memoryCapability", memoryCapability); - - MemoryUnit existingUnit = Mockito.mock(MemoryUnit.class); - when(existingUnit.getConversationMessages()).thenReturn(List.of( - message(Message.Character.USER, "m1"), - message(Message.Character.ASSISTANT, "m2"), - message(Message.Character.USER, "m3"), - message(Message.Character.ASSISTANT, "m4") - )); - memoryCapability.putUnit("session-3", existingUnit); - - List increment = invokeResolveChatIncrement( - updater, - List.of( - message(Message.Character.USER, "m3"), - message(Message.Character.ASSISTANT, "m4"), - message(Message.Character.USER, "m5"), - message(Message.Character.ASSISTANT, "m6") - ) - ); - - assertEquals(List.of("m5", "m6"), increment.stream().map(Message::getContent).toList()); - } - - @Test - void shouldReturnEmptyIncrementWhenSnapshotIsFullyPersisted() throws Exception { - StubMemoryCapability memoryCapability = new StubMemoryCapability("session-4"); - MemoryUpdater updater = new MemoryUpdater(); - setField(updater, "memoryCapability", memoryCapability); - - MemoryUnit existingUnit = Mockito.mock(MemoryUnit.class); - when(existingUnit.getConversationMessages()).thenReturn(List.of( - message(Message.Character.USER, "m1"), - message(Message.Character.ASSISTANT, "m2"), - message(Message.Character.USER, "m3") - )); - memoryCapability.putUnit("session-4", existingUnit); - - List increment = invokeResolveChatIncrement( - updater, - List.of( - message(Message.Character.ASSISTANT, "m2"), - message(Message.Character.USER, "m3") - ) - ); - - assertEquals(List.of(), increment); - } - - @Test - void shouldReturnNullWhenUpdateMemoryReceivesEmptySnapshot() throws Exception { - StubMemoryCapability memoryCapability = new StubMemoryCapability("session-5"); - MemoryUpdater updater = new MemoryUpdater(); - MemoryRuntime memoryRuntime = Mockito.mock(MemoryRuntime.class); - setField(updater, "memoryCapability", memoryCapability); - setField(updater, "memoryRuntime", memoryRuntime); - - Object rollingRecord = invokeUpdateMemory(updater, List.of()); - - assertNull(rollingRecord); - assertNull(memoryCapability.lastSummary()); - Mockito.verifyNoInteractions(memoryRuntime); - } - - private static final class StubMemoryCapability implements MemoryCapability { - private final String sessionId; - private final Map units = new HashMap<>(); - private List lastChatMessages; - private String lastSummary; - - private StubMemoryCapability(String sessionId) { - this.sessionId = sessionId; - } - - private void putUnit(String unitId, MemoryUnit memoryUnit) { - units.put(unitId, memoryUnit); - } - - private void putUnit(MemoryUnit memoryUnit) { - units.put(memoryUnit.getId(), memoryUnit); - } - - private List lastChatMessages() { - return lastChatMessages; - } - - private String lastSummary() { - return lastSummary; - } - - @Override - public MemoryUnit getMemoryUnit(String unitId) { - return units.get(unitId); - } - - @Override - public Result getMemorySlice(String unitId, String sliceId) { - MemoryUnit unit = units.get(unitId); - if (unit == null || unit.getSlices() == null) { - return Result.failure(new MemoryLookupException( - "Memory slice not found: " + unitId + ":" + sliceId, - unitId + ":" + sliceId, - "MEMORY_SLICE" - )); - } - return unit.getSlices().stream() - .filter(slice -> sliceId.equals(slice.getId())) - .findFirst() - .map(Result::success) - .orElseGet(() -> Result.failure(new MemoryLookupException( - "Memory slice not found: " + unitId + ":" + sliceId, - unitId + ":" + sliceId, - "MEMORY_SLICE" - ))); - } - - @Override - public MemoryUnit updateMemoryUnit(List chatMessages, String summary) { - lastChatMessages = List.copyOf(chatMessages); - lastSummary = summary; - MemoryUnit unit = units.computeIfAbsent(sessionId, MemoryUnit::new); - unit.updateTimestamp(); - int startIndex = unit.getConversationMessages().size(); - unit.getConversationMessages().addAll(chatMessages); - unit.getSlices().add(new MemorySlice(startIndex, startIndex + chatMessages.size(), summary)); - return unit; - } - - @Override - public Collection listMemoryUnits() { - return units.values(); - } - - @Override - public void refreshMemorySession() { - } - - @Override - public String getMemorySessionId() { - return sessionId; - } + verify(memoryRuntime).recordMemory(eq(unit), eq(null), eq(List.of())); } }