refactor(memory): manage state serialization via StateCenter in MemoryCore, and normalize slice and unit building

This commit is contained in:
2026-04-07 10:21:17 +08:00
parent 57bc63c57b
commit a242723727
7 changed files with 110 additions and 126 deletions

View File

@@ -1,44 +1,35 @@
package work.slhaf.partner.core.memory;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.Setter;
import com.alibaba.fastjson2.JSONArray;
import com.alibaba.fastjson2.JSONObject;
import lombok.extern.slf4j.Slf4j;
import work.slhaf.partner.core.PartnerCore;
import org.jetbrains.annotations.NotNull;
import work.slhaf.partner.core.memory.pojo.MemorySlice;
import work.slhaf.partner.core.memory.pojo.MemoryUnit;
import work.slhaf.partner.framework.agent.factory.capability.annotation.CapabilityCore;
import work.slhaf.partner.framework.agent.factory.capability.annotation.CapabilityMethod;
import work.slhaf.partner.framework.agent.state.State;
import work.slhaf.partner.framework.agent.state.StateSerializable;
import work.slhaf.partner.framework.agent.state.StateValue;
import java.io.IOException;
import java.io.Serial;
import java.time.Instant;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Comparator;
import java.util.UUID;
import java.nio.file.Path;
import java.util.*;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReentrantLock;
@EqualsAndHashCode(callSuper = true)
@CapabilityCore(value = "memory")
@Getter
@Setter
@Slf4j
public class MemoryCore extends PartnerCore<MemoryCore> {
@Serial
private static final long serialVersionUID = 1L;
public class MemoryCore implements StateSerializable {
private final Lock memoryLock = new ReentrantLock();
private ConcurrentHashMap<String, MemoryUnit> memoryUnits = new ConcurrentHashMap<>();
private final ConcurrentHashMap<String, MemoryUnit> memoryUnits = new ConcurrentHashMap<>();
// 默认值一般只存在于智能体初次启动时
private String memorySessionId = UUID.randomUUID().toString();
private Instant memorySessionStartTime = Instant.now();
public MemoryCore() throws IOException, ClassNotFoundException {
public MemoryCore() {
register();
}
@CapabilityMethod
@@ -54,7 +45,7 @@ public class MemoryCore extends PartnerCore<MemoryCore> {
@CapabilityMethod
public MemoryUnit getMemoryUnit(String unitId) {
return memoryUnits.get(unitId);
return memoryUnits.computeIfAbsent(unitId, MemoryUnit::new);
}
@CapabilityMethod
@@ -79,7 +70,6 @@ public class MemoryCore extends PartnerCore<MemoryCore> {
@CapabilityMethod
public void refreshMemorySession() {
memorySessionId = UUID.randomUUID().toString();
memorySessionStartTime = Instant.now();
}
@CapabilityMethod
@@ -88,44 +78,43 @@ public class MemoryCore extends PartnerCore<MemoryCore> {
}
private void normalizeMemoryUnit(MemoryUnit memoryUnit) {
if (memoryUnit.getId() == null || memoryUnit.getId().isBlank()) {
memoryUnit.setId(UUID.randomUUID().toString());
}
if (memoryUnit.getTimestamp() == null || memoryUnit.getTimestamp() <= 0) {
memoryUnit.setTimestamp(System.currentTimeMillis());
}
if (memoryUnit.getConversationMessages() == null) {
memoryUnit.setConversationMessages(new ArrayList<>());
}
if (memoryUnit.getSlices() == null) {
memoryUnit.setSlices(new ArrayList<>());
}
int maxEndExclusive = Math.max(memoryUnit.getConversationMessages().size(), 0);
for (MemorySlice slice : memoryUnit.getSlices()) {
if (slice.getId() == null || slice.getId().isBlank()) {
slice.setId(UUID.randomUUID().toString());
}
if (slice.getTimestamp() == null || slice.getTimestamp() <= 0) {
slice.setTimestamp(memoryUnit.getTimestamp());
}
if (slice.getStartIndex() == null || slice.getStartIndex() < 0) {
slice.setStartIndex(0);
}
if (slice.getStartIndex() > maxEndExclusive) {
slice.setStartIndex(maxEndExclusive);
}
if (slice.getEndIndex() == null || slice.getEndIndex() < slice.getStartIndex()) {
slice.setEndIndex(maxEndExclusive);
}
if (slice.getEndIndex() > maxEndExclusive) {
slice.setEndIndex(maxEndExclusive);
}
}
memoryUnit.getSlices().sort(Comparator.naturalOrder());
}
@Override
protected String getCoreKey() {
return "memory-core";
public @NotNull Path statePath() {
return Path.of("core", "memory.json");
}
@Override
public void load(@NotNull JSONObject state) {
String memorySessionId = state.getString("memory_session_id");
if (memorySessionId == null) {
throw new IllegalStateException("Memory session id is missing");
}
JSONArray array = state.getJSONArray("memory_unit_uuid_set");
if (array == null) {
throw new IllegalStateException("Memory unit uuid set is missing");
}
for (int i = 0; i < array.size(); i++) {
String unitUuid = array.getString(i);
if (unitUuid == null) {
throw new IllegalStateException("memory_unit_uuid_set is not a uuid array, index: " + i);
}
MemoryUnit memoryUnit = new MemoryUnit(unitUuid);
memoryUnits.put(unitUuid, memoryUnit);
}
}
@Override
public @NotNull State convert() {
State state = new State();
state.append("memory_session_id", StateValue.str(memorySessionId));
List<StateValue.Str> unitOverview = memoryUnits.keySet().stream()
.map(StateValue::str)
.toList();
state.append("memory_unit_uuid_set", StateValue.arr(unitOverview));
return state;
}
}

View File

@@ -5,6 +5,7 @@ import lombok.EqualsAndHashCode;
import work.slhaf.partner.framework.agent.common.entity.PersistableObject;
import java.io.Serial;
import java.util.UUID;
@EqualsAndHashCode(callSuper = true)
@Data
@@ -13,11 +14,19 @@ public class MemorySlice extends PersistableObject implements Comparable<MemoryS
@Serial
private static final long serialVersionUID = 1L;
private String id;
private Integer startIndex;
private Integer endIndex;
private String summary;
private Long timestamp;
private final String id;
private final Integer startIndex;
private final Integer endIndex;
private final String summary;
private final Long timestamp;
public MemorySlice(Integer startIndex, Integer endIndex, String summary) {
this.id = UUID.randomUUID().toString();
this.timestamp = System.currentTimeMillis();
this.startIndex = startIndex;
this.endIndex = endIndex;
this.summary = summary;
}
@Override
public int compareTo(MemorySlice memorySlice) {

View File

@@ -1,23 +1,24 @@
package work.slhaf.partner.core.memory.pojo;
import lombok.Data;
import lombok.EqualsAndHashCode;
import work.slhaf.partner.framework.agent.common.entity.PersistableObject;
import lombok.Getter;
import work.slhaf.partner.framework.agent.model.pojo.Message;
import java.io.Serial;
import java.util.ArrayList;
import java.util.List;
@EqualsAndHashCode(callSuper = true)
@Data
public class MemoryUnit extends PersistableObject {
@Getter
public class MemoryUnit {
@Serial
private static final long serialVersionUID = 1L;
private String id;
private List<Message> conversationMessages = new ArrayList<>();
private final String id;
private final List<Message> conversationMessages = new ArrayList<>();
private Long timestamp;
private List<MemorySlice> slices = new ArrayList<>();
private final List<MemorySlice> slices = new ArrayList<>();
public MemoryUnit(String id) {
this.id = id;
}
public void updateTimestamp() {
timestamp = System.currentTimeMillis();
}
}

View File

@@ -28,7 +28,6 @@ import work.slhaf.partner.module.memory.updater.summarizer.entity.SummarizeInput
import work.slhaf.partner.module.memory.updater.summarizer.entity.SummarizeResult;
import work.slhaf.partner.runtime.interaction.data.context.PartnerRunningFlowContext;
import java.util.ArrayList;
import java.util.List;
import java.util.UUID;
import java.util.concurrent.ExecutorService;
@@ -189,34 +188,23 @@ public class MemoryUpdater extends AbstractAgentModule.Running<PartnerRunningFlo
}
private MemoryUnit buildMemoryUnit(List<Message> chatMessages, SummarizeResult summarizeResult) {
long now = System.currentTimeMillis();
String memoryId = memoryCapability.getMemorySessionId();
String resolvedMemoryId = memoryId == null || memoryId.isBlank() ? UUID.randomUUID().toString() : memoryId;
MemoryUnit existingUnit = memoryCapability.getMemoryUnit(resolvedMemoryId);
List<Message> existingMessages = existingUnit != null && existingUnit.getConversationMessages() != null
? existingUnit.getConversationMessages()
: List.of();
List<Message> existingMessages = existingUnit.getConversationMessages();
int startIndex = existingMessages.size();
MemorySlice memorySlice = new MemorySlice();
memorySlice.setId(UUID.randomUUID().toString());
memorySlice.setStartIndex(startIndex);
memorySlice.setEndIndex(startIndex + chatMessages.size());
memorySlice.setSummary(summarizeResult.getSummary());
memorySlice.setTimestamp(now);
MemorySlice memorySlice = new MemorySlice(
startIndex,
startIndex + chatMessages.size(),
summarizeResult.getSummary()
);
MemoryUnit memoryUnit = new MemoryUnit();
memoryUnit.setId(resolvedMemoryId);
memoryUnit.setTimestamp(now);
List<Message> conversationMessages = new ArrayList<>(existingMessages);
conversationMessages.addAll(chatMessages);
memoryUnit.setConversationMessages(conversationMessages);
MemoryUnit memoryUnit = new MemoryUnit(resolvedMemoryId);
memoryUnit.updateTimestamp();
memoryUnit.getConversationMessages().addAll(chatMessages);
List<MemorySlice> slices = existingUnit != null && existingUnit.getSlices() != null
? new ArrayList<>(existingUnit.getSlices())
: new ArrayList<>();
slices.add(memorySlice);
memoryUnit.setSlices(slices);
memoryUnit.getSlices().add(memorySlice);
return memoryUnit;
}