refactor(memory): manage state serialization via StateCenter in MemoryCore, and normalize slice and unit building

This commit is contained in:
2026-04-07 10:21:17 +08:00
parent 57bc63c57b
commit a242723727
7 changed files with 110 additions and 126 deletions

View File

@@ -1,44 +1,35 @@
package work.slhaf.partner.core.memory;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.Setter;
import com.alibaba.fastjson2.JSONArray;
import com.alibaba.fastjson2.JSONObject;
import lombok.extern.slf4j.Slf4j;
import work.slhaf.partner.core.PartnerCore;
import org.jetbrains.annotations.NotNull;
import work.slhaf.partner.core.memory.pojo.MemorySlice;
import work.slhaf.partner.core.memory.pojo.MemoryUnit;
import work.slhaf.partner.framework.agent.factory.capability.annotation.CapabilityCore;
import work.slhaf.partner.framework.agent.factory.capability.annotation.CapabilityMethod;
import work.slhaf.partner.framework.agent.state.State;
import work.slhaf.partner.framework.agent.state.StateSerializable;
import work.slhaf.partner.framework.agent.state.StateValue;
import java.io.IOException;
import java.io.Serial;
import java.time.Instant;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Comparator;
import java.util.UUID;
import java.nio.file.Path;
import java.util.*;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReentrantLock;
@EqualsAndHashCode(callSuper = true)
@CapabilityCore(value = "memory")
@Getter
@Setter
@Slf4j
public class MemoryCore extends PartnerCore<MemoryCore> {
@Serial
private static final long serialVersionUID = 1L;
public class MemoryCore implements StateSerializable {
private final Lock memoryLock = new ReentrantLock();
private ConcurrentHashMap<String, MemoryUnit> memoryUnits = new ConcurrentHashMap<>();
private final ConcurrentHashMap<String, MemoryUnit> memoryUnits = new ConcurrentHashMap<>();
// 默认值一般只存在于智能体初次启动时
private String memorySessionId = UUID.randomUUID().toString();
private Instant memorySessionStartTime = Instant.now();
public MemoryCore() throws IOException, ClassNotFoundException {
public MemoryCore() {
register();
}
@CapabilityMethod
@@ -54,7 +45,7 @@ public class MemoryCore extends PartnerCore<MemoryCore> {
@CapabilityMethod
public MemoryUnit getMemoryUnit(String unitId) {
return memoryUnits.get(unitId);
return memoryUnits.computeIfAbsent(unitId, MemoryUnit::new);
}
@CapabilityMethod
@@ -79,7 +70,6 @@ public class MemoryCore extends PartnerCore<MemoryCore> {
@CapabilityMethod
public void refreshMemorySession() {
memorySessionId = UUID.randomUUID().toString();
memorySessionStartTime = Instant.now();
}
@CapabilityMethod
@@ -88,44 +78,43 @@ public class MemoryCore extends PartnerCore<MemoryCore> {
}
private void normalizeMemoryUnit(MemoryUnit memoryUnit) {
if (memoryUnit.getId() == null || memoryUnit.getId().isBlank()) {
memoryUnit.setId(UUID.randomUUID().toString());
}
if (memoryUnit.getTimestamp() == null || memoryUnit.getTimestamp() <= 0) {
memoryUnit.setTimestamp(System.currentTimeMillis());
}
if (memoryUnit.getConversationMessages() == null) {
memoryUnit.setConversationMessages(new ArrayList<>());
}
if (memoryUnit.getSlices() == null) {
memoryUnit.setSlices(new ArrayList<>());
}
int maxEndExclusive = Math.max(memoryUnit.getConversationMessages().size(), 0);
for (MemorySlice slice : memoryUnit.getSlices()) {
if (slice.getId() == null || slice.getId().isBlank()) {
slice.setId(UUID.randomUUID().toString());
}
if (slice.getTimestamp() == null || slice.getTimestamp() <= 0) {
slice.setTimestamp(memoryUnit.getTimestamp());
}
if (slice.getStartIndex() == null || slice.getStartIndex() < 0) {
slice.setStartIndex(0);
}
if (slice.getStartIndex() > maxEndExclusive) {
slice.setStartIndex(maxEndExclusive);
}
if (slice.getEndIndex() == null || slice.getEndIndex() < slice.getStartIndex()) {
slice.setEndIndex(maxEndExclusive);
}
if (slice.getEndIndex() > maxEndExclusive) {
slice.setEndIndex(maxEndExclusive);
}
}
memoryUnit.getSlices().sort(Comparator.naturalOrder());
}
@Override
protected String getCoreKey() {
return "memory-core";
public @NotNull Path statePath() {
return Path.of("core", "memory.json");
}
@Override
public void load(@NotNull JSONObject state) {
String memorySessionId = state.getString("memory_session_id");
if (memorySessionId == null) {
throw new IllegalStateException("Memory session id is missing");
}
JSONArray array = state.getJSONArray("memory_unit_uuid_set");
if (array == null) {
throw new IllegalStateException("Memory unit uuid set is missing");
}
for (int i = 0; i < array.size(); i++) {
String unitUuid = array.getString(i);
if (unitUuid == null) {
throw new IllegalStateException("memory_unit_uuid_set is not a uuid array, index: " + i);
}
MemoryUnit memoryUnit = new MemoryUnit(unitUuid);
memoryUnits.put(unitUuid, memoryUnit);
}
}
@Override
public @NotNull State convert() {
State state = new State();
state.append("memory_session_id", StateValue.str(memorySessionId));
List<StateValue.Str> unitOverview = memoryUnits.keySet().stream()
.map(StateValue::str)
.toList();
state.append("memory_unit_uuid_set", StateValue.arr(unitOverview));
return state;
}
}

View File

@@ -5,6 +5,7 @@ import lombok.EqualsAndHashCode;
import work.slhaf.partner.framework.agent.common.entity.PersistableObject;
import java.io.Serial;
import java.util.UUID;
@EqualsAndHashCode(callSuper = true)
@Data
@@ -13,11 +14,19 @@ public class MemorySlice extends PersistableObject implements Comparable<MemoryS
@Serial
private static final long serialVersionUID = 1L;
private String id;
private Integer startIndex;
private Integer endIndex;
private String summary;
private Long timestamp;
private final String id;
private final Integer startIndex;
private final Integer endIndex;
private final String summary;
private final Long timestamp;
public MemorySlice(Integer startIndex, Integer endIndex, String summary) {
this.id = UUID.randomUUID().toString();
this.timestamp = System.currentTimeMillis();
this.startIndex = startIndex;
this.endIndex = endIndex;
this.summary = summary;
}
@Override
public int compareTo(MemorySlice memorySlice) {

View File

@@ -1,23 +1,24 @@
package work.slhaf.partner.core.memory.pojo;
import lombok.Data;
import lombok.EqualsAndHashCode;
import work.slhaf.partner.framework.agent.common.entity.PersistableObject;
import lombok.Getter;
import work.slhaf.partner.framework.agent.model.pojo.Message;
import java.io.Serial;
import java.util.ArrayList;
import java.util.List;
@EqualsAndHashCode(callSuper = true)
@Data
public class MemoryUnit extends PersistableObject {
@Getter
public class MemoryUnit {
@Serial
private static final long serialVersionUID = 1L;
private String id;
private List<Message> conversationMessages = new ArrayList<>();
private final String id;
private final List<Message> conversationMessages = new ArrayList<>();
private Long timestamp;
private List<MemorySlice> slices = new ArrayList<>();
private final List<MemorySlice> slices = new ArrayList<>();
public MemoryUnit(String id) {
this.id = id;
}
public void updateTimestamp() {
timestamp = System.currentTimeMillis();
}
}

View File

@@ -28,7 +28,6 @@ import work.slhaf.partner.module.memory.updater.summarizer.entity.SummarizeInput
import work.slhaf.partner.module.memory.updater.summarizer.entity.SummarizeResult;
import work.slhaf.partner.runtime.interaction.data.context.PartnerRunningFlowContext;
import java.util.ArrayList;
import java.util.List;
import java.util.UUID;
import java.util.concurrent.ExecutorService;
@@ -189,34 +188,23 @@ public class MemoryUpdater extends AbstractAgentModule.Running<PartnerRunningFlo
}
private MemoryUnit buildMemoryUnit(List<Message> chatMessages, SummarizeResult summarizeResult) {
long now = System.currentTimeMillis();
String memoryId = memoryCapability.getMemorySessionId();
String resolvedMemoryId = memoryId == null || memoryId.isBlank() ? UUID.randomUUID().toString() : memoryId;
MemoryUnit existingUnit = memoryCapability.getMemoryUnit(resolvedMemoryId);
List<Message> existingMessages = existingUnit != null && existingUnit.getConversationMessages() != null
? existingUnit.getConversationMessages()
: List.of();
List<Message> existingMessages = existingUnit.getConversationMessages();
int startIndex = existingMessages.size();
MemorySlice memorySlice = new MemorySlice();
memorySlice.setId(UUID.randomUUID().toString());
memorySlice.setStartIndex(startIndex);
memorySlice.setEndIndex(startIndex + chatMessages.size());
memorySlice.setSummary(summarizeResult.getSummary());
memorySlice.setTimestamp(now);
MemorySlice memorySlice = new MemorySlice(
startIndex,
startIndex + chatMessages.size(),
summarizeResult.getSummary()
);
MemoryUnit memoryUnit = new MemoryUnit();
memoryUnit.setId(resolvedMemoryId);
memoryUnit.setTimestamp(now);
List<Message> conversationMessages = new ArrayList<>(existingMessages);
conversationMessages.addAll(chatMessages);
memoryUnit.setConversationMessages(conversationMessages);
MemoryUnit memoryUnit = new MemoryUnit(resolvedMemoryId);
memoryUnit.updateTimestamp();
memoryUnit.getConversationMessages().addAll(chatMessages);
List<MemorySlice> slices = existingUnit != null && existingUnit.getSlices() != null
? new ArrayList<>(existingUnit.getSlices())
: new ArrayList<>();
slices.add(memorySlice);
memoryUnit.setSlices(slices);
memoryUnit.getSlices().add(memorySlice);
return memoryUnit;
}

View File

@@ -11,7 +11,6 @@ import work.slhaf.partner.framework.agent.model.pojo.Message;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.UUID;
@@ -57,14 +56,13 @@ class MemoryCoreTest {
slice.setStartIndex(1);
slice.setEndIndex(99);
MemoryUnit unit = new MemoryUnit();
unit.setId("unit-1");
unit.setConversationMessages(new ArrayList<>(List.of(
MemoryUnit unit = new MemoryUnit("unit-1");
unit.getConversationMessages().addAll(List.of(
new Message(Message.Character.USER, "m0"),
new Message(Message.Character.USER, "m1"),
new Message(Message.Character.USER, "m2")
)));
unit.setSlices(new ArrayList<>(List.of(slice)));
));
unit.getSlices().add(slice);
memoryCore.saveMemoryUnit(unit);

View File

@@ -74,13 +74,13 @@ class MemoryRuntimeTest {
@Test
void shouldSliceMessagesUsingLeftClosedRightOpenRange() throws Exception {
MemoryRuntime runtime = new MemoryRuntime();
MemoryUnit unit = new MemoryUnit();
unit.setConversationMessages(new ArrayList<>(List.of(
MemoryUnit unit = new MemoryUnit("unit-1");
unit.getConversationMessages().addAll(List.of(
message("m0"),
message("m1"),
message("m2"),
message("m3")
)));
));
MemorySlice slice = new MemorySlice();
slice.setStartIndex(1);
@@ -105,14 +105,13 @@ class MemoryRuntimeTest {
MemoryRuntime runtime = new MemoryRuntime();
setField(runtime, "memoryCapability", memoryCapability);
MemoryUnit unit = new MemoryUnit();
unit.setId("unit-1");
unit.setConversationMessages(new ArrayList<>(List.of(
MemoryUnit unit = new MemoryUnit("unit-1");
unit.getConversationMessages().addAll(List.of(
message("m0"),
message("m1"),
message("m2"),
message("m3")
)));
));
MemorySlice firstSlice = new MemorySlice();
firstSlice.setId("slice-1");
@@ -128,7 +127,7 @@ class MemoryRuntimeTest {
secondSlice.setSummary("second");
secondSlice.setTimestamp(2L);
unit.setSlices(new ArrayList<>(List.of(firstSlice, secondSlice)));
unit.getSlices().addAll(List.of(firstSlice, secondSlice));
runtime.recordMemory(unit, "topic/main", List.of("topic/related"));

View File

@@ -9,7 +9,10 @@ import work.slhaf.partner.module.memory.updater.summarizer.entity.SummarizeResul
import java.lang.reflect.Field;
import java.lang.reflect.Method;
import java.util.*;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull;
@@ -55,19 +58,18 @@ class MemoryUpdaterTest {
setField(updater, "memoryCapability", memoryCapability);
String sessionId = memoryCapability.getMemorySessionId();
MemoryUnit existingUnit = new MemoryUnit();
existingUnit.setId(sessionId);
existingUnit.setConversationMessages(new ArrayList<>(List.of(
MemoryUnit existingUnit = new MemoryUnit(sessionId);
existingUnit.getConversationMessages().addAll(List.of(
message(Message.Character.USER, "old-user"),
message(Message.Character.ASSISTANT, "old-assistant")
)));
));
MemorySlice existingSlice = new MemorySlice();
existingSlice.setId("slice-1");
existingSlice.setStartIndex(0);
existingSlice.setEndIndex(2);
existingSlice.setSummary("old-summary");
existingSlice.setTimestamp(1L);
existingUnit.setSlices(new ArrayList<>(List.of(existingSlice)));
existingUnit.getSlices().add(existingSlice);
memoryCapability.saveMemoryUnit(existingUnit);
MemoryUnit merged = invokeBuildMemoryUnit(
@@ -121,14 +123,13 @@ class MemoryUpdaterTest {
MemoryUpdater updater = new MemoryUpdater();
setField(updater, "memoryCapability", memoryCapability);
MemoryUnit existingUnit = new MemoryUnit();
existingUnit.setId("session-3");
existingUnit.setConversationMessages(new ArrayList<>(List.of(
MemoryUnit existingUnit = new MemoryUnit("session-3");
existingUnit.getConversationMessages().addAll(List.of(
message(Message.Character.USER, "m1"),
message(Message.Character.ASSISTANT, "m2"),
message(Message.Character.USER, "m3"),
message(Message.Character.ASSISTANT, "m4")
)));
));
memoryCapability.saveMemoryUnit(existingUnit);
List<Message> increment = invokeResolveChatIncrement(
@@ -150,13 +151,12 @@ class MemoryUpdaterTest {
MemoryUpdater updater = new MemoryUpdater();
setField(updater, "memoryCapability", memoryCapability);
MemoryUnit existingUnit = new MemoryUnit();
existingUnit.setId("session-4");
existingUnit.setConversationMessages(new ArrayList<>(List.of(
MemoryUnit existingUnit = new MemoryUnit("session-4");
existingUnit.getConversationMessages().addAll(List.of(
message(Message.Character.USER, "m1"),
message(Message.Character.ASSISTANT, "m2"),
message(Message.Character.USER, "m3")
)));
));
memoryCapability.saveMemoryUnit(existingUnit);
List<Message> increment = invokeResolveChatIncrement(