refactor(MemoryUpdater): move auto-update to ActionScheduler cron and use Cognation snapshot/drain APIs for thread-safe memory refresh

This commit is contained in:
2026-03-08 13:11:20 +08:00
parent 7df0f208b5
commit 65690c65f8
7 changed files with 232 additions and 99 deletions

View File

@@ -173,10 +173,10 @@ data class StateAction @JvmOverloads constructor(
override val scheduleType: Schedulable.ScheduleType,
override val scheduleContent: String,
val trigger: Trigger,
override var enabled: Boolean = true,
override val timeout: Duration = 5.minutes,
val trigger: Trigger
) : Action(), Schedulable {
sealed interface Trigger {

View File

@@ -6,6 +6,7 @@ import work.slhaf.partner.api.chat.pojo.MetaMessage;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.locks.Lock;
@Capability("cognation")
@@ -15,6 +16,10 @@ public interface CognationCapability {
List<Message> getChatMessages();
List<Message> snapshotChatMessages();
void rollChatMessagesWithSnapshot(int snapshotSize, int retainDivisor);
void cleanMessage(List<Message> messages);
Lock getMessageLock();
@@ -31,6 +36,10 @@ public interface CognationCapability {
HashMap<String, List<MetaMessage>> getSingleMetaMessageMap();
Map<String, List<MetaMessage>> drainSingleMetaMessages();
List<MetaMessage> snapshotSingleMetaMessages(String userId);
String getCurrentMemoryId();
}

View File

@@ -15,10 +15,7 @@ import work.slhaf.partner.runtime.interaction.data.context.PartnerRunningFlowCon
import java.io.IOException;
import java.io.Serial;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.UUID;
import java.util.*;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReentrantLock;
@@ -57,6 +54,38 @@ public class CognationCore extends PartnerCore<CognationCore> {
return chatMessages;
}
@CapabilityMethod
public List<Message> snapshotChatMessages() {
messageLock.lock();
try {
return List.copyOf(chatMessages);
} finally {
messageLock.unlock();
}
}
@CapabilityMethod
public void rollChatMessagesWithSnapshot(int snapshotSize, int retainDivisor) {
messageLock.lock();
try {
int safeSnapshotSize = Math.max(0, Math.min(snapshotSize, chatMessages.size()));
if (safeSnapshotSize == 0) {
return;
}
int safeDivisor = Math.max(retainDivisor, 1);
int retainCount = safeSnapshotSize / safeDivisor;
int retainStart = Math.max(0, safeSnapshotSize - retainCount);
List<Message> rolled = new ArrayList<>(chatMessages.subList(retainStart, safeSnapshotSize));
if (chatMessages.size() > safeSnapshotSize) {
rolled.addAll(chatMessages.subList(safeSnapshotSize, chatMessages.size()));
}
chatMessages = rolled;
} finally {
messageLock.unlock();
}
}
@CapabilityMethod
public long getLastUpdatedTime() {
return lastUpdatedTime;
@@ -75,9 +104,11 @@ public class CognationCore extends PartnerCore<CognationCore> {
@CapabilityMethod
public void cleanMessage(List<Message> messages) {
messageLock.lock();
this.getChatMessages().removeAll(messages);
messageLock.unlock();
try {
this.getChatMessages().removeAll(messages);
} finally {
messageLock.unlock();
}
}
@CapabilityMethod
@@ -88,24 +119,67 @@ public class CognationCore extends PartnerCore<CognationCore> {
@CapabilityMethod
public void addMetaMessage(String userId, MetaMessage metaMessage) {
log.debug("[{}] 当前会话历史: {}", getCoreKey(), JSONObject.toJSONString(singleMetaMessageMap));
if (singleMetaMessageMap.containsKey(userId)) {
singleMetaMessageMap.get(userId).add(metaMessage);
} else {
singleMetaMessageMap.put(userId, new java.util.ArrayList<>());
singleMetaMessageMap.get(userId).add(metaMessage);
messageLock.lock();
try {
if (singleMetaMessageMap.containsKey(userId)) {
singleMetaMessageMap.get(userId).add(metaMessage);
} else {
singleMetaMessageMap.put(userId, new java.util.ArrayList<>());
singleMetaMessageMap.get(userId).add(metaMessage);
}
} finally {
messageLock.unlock();
}
log.debug("[{}] 会话历史更新: {}", getCoreKey(), JSONObject.toJSONString(singleMetaMessageMap));
}
@CapabilityMethod
public List<Message> unpackAndClear(String userId) {
List<Message> messages = new ArrayList<>();
for (MetaMessage metaMessage : singleMetaMessageMap.get(userId)) {
messages.add(metaMessage.getUserMessage());
messages.add(metaMessage.getAssistantMessage());
messageLock.lock();
try {
List<Message> messages = new ArrayList<>();
List<MetaMessage> metaMessages = singleMetaMessageMap.get(userId);
if (metaMessages == null) {
return messages;
}
for (MetaMessage metaMessage : metaMessages) {
messages.add(metaMessage.getUserMessage());
messages.add(metaMessage.getAssistantMessage());
}
singleMetaMessageMap.remove(userId);
return messages;
} finally {
messageLock.unlock();
}
}
@CapabilityMethod
public Map<String, List<MetaMessage>> drainSingleMetaMessages() {
messageLock.lock();
try {
Map<String, List<MetaMessage>> drained = new HashMap<>();
for (Map.Entry<String, List<MetaMessage>> entry : singleMetaMessageMap.entrySet()) {
drained.put(entry.getKey(), new ArrayList<>(entry.getValue()));
}
singleMetaMessageMap.clear();
return drained;
} finally {
messageLock.unlock();
}
}
@CapabilityMethod
public List<MetaMessage> snapshotSingleMetaMessages(String userId) {
messageLock.lock();
try {
List<MetaMessage> metaMessages = singleMetaMessageMap.get(userId);
if (metaMessages == null) {
return List.of();
}
return List.copyOf(metaMessages);
} finally {
messageLock.unlock();
}
singleMetaMessageMap.remove(userId);
return messages;
}
@CapabilityMethod

View File

@@ -224,7 +224,7 @@ public class ActionPlanner extends PreRunningAbstractAgentModuleAbstract {
private ExtractorInput buildExtractorInput(PartnerRunningFlowContext context) {
ExtractorInput input = new ExtractorInput();
input.setInput(context.getInput());
List<Message> chatMessages = cognationCapability.getChatMessages();
List<Message> chatMessages = cognationCapability.snapshotChatMessages();
List<Message> recentMessages = new ArrayList<>();
if (chatMessages.size() > 5) {
recentMessages.addAll(chatMessages.subList(chatMessages.size() - 5, chatMessages.size() - 1));
@@ -239,7 +239,7 @@ public class ActionPlanner extends PreRunningAbstractAgentModuleAbstract {
EvaluatorInput input = new EvaluatorInput();
input.setTendencies(extractorResult.getTendencies());
input.setUser(perceiveCapability.getUser(userId));
input.setRecentMessages(cognationCapability.getChatMessages());
input.setRecentMessages(cognationCapability.snapshotChatMessages());
input.setActivatedSlices(memoryCapability.getActivatedSlices(userId));
return input;
}

View File

@@ -24,6 +24,7 @@ import java.io.Closeable
import java.time.Duration
import java.time.ZonedDateTime
import java.time.temporal.ChronoUnit
import java.util.*
import java.util.stream.Collectors
import kotlin.jvm.optionals.getOrNull
import kotlin.time.Duration.Companion.milliseconds
@@ -35,6 +36,8 @@ class ActionScheduler : AbstractAgentModule.Standalone() {
@InjectModule
private lateinit var actionExecutor: ActionExecutor
private lateinit var timeWheel: TimeWheel
private val runtimeSchedulables: MutableSet<Schedulable> =
Collections.synchronizedSet(mutableSetOf())
private val schedulerScope =
CoroutineScope(Dispatchers.Default + SupervisorJob() + CoroutineName("ActionScheduler"))
@@ -45,12 +48,18 @@ class ActionScheduler : AbstractAgentModule.Standalone() {
@Init
fun init() {
fun loadScheduledActions() {
val listScheduledActions: () -> Set<SchedulableExecutableAction> = {
actionCapability.listActions(null, null)
val listScheduledActions: () -> Set<Schedulable> = {
val persistedExecutable = actionCapability.listActions(null, null)
.stream()
.filter { it is SchedulableExecutableAction }
.map { it as SchedulableExecutableAction }
.collect(Collectors.toSet())
.collect(Collectors.toSet<SchedulableExecutableAction>())
val persisted: MutableSet<Schedulable> = mutableSetOf()
persisted.addAll(persistedExecutable)
synchronized(runtimeSchedulables) {
persisted.addAll(runtimeSchedulables.filter { it.enabled })
}
persisted
}
val onTrigger: (Set<Schedulable>) -> Unit = { schedulableSet ->
schedulableSet.filterIsInstance<Action>()
@@ -81,6 +90,7 @@ class ActionScheduler : AbstractAgentModule.Standalone() {
if (!schedulableAction.enabled) {
return@launch
}
runtimeSchedulables.add(schedulableAction)
log.debug("New data to schedule: {}", schedulableAction)
timeWheel.schedule(schedulableAction)
if (schedulableAction is SchedulableExecutableAction) {

View File

@@ -37,14 +37,10 @@ public class MemorySelectExtractor extends AbstractAgentModule.Sub<PartnerRunnin
log.debug("[MemorySelectExtractor] 主题提取模块开始...");
// 结构化为指定格式
List<Message> chatMessages = new ArrayList<>();
List<MetaMessage> metaMessages = cognationCapability.getSingleMetaMessageMap().get(context.getSource());
if (metaMessages == null) {
cognationCapability.getSingleMetaMessageMap().put(context.getSource(), new ArrayList<>());
} else {
for (MetaMessage metaMessage : metaMessages) {
chatMessages.add(metaMessage.getUserMessage());
chatMessages.add(metaMessage.getAssistantMessage());
}
List<MetaMessage> metaMessages = cognationCapability.snapshotSingleMetaMessages(context.getSource());
for (MetaMessage metaMessage : metaMessages) {
chatMessages.add(metaMessage.getUserMessage());
chatMessages.add(metaMessage.getAssistantMessage());
}
ExtractorResult extractorResult;
try {

View File

@@ -1,6 +1,7 @@
package work.slhaf.partner.module.modules.memory.updater;
import com.alibaba.fastjson2.JSONObject;
import kotlin.Unit;
import lombok.Data;
import lombok.EqualsAndHashCode;
import work.slhaf.partner.api.agent.factory.capability.annotation.InjectCapability;
@@ -8,12 +9,16 @@ import work.slhaf.partner.api.agent.factory.component.annotation.Init;
import work.slhaf.partner.api.agent.factory.component.annotation.InjectModule;
import work.slhaf.partner.api.chat.constant.ChatConstant;
import work.slhaf.partner.api.chat.pojo.Message;
import work.slhaf.partner.api.chat.pojo.MetaMessage;
import work.slhaf.partner.common.thread.InteractionThreadPoolExecutor;
import work.slhaf.partner.core.action.entity.Schedulable;
import work.slhaf.partner.core.action.entity.StateAction;
import work.slhaf.partner.core.cognation.CognationCapability;
import work.slhaf.partner.core.memory.MemoryCapability;
import work.slhaf.partner.core.memory.pojo.MemorySlice;
import work.slhaf.partner.core.perceive.PerceiveCapability;
import work.slhaf.partner.module.common.module.PostRunningAgentModule;
import work.slhaf.partner.module.modules.action.scheduler.ActionScheduler;
import work.slhaf.partner.module.modules.memory.updater.summarizer.MultiSummarizer;
import work.slhaf.partner.module.modules.memory.updater.summarizer.SingleSummarizer;
import work.slhaf.partner.module.modules.memory.updater.summarizer.TotalSummarizer;
@@ -24,6 +29,8 @@ import work.slhaf.partner.runtime.interaction.data.context.PartnerRunningFlowCon
import java.time.LocalDateTime;
import java.util.*;
import java.util.concurrent.Callable;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import static work.slhaf.partner.common.util.ExtractUtil.extractUserId;
@@ -32,8 +39,9 @@ import static work.slhaf.partner.common.util.ExtractUtil.extractUserId;
@Data
public class MemoryUpdater extends PostRunningAgentModule {
private static final long SCHEDULED_UPDATE_INTERVAL = 10 * 1000;
private static final String AUTO_UPDATE_CRON = "0/10 * * * * ?";
private static final long UPDATE_TRIGGER_INTERVAL = 60 * 60 * 1000;
private static final int CONTEXT_RETAIN_DIVISOR = 6;
@InjectCapability
private CognationCapability cognationCapability;
@@ -48,40 +56,32 @@ public class MemoryUpdater extends PostRunningAgentModule {
private SingleSummarizer singleSummarizer;
@InjectModule
private TotalSummarizer totalSummarizer;
private final AtomicBoolean updating = new AtomicBoolean(false);
private InteractionThreadPoolExecutor executor;
/**
* 用于临时存储完整对话记录在MemoryManager的分离后
*/
private List<Message> tempMessage;
@InjectModule
private ActionScheduler actionScheduler;
@Init
public void init() {
executor = InteractionThreadPoolExecutor.getInstance();
setScheduledUpdater();
registerScheduledUpdater();
}
private void setScheduledUpdater() {
executor.execute(() -> {
log.info("[MemoryUpdater] 记忆自动更新线程启动");
while (!Thread.interrupted()) {
try {
long currentTime = System.currentTimeMillis();
long lastUpdatedTime = cognationCapability.getLastUpdatedTime();
int chatCount = cognationCapability.getChatMessages().size();
if (lastUpdatedTime != 0 && currentTime - lastUpdatedTime > UPDATE_TRIGGER_INTERVAL && chatCount > 1) {
updateMemory();
cognationCapability.getChatMessages().clear();
// 重置MemoryId
cognationCapability.refreshMemoryId();
log.info("[MemoryUpdater] 记忆更新: 自动触发");
}
Thread.sleep(SCHEDULED_UPDATE_INTERVAL);
} catch (Exception e) {
log.error("[MemoryUpdater] 记忆自动更新线程出错: ", e);
}
}
log.info("[MemoryUpdater] 记忆自动更新线程结束");
});
private void registerScheduledUpdater() {
StateAction stateAction = new StateAction(
"system",
"memory-auto-update",
"定时检查并触发记忆更新",
Schedulable.ScheduleType.CYCLE,
AUTO_UPDATE_CRON,
new StateAction.Trigger.Call(() -> {
tryAutoUpdate();
return Unit.INSTANCE;
})
);
actionScheduler.schedule(stateAction);
log.info("[MemoryUpdater] 记忆自动更新已注册到 ActionScheduler, cron={}", AUTO_UPDATE_CRON);
}
@Override
@@ -99,14 +99,8 @@ public class MemoryUpdater extends PostRunningAgentModule {
if (!trigger) {
return;
}
try {
log.debug("[MemoryUpdater] 记忆更新触发");
updateMemory();
// 清空chatMessages
clearChatMessages();
} catch (Exception e) {
log.error("[MemoryUpdater] 记忆更新线程出错: ", e);
}
log.debug("[MemoryUpdater] 记忆更新触发");
triggerMemoryUpdate(false);
});
}
@@ -115,26 +109,81 @@ public class MemoryUpdater extends PostRunningAgentModule {
return true;
}
private void updateMemory() {
private void tryAutoUpdate() {
long currentTime = System.currentTimeMillis();
long lastUpdatedTime = cognationCapability.getLastUpdatedTime();
int chatCount = cognationCapability.snapshotChatMessages().size();
if (lastUpdatedTime != 0 && currentTime - lastUpdatedTime > UPDATE_TRIGGER_INTERVAL && chatCount > 1) {
triggerMemoryUpdate(true);
log.info("[MemoryUpdater] 记忆更新: 自动触发");
}
}
private void triggerMemoryUpdate(boolean refreshMemoryId) {
if (!updating.compareAndSet(false, true)) {
log.debug("[MemoryUpdater] 更新任务已在执行中,本次触发跳过");
return;
}
try {
List<Message> chatSnapshot = cognationCapability.snapshotChatMessages();
if (chatSnapshot.size() <= 1) {
return;
}
updateMemory(chatSnapshot);
cognationCapability.rollChatMessagesWithSnapshot(chatSnapshot.size(), CONTEXT_RETAIN_DIVISOR);
if (refreshMemoryId) {
cognationCapability.refreshMemoryId();
}
} catch (Exception e) {
log.error("[MemoryUpdater] 记忆更新线程出错: ", e);
} finally {
updating.set(false);
}
}
private void updateMemory(List<Message> chatSnapshot) {
log.debug("[MemoryUpdater] 记忆更新流程开始...");
tempMessage = new ArrayList<>(cognationCapability.getChatMessages());
HashMap<String, String> singleMemorySummary = new HashMap<>();
Map<String, String> singleMemorySummary = new ConcurrentHashMap<>();
Map<String, List<Message>> singleChatMessages = drainSingleChatMessages();
// 更新单聊记忆同时从chatMessages中去掉单聊记忆
updateSingleChatSlices(singleMemorySummary);
updateSingleChatSlices(singleMemorySummary, singleChatMessages);
// 更新多人场景下的记忆及相关的确定性记忆
updateMultiChatSlices(singleMemorySummary);
List<Message> multiChatMessages = excludeSingleChatMessages(chatSnapshot, singleChatMessages);
updateMultiChatSlices(singleMemorySummary, multiChatMessages);
cognationCapability.resetLastUpdatedTime();
log.debug("[MemoryUpdater] 记忆更新流程结束...");
}
private void updateMultiChatSlices(HashMap<String, String> singleMemorySummary) {
private Map<String, List<Message>> drainSingleChatMessages() {
Map<String, List<Message>> drainedMessages = new HashMap<>();
Map<String, List<MetaMessage>> drainedMetaMessages = cognationCapability.drainSingleMetaMessages();
for (Map.Entry<String, List<MetaMessage>> entry : drainedMetaMessages.entrySet()) {
List<Message> messages = new ArrayList<>();
for (MetaMessage metaMessage : entry.getValue()) {
messages.add(metaMessage.getUserMessage());
messages.add(metaMessage.getAssistantMessage());
}
drainedMessages.put(entry.getKey(), messages);
}
return drainedMessages;
}
private List<Message> excludeSingleChatMessages(List<Message> chatSnapshot, Map<String, List<Message>> singleChatMessages) {
Set<Message> singleMessages = new HashSet<>();
for (List<Message> messages : singleChatMessages.values()) {
singleMessages.addAll(messages);
}
return chatSnapshot.stream()
.filter(message -> !singleMessages.contains(message))
.toList();
}
private void updateMultiChatSlices(Map<String, String> singleMemorySummary, List<Message> multiChatMessages) {
// 此时chatMessages中不再包含单聊记录直接执行摘要以及切片插入
// 对剩下的多人聊天记录进行进行摘要
Callable<Void> task = () -> {
log.debug("[MemoryUpdater] 多人聊天记忆更新流程开始...");
cognationCapability.getMessageLock().lock();
List<Message> chatMessages = getCleanedMessages(cognationCapability.getChatMessages());
cognationCapability.getMessageLock().unlock();
List<Message> chatMessages = getCleanedMessages(multiChatMessages);
if (!chatMessages.isEmpty()) {
log.debug("[MemoryUpdater] 存在多人聊天记录, 流程正常进行...");
// 以第一条user对应的id为发起用户
@@ -153,7 +202,7 @@ public class MemoryUpdater extends PostRunningAgentModule {
memoryCapability.updateDialogMap(LocalDateTime.now(), summarizeResult.getSummary());
} else {
log.debug("[MemoryUpdater] 不存在多人聊天记录, 将以单聊总结为对话缓存的主要输入: {}", singleMemorySummary);
memoryCapability.updateDialogMap(LocalDateTime.now(), totalSummarizer.execute(singleMemorySummary));
memoryCapability.updateDialogMap(LocalDateTime.now(), totalSummarizer.execute(new HashMap<>(singleMemorySummary)));
}
log.debug("[MemoryUpdater] 对话缓存更新完毕");
log.debug("[MemoryUpdater] 多人聊天记忆更新流程结束...");
@@ -169,21 +218,15 @@ public class MemoryUpdater extends PostRunningAgentModule {
if (message.getRole().equals(ChatConstant.Character.ASSISTANT)) {
return message;
}
String time = Arrays.stream(message.getContent().split("\\*\\*")).toList().getLast();
List<String> splitResult = Arrays.stream(message.getContent().split("\\*\\*")).toList();
if (splitResult.isEmpty()) {
return message;
}
String time = splitResult.getLast();
return new Message(ChatConstant.Character.USER, message.getContent().replace("\r\n**" + time, ""));
}).toList();
}
private void clearChatMessages() {
// 不全部清空,保留一部分输入防止上下文割裂
cognationCapability.getMessageLock().lock();
List<Message> temp = new ArrayList<>(
tempMessage.subList(tempMessage.size() - tempMessage.size() / 6, tempMessage.size()));
cognationCapability.getChatMessages().removeAll(tempMessage);
cognationCapability.getChatMessages().addAll(0, temp);
cognationCapability.getMessageLock().unlock();
}
private void setInvolvedUserId(String startUserId, MemorySlice memorySlice, List<Message> chatMessages) {
for (Message chatMessage : chatMessages) {
if (chatMessage.getRole().equals(ChatConstant.Character.ASSISTANT)) {
@@ -202,15 +245,16 @@ public class MemoryUpdater extends PostRunningAgentModule {
}
}
private void updateSingleChatSlices(HashMap<String, String> singleMemorySummary) {
private void updateSingleChatSlices(Map<String, String> singleMemorySummary, Map<String, List<Message>> singleChatMessages) {
log.debug("[MemoryUpdater] 单聊记忆更新流程开始...");
// 更新单聊记忆同时从chatMessages中去掉单聊记忆
Set<String> userIdSet = new HashSet<>(cognationCapability.getSingleMetaMessageMap().keySet());
List<Callable<Void>> tasks = new ArrayList<>();
// 多人聊天?
AtomicInteger count = new AtomicInteger(0);
for (String id : userIdSet) {
List<Message> messages = cognationCapability.unpackAndClear(id);
for (Map.Entry<String, List<Message>> entry : singleChatMessages.entrySet()) {
String id = entry.getKey();
List<Message> messages = entry.getValue();
if (messages.isEmpty()) {
continue;
}
tasks.add(() -> {
int thisCount = count.incrementAndGet();
log.debug("[MemoryUpdater] 单聊记忆[{}]更新: {}", thisCount, id);