mirror of
https://github.com/slhaf/Partner.git
synced 2026-05-12 16:53:04 +08:00
refactor(memory): correct memorySlice-topicPath binding behavior and adjust slice index into [startIndex,endIndex)
This commit is contained in:
@@ -100,7 +100,7 @@ public class MemoryCore extends PartnerCore<MemoryCore> {
|
|||||||
if (memoryUnit.getSlices() == null) {
|
if (memoryUnit.getSlices() == null) {
|
||||||
memoryUnit.setSlices(new ArrayList<>());
|
memoryUnit.setSlices(new ArrayList<>());
|
||||||
}
|
}
|
||||||
int maxIndex = Math.max(memoryUnit.getConversationMessages().size() - 1, 0);
|
int maxEndExclusive = Math.max(memoryUnit.getConversationMessages().size(), 0);
|
||||||
for (MemorySlice slice : memoryUnit.getSlices()) {
|
for (MemorySlice slice : memoryUnit.getSlices()) {
|
||||||
if (slice.getId() == null || slice.getId().isBlank()) {
|
if (slice.getId() == null || slice.getId().isBlank()) {
|
||||||
slice.setId(UUID.randomUUID().toString());
|
slice.setId(UUID.randomUUID().toString());
|
||||||
@@ -111,11 +111,14 @@ public class MemoryCore extends PartnerCore<MemoryCore> {
|
|||||||
if (slice.getStartIndex() == null || slice.getStartIndex() < 0) {
|
if (slice.getStartIndex() == null || slice.getStartIndex() < 0) {
|
||||||
slice.setStartIndex(0);
|
slice.setStartIndex(0);
|
||||||
}
|
}
|
||||||
if (slice.getEndIndex() == null || slice.getEndIndex() < slice.getStartIndex()) {
|
if (slice.getStartIndex() > maxEndExclusive) {
|
||||||
slice.setEndIndex(maxIndex);
|
slice.setStartIndex(maxEndExclusive);
|
||||||
}
|
}
|
||||||
if (slice.getEndIndex() > maxIndex) {
|
if (slice.getEndIndex() == null || slice.getEndIndex() < slice.getStartIndex()) {
|
||||||
slice.setEndIndex(maxIndex);
|
slice.setEndIndex(maxEndExclusive);
|
||||||
|
}
|
||||||
|
if (slice.getEndIndex() > maxEndExclusive) {
|
||||||
|
slice.setEndIndex(maxEndExclusive);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
memoryUnit.getSlices().sort(Comparator.naturalOrder());
|
memoryUnit.getSlices().sort(Comparator.naturalOrder());
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import work.slhaf.partner.api.agent.factory.capability.annotation.InjectCapabili
|
|||||||
import work.slhaf.partner.api.agent.factory.component.abstracts.AbstractAgentModule;
|
import work.slhaf.partner.api.agent.factory.component.abstracts.AbstractAgentModule;
|
||||||
import work.slhaf.partner.api.agent.factory.component.annotation.Init;
|
import work.slhaf.partner.api.agent.factory.component.annotation.Init;
|
||||||
import work.slhaf.partner.api.agent.runtime.config.AgentConfigLoader;
|
import work.slhaf.partner.api.agent.runtime.config.AgentConfigLoader;
|
||||||
|
import work.slhaf.partner.api.chat.pojo.Message;
|
||||||
import work.slhaf.partner.api.common.entity.PersistableObject;
|
import work.slhaf.partner.api.common.entity.PersistableObject;
|
||||||
import work.slhaf.partner.common.config.PartnerAgentConfigLoader;
|
import work.slhaf.partner.common.config.PartnerAgentConfigLoader;
|
||||||
import work.slhaf.partner.core.cognition.CognitionCapability;
|
import work.slhaf.partner.core.cognition.CognitionCapability;
|
||||||
@@ -80,7 +81,7 @@ public class MemoryRuntime extends AbstractAgentModule.Standalone {
|
|||||||
|
|
||||||
public void recordMemory(MemoryUnit memoryUnit, String topicPath, List<String> relatedTopicPaths, String dialogSummary) {
|
public void recordMemory(MemoryUnit memoryUnit, String topicPath, List<String> relatedTopicPaths, String dialogSummary) {
|
||||||
memoryCapability.saveMemoryUnit(memoryUnit);
|
memoryCapability.saveMemoryUnit(memoryUnit);
|
||||||
MemorySlice memorySlice = memoryUnit.getSlices().getFirst();
|
MemorySlice memorySlice = memoryUnit.getSlices().getLast();
|
||||||
SliceRef sliceRef = new SliceRef(memoryUnit.getId(), memorySlice.getId());
|
SliceRef sliceRef = new SliceRef(memoryUnit.getId(), memorySlice.getId());
|
||||||
indexMemoryUnit(memoryUnit);
|
indexMemoryUnit(memoryUnit);
|
||||||
bindTopic(topicPath, sliceRef);
|
bindTopic(topicPath, sliceRef);
|
||||||
@@ -209,7 +210,7 @@ public class MemoryRuntime extends AbstractAgentModule.Standalone {
|
|||||||
if (memoryUnit == null || memorySlice == null) {
|
if (memoryUnit == null || memorySlice == null) {
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
List<work.slhaf.partner.api.chat.pojo.Message> messages = sliceMessages(memoryUnit, memorySlice);
|
List<Message> messages = sliceMessages(memoryUnit, memorySlice);
|
||||||
LocalDate date = Instant.ofEpochMilli(memorySlice.getTimestamp())
|
LocalDate date = Instant.ofEpochMilli(memorySlice.getTimestamp())
|
||||||
.atZone(ZoneId.systemDefault())
|
.atZone(ZoneId.systemDefault())
|
||||||
.toLocalDate();
|
.toLocalDate();
|
||||||
@@ -223,17 +224,18 @@ public class MemoryRuntime extends AbstractAgentModule.Standalone {
|
|||||||
.build();
|
.build();
|
||||||
}
|
}
|
||||||
|
|
||||||
private List<work.slhaf.partner.api.chat.pojo.Message> sliceMessages(MemoryUnit memoryUnit, MemorySlice memorySlice) {
|
private List<Message> sliceMessages(MemoryUnit memoryUnit, MemorySlice memorySlice) {
|
||||||
List<work.slhaf.partner.api.chat.pojo.Message> conversationMessages = memoryUnit.getConversationMessages();
|
List<Message> conversationMessages = memoryUnit.getConversationMessages();
|
||||||
if (conversationMessages == null || conversationMessages.isEmpty()) {
|
if (conversationMessages == null || conversationMessages.isEmpty()) {
|
||||||
return List.of();
|
return List.of();
|
||||||
}
|
}
|
||||||
int start = Math.max(0, memorySlice.getStartIndex());
|
int size = conversationMessages.size();
|
||||||
int end = Math.min(conversationMessages.size() - 1, memorySlice.getEndIndex());
|
int start = Math.max(0, Math.min(memorySlice.getStartIndex(), size));
|
||||||
if (start > end) {
|
int end = Math.max(start, Math.min(memorySlice.getEndIndex(), size));
|
||||||
|
if (start >= end) {
|
||||||
return List.of();
|
return List.of();
|
||||||
}
|
}
|
||||||
return new ArrayList<>(conversationMessages.subList(start, end + 1));
|
return new ArrayList<>(conversationMessages.subList(start, end));
|
||||||
}
|
}
|
||||||
|
|
||||||
private void printSubTopicsTreeFormat(TopicTreeNode node, String prefix, StringBuilder stringBuilder) {
|
private void printSubTopicsTreeFormat(TopicTreeNode node, String prefix, StringBuilder stringBuilder) {
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import com.alibaba.fastjson2.JSONObject;
|
|||||||
import kotlin.Unit;
|
import kotlin.Unit;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
import lombok.EqualsAndHashCode;
|
import lombok.EqualsAndHashCode;
|
||||||
|
import org.jetbrains.annotations.NotNull;
|
||||||
import work.slhaf.partner.api.agent.factory.capability.annotation.InjectCapability;
|
import work.slhaf.partner.api.agent.factory.capability.annotation.InjectCapability;
|
||||||
import work.slhaf.partner.api.agent.factory.component.abstracts.AbstractAgentModule;
|
import work.slhaf.partner.api.agent.factory.component.abstracts.AbstractAgentModule;
|
||||||
import work.slhaf.partner.api.agent.factory.component.annotation.Init;
|
import work.slhaf.partner.api.agent.factory.component.annotation.Init;
|
||||||
@@ -81,7 +82,7 @@ public class MemoryUpdater extends AbstractAgentModule.Running<PartnerRunningFlo
|
|||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void execute(PartnerRunningFlowContext context) {
|
public void execute(@NotNull PartnerRunningFlowContext context) {
|
||||||
boolean trigger = cognitionCapability.getChatMessages().size() >= MEMORY_UPDATE_TRIGGER_ROLL_LIMIT;
|
boolean trigger = cognitionCapability.getChatMessages().size() >= MEMORY_UPDATE_TRIGGER_ROLL_LIMIT;
|
||||||
if (!trigger) {
|
if (!trigger) {
|
||||||
return;
|
return;
|
||||||
@@ -149,19 +150,33 @@ public class MemoryUpdater extends AbstractAgentModule.Running<PartnerRunningFlo
|
|||||||
|
|
||||||
private MemoryUnit buildMemoryUnit(List<Message> chatMessages, SummarizeResult summarizeResult) {
|
private MemoryUnit buildMemoryUnit(List<Message> chatMessages, SummarizeResult summarizeResult) {
|
||||||
long now = System.currentTimeMillis();
|
long now = System.currentTimeMillis();
|
||||||
|
String memoryId = memoryCapability.getMemorySessionId();
|
||||||
|
String resolvedMemoryId = memoryId == null || memoryId.isBlank() ? UUID.randomUUID().toString() : memoryId;
|
||||||
|
MemoryUnit existingUnit = memoryCapability.getMemoryUnit(resolvedMemoryId);
|
||||||
|
List<Message> existingMessages = existingUnit != null && existingUnit.getConversationMessages() != null
|
||||||
|
? existingUnit.getConversationMessages()
|
||||||
|
: List.of();
|
||||||
|
int startIndex = existingMessages.size();
|
||||||
|
|
||||||
MemorySlice memorySlice = new MemorySlice();
|
MemorySlice memorySlice = new MemorySlice();
|
||||||
memorySlice.setId(UUID.randomUUID().toString());
|
memorySlice.setId(UUID.randomUUID().toString());
|
||||||
memorySlice.setStartIndex(0);
|
memorySlice.setStartIndex(startIndex);
|
||||||
memorySlice.setEndIndex(chatMessages.size());
|
memorySlice.setEndIndex(startIndex + chatMessages.size());
|
||||||
memorySlice.setSummary(summarizeResult.getSummary());
|
memorySlice.setSummary(summarizeResult.getSummary());
|
||||||
memorySlice.setTimestamp(now);
|
memorySlice.setTimestamp(now);
|
||||||
|
|
||||||
MemoryUnit memoryUnit = new MemoryUnit();
|
MemoryUnit memoryUnit = new MemoryUnit();
|
||||||
String memoryId = memoryCapability.getMemorySessionId();
|
memoryUnit.setId(resolvedMemoryId);
|
||||||
memoryUnit.setId(memoryId == null || memoryId.isBlank() ? UUID.randomUUID().toString() : memoryId);
|
|
||||||
memoryUnit.setTimestamp(now);
|
memoryUnit.setTimestamp(now);
|
||||||
memoryUnit.setConversationMessages(new ArrayList<>(chatMessages));
|
List<Message> conversationMessages = new ArrayList<>(existingMessages);
|
||||||
memoryUnit.setSlices(new ArrayList<>(List.of(memorySlice)));
|
conversationMessages.addAll(chatMessages);
|
||||||
|
memoryUnit.setConversationMessages(conversationMessages);
|
||||||
|
|
||||||
|
List<MemorySlice> slices = existingUnit != null && existingUnit.getSlices() != null
|
||||||
|
? new ArrayList<>(existingUnit.getSlices())
|
||||||
|
: new ArrayList<>();
|
||||||
|
slices.add(memorySlice);
|
||||||
|
memoryUnit.setSlices(slices);
|
||||||
return memoryUnit;
|
return memoryUnit;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,75 @@
|
|||||||
|
package work.slhaf.partner.core.memory;
|
||||||
|
|
||||||
|
import org.junit.jupiter.api.AfterEach;
|
||||||
|
import org.junit.jupiter.api.Test;
|
||||||
|
import work.slhaf.partner.api.agent.runtime.config.AgentConfigLoader;
|
||||||
|
import work.slhaf.partner.api.chat.pojo.Message;
|
||||||
|
import work.slhaf.partner.common.config.Config;
|
||||||
|
import work.slhaf.partner.common.config.PartnerAgentConfigLoader;
|
||||||
|
import work.slhaf.partner.core.memory.pojo.MemorySlice;
|
||||||
|
import work.slhaf.partner.core.memory.pojo.MemoryUnit;
|
||||||
|
|
||||||
|
import java.nio.file.Files;
|
||||||
|
import java.nio.file.Path;
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.HashMap;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.UUID;
|
||||||
|
|
||||||
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
|
|
||||||
|
class MemoryCoreTest {
|
||||||
|
|
||||||
|
private AgentConfigLoader previousLoader;
|
||||||
|
private String agentId;
|
||||||
|
|
||||||
|
private static PartnerAgentConfigLoader testLoader(String agentId) {
|
||||||
|
PartnerAgentConfigLoader loader = new PartnerAgentConfigLoader();
|
||||||
|
Config config = new Config();
|
||||||
|
config.setAgentId(agentId);
|
||||||
|
Config.WebSocketConfig webSocketConfig = new Config.WebSocketConfig();
|
||||||
|
webSocketConfig.setPort(18080);
|
||||||
|
config.setWebSocketConfig(webSocketConfig);
|
||||||
|
loader.setConfig(config);
|
||||||
|
loader.setModelConfigMap(new HashMap<>());
|
||||||
|
return loader;
|
||||||
|
}
|
||||||
|
|
||||||
|
@AfterEach
|
||||||
|
void tearDown() throws Exception {
|
||||||
|
AgentConfigLoader.INSTANCE = previousLoader;
|
||||||
|
if (agentId != null) {
|
||||||
|
Files.deleteIfExists(Path.of("data/memory", agentId + "-memory-core.memory"));
|
||||||
|
Files.deleteIfExists(Path.of("data/memory", agentId + "-temp-memory-core.memory"));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void shouldNormalizeSliceEndIndexUsingExclusiveUpperBound() throws Exception {
|
||||||
|
agentId = "memory-core-test-" + UUID.randomUUID();
|
||||||
|
previousLoader = AgentConfigLoader.INSTANCE;
|
||||||
|
AgentConfigLoader.INSTANCE = testLoader(agentId);
|
||||||
|
|
||||||
|
MemoryCore memoryCore = new MemoryCore();
|
||||||
|
|
||||||
|
MemorySlice slice = new MemorySlice();
|
||||||
|
slice.setId("slice-1");
|
||||||
|
slice.setStartIndex(1);
|
||||||
|
slice.setEndIndex(99);
|
||||||
|
|
||||||
|
MemoryUnit unit = new MemoryUnit();
|
||||||
|
unit.setId("unit-1");
|
||||||
|
unit.setConversationMessages(new ArrayList<>(List.of(
|
||||||
|
new Message(Message.Character.USER, "m0"),
|
||||||
|
new Message(Message.Character.USER, "m1"),
|
||||||
|
new Message(Message.Character.USER, "m2")
|
||||||
|
)));
|
||||||
|
unit.setSlices(new ArrayList<>(List.of(slice)));
|
||||||
|
|
||||||
|
memoryCore.saveMemoryUnit(unit);
|
||||||
|
|
||||||
|
MemorySlice savedSlice = memoryCore.getMemorySlice("unit-1", "slice-1");
|
||||||
|
assertEquals(1, savedSlice.getStartIndex());
|
||||||
|
assertEquals(3, savedSlice.getEndIndex());
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,186 @@
|
|||||||
|
package work.slhaf.partner.module.memory.runtime;
|
||||||
|
|
||||||
|
import org.junit.jupiter.api.AfterEach;
|
||||||
|
import org.junit.jupiter.api.Test;
|
||||||
|
import work.slhaf.partner.api.agent.runtime.config.AgentConfigLoader;
|
||||||
|
import work.slhaf.partner.api.chat.pojo.Message;
|
||||||
|
import work.slhaf.partner.common.config.Config;
|
||||||
|
import work.slhaf.partner.common.config.PartnerAgentConfigLoader;
|
||||||
|
import work.slhaf.partner.core.memory.MemoryCapability;
|
||||||
|
import work.slhaf.partner.core.memory.pojo.MemorySlice;
|
||||||
|
import work.slhaf.partner.core.memory.pojo.MemoryUnit;
|
||||||
|
import work.slhaf.partner.core.memory.pojo.SliceRef;
|
||||||
|
|
||||||
|
import java.lang.reflect.Field;
|
||||||
|
import java.lang.reflect.Method;
|
||||||
|
import java.nio.file.Files;
|
||||||
|
import java.nio.file.Path;
|
||||||
|
import java.util.*;
|
||||||
|
import java.util.concurrent.CopyOnWriteArrayList;
|
||||||
|
|
||||||
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
|
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||||
|
|
||||||
|
class MemoryRuntimeTest {
|
||||||
|
|
||||||
|
private AgentConfigLoader previousLoader;
|
||||||
|
private String runtimeAgentId;
|
||||||
|
|
||||||
|
@SuppressWarnings("unchecked")
|
||||||
|
private static Map<String, CopyOnWriteArrayList<SliceRef>> topicSlices(MemoryRuntime runtime) throws Exception {
|
||||||
|
Field field = MemoryRuntime.class.getDeclaredField("topicSlices");
|
||||||
|
field.setAccessible(true);
|
||||||
|
return (Map<String, CopyOnWriteArrayList<SliceRef>>) field.get(runtime);
|
||||||
|
}
|
||||||
|
|
||||||
|
@SuppressWarnings("unchecked")
|
||||||
|
private static List<Message> invokeSliceMessages(MemoryRuntime runtime, MemoryUnit unit, MemorySlice slice) throws Exception {
|
||||||
|
Method method = MemoryRuntime.class.getDeclaredMethod("sliceMessages", MemoryUnit.class, MemorySlice.class);
|
||||||
|
method.setAccessible(true);
|
||||||
|
return (List<Message>) method.invoke(runtime, unit, slice);
|
||||||
|
}
|
||||||
|
|
||||||
|
private static void setField(Object target, String fieldName, Object value) throws Exception {
|
||||||
|
Field field = target.getClass().getDeclaredField(fieldName);
|
||||||
|
field.setAccessible(true);
|
||||||
|
field.set(target, value);
|
||||||
|
}
|
||||||
|
|
||||||
|
private static PartnerAgentConfigLoader testLoader(String agentId) {
|
||||||
|
PartnerAgentConfigLoader loader = new PartnerAgentConfigLoader();
|
||||||
|
Config config = new Config();
|
||||||
|
config.setAgentId(agentId);
|
||||||
|
Config.WebSocketConfig webSocketConfig = new Config.WebSocketConfig();
|
||||||
|
webSocketConfig.setPort(18080);
|
||||||
|
config.setWebSocketConfig(webSocketConfig);
|
||||||
|
loader.setConfig(config);
|
||||||
|
loader.setModelConfigMap(new HashMap<>());
|
||||||
|
return loader;
|
||||||
|
}
|
||||||
|
|
||||||
|
private static Message message(String content) {
|
||||||
|
return new Message(Message.Character.USER, content);
|
||||||
|
}
|
||||||
|
|
||||||
|
@AfterEach
|
||||||
|
void tearDown() throws Exception {
|
||||||
|
AgentConfigLoader.INSTANCE = previousLoader;
|
||||||
|
if (runtimeAgentId != null) {
|
||||||
|
Files.deleteIfExists(Path.of("data/memory", runtimeAgentId + "-memory-runtime.memory"));
|
||||||
|
Files.deleteIfExists(Path.of("data/memory", runtimeAgentId + "-memory-runtime-temp.memory"));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void shouldSliceMessagesUsingLeftClosedRightOpenRange() throws Exception {
|
||||||
|
MemoryRuntime runtime = new MemoryRuntime();
|
||||||
|
MemoryUnit unit = new MemoryUnit();
|
||||||
|
unit.setConversationMessages(new ArrayList<>(List.of(
|
||||||
|
message("m0"),
|
||||||
|
message("m1"),
|
||||||
|
message("m2"),
|
||||||
|
message("m3")
|
||||||
|
)));
|
||||||
|
|
||||||
|
MemorySlice slice = new MemorySlice();
|
||||||
|
slice.setStartIndex(1);
|
||||||
|
slice.setEndIndex(3);
|
||||||
|
|
||||||
|
List<Message> messages = invokeSliceMessages(runtime, unit, slice);
|
||||||
|
assertEquals(List.of("m1", "m2"), messages.stream().map(Message::getContent).toList());
|
||||||
|
|
||||||
|
slice.setStartIndex(2);
|
||||||
|
slice.setEndIndex(2);
|
||||||
|
assertTrue(invokeSliceMessages(runtime, unit, slice).isEmpty());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void shouldBindTopicToLatestMemorySliceInsteadOfFirstSlice() throws Exception {
|
||||||
|
runtimeAgentId = "runtime-test-" + UUID.randomUUID();
|
||||||
|
previousLoader = AgentConfigLoader.INSTANCE;
|
||||||
|
AgentConfigLoader.INSTANCE = testLoader(runtimeAgentId);
|
||||||
|
|
||||||
|
StubMemoryCapability memoryCapability = new StubMemoryCapability("session-test");
|
||||||
|
|
||||||
|
MemoryRuntime runtime = new MemoryRuntime();
|
||||||
|
setField(runtime, "memoryCapability", memoryCapability);
|
||||||
|
|
||||||
|
MemoryUnit unit = new MemoryUnit();
|
||||||
|
unit.setId("unit-1");
|
||||||
|
unit.setConversationMessages(new ArrayList<>(List.of(
|
||||||
|
message("m0"),
|
||||||
|
message("m1"),
|
||||||
|
message("m2"),
|
||||||
|
message("m3")
|
||||||
|
)));
|
||||||
|
|
||||||
|
MemorySlice firstSlice = new MemorySlice();
|
||||||
|
firstSlice.setId("slice-1");
|
||||||
|
firstSlice.setStartIndex(0);
|
||||||
|
firstSlice.setEndIndex(2);
|
||||||
|
firstSlice.setSummary("first");
|
||||||
|
firstSlice.setTimestamp(1L);
|
||||||
|
|
||||||
|
MemorySlice secondSlice = new MemorySlice();
|
||||||
|
secondSlice.setId("slice-2");
|
||||||
|
secondSlice.setStartIndex(2);
|
||||||
|
secondSlice.setEndIndex(4);
|
||||||
|
secondSlice.setSummary("second");
|
||||||
|
secondSlice.setTimestamp(2L);
|
||||||
|
|
||||||
|
unit.setSlices(new ArrayList<>(List.of(firstSlice, secondSlice)));
|
||||||
|
|
||||||
|
runtime.recordMemory(unit, "topic/main", List.of("topic/related"), "dialog-summary");
|
||||||
|
|
||||||
|
Map<String, CopyOnWriteArrayList<SliceRef>> topicSlices = topicSlices(runtime);
|
||||||
|
assertEquals(List.of("slice-2"),
|
||||||
|
topicSlices.get("topic/main").stream().map(SliceRef::getSliceId).toList());
|
||||||
|
assertEquals(List.of("slice-2"),
|
||||||
|
topicSlices.get("topic/related").stream().map(SliceRef::getSliceId).toList());
|
||||||
|
}
|
||||||
|
|
||||||
|
private static final class StubMemoryCapability implements MemoryCapability {
|
||||||
|
private final String sessionId;
|
||||||
|
private final Map<String, MemoryUnit> units = new HashMap<>();
|
||||||
|
|
||||||
|
private StubMemoryCapability(String sessionId) {
|
||||||
|
this.sessionId = sessionId;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void saveMemoryUnit(MemoryUnit memoryUnit) {
|
||||||
|
units.put(memoryUnit.getId(), memoryUnit);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public MemoryUnit getMemoryUnit(String unitId) {
|
||||||
|
return units.get(unitId);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public MemorySlice getMemorySlice(String unitId, String sliceId) {
|
||||||
|
MemoryUnit unit = units.get(unitId);
|
||||||
|
if (unit == null || unit.getSlices() == null) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
return unit.getSlices().stream()
|
||||||
|
.filter(slice -> sliceId.equals(slice.getId()))
|
||||||
|
.findFirst()
|
||||||
|
.orElse(null);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Collection<MemoryUnit> listMemoryUnits() {
|
||||||
|
return units.values();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void refreshMemorySession() {
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String getMemorySessionId() {
|
||||||
|
return sessionId;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,154 @@
|
|||||||
|
package work.slhaf.partner.module.memory.updater;
|
||||||
|
|
||||||
|
import org.junit.jupiter.api.Test;
|
||||||
|
import work.slhaf.partner.api.chat.pojo.Message;
|
||||||
|
import work.slhaf.partner.core.memory.MemoryCapability;
|
||||||
|
import work.slhaf.partner.core.memory.pojo.MemorySlice;
|
||||||
|
import work.slhaf.partner.core.memory.pojo.MemoryUnit;
|
||||||
|
import work.slhaf.partner.module.memory.updater.summarizer.entity.SummarizeResult;
|
||||||
|
|
||||||
|
import java.lang.reflect.Field;
|
||||||
|
import java.lang.reflect.Method;
|
||||||
|
import java.util.*;
|
||||||
|
|
||||||
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
|
import static org.junit.jupiter.api.Assertions.assertNotNull;
|
||||||
|
|
||||||
|
class MemoryUpdaterTest {
|
||||||
|
|
||||||
|
private static MemoryUnit invokeBuildMemoryUnit(MemoryUpdater updater,
|
||||||
|
List<Message> chatMessages,
|
||||||
|
SummarizeResult summarizeResult) throws Exception {
|
||||||
|
Method method = MemoryUpdater.class.getDeclaredMethod("buildMemoryUnit", List.class, SummarizeResult.class);
|
||||||
|
method.setAccessible(true);
|
||||||
|
return (MemoryUnit) method.invoke(updater, chatMessages, summarizeResult);
|
||||||
|
}
|
||||||
|
|
||||||
|
private static void setField(Object target, String fieldName, Object value) throws Exception {
|
||||||
|
Field field = target.getClass().getDeclaredField(fieldName);
|
||||||
|
field.setAccessible(true);
|
||||||
|
field.set(target, value);
|
||||||
|
}
|
||||||
|
|
||||||
|
private static SummarizeResult summarizeResult(String summary) {
|
||||||
|
SummarizeResult result = new SummarizeResult();
|
||||||
|
result.setSummary(summary);
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
private static Message message(Message.Character role, String content) {
|
||||||
|
return new Message(role, content);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void shouldAppendNewSliceToExistingMemoryUnitWithinSameSession() throws Exception {
|
||||||
|
StubMemoryCapability memoryCapability = new StubMemoryCapability("session-1");
|
||||||
|
MemoryUpdater updater = new MemoryUpdater();
|
||||||
|
setField(updater, "memoryCapability", memoryCapability);
|
||||||
|
|
||||||
|
String sessionId = memoryCapability.getMemorySessionId();
|
||||||
|
MemoryUnit existingUnit = new MemoryUnit();
|
||||||
|
existingUnit.setId(sessionId);
|
||||||
|
existingUnit.setConversationMessages(new ArrayList<>(List.of(
|
||||||
|
message(Message.Character.USER, "old-user"),
|
||||||
|
message(Message.Character.ASSISTANT, "old-assistant")
|
||||||
|
)));
|
||||||
|
MemorySlice existingSlice = new MemorySlice();
|
||||||
|
existingSlice.setId("slice-1");
|
||||||
|
existingSlice.setStartIndex(0);
|
||||||
|
existingSlice.setEndIndex(2);
|
||||||
|
existingSlice.setSummary("old-summary");
|
||||||
|
existingSlice.setTimestamp(1L);
|
||||||
|
existingUnit.setSlices(new ArrayList<>(List.of(existingSlice)));
|
||||||
|
memoryCapability.saveMemoryUnit(existingUnit);
|
||||||
|
|
||||||
|
MemoryUnit merged = invokeBuildMemoryUnit(
|
||||||
|
updater,
|
||||||
|
List.of(
|
||||||
|
message(Message.Character.USER, "new-user"),
|
||||||
|
message(Message.Character.ASSISTANT, "new-assistant")
|
||||||
|
),
|
||||||
|
summarizeResult("new-summary")
|
||||||
|
);
|
||||||
|
|
||||||
|
assertEquals(sessionId, merged.getId());
|
||||||
|
assertEquals(4, merged.getConversationMessages().size());
|
||||||
|
assertEquals(List.of("old-user", "old-assistant", "new-user", "new-assistant"),
|
||||||
|
merged.getConversationMessages().stream().map(Message::getContent).toList());
|
||||||
|
assertEquals(2, merged.getSlices().size());
|
||||||
|
|
||||||
|
MemorySlice appendedSlice = merged.getSlices().get(1);
|
||||||
|
assertNotNull(appendedSlice.getId());
|
||||||
|
assertEquals(2, appendedSlice.getStartIndex());
|
||||||
|
assertEquals(4, appendedSlice.getEndIndex());
|
||||||
|
assertEquals("new-summary", appendedSlice.getSummary());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void shouldCreateNewMemoryUnitForNewSessionId() throws Exception {
|
||||||
|
StubMemoryCapability memoryCapability = new StubMemoryCapability("session-2");
|
||||||
|
MemoryUpdater updater = new MemoryUpdater();
|
||||||
|
setField(updater, "memoryCapability", memoryCapability);
|
||||||
|
|
||||||
|
MemoryUnit created = invokeBuildMemoryUnit(
|
||||||
|
updater,
|
||||||
|
List.of(
|
||||||
|
message(Message.Character.USER, "first"),
|
||||||
|
message(Message.Character.ASSISTANT, "second")
|
||||||
|
),
|
||||||
|
summarizeResult("fresh-summary")
|
||||||
|
);
|
||||||
|
|
||||||
|
assertEquals("session-2", created.getId());
|
||||||
|
assertEquals(2, created.getConversationMessages().size());
|
||||||
|
assertEquals(1, created.getSlices().size());
|
||||||
|
assertEquals(0, created.getSlices().getFirst().getStartIndex());
|
||||||
|
assertEquals(2, created.getSlices().getFirst().getEndIndex());
|
||||||
|
assertEquals("fresh-summary", created.getSlices().getFirst().getSummary());
|
||||||
|
}
|
||||||
|
|
||||||
|
private static final class StubMemoryCapability implements MemoryCapability {
|
||||||
|
private final String sessionId;
|
||||||
|
private final Map<String, MemoryUnit> units = new HashMap<>();
|
||||||
|
|
||||||
|
private StubMemoryCapability(String sessionId) {
|
||||||
|
this.sessionId = sessionId;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void saveMemoryUnit(MemoryUnit memoryUnit) {
|
||||||
|
units.put(memoryUnit.getId(), memoryUnit);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public MemoryUnit getMemoryUnit(String unitId) {
|
||||||
|
return units.get(unitId);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public MemorySlice getMemorySlice(String unitId, String sliceId) {
|
||||||
|
MemoryUnit unit = units.get(unitId);
|
||||||
|
if (unit == null || unit.getSlices() == null) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
return unit.getSlices().stream()
|
||||||
|
.filter(slice -> sliceId.equals(slice.getId()))
|
||||||
|
.findFirst()
|
||||||
|
.orElse(null);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Collection<MemoryUnit> listMemoryUnits() {
|
||||||
|
return units.values();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void refreshMemorySession() {
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String getMemorySessionId() {
|
||||||
|
return sessionId;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user