diff --git a/Partner-Core/src/main/java/work/slhaf/partner/core/memory/MemoryCore.java b/Partner-Core/src/main/java/work/slhaf/partner/core/memory/MemoryCore.java index f0970daa..53ac9feb 100644 --- a/Partner-Core/src/main/java/work/slhaf/partner/core/memory/MemoryCore.java +++ b/Partner-Core/src/main/java/work/slhaf/partner/core/memory/MemoryCore.java @@ -1,44 +1,35 @@ package work.slhaf.partner.core.memory; -import lombok.EqualsAndHashCode; -import lombok.Getter; -import lombok.Setter; +import com.alibaba.fastjson2.JSONArray; +import com.alibaba.fastjson2.JSONObject; import lombok.extern.slf4j.Slf4j; -import work.slhaf.partner.core.PartnerCore; +import org.jetbrains.annotations.NotNull; import work.slhaf.partner.core.memory.pojo.MemorySlice; import work.slhaf.partner.core.memory.pojo.MemoryUnit; import work.slhaf.partner.framework.agent.factory.capability.annotation.CapabilityCore; import work.slhaf.partner.framework.agent.factory.capability.annotation.CapabilityMethod; +import work.slhaf.partner.framework.agent.state.State; +import work.slhaf.partner.framework.agent.state.StateSerializable; +import work.slhaf.partner.framework.agent.state.StateValue; -import java.io.IOException; -import java.io.Serial; -import java.time.Instant; -import java.util.ArrayList; -import java.util.Collection; -import java.util.Comparator; -import java.util.UUID; +import java.nio.file.Path; +import java.util.*; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.locks.Lock; import java.util.concurrent.locks.ReentrantLock; -@EqualsAndHashCode(callSuper = true) @CapabilityCore(value = "memory") -@Getter -@Setter @Slf4j -public class MemoryCore extends PartnerCore { - - @Serial - private static final long serialVersionUID = 1L; +public class MemoryCore implements StateSerializable { private final Lock memoryLock = new ReentrantLock(); - private ConcurrentHashMap memoryUnits = new ConcurrentHashMap<>(); + private final ConcurrentHashMap memoryUnits = new ConcurrentHashMap<>(); // 默认值一般只存在于智能体初次启动时 private String memorySessionId = UUID.randomUUID().toString(); - private Instant memorySessionStartTime = Instant.now(); - public MemoryCore() throws IOException, ClassNotFoundException { + public MemoryCore() { + register(); } @CapabilityMethod @@ -54,7 +45,7 @@ public class MemoryCore extends PartnerCore { @CapabilityMethod public MemoryUnit getMemoryUnit(String unitId) { - return memoryUnits.get(unitId); + return memoryUnits.computeIfAbsent(unitId, MemoryUnit::new); } @CapabilityMethod @@ -79,7 +70,6 @@ public class MemoryCore extends PartnerCore { @CapabilityMethod public void refreshMemorySession() { memorySessionId = UUID.randomUUID().toString(); - memorySessionStartTime = Instant.now(); } @CapabilityMethod @@ -88,44 +78,43 @@ public class MemoryCore extends PartnerCore { } private void normalizeMemoryUnit(MemoryUnit memoryUnit) { - if (memoryUnit.getId() == null || memoryUnit.getId().isBlank()) { - memoryUnit.setId(UUID.randomUUID().toString()); - } - if (memoryUnit.getTimestamp() == null || memoryUnit.getTimestamp() <= 0) { - memoryUnit.setTimestamp(System.currentTimeMillis()); - } - if (memoryUnit.getConversationMessages() == null) { - memoryUnit.setConversationMessages(new ArrayList<>()); - } - if (memoryUnit.getSlices() == null) { - memoryUnit.setSlices(new ArrayList<>()); - } - int maxEndExclusive = Math.max(memoryUnit.getConversationMessages().size(), 0); - for (MemorySlice slice : memoryUnit.getSlices()) { - if (slice.getId() == null || slice.getId().isBlank()) { - slice.setId(UUID.randomUUID().toString()); - } - if (slice.getTimestamp() == null || slice.getTimestamp() <= 0) { - slice.setTimestamp(memoryUnit.getTimestamp()); - } - if (slice.getStartIndex() == null || slice.getStartIndex() < 0) { - slice.setStartIndex(0); - } - if (slice.getStartIndex() > maxEndExclusive) { - slice.setStartIndex(maxEndExclusive); - } - if (slice.getEndIndex() == null || slice.getEndIndex() < slice.getStartIndex()) { - slice.setEndIndex(maxEndExclusive); - } - if (slice.getEndIndex() > maxEndExclusive) { - slice.setEndIndex(maxEndExclusive); - } - } memoryUnit.getSlices().sort(Comparator.naturalOrder()); } @Override - protected String getCoreKey() { - return "memory-core"; + public @NotNull Path statePath() { + return Path.of("core", "memory.json"); + } + + @Override + public void load(@NotNull JSONObject state) { + String memorySessionId = state.getString("memory_session_id"); + if (memorySessionId == null) { + throw new IllegalStateException("Memory session id is missing"); + } + JSONArray array = state.getJSONArray("memory_unit_uuid_set"); + if (array == null) { + throw new IllegalStateException("Memory unit uuid set is missing"); + } + for (int i = 0; i < array.size(); i++) { + String unitUuid = array.getString(i); + if (unitUuid == null) { + throw new IllegalStateException("memory_unit_uuid_set is not a uuid array, index: " + i); + } + MemoryUnit memoryUnit = new MemoryUnit(unitUuid); + memoryUnits.put(unitUuid, memoryUnit); + } + } + + @Override + public @NotNull State convert() { + State state = new State(); + state.append("memory_session_id", StateValue.str(memorySessionId)); + + List unitOverview = memoryUnits.keySet().stream() + .map(StateValue::str) + .toList(); + state.append("memory_unit_uuid_set", StateValue.arr(unitOverview)); + return state; } } diff --git a/Partner-Core/src/main/java/work/slhaf/partner/core/memory/pojo/MemorySlice.java b/Partner-Core/src/main/java/work/slhaf/partner/core/memory/pojo/MemorySlice.java index a3fd28a9..f46d7e7c 100644 --- a/Partner-Core/src/main/java/work/slhaf/partner/core/memory/pojo/MemorySlice.java +++ b/Partner-Core/src/main/java/work/slhaf/partner/core/memory/pojo/MemorySlice.java @@ -5,6 +5,7 @@ import lombok.EqualsAndHashCode; import work.slhaf.partner.framework.agent.common.entity.PersistableObject; import java.io.Serial; +import java.util.UUID; @EqualsAndHashCode(callSuper = true) @Data @@ -13,11 +14,19 @@ public class MemorySlice extends PersistableObject implements Comparable conversationMessages = new ArrayList<>(); + private final String id; + private final List conversationMessages = new ArrayList<>(); private Long timestamp; - private List slices = new ArrayList<>(); + private final List slices = new ArrayList<>(); + + public MemoryUnit(String id) { + this.id = id; + } + + public void updateTimestamp() { + timestamp = System.currentTimeMillis(); + } } diff --git a/Partner-Core/src/main/java/work/slhaf/partner/module/memory/updater/MemoryUpdater.java b/Partner-Core/src/main/java/work/slhaf/partner/module/memory/updater/MemoryUpdater.java index 35e643a6..c439eefb 100644 --- a/Partner-Core/src/main/java/work/slhaf/partner/module/memory/updater/MemoryUpdater.java +++ b/Partner-Core/src/main/java/work/slhaf/partner/module/memory/updater/MemoryUpdater.java @@ -28,7 +28,6 @@ import work.slhaf.partner.module.memory.updater.summarizer.entity.SummarizeInput import work.slhaf.partner.module.memory.updater.summarizer.entity.SummarizeResult; import work.slhaf.partner.runtime.interaction.data.context.PartnerRunningFlowContext; -import java.util.ArrayList; import java.util.List; import java.util.UUID; import java.util.concurrent.ExecutorService; @@ -189,34 +188,23 @@ public class MemoryUpdater extends AbstractAgentModule.Running chatMessages, SummarizeResult summarizeResult) { - long now = System.currentTimeMillis(); String memoryId = memoryCapability.getMemorySessionId(); String resolvedMemoryId = memoryId == null || memoryId.isBlank() ? UUID.randomUUID().toString() : memoryId; MemoryUnit existingUnit = memoryCapability.getMemoryUnit(resolvedMemoryId); - List existingMessages = existingUnit != null && existingUnit.getConversationMessages() != null - ? existingUnit.getConversationMessages() - : List.of(); + List existingMessages = existingUnit.getConversationMessages(); int startIndex = existingMessages.size(); - MemorySlice memorySlice = new MemorySlice(); - memorySlice.setId(UUID.randomUUID().toString()); - memorySlice.setStartIndex(startIndex); - memorySlice.setEndIndex(startIndex + chatMessages.size()); - memorySlice.setSummary(summarizeResult.getSummary()); - memorySlice.setTimestamp(now); + MemorySlice memorySlice = new MemorySlice( + startIndex, + startIndex + chatMessages.size(), + summarizeResult.getSummary() + ); - MemoryUnit memoryUnit = new MemoryUnit(); - memoryUnit.setId(resolvedMemoryId); - memoryUnit.setTimestamp(now); - List conversationMessages = new ArrayList<>(existingMessages); - conversationMessages.addAll(chatMessages); - memoryUnit.setConversationMessages(conversationMessages); + MemoryUnit memoryUnit = new MemoryUnit(resolvedMemoryId); + memoryUnit.updateTimestamp(); + memoryUnit.getConversationMessages().addAll(chatMessages); - List slices = existingUnit != null && existingUnit.getSlices() != null - ? new ArrayList<>(existingUnit.getSlices()) - : new ArrayList<>(); - slices.add(memorySlice); - memoryUnit.setSlices(slices); + memoryUnit.getSlices().add(memorySlice); return memoryUnit; } diff --git a/Partner-Core/src/test/java/work/slhaf/partner/core/memory/MemoryCoreTest.java b/Partner-Core/src/test/java/work/slhaf/partner/core/memory/MemoryCoreTest.java index 664015c3..6e9ec328 100644 --- a/Partner-Core/src/test/java/work/slhaf/partner/core/memory/MemoryCoreTest.java +++ b/Partner-Core/src/test/java/work/slhaf/partner/core/memory/MemoryCoreTest.java @@ -11,7 +11,6 @@ import work.slhaf.partner.framework.agent.model.pojo.Message; import java.nio.file.Files; import java.nio.file.Path; -import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.UUID; @@ -57,14 +56,13 @@ class MemoryCoreTest { slice.setStartIndex(1); slice.setEndIndex(99); - MemoryUnit unit = new MemoryUnit(); - unit.setId("unit-1"); - unit.setConversationMessages(new ArrayList<>(List.of( + MemoryUnit unit = new MemoryUnit("unit-1"); + unit.getConversationMessages().addAll(List.of( new Message(Message.Character.USER, "m0"), new Message(Message.Character.USER, "m1"), new Message(Message.Character.USER, "m2") - ))); - unit.setSlices(new ArrayList<>(List.of(slice))); + )); + unit.getSlices().add(slice); memoryCore.saveMemoryUnit(unit); diff --git a/Partner-Core/src/test/java/work/slhaf/partner/module/memory/runtime/MemoryRuntimeTest.java b/Partner-Core/src/test/java/work/slhaf/partner/module/memory/runtime/MemoryRuntimeTest.java index e0213ac7..008736cb 100644 --- a/Partner-Core/src/test/java/work/slhaf/partner/module/memory/runtime/MemoryRuntimeTest.java +++ b/Partner-Core/src/test/java/work/slhaf/partner/module/memory/runtime/MemoryRuntimeTest.java @@ -74,13 +74,13 @@ class MemoryRuntimeTest { @Test void shouldSliceMessagesUsingLeftClosedRightOpenRange() throws Exception { MemoryRuntime runtime = new MemoryRuntime(); - MemoryUnit unit = new MemoryUnit(); - unit.setConversationMessages(new ArrayList<>(List.of( + MemoryUnit unit = new MemoryUnit("unit-1"); + unit.getConversationMessages().addAll(List.of( message("m0"), message("m1"), message("m2"), message("m3") - ))); + )); MemorySlice slice = new MemorySlice(); slice.setStartIndex(1); @@ -105,14 +105,13 @@ class MemoryRuntimeTest { MemoryRuntime runtime = new MemoryRuntime(); setField(runtime, "memoryCapability", memoryCapability); - MemoryUnit unit = new MemoryUnit(); - unit.setId("unit-1"); - unit.setConversationMessages(new ArrayList<>(List.of( + MemoryUnit unit = new MemoryUnit("unit-1"); + unit.getConversationMessages().addAll(List.of( message("m0"), message("m1"), message("m2"), message("m3") - ))); + )); MemorySlice firstSlice = new MemorySlice(); firstSlice.setId("slice-1"); @@ -128,7 +127,7 @@ class MemoryRuntimeTest { secondSlice.setSummary("second"); secondSlice.setTimestamp(2L); - unit.setSlices(new ArrayList<>(List.of(firstSlice, secondSlice))); + unit.getSlices().addAll(List.of(firstSlice, secondSlice)); runtime.recordMemory(unit, "topic/main", List.of("topic/related")); diff --git a/Partner-Core/src/test/java/work/slhaf/partner/module/memory/updater/MemoryUpdaterTest.java b/Partner-Core/src/test/java/work/slhaf/partner/module/memory/updater/MemoryUpdaterTest.java index f5aee76a..f3114fab 100644 --- a/Partner-Core/src/test/java/work/slhaf/partner/module/memory/updater/MemoryUpdaterTest.java +++ b/Partner-Core/src/test/java/work/slhaf/partner/module/memory/updater/MemoryUpdaterTest.java @@ -9,7 +9,10 @@ import work.slhaf.partner.module.memory.updater.summarizer.entity.SummarizeResul import java.lang.reflect.Field; import java.lang.reflect.Method; -import java.util.*; +import java.util.Collection; +import java.util.HashMap; +import java.util.List; +import java.util.Map; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertNotNull; @@ -55,19 +58,18 @@ class MemoryUpdaterTest { setField(updater, "memoryCapability", memoryCapability); String sessionId = memoryCapability.getMemorySessionId(); - MemoryUnit existingUnit = new MemoryUnit(); - existingUnit.setId(sessionId); - existingUnit.setConversationMessages(new ArrayList<>(List.of( + MemoryUnit existingUnit = new MemoryUnit(sessionId); + existingUnit.getConversationMessages().addAll(List.of( message(Message.Character.USER, "old-user"), message(Message.Character.ASSISTANT, "old-assistant") - ))); + )); MemorySlice existingSlice = new MemorySlice(); existingSlice.setId("slice-1"); existingSlice.setStartIndex(0); existingSlice.setEndIndex(2); existingSlice.setSummary("old-summary"); existingSlice.setTimestamp(1L); - existingUnit.setSlices(new ArrayList<>(List.of(existingSlice))); + existingUnit.getSlices().add(existingSlice); memoryCapability.saveMemoryUnit(existingUnit); MemoryUnit merged = invokeBuildMemoryUnit( @@ -121,14 +123,13 @@ class MemoryUpdaterTest { MemoryUpdater updater = new MemoryUpdater(); setField(updater, "memoryCapability", memoryCapability); - MemoryUnit existingUnit = new MemoryUnit(); - existingUnit.setId("session-3"); - existingUnit.setConversationMessages(new ArrayList<>(List.of( + MemoryUnit existingUnit = new MemoryUnit("session-3"); + existingUnit.getConversationMessages().addAll(List.of( message(Message.Character.USER, "m1"), message(Message.Character.ASSISTANT, "m2"), message(Message.Character.USER, "m3"), message(Message.Character.ASSISTANT, "m4") - ))); + )); memoryCapability.saveMemoryUnit(existingUnit); List increment = invokeResolveChatIncrement( @@ -150,13 +151,12 @@ class MemoryUpdaterTest { MemoryUpdater updater = new MemoryUpdater(); setField(updater, "memoryCapability", memoryCapability); - MemoryUnit existingUnit = new MemoryUnit(); - existingUnit.setId("session-4"); - existingUnit.setConversationMessages(new ArrayList<>(List.of( + MemoryUnit existingUnit = new MemoryUnit("session-4"); + existingUnit.getConversationMessages().addAll(List.of( message(Message.Character.USER, "m1"), message(Message.Character.ASSISTANT, "m2"), message(Message.Character.USER, "m3") - ))); + )); memoryCapability.saveMemoryUnit(existingUnit); List increment = invokeResolveChatIncrement(