mirror of
https://github.com/slhaf/Partner.git
synced 2026-05-12 08:43:02 +08:00
fix(memory): trim persisted overlap from chat snapshot in MemoryUpdater
This commit is contained in:
@@ -111,14 +111,21 @@ public class MemoryUpdater extends AbstractAgentModule.Running<PartnerRunningFlo
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
try {
|
try {
|
||||||
List<Message> chatSnapshot = cognitionCapability.snapshotChatMessages();
|
List<Message> fullChatSnapshot = cognitionCapability.snapshotChatMessages();
|
||||||
if (chatSnapshot.size() <= 1) {
|
if (fullChatSnapshot.size() <= 1) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
List<Message> chatIncrement = resolveChatIncrement(fullChatSnapshot);
|
||||||
|
if (chatIncrement.isEmpty()) {
|
||||||
|
if (refreshMemoryId) {
|
||||||
|
memoryCapability.refreshMemorySession();
|
||||||
|
}
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
RollingRecord record = updateMemory(chatSnapshot);
|
RollingRecord record = updateMemory(chatIncrement);
|
||||||
if (record != null) {
|
if (record != null) {
|
||||||
dialogRollingService.rollMessages(chatSnapshot, chatSnapshot.size(), CONTEXT_RETAIN_DIVISOR, record.unitId, record.sliceId, record.summary);
|
dialogRollingService.rollMessages(chatIncrement, fullChatSnapshot.size(), CONTEXT_RETAIN_DIVISOR, record.unitId, record.sliceId, record.summary);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (refreshMemoryId) {
|
if (refreshMemoryId) {
|
||||||
@@ -131,6 +138,27 @@ public class MemoryUpdater extends AbstractAgentModule.Running<PartnerRunningFlo
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private List<Message> resolveChatIncrement(List<Message> fullChatSnapshot) {
|
||||||
|
String memoryId = memoryCapability.getMemorySessionId();
|
||||||
|
if (memoryId == null || memoryId.isBlank()) {
|
||||||
|
return fullChatSnapshot;
|
||||||
|
}
|
||||||
|
MemoryUnit existingUnit = memoryCapability.getMemoryUnit(memoryId);
|
||||||
|
if (existingUnit == null || existingUnit.getConversationMessages() == null || existingUnit.getConversationMessages().isEmpty()) {
|
||||||
|
return fullChatSnapshot;
|
||||||
|
}
|
||||||
|
List<Message> existingMessages = existingUnit.getConversationMessages();
|
||||||
|
int maxOverlap = Math.min(existingMessages.size(), fullChatSnapshot.size());
|
||||||
|
for (int overlap = maxOverlap; overlap > 0; overlap--) {
|
||||||
|
List<Message> existingSuffix = existingMessages.subList(existingMessages.size() - overlap, existingMessages.size());
|
||||||
|
List<Message> snapshotPrefix = fullChatSnapshot.subList(0, overlap);
|
||||||
|
if (existingSuffix.equals(snapshotPrefix)) {
|
||||||
|
return fullChatSnapshot.subList(overlap, fullChatSnapshot.size());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return fullChatSnapshot;
|
||||||
|
}
|
||||||
|
|
||||||
private RollingRecord updateMemory(List<Message> chatSnapshot) {
|
private RollingRecord updateMemory(List<Message> chatSnapshot) {
|
||||||
log.debug("[MemoryUpdater] 记忆更新流程开始...");
|
log.debug("[MemoryUpdater] 记忆更新流程开始...");
|
||||||
if (chatSnapshot.isEmpty()) {
|
if (chatSnapshot.isEmpty()) {
|
||||||
|
|||||||
@@ -24,6 +24,14 @@ class MemoryUpdaterTest {
|
|||||||
return (MemoryUnit) method.invoke(updater, chatMessages, summarizeResult);
|
return (MemoryUnit) method.invoke(updater, chatMessages, summarizeResult);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@SuppressWarnings("unchecked")
|
||||||
|
private static List<Message> invokeResolveChatIncrement(MemoryUpdater updater,
|
||||||
|
List<Message> chatMessages) throws Exception {
|
||||||
|
Method method = MemoryUpdater.class.getDeclaredMethod("resolveChatIncrement", List.class);
|
||||||
|
method.setAccessible(true);
|
||||||
|
return (List<Message>) method.invoke(updater, chatMessages);
|
||||||
|
}
|
||||||
|
|
||||||
private static void setField(Object target, String fieldName, Object value) throws Exception {
|
private static void setField(Object target, String fieldName, Object value) throws Exception {
|
||||||
Field field = target.getClass().getDeclaredField(fieldName);
|
Field field = target.getClass().getDeclaredField(fieldName);
|
||||||
field.setAccessible(true);
|
field.setAccessible(true);
|
||||||
@@ -107,6 +115,61 @@ class MemoryUpdaterTest {
|
|||||||
assertEquals("fresh-summary", created.getSlices().getFirst().getSummary());
|
assertEquals("fresh-summary", created.getSlices().getFirst().getSummary());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void shouldTrimPersistedOverlapFromCurrentSnapshot() throws Exception {
|
||||||
|
StubMemoryCapability memoryCapability = new StubMemoryCapability("session-3");
|
||||||
|
MemoryUpdater updater = new MemoryUpdater();
|
||||||
|
setField(updater, "memoryCapability", memoryCapability);
|
||||||
|
|
||||||
|
MemoryUnit existingUnit = new MemoryUnit();
|
||||||
|
existingUnit.setId("session-3");
|
||||||
|
existingUnit.setConversationMessages(new ArrayList<>(List.of(
|
||||||
|
message(Message.Character.USER, "m1"),
|
||||||
|
message(Message.Character.ASSISTANT, "m2"),
|
||||||
|
message(Message.Character.USER, "m3"),
|
||||||
|
message(Message.Character.ASSISTANT, "m4")
|
||||||
|
)));
|
||||||
|
memoryCapability.saveMemoryUnit(existingUnit);
|
||||||
|
|
||||||
|
List<Message> increment = invokeResolveChatIncrement(
|
||||||
|
updater,
|
||||||
|
List.of(
|
||||||
|
message(Message.Character.USER, "m3"),
|
||||||
|
message(Message.Character.ASSISTANT, "m4"),
|
||||||
|
message(Message.Character.USER, "m5"),
|
||||||
|
message(Message.Character.ASSISTANT, "m6")
|
||||||
|
)
|
||||||
|
);
|
||||||
|
|
||||||
|
assertEquals(List.of("m5", "m6"), increment.stream().map(Message::getContent).toList());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void shouldReturnEmptyIncrementWhenSnapshotIsFullyPersisted() throws Exception {
|
||||||
|
StubMemoryCapability memoryCapability = new StubMemoryCapability("session-4");
|
||||||
|
MemoryUpdater updater = new MemoryUpdater();
|
||||||
|
setField(updater, "memoryCapability", memoryCapability);
|
||||||
|
|
||||||
|
MemoryUnit existingUnit = new MemoryUnit();
|
||||||
|
existingUnit.setId("session-4");
|
||||||
|
existingUnit.setConversationMessages(new ArrayList<>(List.of(
|
||||||
|
message(Message.Character.USER, "m1"),
|
||||||
|
message(Message.Character.ASSISTANT, "m2"),
|
||||||
|
message(Message.Character.USER, "m3")
|
||||||
|
)));
|
||||||
|
memoryCapability.saveMemoryUnit(existingUnit);
|
||||||
|
|
||||||
|
List<Message> increment = invokeResolveChatIncrement(
|
||||||
|
updater,
|
||||||
|
List.of(
|
||||||
|
message(Message.Character.ASSISTANT, "m2"),
|
||||||
|
message(Message.Character.USER, "m3")
|
||||||
|
)
|
||||||
|
);
|
||||||
|
|
||||||
|
assertEquals(List.of(), increment);
|
||||||
|
}
|
||||||
|
|
||||||
private static final class StubMemoryCapability implements MemoryCapability {
|
private static final class StubMemoryCapability implements MemoryCapability {
|
||||||
private final String sessionId;
|
private final String sessionId;
|
||||||
private final Map<String, MemoryUnit> units = new HashMap<>();
|
private final Map<String, MemoryUnit> units = new HashMap<>();
|
||||||
|
|||||||
Reference in New Issue
Block a user