mirror of
https://github.com/slhaf/Partner.git
synced 2026-05-12 08:43:02 +08:00
refactor(memory): manage state serialization via StateCenter in MemoryCore, and normalize slice and unit building
This commit is contained in:
@@ -1,44 +1,35 @@
|
||||
package work.slhaf.partner.core.memory;
|
||||
|
||||
import lombok.EqualsAndHashCode;
|
||||
import lombok.Getter;
|
||||
import lombok.Setter;
|
||||
import com.alibaba.fastjson2.JSONArray;
|
||||
import com.alibaba.fastjson2.JSONObject;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import work.slhaf.partner.core.PartnerCore;
|
||||
import org.jetbrains.annotations.NotNull;
|
||||
import work.slhaf.partner.core.memory.pojo.MemorySlice;
|
||||
import work.slhaf.partner.core.memory.pojo.MemoryUnit;
|
||||
import work.slhaf.partner.framework.agent.factory.capability.annotation.CapabilityCore;
|
||||
import work.slhaf.partner.framework.agent.factory.capability.annotation.CapabilityMethod;
|
||||
import work.slhaf.partner.framework.agent.state.State;
|
||||
import work.slhaf.partner.framework.agent.state.StateSerializable;
|
||||
import work.slhaf.partner.framework.agent.state.StateValue;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.io.Serial;
|
||||
import java.time.Instant;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collection;
|
||||
import java.util.Comparator;
|
||||
import java.util.UUID;
|
||||
import java.nio.file.Path;
|
||||
import java.util.*;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
import java.util.concurrent.locks.Lock;
|
||||
import java.util.concurrent.locks.ReentrantLock;
|
||||
|
||||
@EqualsAndHashCode(callSuper = true)
|
||||
@CapabilityCore(value = "memory")
|
||||
@Getter
|
||||
@Setter
|
||||
@Slf4j
|
||||
public class MemoryCore extends PartnerCore<MemoryCore> {
|
||||
|
||||
@Serial
|
||||
private static final long serialVersionUID = 1L;
|
||||
public class MemoryCore implements StateSerializable {
|
||||
|
||||
private final Lock memoryLock = new ReentrantLock();
|
||||
private ConcurrentHashMap<String, MemoryUnit> memoryUnits = new ConcurrentHashMap<>();
|
||||
private final ConcurrentHashMap<String, MemoryUnit> memoryUnits = new ConcurrentHashMap<>();
|
||||
|
||||
// 默认值一般只存在于智能体初次启动时
|
||||
private String memorySessionId = UUID.randomUUID().toString();
|
||||
private Instant memorySessionStartTime = Instant.now();
|
||||
|
||||
public MemoryCore() throws IOException, ClassNotFoundException {
|
||||
public MemoryCore() {
|
||||
register();
|
||||
}
|
||||
|
||||
@CapabilityMethod
|
||||
@@ -54,7 +45,7 @@ public class MemoryCore extends PartnerCore<MemoryCore> {
|
||||
|
||||
@CapabilityMethod
|
||||
public MemoryUnit getMemoryUnit(String unitId) {
|
||||
return memoryUnits.get(unitId);
|
||||
return memoryUnits.computeIfAbsent(unitId, MemoryUnit::new);
|
||||
}
|
||||
|
||||
@CapabilityMethod
|
||||
@@ -79,7 +70,6 @@ public class MemoryCore extends PartnerCore<MemoryCore> {
|
||||
@CapabilityMethod
|
||||
public void refreshMemorySession() {
|
||||
memorySessionId = UUID.randomUUID().toString();
|
||||
memorySessionStartTime = Instant.now();
|
||||
}
|
||||
|
||||
@CapabilityMethod
|
||||
@@ -88,44 +78,43 @@ public class MemoryCore extends PartnerCore<MemoryCore> {
|
||||
}
|
||||
|
||||
private void normalizeMemoryUnit(MemoryUnit memoryUnit) {
|
||||
if (memoryUnit.getId() == null || memoryUnit.getId().isBlank()) {
|
||||
memoryUnit.setId(UUID.randomUUID().toString());
|
||||
}
|
||||
if (memoryUnit.getTimestamp() == null || memoryUnit.getTimestamp() <= 0) {
|
||||
memoryUnit.setTimestamp(System.currentTimeMillis());
|
||||
}
|
||||
if (memoryUnit.getConversationMessages() == null) {
|
||||
memoryUnit.setConversationMessages(new ArrayList<>());
|
||||
}
|
||||
if (memoryUnit.getSlices() == null) {
|
||||
memoryUnit.setSlices(new ArrayList<>());
|
||||
}
|
||||
int maxEndExclusive = Math.max(memoryUnit.getConversationMessages().size(), 0);
|
||||
for (MemorySlice slice : memoryUnit.getSlices()) {
|
||||
if (slice.getId() == null || slice.getId().isBlank()) {
|
||||
slice.setId(UUID.randomUUID().toString());
|
||||
}
|
||||
if (slice.getTimestamp() == null || slice.getTimestamp() <= 0) {
|
||||
slice.setTimestamp(memoryUnit.getTimestamp());
|
||||
}
|
||||
if (slice.getStartIndex() == null || slice.getStartIndex() < 0) {
|
||||
slice.setStartIndex(0);
|
||||
}
|
||||
if (slice.getStartIndex() > maxEndExclusive) {
|
||||
slice.setStartIndex(maxEndExclusive);
|
||||
}
|
||||
if (slice.getEndIndex() == null || slice.getEndIndex() < slice.getStartIndex()) {
|
||||
slice.setEndIndex(maxEndExclusive);
|
||||
}
|
||||
if (slice.getEndIndex() > maxEndExclusive) {
|
||||
slice.setEndIndex(maxEndExclusive);
|
||||
}
|
||||
}
|
||||
memoryUnit.getSlices().sort(Comparator.naturalOrder());
|
||||
}
|
||||
|
||||
@Override
|
||||
protected String getCoreKey() {
|
||||
return "memory-core";
|
||||
public @NotNull Path statePath() {
|
||||
return Path.of("core", "memory.json");
|
||||
}
|
||||
|
||||
@Override
|
||||
public void load(@NotNull JSONObject state) {
|
||||
String memorySessionId = state.getString("memory_session_id");
|
||||
if (memorySessionId == null) {
|
||||
throw new IllegalStateException("Memory session id is missing");
|
||||
}
|
||||
JSONArray array = state.getJSONArray("memory_unit_uuid_set");
|
||||
if (array == null) {
|
||||
throw new IllegalStateException("Memory unit uuid set is missing");
|
||||
}
|
||||
for (int i = 0; i < array.size(); i++) {
|
||||
String unitUuid = array.getString(i);
|
||||
if (unitUuid == null) {
|
||||
throw new IllegalStateException("memory_unit_uuid_set is not a uuid array, index: " + i);
|
||||
}
|
||||
MemoryUnit memoryUnit = new MemoryUnit(unitUuid);
|
||||
memoryUnits.put(unitUuid, memoryUnit);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public @NotNull State convert() {
|
||||
State state = new State();
|
||||
state.append("memory_session_id", StateValue.str(memorySessionId));
|
||||
|
||||
List<StateValue.Str> unitOverview = memoryUnits.keySet().stream()
|
||||
.map(StateValue::str)
|
||||
.toList();
|
||||
state.append("memory_unit_uuid_set", StateValue.arr(unitOverview));
|
||||
return state;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@ import lombok.EqualsAndHashCode;
|
||||
import work.slhaf.partner.framework.agent.common.entity.PersistableObject;
|
||||
|
||||
import java.io.Serial;
|
||||
import java.util.UUID;
|
||||
|
||||
@EqualsAndHashCode(callSuper = true)
|
||||
@Data
|
||||
@@ -13,11 +14,19 @@ public class MemorySlice extends PersistableObject implements Comparable<MemoryS
|
||||
@Serial
|
||||
private static final long serialVersionUID = 1L;
|
||||
|
||||
private String id;
|
||||
private Integer startIndex;
|
||||
private Integer endIndex;
|
||||
private String summary;
|
||||
private Long timestamp;
|
||||
private final String id;
|
||||
private final Integer startIndex;
|
||||
private final Integer endIndex;
|
||||
private final String summary;
|
||||
private final Long timestamp;
|
||||
|
||||
public MemorySlice(Integer startIndex, Integer endIndex, String summary) {
|
||||
this.id = UUID.randomUUID().toString();
|
||||
this.timestamp = System.currentTimeMillis();
|
||||
this.startIndex = startIndex;
|
||||
this.endIndex = endIndex;
|
||||
this.summary = summary;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int compareTo(MemorySlice memorySlice) {
|
||||
|
||||
@@ -1,23 +1,24 @@
|
||||
package work.slhaf.partner.core.memory.pojo;
|
||||
|
||||
import lombok.Data;
|
||||
import lombok.EqualsAndHashCode;
|
||||
import work.slhaf.partner.framework.agent.common.entity.PersistableObject;
|
||||
import lombok.Getter;
|
||||
import work.slhaf.partner.framework.agent.model.pojo.Message;
|
||||
|
||||
import java.io.Serial;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
@EqualsAndHashCode(callSuper = true)
|
||||
@Data
|
||||
public class MemoryUnit extends PersistableObject {
|
||||
@Getter
|
||||
public class MemoryUnit {
|
||||
|
||||
@Serial
|
||||
private static final long serialVersionUID = 1L;
|
||||
|
||||
private String id;
|
||||
private List<Message> conversationMessages = new ArrayList<>();
|
||||
private final String id;
|
||||
private final List<Message> conversationMessages = new ArrayList<>();
|
||||
private Long timestamp;
|
||||
private List<MemorySlice> slices = new ArrayList<>();
|
||||
private final List<MemorySlice> slices = new ArrayList<>();
|
||||
|
||||
public MemoryUnit(String id) {
|
||||
this.id = id;
|
||||
}
|
||||
|
||||
public void updateTimestamp() {
|
||||
timestamp = System.currentTimeMillis();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -28,7 +28,6 @@ import work.slhaf.partner.module.memory.updater.summarizer.entity.SummarizeInput
|
||||
import work.slhaf.partner.module.memory.updater.summarizer.entity.SummarizeResult;
|
||||
import work.slhaf.partner.runtime.interaction.data.context.PartnerRunningFlowContext;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.UUID;
|
||||
import java.util.concurrent.ExecutorService;
|
||||
@@ -189,34 +188,23 @@ public class MemoryUpdater extends AbstractAgentModule.Running<PartnerRunningFlo
|
||||
}
|
||||
|
||||
private MemoryUnit buildMemoryUnit(List<Message> chatMessages, SummarizeResult summarizeResult) {
|
||||
long now = System.currentTimeMillis();
|
||||
String memoryId = memoryCapability.getMemorySessionId();
|
||||
String resolvedMemoryId = memoryId == null || memoryId.isBlank() ? UUID.randomUUID().toString() : memoryId;
|
||||
MemoryUnit existingUnit = memoryCapability.getMemoryUnit(resolvedMemoryId);
|
||||
List<Message> existingMessages = existingUnit != null && existingUnit.getConversationMessages() != null
|
||||
? existingUnit.getConversationMessages()
|
||||
: List.of();
|
||||
List<Message> existingMessages = existingUnit.getConversationMessages();
|
||||
int startIndex = existingMessages.size();
|
||||
|
||||
MemorySlice memorySlice = new MemorySlice();
|
||||
memorySlice.setId(UUID.randomUUID().toString());
|
||||
memorySlice.setStartIndex(startIndex);
|
||||
memorySlice.setEndIndex(startIndex + chatMessages.size());
|
||||
memorySlice.setSummary(summarizeResult.getSummary());
|
||||
memorySlice.setTimestamp(now);
|
||||
MemorySlice memorySlice = new MemorySlice(
|
||||
startIndex,
|
||||
startIndex + chatMessages.size(),
|
||||
summarizeResult.getSummary()
|
||||
);
|
||||
|
||||
MemoryUnit memoryUnit = new MemoryUnit();
|
||||
memoryUnit.setId(resolvedMemoryId);
|
||||
memoryUnit.setTimestamp(now);
|
||||
List<Message> conversationMessages = new ArrayList<>(existingMessages);
|
||||
conversationMessages.addAll(chatMessages);
|
||||
memoryUnit.setConversationMessages(conversationMessages);
|
||||
MemoryUnit memoryUnit = new MemoryUnit(resolvedMemoryId);
|
||||
memoryUnit.updateTimestamp();
|
||||
memoryUnit.getConversationMessages().addAll(chatMessages);
|
||||
|
||||
List<MemorySlice> slices = existingUnit != null && existingUnit.getSlices() != null
|
||||
? new ArrayList<>(existingUnit.getSlices())
|
||||
: new ArrayList<>();
|
||||
slices.add(memorySlice);
|
||||
memoryUnit.setSlices(slices);
|
||||
memoryUnit.getSlices().add(memorySlice);
|
||||
return memoryUnit;
|
||||
}
|
||||
|
||||
|
||||
@@ -11,7 +11,6 @@ import work.slhaf.partner.framework.agent.model.pojo.Message;
|
||||
|
||||
import java.nio.file.Files;
|
||||
import java.nio.file.Path;
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.UUID;
|
||||
@@ -57,14 +56,13 @@ class MemoryCoreTest {
|
||||
slice.setStartIndex(1);
|
||||
slice.setEndIndex(99);
|
||||
|
||||
MemoryUnit unit = new MemoryUnit();
|
||||
unit.setId("unit-1");
|
||||
unit.setConversationMessages(new ArrayList<>(List.of(
|
||||
MemoryUnit unit = new MemoryUnit("unit-1");
|
||||
unit.getConversationMessages().addAll(List.of(
|
||||
new Message(Message.Character.USER, "m0"),
|
||||
new Message(Message.Character.USER, "m1"),
|
||||
new Message(Message.Character.USER, "m2")
|
||||
)));
|
||||
unit.setSlices(new ArrayList<>(List.of(slice)));
|
||||
));
|
||||
unit.getSlices().add(slice);
|
||||
|
||||
memoryCore.saveMemoryUnit(unit);
|
||||
|
||||
|
||||
@@ -74,13 +74,13 @@ class MemoryRuntimeTest {
|
||||
@Test
|
||||
void shouldSliceMessagesUsingLeftClosedRightOpenRange() throws Exception {
|
||||
MemoryRuntime runtime = new MemoryRuntime();
|
||||
MemoryUnit unit = new MemoryUnit();
|
||||
unit.setConversationMessages(new ArrayList<>(List.of(
|
||||
MemoryUnit unit = new MemoryUnit("unit-1");
|
||||
unit.getConversationMessages().addAll(List.of(
|
||||
message("m0"),
|
||||
message("m1"),
|
||||
message("m2"),
|
||||
message("m3")
|
||||
)));
|
||||
));
|
||||
|
||||
MemorySlice slice = new MemorySlice();
|
||||
slice.setStartIndex(1);
|
||||
@@ -105,14 +105,13 @@ class MemoryRuntimeTest {
|
||||
MemoryRuntime runtime = new MemoryRuntime();
|
||||
setField(runtime, "memoryCapability", memoryCapability);
|
||||
|
||||
MemoryUnit unit = new MemoryUnit();
|
||||
unit.setId("unit-1");
|
||||
unit.setConversationMessages(new ArrayList<>(List.of(
|
||||
MemoryUnit unit = new MemoryUnit("unit-1");
|
||||
unit.getConversationMessages().addAll(List.of(
|
||||
message("m0"),
|
||||
message("m1"),
|
||||
message("m2"),
|
||||
message("m3")
|
||||
)));
|
||||
));
|
||||
|
||||
MemorySlice firstSlice = new MemorySlice();
|
||||
firstSlice.setId("slice-1");
|
||||
@@ -128,7 +127,7 @@ class MemoryRuntimeTest {
|
||||
secondSlice.setSummary("second");
|
||||
secondSlice.setTimestamp(2L);
|
||||
|
||||
unit.setSlices(new ArrayList<>(List.of(firstSlice, secondSlice)));
|
||||
unit.getSlices().addAll(List.of(firstSlice, secondSlice));
|
||||
|
||||
runtime.recordMemory(unit, "topic/main", List.of("topic/related"));
|
||||
|
||||
|
||||
@@ -9,7 +9,10 @@ import work.slhaf.partner.module.memory.updater.summarizer.entity.SummarizeResul
|
||||
|
||||
import java.lang.reflect.Field;
|
||||
import java.lang.reflect.Method;
|
||||
import java.util.*;
|
||||
import java.util.Collection;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import static org.junit.jupiter.api.Assertions.assertNotNull;
|
||||
@@ -55,19 +58,18 @@ class MemoryUpdaterTest {
|
||||
setField(updater, "memoryCapability", memoryCapability);
|
||||
|
||||
String sessionId = memoryCapability.getMemorySessionId();
|
||||
MemoryUnit existingUnit = new MemoryUnit();
|
||||
existingUnit.setId(sessionId);
|
||||
existingUnit.setConversationMessages(new ArrayList<>(List.of(
|
||||
MemoryUnit existingUnit = new MemoryUnit(sessionId);
|
||||
existingUnit.getConversationMessages().addAll(List.of(
|
||||
message(Message.Character.USER, "old-user"),
|
||||
message(Message.Character.ASSISTANT, "old-assistant")
|
||||
)));
|
||||
));
|
||||
MemorySlice existingSlice = new MemorySlice();
|
||||
existingSlice.setId("slice-1");
|
||||
existingSlice.setStartIndex(0);
|
||||
existingSlice.setEndIndex(2);
|
||||
existingSlice.setSummary("old-summary");
|
||||
existingSlice.setTimestamp(1L);
|
||||
existingUnit.setSlices(new ArrayList<>(List.of(existingSlice)));
|
||||
existingUnit.getSlices().add(existingSlice);
|
||||
memoryCapability.saveMemoryUnit(existingUnit);
|
||||
|
||||
MemoryUnit merged = invokeBuildMemoryUnit(
|
||||
@@ -121,14 +123,13 @@ class MemoryUpdaterTest {
|
||||
MemoryUpdater updater = new MemoryUpdater();
|
||||
setField(updater, "memoryCapability", memoryCapability);
|
||||
|
||||
MemoryUnit existingUnit = new MemoryUnit();
|
||||
existingUnit.setId("session-3");
|
||||
existingUnit.setConversationMessages(new ArrayList<>(List.of(
|
||||
MemoryUnit existingUnit = new MemoryUnit("session-3");
|
||||
existingUnit.getConversationMessages().addAll(List.of(
|
||||
message(Message.Character.USER, "m1"),
|
||||
message(Message.Character.ASSISTANT, "m2"),
|
||||
message(Message.Character.USER, "m3"),
|
||||
message(Message.Character.ASSISTANT, "m4")
|
||||
)));
|
||||
));
|
||||
memoryCapability.saveMemoryUnit(existingUnit);
|
||||
|
||||
List<Message> increment = invokeResolveChatIncrement(
|
||||
@@ -150,13 +151,12 @@ class MemoryUpdaterTest {
|
||||
MemoryUpdater updater = new MemoryUpdater();
|
||||
setField(updater, "memoryCapability", memoryCapability);
|
||||
|
||||
MemoryUnit existingUnit = new MemoryUnit();
|
||||
existingUnit.setId("session-4");
|
||||
existingUnit.setConversationMessages(new ArrayList<>(List.of(
|
||||
MemoryUnit existingUnit = new MemoryUnit("session-4");
|
||||
existingUnit.getConversationMessages().addAll(List.of(
|
||||
message(Message.Character.USER, "m1"),
|
||||
message(Message.Character.ASSISTANT, "m2"),
|
||||
message(Message.Character.USER, "m3")
|
||||
)));
|
||||
));
|
||||
memoryCapability.saveMemoryUnit(existingUnit);
|
||||
|
||||
List<Message> increment = invokeResolveChatIncrement(
|
||||
|
||||
Reference in New Issue
Block a user