3 Commits

20 changed files with 288 additions and 164 deletions

View File

@@ -60,7 +60,7 @@ public interface CognitionCapability {
/**
* Rename the canonical subject of a known entity and refresh entity/active-entity indexes.
*/
boolean renameEntitySubject(String entityUuid, String newSubject);
boolean renameEntitySubject(String entityUuid, String newSubject, boolean keepOldSubjectAsAlias);
/**
* Add an alias or mention form for a known entity and refresh entity indexes.

View File

@@ -154,16 +154,16 @@ public class ImpressionCore implements StateSerializable {
}
/**
* Rename the canonical subject of a known entity and keep its previous subject as a historical alias.
* Rename the canonical subject of a known entity and optionally keep its previous subject as a historical alias.
*/
@CapabilityMethod
public boolean renameEntitySubject(String entityUuid, String newSubject) {
public boolean renameEntitySubject(String entityUuid, String newSubject, boolean keepOldSubjectAsAlias) {
Entity entity = knownEntitiesByUuid.get(entityUuid);
if (entity == null || newSubject == null || newSubject.isBlank()) {
return false;
}
boolean renamed = entity.renameSubject(newSubject.trim());
boolean renamed = entity.renameSubject(newSubject.trim(), keepOldSubjectAsAlias);
if (!renamed) {
return false;
}

View File

@@ -1,7 +1,7 @@
package work.slhaf.partner.core.memory;
import work.slhaf.partner.core.memory.pojo.MemorySlice;
import work.slhaf.partner.core.memory.pojo.MemoryUnit;
import work.slhaf.partner.core.memory.pojo.MemorySliceSnapshot;
import work.slhaf.partner.core.memory.pojo.MemoryUnitSnapshot;
import work.slhaf.partner.framework.agent.factory.capability.annotation.Capability;
import work.slhaf.partner.framework.agent.model.pojo.Message;
import work.slhaf.partner.framework.agent.support.Result;
@@ -12,13 +12,13 @@ import java.util.List;
@Capability(value = "memory")
public interface MemoryCapability {
MemoryUnit getMemoryUnit(String unitId);
MemoryUnitSnapshot getMemoryUnit(String unitId);
Result<MemorySlice> getMemorySlice(String unitId, String sliceId);
Result<MemorySliceSnapshot> getMemorySlice(String unitId, String sliceId);
MemoryUnit updateMemoryUnit(List<Message> chatMessages, String summary);
MemoryUnitSnapshot updateMemoryUnit(List<Message> chatMessages, String summary);
Collection<MemoryUnit> listMemoryUnits();
Collection<MemoryUnitSnapshot> listMemoryUnits();
void refreshMemorySession();

View File

@@ -5,7 +5,9 @@ import com.alibaba.fastjson2.JSONObject;
import lombok.extern.slf4j.Slf4j;
import org.jetbrains.annotations.NotNull;
import work.slhaf.partner.core.memory.pojo.MemorySlice;
import work.slhaf.partner.core.memory.pojo.MemorySliceSnapshot;
import work.slhaf.partner.core.memory.pojo.MemoryUnit;
import work.slhaf.partner.core.memory.pojo.MemoryUnitSnapshot;
import work.slhaf.partner.framework.agent.factory.capability.annotation.CapabilityCore;
import work.slhaf.partner.framework.agent.factory.capability.annotation.CapabilityMethod;
import work.slhaf.partner.framework.agent.model.pojo.Message;
@@ -36,10 +38,10 @@ public class MemoryCore implements StateSerializable {
}
@CapabilityMethod
public MemoryUnit updateMemoryUnit(List<Message> chatMessages, String summary) {
public MemoryUnitSnapshot updateMemoryUnit(List<Message> chatMessages, String summary) {
memoryLock.lock();
try {
MemoryUnit unit = getMemoryUnit(memorySessionId);
MemoryUnit unit = getOrLoadMemoryUnit(memorySessionId);
unit.updateTimestamp();
List<Message> conversationMessages = unit.getConversationMessages();
@@ -55,14 +57,60 @@ public class MemoryCore implements StateSerializable {
unit.getSlices().add(memorySlice);
normalizeMemoryUnit(unit);
return unit;
return unit.snapshot();
} finally {
memoryLock.unlock();
}
}
@CapabilityMethod
public MemoryUnit getMemoryUnit(String unitId) {
public MemoryUnitSnapshot getMemoryUnit(String unitId) {
memoryLock.lock();
try {
MemoryUnit unit = getOrLoadMemoryUnit(unitId);
normalizeMemoryUnit(unit);
return unit.snapshot();
} finally {
memoryLock.unlock();
}
}
@CapabilityMethod
public Result<MemorySliceSnapshot> getMemorySlice(String unitId, String sliceId) {
memoryLock.lock();
try {
MemoryUnit memoryUnit = memoryUnits.get(unitId);
if (memoryUnit == null) {
return memorySliceNotFound(unitId, sliceId);
}
memoryUnit.load();
normalizeMemoryUnit(memoryUnit);
for (MemorySlice slice : memoryUnit.getSlices()) {
if (sliceId.equals(slice.getId())) {
return Result.success(slice.snapshot());
}
}
return memorySliceNotFound(unitId, sliceId);
} finally {
memoryLock.unlock();
}
}
@CapabilityMethod
public Collection<MemoryUnitSnapshot> listMemoryUnits() {
memoryLock.lock();
try {
return memoryUnits.values().stream()
.peek(MemoryUnit::load)
.peek(this::normalizeMemoryUnit)
.map(MemoryUnit::snapshot)
.toList();
} finally {
memoryLock.unlock();
}
}
private MemoryUnit getOrLoadMemoryUnit(String unitId) {
MemoryUnit unit = memoryUnits.computeIfAbsent(unitId, id -> {
MemoryUnit newUnit = new MemoryUnit(id);
newUnit.register();
@@ -72,21 +120,7 @@ public class MemoryCore implements StateSerializable {
return unit;
}
@CapabilityMethod
public Result<MemorySlice> getMemorySlice(String unitId, String sliceId) {
MemoryUnit memoryUnit = memoryUnits.get(unitId);
if (memoryUnit == null) {
return Result.failure(new MemoryLookupException(
"Memory slice not found: " + unitId + ":" + sliceId,
unitId + ":" + sliceId,
"MEMORY_SLICE"
));
}
for (MemorySlice slice : memoryUnit.getSlices()) {
if (sliceId.equals(slice.getId())) {
return Result.success(slice);
}
}
private Result<MemorySliceSnapshot> memorySliceNotFound(String unitId, String sliceId) {
return Result.failure(new MemoryLookupException(
"Memory slice not found: " + unitId + ":" + sliceId,
unitId + ":" + sliceId,
@@ -94,11 +128,6 @@ public class MemoryCore implements StateSerializable {
));
}
@CapabilityMethod
public Collection<MemoryUnit> listMemoryUnits() {
return new ArrayList<>(memoryUnits.values());
}
@CapabilityMethod
public void refreshMemorySession() {
memorySessionId = UUID.randomUUID().toString();

View File

@@ -33,6 +33,16 @@ public class MemorySlice implements Comparable<MemorySlice> {
return new MemorySlice(id, startIndex, endIndex, summary, timestamp);
}
public MemorySliceSnapshot snapshot() {
return new MemorySliceSnapshot(
id,
startIndex == null ? 0 : startIndex,
endIndex == null ? 0 : endIndex,
summary,
timestamp == null ? 0L : timestamp
);
}
@Override
public int compareTo(MemorySlice memorySlice) {
if (memorySlice.getTimestamp() > this.getTimestamp()) {

View File

@@ -0,0 +1,9 @@
package work.slhaf.partner.core.memory.pojo
data class MemorySliceSnapshot(
val id: String,
val startIndex: Int,
val endIndex: Int,
val summary: String?,
val timestamp: Long,
)

View File

@@ -31,6 +31,15 @@ public class MemoryUnit implements StateSerializable {
timestamp = System.currentTimeMillis();
}
public MemoryUnitSnapshot snapshot() {
return new MemoryUnitSnapshot(
id,
List.copyOf(conversationMessages),
timestamp == null ? 0L : timestamp,
slices.stream().map(MemorySlice::snapshot).toList()
);
}
@Override
public @NotNull Path statePath() {
return Path.of("core", "memory", "memory-unit" + id + ".json");

View File

@@ -0,0 +1,23 @@
package work.slhaf.partner.core.memory.pojo
import work.slhaf.partner.framework.agent.model.pojo.Message
data class MemoryUnitSnapshot(
val id: String,
val conversationMessages: List<Message>,
val timestamp: Long,
val slices: List<MemorySliceSnapshot>,
) {
fun messagesOf(slice: MemorySliceSnapshot): List<Message> {
if (conversationMessages.isEmpty()) {
return emptyList()
}
val start = slice.startIndex.coerceIn(0, conversationMessages.size)
val end = slice.endIndex.coerceIn(start, conversationMessages.size)
if (start >= end) {
return emptyList()
}
return conversationMessages.subList(start, end).toList()
}
}

View File

@@ -10,8 +10,8 @@ import work.slhaf.partner.core.cognition.CognitionCapability;
import work.slhaf.partner.core.cognition.context.BlockContent;
import work.slhaf.partner.core.cognition.context.ContextBlock;
import work.slhaf.partner.core.memory.MemoryCapability;
import work.slhaf.partner.core.memory.pojo.MemorySlice;
import work.slhaf.partner.core.memory.pojo.MemoryUnit;
import work.slhaf.partner.core.memory.pojo.MemorySliceSnapshot;
import work.slhaf.partner.core.memory.pojo.MemoryUnitSnapshot;
import work.slhaf.partner.framework.agent.factory.capability.annotation.InjectCapability;
import work.slhaf.partner.framework.agent.factory.component.annotation.AgentComponent;
import work.slhaf.partner.framework.agent.factory.component.annotation.Init;
@@ -75,16 +75,16 @@ class BuiltinCapabilityActionProvider implements BuiltinActionProvider {
Function<Map<String, Object>, String> invoker = params -> {
String unitId = BuiltinActionRegistry.BuiltinActionDefinition.requireString(params, "unit_id");
String sliceId = BuiltinActionRegistry.BuiltinActionDefinition.requireString(params, "slice_id");
Result<MemorySlice> sliceResult = memoryCapability.getMemorySlice(unitId, sliceId);
Result<MemorySliceSnapshot> sliceResult = memoryCapability.getMemorySlice(unitId, sliceId);
if (sliceResult.exceptionOrNull() != null) {
return JSONObject.of(
"ok", false,
"message", sliceResult.exceptionOrNull().getLocalizedMessage()
).toJSONString();
}
MemorySlice slice = sliceResult.getOrThrow();
MemorySliceSnapshot slice = sliceResult.getOrThrow();
MemoryUnit unit = memoryCapability.getMemoryUnit(unitId);
MemoryUnitSnapshot unit = memoryCapability.getMemoryUnit(unitId);
cognitionCapability.contextWorkspace().register(new ContextBlock(
buildMemoryRecallFullBlock(unit, slice),
Set.of(ContextBlock.FocusedDomain.MEMORY),
@@ -105,13 +105,13 @@ class BuiltinCapabilityActionProvider implements BuiltinActionProvider {
);
}
private @NotNull BlockContent buildMemoryRecallFullBlock(MemoryUnit unit, MemorySlice slice) {
private @NotNull BlockContent buildMemoryRecallFullBlock(MemoryUnitSnapshot unit, MemorySliceSnapshot slice) {
return new BlockContent("memory_recall", "memory_capability") {
@Override
protected void fillXml(@NotNull Document document, @NotNull Element root) {
root.setAttribute("unit_id", unit.getId());
root.setAttribute("slice_id", slice.getId());
appendRepeatedElements(document, root, "message", unit.getConversationMessages().subList(slice.getStartIndex(), slice.getEndIndex()), (messageElement, message) -> {
appendRepeatedElements(document, root, "message", unit.messagesOf(slice), (messageElement, message) -> {
messageElement.setAttribute("role", message.getRole().name().toLowerCase(Locale.ROOT));
messageElement.setTextContent(message.getContent());
return Unit.INSTANCE;

View File

@@ -13,8 +13,8 @@ import work.slhaf.partner.core.cognition.CognitionCapability;
import work.slhaf.partner.core.cognition.context.BlockContent;
import work.slhaf.partner.core.cognition.context.ContextBlock;
import work.slhaf.partner.core.memory.MemoryCapability;
import work.slhaf.partner.core.memory.pojo.MemorySlice;
import work.slhaf.partner.core.memory.pojo.MemoryUnit;
import work.slhaf.partner.core.memory.pojo.MemorySliceSnapshot;
import work.slhaf.partner.core.memory.pojo.MemoryUnitSnapshot;
import work.slhaf.partner.core.perceive.PerceiveCapability;
import work.slhaf.partner.framework.agent.exception.AgentRuntimeException;
import work.slhaf.partner.framework.agent.exception.ExceptionReporterHandler;
@@ -31,6 +31,7 @@ import work.slhaf.partner.runtime.PartnerRunningFlowContext;
import java.time.ZonedDateTime;
import java.time.format.DateTimeFormatter;
import java.util.ArrayList;
import java.util.List;
import java.util.Set;
import java.util.concurrent.atomic.AtomicBoolean;
@@ -140,7 +141,7 @@ public class DialogRolling extends AbstractAgentModule.Running<PartnerRunningFlo
if (memoryId.isBlank()) {
return fullChatSnapshot;
}
MemoryUnit existingUnit = memoryCapability.getMemoryUnit(memoryId);
MemoryUnitSnapshot existingUnit = memoryCapability.getMemoryUnit(memoryId);
if (existingUnit.getConversationMessages().isEmpty()) {
return fullChatSnapshot;
}
@@ -158,8 +159,9 @@ public class DialogRolling extends AbstractAgentModule.Running<PartnerRunningFlo
@NotNull
RollingResult buildRollingResult(List<Message> chatSnapshot, int rollingSize, int retainDivisor) {
messageCompressor.execute(chatSnapshot);
Result<String> summaryResult = messageSummarizer.execute(chatSnapshot);
List<Message> rollingMessages = new ArrayList<>(chatSnapshot);
messageCompressor.execute(rollingMessages);
Result<String> summaryResult = messageSummarizer.execute(rollingMessages);
String summary = summaryResult.fold(
value -> value,
exp -> "no summary, due to exception"
@@ -167,20 +169,20 @@ public class DialogRolling extends AbstractAgentModule.Running<PartnerRunningFlo
if (summary.isBlank()) {
summary = "no summary, due to empty summarize result";
}
MemoryUnit memoryUnit = memoryCapability.updateMemoryUnit(chatSnapshot, summary);
MemorySlice newSlice = memoryUnit.getSlices().getLast();
return new RollingResult(memoryUnit, newSlice, List.copyOf(chatSnapshot), newSlice.getSummary(), rollingSize, retainDivisor);
MemoryUnitSnapshot memoryUnit = memoryCapability.updateMemoryUnit(rollingMessages, summary);
MemorySliceSnapshot newSlice = memoryUnit.getSlices().getLast();
return new RollingResult(memoryUnit, newSlice, rollingSize, retainDivisor);
}
private void applyRolling(RollingResult result) {
cognitionCapability.contextWorkspace().register(new ContextBlock(
buildDialogAbstractBlock(result.summary(), result.memoryUnit().getId(), result.memorySlice().getId()),
buildDialogAbstractBlock(result.getSummary(), result.getMemoryUnit().getId(), result.getMemorySlice().getId()),
Set.of(ContextBlock.FocusedDomain.MEMORY, ContextBlock.FocusedDomain.COMMUNICATION),
20,
5,
10
));
cognitionCapability.rollChatMessagesWithSnapshot(result.rollingSize(), result.retainDivisor());
cognitionCapability.rollChatMessagesWithSnapshot(result.getRollingSize(), result.getRetainDivisor());
}
private @NotNull BlockContent buildDialogAbstractBlock(String summary, String unitId, String sliceId) {

View File

@@ -1,17 +0,0 @@
package work.slhaf.partner.module.communication;
import work.slhaf.partner.core.memory.pojo.MemorySlice;
import work.slhaf.partner.core.memory.pojo.MemoryUnit;
import work.slhaf.partner.framework.agent.model.pojo.Message;
import java.util.List;
public record RollingResult(
MemoryUnit memoryUnit,
MemorySlice memorySlice,
List<Message> incrementMessages,
String summary,
int rollingSize,
int retainDivisor
) {
}

View File

@@ -0,0 +1,17 @@
package work.slhaf.partner.module.communication
import work.slhaf.partner.core.memory.pojo.MemorySliceSnapshot
import work.slhaf.partner.core.memory.pojo.MemoryUnitSnapshot
import work.slhaf.partner.framework.agent.model.pojo.Message
data class RollingResult(
val memoryUnit: MemoryUnitSnapshot,
val memorySlice: MemorySliceSnapshot,
val rollingSize: Int,
val retainDivisor: Int,
) {
val summary: String
get() = memorySlice.summary ?: ""
fun incrementMessages(): List<Message> = memoryUnit.messagesOf(memorySlice)
}

View File

@@ -0,0 +1,66 @@
package work.slhaf.partner.module.impression
/**
* A conservative, auditable plan produced after message rolling.
*
* The updater should treat this model as intent only: validation decides whether
* a step is safe to execute, and the applier performs mutations through
* CognitionCapability / ImpressionCore so indexes stay consistent.
*/
data class ImpressionUpdatePlan @JvmOverloads constructor(
val steps: List<ImpressionUpdateStep>,
val status: PlanStatus = PlanStatus.PREPARED,
val reason: String? = null,
)
enum class PlanStatus {
PREPARED,
CONFIRMED,
REJECTED,
}
sealed class ImpressionUpdateStep
data class UpdateExistingStep(
val entityUuid: String,
val updatePatch: UpdatePatch,
) : ImpressionUpdateStep()
data class CreateEntityStep(
val subject: String,
val impressions: List<ImpressionPatch> = emptyList(),
val features: List<FeaturePatch> = emptyList(),
val aliases: List<AliasPatch> = emptyList(),
val relations: List<RelationPatch> = emptyList(),
) : ImpressionUpdateStep()
sealed class UpdatePatch
data class ImpressionPatch @JvmOverloads constructor(
val impression: String,
val newImpression: String? = null,
val confidence: Double = 1.0,
) : UpdatePatch()
data class FeaturePatch @JvmOverloads constructor(
val feature: String,
val newFeature: String? = null,
val confidence: Double = 1.0,
) : UpdatePatch()
data class AliasPatch @JvmOverloads constructor(
val alias: String,
val deprecated: Boolean = false,
) : UpdatePatch()
data class SubjectPatch @JvmOverloads constructor(
val subject: String,
val keepOldSubjectAsAlias: Boolean = true,
) : UpdatePatch()
data class RelationPatch @JvmOverloads constructor(
val target: String,
val relation: String,
val strength: Double = 1.0,
) : UpdatePatch()

View File

@@ -4,8 +4,8 @@ import com.alibaba.fastjson2.JSONObject;
import org.jetbrains.annotations.NotNull;
import work.slhaf.partner.core.cognition.CognitionCapability;
import work.slhaf.partner.core.memory.MemoryCapability;
import work.slhaf.partner.core.memory.pojo.MemorySlice;
import work.slhaf.partner.core.memory.pojo.MemoryUnit;
import work.slhaf.partner.core.memory.pojo.MemorySliceSnapshot;
import work.slhaf.partner.core.memory.pojo.MemoryUnitSnapshot;
import work.slhaf.partner.core.memory.pojo.SliceRef;
import work.slhaf.partner.framework.agent.exception.ExceptionReporterHandler;
import work.slhaf.partner.framework.agent.factory.capability.annotation.InjectCapability;
@@ -52,11 +52,11 @@ public class MemoryRuntime extends AbstractAgentModule.Standalone implements Sta
}
}
public void recordMemory(MemoryUnit memoryUnit,
public void recordMemory(MemoryUnitSnapshot memoryUnit,
String topicPath,
List<String> relatedTopicPaths,
ActivationProfile activationProfile) {
MemorySlice memorySlice = memoryUnit.getSlices().getLast();
MemorySliceSnapshot memorySlice = memoryUnit.getSlices().getLast();
SliceRef sliceRef = new SliceRef(memoryUnit.getId(), memorySlice.getId());
LocalDate date = toLocalDate(memorySlice.getTimestamp());
runtimeLock.lock();
@@ -159,13 +159,13 @@ public class MemoryRuntime extends AbstractAgentModule.Standalone implements Sta
}
private ActivatedMemorySlice buildActivatedMemorySlice(SliceRef ref) {
MemoryUnit memoryUnit = memoryCapability.getMemoryUnit(ref.getUnitId());
Result<MemorySlice> memorySliceResult = memoryCapability.getMemorySlice(ref.getUnitId(), ref.getSliceId());
MemoryUnitSnapshot memoryUnit = memoryCapability.getMemoryUnit(ref.getUnitId());
Result<MemorySliceSnapshot> memorySliceResult = memoryCapability.getMemorySlice(ref.getUnitId(), ref.getSliceId());
if (memoryUnit == null || memorySliceResult.exceptionOrNull() != null) {
return null;
}
MemorySlice memorySlice = memorySliceResult.getOrThrow();
List<Message> messages = sliceMessages(memoryUnit, memorySlice);
MemorySliceSnapshot memorySlice = memorySliceResult.getOrThrow();
List<Message> messages = memoryUnit.messagesOf(memorySlice);
LocalDate date = toLocalDate(memorySlice.getTimestamp());
return ActivatedMemorySlice.builder()
.unitId(ref.getUnitId())
@@ -177,19 +177,6 @@ public class MemoryRuntime extends AbstractAgentModule.Standalone implements Sta
.build();
}
private List<Message> sliceMessages(MemoryUnit memoryUnit, MemorySlice memorySlice) {
List<Message> conversationMessages = memoryUnit.getConversationMessages();
if (conversationMessages.isEmpty()) {
return List.of();
}
int size = conversationMessages.size();
int start = Math.clamp(memorySlice.getStartIndex(), 0, size);
int end = Math.clamp(memorySlice.getEndIndex(), start, size);
if (start >= end) {
return List.of();
}
return new ArrayList<>(conversationMessages.subList(start, end));
}
private LocalDate toLocalDate(Long timestamp) {
return Instant.ofEpochMilli(timestamp)

View File

@@ -149,7 +149,7 @@ public class MemoryRecallProfileExtractor extends AbstractAgentModule.Standalone
@Override
public void consume(RollingResult result) {
List<Message> slicedMessages = sliceMessages(result);
List<Message> slicedMessages = result.incrementMessages();
if (slicedMessages.isEmpty()) {
return;
}
@@ -169,31 +169,21 @@ public class MemoryRecallProfileExtractor extends AbstractAgentModule.Standalone
relatedTopicPaths,
slicedMessages
);
memoryRuntime.recordMemory(result.memoryUnit(), topicPath, relatedTopicPaths, activationProfile);
memoryRuntime.recordMemory(result.getMemoryUnit(), topicPath, relatedTopicPaths, activationProfile);
}).onFailure(exp -> memoryRuntime.recordMemory(
result.memoryUnit(),
result.getMemoryUnit(),
null,
List.of(),
defaultActivationProfile()
));
}
private List<Message> sliceMessages(RollingResult result) {
int size = result.memoryUnit().getConversationMessages().size();
int start = Math.clamp(result.memorySlice().getStartIndex(), 0, size);
int end = Math.clamp(result.memorySlice().getEndIndex(), start, size);
if (start >= end) {
return List.of();
}
return result.memoryUnit().getConversationMessages().subList(start, end);
}
private Message resolveTopicTaskMessage(RollingResult result, List<Message> slicedMessages) {
return new TaskBlock() {
@Override
protected void fillXml(@NotNull Document document, @NotNull Element root) {
appendTextElement(document, root, "current_topic_tree", memoryRuntime.getTopicTree());
appendTextElement(document, root, "slice_summary", result.summary());
appendTextElement(document, root, "slice_summary", result.getSummary());
appendRepeatedElements(document, root, "message", slicedMessages, (messageElement, message) -> {
messageElement.setAttribute("role", message.roleValue());
messageElement.setTextContent(message.getContent());

View File

@@ -4,8 +4,8 @@ import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir;
import work.slhaf.partner.core.memory.pojo.MemorySlice;
import work.slhaf.partner.core.memory.pojo.MemoryUnit;
import work.slhaf.partner.core.memory.pojo.MemorySliceSnapshot;
import work.slhaf.partner.core.memory.pojo.MemoryUnitSnapshot;
import work.slhaf.partner.framework.agent.model.pojo.Message;
import java.nio.file.Path;
@@ -32,7 +32,7 @@ class MemoryCoreTest {
void shouldCreateFirstSliceFromChatMessages() {
String sessionId = memoryCore.getMemorySessionId();
MemoryUnit updatedUnit = memoryCore.updateMemoryUnit(List.of(
MemoryUnitSnapshot updatedUnit = memoryCore.updateMemoryUnit(List.of(
new Message(Message.Character.USER, "m0"),
new Message(Message.Character.USER, "m1"),
new Message(Message.Character.USER, "m2")
@@ -43,7 +43,7 @@ class MemoryCoreTest {
updatedUnit.getConversationMessages().stream().map(Message::getContent).toList());
assertEquals(1, updatedUnit.getSlices().size());
MemorySlice firstSlice = updatedUnit.getSlices().getFirst();
MemorySliceSnapshot firstSlice = updatedUnit.getSlices().getFirst();
assertNotNull(firstSlice.getId());
assertEquals(0, firstSlice.getStartIndex());
assertEquals(3, firstSlice.getEndIndex());
@@ -60,7 +60,7 @@ class MemoryCoreTest {
new Message(Message.Character.USER, "m0")
), "first-summary");
MemoryUnit updatedUnit = memoryCore.updateMemoryUnit(List.of(
MemoryUnitSnapshot updatedUnit = memoryCore.updateMemoryUnit(List.of(
new Message(Message.Character.ASSISTANT, "m1"),
new Message(Message.Character.USER, "m2")
), "second-summary");
@@ -70,14 +70,14 @@ class MemoryCoreTest {
updatedUnit.getConversationMessages().stream().map(Message::getContent).toList());
assertEquals(2, updatedUnit.getSlices().size());
MemorySlice appendedSlice = updatedUnit.getSlices().getLast();
MemorySliceSnapshot appendedSlice = updatedUnit.getSlices().getLast();
assertNotNull(appendedSlice.getId());
assertEquals(1, appendedSlice.getStartIndex());
assertEquals(3, appendedSlice.getEndIndex());
assertEquals("second-summary", appendedSlice.getSummary());
assertTrue(appendedSlice.getTimestamp() > 0);
MemorySlice loadedSlice = memoryCore.getMemorySlice(sessionId, appendedSlice.getId()).getOrThrow();
MemorySliceSnapshot loadedSlice = memoryCore.getMemorySlice(sessionId, appendedSlice.getId()).getOrThrow();
assertNotNull(loadedSlice);
assertEquals(1, loadedSlice.getStartIndex());
assertEquals(3, loadedSlice.getEndIndex());

View File

@@ -175,7 +175,7 @@ class CommunicationProducerTest {
}
@Override
public boolean renameEntitySubject(String entityUuid, String newSubject) {
public boolean renameEntitySubject(String entityUuid, String newSubject, boolean keepOldSubjectAsAlias) {
return false;
}

View File

@@ -6,7 +6,9 @@ import org.junit.jupiter.api.io.TempDir;
import org.mockito.Mockito;
import work.slhaf.partner.core.memory.MemoryCapability;
import work.slhaf.partner.core.memory.pojo.MemorySlice;
import work.slhaf.partner.core.memory.pojo.MemorySliceSnapshot;
import work.slhaf.partner.core.memory.pojo.MemoryUnit;
import work.slhaf.partner.core.memory.pojo.MemoryUnitSnapshot;
import work.slhaf.partner.framework.agent.model.pojo.Message;
import work.slhaf.partner.framework.agent.support.Result;
import work.slhaf.partner.module.communication.summarizer.MessageCompressor;
@@ -63,19 +65,19 @@ class DialogRollingTest {
message(Message.Character.ASSISTANT, "new-assistant")
), 4, 6);
MemoryUnit merged = memoryCapability.getMemoryUnit(sessionId);
MemoryUnitSnapshot merged = memoryCapability.getMemoryUnit(sessionId);
assertEquals(List.of("old-user", "old-assistant", "new-user", "new-assistant"),
merged.getConversationMessages().stream().map(Message::getContent).toList());
assertEquals(2, merged.getSlices().size());
MemorySlice appendedSlice = merged.getSlices().getLast();
MemorySliceSnapshot appendedSlice = merged.getSlices().getLast();
assertNotNull(appendedSlice.getId());
assertEquals(2, appendedSlice.getStartIndex());
assertEquals(4, appendedSlice.getEndIndex());
assertEquals("new-summary", appendedSlice.getSummary());
assertEquals(sessionId, rollingResult.memoryUnit().getId());
assertEquals(appendedSlice.getId(), rollingResult.memorySlice().getId());
assertEquals("new-summary", rollingResult.summary());
assertEquals(sessionId, rollingResult.getMemoryUnit().getId());
assertEquals(appendedSlice.getId(), rollingResult.getMemorySlice().getId());
assertEquals("new-summary", rollingResult.getSummary());
}
@Test
@@ -96,7 +98,7 @@ class DialogRollingTest {
message(Message.Character.ASSISTANT, "second")
), 2, 6);
MemoryUnit created = memoryCapability.getMemoryUnit(sessionId);
MemoryUnitSnapshot created = memoryCapability.getMemoryUnit(sessionId);
assertNotNull(created);
assertEquals(List.of("first", "second"),
created.getConversationMessages().stream().map(Message::getContent).toList());
@@ -104,7 +106,7 @@ class DialogRollingTest {
assertEquals(0, created.getSlices().getFirst().getStartIndex());
assertEquals(2, created.getSlices().getFirst().getEndIndex());
assertEquals("fresh-summary", created.getSlices().getFirst().getSummary());
assertEquals(created, rollingResult.memoryUnit());
assertEquals(created, rollingResult.getMemoryUnit());
}
@Test
@@ -151,8 +153,8 @@ class DialogRollingTest {
message(Message.Character.ASSISTANT, "a1")
), 2, 6);
assertEquals(sessionId, rollingResult.memoryUnit().getId());
assertEquals("no summary, due to empty summarize result", rollingResult.summary());
assertEquals(sessionId, rollingResult.getMemoryUnit().getId());
assertEquals("no summary, due to empty summarize result", rollingResult.getSummary());
}
private static final class StubMemoryCapability implements MemoryCapability {
@@ -172,28 +174,29 @@ class DialogRollingTest {
}
@Override
public MemoryUnit getMemoryUnit(String unitId) {
return units.get(unitId);
public MemoryUnitSnapshot getMemoryUnit(String unitId) {
MemoryUnit unit = units.get(unitId);
return unit == null ? null : unit.snapshot();
}
@Override
public work.slhaf.partner.framework.agent.support.Result<MemorySlice> getMemorySlice(String unitId, String sliceId) {
public work.slhaf.partner.framework.agent.support.Result<MemorySliceSnapshot> getMemorySlice(String unitId, String sliceId) {
return null;
}
@Override
public MemoryUnit updateMemoryUnit(List<Message> chatMessages, String summary) {
public MemoryUnitSnapshot updateMemoryUnit(List<Message> chatMessages, String summary) {
MemoryUnit unit = units.computeIfAbsent(sessionId, MemoryUnit::new);
unit.updateTimestamp();
int startIndex = unit.getConversationMessages().size();
unit.getConversationMessages().addAll(chatMessages);
unit.getSlices().add(new MemorySlice(startIndex, startIndex + chatMessages.size(), summary));
return unit;
return unit.snapshot();
}
@Override
public Collection<MemoryUnit> listMemoryUnits() {
return units.values();
public Collection<MemoryUnitSnapshot> listMemoryUnits() {
return units.values().stream().map(MemoryUnit::snapshot).toList();
}
@Override

View File

@@ -11,7 +11,9 @@ import work.slhaf.partner.core.cognition.CognitionCapability;
import work.slhaf.partner.core.cognition.context.ContextWorkspace;
import work.slhaf.partner.core.memory.MemoryCapability;
import work.slhaf.partner.core.memory.pojo.MemorySlice;
import work.slhaf.partner.core.memory.pojo.MemorySliceSnapshot;
import work.slhaf.partner.core.memory.pojo.MemoryUnit;
import work.slhaf.partner.core.memory.pojo.MemoryUnitSnapshot;
import work.slhaf.partner.framework.agent.model.pojo.Message;
import work.slhaf.partner.framework.agent.support.Result;
import work.slhaf.partner.module.memory.pojo.ActivationProfile;
@@ -19,7 +21,6 @@ import work.slhaf.partner.module.memory.runtime.exception.MemoryLookupException;
import work.slhaf.partner.module.memory.selector.ActivatedMemorySlice;
import java.lang.reflect.Field;
import java.lang.reflect.Method;
import java.nio.file.Path;
import java.time.LocalDate;
import java.util.Collection;
@@ -41,11 +42,8 @@ class MemoryRuntimeTest {
System.setProperty("user.home", tempDir.toAbsolutePath().toString());
}
@SuppressWarnings("unchecked")
private static List<Message> invokeSliceMessages(MemoryRuntime runtime, MemoryUnit unit, MemorySlice slice) throws Exception {
Method method = MemoryRuntime.class.getDeclaredMethod("sliceMessages", MemoryUnit.class, MemorySlice.class);
method.setAccessible(true);
return (List<Message>) method.invoke(runtime, unit, slice);
private static List<Message> invokeSliceMessages(MemoryRuntime runtime, MemoryUnit unit, MemorySlice slice) {
return unit.snapshot().messagesOf(slice.snapshot());
}
private static void setField(Object target, String fieldName, Object value) throws Exception {
@@ -128,7 +126,7 @@ class MemoryRuntimeTest {
}
@Override
public boolean renameEntitySubject(String entityUuid, String newSubject) {
public boolean renameEntitySubject(String entityUuid, String newSubject, boolean keepOldSubjectAsAlias) {
return false;
}
@@ -200,7 +198,7 @@ class MemoryRuntimeTest {
unit.getSlices().addAll(List.of(firstSlice, secondSlice));
memoryCapability.remember(unit);
runtime.recordMemory(unit, "topic/main", List.of("topic/related"), DEFAULT_PROFILE);
runtime.recordMemory(unit.snapshot(), "topic/main", List.of("topic/related"), DEFAULT_PROFILE);
List<ActivatedMemorySlice> topicResult = runtime.queryActivatedMemoryByTopicPath("topic/main");
assertEquals(List.of("slice-2"), topicResult.stream().map(ActivatedMemorySlice::getSliceId).toList());
@@ -240,8 +238,8 @@ class MemoryRuntimeTest {
relatedUnit.getSlices().add(relatedSlice);
memoryCapability.remember(relatedUnit);
runtime.recordMemory(mainUnit, "topic/main", List.of("topic/related"), DEFAULT_PROFILE);
runtime.recordMemory(relatedUnit, "topic/related", List.of(), DEFAULT_PROFILE);
runtime.recordMemory(mainUnit.snapshot(), "topic/main", List.of("topic/related"), DEFAULT_PROFILE);
runtime.recordMemory(relatedUnit.snapshot(), "topic/related", List.of(), DEFAULT_PROFILE);
List<ActivatedMemorySlice> topicResult = runtime.queryActivatedMemoryByTopicPath("topic/main");
assertEquals(List.of("slice-main", "slice-related"),
@@ -260,7 +258,7 @@ class MemoryRuntimeTest {
MemorySlice firstSlice = MemorySlice.restore("slice-1", 0, 1, "first", 86_400_000L);
firstUnitSnapshot.getSlices().add(firstSlice);
memoryCapability.remember(firstUnitSnapshot);
runtime.recordMemory(firstUnitSnapshot, "topic/main", List.of(), DEFAULT_PROFILE);
runtime.recordMemory(firstUnitSnapshot.snapshot(), "topic/main", List.of(), DEFAULT_PROFILE);
firstUnitSnapshot.getConversationMessages().clear();
firstUnitSnapshot.getConversationMessages().addAll(List.of(message("m2"), message("m3")));
@@ -268,7 +266,7 @@ class MemoryRuntimeTest {
firstUnitSnapshot.getSlices().clear();
firstUnitSnapshot.getSlices().add(secondSlice);
memoryCapability.remember(firstUnitSnapshot);
runtime.recordMemory(firstUnitSnapshot, "topic/main", List.of(), DEFAULT_PROFILE);
runtime.recordMemory(firstUnitSnapshot.snapshot(), "topic/main", List.of(), DEFAULT_PROFILE);
JSONObject state = JSONObject.parseObject(runtime.convert().toString());
JSONArray dateIndex = state.getJSONArray("date_index");
@@ -306,14 +304,14 @@ class MemoryRuntimeTest {
MemorySlice secondSlice = MemorySlice.restore("slice-2", 2, 4, "second", 172_800_000L);
mainUnit.getSlices().addAll(List.of(firstSlice, secondSlice));
memoryCapability.remember(mainUnit);
runtime.recordMemory(mainUnit, "topic/main", List.of("topic/related"), DEFAULT_PROFILE);
runtime.recordMemory(mainUnit.snapshot(), "topic/main", List.of("topic/related"), DEFAULT_PROFILE);
MemoryUnit relatedUnit = new MemoryUnit("unit-201");
relatedUnit.getConversationMessages().addAll(List.of(message("r0"), message("r1")));
MemorySlice relatedSlice = MemorySlice.restore("slice-3", 0, 2, "related", 259_200_000L);
relatedUnit.getSlices().add(relatedSlice);
memoryCapability.remember(relatedUnit);
runtime.recordMemory(relatedUnit, "topic/related", List.of(), DEFAULT_PROFILE);
runtime.recordMemory(relatedUnit.snapshot(), "topic/related", List.of(), DEFAULT_PROFILE);
JSONObject state = JSONObject.parseObject(runtime.convert().toString());
JSONArray topicSlices = state.getJSONArray("topic_slices");
@@ -380,21 +378,21 @@ class MemoryRuntimeTest {
MemorySlice primarySlice = MemorySlice.restore("slice-primary", 0, 2, "primary", System.currentTimeMillis());
primaryUnit.getSlices().add(primarySlice);
memoryCapability.remember(primaryUnit);
runtime.recordMemory(primaryUnit, "topic->main", List.of("topic->related"), new ActivationProfile(0.9f, 0.1f, 0.9f));
runtime.recordMemory(primaryUnit.snapshot(), "topic->main", List.of("topic->related"), new ActivationProfile(0.9f, 0.1f, 0.9f));
MemoryUnit relatedUnit = new MemoryUnit("unit-related-rank");
relatedUnit.getConversationMessages().addAll(List.of(message("r0"), message("r1")));
MemorySlice relatedSlice = MemorySlice.restore("slice-related-rank", 0, 2, "related", System.currentTimeMillis());
relatedUnit.getSlices().add(relatedSlice);
memoryCapability.remember(relatedUnit);
runtime.recordMemory(relatedUnit, "topic->related", List.of(), new ActivationProfile(1.0f, 1.0f, 1.0f));
runtime.recordMemory(relatedUnit.snapshot(), "topic->related", List.of(), new ActivationProfile(1.0f, 1.0f, 1.0f));
MemoryUnit parentUnit = new MemoryUnit("unit-parent");
parentUnit.getConversationMessages().addAll(List.of(message("x0"), message("x1")));
MemorySlice parentSlice = MemorySlice.restore("slice-parent", 0, 2, "parent", System.currentTimeMillis());
parentUnit.getSlices().add(parentSlice);
memoryCapability.remember(parentUnit);
runtime.recordMemory(parentUnit, "topic", List.of(), new ActivationProfile(1.0f, 1.0f, 1.0f));
runtime.recordMemory(parentUnit.snapshot(), "topic", List.of(), new ActivationProfile(1.0f, 1.0f, 1.0f));
List<ActivatedMemorySlice> topicResult = runtime.queryActivatedMemoryByTopicPath("topic->main");
assertEquals(List.of("slice-primary", "slice-related-rank", "slice-parent"),
@@ -414,7 +412,7 @@ class MemoryRuntimeTest {
primaryUnit.getSlices().add(primarySlice);
memoryCapability.remember(primaryUnit);
runtime.recordMemory(
primaryUnit,
primaryUnit.snapshot(),
"topic->main",
List.of("topic->related"),
new ActivationProfile(0.8f, 0.0f, 0.8f)
@@ -425,7 +423,7 @@ class MemoryRuntimeTest {
MemorySlice relatedSlice = MemorySlice.restore("slice-related-zero", 0, 2, "related", System.currentTimeMillis());
relatedUnit.getSlices().add(relatedSlice);
memoryCapability.remember(relatedUnit);
runtime.recordMemory(relatedUnit, "topic->related", List.of(), new ActivationProfile(1.0f, 1.0f, 1.0f));
runtime.recordMemory(relatedUnit.snapshot(), "topic->related", List.of(), new ActivationProfile(1.0f, 1.0f, 1.0f));
List<ActivatedMemorySlice> topicResult = runtime.queryActivatedMemoryByTopicPath("topic->main");
assertEquals(List.of("slice-primary-zero"), topicResult.stream().map(ActivatedMemorySlice::getSliceId).toList());
@@ -444,10 +442,10 @@ class MemoryRuntimeTest {
unit.getSlices().add(slice);
memoryCapability.remember(unit);
runtime.recordMemory(unit, "topic->main", List.of("topic->related"), new ActivationProfile(0.2f, 0.1f, 0.2f));
runtime.recordMemory(unit.snapshot(), "topic->main", List.of("topic->related"), new ActivationProfile(0.2f, 0.1f, 0.2f));
unit.getSlices().clear();
unit.getSlices().add(MemorySlice.restore("slice-refresh", 0, 2, "summary", 172_800_000L));
runtime.recordMemory(unit, "topic->main", List.of("topic->related-2"), new ActivationProfile(0.9f, 0.8f, 0.7f));
runtime.recordMemory(unit.snapshot(), "topic->main", List.of("topic->related-2"), new ActivationProfile(0.9f, 0.8f, 0.7f));
JSONObject state = JSONObject.parseObject(runtime.convert().toString());
JSONObject mainTopic = state.getJSONArray("topic_slices").stream()
@@ -481,12 +479,13 @@ class MemoryRuntimeTest {
}
@Override
public MemoryUnit getMemoryUnit(String unitId) {
return units.get(unitId);
public MemoryUnitSnapshot getMemoryUnit(String unitId) {
MemoryUnit unit = units.get(unitId);
return unit == null ? null : unit.snapshot();
}
@Override
public Result<MemorySlice> getMemorySlice(String unitId, String sliceId) {
public Result<MemorySliceSnapshot> getMemorySlice(String unitId, String sliceId) {
MemoryUnit unit = units.get(unitId);
if (unit == null || unit.getSlices() == null) {
return Result.failure(new MemoryLookupException(
@@ -498,7 +497,7 @@ class MemoryRuntimeTest {
return unit.getSlices().stream()
.filter(slice -> sliceId.equals(slice.getId()))
.findFirst()
.map(Result::success)
.map(slice -> Result.success(slice.snapshot()))
.orElseGet(() -> Result.failure(new MemoryLookupException(
"Memory slice not found: " + unitId + ":" + sliceId,
unitId + ":" + sliceId,
@@ -507,13 +506,13 @@ class MemoryRuntimeTest {
}
@Override
public MemoryUnit updateMemoryUnit(List<Message> chatMessages, String summary) {
public MemoryUnitSnapshot updateMemoryUnit(List<Message> chatMessages, String summary) {
return null;
}
@Override
public Collection<MemoryUnit> listMemoryUnits() {
return units.values();
public Collection<MemoryUnitSnapshot> listMemoryUnits() {
return units.values().stream().map(MemoryUnit::snapshot).toList();
}
@Override

View File

@@ -79,13 +79,10 @@ class MemoryRecallProfileExtractorTest {
MemorySlice slice = new MemorySlice(2, 4, "slice-summary");
unit.getSlices().add(slice);
updater.consume(new RollingResult(unit, slice, List.of(
message(Message.Character.USER, "new"),
message(Message.Character.ASSISTANT, "new-reply")
), "slice-summary", 4, 6));
updater.consume(new RollingResult(unit.snapshot(), slice.snapshot(), 4, 6));
verify(memoryRuntime).recordMemory(
eq(unit),
eq(unit.snapshot()),
eq("root->branch"),
eq(List.of("root->related")),
argThat(profile -> profile != null
@@ -113,10 +110,10 @@ class MemoryRecallProfileExtractorTest {
MemorySlice slice = new MemorySlice(0, 2, "slice-summary");
unit.getSlices().add(slice);
updater.consume(new RollingResult(unit, slice, unit.getConversationMessages(), "slice-summary", 2, 6));
updater.consume(new RollingResult(unit.snapshot(), slice.snapshot(), 2, 6));
verify(memoryRuntime).recordMemory(
eq(unit),
eq(unit.snapshot()),
eq(null),
eq(List.of()),
argThat(profile -> profile != null
@@ -147,10 +144,10 @@ class MemoryRecallProfileExtractorTest {
MemorySlice slice = new MemorySlice(0, 1, "slice-summary");
unit.getSlices().add(slice);
updater.consume(new RollingResult(unit, slice, unit.getConversationMessages(), "slice-summary", 1, 6));
updater.consume(new RollingResult(unit.snapshot(), slice.snapshot(), 1, 6));
verify(memoryRuntime).recordMemory(
eq(unit),
eq(unit.snapshot()),
eq("root->branch"),
eq(List.of()),
argThat(profile -> profile != null