mirror of
https://github.com/slhaf/Partner.git
synced 2026-05-12 16:53:04 +08:00
refactor(memory): manage state serialization via StateCenter in MemoryCore, and normalize slice and unit building
This commit is contained in:
@@ -1,44 +1,35 @@
|
||||
package work.slhaf.partner.core.memory;
|
||||
|
||||
import lombok.EqualsAndHashCode;
|
||||
import lombok.Getter;
|
||||
import lombok.Setter;
|
||||
import com.alibaba.fastjson2.JSONArray;
|
||||
import com.alibaba.fastjson2.JSONObject;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import work.slhaf.partner.core.PartnerCore;
|
||||
import org.jetbrains.annotations.NotNull;
|
||||
import work.slhaf.partner.core.memory.pojo.MemorySlice;
|
||||
import work.slhaf.partner.core.memory.pojo.MemoryUnit;
|
||||
import work.slhaf.partner.framework.agent.factory.capability.annotation.CapabilityCore;
|
||||
import work.slhaf.partner.framework.agent.factory.capability.annotation.CapabilityMethod;
|
||||
import work.slhaf.partner.framework.agent.state.State;
|
||||
import work.slhaf.partner.framework.agent.state.StateSerializable;
|
||||
import work.slhaf.partner.framework.agent.state.StateValue;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.io.Serial;
|
||||
import java.time.Instant;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collection;
|
||||
import java.util.Comparator;
|
||||
import java.util.UUID;
|
||||
import java.nio.file.Path;
|
||||
import java.util.*;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
import java.util.concurrent.locks.Lock;
|
||||
import java.util.concurrent.locks.ReentrantLock;
|
||||
|
||||
@EqualsAndHashCode(callSuper = true)
|
||||
@CapabilityCore(value = "memory")
|
||||
@Getter
|
||||
@Setter
|
||||
@Slf4j
|
||||
public class MemoryCore extends PartnerCore<MemoryCore> {
|
||||
|
||||
@Serial
|
||||
private static final long serialVersionUID = 1L;
|
||||
public class MemoryCore implements StateSerializable {
|
||||
|
||||
private final Lock memoryLock = new ReentrantLock();
|
||||
private ConcurrentHashMap<String, MemoryUnit> memoryUnits = new ConcurrentHashMap<>();
|
||||
private final ConcurrentHashMap<String, MemoryUnit> memoryUnits = new ConcurrentHashMap<>();
|
||||
|
||||
// 默认值一般只存在于智能体初次启动时
|
||||
private String memorySessionId = UUID.randomUUID().toString();
|
||||
private Instant memorySessionStartTime = Instant.now();
|
||||
|
||||
public MemoryCore() throws IOException, ClassNotFoundException {
|
||||
public MemoryCore() {
|
||||
register();
|
||||
}
|
||||
|
||||
@CapabilityMethod
|
||||
@@ -54,7 +45,7 @@ public class MemoryCore extends PartnerCore<MemoryCore> {
|
||||
|
||||
@CapabilityMethod
|
||||
public MemoryUnit getMemoryUnit(String unitId) {
|
||||
return memoryUnits.get(unitId);
|
||||
return memoryUnits.computeIfAbsent(unitId, MemoryUnit::new);
|
||||
}
|
||||
|
||||
@CapabilityMethod
|
||||
@@ -79,7 +70,6 @@ public class MemoryCore extends PartnerCore<MemoryCore> {
|
||||
@CapabilityMethod
|
||||
public void refreshMemorySession() {
|
||||
memorySessionId = UUID.randomUUID().toString();
|
||||
memorySessionStartTime = Instant.now();
|
||||
}
|
||||
|
||||
@CapabilityMethod
|
||||
@@ -88,44 +78,43 @@ public class MemoryCore extends PartnerCore<MemoryCore> {
|
||||
}
|
||||
|
||||
private void normalizeMemoryUnit(MemoryUnit memoryUnit) {
|
||||
if (memoryUnit.getId() == null || memoryUnit.getId().isBlank()) {
|
||||
memoryUnit.setId(UUID.randomUUID().toString());
|
||||
}
|
||||
if (memoryUnit.getTimestamp() == null || memoryUnit.getTimestamp() <= 0) {
|
||||
memoryUnit.setTimestamp(System.currentTimeMillis());
|
||||
}
|
||||
if (memoryUnit.getConversationMessages() == null) {
|
||||
memoryUnit.setConversationMessages(new ArrayList<>());
|
||||
}
|
||||
if (memoryUnit.getSlices() == null) {
|
||||
memoryUnit.setSlices(new ArrayList<>());
|
||||
}
|
||||
int maxEndExclusive = Math.max(memoryUnit.getConversationMessages().size(), 0);
|
||||
for (MemorySlice slice : memoryUnit.getSlices()) {
|
||||
if (slice.getId() == null || slice.getId().isBlank()) {
|
||||
slice.setId(UUID.randomUUID().toString());
|
||||
}
|
||||
if (slice.getTimestamp() == null || slice.getTimestamp() <= 0) {
|
||||
slice.setTimestamp(memoryUnit.getTimestamp());
|
||||
}
|
||||
if (slice.getStartIndex() == null || slice.getStartIndex() < 0) {
|
||||
slice.setStartIndex(0);
|
||||
}
|
||||
if (slice.getStartIndex() > maxEndExclusive) {
|
||||
slice.setStartIndex(maxEndExclusive);
|
||||
}
|
||||
if (slice.getEndIndex() == null || slice.getEndIndex() < slice.getStartIndex()) {
|
||||
slice.setEndIndex(maxEndExclusive);
|
||||
}
|
||||
if (slice.getEndIndex() > maxEndExclusive) {
|
||||
slice.setEndIndex(maxEndExclusive);
|
||||
}
|
||||
}
|
||||
memoryUnit.getSlices().sort(Comparator.naturalOrder());
|
||||
}
|
||||
|
||||
@Override
|
||||
protected String getCoreKey() {
|
||||
return "memory-core";
|
||||
public @NotNull Path statePath() {
|
||||
return Path.of("core", "memory.json");
|
||||
}
|
||||
|
||||
@Override
|
||||
public void load(@NotNull JSONObject state) {
|
||||
String memorySessionId = state.getString("memory_session_id");
|
||||
if (memorySessionId == null) {
|
||||
throw new IllegalStateException("Memory session id is missing");
|
||||
}
|
||||
JSONArray array = state.getJSONArray("memory_unit_uuid_set");
|
||||
if (array == null) {
|
||||
throw new IllegalStateException("Memory unit uuid set is missing");
|
||||
}
|
||||
for (int i = 0; i < array.size(); i++) {
|
||||
String unitUuid = array.getString(i);
|
||||
if (unitUuid == null) {
|
||||
throw new IllegalStateException("memory_unit_uuid_set is not a uuid array, index: " + i);
|
||||
}
|
||||
MemoryUnit memoryUnit = new MemoryUnit(unitUuid);
|
||||
memoryUnits.put(unitUuid, memoryUnit);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public @NotNull State convert() {
|
||||
State state = new State();
|
||||
state.append("memory_session_id", StateValue.str(memorySessionId));
|
||||
|
||||
List<StateValue.Str> unitOverview = memoryUnits.keySet().stream()
|
||||
.map(StateValue::str)
|
||||
.toList();
|
||||
state.append("memory_unit_uuid_set", StateValue.arr(unitOverview));
|
||||
return state;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@ import lombok.EqualsAndHashCode;
|
||||
import work.slhaf.partner.framework.agent.common.entity.PersistableObject;
|
||||
|
||||
import java.io.Serial;
|
||||
import java.util.UUID;
|
||||
|
||||
@EqualsAndHashCode(callSuper = true)
|
||||
@Data
|
||||
@@ -13,11 +14,19 @@ public class MemorySlice extends PersistableObject implements Comparable<MemoryS
|
||||
@Serial
|
||||
private static final long serialVersionUID = 1L;
|
||||
|
||||
private String id;
|
||||
private Integer startIndex;
|
||||
private Integer endIndex;
|
||||
private String summary;
|
||||
private Long timestamp;
|
||||
private final String id;
|
||||
private final Integer startIndex;
|
||||
private final Integer endIndex;
|
||||
private final String summary;
|
||||
private final Long timestamp;
|
||||
|
||||
public MemorySlice(Integer startIndex, Integer endIndex, String summary) {
|
||||
this.id = UUID.randomUUID().toString();
|
||||
this.timestamp = System.currentTimeMillis();
|
||||
this.startIndex = startIndex;
|
||||
this.endIndex = endIndex;
|
||||
this.summary = summary;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int compareTo(MemorySlice memorySlice) {
|
||||
|
||||
@@ -1,23 +1,24 @@
|
||||
package work.slhaf.partner.core.memory.pojo;
|
||||
|
||||
import lombok.Data;
|
||||
import lombok.EqualsAndHashCode;
|
||||
import work.slhaf.partner.framework.agent.common.entity.PersistableObject;
|
||||
import lombok.Getter;
|
||||
import work.slhaf.partner.framework.agent.model.pojo.Message;
|
||||
|
||||
import java.io.Serial;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
@EqualsAndHashCode(callSuper = true)
|
||||
@Data
|
||||
public class MemoryUnit extends PersistableObject {
|
||||
@Getter
|
||||
public class MemoryUnit {
|
||||
|
||||
@Serial
|
||||
private static final long serialVersionUID = 1L;
|
||||
|
||||
private String id;
|
||||
private List<Message> conversationMessages = new ArrayList<>();
|
||||
private final String id;
|
||||
private final List<Message> conversationMessages = new ArrayList<>();
|
||||
private Long timestamp;
|
||||
private List<MemorySlice> slices = new ArrayList<>();
|
||||
private final List<MemorySlice> slices = new ArrayList<>();
|
||||
|
||||
public MemoryUnit(String id) {
|
||||
this.id = id;
|
||||
}
|
||||
|
||||
public void updateTimestamp() {
|
||||
timestamp = System.currentTimeMillis();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -28,7 +28,6 @@ import work.slhaf.partner.module.memory.updater.summarizer.entity.SummarizeInput
|
||||
import work.slhaf.partner.module.memory.updater.summarizer.entity.SummarizeResult;
|
||||
import work.slhaf.partner.runtime.interaction.data.context.PartnerRunningFlowContext;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.UUID;
|
||||
import java.util.concurrent.ExecutorService;
|
||||
@@ -189,34 +188,23 @@ public class MemoryUpdater extends AbstractAgentModule.Running<PartnerRunningFlo
|
||||
}
|
||||
|
||||
private MemoryUnit buildMemoryUnit(List<Message> chatMessages, SummarizeResult summarizeResult) {
|
||||
long now = System.currentTimeMillis();
|
||||
String memoryId = memoryCapability.getMemorySessionId();
|
||||
String resolvedMemoryId = memoryId == null || memoryId.isBlank() ? UUID.randomUUID().toString() : memoryId;
|
||||
MemoryUnit existingUnit = memoryCapability.getMemoryUnit(resolvedMemoryId);
|
||||
List<Message> existingMessages = existingUnit != null && existingUnit.getConversationMessages() != null
|
||||
? existingUnit.getConversationMessages()
|
||||
: List.of();
|
||||
List<Message> existingMessages = existingUnit.getConversationMessages();
|
||||
int startIndex = existingMessages.size();
|
||||
|
||||
MemorySlice memorySlice = new MemorySlice();
|
||||
memorySlice.setId(UUID.randomUUID().toString());
|
||||
memorySlice.setStartIndex(startIndex);
|
||||
memorySlice.setEndIndex(startIndex + chatMessages.size());
|
||||
memorySlice.setSummary(summarizeResult.getSummary());
|
||||
memorySlice.setTimestamp(now);
|
||||
MemorySlice memorySlice = new MemorySlice(
|
||||
startIndex,
|
||||
startIndex + chatMessages.size(),
|
||||
summarizeResult.getSummary()
|
||||
);
|
||||
|
||||
MemoryUnit memoryUnit = new MemoryUnit();
|
||||
memoryUnit.setId(resolvedMemoryId);
|
||||
memoryUnit.setTimestamp(now);
|
||||
List<Message> conversationMessages = new ArrayList<>(existingMessages);
|
||||
conversationMessages.addAll(chatMessages);
|
||||
memoryUnit.setConversationMessages(conversationMessages);
|
||||
MemoryUnit memoryUnit = new MemoryUnit(resolvedMemoryId);
|
||||
memoryUnit.updateTimestamp();
|
||||
memoryUnit.getConversationMessages().addAll(chatMessages);
|
||||
|
||||
List<MemorySlice> slices = existingUnit != null && existingUnit.getSlices() != null
|
||||
? new ArrayList<>(existingUnit.getSlices())
|
||||
: new ArrayList<>();
|
||||
slices.add(memorySlice);
|
||||
memoryUnit.setSlices(slices);
|
||||
memoryUnit.getSlices().add(memorySlice);
|
||||
return memoryUnit;
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user