refactor(memory): manage state serialization via StateCenter in MemoryCore, and normalize slice and unit building

This commit is contained in:
2026-04-07 10:21:17 +08:00
parent 57bc63c57b
commit a242723727
7 changed files with 110 additions and 126 deletions

View File

@@ -1,44 +1,35 @@
package work.slhaf.partner.core.memory; package work.slhaf.partner.core.memory;
import lombok.EqualsAndHashCode; import com.alibaba.fastjson2.JSONArray;
import lombok.Getter; import com.alibaba.fastjson2.JSONObject;
import lombok.Setter;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import work.slhaf.partner.core.PartnerCore; import org.jetbrains.annotations.NotNull;
import work.slhaf.partner.core.memory.pojo.MemorySlice; import work.slhaf.partner.core.memory.pojo.MemorySlice;
import work.slhaf.partner.core.memory.pojo.MemoryUnit; import work.slhaf.partner.core.memory.pojo.MemoryUnit;
import work.slhaf.partner.framework.agent.factory.capability.annotation.CapabilityCore; import work.slhaf.partner.framework.agent.factory.capability.annotation.CapabilityCore;
import work.slhaf.partner.framework.agent.factory.capability.annotation.CapabilityMethod; import work.slhaf.partner.framework.agent.factory.capability.annotation.CapabilityMethod;
import work.slhaf.partner.framework.agent.state.State;
import work.slhaf.partner.framework.agent.state.StateSerializable;
import work.slhaf.partner.framework.agent.state.StateValue;
import java.io.IOException; import java.nio.file.Path;
import java.io.Serial; import java.util.*;
import java.time.Instant;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Comparator;
import java.util.UUID;
import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.locks.Lock; import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReentrantLock; import java.util.concurrent.locks.ReentrantLock;
@EqualsAndHashCode(callSuper = true)
@CapabilityCore(value = "memory") @CapabilityCore(value = "memory")
@Getter
@Setter
@Slf4j @Slf4j
public class MemoryCore extends PartnerCore<MemoryCore> { public class MemoryCore implements StateSerializable {
@Serial
private static final long serialVersionUID = 1L;
private final Lock memoryLock = new ReentrantLock(); private final Lock memoryLock = new ReentrantLock();
private ConcurrentHashMap<String, MemoryUnit> memoryUnits = new ConcurrentHashMap<>(); private final ConcurrentHashMap<String, MemoryUnit> memoryUnits = new ConcurrentHashMap<>();
// 默认值一般只存在于智能体初次启动时 // 默认值一般只存在于智能体初次启动时
private String memorySessionId = UUID.randomUUID().toString(); private String memorySessionId = UUID.randomUUID().toString();
private Instant memorySessionStartTime = Instant.now();
public MemoryCore() throws IOException, ClassNotFoundException { public MemoryCore() {
register();
} }
@CapabilityMethod @CapabilityMethod
@@ -54,7 +45,7 @@ public class MemoryCore extends PartnerCore<MemoryCore> {
@CapabilityMethod @CapabilityMethod
public MemoryUnit getMemoryUnit(String unitId) { public MemoryUnit getMemoryUnit(String unitId) {
return memoryUnits.get(unitId); return memoryUnits.computeIfAbsent(unitId, MemoryUnit::new);
} }
@CapabilityMethod @CapabilityMethod
@@ -79,7 +70,6 @@ public class MemoryCore extends PartnerCore<MemoryCore> {
@CapabilityMethod @CapabilityMethod
public void refreshMemorySession() { public void refreshMemorySession() {
memorySessionId = UUID.randomUUID().toString(); memorySessionId = UUID.randomUUID().toString();
memorySessionStartTime = Instant.now();
} }
@CapabilityMethod @CapabilityMethod
@@ -88,44 +78,43 @@ public class MemoryCore extends PartnerCore<MemoryCore> {
} }
private void normalizeMemoryUnit(MemoryUnit memoryUnit) { private void normalizeMemoryUnit(MemoryUnit memoryUnit) {
if (memoryUnit.getId() == null || memoryUnit.getId().isBlank()) {
memoryUnit.setId(UUID.randomUUID().toString());
}
if (memoryUnit.getTimestamp() == null || memoryUnit.getTimestamp() <= 0) {
memoryUnit.setTimestamp(System.currentTimeMillis());
}
if (memoryUnit.getConversationMessages() == null) {
memoryUnit.setConversationMessages(new ArrayList<>());
}
if (memoryUnit.getSlices() == null) {
memoryUnit.setSlices(new ArrayList<>());
}
int maxEndExclusive = Math.max(memoryUnit.getConversationMessages().size(), 0);
for (MemorySlice slice : memoryUnit.getSlices()) {
if (slice.getId() == null || slice.getId().isBlank()) {
slice.setId(UUID.randomUUID().toString());
}
if (slice.getTimestamp() == null || slice.getTimestamp() <= 0) {
slice.setTimestamp(memoryUnit.getTimestamp());
}
if (slice.getStartIndex() == null || slice.getStartIndex() < 0) {
slice.setStartIndex(0);
}
if (slice.getStartIndex() > maxEndExclusive) {
slice.setStartIndex(maxEndExclusive);
}
if (slice.getEndIndex() == null || slice.getEndIndex() < slice.getStartIndex()) {
slice.setEndIndex(maxEndExclusive);
}
if (slice.getEndIndex() > maxEndExclusive) {
slice.setEndIndex(maxEndExclusive);
}
}
memoryUnit.getSlices().sort(Comparator.naturalOrder()); memoryUnit.getSlices().sort(Comparator.naturalOrder());
} }
@Override @Override
protected String getCoreKey() { public @NotNull Path statePath() {
return "memory-core"; return Path.of("core", "memory.json");
}
@Override
public void load(@NotNull JSONObject state) {
String memorySessionId = state.getString("memory_session_id");
if (memorySessionId == null) {
throw new IllegalStateException("Memory session id is missing");
}
JSONArray array = state.getJSONArray("memory_unit_uuid_set");
if (array == null) {
throw new IllegalStateException("Memory unit uuid set is missing");
}
for (int i = 0; i < array.size(); i++) {
String unitUuid = array.getString(i);
if (unitUuid == null) {
throw new IllegalStateException("memory_unit_uuid_set is not a uuid array, index: " + i);
}
MemoryUnit memoryUnit = new MemoryUnit(unitUuid);
memoryUnits.put(unitUuid, memoryUnit);
}
}
@Override
public @NotNull State convert() {
State state = new State();
state.append("memory_session_id", StateValue.str(memorySessionId));
List<StateValue.Str> unitOverview = memoryUnits.keySet().stream()
.map(StateValue::str)
.toList();
state.append("memory_unit_uuid_set", StateValue.arr(unitOverview));
return state;
} }
} }

View File

@@ -5,6 +5,7 @@ import lombok.EqualsAndHashCode;
import work.slhaf.partner.framework.agent.common.entity.PersistableObject; import work.slhaf.partner.framework.agent.common.entity.PersistableObject;
import java.io.Serial; import java.io.Serial;
import java.util.UUID;
@EqualsAndHashCode(callSuper = true) @EqualsAndHashCode(callSuper = true)
@Data @Data
@@ -13,11 +14,19 @@ public class MemorySlice extends PersistableObject implements Comparable<MemoryS
@Serial @Serial
private static final long serialVersionUID = 1L; private static final long serialVersionUID = 1L;
private String id; private final String id;
private Integer startIndex; private final Integer startIndex;
private Integer endIndex; private final Integer endIndex;
private String summary; private final String summary;
private Long timestamp; private final Long timestamp;
public MemorySlice(Integer startIndex, Integer endIndex, String summary) {
this.id = UUID.randomUUID().toString();
this.timestamp = System.currentTimeMillis();
this.startIndex = startIndex;
this.endIndex = endIndex;
this.summary = summary;
}
@Override @Override
public int compareTo(MemorySlice memorySlice) { public int compareTo(MemorySlice memorySlice) {

View File

@@ -1,23 +1,24 @@
package work.slhaf.partner.core.memory.pojo; package work.slhaf.partner.core.memory.pojo;
import lombok.Data; import lombok.Getter;
import lombok.EqualsAndHashCode;
import work.slhaf.partner.framework.agent.common.entity.PersistableObject;
import work.slhaf.partner.framework.agent.model.pojo.Message; import work.slhaf.partner.framework.agent.model.pojo.Message;
import java.io.Serial;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.List; import java.util.List;
@EqualsAndHashCode(callSuper = true) @Getter
@Data public class MemoryUnit {
public class MemoryUnit extends PersistableObject {
@Serial private final String id;
private static final long serialVersionUID = 1L; private final List<Message> conversationMessages = new ArrayList<>();
private String id;
private List<Message> conversationMessages = new ArrayList<>();
private Long timestamp; private Long timestamp;
private List<MemorySlice> slices = new ArrayList<>(); private final List<MemorySlice> slices = new ArrayList<>();
public MemoryUnit(String id) {
this.id = id;
}
public void updateTimestamp() {
timestamp = System.currentTimeMillis();
}
} }

View File

@@ -28,7 +28,6 @@ import work.slhaf.partner.module.memory.updater.summarizer.entity.SummarizeInput
import work.slhaf.partner.module.memory.updater.summarizer.entity.SummarizeResult; import work.slhaf.partner.module.memory.updater.summarizer.entity.SummarizeResult;
import work.slhaf.partner.runtime.interaction.data.context.PartnerRunningFlowContext; import work.slhaf.partner.runtime.interaction.data.context.PartnerRunningFlowContext;
import java.util.ArrayList;
import java.util.List; import java.util.List;
import java.util.UUID; import java.util.UUID;
import java.util.concurrent.ExecutorService; import java.util.concurrent.ExecutorService;
@@ -189,34 +188,23 @@ public class MemoryUpdater extends AbstractAgentModule.Running<PartnerRunningFlo
} }
private MemoryUnit buildMemoryUnit(List<Message> chatMessages, SummarizeResult summarizeResult) { private MemoryUnit buildMemoryUnit(List<Message> chatMessages, SummarizeResult summarizeResult) {
long now = System.currentTimeMillis();
String memoryId = memoryCapability.getMemorySessionId(); String memoryId = memoryCapability.getMemorySessionId();
String resolvedMemoryId = memoryId == null || memoryId.isBlank() ? UUID.randomUUID().toString() : memoryId; String resolvedMemoryId = memoryId == null || memoryId.isBlank() ? UUID.randomUUID().toString() : memoryId;
MemoryUnit existingUnit = memoryCapability.getMemoryUnit(resolvedMemoryId); MemoryUnit existingUnit = memoryCapability.getMemoryUnit(resolvedMemoryId);
List<Message> existingMessages = existingUnit != null && existingUnit.getConversationMessages() != null List<Message> existingMessages = existingUnit.getConversationMessages();
? existingUnit.getConversationMessages()
: List.of();
int startIndex = existingMessages.size(); int startIndex = existingMessages.size();
MemorySlice memorySlice = new MemorySlice(); MemorySlice memorySlice = new MemorySlice(
memorySlice.setId(UUID.randomUUID().toString()); startIndex,
memorySlice.setStartIndex(startIndex); startIndex + chatMessages.size(),
memorySlice.setEndIndex(startIndex + chatMessages.size()); summarizeResult.getSummary()
memorySlice.setSummary(summarizeResult.getSummary()); );
memorySlice.setTimestamp(now);
MemoryUnit memoryUnit = new MemoryUnit(); MemoryUnit memoryUnit = new MemoryUnit(resolvedMemoryId);
memoryUnit.setId(resolvedMemoryId); memoryUnit.updateTimestamp();
memoryUnit.setTimestamp(now); memoryUnit.getConversationMessages().addAll(chatMessages);
List<Message> conversationMessages = new ArrayList<>(existingMessages);
conversationMessages.addAll(chatMessages);
memoryUnit.setConversationMessages(conversationMessages);
List<MemorySlice> slices = existingUnit != null && existingUnit.getSlices() != null memoryUnit.getSlices().add(memorySlice);
? new ArrayList<>(existingUnit.getSlices())
: new ArrayList<>();
slices.add(memorySlice);
memoryUnit.setSlices(slices);
return memoryUnit; return memoryUnit;
} }

View File

@@ -11,7 +11,6 @@ import work.slhaf.partner.framework.agent.model.pojo.Message;
import java.nio.file.Files; import java.nio.file.Files;
import java.nio.file.Path; import java.nio.file.Path;
import java.util.ArrayList;
import java.util.HashMap; import java.util.HashMap;
import java.util.List; import java.util.List;
import java.util.UUID; import java.util.UUID;
@@ -57,14 +56,13 @@ class MemoryCoreTest {
slice.setStartIndex(1); slice.setStartIndex(1);
slice.setEndIndex(99); slice.setEndIndex(99);
MemoryUnit unit = new MemoryUnit(); MemoryUnit unit = new MemoryUnit("unit-1");
unit.setId("unit-1"); unit.getConversationMessages().addAll(List.of(
unit.setConversationMessages(new ArrayList<>(List.of(
new Message(Message.Character.USER, "m0"), new Message(Message.Character.USER, "m0"),
new Message(Message.Character.USER, "m1"), new Message(Message.Character.USER, "m1"),
new Message(Message.Character.USER, "m2") new Message(Message.Character.USER, "m2")
))); ));
unit.setSlices(new ArrayList<>(List.of(slice))); unit.getSlices().add(slice);
memoryCore.saveMemoryUnit(unit); memoryCore.saveMemoryUnit(unit);

View File

@@ -74,13 +74,13 @@ class MemoryRuntimeTest {
@Test @Test
void shouldSliceMessagesUsingLeftClosedRightOpenRange() throws Exception { void shouldSliceMessagesUsingLeftClosedRightOpenRange() throws Exception {
MemoryRuntime runtime = new MemoryRuntime(); MemoryRuntime runtime = new MemoryRuntime();
MemoryUnit unit = new MemoryUnit(); MemoryUnit unit = new MemoryUnit("unit-1");
unit.setConversationMessages(new ArrayList<>(List.of( unit.getConversationMessages().addAll(List.of(
message("m0"), message("m0"),
message("m1"), message("m1"),
message("m2"), message("m2"),
message("m3") message("m3")
))); ));
MemorySlice slice = new MemorySlice(); MemorySlice slice = new MemorySlice();
slice.setStartIndex(1); slice.setStartIndex(1);
@@ -105,14 +105,13 @@ class MemoryRuntimeTest {
MemoryRuntime runtime = new MemoryRuntime(); MemoryRuntime runtime = new MemoryRuntime();
setField(runtime, "memoryCapability", memoryCapability); setField(runtime, "memoryCapability", memoryCapability);
MemoryUnit unit = new MemoryUnit(); MemoryUnit unit = new MemoryUnit("unit-1");
unit.setId("unit-1"); unit.getConversationMessages().addAll(List.of(
unit.setConversationMessages(new ArrayList<>(List.of(
message("m0"), message("m0"),
message("m1"), message("m1"),
message("m2"), message("m2"),
message("m3") message("m3")
))); ));
MemorySlice firstSlice = new MemorySlice(); MemorySlice firstSlice = new MemorySlice();
firstSlice.setId("slice-1"); firstSlice.setId("slice-1");
@@ -128,7 +127,7 @@ class MemoryRuntimeTest {
secondSlice.setSummary("second"); secondSlice.setSummary("second");
secondSlice.setTimestamp(2L); secondSlice.setTimestamp(2L);
unit.setSlices(new ArrayList<>(List.of(firstSlice, secondSlice))); unit.getSlices().addAll(List.of(firstSlice, secondSlice));
runtime.recordMemory(unit, "topic/main", List.of("topic/related")); runtime.recordMemory(unit, "topic/main", List.of("topic/related"));

View File

@@ -9,7 +9,10 @@ import work.slhaf.partner.module.memory.updater.summarizer.entity.SummarizeResul
import java.lang.reflect.Field; import java.lang.reflect.Field;
import java.lang.reflect.Method; import java.lang.reflect.Method;
import java.util.*; import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull; import static org.junit.jupiter.api.Assertions.assertNotNull;
@@ -55,19 +58,18 @@ class MemoryUpdaterTest {
setField(updater, "memoryCapability", memoryCapability); setField(updater, "memoryCapability", memoryCapability);
String sessionId = memoryCapability.getMemorySessionId(); String sessionId = memoryCapability.getMemorySessionId();
MemoryUnit existingUnit = new MemoryUnit(); MemoryUnit existingUnit = new MemoryUnit(sessionId);
existingUnit.setId(sessionId); existingUnit.getConversationMessages().addAll(List.of(
existingUnit.setConversationMessages(new ArrayList<>(List.of(
message(Message.Character.USER, "old-user"), message(Message.Character.USER, "old-user"),
message(Message.Character.ASSISTANT, "old-assistant") message(Message.Character.ASSISTANT, "old-assistant")
))); ));
MemorySlice existingSlice = new MemorySlice(); MemorySlice existingSlice = new MemorySlice();
existingSlice.setId("slice-1"); existingSlice.setId("slice-1");
existingSlice.setStartIndex(0); existingSlice.setStartIndex(0);
existingSlice.setEndIndex(2); existingSlice.setEndIndex(2);
existingSlice.setSummary("old-summary"); existingSlice.setSummary("old-summary");
existingSlice.setTimestamp(1L); existingSlice.setTimestamp(1L);
existingUnit.setSlices(new ArrayList<>(List.of(existingSlice))); existingUnit.getSlices().add(existingSlice);
memoryCapability.saveMemoryUnit(existingUnit); memoryCapability.saveMemoryUnit(existingUnit);
MemoryUnit merged = invokeBuildMemoryUnit( MemoryUnit merged = invokeBuildMemoryUnit(
@@ -121,14 +123,13 @@ class MemoryUpdaterTest {
MemoryUpdater updater = new MemoryUpdater(); MemoryUpdater updater = new MemoryUpdater();
setField(updater, "memoryCapability", memoryCapability); setField(updater, "memoryCapability", memoryCapability);
MemoryUnit existingUnit = new MemoryUnit(); MemoryUnit existingUnit = new MemoryUnit("session-3");
existingUnit.setId("session-3"); existingUnit.getConversationMessages().addAll(List.of(
existingUnit.setConversationMessages(new ArrayList<>(List.of(
message(Message.Character.USER, "m1"), message(Message.Character.USER, "m1"),
message(Message.Character.ASSISTANT, "m2"), message(Message.Character.ASSISTANT, "m2"),
message(Message.Character.USER, "m3"), message(Message.Character.USER, "m3"),
message(Message.Character.ASSISTANT, "m4") message(Message.Character.ASSISTANT, "m4")
))); ));
memoryCapability.saveMemoryUnit(existingUnit); memoryCapability.saveMemoryUnit(existingUnit);
List<Message> increment = invokeResolveChatIncrement( List<Message> increment = invokeResolveChatIncrement(
@@ -150,13 +151,12 @@ class MemoryUpdaterTest {
MemoryUpdater updater = new MemoryUpdater(); MemoryUpdater updater = new MemoryUpdater();
setField(updater, "memoryCapability", memoryCapability); setField(updater, "memoryCapability", memoryCapability);
MemoryUnit existingUnit = new MemoryUnit(); MemoryUnit existingUnit = new MemoryUnit("session-4");
existingUnit.setId("session-4"); existingUnit.getConversationMessages().addAll(List.of(
existingUnit.setConversationMessages(new ArrayList<>(List.of(
message(Message.Character.USER, "m1"), message(Message.Character.USER, "m1"),
message(Message.Character.ASSISTANT, "m2"), message(Message.Character.ASSISTANT, "m2"),
message(Message.Character.USER, "m3") message(Message.Character.USER, "m3")
))); ));
memoryCapability.saveMemoryUnit(existingUnit); memoryCapability.saveMemoryUnit(existingUnit);
List<Message> increment = invokeResolveChatIncrement( List<Message> increment = invokeResolveChatIncrement(