diff --git a/Partner-Core/src/main/java/work/slhaf/partner/module/memory/runtime/MemoryRuntime.java b/Partner-Core/src/main/java/work/slhaf/partner/module/memory/runtime/MemoryRuntime.java index 376e5204..21ec5326 100644 --- a/Partner-Core/src/main/java/work/slhaf/partner/module/memory/runtime/MemoryRuntime.java +++ b/Partner-Core/src/main/java/work/slhaf/partner/module/memory/runtime/MemoryRuntime.java @@ -1,9 +1,10 @@ package work.slhaf.partner.module.memory.runtime; +import com.alibaba.fastjson2.JSONArray; +import com.alibaba.fastjson2.JSONObject; import lombok.EqualsAndHashCode; import lombok.extern.slf4j.Slf4j; -import org.apache.commons.io.FileUtils; -import work.slhaf.partner.common.config.PartnerAgentConfigLoader; +import org.jetbrains.annotations.NotNull; import work.slhaf.partner.core.cognition.CognitionCapability; import work.slhaf.partner.core.memory.MemoryCapability; import work.slhaf.partner.core.memory.exception.UnExistedDateIndexException; @@ -11,18 +12,16 @@ import work.slhaf.partner.core.memory.exception.UnExistedTopicException; import work.slhaf.partner.core.memory.pojo.MemorySlice; import work.slhaf.partner.core.memory.pojo.MemoryUnit; import work.slhaf.partner.core.memory.pojo.SliceRef; -import work.slhaf.partner.framework.agent.common.entity.PersistableObject; -import work.slhaf.partner.framework.agent.config.AgentConfigLoader; import work.slhaf.partner.framework.agent.factory.capability.annotation.InjectCapability; import work.slhaf.partner.framework.agent.factory.component.abstracts.AbstractAgentModule; import work.slhaf.partner.framework.agent.factory.component.annotation.Init; import work.slhaf.partner.framework.agent.model.pojo.Message; +import work.slhaf.partner.framework.agent.state.State; +import work.slhaf.partner.framework.agent.state.StateSerializable; +import work.slhaf.partner.framework.agent.state.StateValue; import work.slhaf.partner.module.memory.selector.ActivatedMemorySlice; -import java.io.*; -import java.nio.file.Files; import java.nio.file.Path; -import java.nio.file.Paths; import java.time.Instant; import java.time.LocalDate; import java.time.ZoneId; @@ -30,13 +29,9 @@ import java.util.*; import java.util.concurrent.CopyOnWriteArrayList; import java.util.concurrent.locks.ReentrantLock; -import static work.slhaf.partner.common.Constant.Path.MEMORY_DATA; - @EqualsAndHashCode(callSuper = true) @Slf4j -public class MemoryRuntime extends AbstractAgentModule.Standalone { - - private static final String RUNTIME_KEY = "memory-runtime"; +public class MemoryRuntime extends AbstractAgentModule.Standalone implements StateSerializable { @InjectCapability private MemoryCapability memoryCapability; @@ -49,9 +44,8 @@ public class MemoryRuntime extends AbstractAgentModule.Standalone { @Init public void init() { - loadState(); + register(); checkAndSetMemoryId(); - Runtime.getRuntime().addShutdownHook(new Thread(this::saveStateSafely)); } private void checkAndSetMemoryId() { @@ -71,7 +65,6 @@ public class MemoryRuntime extends AbstractAgentModule.Standalone { if (!exists) { refs.add(sliceRef); } - saveState(); } finally { runtimeLock.unlock(); } @@ -104,7 +97,6 @@ public class MemoryRuntime extends AbstractAgentModule.Standalone { .addIfAbsent(new SliceRef(memoryUnit.getId(), slice.getId())); } } - saveState(); } finally { runtimeLock.unlock(); } @@ -192,8 +184,8 @@ public class MemoryRuntime extends AbstractAgentModule.Standalone { return List.of(); } int size = conversationMessages.size(); - int start = Math.max(0, Math.min(memorySlice.getStartIndex(), size)); - int end = Math.max(start, Math.min(memorySlice.getEndIndex(), size)); + int start = Math.clamp(memorySlice.getStartIndex(), 0, size); + int end = Math.clamp(memorySlice.getEndIndex(), start, size); if (start >= end) { return List.of(); } @@ -220,69 +212,118 @@ public class MemoryRuntime extends AbstractAgentModule.Standalone { return topicPath == null ? "" : topicPath.trim(); } - private void loadState() { - Path filePath = getFilePath(); - if (!Files.exists(filePath)) { - return; - } - try (ObjectInputStream ois = new ObjectInputStream(new FileInputStream(filePath.toFile()))) { - RuntimeState state = (RuntimeState) ois.readObject(); - topicSlices = state.topicSlices; - dateIndex = state.dateIndex; - } catch (Exception e) { - log.error("[MemoryRuntime] 加载运行态失败", e); - topicSlices = new HashMap<>(); - dateIndex = new HashMap<>(); - } + @Override + @NotNull + public Path statePath() { + return Path.of("module", "memory", "topic_based_memory.json"); } - private void saveStateSafely() { + @Override + public void load(@NotNull JSONObject state) { runtimeLock.lock(); try { - saveState(); + topicSlices = new HashMap<>(); + dateIndex = new HashMap<>(); + + JSONArray topicSlicesArray = state.getJSONArray("topic_slices"); + if (topicSlicesArray != null) { + for (int i = 0; i < topicSlicesArray.size(); i++) { + JSONObject topicObject = topicSlicesArray.getJSONObject(i); + if (topicObject == null) { + continue; + } + String topicPath = topicObject.getString("topic_path"); + if (topicPath == null) { + continue; + } + topicSlices.put(normalizeTopicPath(topicPath), decodeSliceRefs(topicObject.getJSONArray("refs"))); + } + } + + JSONArray dateIndexArray = state.getJSONArray("date_index"); + if (dateIndexArray != null) { + for (int i = 0; i < dateIndexArray.size(); i++) { + JSONObject dateObject = dateIndexArray.getJSONObject(i); + if (dateObject == null) { + continue; + } + String date = dateObject.getString("date"); + if (date == null) { + continue; + } + try { + dateIndex.put(LocalDate.parse(date), decodeSliceRefs(dateObject.getJSONArray("refs"))); + } catch (Exception e) { + log.warn("[MemoryRuntime] 跳过非法日期索引: {}", date, e); + } + } + } } finally { runtimeLock.unlock(); } } - private void saveState() { - Path filePath = getFilePath(); - Path tempPath = getTempFilePath(); + @Override + public @NotNull State convert() { + runtimeLock.lock(); try { - Files.createDirectories(Paths.get(MEMORY_DATA)); - FileUtils.createParentDirectories(filePath.toFile()); - try (ObjectOutputStream oos = new ObjectOutputStream(new FileOutputStream(tempPath.toFile()))) { - RuntimeState state = new RuntimeState(); - state.topicSlices = new HashMap<>(topicSlices); - state.dateIndex = new HashMap<>(dateIndex); - oos.writeObject(state); - } - Files.move(tempPath, filePath, java.nio.file.StandardCopyOption.REPLACE_EXISTING); - } catch (IOException e) { - log.error("[MemoryRuntime] 保存运行态失败", e); + State state = new State(); + + List topicSliceStates = topicSlices.entrySet().stream() + .sorted(Map.Entry.comparingByKey()) + .map(entry -> StateValue.obj(Map.of( + "topic_path", StateValue.str(entry.getKey()), + "refs", StateValue.arr(encodeSliceRefs(entry.getValue())) + ))) + .toList(); + state.append("topic_slices", StateValue.arr(topicSliceStates)); + + List dateIndexStates = dateIndex.entrySet().stream() + .sorted(Map.Entry.comparingByKey()) + .map(entry -> StateValue.obj(Map.of( + "date", StateValue.str(entry.getKey().toString()), + "refs", StateValue.arr(encodeSliceRefs(entry.getValue())) + ))) + .toList(); + state.append("date_index", StateValue.arr(dateIndexStates)); + + return state; + } finally { + runtimeLock.unlock(); } } - private Path getFilePath() { - String id = ((PartnerAgentConfigLoader) AgentConfigLoader.INSTANCE).getConfig().getAgentId(); - return Paths.get(MEMORY_DATA, id + "-" + RUNTIME_KEY + ".memory"); + private List encodeSliceRefs(List refs) { + return refs.stream() + .map(ref -> (StateValue) StateValue.obj(Map.of( + "unit_id", StateValue.str(ref.getUnitId()), + "slice_id", StateValue.str(ref.getSliceId()) + ))) + .toList(); } - private Path getTempFilePath() { - String id = ((PartnerAgentConfigLoader) AgentConfigLoader.INSTANCE).getConfig().getAgentId(); - return Paths.get(MEMORY_DATA, id + "-" + RUNTIME_KEY + "-temp.memory"); + private CopyOnWriteArrayList decodeSliceRefs(JSONArray refsArray) { + CopyOnWriteArrayList refs = new CopyOnWriteArrayList<>(); + if (refsArray == null) { + return refs; + } + for (int i = 0; i < refsArray.size(); i++) { + JSONObject refObject = refsArray.getJSONObject(i); + if (refObject == null) { + continue; + } + String unitId = refObject.getString("unit_id"); + String sliceId = refObject.getString("slice_id"); + if (unitId == null || sliceId == null) { + continue; + } + refs.addIfAbsent(new SliceRef(unitId, sliceId)); + } + return refs; } private static final class TopicTreeNode { private final Map children = new LinkedHashMap<>(); private int count; } - - private static final class RuntimeState extends PersistableObject { - @Serial - private static final long serialVersionUID = 1L; - - private Map> topicSlices = new HashMap<>(); - private Map> dateIndex = new HashMap<>(); - } } diff --git a/Partner-Core/src/test/java/work/slhaf/partner/module/memory/runtime/MemoryRuntimeTest.java b/Partner-Core/src/test/java/work/slhaf/partner/module/memory/runtime/MemoryRuntimeTest.java index 64b8de73..6ac884cf 100644 --- a/Partner-Core/src/test/java/work/slhaf/partner/module/memory/runtime/MemoryRuntimeTest.java +++ b/Partner-Core/src/test/java/work/slhaf/partner/module/memory/runtime/MemoryRuntimeTest.java @@ -1,33 +1,36 @@ package work.slhaf.partner.module.memory.runtime; +import com.alibaba.fastjson2.JSONArray; +import com.alibaba.fastjson2.JSONObject; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.io.TempDir; -import work.slhaf.partner.common.config.Config; -import work.slhaf.partner.common.config.PartnerAgentConfigLoader; +import work.slhaf.partner.core.cognition.CognitionCapability; import work.slhaf.partner.core.memory.MemoryCapability; import work.slhaf.partner.core.memory.pojo.MemorySlice; import work.slhaf.partner.core.memory.pojo.MemoryUnit; import work.slhaf.partner.core.memory.pojo.SliceRef; -import work.slhaf.partner.framework.agent.config.AgentConfigLoader; import work.slhaf.partner.framework.agent.model.pojo.Message; +import work.slhaf.partner.module.memory.selector.ActivatedMemorySlice; import java.lang.reflect.Field; import java.lang.reflect.Method; -import java.nio.file.Files; import java.nio.file.Path; -import java.util.*; +import java.time.LocalDate; +import java.util.Collection; +import java.util.HashMap; +import java.util.List; +import java.util.Map; import java.util.concurrent.CopyOnWriteArrayList; +import java.util.concurrent.locks.Lock; +import java.util.concurrent.locks.ReentrantLock; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertTrue; class MemoryRuntimeTest { - private AgentConfigLoader previousLoader; - private String runtimeAgentId; - @BeforeAll public static void beforeAll(@TempDir Path tempDir) { System.setProperty("user.home", tempDir.toAbsolutePath().toString()); @@ -53,29 +56,41 @@ class MemoryRuntimeTest { field.set(target, value); } - private static PartnerAgentConfigLoader testLoader(String agentId) { - PartnerAgentConfigLoader loader = new PartnerAgentConfigLoader(); - Config config = new Config(); - config.setAgentId(agentId); - Config.WebSocketConfig webSocketConfig = new Config.WebSocketConfig(); - webSocketConfig.setPort(18080); - config.setWebSocketConfig(webSocketConfig); - loader.setConfig(config); - loader.setModelConfigMap(new HashMap<>()); - return loader; - } - private static Message message(String content) { return new Message(Message.Character.USER, content); } - @AfterEach - void tearDown() throws Exception { - AgentConfigLoader.INSTANCE = previousLoader; - if (runtimeAgentId != null) { - Files.deleteIfExists(Path.of("data/memory", runtimeAgentId + "-memory-runtime.memory")); - Files.deleteIfExists(Path.of("data/memory", runtimeAgentId + "-memory-runtime-temp.memory")); - } + private static CognitionCapability stubCognitionCapability(List chatMessages) { + Lock lock = new ReentrantLock(); + return new CognitionCapability() { + @Override + public void initiateTurn(String input, String target, String... skippedModules) { + } + + @Override + public work.slhaf.partner.core.cognition.ContextWorkspace contextWorkspace() { + return new work.slhaf.partner.core.cognition.ContextWorkspace(); + } + + @Override + public List getChatMessages() { + return chatMessages; + } + + @Override + public List snapshotChatMessages() { + return chatMessages; + } + + @Override + public void rollChatMessagesWithSnapshot(int snapshotSize, int retainDivisor) { + } + + @Override + public Lock getMessageLock() { + return lock; + } + }; } @Test @@ -98,16 +113,16 @@ class MemoryRuntimeTest { assertTrue(invokeSliceMessages(runtime, unit, emptySlice).isEmpty()); } + @AfterEach + void tearDown() { + } + @Test void shouldBindTopicToLatestMemorySliceInsteadOfFirstSlice() throws Exception { - runtimeAgentId = "runtime-test-" + UUID.randomUUID(); - previousLoader = AgentConfigLoader.INSTANCE; - AgentConfigLoader.INSTANCE = testLoader(runtimeAgentId); - StubMemoryCapability memoryCapability = new StubMemoryCapability("session-test"); - MemoryRuntime runtime = new MemoryRuntime(); setField(runtime, "memoryCapability", memoryCapability); + setField(runtime, "cognitionCapability", stubCognitionCapability(List.of(message("seed")))); MemoryUnit unit = new MemoryUnit("unit-99"); unit.getConversationMessages().addAll(List.of( @@ -122,6 +137,7 @@ class MemoryRuntimeTest { MemorySlice secondSlice = MemorySlice.restore("slice-2", 2, 4, "second", 2L); unit.getSlices().addAll(List.of(firstSlice, secondSlice)); + memoryCapability.remember(unit); runtime.recordMemory(unit, "topic/main", List.of("topic/related")); @@ -132,6 +148,62 @@ class MemoryRuntimeTest { topicSlices.get("topic/related").stream().map(SliceRef::getSliceId).toList()); } + @Test + void shouldRoundTripTopicAndDateIndexesViaState() throws Exception { + StubMemoryCapability memoryCapability = new StubMemoryCapability("session-test"); + MemoryRuntime runtime = new MemoryRuntime(); + setField(runtime, "memoryCapability", memoryCapability); + setField(runtime, "cognitionCapability", stubCognitionCapability(List.of(message("seed")))); + + MemoryUnit unit = new MemoryUnit("unit-100"); + unit.getConversationMessages().addAll(List.of( + message("m0"), + message("m1"), + message("m2"), + message("m3") + )); + MemorySlice firstSlice = MemorySlice.restore("slice-1", 0, 2, "first", 86_400_000L); + MemorySlice secondSlice = MemorySlice.restore("slice-2", 2, 4, "second", 172_800_000L); + unit.getSlices().addAll(List.of(firstSlice, secondSlice)); + memoryCapability.remember(unit); + + runtime.recordMemory(unit, "topic/main", List.of("topic/related")); + + JSONObject state = JSONObject.parseObject(runtime.convert().toString()); + JSONArray topicSlices = state.getJSONArray("topic_slices"); + assertEquals(2, topicSlices.size()); + JSONObject mainTopic = topicSlices.stream() + .map(JSONObject.class::cast) + .filter(item -> "topic/main".equals(item.getString("topic_path"))) + .findFirst() + .orElseThrow(); + assertEquals("slice-2", mainTopic.getJSONArray("refs").getJSONObject(0).getString("slice_id")); + + JSONArray dateIndex = state.getJSONArray("date_index"); + assertEquals(2, dateIndex.size()); + JSONObject secondDate = dateIndex.stream() + .map(JSONObject.class::cast) + .filter(item -> "1970-01-03".equals(item.getString("date"))) + .findFirst() + .orElseThrow(); + assertEquals("slice-2", secondDate.getJSONArray("refs").getJSONObject(0).getString("slice_id")); + + MemoryRuntime restored = new MemoryRuntime(); + setField(restored, "memoryCapability", memoryCapability); + setField(restored, "cognitionCapability", stubCognitionCapability(List.of(message("seed")))); + restored.load(state); + + List topicResult = restored.queryActivatedMemoryByTopicPath("topic/main"); + assertEquals(1, topicResult.size()); + assertEquals("slice-2", topicResult.getFirst().getSliceId()); + assertEquals(List.of("m2", "m3"), topicResult.getFirst().getMessages().stream().map(Message::getContent).toList()); + + List dateResult = restored.queryActivatedMemoryByDate(LocalDate.parse("1970-01-03")); + assertEquals(1, dateResult.size()); + assertEquals("slice-2", dateResult.getFirst().getSliceId()); + assertEquals("second", dateResult.getFirst().getSummary()); + } + private static final class StubMemoryCapability implements MemoryCapability { private final String sessionId; private final Map units = new HashMap<>(); @@ -140,6 +212,10 @@ class MemoryRuntimeTest { this.sessionId = sessionId; } + private void remember(MemoryUnit unit) { + units.put(unit.getId(), unit); + } + @Override public MemoryUnit getMemoryUnit(String unitId) { return units.get(unitId);