mirror of
https://github.com/slhaf/Partner.git
synced 2026-05-12 16:53:04 +08:00
refactor(rolling): separate dialog rolling with memory topic binding,
now support register after-rolling consumer independently
This commit is contained in:
@@ -75,7 +75,7 @@ public class MemoryCore implements StateSerializable {
|
||||
@CapabilityMethod
|
||||
public Result<MemorySlice> getMemorySlice(String unitId, String sliceId) {
|
||||
MemoryUnit memoryUnit = memoryUnits.get(unitId);
|
||||
if (memoryUnit == null || memoryUnit.getSlices() == null) {
|
||||
if (memoryUnit == null) {
|
||||
return Result.failure(new MemoryLookupException(
|
||||
"Memory slice not found: " + unitId + ":" + sliceId,
|
||||
unitId + ":" + sliceId,
|
||||
|
||||
@@ -0,0 +1,5 @@
|
||||
package work.slhaf.partner.module.communication;
|
||||
|
||||
public interface AfterRolling {
|
||||
void consume(RollingResult result);
|
||||
}
|
||||
@@ -0,0 +1,52 @@
|
||||
package work.slhaf.partner.module.communication;
|
||||
|
||||
import lombok.EqualsAndHashCode;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import work.slhaf.partner.core.action.ActionCapability;
|
||||
import work.slhaf.partner.core.action.ActionCore;
|
||||
import work.slhaf.partner.framework.agent.exception.AgentRuntimeException;
|
||||
import work.slhaf.partner.framework.agent.exception.ExceptionReporterHandler;
|
||||
import work.slhaf.partner.framework.agent.factory.capability.annotation.InjectCapability;
|
||||
import work.slhaf.partner.framework.agent.factory.component.abstracts.AbstractAgentModule;
|
||||
import work.slhaf.partner.framework.agent.factory.component.annotation.Init;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.concurrent.CopyOnWriteArrayList;
|
||||
import java.util.concurrent.ExecutorService;
|
||||
|
||||
@EqualsAndHashCode(callSuper = true)
|
||||
@Slf4j
|
||||
public class AfterRollingRegistry extends AbstractAgentModule.Standalone {
|
||||
|
||||
private final CopyOnWriteArrayList<AfterRolling> consumers = new CopyOnWriteArrayList<>();
|
||||
@InjectCapability
|
||||
private ActionCapability actionCapability;
|
||||
private ExecutorService executor;
|
||||
|
||||
@Init
|
||||
public void init() {
|
||||
executor = actionCapability.getExecutor(ActionCore.ExecutorType.VIRTUAL);
|
||||
}
|
||||
|
||||
public void register(AfterRolling consumer) {
|
||||
if (consumers.contains(consumer)) {
|
||||
return;
|
||||
}
|
||||
consumers.add(consumer);
|
||||
}
|
||||
|
||||
public void trigger(RollingResult result) {
|
||||
if (consumers.isEmpty()) {
|
||||
return;
|
||||
}
|
||||
executor.execute(() -> {
|
||||
for (AfterRolling consumer : List.copyOf(consumers)) {
|
||||
try {
|
||||
consumer.consume(result);
|
||||
} catch (Exception e) {
|
||||
ExceptionReporterHandler.INSTANCE.report(new AgentRuntimeException("after-rolling consumer occurred exception: " + consumer.getClass().getSimpleName(), e));
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,208 @@
|
||||
package work.slhaf.partner.module.communication;
|
||||
|
||||
import kotlin.Unit;
|
||||
import lombok.Data;
|
||||
import lombok.EqualsAndHashCode;
|
||||
import org.jetbrains.annotations.NotNull;
|
||||
import org.w3c.dom.Document;
|
||||
import org.w3c.dom.Element;
|
||||
import work.slhaf.partner.core.action.ActionCapability;
|
||||
import work.slhaf.partner.core.action.ActionCore;
|
||||
import work.slhaf.partner.core.action.entity.Schedulable;
|
||||
import work.slhaf.partner.core.action.entity.StateAction;
|
||||
import work.slhaf.partner.core.cognition.BlockContent;
|
||||
import work.slhaf.partner.core.cognition.CognitionCapability;
|
||||
import work.slhaf.partner.core.cognition.ContextBlock;
|
||||
import work.slhaf.partner.core.memory.MemoryCapability;
|
||||
import work.slhaf.partner.core.memory.pojo.MemorySlice;
|
||||
import work.slhaf.partner.core.memory.pojo.MemoryUnit;
|
||||
import work.slhaf.partner.core.perceive.PerceiveCapability;
|
||||
import work.slhaf.partner.framework.agent.exception.AgentRuntimeException;
|
||||
import work.slhaf.partner.framework.agent.exception.ExceptionReporterHandler;
|
||||
import work.slhaf.partner.framework.agent.factory.capability.annotation.InjectCapability;
|
||||
import work.slhaf.partner.framework.agent.factory.component.abstracts.AbstractAgentModule;
|
||||
import work.slhaf.partner.framework.agent.factory.component.annotation.Init;
|
||||
import work.slhaf.partner.framework.agent.factory.component.annotation.InjectModule;
|
||||
import work.slhaf.partner.framework.agent.model.pojo.Message;
|
||||
import work.slhaf.partner.framework.agent.support.Result;
|
||||
import work.slhaf.partner.module.action.scheduler.ActionScheduler;
|
||||
import work.slhaf.partner.module.memory.runtime.MemoryRuntime;
|
||||
import work.slhaf.partner.module.memory.updater.summarizer.MultiSummarizer;
|
||||
import work.slhaf.partner.module.memory.updater.summarizer.SingleSummarizer;
|
||||
import work.slhaf.partner.module.memory.updater.summarizer.entity.SummarizeInput;
|
||||
import work.slhaf.partner.module.memory.updater.summarizer.entity.SummarizeResult;
|
||||
import work.slhaf.partner.runtime.PartnerRunningFlowContext;
|
||||
|
||||
import java.time.ZonedDateTime;
|
||||
import java.time.format.DateTimeFormatter;
|
||||
import java.util.List;
|
||||
import java.util.Set;
|
||||
import java.util.concurrent.ExecutorService;
|
||||
import java.util.concurrent.atomic.AtomicBoolean;
|
||||
|
||||
@EqualsAndHashCode(callSuper = true)
|
||||
@Data
|
||||
public class DialogRolling extends AbstractAgentModule.Running<PartnerRunningFlowContext> {
|
||||
|
||||
private static final String AUTO_UPDATE_CRON = "0/10 * * * * ?";
|
||||
private static final long UPDATE_TRIGGER_INTERVAL = 60 * 60 * 1000;
|
||||
private static final int CONTEXT_RETAIN_DIVISOR = 6;
|
||||
private static final int MEMORY_UPDATE_TRIGGER_ROLL_LIMIT = 36;
|
||||
private final AtomicBoolean rolling = new AtomicBoolean(false);
|
||||
@InjectCapability
|
||||
private CognitionCapability cognitionCapability;
|
||||
@InjectCapability
|
||||
private MemoryCapability memoryCapability;
|
||||
@InjectCapability
|
||||
private PerceiveCapability perceiveCapability;
|
||||
@InjectCapability
|
||||
private ActionCapability actionCapability;
|
||||
@InjectModule
|
||||
private MemoryRuntime memoryRuntime;
|
||||
@InjectModule
|
||||
private MultiSummarizer multiSummarizer;
|
||||
@InjectModule
|
||||
private SingleSummarizer singleSummarizer;
|
||||
@InjectModule
|
||||
private ActionScheduler actionScheduler;
|
||||
@InjectModule
|
||||
private AfterRollingRegistry afterRollingRegistry;
|
||||
private ExecutorService executor;
|
||||
|
||||
@Init
|
||||
public void init() {
|
||||
executor = actionCapability.getExecutor(ActionCore.ExecutorType.VIRTUAL);
|
||||
registerScheduledUpdater();
|
||||
}
|
||||
|
||||
private void registerScheduledUpdater() {
|
||||
StateAction stateAction = new StateAction(
|
||||
"system",
|
||||
"dialog-rolling-auto-update",
|
||||
"定时检查并触发对话滚动",
|
||||
Schedulable.ScheduleType.CYCLE,
|
||||
AUTO_UPDATE_CRON,
|
||||
new StateAction.Trigger.Call(() -> {
|
||||
tryAutoRolling();
|
||||
return Unit.INSTANCE;
|
||||
})
|
||||
);
|
||||
actionScheduler.schedule(stateAction);
|
||||
log.info("Dialog rolling has been registered into ActionScheduler, cron={}", AUTO_UPDATE_CRON);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void doExecute(@NotNull PartnerRunningFlowContext context) {
|
||||
if (cognitionCapability.getChatMessages().size() < MEMORY_UPDATE_TRIGGER_ROLL_LIMIT) {
|
||||
return;
|
||||
}
|
||||
executor.execute(() -> triggerRolling(false));
|
||||
}
|
||||
|
||||
private void tryAutoRolling() {
|
||||
long currentTime = System.currentTimeMillis();
|
||||
int chatCount = cognitionCapability.snapshotChatMessages().size();
|
||||
if (currentTime - perceiveCapability.showLastInteract().toEpochMilli() > UPDATE_TRIGGER_INTERVAL && chatCount > 1) {
|
||||
triggerRolling(true);
|
||||
log.debug("Dialog rolling: auto triggered");
|
||||
}
|
||||
}
|
||||
|
||||
private void triggerRolling(boolean refreshMemoryId) {
|
||||
if (!rolling.compareAndSet(false, true)) {
|
||||
log.debug("Dialog rolling: rolling is already executing");
|
||||
return;
|
||||
}
|
||||
try {
|
||||
List<Message> fullChatSnapshot = cognitionCapability.snapshotChatMessages();
|
||||
if (fullChatSnapshot.size() <= 1) {
|
||||
return;
|
||||
}
|
||||
List<Message> chatIncrement = resolveChatIncrement(fullChatSnapshot);
|
||||
if (chatIncrement.isEmpty()) {
|
||||
if (refreshMemoryId) {
|
||||
memoryCapability.refreshMemorySession();
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
RollingResult result = buildRollingResult(chatIncrement, fullChatSnapshot.size(), CONTEXT_RETAIN_DIVISOR);
|
||||
applyRolling(result);
|
||||
afterRollingRegistry.trigger(result);
|
||||
|
||||
if (refreshMemoryId) {
|
||||
memoryCapability.refreshMemorySession();
|
||||
}
|
||||
} catch (Exception e) {
|
||||
ExceptionReporterHandler.INSTANCE.report(new AgentRuntimeException("Dialog rolling failed", e));
|
||||
} finally {
|
||||
rolling.set(false);
|
||||
}
|
||||
}
|
||||
|
||||
List<Message> resolveChatIncrement(List<Message> fullChatSnapshot) {
|
||||
String memoryId = memoryCapability.getMemorySessionId();
|
||||
if (memoryId.isBlank()) {
|
||||
return fullChatSnapshot;
|
||||
}
|
||||
MemoryUnit existingUnit = memoryCapability.getMemoryUnit(memoryId);
|
||||
if (existingUnit.getConversationMessages().isEmpty()) {
|
||||
return fullChatSnapshot;
|
||||
}
|
||||
List<Message> existingMessages = existingUnit.getConversationMessages();
|
||||
int maxOverlap = Math.min(existingMessages.size(), fullChatSnapshot.size());
|
||||
for (int overlap = maxOverlap; overlap > 0; overlap--) {
|
||||
List<Message> existingSuffix = existingMessages.subList(existingMessages.size() - overlap, existingMessages.size());
|
||||
List<Message> snapshotPrefix = fullChatSnapshot.subList(0, overlap);
|
||||
if (existingSuffix.equals(snapshotPrefix)) {
|
||||
return fullChatSnapshot.subList(overlap, fullChatSnapshot.size());
|
||||
}
|
||||
}
|
||||
return fullChatSnapshot;
|
||||
}
|
||||
|
||||
@NotNull
|
||||
RollingResult buildRollingResult(List<Message> chatSnapshot, int rollingSize, int retainDivisor) {
|
||||
SummarizeInput summarizeInput = new SummarizeInput(chatSnapshot, memoryRuntime.getTopicTree());
|
||||
singleSummarizer.execute(summarizeInput.getChatMessages());
|
||||
Result<SummarizeResult> summarizeResult = multiSummarizer.execute(summarizeInput);
|
||||
String summary = summarizeResult.fold(
|
||||
SummarizeResult::getSummary,
|
||||
exp -> "no summary, due to exception"
|
||||
);
|
||||
if (summary.isBlank()) {
|
||||
summary = "no summary, due to empty summarize result";
|
||||
}
|
||||
MemoryUnit memoryUnit = memoryCapability.updateMemoryUnit(chatSnapshot, summary);
|
||||
MemorySlice newSlice = memoryUnit.getSlices().getLast();
|
||||
return new RollingResult(memoryUnit, newSlice, List.copyOf(chatSnapshot), newSlice.getSummary(), rollingSize, retainDivisor);
|
||||
}
|
||||
|
||||
private void applyRolling(RollingResult result) {
|
||||
cognitionCapability.contextWorkspace().register(new ContextBlock(
|
||||
buildDialogAbstractBlock(result.summary(), result.memoryUnit().getId(), result.memorySlice().getId()),
|
||||
Set.of(ContextBlock.VisibleDomain.MEMORY, ContextBlock.VisibleDomain.COMMUNICATION),
|
||||
20,
|
||||
5,
|
||||
10
|
||||
));
|
||||
cognitionCapability.rollChatMessagesWithSnapshot(result.rollingSize(), result.retainDivisor());
|
||||
}
|
||||
|
||||
private @NotNull BlockContent buildDialogAbstractBlock(String summary, String unitId, String sliceId) {
|
||||
return new BlockContent("dialog_history", "dialog_rolling") {
|
||||
@Override
|
||||
protected void fillXml(@NotNull Document document, @NotNull Element root) {
|
||||
root.setAttribute("related_memory_unit_id", unitId);
|
||||
root.setAttribute("related_memory_slice_id", sliceId);
|
||||
root.setAttribute("datetime", ZonedDateTime.now().format(DateTimeFormatter.ofPattern("yyyy-MM-dd HH:mm:ss")));
|
||||
appendTextElement(document, root, "summary", summary);
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
@Override
|
||||
public int order() {
|
||||
return 7;
|
||||
}
|
||||
}
|
||||
@@ -1,77 +0,0 @@
|
||||
package work.slhaf.partner.module.communication;
|
||||
|
||||
import kotlin.Unit;
|
||||
import org.jetbrains.annotations.NotNull;
|
||||
import org.jetbrains.annotations.Nullable;
|
||||
import org.w3c.dom.Document;
|
||||
import org.w3c.dom.Element;
|
||||
import work.slhaf.partner.core.cognition.BlockContent;
|
||||
import work.slhaf.partner.core.cognition.CognitionCapability;
|
||||
import work.slhaf.partner.core.cognition.ContextBlock;
|
||||
import work.slhaf.partner.framework.agent.factory.capability.annotation.InjectCapability;
|
||||
import work.slhaf.partner.framework.agent.factory.component.abstracts.AbstractAgentModule;
|
||||
import work.slhaf.partner.framework.agent.model.ActivateModel;
|
||||
import work.slhaf.partner.framework.agent.model.pojo.Message;
|
||||
import work.slhaf.partner.module.TaskBlock;
|
||||
|
||||
import java.time.ZonedDateTime;
|
||||
import java.time.format.DateTimeFormatter;
|
||||
import java.util.List;
|
||||
import java.util.Locale;
|
||||
import java.util.Set;
|
||||
|
||||
public class DialogRollingService extends AbstractAgentModule.Standalone implements ActivateModel {
|
||||
|
||||
@InjectCapability
|
||||
private CognitionCapability cognitionCapability;
|
||||
|
||||
public void rollMessages(List<Message> snapshotMessages, int rollingSize, int retainSize) {
|
||||
rollMessages(snapshotMessages, rollingSize, retainSize, null, null, null);
|
||||
}
|
||||
|
||||
public void rollMessages(List<Message> snapshotMessages, int rollingSize, int retainSize, @Nullable String unitId, @Nullable String sliceId, @Nullable String summary) {
|
||||
summary = summary == null ? summarize(snapshotMessages) : summary;
|
||||
cognitionCapability.contextWorkspace().register(new ContextBlock(
|
||||
buildDialogAbstractBlock(summary, unitId, sliceId),
|
||||
Set.of(ContextBlock.VisibleDomain.MEMORY, ContextBlock.VisibleDomain.COMMUNICATION),
|
||||
35,
|
||||
8,
|
||||
10
|
||||
));
|
||||
cognitionCapability.rollChatMessagesWithSnapshot(rollingSize, retainSize);
|
||||
}
|
||||
|
||||
private String summarize(List<Message> snapshotMessages) {
|
||||
List<Message> messages = List.of(
|
||||
resolveTaskBlock(snapshotMessages)
|
||||
);
|
||||
return chat(messages).getOrThrow();
|
||||
}
|
||||
|
||||
private @NotNull BlockContent buildDialogAbstractBlock(String summary, @Nullable String unitId, @Nullable String sliceId) {
|
||||
return new BlockContent("dialog_history", "dialog_rolling_service") {
|
||||
@Override
|
||||
protected void fillXml(@NotNull Document document, @NotNull Element root) {
|
||||
if (unitId != null) root.setAttribute("related_memory_unit_id", unitId);
|
||||
if (sliceId != null) root.setAttribute("related_memory_slice_id", sliceId);
|
||||
root.setAttribute("datetime", ZonedDateTime.now().format(DateTimeFormatter.ofPattern("yyyy-MM-dd HH:mm:ss")));
|
||||
|
||||
appendTextElement(document, root, "summary", summary);
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
private Message resolveTaskBlock(List<Message> snapshotMessages) {
|
||||
return new TaskBlock() {
|
||||
@Override
|
||||
protected void fillXml(@NotNull Document document, @NotNull Element root) {
|
||||
appendRepeatedElements(document, root, "message", snapshotMessages, (messageElement, message) -> {
|
||||
messageElement.setAttribute("role", message.getRole().name().toLowerCase(Locale.ROOT));
|
||||
messageElement.setTextContent(message.getContent());
|
||||
return Unit.INSTANCE;
|
||||
});
|
||||
}
|
||||
}.encodeToMessage();
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,17 @@
|
||||
package work.slhaf.partner.module.communication;
|
||||
|
||||
import work.slhaf.partner.core.memory.pojo.MemorySlice;
|
||||
import work.slhaf.partner.core.memory.pojo.MemoryUnit;
|
||||
import work.slhaf.partner.framework.agent.model.pojo.Message;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
public record RollingResult(
|
||||
MemoryUnit memoryUnit,
|
||||
MemorySlice memorySlice,
|
||||
List<Message> incrementMessages,
|
||||
String summary,
|
||||
int rollingSize,
|
||||
int retainDivisor
|
||||
) {
|
||||
}
|
||||
@@ -50,8 +50,7 @@ public class MemoryRuntime extends AbstractAgentModule.Standalone implements Sta
|
||||
}
|
||||
|
||||
private void checkAndSetMemoryId() {
|
||||
String currentMemoryId = memoryCapability.getMemorySessionId();
|
||||
if (currentMemoryId == null || cognitionCapability.getChatMessages().isEmpty()) {
|
||||
if (cognitionCapability.getChatMessages().isEmpty()) {
|
||||
memoryCapability.refreshMemorySession();
|
||||
}
|
||||
}
|
||||
@@ -75,9 +74,11 @@ public class MemoryRuntime extends AbstractAgentModule.Standalone implements Sta
|
||||
MemorySlice memorySlice = memoryUnit.getSlices().getLast();
|
||||
SliceRef sliceRef = new SliceRef(memoryUnit.getId(), memorySlice.getId());
|
||||
indexMemoryUnit(memoryUnit);
|
||||
bindTopic(topicPath, sliceRef);
|
||||
if (relatedTopicPaths != null) {
|
||||
for (String relatedTopicPath : relatedTopicPaths) {
|
||||
if (topicPath != null && !topicPath.isBlank()) {
|
||||
bindTopic(topicPath, sliceRef);
|
||||
}
|
||||
for (String relatedTopicPath : relatedTopicPaths) {
|
||||
if (relatedTopicPath != null && !relatedTopicPath.isBlank()) {
|
||||
bindTopic(relatedTopicPath, sliceRef);
|
||||
}
|
||||
}
|
||||
@@ -89,14 +90,12 @@ public class MemoryRuntime extends AbstractAgentModule.Standalone implements Sta
|
||||
for (CopyOnWriteArrayList<SliceRef> refs : dateIndex.values()) {
|
||||
refs.removeIf(ref -> memoryUnit.getId().equals(ref.getUnitId()));
|
||||
}
|
||||
if (memoryUnit.getSlices() != null) {
|
||||
for (MemorySlice slice : memoryUnit.getSlices()) {
|
||||
LocalDate date = Instant.ofEpochMilli(slice.getTimestamp())
|
||||
.atZone(ZoneId.systemDefault())
|
||||
.toLocalDate();
|
||||
dateIndex.computeIfAbsent(date, key -> new CopyOnWriteArrayList<>())
|
||||
.addIfAbsent(new SliceRef(memoryUnit.getId(), slice.getId()));
|
||||
}
|
||||
for (MemorySlice slice : memoryUnit.getSlices()) {
|
||||
LocalDate date = Instant.ofEpochMilli(slice.getTimestamp())
|
||||
.atZone(ZoneId.systemDefault())
|
||||
.toLocalDate();
|
||||
dateIndex.computeIfAbsent(date, key -> new CopyOnWriteArrayList<>())
|
||||
.addIfAbsent(new SliceRef(memoryUnit.getId(), slice.getId()));
|
||||
}
|
||||
} finally {
|
||||
runtimeLock.unlock();
|
||||
@@ -192,7 +191,7 @@ public class MemoryRuntime extends AbstractAgentModule.Standalone implements Sta
|
||||
|
||||
private List<Message> sliceMessages(MemoryUnit memoryUnit, MemorySlice memorySlice) {
|
||||
List<Message> conversationMessages = memoryUnit.getConversationMessages();
|
||||
if (conversationMessages == null || conversationMessages.isEmpty()) {
|
||||
if (conversationMessages.isEmpty()) {
|
||||
return List.of();
|
||||
}
|
||||
int size = conversationMessages.size();
|
||||
|
||||
@@ -1,192 +1,94 @@
|
||||
package work.slhaf.partner.module.memory.updater;
|
||||
|
||||
import kotlin.Unit;
|
||||
import lombok.Data;
|
||||
import lombok.EqualsAndHashCode;
|
||||
import org.jetbrains.annotations.NotNull;
|
||||
import work.slhaf.partner.core.action.ActionCapability;
|
||||
import work.slhaf.partner.core.action.ActionCore;
|
||||
import work.slhaf.partner.core.action.entity.Schedulable;
|
||||
import work.slhaf.partner.core.action.entity.StateAction;
|
||||
import org.w3c.dom.Document;
|
||||
import org.w3c.dom.Element;
|
||||
import work.slhaf.partner.core.cognition.CognitionCapability;
|
||||
import work.slhaf.partner.core.memory.MemoryCapability;
|
||||
import work.slhaf.partner.core.memory.pojo.MemorySlice;
|
||||
import work.slhaf.partner.core.memory.pojo.MemoryUnit;
|
||||
import work.slhaf.partner.core.perceive.PerceiveCapability;
|
||||
import work.slhaf.partner.core.cognition.ContextBlock;
|
||||
import work.slhaf.partner.framework.agent.factory.capability.annotation.InjectCapability;
|
||||
import work.slhaf.partner.framework.agent.factory.component.abstracts.AbstractAgentModule;
|
||||
import work.slhaf.partner.framework.agent.factory.component.annotation.Init;
|
||||
import work.slhaf.partner.framework.agent.factory.component.annotation.InjectModule;
|
||||
import work.slhaf.partner.framework.agent.model.ActivateModel;
|
||||
import work.slhaf.partner.framework.agent.model.pojo.Message;
|
||||
import work.slhaf.partner.module.action.scheduler.ActionScheduler;
|
||||
import work.slhaf.partner.module.communication.DialogRollingService;
|
||||
import work.slhaf.partner.framework.agent.support.Result;
|
||||
import work.slhaf.partner.module.TaskBlock;
|
||||
import work.slhaf.partner.module.communication.AfterRolling;
|
||||
import work.slhaf.partner.module.communication.AfterRollingRegistry;
|
||||
import work.slhaf.partner.module.communication.RollingResult;
|
||||
import work.slhaf.partner.module.memory.runtime.MemoryRuntime;
|
||||
import work.slhaf.partner.module.memory.updater.summarizer.MultiSummarizer;
|
||||
import work.slhaf.partner.module.memory.updater.summarizer.SingleSummarizer;
|
||||
import work.slhaf.partner.module.memory.updater.summarizer.entity.SummarizeInput;
|
||||
import work.slhaf.partner.runtime.PartnerRunningFlowContext;
|
||||
import work.slhaf.partner.module.memory.updater.summarizer.entity.MemoryTopicResult;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.concurrent.ExecutorService;
|
||||
import java.util.concurrent.atomic.AtomicBoolean;
|
||||
|
||||
@EqualsAndHashCode(callSuper = true)
|
||||
@Data
|
||||
public class MemoryUpdater extends AbstractAgentModule.Running<PartnerRunningFlowContext> {
|
||||
|
||||
private static final String AUTO_UPDATE_CRON = "0/10 * * * * ?";
|
||||
private static final long UPDATE_TRIGGER_INTERVAL = 60 * 60 * 1000;
|
||||
private static final int CONTEXT_RETAIN_DIVISOR = 6;
|
||||
private static final int MEMORY_UPDATE_TRIGGER_ROLL_LIMIT = 36;
|
||||
public class MemoryUpdater extends AbstractAgentModule.Standalone implements AfterRolling, ActivateModel {
|
||||
|
||||
@InjectCapability
|
||||
private CognitionCapability cognitionCapability;
|
||||
@InjectCapability
|
||||
private MemoryCapability memoryCapability;
|
||||
@InjectCapability
|
||||
private PerceiveCapability perceiveCapability;
|
||||
@InjectCapability
|
||||
private ActionCapability actionCapability;
|
||||
|
||||
@InjectModule
|
||||
private MemoryRuntime memoryRuntime;
|
||||
@InjectModule
|
||||
private MultiSummarizer multiSummarizer;
|
||||
@InjectModule
|
||||
private SingleSummarizer singleSummarizer;
|
||||
@InjectModule
|
||||
private ActionScheduler actionScheduler;
|
||||
@InjectModule
|
||||
private DialogRollingService dialogRollingService;
|
||||
|
||||
private final AtomicBoolean updating = new AtomicBoolean(false);
|
||||
private ExecutorService executor;
|
||||
private AfterRollingRegistry afterRollingRegistry;
|
||||
|
||||
@Init
|
||||
public void init() {
|
||||
executor = actionCapability.getExecutor(ActionCore.ExecutorType.VIRTUAL);
|
||||
registerScheduledUpdater();
|
||||
afterRollingRegistry.register(this);
|
||||
}
|
||||
|
||||
private void registerScheduledUpdater() {
|
||||
StateAction stateAction = new StateAction(
|
||||
"system",
|
||||
"memory-auto-update",
|
||||
"定时检查并触发记忆更新",
|
||||
Schedulable.ScheduleType.CYCLE,
|
||||
AUTO_UPDATE_CRON,
|
||||
new StateAction.Trigger.Call(() -> {
|
||||
tryAutoUpdate();
|
||||
return Unit.INSTANCE;
|
||||
})
|
||||
@Override
|
||||
public void consume(RollingResult result) {
|
||||
List<Message> slicedMessages = sliceMessages(result);
|
||||
if (slicedMessages.isEmpty()) {
|
||||
return;
|
||||
}
|
||||
Result<MemoryTopicResult> extractResult = formattedChat(
|
||||
List.of(
|
||||
cognitionCapability.contextWorkspace().resolve(List.of(
|
||||
ContextBlock.VisibleDomain.COGNITION,
|
||||
ContextBlock.VisibleDomain.MEMORY
|
||||
)).encodeToMessage(),
|
||||
resolveTopicTaskMessage(result, slicedMessages)
|
||||
),
|
||||
MemoryTopicResult.class
|
||||
);
|
||||
actionScheduler.schedule(stateAction);
|
||||
log.info("[MemoryUpdater] 记忆自动更新已注册到 ActionScheduler, cron={}", AUTO_UPDATE_CRON);
|
||||
extractResult.onSuccess(topicResult -> {
|
||||
String topicPath = topicResult.getTopicPath() == null ? null : memoryRuntime.fixTopicPath(topicResult.getTopicPath());
|
||||
List<String> relatedTopicPaths = topicResult.getRelatedTopicPaths() == null
|
||||
? List.of()
|
||||
: topicResult.getRelatedTopicPaths().stream().map(memoryRuntime::fixTopicPath).toList();
|
||||
memoryRuntime.recordMemory(result.memoryUnit(), topicPath, relatedTopicPaths);
|
||||
}).onFailure(exp -> memoryRuntime.recordMemory(result.memoryUnit(), null, List.of()));
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void doExecute(@NotNull PartnerRunningFlowContext context) {
|
||||
boolean trigger = cognitionCapability.getChatMessages().size() >= MEMORY_UPDATE_TRIGGER_ROLL_LIMIT;
|
||||
if (!trigger) {
|
||||
return;
|
||||
private List<Message> sliceMessages(RollingResult result) {
|
||||
int size = result.memoryUnit().getConversationMessages().size();
|
||||
int start = Math.clamp(result.memorySlice().getStartIndex(), 0, size);
|
||||
int end = Math.clamp(result.memorySlice().getEndIndex(), start, size);
|
||||
if (start >= end) {
|
||||
return List.of();
|
||||
}
|
||||
executor.execute(() -> {
|
||||
log.debug("[MemoryUpdater] 记忆更新触发");
|
||||
triggerMemoryUpdate(false);
|
||||
});
|
||||
return result.memoryUnit().getConversationMessages().subList(start, end);
|
||||
}
|
||||
|
||||
private void tryAutoUpdate() {
|
||||
long currentTime = System.currentTimeMillis();
|
||||
int chatCount = cognitionCapability.snapshotChatMessages().size();
|
||||
if (currentTime - perceiveCapability.showLastInteract().toEpochMilli() > UPDATE_TRIGGER_INTERVAL && chatCount > 1) {
|
||||
triggerMemoryUpdate(true);
|
||||
log.info("[MemoryUpdater] 记忆更新: 自动触发");
|
||||
}
|
||||
}
|
||||
|
||||
private void triggerMemoryUpdate(boolean refreshMemoryId) {
|
||||
if (!updating.compareAndSet(false, true)) {
|
||||
log.debug("[MemoryUpdater] 更新任务已在执行中,本次触发跳过");
|
||||
return;
|
||||
}
|
||||
try {
|
||||
List<Message> fullChatSnapshot = cognitionCapability.snapshotChatMessages();
|
||||
if (fullChatSnapshot.size() <= 1) {
|
||||
return;
|
||||
}
|
||||
List<Message> chatIncrement = resolveChatIncrement(fullChatSnapshot);
|
||||
if (chatIncrement.isEmpty()) {
|
||||
if (refreshMemoryId) {
|
||||
memoryCapability.refreshMemorySession();
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
RollingRecord record = updateMemory(chatIncrement);
|
||||
dialogRollingService.rollMessages(chatIncrement, fullChatSnapshot.size(), CONTEXT_RETAIN_DIVISOR, record.unitId, record.sliceId, record.summary);
|
||||
|
||||
if (refreshMemoryId) {
|
||||
memoryCapability.refreshMemorySession();
|
||||
}
|
||||
} catch (Exception e) {
|
||||
log.error("[MemoryUpdater] 记忆更新线程出错: ", e);
|
||||
} finally {
|
||||
updating.set(false);
|
||||
}
|
||||
}
|
||||
|
||||
private List<Message> resolveChatIncrement(List<Message> fullChatSnapshot) {
|
||||
String memoryId = memoryCapability.getMemorySessionId();
|
||||
if (memoryId == null || memoryId.isBlank()) {
|
||||
return fullChatSnapshot;
|
||||
}
|
||||
MemoryUnit existingUnit = memoryCapability.getMemoryUnit(memoryId);
|
||||
if (existingUnit == null || existingUnit.getConversationMessages() == null || existingUnit.getConversationMessages().isEmpty()) {
|
||||
return fullChatSnapshot;
|
||||
}
|
||||
List<Message> existingMessages = existingUnit.getConversationMessages();
|
||||
int maxOverlap = Math.min(existingMessages.size(), fullChatSnapshot.size());
|
||||
for (int overlap = maxOverlap; overlap > 0; overlap--) {
|
||||
List<Message> existingSuffix = existingMessages.subList(existingMessages.size() - overlap, existingMessages.size());
|
||||
List<Message> snapshotPrefix = fullChatSnapshot.subList(0, overlap);
|
||||
if (existingSuffix.equals(snapshotPrefix)) {
|
||||
return fullChatSnapshot.subList(overlap, fullChatSnapshot.size());
|
||||
}
|
||||
}
|
||||
return fullChatSnapshot;
|
||||
}
|
||||
|
||||
private RollingRecord updateMemory(List<Message> chatSnapshot) {
|
||||
log.debug("[MemoryUpdater] 记忆更新流程开始...");
|
||||
if (chatSnapshot.isEmpty()) {
|
||||
return null;
|
||||
}
|
||||
SummarizeInput summarizeInput = new SummarizeInput(chatSnapshot, memoryRuntime.getTopicTree());
|
||||
singleSummarizer.execute(summarizeInput.getChatMessages());
|
||||
return multiSummarizer.execute(summarizeInput).fold(
|
||||
summarizeResult -> {
|
||||
MemoryUnit memoryUnit = memoryCapability.updateMemoryUnit(chatSnapshot, summarizeResult.getSummary());
|
||||
memoryRuntime.recordMemory(
|
||||
memoryUnit,
|
||||
summarizeResult.getTopicPath(),
|
||||
summarizeResult.getRelatedTopicPath()
|
||||
);
|
||||
MemorySlice newSlice = memoryUnit.getSlices().getLast();
|
||||
return new RollingRecord(memoryUnit.getId(), newSlice.getId(), newSlice.getSummary());
|
||||
},
|
||||
exp -> {
|
||||
MemoryUnit memoryUnit = memoryCapability.updateMemoryUnit(chatSnapshot, "no summary, due to exception");
|
||||
MemorySlice newSlice = memoryUnit.getSlices().getLast();
|
||||
return new RollingRecord(memoryUnit.getId(), newSlice.getId(), newSlice.getSummary());
|
||||
private Message resolveTopicTaskMessage(RollingResult result, List<Message> slicedMessages) {
|
||||
return new TaskBlock() {
|
||||
@Override
|
||||
protected void fillXml(@NotNull Document document, @NotNull Element root) {
|
||||
appendTextElement(document, root, "current_topic_tree", memoryRuntime.getTopicTree());
|
||||
appendTextElement(document, root, "slice_summary", result.summary());
|
||||
appendRepeatedElements(document, root, "message", slicedMessages, (messageElement, message) -> {
|
||||
messageElement.setAttribute("role", message.roleValue());
|
||||
messageElement.setTextContent(message.getContent());
|
||||
return kotlin.Unit.INSTANCE;
|
||||
});
|
||||
}
|
||||
}.encodeToMessage();
|
||||
}
|
||||
|
||||
@Override
|
||||
public int order() {
|
||||
return 7;
|
||||
}
|
||||
|
||||
private record RollingRecord(String unitId, String sliceId, String summary) {
|
||||
@NotNull
|
||||
public String modelKey() {
|
||||
return "topic_extractor";
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,11 @@
|
||||
package work.slhaf.partner.module.memory.updater.summarizer.entity;
|
||||
|
||||
import lombok.Data;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
@Data
|
||||
public class MemoryTopicResult {
|
||||
private String topicPath;
|
||||
private List<String> relatedTopicPaths;
|
||||
}
|
||||
@@ -0,0 +1,231 @@
|
||||
package work.slhaf.partner.module.communication;
|
||||
|
||||
import org.junit.jupiter.api.BeforeAll;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.io.TempDir;
|
||||
import org.mockito.Mockito;
|
||||
import work.slhaf.partner.core.memory.MemoryCapability;
|
||||
import work.slhaf.partner.core.memory.pojo.MemorySlice;
|
||||
import work.slhaf.partner.core.memory.pojo.MemoryUnit;
|
||||
import work.slhaf.partner.framework.agent.model.pojo.Message;
|
||||
import work.slhaf.partner.framework.agent.support.Result;
|
||||
import work.slhaf.partner.module.memory.runtime.MemoryRuntime;
|
||||
import work.slhaf.partner.module.memory.updater.summarizer.MultiSummarizer;
|
||||
import work.slhaf.partner.module.memory.updater.summarizer.SingleSummarizer;
|
||||
import work.slhaf.partner.module.memory.updater.summarizer.entity.SummarizeResult;
|
||||
|
||||
import java.lang.reflect.Field;
|
||||
import java.nio.file.Path;
|
||||
import java.util.*;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import static org.junit.jupiter.api.Assertions.assertNotNull;
|
||||
import static org.mockito.Mockito.when;
|
||||
|
||||
class DialogRollingTest {
|
||||
|
||||
@BeforeAll
|
||||
static void beforeAll(@TempDir Path tempDir) {
|
||||
System.setProperty("user.home", tempDir.toAbsolutePath().toString());
|
||||
}
|
||||
|
||||
private static void setField(Object target, String fieldName, Object value) throws Exception {
|
||||
Field field = target.getClass().getDeclaredField(fieldName);
|
||||
field.setAccessible(true);
|
||||
field.set(target, value);
|
||||
}
|
||||
|
||||
private static Message message(Message.Character role, String content) {
|
||||
return new Message(role, content);
|
||||
}
|
||||
|
||||
private static SummarizeResult summarizeResult(String summary, String topicPath, List<String> relatedTopicPath) {
|
||||
SummarizeResult result = new SummarizeResult();
|
||||
result.setSummary(summary);
|
||||
result.setTopicPath(topicPath);
|
||||
result.setRelatedTopicPath(relatedTopicPath);
|
||||
return result;
|
||||
}
|
||||
|
||||
@Test
|
||||
void shouldDelegateMemoryUpdateToCapability() throws Exception {
|
||||
String sessionId = "dialog-rolling-" + UUID.randomUUID();
|
||||
StubMemoryCapability memoryCapability = new StubMemoryCapability(sessionId);
|
||||
DialogRolling dialogRolling = new DialogRolling();
|
||||
MemoryRuntime memoryRuntime = Mockito.mock(MemoryRuntime.class);
|
||||
MultiSummarizer multiSummarizer = Mockito.mock(MultiSummarizer.class);
|
||||
SingleSummarizer singleSummarizer = Mockito.mock(SingleSummarizer.class);
|
||||
setField(dialogRolling, "memoryCapability", memoryCapability);
|
||||
setField(dialogRolling, "memoryRuntime", memoryRuntime);
|
||||
setField(dialogRolling, "multiSummarizer", multiSummarizer);
|
||||
setField(dialogRolling, "singleSummarizer", singleSummarizer);
|
||||
|
||||
when(memoryRuntime.getTopicTree()).thenReturn("topic-tree");
|
||||
when(multiSummarizer.execute(Mockito.any())).thenReturn(Result.success(
|
||||
summarizeResult("new-summary", "topic/main", List.of("topic/related"))
|
||||
));
|
||||
|
||||
MemoryUnit existingUnit = new MemoryUnit(sessionId);
|
||||
existingUnit.getConversationMessages().addAll(List.of(
|
||||
message(Message.Character.USER, "old-user"),
|
||||
message(Message.Character.ASSISTANT, "old-assistant")
|
||||
));
|
||||
existingUnit.getSlices().add(MemorySlice.restore("slice-1", 0, 2, "old-summary", 1L));
|
||||
memoryCapability.putUnit(existingUnit);
|
||||
|
||||
RollingResult rollingResult = dialogRolling.buildRollingResult(List.of(
|
||||
message(Message.Character.USER, "new-user"),
|
||||
message(Message.Character.ASSISTANT, "new-assistant")
|
||||
), 4, 6);
|
||||
|
||||
MemoryUnit merged = memoryCapability.getMemoryUnit(sessionId);
|
||||
assertEquals(List.of("old-user", "old-assistant", "new-user", "new-assistant"),
|
||||
merged.getConversationMessages().stream().map(Message::getContent).toList());
|
||||
assertEquals(2, merged.getSlices().size());
|
||||
|
||||
MemorySlice appendedSlice = merged.getSlices().getLast();
|
||||
assertNotNull(appendedSlice.getId());
|
||||
assertEquals(2, appendedSlice.getStartIndex());
|
||||
assertEquals(4, appendedSlice.getEndIndex());
|
||||
assertEquals("new-summary", appendedSlice.getSummary());
|
||||
assertEquals(sessionId, rollingResult.memoryUnit().getId());
|
||||
assertEquals(appendedSlice.getId(), rollingResult.memorySlice().getId());
|
||||
assertEquals("new-summary", rollingResult.summary());
|
||||
}
|
||||
|
||||
@Test
|
||||
void shouldCreateFirstSliceForFreshSessionThroughCapability() throws Exception {
|
||||
String sessionId = "dialog-rolling-" + UUID.randomUUID();
|
||||
StubMemoryCapability memoryCapability = new StubMemoryCapability(sessionId);
|
||||
DialogRolling dialogRolling = new DialogRolling();
|
||||
MemoryRuntime memoryRuntime = Mockito.mock(MemoryRuntime.class);
|
||||
MultiSummarizer multiSummarizer = Mockito.mock(MultiSummarizer.class);
|
||||
SingleSummarizer singleSummarizer = Mockito.mock(SingleSummarizer.class);
|
||||
setField(dialogRolling, "memoryCapability", memoryCapability);
|
||||
setField(dialogRolling, "memoryRuntime", memoryRuntime);
|
||||
setField(dialogRolling, "multiSummarizer", multiSummarizer);
|
||||
setField(dialogRolling, "singleSummarizer", singleSummarizer);
|
||||
|
||||
when(memoryRuntime.getTopicTree()).thenReturn("topic-tree");
|
||||
when(multiSummarizer.execute(Mockito.any())).thenReturn(Result.success(
|
||||
summarizeResult("fresh-summary", "topic/root", List.of())
|
||||
));
|
||||
|
||||
RollingResult rollingResult = dialogRolling.buildRollingResult(List.of(
|
||||
message(Message.Character.USER, "first"),
|
||||
message(Message.Character.ASSISTANT, "second")
|
||||
), 2, 6);
|
||||
|
||||
MemoryUnit created = memoryCapability.getMemoryUnit(sessionId);
|
||||
assertNotNull(created);
|
||||
assertEquals(List.of("first", "second"),
|
||||
created.getConversationMessages().stream().map(Message::getContent).toList());
|
||||
assertEquals(1, created.getSlices().size());
|
||||
assertEquals(0, created.getSlices().getFirst().getStartIndex());
|
||||
assertEquals(2, created.getSlices().getFirst().getEndIndex());
|
||||
assertEquals("fresh-summary", created.getSlices().getFirst().getSummary());
|
||||
assertEquals(created, rollingResult.memoryUnit());
|
||||
}
|
||||
|
||||
@Test
|
||||
void shouldTrimPersistedOverlapFromCurrentSnapshot() throws Exception {
|
||||
String sessionId = "dialog-rolling-" + UUID.randomUUID();
|
||||
StubMemoryCapability memoryCapability = new StubMemoryCapability(sessionId);
|
||||
DialogRolling dialogRolling = new DialogRolling();
|
||||
setField(dialogRolling, "memoryCapability", memoryCapability);
|
||||
|
||||
MemoryUnit existingUnit = Mockito.mock(MemoryUnit.class);
|
||||
when(existingUnit.getConversationMessages()).thenReturn(List.of(
|
||||
message(Message.Character.USER, "m1"),
|
||||
message(Message.Character.ASSISTANT, "m2"),
|
||||
message(Message.Character.USER, "m3"),
|
||||
message(Message.Character.ASSISTANT, "m4")
|
||||
));
|
||||
memoryCapability.putUnit(sessionId, existingUnit);
|
||||
|
||||
List<Message> increment = dialogRolling.resolveChatIncrement(List.of(
|
||||
message(Message.Character.USER, "m3"),
|
||||
message(Message.Character.ASSISTANT, "m4"),
|
||||
message(Message.Character.USER, "m5"),
|
||||
message(Message.Character.ASSISTANT, "m6")
|
||||
));
|
||||
|
||||
assertEquals(List.of("m5", "m6"), increment.stream().map(Message::getContent).toList());
|
||||
}
|
||||
|
||||
@Test
|
||||
void shouldFallbackWhenSummarizeResultIsBlank() throws Exception {
|
||||
String sessionId = "dialog-rolling-" + UUID.randomUUID();
|
||||
StubMemoryCapability memoryCapability = new StubMemoryCapability(sessionId);
|
||||
DialogRolling dialogRolling = new DialogRolling();
|
||||
MemoryRuntime memoryRuntime = Mockito.mock(MemoryRuntime.class);
|
||||
MultiSummarizer multiSummarizer = Mockito.mock(MultiSummarizer.class);
|
||||
SingleSummarizer singleSummarizer = Mockito.mock(SingleSummarizer.class);
|
||||
setField(dialogRolling, "memoryCapability", memoryCapability);
|
||||
setField(dialogRolling, "memoryRuntime", memoryRuntime);
|
||||
setField(dialogRolling, "multiSummarizer", multiSummarizer);
|
||||
setField(dialogRolling, "singleSummarizer", singleSummarizer);
|
||||
|
||||
when(memoryRuntime.getTopicTree()).thenReturn("topic-tree");
|
||||
when(multiSummarizer.execute(Mockito.any())).thenReturn(Result.success(summarizeResult(" ", "topic/root", List.of())));
|
||||
|
||||
RollingResult rollingResult = dialogRolling.buildRollingResult(List.of(
|
||||
message(Message.Character.USER, "u1"),
|
||||
message(Message.Character.ASSISTANT, "a1")
|
||||
), 2, 6);
|
||||
|
||||
assertEquals(sessionId, rollingResult.memoryUnit().getId());
|
||||
assertEquals("no summary, due to empty summarize result", rollingResult.summary());
|
||||
}
|
||||
|
||||
private static final class StubMemoryCapability implements MemoryCapability {
|
||||
private final String sessionId;
|
||||
private final Map<String, MemoryUnit> units = new HashMap<>();
|
||||
|
||||
private StubMemoryCapability(String sessionId) {
|
||||
this.sessionId = sessionId;
|
||||
}
|
||||
|
||||
private void putUnit(String unitId, MemoryUnit memoryUnit) {
|
||||
units.put(unitId == null ? memoryUnit.getId() : unitId, memoryUnit);
|
||||
}
|
||||
|
||||
private void putUnit(MemoryUnit memoryUnit) {
|
||||
units.put(memoryUnit.getId(), memoryUnit);
|
||||
}
|
||||
|
||||
@Override
|
||||
public MemoryUnit getMemoryUnit(String unitId) {
|
||||
return units.get(unitId);
|
||||
}
|
||||
|
||||
@Override
|
||||
public work.slhaf.partner.framework.agent.support.Result<MemorySlice> getMemorySlice(String unitId, String sliceId) {
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public MemoryUnit updateMemoryUnit(List<Message> chatMessages, String summary) {
|
||||
MemoryUnit unit = units.computeIfAbsent(sessionId, MemoryUnit::new);
|
||||
unit.updateTimestamp();
|
||||
int startIndex = unit.getConversationMessages().size();
|
||||
unit.getConversationMessages().addAll(chatMessages);
|
||||
unit.getSlices().add(new MemorySlice(startIndex, startIndex + chatMessages.size(), summary));
|
||||
return unit;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Collection<MemoryUnit> listMemoryUnits() {
|
||||
return units.values();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void refreshMemorySession() {
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getMemorySessionId() {
|
||||
return sessionId;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -4,26 +4,21 @@ import org.junit.jupiter.api.BeforeAll;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.io.TempDir;
|
||||
import org.mockito.Mockito;
|
||||
import work.slhaf.partner.core.memory.MemoryCapability;
|
||||
import work.slhaf.partner.core.cognition.CognitionCapability;
|
||||
import work.slhaf.partner.core.memory.pojo.MemorySlice;
|
||||
import work.slhaf.partner.core.memory.pojo.MemoryUnit;
|
||||
import work.slhaf.partner.framework.agent.exception.AgentRuntimeException;
|
||||
import work.slhaf.partner.framework.agent.model.pojo.Message;
|
||||
import work.slhaf.partner.framework.agent.support.Result;
|
||||
import work.slhaf.partner.module.communication.AfterRollingRegistry;
|
||||
import work.slhaf.partner.module.communication.RollingResult;
|
||||
import work.slhaf.partner.module.memory.runtime.MemoryRuntime;
|
||||
import work.slhaf.partner.module.memory.runtime.exception.MemoryLookupException;
|
||||
import work.slhaf.partner.module.memory.updater.summarizer.MultiSummarizer;
|
||||
import work.slhaf.partner.module.memory.updater.summarizer.SingleSummarizer;
|
||||
import work.slhaf.partner.module.memory.updater.summarizer.entity.SummarizeResult;
|
||||
import work.slhaf.partner.module.memory.updater.summarizer.entity.MemoryTopicResult;
|
||||
|
||||
import java.lang.reflect.Field;
|
||||
import java.lang.reflect.Method;
|
||||
import java.nio.file.Path;
|
||||
import java.util.Collection;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.*;
|
||||
import static org.mockito.ArgumentMatchers.eq;
|
||||
import static org.mockito.Mockito.verify;
|
||||
import static org.mockito.Mockito.when;
|
||||
@@ -35,275 +30,89 @@ class MemoryUpdaterTest {
|
||||
System.setProperty("user.home", tempDir.toAbsolutePath().toString());
|
||||
}
|
||||
|
||||
private static Object invokeUpdateMemory(MemoryUpdater updater, List<Message> chatMessages) throws Exception {
|
||||
Method method = MemoryUpdater.class.getDeclaredMethod("updateMemory", List.class);
|
||||
method.setAccessible(true);
|
||||
return method.invoke(updater, chatMessages);
|
||||
}
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
private static List<Message> invokeResolveChatIncrement(MemoryUpdater updater,
|
||||
List<Message> chatMessages) throws Exception {
|
||||
Method method = MemoryUpdater.class.getDeclaredMethod("resolveChatIncrement", List.class);
|
||||
method.setAccessible(true);
|
||||
return (List<Message>) method.invoke(updater, chatMessages);
|
||||
}
|
||||
|
||||
private static String recordField(Object target, String fieldName) throws Exception {
|
||||
Field field = target.getClass().getDeclaredField(fieldName);
|
||||
field.setAccessible(true);
|
||||
return (String) field.get(target);
|
||||
}
|
||||
|
||||
private static void setField(Object target, String fieldName, Object value) throws Exception {
|
||||
Field field = target.getClass().getDeclaredField(fieldName);
|
||||
field.setAccessible(true);
|
||||
field.set(target, value);
|
||||
}
|
||||
|
||||
private static SummarizeResult summarizeResult(String summary, String topicPath, List<String> relatedTopicPath) {
|
||||
SummarizeResult result = new SummarizeResult();
|
||||
result.setSummary(summary);
|
||||
result.setTopicPath(topicPath);
|
||||
result.setRelatedTopicPath(relatedTopicPath);
|
||||
return result;
|
||||
}
|
||||
|
||||
private static Message message(Message.Character role, String content) {
|
||||
return new Message(role, content);
|
||||
}
|
||||
|
||||
@Test
|
||||
void shouldDelegateMemoryUpdateToCapabilityAndRuntime() throws Exception {
|
||||
StubMemoryCapability memoryCapability = new StubMemoryCapability("session-1");
|
||||
MemoryUpdater updater = new MemoryUpdater();
|
||||
MemoryRuntime memoryRuntime = Mockito.mock(MemoryRuntime.class);
|
||||
MultiSummarizer multiSummarizer = Mockito.mock(MultiSummarizer.class);
|
||||
SingleSummarizer singleSummarizer = Mockito.mock(SingleSummarizer.class);
|
||||
setField(updater, "memoryCapability", memoryCapability);
|
||||
setField(updater, "memoryRuntime", memoryRuntime);
|
||||
setField(updater, "multiSummarizer", multiSummarizer);
|
||||
setField(updater, "singleSummarizer", singleSummarizer);
|
||||
void shouldRegisterItselfToAfterRollingRegistryOnInit() throws Exception {
|
||||
MemoryUpdater updater = Mockito.spy(new MemoryUpdater());
|
||||
AfterRollingRegistry registry = Mockito.mock(AfterRollingRegistry.class);
|
||||
setField(updater, "afterRollingRegistry", registry);
|
||||
|
||||
updater.init();
|
||||
|
||||
verify(registry).register(updater);
|
||||
}
|
||||
|
||||
@Test
|
||||
void shouldExtractTopicAndRecordMemoryOnConsume() throws Exception {
|
||||
MemoryUpdater updater = Mockito.spy(new MemoryUpdater());
|
||||
MemoryRuntime memoryRuntime = Mockito.mock(MemoryRuntime.class);
|
||||
CognitionCapability cognitionCapability = Mockito.mock(CognitionCapability.class);
|
||||
setField(updater, "memoryRuntime", memoryRuntime);
|
||||
setField(updater, "cognitionCapability", cognitionCapability);
|
||||
|
||||
when(cognitionCapability.contextWorkspace()).thenReturn(new work.slhaf.partner.core.cognition.ContextWorkspace());
|
||||
when(memoryRuntime.getTopicTree()).thenReturn("topic-tree");
|
||||
when(multiSummarizer.execute(Mockito.any())).thenReturn(Result.success(
|
||||
summarizeResult("new-summary", "topic/main", List.of("topic/related"))
|
||||
when(memoryRuntime.fixTopicPath("root[2]->branch[1]")).thenReturn("root->branch");
|
||||
when(memoryRuntime.fixTopicPath("root[2]->related[1]")).thenReturn("root->related");
|
||||
|
||||
MemoryTopicResult topicResult = new MemoryTopicResult();
|
||||
topicResult.setTopicPath("root[2]->branch[1]");
|
||||
topicResult.setRelatedTopicPaths(List.of("root[2]->related[1]"));
|
||||
Mockito.doReturn(Result.success(topicResult))
|
||||
.when(updater)
|
||||
.formattedChat(Mockito.anyList(), eq(MemoryTopicResult.class));
|
||||
|
||||
MemoryUnit unit = new MemoryUnit("session-1");
|
||||
unit.getConversationMessages().addAll(List.of(
|
||||
message(Message.Character.USER, "old"),
|
||||
message(Message.Character.ASSISTANT, "old-reply"),
|
||||
message(Message.Character.USER, "new"),
|
||||
message(Message.Character.ASSISTANT, "new-reply")
|
||||
));
|
||||
MemorySlice slice = new MemorySlice(2, 4, "slice-summary");
|
||||
unit.getSlices().add(slice);
|
||||
|
||||
MemoryUnit existingUnit = new MemoryUnit("session-1");
|
||||
existingUnit.getConversationMessages().addAll(List.of(
|
||||
message(Message.Character.USER, "old-user"),
|
||||
message(Message.Character.ASSISTANT, "old-assistant")
|
||||
));
|
||||
existingUnit.getSlices().add(MemorySlice.restore("slice-1", 0, 2, "old-summary", 1L));
|
||||
memoryCapability.putUnit(existingUnit);
|
||||
updater.consume(new RollingResult(unit, slice, List.of(
|
||||
message(Message.Character.USER, "new"),
|
||||
message(Message.Character.ASSISTANT, "new-reply")
|
||||
), "slice-summary", 4, 6));
|
||||
|
||||
Object rollingRecord = invokeUpdateMemory(updater, List.of(
|
||||
message(Message.Character.USER, "new-user"),
|
||||
message(Message.Character.ASSISTANT, "new-assistant")
|
||||
));
|
||||
|
||||
MemoryUnit merged = memoryCapability.getMemoryUnit("session-1");
|
||||
assertEquals(List.of("old-user", "old-assistant", "new-user", "new-assistant"),
|
||||
merged.getConversationMessages().stream().map(Message::getContent).toList());
|
||||
assertEquals(2, merged.getSlices().size());
|
||||
|
||||
MemorySlice appendedSlice = merged.getSlices().getLast();
|
||||
assertNotNull(appendedSlice.getId());
|
||||
assertEquals(2, appendedSlice.getStartIndex());
|
||||
assertEquals(4, appendedSlice.getEndIndex());
|
||||
assertEquals("new-summary", appendedSlice.getSummary());
|
||||
|
||||
assertEquals(List.of("new-user", "new-assistant"),
|
||||
memoryCapability.lastChatMessages().stream().map(Message::getContent).toList());
|
||||
assertEquals("new-summary", memoryCapability.lastSummary());
|
||||
verify(memoryRuntime).recordMemory(eq(merged), eq("topic/main"), eq(List.of("topic/related")));
|
||||
assertEquals("session-1", recordField(rollingRecord, "unitId"));
|
||||
assertEquals(appendedSlice.getId(), recordField(rollingRecord, "sliceId"));
|
||||
assertEquals("new-summary", recordField(rollingRecord, "summary"));
|
||||
verify(memoryRuntime).recordMemory(eq(unit), eq("root->branch"), eq(List.of("root->related")));
|
||||
}
|
||||
|
||||
@Test
|
||||
void shouldCreateFirstSliceForFreshSessionThroughCapability() throws Exception {
|
||||
StubMemoryCapability memoryCapability = new StubMemoryCapability("session-2");
|
||||
MemoryUpdater updater = new MemoryUpdater();
|
||||
void shouldFallbackToDateOnlyRecordWhenExtractionFails() throws Exception {
|
||||
MemoryUpdater updater = Mockito.spy(new MemoryUpdater());
|
||||
MemoryRuntime memoryRuntime = Mockito.mock(MemoryRuntime.class);
|
||||
MultiSummarizer multiSummarizer = Mockito.mock(MultiSummarizer.class);
|
||||
SingleSummarizer singleSummarizer = Mockito.mock(SingleSummarizer.class);
|
||||
setField(updater, "memoryCapability", memoryCapability);
|
||||
CognitionCapability cognitionCapability = Mockito.mock(CognitionCapability.class);
|
||||
setField(updater, "memoryRuntime", memoryRuntime);
|
||||
setField(updater, "multiSummarizer", multiSummarizer);
|
||||
setField(updater, "singleSummarizer", singleSummarizer);
|
||||
setField(updater, "cognitionCapability", cognitionCapability);
|
||||
|
||||
when(cognitionCapability.contextWorkspace()).thenReturn(new work.slhaf.partner.core.cognition.ContextWorkspace());
|
||||
when(memoryRuntime.getTopicTree()).thenReturn("topic-tree");
|
||||
when(multiSummarizer.execute(Mockito.any())).thenReturn(Result.success(
|
||||
summarizeResult("fresh-summary", "topic/root", List.of())
|
||||
Mockito.doReturn(Result.failure(new AgentRuntimeException("boom")))
|
||||
.when(updater)
|
||||
.formattedChat(Mockito.anyList(), eq(MemoryTopicResult.class));
|
||||
|
||||
MemoryUnit unit = new MemoryUnit("session-2");
|
||||
unit.getConversationMessages().addAll(List.of(
|
||||
message(Message.Character.USER, "u1"),
|
||||
message(Message.Character.ASSISTANT, "a1")
|
||||
));
|
||||
MemorySlice slice = new MemorySlice(0, 2, "slice-summary");
|
||||
unit.getSlices().add(slice);
|
||||
|
||||
Object rollingRecord = invokeUpdateMemory(updater, List.of(
|
||||
message(Message.Character.USER, "first"),
|
||||
message(Message.Character.ASSISTANT, "second")
|
||||
));
|
||||
updater.consume(new RollingResult(unit, slice, unit.getConversationMessages(), "slice-summary", 2, 6));
|
||||
|
||||
MemoryUnit created = memoryCapability.getMemoryUnit("session-2");
|
||||
assertNotNull(created);
|
||||
assertEquals("session-2", created.getId());
|
||||
assertEquals(List.of("first", "second"),
|
||||
created.getConversationMessages().stream().map(Message::getContent).toList());
|
||||
assertEquals(1, created.getSlices().size());
|
||||
assertEquals(0, created.getSlices().getFirst().getStartIndex());
|
||||
assertEquals(2, created.getSlices().getFirst().getEndIndex());
|
||||
assertEquals("fresh-summary", created.getSlices().getFirst().getSummary());
|
||||
verify(memoryRuntime).recordMemory(eq(created), eq("topic/root"), eq(List.of()));
|
||||
assertEquals("session-2", recordField(rollingRecord, "unitId"));
|
||||
assertEquals(created.getSlices().getFirst().getId(), recordField(rollingRecord, "sliceId"));
|
||||
assertEquals("fresh-summary", recordField(rollingRecord, "summary"));
|
||||
}
|
||||
|
||||
@Test
|
||||
void shouldTrimPersistedOverlapFromCurrentSnapshot() throws Exception {
|
||||
StubMemoryCapability memoryCapability = new StubMemoryCapability("session-3");
|
||||
MemoryUpdater updater = new MemoryUpdater();
|
||||
setField(updater, "memoryCapability", memoryCapability);
|
||||
|
||||
MemoryUnit existingUnit = Mockito.mock(MemoryUnit.class);
|
||||
when(existingUnit.getConversationMessages()).thenReturn(List.of(
|
||||
message(Message.Character.USER, "m1"),
|
||||
message(Message.Character.ASSISTANT, "m2"),
|
||||
message(Message.Character.USER, "m3"),
|
||||
message(Message.Character.ASSISTANT, "m4")
|
||||
));
|
||||
memoryCapability.putUnit("session-3", existingUnit);
|
||||
|
||||
List<Message> increment = invokeResolveChatIncrement(
|
||||
updater,
|
||||
List.of(
|
||||
message(Message.Character.USER, "m3"),
|
||||
message(Message.Character.ASSISTANT, "m4"),
|
||||
message(Message.Character.USER, "m5"),
|
||||
message(Message.Character.ASSISTANT, "m6")
|
||||
)
|
||||
);
|
||||
|
||||
assertEquals(List.of("m5", "m6"), increment.stream().map(Message::getContent).toList());
|
||||
}
|
||||
|
||||
@Test
|
||||
void shouldReturnEmptyIncrementWhenSnapshotIsFullyPersisted() throws Exception {
|
||||
StubMemoryCapability memoryCapability = new StubMemoryCapability("session-4");
|
||||
MemoryUpdater updater = new MemoryUpdater();
|
||||
setField(updater, "memoryCapability", memoryCapability);
|
||||
|
||||
MemoryUnit existingUnit = Mockito.mock(MemoryUnit.class);
|
||||
when(existingUnit.getConversationMessages()).thenReturn(List.of(
|
||||
message(Message.Character.USER, "m1"),
|
||||
message(Message.Character.ASSISTANT, "m2"),
|
||||
message(Message.Character.USER, "m3")
|
||||
));
|
||||
memoryCapability.putUnit("session-4", existingUnit);
|
||||
|
||||
List<Message> increment = invokeResolveChatIncrement(
|
||||
updater,
|
||||
List.of(
|
||||
message(Message.Character.ASSISTANT, "m2"),
|
||||
message(Message.Character.USER, "m3")
|
||||
)
|
||||
);
|
||||
|
||||
assertEquals(List.of(), increment);
|
||||
}
|
||||
|
||||
@Test
|
||||
void shouldReturnNullWhenUpdateMemoryReceivesEmptySnapshot() throws Exception {
|
||||
StubMemoryCapability memoryCapability = new StubMemoryCapability("session-5");
|
||||
MemoryUpdater updater = new MemoryUpdater();
|
||||
MemoryRuntime memoryRuntime = Mockito.mock(MemoryRuntime.class);
|
||||
setField(updater, "memoryCapability", memoryCapability);
|
||||
setField(updater, "memoryRuntime", memoryRuntime);
|
||||
|
||||
Object rollingRecord = invokeUpdateMemory(updater, List.of());
|
||||
|
||||
assertNull(rollingRecord);
|
||||
assertNull(memoryCapability.lastSummary());
|
||||
Mockito.verifyNoInteractions(memoryRuntime);
|
||||
}
|
||||
|
||||
private static final class StubMemoryCapability implements MemoryCapability {
|
||||
private final String sessionId;
|
||||
private final Map<String, MemoryUnit> units = new HashMap<>();
|
||||
private List<Message> lastChatMessages;
|
||||
private String lastSummary;
|
||||
|
||||
private StubMemoryCapability(String sessionId) {
|
||||
this.sessionId = sessionId;
|
||||
}
|
||||
|
||||
private void putUnit(String unitId, MemoryUnit memoryUnit) {
|
||||
units.put(unitId, memoryUnit);
|
||||
}
|
||||
|
||||
private void putUnit(MemoryUnit memoryUnit) {
|
||||
units.put(memoryUnit.getId(), memoryUnit);
|
||||
}
|
||||
|
||||
private List<Message> lastChatMessages() {
|
||||
return lastChatMessages;
|
||||
}
|
||||
|
||||
private String lastSummary() {
|
||||
return lastSummary;
|
||||
}
|
||||
|
||||
@Override
|
||||
public MemoryUnit getMemoryUnit(String unitId) {
|
||||
return units.get(unitId);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Result<MemorySlice> getMemorySlice(String unitId, String sliceId) {
|
||||
MemoryUnit unit = units.get(unitId);
|
||||
if (unit == null || unit.getSlices() == null) {
|
||||
return Result.failure(new MemoryLookupException(
|
||||
"Memory slice not found: " + unitId + ":" + sliceId,
|
||||
unitId + ":" + sliceId,
|
||||
"MEMORY_SLICE"
|
||||
));
|
||||
}
|
||||
return unit.getSlices().stream()
|
||||
.filter(slice -> sliceId.equals(slice.getId()))
|
||||
.findFirst()
|
||||
.map(Result::success)
|
||||
.orElseGet(() -> Result.failure(new MemoryLookupException(
|
||||
"Memory slice not found: " + unitId + ":" + sliceId,
|
||||
unitId + ":" + sliceId,
|
||||
"MEMORY_SLICE"
|
||||
)));
|
||||
}
|
||||
|
||||
@Override
|
||||
public MemoryUnit updateMemoryUnit(List<Message> chatMessages, String summary) {
|
||||
lastChatMessages = List.copyOf(chatMessages);
|
||||
lastSummary = summary;
|
||||
MemoryUnit unit = units.computeIfAbsent(sessionId, MemoryUnit::new);
|
||||
unit.updateTimestamp();
|
||||
int startIndex = unit.getConversationMessages().size();
|
||||
unit.getConversationMessages().addAll(chatMessages);
|
||||
unit.getSlices().add(new MemorySlice(startIndex, startIndex + chatMessages.size(), summary));
|
||||
return unit;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Collection<MemoryUnit> listMemoryUnits() {
|
||||
return units.values();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void refreshMemorySession() {
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getMemorySessionId() {
|
||||
return sessionId;
|
||||
}
|
||||
verify(memoryRuntime).recordMemory(eq(unit), eq(null), eq(List.of()));
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user