refactor(memory): decouple memory storage and runtime structures

This commit is contained in:
2026-03-10 19:41:05 +08:00
parent 760ba8300b
commit 5ad80d8b86
32 changed files with 603 additions and 1134 deletions

View File

@@ -2,11 +2,8 @@ package work.slhaf.partner.core.cognation;
import work.slhaf.partner.api.agent.factory.capability.annotation.Capability;
import work.slhaf.partner.api.chat.pojo.Message;
import work.slhaf.partner.api.chat.pojo.MetaMessage;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.locks.Lock;
@Capability("cognation")
@@ -20,26 +17,6 @@ public interface CognationCapability {
void rollChatMessagesWithSnapshot(int snapshotSize, int retainDivisor);
void cleanMessage(List<Message> messages);
Lock getMessageLock();
void addMetaMessage(String userId, MetaMessage metaMessage);
List<Message> unpackAndClear(String userId);
void refreshMemoryId();
void resetLastUpdatedTime();
long getLastUpdatedTime();
HashMap<String, List<MetaMessage>> getSingleMetaMessageMap();
Map<String, List<MetaMessage>> drainSingleMetaMessages();
List<MetaMessage> snapshotSingleMetaMessages(String userId);
String getCurrentMemoryId();
}

View File

@@ -1,6 +1,5 @@
package work.slhaf.partner.core.cognation;
import com.alibaba.fastjson2.JSONObject;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.Setter;
@@ -9,13 +8,13 @@ import work.slhaf.partner.api.agent.factory.capability.annotation.CapabilityCore
import work.slhaf.partner.api.agent.factory.capability.annotation.CapabilityMethod;
import work.slhaf.partner.api.agent.runtime.interaction.AgentRuntime;
import work.slhaf.partner.api.chat.pojo.Message;
import work.slhaf.partner.api.chat.pojo.MetaMessage;
import work.slhaf.partner.core.PartnerCore;
import work.slhaf.partner.runtime.interaction.data.context.PartnerRunningFlowContext;
import java.io.IOException;
import java.io.Serial;
import java.util.*;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReentrantLock;
@@ -35,10 +34,6 @@ public class CognationCore extends PartnerCore<CognationCore> {
* 主模型的聊天记录
*/
private List<Message> chatMessages = new ArrayList<>();
private HashMap<String /*startUserId*/, List<MetaMessage>> singleMetaMessageMap = new HashMap<>();
private String currentMemoryId;
private long lastUpdatedTime;
public CognationCore() throws IOException, ClassNotFoundException {
}
@@ -86,112 +81,11 @@ public class CognationCore extends PartnerCore<CognationCore> {
}
}
@CapabilityMethod
public long getLastUpdatedTime() {
return lastUpdatedTime;
}
@CapabilityMethod
public HashMap<String, List<MetaMessage>> getSingleMetaMessageMap() {
return singleMetaMessageMap;
}
@CapabilityMethod
public String getCurrentMemoryId() {
return currentMemoryId;
}
@CapabilityMethod
public void cleanMessage(List<Message> messages) {
messageLock.lock();
try {
this.getChatMessages().removeAll(messages);
} finally {
messageLock.unlock();
}
}
@CapabilityMethod
public Lock getMessageLock() {
return messageLock;
}
@CapabilityMethod
public void addMetaMessage(String userId, MetaMessage metaMessage) {
log.debug("[{}] 当前会话历史: {}", getCoreKey(), JSONObject.toJSONString(singleMetaMessageMap));
messageLock.lock();
try {
if (singleMetaMessageMap.containsKey(userId)) {
singleMetaMessageMap.get(userId).add(metaMessage);
} else {
singleMetaMessageMap.put(userId, new java.util.ArrayList<>());
singleMetaMessageMap.get(userId).add(metaMessage);
}
} finally {
messageLock.unlock();
}
log.debug("[{}] 会话历史更新: {}", getCoreKey(), JSONObject.toJSONString(singleMetaMessageMap));
}
@CapabilityMethod
public List<Message> unpackAndClear(String userId) {
messageLock.lock();
try {
List<Message> messages = new ArrayList<>();
List<MetaMessage> metaMessages = singleMetaMessageMap.get(userId);
if (metaMessages == null) {
return messages;
}
for (MetaMessage metaMessage : metaMessages) {
messages.add(metaMessage.getUserMessage());
messages.add(metaMessage.getAssistantMessage());
}
singleMetaMessageMap.remove(userId);
return messages;
} finally {
messageLock.unlock();
}
}
@CapabilityMethod
public Map<String, List<MetaMessage>> drainSingleMetaMessages() {
messageLock.lock();
try {
Map<String, List<MetaMessage>> drained = new HashMap<>();
for (Map.Entry<String, List<MetaMessage>> entry : singleMetaMessageMap.entrySet()) {
drained.put(entry.getKey(), new ArrayList<>(entry.getValue()));
}
singleMetaMessageMap.clear();
return drained;
} finally {
messageLock.unlock();
}
}
@CapabilityMethod
public List<MetaMessage> snapshotSingleMetaMessages(String userId) {
messageLock.lock();
try {
List<MetaMessage> metaMessages = singleMetaMessageMap.get(userId);
if (metaMessages == null) {
return List.of();
}
return List.copyOf(metaMessages);
} finally {
messageLock.unlock();
}
}
@CapabilityMethod
public void refreshMemoryId() {
currentMemoryId = UUID.randomUUID().toString();
}
@CapabilityMethod
public void resetLastUpdatedTime() {
lastUpdatedTime = System.currentTimeMillis();
}
@Override
protected String getCoreKey() {
return "cognation-core";

View File

@@ -1,51 +1,36 @@
package work.slhaf.partner.core.memory;
import work.slhaf.partner.api.agent.factory.capability.annotation.Capability;
import work.slhaf.partner.core.memory.pojo.EvaluatedSlice;
import work.slhaf.partner.core.memory.pojo.MemoryResult;
import work.slhaf.partner.core.memory.pojo.ActivatedMemorySlice;
import work.slhaf.partner.core.memory.pojo.MemorySlice;
import work.slhaf.partner.core.memory.pojo.MemoryUnit;
import java.time.LocalDate;
import java.time.LocalDateTime;
import java.util.HashMap;
import java.util.Collection;
import java.util.List;
import java.util.concurrent.ConcurrentHashMap;
@Capability(value = "memory")
public interface MemoryCapability {
void cleanSelectedSliceFilter();
void clearActivatedSlices();
String getTopicTree();
void updateActivatedSlices(List<ActivatedMemorySlice> memorySlices);
HashMap<LocalDateTime, String> getDialogMap();
boolean hasActivatedSlices();
ConcurrentHashMap<LocalDateTime, String> getUserDialogMap(String userId);
int getActivatedSlicesSize();
void updateDialogMap(LocalDateTime dateTime, String newDialogCache);
List<ActivatedMemorySlice> getActivatedSlices();
String getDialogMapStr();
void saveMemoryUnit(MemoryUnit memoryUnit);
String getUserDialogMapStr(String userId);
MemoryUnit getMemoryUnit(String unitId);
void updateActivatedSlices(String userId, List<EvaluatedSlice> memorySlices);
MemorySlice getMemorySlice(String unitId, String sliceId);
String getActivatedSlicesStr(String userId);
Collection<MemoryUnit> listMemoryUnits();
HashMap<String, List<EvaluatedSlice>> getActivatedSlices();
void refreshMemoryId();
void clearActivatedSlices(String userId);
boolean hasActivatedSlices(String userId);
int getActivatedSlicesSize(String userId);
List<EvaluatedSlice> getActivatedSlices(String userId);
MemoryResult selectMemory(String topicPathStr);
MemoryResult selectMemory(LocalDate date);
void insertSlice(MemorySlice memorySlice, String topicPath);
String getCurrentMemoryId();
}

View File

@@ -7,19 +7,12 @@ import lombok.extern.slf4j.Slf4j;
import work.slhaf.partner.api.agent.factory.capability.annotation.CapabilityCore;
import work.slhaf.partner.api.agent.factory.capability.annotation.CapabilityMethod;
import work.slhaf.partner.core.PartnerCore;
import work.slhaf.partner.core.memory.exception.UnExistedDateIndexException;
import work.slhaf.partner.core.memory.exception.UnExistedTopicException;
import work.slhaf.partner.core.memory.pojo.EvaluatedSlice;
import work.slhaf.partner.core.memory.pojo.MemoryResult;
import work.slhaf.partner.core.memory.pojo.ActivatedMemorySlice;
import work.slhaf.partner.core.memory.pojo.MemorySlice;
import work.slhaf.partner.core.memory.pojo.MemorySliceResult;
import work.slhaf.partner.core.memory.pojo.node.MemoryNode;
import work.slhaf.partner.core.memory.pojo.node.TopicNode;
import work.slhaf.partner.core.memory.pojo.MemoryUnit;
import java.io.IOException;
import java.io.Serial;
import java.time.LocalDate;
import java.time.LocalDateTime;
import java.util.*;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.CopyOnWriteArrayList;
@@ -35,571 +28,121 @@ public class MemoryCore extends PartnerCore<MemoryCore> {
@Serial
private static final long serialVersionUID = 1L;
private final Lock sliceInsertLock = new ReentrantLock();
/**
* key: 根主题名称 value: 根主题节点
*/
private HashMap<String, TopicNode> topicNodes = new HashMap<>();
/**
* 用于存储已存在的主题列表,便于记忆查找, 使用根主题名称作为键, 子主题名称集合为值
* 该部分在'主题提取LLM'的system prompt中常驻
*/
private HashMap<String /*根主题名*/, LinkedHashSet<String> /*子主题列表*/> existedTopics = new HashMap<>();
/**
* 临时的同一对话切片容器, 用于为同一对话内的不同切片提供更新上下文的场所
*/
private HashMap<String /*对话id, 即slice中的字段'memoryId'*/, List<MemorySlice>> currentDateDialogSlices = new HashMap<>();
/**
* 记忆节点的日期索引, 同一日期内按照对话id区分
*/
private HashMap<LocalDate, Set<String>> dateIndex = new HashMap<>();
/**
* 已被选中的切片时间戳集合,需要及时清理
*/
private Set<Long> selectedSlices = new HashSet<>();
private HashMap<String, List<String>> userIndex = new HashMap<>();
private MemoryCache cache = new MemoryCache();
private final Lock memoryLock = new ReentrantLock();
private ConcurrentHashMap<String, MemoryUnit> memoryUnits = new ConcurrentHashMap<>();
private List<ActivatedMemorySlice> activatedSlices = new CopyOnWriteArrayList<>();
private String currentMemoryId;
public MemoryCore() throws IOException, ClassNotFoundException {
}
@CapabilityMethod
public MemoryResult selectMemory(LocalDate date) throws IOException, ClassNotFoundException {
MemoryResult memoryResult = new MemoryResult();
CopyOnWriteArrayList<MemorySliceResult> targetSliceList = new CopyOnWriteArrayList<>();
//加载节点并获取记忆切片列表
List<List<MemorySlice>> currentDateDialogSlices = loadSlicesByDate(date);
for (List<MemorySlice> value : currentDateDialogSlices) {
for (MemorySlice memorySlice : value) {
if (selectedSlices.contains(memorySlice.getTimestamp())) {
continue;
}
MemorySliceResult memorySliceResult = new MemorySliceResult();
memorySliceResult.setMemorySlice(memorySlice);
targetSliceList.add(memorySliceResult);
selectedSlices.add(memorySlice.getTimestamp());
}
}
memoryResult.setMemorySliceResult(targetSliceList);
return cacheFilter(memoryResult);
public void clearActivatedSlices() {
activatedSlices.clear();
}
@CapabilityMethod
public void insertSlice(MemorySlice memorySlice, String topicPath) {
sliceInsertLock.lock();
List<String> topicPathList = Arrays.stream(topicPath.split("->")).toList();
public void updateActivatedSlices(List<ActivatedMemorySlice> memorySlices) {
activatedSlices = new CopyOnWriteArrayList<>(memorySlices);
}
@CapabilityMethod
public boolean hasActivatedSlices() {
return !activatedSlices.isEmpty();
}
@CapabilityMethod
public int getActivatedSlicesSize() {
return activatedSlices.size();
}
@CapabilityMethod
public List<ActivatedMemorySlice> getActivatedSlices() {
return new ArrayList<>(activatedSlices);
}
@CapabilityMethod
public void saveMemoryUnit(MemoryUnit memoryUnit) {
memoryLock.lock();
try {
//检查是否存在当天对应的memorySlice并确定是否插入
//每日刷新缓存
checkCacheDate();
//如果topicPath在memorySliceCache中存在对应缓存由于进行的插入操作则需要移除该缓存但不清除相关计数
clearCacheByTopicPath(topicPathList);
insertMemory(topicPathList, memorySlice);
if (!memorySlice.isPrivate()) {
updateUserDialogMap(memorySlice);
}
} catch (Exception e) {
log.error("插入记忆时出错: ", e);
normalizeMemoryUnit(memoryUnit);
memoryUnits.put(memoryUnit.getId(), memoryUnit);
} finally {
memoryLock.unlock();
}
log.debug("插入切片: {}, 路径: {}", memorySlice, topicPath);
sliceInsertLock.unlock();
}
@CapabilityMethod
public String getTopicTree() {
StringBuilder stringBuilder = new StringBuilder();
for (Map.Entry<String, TopicNode> entry : topicNodes.entrySet()) {
String rootName = entry.getKey();
TopicNode rootNode = entry.getValue();
stringBuilder.append(rootName).append("[root]").append("\r\n");
printSubTopicsTreeFormat(rootNode, "", stringBuilder);
}
return stringBuilder.toString();
public MemoryUnit getMemoryUnit(String unitId) {
return memoryUnits.get(unitId);
}
@CapabilityMethod
public void updateDialogMap(LocalDateTime dateTime, String newDialogCache) {
List<LocalDateTime> keysToRemove = new ArrayList<>();
HashMap<LocalDateTime, String> dialogMap = cache.dialogMap;
dialogMap.forEach((k, v) -> {
if (dateTime.minusDays(2).isAfter(k)) {
keysToRemove.add(k);
}
});
for (LocalDateTime temp : keysToRemove) {
dialogMap.remove(temp);
}
keysToRemove.clear();
//放入新缓存
dialogMap.put(dateTime, newDialogCache);
}
@CapabilityMethod
public HashMap<LocalDateTime, String> getDialogMap() {
return cache.dialogMap;
}
@CapabilityMethod
public ConcurrentHashMap<LocalDateTime, String> getUserDialogMap(String userId) {
return cache.userDialogMap.get(userId);
}
@CapabilityMethod
public String getDialogMapStr() {
StringBuilder str = new StringBuilder();
this.getDialogMap().forEach((dateTime, dialog) -> str.append("\n\n").append("[").append(dateTime).append("]\n")
.append(dialog));
return str.toString();
}
@CapabilityMethod
public String getUserDialogMapStr(String userId) {
ConcurrentHashMap<String, ConcurrentHashMap<LocalDateTime, String>> userDialogMap = cache.userDialogMap;
if (userDialogMap.containsKey(userId)) {
StringBuilder str = new StringBuilder();
Collection<String> dialogMapValues = this.getDialogMap().values();
userDialogMap.get(userId).forEach((dateTime, dialog) -> {
if (dialogMapValues.contains(dialog)) {
return;
}
str.append("\n\n").append("[").append(dateTime).append("]\n")
.append(dialog);
});
return str.toString();
} else {
public MemorySlice getMemorySlice(String unitId, String sliceId) {
MemoryUnit memoryUnit = memoryUnits.get(unitId);
if (memoryUnit == null || memoryUnit.getSlices() == null) {
return null;
}
}
@CapabilityMethod
public MemoryResult selectMemory(String topicPathStr) {
MemoryResult memoryResult;
List<String> topicPath = List.of(topicPathStr.split("->"));
try {
List<String> path = new ArrayList<>(topicPath);
//每日刷新缓存
checkCacheDate();
//检测缓存并更新计数, 查看是否需要放入缓存
updateCacheCounter(path);
//查看是否存在缓存,如果存在,则直接返回
if ((memoryResult = selectCache(path)) != null) {
return memoryResult;
for (MemorySlice slice : memoryUnit.getSlices()) {
if (sliceId.equals(slice.getId())) {
return slice;
}
memoryResult = selectMemory(path);
//尝试更新缓存
updateCache(topicPath, memoryResult);
} catch (Exception e) {
log.error("[{}] selectMemory error: ", getCoreKey(), e);
log.error("[{}] 路径: {}", getCoreKey(), topicPathStr);
log.error("[{}] 主题树: {}", getCoreKey(), getTopicTree());
memoryResult = new MemoryResult();
memoryResult.setRelatedMemorySliceResult(new ArrayList<>());
memoryResult.setMemorySliceResult(new CopyOnWriteArrayList<>());
}
return cacheFilter(memoryResult);
}
@CapabilityMethod
public void updateActivatedSlices(String userId, List<EvaluatedSlice> memorySlices) {
cache.activatedSlices.put(userId, memorySlices);
log.debug("[{}] 已更新激活切片, userId: {}", getCoreKey(), userId);
}
@CapabilityMethod
public String getActivatedSlicesStr(String userId) {
HashMap<String, List<EvaluatedSlice>> activatedSlices = cache.activatedSlices;
if (activatedSlices.containsKey(userId)) {
StringBuilder str = new StringBuilder();
activatedSlices.get(userId).forEach(slice -> str.append("\n\n").append("[").append(slice.getDate()).append("]\n")
.append(slice.getSummary()));
return str.toString();
} else {
return null;
}
}
@CapabilityMethod
public HashMap<String, List<EvaluatedSlice>> getActivatedSlices() {
return cache.activatedSlices;
}
@CapabilityMethod
public void clearActivatedSlices(String userId) {
cache.activatedSlices.remove(userId);
}
@CapabilityMethod
public boolean hasActivatedSlices(String userId) {
HashMap<String, List<EvaluatedSlice>> activatedSlices = cache.activatedSlices;
if (!activatedSlices.containsKey(userId)) {
return false;
}
return !activatedSlices.get(userId).isEmpty();
}
@CapabilityMethod
public int getActivatedSlicesSize(String userId) {
return cache.activatedSlices.get(userId).size();
}
@CapabilityMethod
public List<EvaluatedSlice> getActivatedSlices(String userId) {
return cache.activatedSlices.get(userId);
}
@CapabilityMethod
public void cleanSelectedSliceFilter() {
this.selectedSlices.clear();
}
private List<List<MemorySlice>> loadSlicesByDate(LocalDate date) throws IOException, ClassNotFoundException {
if (!dateIndex.containsKey(date)) {
throw new UnExistedDateIndexException("不存在的日期索引: " + date);
}
List<List<MemorySlice>> list = new ArrayList<>();
for (String memoryNodeId : dateIndex.get(date)) {
MemoryNode memoryNode = new MemoryNode();
memoryNode.setMemoryNodeId(memoryNodeId);
list.add(memoryNode.loadMemorySliceList());
}
return list;
}
private void printSubTopicsTreeFormat(TopicNode node, String prefix, StringBuilder stringBuilder) {
if (node.getTopicNodes() == null || node.getTopicNodes().isEmpty()) return;
List<Map.Entry<String, TopicNode>> entries = new ArrayList<>(node.getTopicNodes().entrySet());
for (int i = 0; i < entries.size(); i++) {
boolean last = (i == entries.size() - 1);
Map.Entry<String, TopicNode> entry = entries.get(i);
stringBuilder.append(prefix).append(last ? "└── " : "├── ").append(entry.getKey()).append("[").append(entry.getValue().getMemoryNodes().size()).append("]").append("\r\n");
printSubTopicsTreeFormat(entry.getValue(), prefix + (last ? " " : ""), stringBuilder);
}
}
private void insertMemory(List<String> topicPath, MemorySlice slice) throws IOException, ClassNotFoundException {
LocalDate now = LocalDate.now();
boolean hasSlice = false;
MemoryNode node = null;
TopicNode lastTopicNode = generateTopicPath(topicPath);
for (MemoryNode memoryNode : lastTopicNode.getMemoryNodes()) {
if (now.equals(memoryNode.getLocalDate())) {
hasSlice = true;
node = memoryNode;
break;
}
}
if (!hasSlice) {
node = new MemoryNode();
node.setLocalDate(now);
node.setMemoryNodeId(UUID.randomUUID().toString());
node.setMemorySliceList(new CopyOnWriteArrayList<>());
lastTopicNode.getMemoryNodes().add(node);
lastTopicNode.getMemoryNodes().sort(null);
}
node.loadMemorySliceList().add(slice);
//生成relatedTopicPath
for (List<String> relatedTopic : slice.getRelatedTopics()) {
generateTopicPath(relatedTopic);
}
updateSlicePrecedent(slice);
updateDateIndex(slice);
updateUserIndex(slice);
node.saveMemorySliceList();
}
private void updateUserIndex(MemorySlice slice) {
String memoryId = slice.getMemoryId();
String userId = slice.getStartUserId();
if (!userIndex.containsKey(userId)) {
List<String> memoryIdSet = new ArrayList<>();
memoryIdSet.add(memoryId);
userIndex.put(userId, memoryIdSet);
} else {
userIndex.get(userId).add(memoryId);
}
}
private TopicNode generateTopicPath(List<String> topicPath) {
topicPath = new ArrayList<>(topicPath);
//查看是否存在根主题节点
String rootTopic = topicPath.getFirst();
topicPath.removeFirst();
if (!topicNodes.containsKey(rootTopic)) {
synchronized (this) {
if (!topicNodes.containsKey(rootTopic)) {
TopicNode rootNode = new TopicNode();
topicNodes.put(rootTopic, rootNode);
existedTopics.put(rootTopic, new LinkedHashSet<>());
}
}
}
TopicNode current = topicNodes.get(rootTopic);
Set<String> existedTopicNodes = existedTopics.get(rootTopic);
for (String topic : topicPath) {
if (existedTopicNodes.contains(topic) && current.getTopicNodes().containsKey(topic)) {
current = current.getTopicNodes().get(topic);
} else {
TopicNode newNode = new TopicNode();
current.getTopicNodes().put(topic, newNode);
current = newNode;
current.setMemoryNodes(new CopyOnWriteArrayList<>());
current.setTopicNodes(new ConcurrentHashMap<>());
existedTopicNodes.add(topic);
}
}
return current;
}
private void updateSlicePrecedent(MemorySlice slice) {
String memoryId = slice.getMemoryId();
//查看是否切换了memoryId
if (!currentDateDialogSlices.containsKey(memoryId)) {
List<MemorySlice> memorySliceList = new ArrayList<>();
currentDateDialogSlices.clear();
currentDateDialogSlices.put(memoryId, memorySliceList);
}
//处理上下文关系
List<MemorySlice> memorySliceList = currentDateDialogSlices.get(memoryId);
if (memorySliceList.isEmpty()) {
memorySliceList.add(slice);
} else {
//排序
memorySliceList.sort(null);
MemorySlice tempSlice = memorySliceList.getLast();
//设置私密状态一致
tempSlice.setPrivate(slice.isPrivate());
//末尾切片添加当前切片的引用
tempSlice.setSliceAfter(slice);
//当前切片添加前序切片的引用
slice.setSliceBefore(tempSlice);
}
}
private void updateDateIndex(MemorySlice slice) {
String memoryId = slice.getMemoryId();
LocalDate date = LocalDate.now();
if (!dateIndex.containsKey(date)) {
HashSet<String> memoryIdSet = new HashSet<>();
memoryIdSet.add(memoryId);
dateIndex.put(date, memoryIdSet);
} else {
dateIndex.get(date).add(memoryId);
}
}
public MemoryResult selectMemory(List<String> path) throws IOException, ClassNotFoundException {
MemoryResult memoryResult = new MemoryResult();
CopyOnWriteArrayList<MemorySliceResult> targetSliceList = new CopyOnWriteArrayList<>();
String targetTopic = path.getLast();
TopicNode targetParentNode = getTargetParentNode(path, targetTopic);
List<List<String>> relatedTopics = new ArrayList<>();
//终点记忆节点
MemorySliceResult sliceResult = new MemorySliceResult();
for (MemoryNode memoryNode : targetParentNode.getTopicNodes().get(targetTopic).getMemoryNodes()) {
List<MemorySlice> endpointMemorySliceList = memoryNode.loadMemorySliceList();
for (MemorySlice memorySlice : endpointMemorySliceList) {
if (selectedSlices.contains(memorySlice.getTimestamp())) {
continue;
}
sliceResult.setSliceBefore(memorySlice.getSliceBefore());
sliceResult.setMemorySlice(memorySlice);
sliceResult.setSliceAfter(memorySlice.getSliceAfter());
targetSliceList.add(sliceResult);
selectedSlices.add(memorySlice.getTimestamp());
}
for (MemorySlice memorySlice : endpointMemorySliceList) {
if (memorySlice.getRelatedTopics() != null) {
relatedTopics.addAll(memorySlice.getRelatedTopics());
}
}
}
memoryResult.setMemorySliceResult(targetSliceList);
//邻近节点
List<MemorySlice> relatedMemorySlice = new ArrayList<>();
//邻近记忆节点 联系
for (List<String> relatedTopic : relatedTopics) {
List<String> tempTopicPath = new ArrayList<>(relatedTopic);
String tempTargetTopic = tempTopicPath.getLast();
TopicNode tempTargetParentNode = getTargetParentNode(tempTopicPath, tempTargetTopic);
//获取终点节点及其最新记忆节点
TopicNode tempTargetNode = tempTargetParentNode.getTopicNodes().get(tempTopicPath.getLast());
setRelatedMemorySlices(tempTargetNode, relatedMemorySlice);
}
//邻近记忆节点 父级
setRelatedMemorySlices(targetParentNode, relatedMemorySlice);
//将上述结果包装为MemoryResult
memoryResult.setRelatedMemorySliceResult(relatedMemorySlice);
return memoryResult;
}
private void setRelatedMemorySlices(TopicNode targetParentNode, List<MemorySlice> relatedMemorySlice) throws IOException, ClassNotFoundException {
List<MemoryNode> targetParentMemoryNodes = targetParentNode.getMemoryNodes();
if (!targetParentMemoryNodes.isEmpty()) {
for (MemorySlice memorySlice : targetParentMemoryNodes.getFirst().loadMemorySliceList()) {
if (selectedSlices.contains(memorySlice.getTimestamp())) {
continue;
}
relatedMemorySlice.add(memorySlice);
selectedSlices.add(memorySlice.getTimestamp());
}
}
}
private TopicNode getTargetParentNode(List<String> topicPath, String targetTopic) {
String topTopic = topicPath.getFirst();
if (!existedTopics.containsKey(topTopic)) {
throw new UnExistedTopicException("不存在的主题: " + topTopic);
}
TopicNode targetParentNode = topicNodes.get(topTopic);
topicPath.removeFirst();
for (String topic : topicPath) {
if (!existedTopics.get(topTopic).contains(topic)) {
throw new UnExistedTopicException("不存在的主题: " + topTopic);
}
}
//逐层查找目标主题
while (!targetParentNode.getTopicNodes().containsKey(targetTopic)) {
targetParentNode = targetParentNode.getTopicNodes().get(topicPath.getFirst());
topicPath.removeFirst();
}
return targetParentNode;
}
private void updateCacheCounter(List<String> topicPath) {
ConcurrentHashMap<List<String>, Integer> memoryNodeCacheCounter = cache.memoryNodeCacheCounter;
if (memoryNodeCacheCounter.containsKey(topicPath)) {
Integer tempCount = memoryNodeCacheCounter.get(topicPath);
memoryNodeCacheCounter.put(topicPath, ++tempCount);
} else {
memoryNodeCacheCounter.put(topicPath, 1);
}
}
private void checkCacheDate() {
if (cache.cacheDate == null || cache.cacheDate.isBefore(LocalDate.now())) {
cache.memorySliceCache.clear();
cache.memoryNodeCacheCounter.clear();
cache.cacheDate = LocalDate.now();
}
}
private void updateCache(List<String> topicPath, MemoryResult memoryResult) {
ConcurrentHashMap<List<String>, Integer> memoryNodeCacheCounter = cache.memoryNodeCacheCounter;
Integer tempCount = memoryNodeCacheCounter.get(topicPath);
if (tempCount == null) {
log.warn("[CacheCore] tempCount为null? memoryNodeCacheCounter: {}; topicPath: {}", memoryNodeCacheCounter, topicPath);
return;
}
if (tempCount >= 5) {
cache.memorySliceCache.put(topicPath, memoryResult);
}
}
private void updateUserDialogMap(MemorySlice slice) {
String summary = slice.getSummary();
LocalDateTime now = LocalDateTime.now();
ConcurrentHashMap<String, ConcurrentHashMap<LocalDateTime, String>> userDialogMap = cache.userDialogMap;
//更新userDialogMap
//移除两天前上下文缓存(切片总结)
List<LocalDateTime> keysToRemove = new ArrayList<>();
userDialogMap.forEach((k, v) -> v.forEach((i, j) -> {
if (now.minusDays(2).isAfter(i)) {
keysToRemove.add(i);
}
}));
for (LocalDateTime dateTime : keysToRemove) {
userDialogMap.forEach((k, v) -> v.remove(dateTime));
}
//放入新缓存
userDialogMap
.computeIfAbsent(slice.getStartUserId(), k -> new ConcurrentHashMap<>())
.merge(now, summary, (oldVal, newVal) -> oldVal + " " + newVal);
}
private void clearCacheByTopicPath(List<String> topicPath) {
cache.memorySliceCache.remove(topicPath);
}
private MemoryResult selectCache(List<String> path) {
ConcurrentHashMap<List<String>, MemoryResult> memorySliceCache = cache.memorySliceCache;
if (memorySliceCache.containsKey(path)) {
return memorySliceCache.get(path);
}
return null;
}
@CapabilityMethod
public Collection<MemoryUnit> listMemoryUnits() {
return new ArrayList<>(memoryUnits.values());
}
@CapabilityMethod
public void refreshMemoryId() {
currentMemoryId = UUID.randomUUID().toString();
}
@CapabilityMethod
public String getCurrentMemoryId() {
return currentMemoryId;
}
private void normalizeMemoryUnit(MemoryUnit memoryUnit) {
if (memoryUnit.getId() == null || memoryUnit.getId().isBlank()) {
memoryUnit.setId(UUID.randomUUID().toString());
}
if (memoryUnit.getTimestamp() == null || memoryUnit.getTimestamp() <= 0) {
memoryUnit.setTimestamp(System.currentTimeMillis());
}
if (memoryUnit.getConversationMessages() == null) {
memoryUnit.setConversationMessages(new ArrayList<>());
}
if (memoryUnit.getSlices() == null) {
memoryUnit.setSlices(new ArrayList<>());
}
int maxIndex = Math.max(memoryUnit.getConversationMessages().size() - 1, 0);
for (MemorySlice slice : memoryUnit.getSlices()) {
if (slice.getId() == null || slice.getId().isBlank()) {
slice.setId(UUID.randomUUID().toString());
}
if (slice.getTimestamp() == null || slice.getTimestamp() <= 0) {
slice.setTimestamp(memoryUnit.getTimestamp());
}
if (slice.getStartIndex() == null || slice.getStartIndex() < 0) {
slice.setStartIndex(0);
}
if (slice.getEndIndex() == null || slice.getEndIndex() < slice.getStartIndex()) {
slice.setEndIndex(maxIndex);
}
if (slice.getEndIndex() > maxIndex) {
slice.setEndIndex(maxIndex);
}
}
memoryUnit.getSlices().sort(Comparator.naturalOrder());
}
@Override
protected String getCoreKey() {
return "memory-core";
}
public ConcurrentHashMap<String, ConcurrentHashMap<LocalDateTime, String>> getUserDialogMap() {
return cache.userDialogMap;
}
private MemoryResult cacheFilter(MemoryResult memoryResult) {
//过滤掉与缓存重复的切片
CopyOnWriteArrayList<MemorySliceResult> memorySliceResult = memoryResult.getMemorySliceResult();
List<MemorySlice> relatedMemorySliceResult = memoryResult.getRelatedMemorySliceResult();
cache.dialogMap.forEach((k, v) -> {
memorySliceResult.removeIf(m -> m.getMemorySlice().getSummary().equals(v));
relatedMemorySliceResult.removeIf(m -> m.getSummary().equals(v));
});
return memoryResult;
}
@SuppressWarnings("FieldMayBeFinal")
private static class MemoryCache {
/**
* 近两日的对话总结缓存, 用于为大模型提供必要的记忆补充, hashmap以切片的存储时间为键总结为值
* 该部分作为'主LLM'system prompt常驻
* 该部分作为近两日的整体对话缓存, 不区分用户
*/
private HashMap<LocalDateTime, String> dialogMap = new HashMap<>();
/**
* 近两日的区分用户的对话总结缓存在prompt结构上比dialogMap层级深一层, dialogMap更具近两日整体对话的摘要性质
*/
private ConcurrentHashMap<String/*userId*/, ConcurrentHashMap<LocalDateTime, String>> userDialogMap = new ConcurrentHashMap<>();
/**
* memorySliceCache计数器每日清空
*/
private ConcurrentHashMap<List<String> /*触发查询的主题列表*/, Integer> memoryNodeCacheCounter = new ConcurrentHashMap<>();
/**
* 记忆切片缓存,每日清空
* 用于记录作为终点节点调用次数最多的记忆节点的切片数据
*/
private ConcurrentHashMap<List<String> /*主题路径*/, MemoryResult /*切片列表*/> memorySliceCache = new ConcurrentHashMap<>();
/**
* 缓存日期
*/
private LocalDate cacheDate;
private HashMap<String, List<EvaluatedSlice>> activatedSlices = new HashMap<>();
private MemoryCache() {
}
}
}

View File

@@ -3,20 +3,25 @@ package work.slhaf.partner.core.memory.pojo;
import lombok.Builder;
import lombok.Data;
import lombok.EqualsAndHashCode;
import work.slhaf.partner.api.chat.pojo.Message;
import work.slhaf.partner.api.common.entity.PersistableObject;
import java.io.Serial;
import java.time.LocalDate;
import java.util.List;
@EqualsAndHashCode(callSuper = true)
@Data
@Builder
public class EvaluatedSlice extends PersistableObject {
public class ActivatedMemorySlice extends PersistableObject {
@Serial
private static final long serialVersionUID = 1L;
// private List<Message> chatMessages;
private String unitId;
private String sliceId;
private LocalDate date;
private Long timestamp;
private String summary;
private List<Message> messages;
}

View File

@@ -1,26 +0,0 @@
package work.slhaf.partner.core.memory.pojo;
import lombok.Data;
import lombok.EqualsAndHashCode;
import work.slhaf.partner.api.common.entity.PersistableObject;
import java.io.Serial;
import java.util.List;
import java.util.concurrent.CopyOnWriteArrayList;
@EqualsAndHashCode(callSuper = true)
@Data
public class MemoryResult extends PersistableObject {
@Serial
private static final long serialVersionUID = 1L;
private CopyOnWriteArrayList<MemorySliceResult> memorySliceResult;
private List<MemorySlice> relatedMemorySliceResult;
public boolean isEmpty() {
boolean a = memorySliceResult == null || memorySliceResult.isEmpty();
boolean b = relatedMemorySliceResult == null || relatedMemorySliceResult.isEmpty();
return a && b;
}
}

View File

@@ -2,12 +2,9 @@ package work.slhaf.partner.core.memory.pojo;
import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.ToString;
import work.slhaf.partner.api.chat.pojo.Message;
import work.slhaf.partner.api.common.entity.PersistableObject;
import java.io.Serial;
import java.util.List;
@EqualsAndHashCode(callSuper = true)
@Data
@@ -16,59 +13,11 @@ public class MemorySlice extends PersistableObject implements Comparable<MemoryS
@Serial
private static final long serialVersionUID = 1L;
/**
* 关联的完整对话的id
*/
private String memoryId;
/**
* 该切片在关联的完整对话中的顺序, 由时间戳确定
*/
private Long timestamp;
/**
* 格式为"<日期>.slice", 如2025-04-11.slice
*/
private String id;
private Integer startIndex;
private Integer endIndex;
private String summary;
private List<Message> chatMessages;
/**
* 关联的其他主题, 即"邻近节点(联系)"
*/
private List<List<String>> relatedTopics;
/**
* 关联完整对话中的前序切片, 排序为键,完整路径为值
*/
@ToString.Exclude
private MemorySlice sliceBefore, sliceAfter;
/**
* 多用户设定
* 发起该切片对话的用户
*/
private String startUserId;
/**
* 该切片涉及到的用户uuid
*/
private List<String> involvedUserIds;
/**
* 是否仅供发起用户作为记忆参考
*/
private boolean isPrivate;
/**
* 摘要向量化结果
*/
private float[] summaryEmbedding;
/**
* 是否向量化
*/
private boolean embedded;
private Long timestamp;
@Override
public int compareTo(MemorySlice memorySlice) {
@@ -79,5 +28,4 @@ public class MemorySlice extends PersistableObject implements Comparable<MemoryS
}
return 0;
}
}

View File

@@ -0,0 +1,23 @@
package work.slhaf.partner.core.memory.pojo;
import lombok.Data;
import lombok.EqualsAndHashCode;
import work.slhaf.partner.api.chat.pojo.Message;
import work.slhaf.partner.api.common.entity.PersistableObject;
import java.io.Serial;
import java.util.ArrayList;
import java.util.List;
@EqualsAndHashCode(callSuper = true)
@Data
public class MemoryUnit extends PersistableObject {
@Serial
private static final long serialVersionUID = 1L;
private String id;
private List<Message> conversationMessages = new ArrayList<>();
private Long timestamp;
private List<MemorySlice> slices = new ArrayList<>();
}

View File

@@ -1,24 +1,22 @@
package work.slhaf.partner.core.memory.pojo;
import com.alibaba.fastjson2.annotation.JSONField;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.NoArgsConstructor;
import work.slhaf.partner.api.common.entity.PersistableObject;
import java.io.Serial;
@EqualsAndHashCode(callSuper = true)
@Data
public class MemorySliceResult extends PersistableObject {
@NoArgsConstructor
@AllArgsConstructor
public class SliceRef extends PersistableObject {
@Serial
private static final long serialVersionUID = 1L;
@JSONField(serialize = false)
private MemorySlice sliceBefore;
private MemorySlice memorySlice;
@JSONField(serialize = false)
private MemorySlice sliceAfter;
private String unitId;
private String sliceId;
}

View File

@@ -367,7 +367,7 @@ public class ActionExecutor extends AbstractAgentModule.Standalone {
private ExtractorInput buildExtractorInput(MetaAction action, String source, List<HistoryAction> historyActionResults,
List<String> additionalContext) {
ExtractorInput input = new ExtractorInput();
input.setEvaluatedSlices(memoryCapability.getActivatedSlices(source));
input.setActivatedMemorySlices(memoryCapability.getActivatedSlices());
input.setRecentMessages(cognationCapability.getChatMessages());
input.setMetaActionInfo(actionCapability.loadMetaActionInfo(action.getKey()));
input.setHistoryActionResults(historyActionResults);
@@ -384,7 +384,7 @@ public class ActionExecutor extends AbstractAgentModule.Standalone {
.history(executableAction.getHistory().get(executableAction.getExecutingStage()))
.status(executableAction.getStatus())
.recentMessages(cognationCapability.getChatMessages())
.activatedSlices(memoryCapability.getActivatedSlices(source))
.activatedSlices(memoryCapability.getActivatedSlices())
.build();
}
}

View File

@@ -4,7 +4,7 @@ import lombok.Builder;
import lombok.Data;
import work.slhaf.partner.api.chat.pojo.Message;
import work.slhaf.partner.core.action.entity.ExecutableAction;
import work.slhaf.partner.core.memory.pojo.EvaluatedSlice;
import work.slhaf.partner.core.memory.pojo.ActivatedMemorySlice;
import java.util.List;
@@ -20,5 +20,5 @@ public class CorrectorInput {
private ExecutableAction.Status status;
private List<Message> recentMessages;
private List<EvaluatedSlice> activatedSlices;
private List<ActivatedMemorySlice> activatedSlices;
}

View File

@@ -3,7 +3,7 @@ package work.slhaf.partner.module.modules.action.executor.entity;
import lombok.Data;
import work.slhaf.partner.api.chat.pojo.Message;
import work.slhaf.partner.core.action.entity.MetaActionInfo;
import work.slhaf.partner.core.memory.pojo.EvaluatedSlice;
import work.slhaf.partner.core.memory.pojo.ActivatedMemorySlice;
import java.util.List;
@@ -16,7 +16,7 @@ public class ExtractorInput {
/**
* 可参考的记忆切片
*/
private List<EvaluatedSlice> evaluatedSlices;
private List<ActivatedMemorySlice> activatedMemorySlices;
/**
* 历史行动执行结果
*/

View File

@@ -22,6 +22,7 @@ import work.slhaf.partner.module.modules.action.interventor.evaluator.entity.Eva
import work.slhaf.partner.module.modules.action.interventor.recognizer.InterventionRecognizer;
import work.slhaf.partner.module.modules.action.interventor.recognizer.entity.RecognizerInput;
import work.slhaf.partner.module.modules.action.interventor.recognizer.entity.RecognizerResult;
import work.slhaf.partner.module.modules.memory.runtime.MemoryRuntime;
import work.slhaf.partner.runtime.interaction.data.context.PartnerRunningFlowContext;
import java.util.ArrayList;
@@ -51,6 +52,8 @@ public class ActionInterventor extends AbstractAgentModule.Running<PartnerRunnin
private CognationCapability cognationCapability;
@InjectCapability
private MemoryCapability memoryCapability;
@InjectModule
private MemoryRuntime memoryRuntime;
@Override
public void execute(PartnerRunningFlowContext context) {
@@ -146,7 +149,7 @@ public class ActionInterventor extends AbstractAgentModule.Running<PartnerRunnin
private RecognizerInput buildRecognizerInput(String userId, String input) {
RecognizerInput recognizerInput = new RecognizerInput();
recognizerInput.setInput(input);
recognizerInput.setUserDialogMapStr(memoryCapability.getUserDialogMapStr(userId));
recognizerInput.setUserDialogMapStr(memoryRuntime.getDialogMapStr());
// 参考的对话列表大小或需调整
recognizerInput.setRecentMessages(cognationCapability.getChatMessages());
recognizerInput.setExecutingActions(actionCapability.listPhaserRecords().stream().map(PhaserRecord::executableAction).toList());
@@ -159,7 +162,7 @@ public class ActionInterventor extends AbstractAgentModule.Running<PartnerRunnin
input.setExecutingInterventions(recognizerResult.getExecutingInterventions());
input.setPreparedInterventions(recognizerResult.getPreparedInterventions());
input.setRecentMessages(cognationCapability.getChatMessages());
input.setActivatedSlices(memoryCapability.getActivatedSlices(userId));
input.setActivatedSlices(memoryCapability.getActivatedSlices());
return input;
}
}

View File

@@ -8,7 +8,7 @@ import work.slhaf.partner.api.chat.pojo.Message;
import work.slhaf.partner.core.action.ActionCapability;
import work.slhaf.partner.core.action.ActionCore.ExecutorType;
import work.slhaf.partner.core.action.entity.ExecutableAction;
import work.slhaf.partner.core.memory.pojo.EvaluatedSlice;
import work.slhaf.partner.core.memory.pojo.ActivatedMemorySlice;
import work.slhaf.partner.module.modules.action.interventor.evaluator.entity.EvaluatorInput;
import work.slhaf.partner.module.modules.action.interventor.evaluator.entity.EvaluatorResult;
import work.slhaf.partner.module.modules.action.interventor.evaluator.entity.EvaluatorResult.EvaluatedInterventionData;
@@ -66,7 +66,7 @@ public class InterventionEvaluator extends AbstractAgentModule.Sub<EvaluatorInpu
}));
}
private String buildPrompt(List<Message> recentMessages, List<EvaluatedSlice> activatedSlices,
private String buildPrompt(List<Message> recentMessages, List<ActivatedMemorySlice> activatedSlices,
ExecutableAction executableAction, String tendency) {
JSONObject json = new JSONObject();
json.put("干预倾向", tendency);

View File

@@ -3,7 +3,7 @@ package work.slhaf.partner.module.modules.action.interventor.evaluator.entity;
import lombok.Data;
import work.slhaf.partner.api.chat.pojo.Message;
import work.slhaf.partner.core.action.entity.ExecutableAction;
import work.slhaf.partner.core.memory.pojo.EvaluatedSlice;
import work.slhaf.partner.core.memory.pojo.ActivatedMemorySlice;
import java.util.List;
import java.util.Map;
@@ -12,6 +12,6 @@ import java.util.Map;
public class EvaluatorInput {
private Map<String, ExecutableAction> executingInterventions;
private Map<String, ExecutableAction> preparedInterventions;
private List<EvaluatedSlice> activatedSlices;
private List<ActivatedMemorySlice> activatedSlices;
private List<Message> recentMessages;
}

View File

@@ -328,7 +328,7 @@ public class ActionPlanner extends AbstractAgentModule.Running<PartnerRunningFlo
input.setTendencies(extractorResult.getTendencies());
input.setUser(perceiveCapability.getUser(userId));
input.setRecentMessages(cognationCapability.snapshotChatMessages());
input.setActivatedSlices(memoryCapability.getActivatedSlices(userId));
input.setActivatedSlices(memoryCapability.getActivatedSlices());
return input;
}

View File

@@ -10,7 +10,7 @@ import work.slhaf.partner.api.agent.factory.component.annotation.Init;
import work.slhaf.partner.api.chat.pojo.Message;
import work.slhaf.partner.common.thread.InteractionThreadPoolExecutor;
import work.slhaf.partner.core.action.ActionCapability;
import work.slhaf.partner.core.memory.pojo.EvaluatedSlice;
import work.slhaf.partner.core.memory.pojo.ActivatedMemorySlice;
import work.slhaf.partner.module.modules.action.planner.evaluator.entity.EvaluatorBatchInput;
import work.slhaf.partner.module.modules.action.planner.evaluator.entity.EvaluatorInput;
import work.slhaf.partner.module.modules.action.planner.evaluator.entity.EvaluatorResult;
@@ -77,7 +77,7 @@ public class ActionEvaluator extends AbstractAgentModule.Sub<EvaluatorInput, Lis
JSONObject prompt = new JSONObject();
prompt.put("[行动倾向]", batchInput.getTendency());
JSONArray memoryData = prompt.putArray("[相关记忆切片]");
for (EvaluatedSlice evaluatedSlice : batchInput.getActivatedSlices()) {
for (ActivatedMemorySlice evaluatedSlice : batchInput.getActivatedSlices()) {
JSONObject memory = memoryData.addObject();
memory.put("[日期]", evaluatedSlice.getDate());
memory.put("[摘要]", evaluatedSlice.getSummary());

View File

@@ -2,7 +2,7 @@ package work.slhaf.partner.module.modules.action.planner.evaluator.entity;
import lombok.Data;
import work.slhaf.partner.api.chat.pojo.Message;
import work.slhaf.partner.core.memory.pojo.EvaluatedSlice;
import work.slhaf.partner.core.memory.pojo.ActivatedMemorySlice;
import java.util.List;
import java.util.Map;
@@ -10,7 +10,7 @@ import java.util.Map;
@Data
public class EvaluatorBatchInput {
private List<Message> recentMessages;
private List<EvaluatedSlice> activatedSlices;
private List<ActivatedMemorySlice> activatedSlices;
private Map<String, String> availableActions;
private String tendency;
}

View File

@@ -2,7 +2,7 @@ package work.slhaf.partner.module.modules.action.planner.evaluator.entity;
import lombok.Data;
import work.slhaf.partner.api.chat.pojo.Message;
import work.slhaf.partner.core.memory.pojo.EvaluatedSlice;
import work.slhaf.partner.core.memory.pojo.ActivatedMemorySlice;
import work.slhaf.partner.core.perceive.pojo.User;
import java.util.List;
@@ -11,6 +11,6 @@ import java.util.List;
public class EvaluatorInput {
private List<Message> recentMessages;
private User user;
private List<EvaluatedSlice> activatedSlices;
private List<ActivatedMemorySlice> activatedSlices;
private List<String> tendencies;
}

View File

@@ -12,7 +12,6 @@ import work.slhaf.partner.api.agent.factory.component.abstracts.ActivateModel;
import work.slhaf.partner.api.agent.factory.component.annotation.Init;
import work.slhaf.partner.api.agent.runtime.interaction.flow.ContextBlock;
import work.slhaf.partner.api.chat.pojo.Message;
import work.slhaf.partner.api.chat.pojo.MetaMessage;
import work.slhaf.partner.core.cognation.CognationCapability;
import work.slhaf.partner.runtime.interaction.data.context.PartnerRunningFlowContext;
@@ -126,8 +125,6 @@ public class CommunicationProducer extends AbstractAgentModule.Running<PartnerRu
chatMessages.add(primaryUserMessage);
Message assistantMessage = new Message(Message.Character.ASSISTANT, response);
chatMessages.add(assistantMessage);
MetaMessage metaMessage = new MetaMessage(primaryUserMessage, assistantMessage);
cognationCapability.addMetaMessage(runningFlowContext.getSource(), metaMessage);
} finally {
cognationCapability.getMessageLock().unlock();
}

View File

@@ -0,0 +1,247 @@
package work.slhaf.partner.module.modules.memory.runtime;
import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.io.FileUtils;
import work.slhaf.partner.api.agent.factory.component.abstracts.AbstractAgentModule;
import work.slhaf.partner.api.agent.factory.component.annotation.Init;
import work.slhaf.partner.api.agent.runtime.config.AgentConfigLoader;
import work.slhaf.partner.api.common.entity.PersistableObject;
import work.slhaf.partner.common.config.PartnerAgentConfigLoader;
import work.slhaf.partner.core.memory.exception.UnExistedDateIndexException;
import work.slhaf.partner.core.memory.exception.UnExistedTopicException;
import work.slhaf.partner.core.memory.pojo.MemorySlice;
import work.slhaf.partner.core.memory.pojo.MemoryUnit;
import work.slhaf.partner.core.memory.pojo.SliceRef;
import java.io.*;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.time.Instant;
import java.time.LocalDate;
import java.time.LocalDateTime;
import java.time.ZoneId;
import java.util.*;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.concurrent.locks.ReentrantLock;
import static work.slhaf.partner.common.Constant.Path.MEMORY_DATA;
@EqualsAndHashCode(callSuper = true)
@Data
@Slf4j
public class MemoryRuntime extends AbstractAgentModule.Standalone {
private static final String RUNTIME_KEY = "memory-runtime";
private final ReentrantLock runtimeLock = new ReentrantLock();
private Map<String, CopyOnWriteArrayList<SliceRef>> topicSlices = new HashMap<>();
private Map<LocalDate, CopyOnWriteArrayList<SliceRef>> dateIndex = new HashMap<>();
private HashMap<LocalDateTime, String> dialogMap = new HashMap<>();
@Init
public void init() {
loadState();
Runtime.getRuntime().addShutdownHook(new Thread(this::saveStateSafely));
}
public void bindTopic(String topicPath, SliceRef sliceRef) {
String normalizedPath = normalizeTopicPath(topicPath);
runtimeLock.lock();
try {
CopyOnWriteArrayList<SliceRef> refs = topicSlices.computeIfAbsent(normalizedPath, key -> new CopyOnWriteArrayList<>());
boolean exists = refs.stream().anyMatch(ref -> Objects.equals(ref.getUnitId(), sliceRef.getUnitId())
&& Objects.equals(ref.getSliceId(), sliceRef.getSliceId()));
if (!exists) {
refs.add(sliceRef);
}
saveState();
} finally {
runtimeLock.unlock();
}
}
public void indexMemoryUnit(MemoryUnit memoryUnit) {
runtimeLock.lock();
try {
for (CopyOnWriteArrayList<SliceRef> refs : dateIndex.values()) {
refs.removeIf(ref -> memoryUnit.getId().equals(ref.getUnitId()));
}
if (memoryUnit.getSlices() != null) {
for (MemorySlice slice : memoryUnit.getSlices()) {
LocalDate date = Instant.ofEpochMilli(slice.getTimestamp())
.atZone(ZoneId.systemDefault())
.toLocalDate();
dateIndex.computeIfAbsent(date, key -> new CopyOnWriteArrayList<>())
.addIfAbsent(new SliceRef(memoryUnit.getId(), slice.getId()));
}
}
saveState();
} finally {
runtimeLock.unlock();
}
}
public List<SliceRef> findByTopicPath(String topicPath) {
String normalizedPath = normalizeTopicPath(topicPath);
List<SliceRef> refs = topicSlices.get(normalizedPath);
if (refs == null || refs.isEmpty()) {
throw new UnExistedTopicException("不存在的主题: " + normalizedPath);
}
return new ArrayList<>(refs);
}
public List<SliceRef> findByDate(LocalDate date) {
List<SliceRef> refs = dateIndex.get(date);
if (refs == null || refs.isEmpty()) {
throw new UnExistedDateIndexException("不存在的日期索引: " + date);
}
return new ArrayList<>(refs);
}
public void updateDialogMap(LocalDateTime dateTime, String newDialogCache) {
runtimeLock.lock();
try {
List<LocalDateTime> keysToRemove = new ArrayList<>();
dialogMap.forEach((k, v) -> {
if (dateTime.minusDays(2).isAfter(k)) {
keysToRemove.add(k);
}
});
for (LocalDateTime temp : keysToRemove) {
dialogMap.remove(temp);
}
dialogMap.put(dateTime, newDialogCache);
saveState();
} finally {
runtimeLock.unlock();
}
}
public HashMap<LocalDateTime, String> getDialogMap() {
return dialogMap;
}
public String getDialogMapStr() {
StringBuilder str = new StringBuilder();
dialogMap.entrySet().stream()
.sorted(Map.Entry.comparingByKey())
.forEach(entry -> str.append("\n\n[")
.append(entry.getKey())
.append("]\n")
.append(entry.getValue()));
return str.toString();
}
public String getTopicTree() {
TopicTreeNode root = new TopicTreeNode();
for (Map.Entry<String, CopyOnWriteArrayList<SliceRef>> entry : topicSlices.entrySet()) {
String[] parts = entry.getKey().split("->");
TopicTreeNode current = root;
for (String part : parts) {
current = current.children.computeIfAbsent(part, key -> new TopicTreeNode());
}
current.count += entry.getValue().size();
}
StringBuilder stringBuilder = new StringBuilder();
List<Map.Entry<String, TopicTreeNode>> roots = new ArrayList<>(root.children.entrySet());
for (Map.Entry<String, TopicTreeNode> entry : roots) {
stringBuilder.append(entry.getKey()).append("[root]").append("\r\n");
printSubTopicsTreeFormat(entry.getValue(), "", stringBuilder);
}
return stringBuilder.toString();
}
private void printSubTopicsTreeFormat(TopicTreeNode node, String prefix, StringBuilder stringBuilder) {
List<Map.Entry<String, TopicTreeNode>> entries = new ArrayList<>(node.children.entrySet());
for (int i = 0; i < entries.size(); i++) {
boolean last = i == entries.size() - 1;
Map.Entry<String, TopicTreeNode> entry = entries.get(i);
stringBuilder.append(prefix)
.append(last ? "└── " : "├── ")
.append(entry.getKey())
.append("[")
.append(entry.getValue().count)
.append("]")
.append("\r\n");
printSubTopicsTreeFormat(entry.getValue(), prefix + (last ? " " : ""), stringBuilder);
}
}
private String normalizeTopicPath(String topicPath) {
return topicPath == null ? "" : topicPath.trim();
}
private void loadState() {
Path filePath = getFilePath();
if (!Files.exists(filePath)) {
return;
}
try (ObjectInputStream ois = new ObjectInputStream(new FileInputStream(filePath.toFile()))) {
RuntimeState state = (RuntimeState) ois.readObject();
topicSlices = state.topicSlices;
dateIndex = state.dateIndex;
dialogMap = state.dialogMap;
} catch (Exception e) {
log.error("[MemoryRuntime] 加载运行态失败", e);
topicSlices = new HashMap<>();
dateIndex = new HashMap<>();
dialogMap = new HashMap<>();
}
}
private void saveStateSafely() {
runtimeLock.lock();
try {
saveState();
} finally {
runtimeLock.unlock();
}
}
private void saveState() {
Path filePath = getFilePath();
Path tempPath = getTempFilePath();
try {
Files.createDirectories(Paths.get(MEMORY_DATA));
FileUtils.createParentDirectories(filePath.toFile());
try (ObjectOutputStream oos = new ObjectOutputStream(new FileOutputStream(tempPath.toFile()))) {
RuntimeState state = new RuntimeState();
state.topicSlices = new HashMap<>(topicSlices);
state.dateIndex = new HashMap<>(dateIndex);
state.dialogMap = new HashMap<>(dialogMap);
oos.writeObject(state);
}
Files.move(tempPath, filePath, java.nio.file.StandardCopyOption.REPLACE_EXISTING);
} catch (IOException e) {
log.error("[MemoryRuntime] 保存运行态失败", e);
}
}
private Path getFilePath() {
String id = ((PartnerAgentConfigLoader) AgentConfigLoader.INSTANCE).getConfig().getAgentId();
return Paths.get(MEMORY_DATA, id + "-" + RUNTIME_KEY + ".memory");
}
private Path getTempFilePath() {
String id = ((PartnerAgentConfigLoader) AgentConfigLoader.INSTANCE).getConfig().getAgentId();
return Paths.get(MEMORY_DATA, id + "-" + RUNTIME_KEY + "-temp.memory");
}
private static final class TopicTreeNode {
private final Map<String, TopicTreeNode> children = new LinkedHashMap<>();
private int count;
}
private static final class RuntimeState extends PersistableObject {
@Serial
private static final long serialVersionUID = 1L;
private Map<String, CopyOnWriteArrayList<SliceRef>> topicSlices = new HashMap<>();
private Map<LocalDate, CopyOnWriteArrayList<SliceRef>> dateIndex = new HashMap<>();
private HashMap<LocalDateTime, String> dialogMap = new HashMap<>();
}
}

View File

@@ -6,13 +6,16 @@ import lombok.EqualsAndHashCode;
import work.slhaf.partner.api.agent.factory.capability.annotation.InjectCapability;
import work.slhaf.partner.api.agent.factory.component.abstracts.AbstractAgentModule;
import work.slhaf.partner.api.agent.factory.component.annotation.InjectModule;
import work.slhaf.partner.api.chat.pojo.Message;
import work.slhaf.partner.core.cognation.CognationCapability;
import work.slhaf.partner.core.memory.MemoryCapability;
import work.slhaf.partner.core.memory.exception.UnExistedDateIndexException;
import work.slhaf.partner.core.memory.exception.UnExistedTopicException;
import work.slhaf.partner.core.memory.pojo.EvaluatedSlice;
import work.slhaf.partner.core.memory.pojo.MemoryResult;
import work.slhaf.partner.core.memory.pojo.ActivatedMemorySlice;
import work.slhaf.partner.core.memory.pojo.MemorySlice;
import work.slhaf.partner.core.memory.pojo.MemoryUnit;
import work.slhaf.partner.core.memory.pojo.SliceRef;
import work.slhaf.partner.module.modules.memory.runtime.MemoryRuntime;
import work.slhaf.partner.module.modules.memory.selector.evaluator.SliceSelectEvaluator;
import work.slhaf.partner.module.modules.memory.selector.evaluator.entity.EvaluatorInput;
import work.slhaf.partner.module.modules.memory.selector.extractor.MemorySelectExtractor;
@@ -20,9 +23,12 @@ import work.slhaf.partner.module.modules.memory.selector.extractor.entity.Extrac
import work.slhaf.partner.module.modules.memory.selector.extractor.entity.ExtractorResult;
import work.slhaf.partner.runtime.interaction.data.context.PartnerRunningFlowContext;
import java.time.Instant;
import java.time.LocalDate;
import java.time.ZoneId;
import java.util.ArrayList;
import java.util.Collection;
import java.util.LinkedHashMap;
import java.util.List;
@EqualsAndHashCode(callSuper = true)
@@ -33,89 +39,106 @@ public class MemorySelector extends AbstractAgentModule.Running<PartnerRunningFl
@InjectCapability
private CognationCapability cognationCapability;
@InjectModule
private MemoryRuntime memoryRuntime;
@InjectModule
private SliceSelectEvaluator sliceSelectEvaluator;
@InjectModule
private MemorySelectExtractor memorySelectExtractor;
@Override
public void execute(PartnerRunningFlowContext runningFlowContext) {
String userId = runningFlowContext.getSource();
//获取主题路径
ExtractorResult extractorResult = memorySelectExtractor.execute(runningFlowContext);
if (extractorResult.isRecall() || !extractorResult.getMatches().isEmpty()) {
memoryCapability.clearActivatedSlices(userId);
List<EvaluatedSlice> evaluatedSlices = selectAndEvaluateMemory(runningFlowContext, extractorResult);
memoryCapability.updateActivatedSlices(userId, evaluatedSlices);
memoryCapability.clearActivatedSlices();
List<ActivatedMemorySlice> activatedSlices = selectAndEvaluateMemory(runningFlowContext, extractorResult);
memoryCapability.updateActivatedSlices(activatedSlices);
}
setModuleContextRecall(runningFlowContext);
}
private List<EvaluatedSlice> selectAndEvaluateMemory(PartnerRunningFlowContext runningFlowContext, ExtractorResult extractorResult) {
private List<ActivatedMemorySlice> selectAndEvaluateMemory(PartnerRunningFlowContext runningFlowContext,
ExtractorResult extractorResult) {
log.debug("[MemorySelector] 触发记忆回溯...");
//查找切片
String userId = runningFlowContext.getSource();
List<MemoryResult> memoryResultList = new ArrayList<>();
setMemoryResultList(memoryResultList, extractorResult.getMatches(), userId);
//评估切片
LinkedHashMap<String, ActivatedMemorySlice> candidates = new LinkedHashMap<>();
setMemoryCandidates(candidates, extractorResult.getMatches());
removeDuplicateSlice(candidates.values());
EvaluatorInput evaluatorInput = EvaluatorInput.builder()
.input(runningFlowContext.getInput())
.memoryResults(memoryResultList)
.memorySlices(new ArrayList<>(candidates.values()))
.messages(cognationCapability.getChatMessages())
.build();
log.debug("[MemorySelector] 切片评估输入: {}", JSONObject.toJSONString(evaluatorInput));
List<EvaluatedSlice> memorySlices = sliceSelectEvaluator.execute(evaluatorInput);
List<ActivatedMemorySlice> memorySlices = sliceSelectEvaluator.execute(evaluatorInput);
log.debug("[MemorySelector] 切片评估结果: {}", JSONObject.toJSONString(memorySlices));
return memorySlices;
}
private void setModuleContextRecall(PartnerRunningFlowContext runningFlowContext) {
String userId = runningFlowContext.getSource();
boolean recall = memoryCapability.hasActivatedSlices(userId);
boolean recall = memoryCapability.hasActivatedSlices();
runningFlowContext.getModuleContext().getExtraContext().put("recall", recall);
if (recall) {
runningFlowContext.getModuleContext().getExtraContext().put("recall_count", memoryCapability.getActivatedSlicesSize(userId));
runningFlowContext.getModuleContext().getExtraContext().put("recall_count", memoryCapability.getActivatedSlicesSize());
}
}
private void setMemoryResultList(List<MemoryResult> memoryResultList, List<ExtractorMatchData> matches, String userId) {
private void setMemoryCandidates(LinkedHashMap<String, ActivatedMemorySlice> candidates,
List<ExtractorMatchData> matches) {
for (ExtractorMatchData match : matches) {
try {
MemoryResult memoryResult = switch (match.getType()) {
case ExtractorMatchData.Constant.TOPIC -> memoryCapability.selectMemory(match.getText());
case ExtractorMatchData.Constant.DATE ->
memoryCapability.selectMemory(LocalDate.parse(match.getText()));
default -> null;
List<SliceRef> refs = switch (match.getType()) {
case ExtractorMatchData.Constant.TOPIC -> memoryRuntime.findByTopicPath(match.getText());
case ExtractorMatchData.Constant.DATE -> memoryRuntime.findByDate(LocalDate.parse(match.getText()));
default -> List.of();
};
if (memoryResult == null || memoryResult.isEmpty()) continue;
removeDuplicateSlice(memoryResult);
memoryResultList.add(memoryResult);
for (SliceRef ref : refs) {
ActivatedMemorySlice recalledSlice = buildActivatedMemorySlice(ref);
if (recalledSlice != null) {
candidates.putIfAbsent(ref.getUnitId() + ":" + ref.getSliceId(), recalledSlice);
}
}
} catch (UnExistedDateIndexException | UnExistedTopicException e) {
log.error("[MemorySelector] 不存在的记忆索引! 请尝试更换更合适的主题提取LLM!", e);
log.error("[MemorySelector] 不存在的记忆索引", e);
log.error("[MemorySelector] 错误索引: {}", match.getText());
}
}
//清理切片记录
memoryCapability.cleanSelectedSliceFilter();
//根据userInfo过滤是否为私人记忆
for (MemoryResult memoryResult : memoryResultList) {
//过滤终点记忆
memoryResult.getMemorySliceResult().removeIf(m -> removeOrNot(m.getMemorySlice(), userId));
//过滤邻近记忆
memoryResult.getRelatedMemorySliceResult().removeIf(m -> removeOrNot(m, userId));
}
}
private void removeDuplicateSlice(MemoryResult memoryResult) {
Collection<String> values = memoryCapability.getDialogMap().values();
memoryResult.getRelatedMemorySliceResult().removeIf(m -> values.contains(m.getSummary()));
memoryResult.getMemorySliceResult().removeIf(m -> values.contains(m.getMemorySlice().getSummary()));
private ActivatedMemorySlice buildActivatedMemorySlice(SliceRef ref) {
MemoryUnit memoryUnit = memoryCapability.getMemoryUnit(ref.getUnitId());
MemorySlice memorySlice = memoryCapability.getMemorySlice(ref.getUnitId(), ref.getSliceId());
if (memoryUnit == null || memorySlice == null) {
return null;
}
List<Message> messages = sliceMessages(memoryUnit, memorySlice);
LocalDate date = Instant.ofEpochMilli(memorySlice.getTimestamp())
.atZone(ZoneId.systemDefault())
.toLocalDate();
return ActivatedMemorySlice.builder()
.unitId(ref.getUnitId())
.sliceId(ref.getSliceId())
.summary(memorySlice.getSummary())
.timestamp(memorySlice.getTimestamp())
.date(date)
.messages(messages)
.build();
}
private boolean removeOrNot(MemorySlice memorySlice, String userId) {
if (memorySlice.isPrivate()) {
return memorySlice.getStartUserId().equals(userId);
private List<Message> sliceMessages(MemoryUnit memoryUnit, MemorySlice memorySlice) {
List<Message> conversationMessages = memoryUnit.getConversationMessages();
if (conversationMessages == null || conversationMessages.isEmpty()) {
return List.of();
}
return false;
int start = Math.max(0, memorySlice.getStartIndex());
int end = Math.min(conversationMessages.size() - 1, memorySlice.getEndIndex());
if (start > end) {
return List.of();
}
return new ArrayList<>(conversationMessages.subList(start, end + 1));
}
private void removeDuplicateSlice(Collection<ActivatedMemorySlice> candidates) {
Collection<String> values = memoryRuntime.getDialogMap().values();
candidates.removeIf(m -> values.contains(m.getSummary()));
}
@Override

View File

@@ -1,6 +1,5 @@
package work.slhaf.partner.module.modules.memory.selector.evaluator;
import cn.hutool.core.date.DateUtil;
import cn.hutool.json.JSONUtil;
import com.alibaba.fastjson2.JSONObject;
import lombok.Data;
@@ -10,10 +9,7 @@ import work.slhaf.partner.api.agent.factory.component.abstracts.ActivateModel;
import work.slhaf.partner.api.agent.factory.component.annotation.Init;
import work.slhaf.partner.api.chat.pojo.Message;
import work.slhaf.partner.common.thread.InteractionThreadPoolExecutor;
import work.slhaf.partner.core.memory.pojo.EvaluatedSlice;
import work.slhaf.partner.core.memory.pojo.MemoryResult;
import work.slhaf.partner.core.memory.pojo.MemorySlice;
import work.slhaf.partner.core.memory.pojo.MemorySliceResult;
import work.slhaf.partner.core.memory.pojo.ActivatedMemorySlice;
import work.slhaf.partner.module.modules.memory.selector.evaluator.entity.EvaluatorBatchInput;
import work.slhaf.partner.module.modules.memory.selector.evaluator.entity.EvaluatorInput;
import work.slhaf.partner.module.modules.memory.selector.evaluator.entity.EvaluatorResult;
@@ -27,7 +23,7 @@ import java.util.concurrent.atomic.AtomicInteger;
@EqualsAndHashCode(callSuper = true)
@Data
public class SliceSelectEvaluator extends AbstractAgentModule.Sub<EvaluatorInput, List<EvaluatedSlice>> implements ActivateModel {
public class SliceSelectEvaluator extends AbstractAgentModule.Sub<EvaluatorInput, List<ActivatedMemorySlice>> implements ActivateModel {
private InteractionThreadPoolExecutor executor;
@Init
@@ -36,83 +32,58 @@ public class SliceSelectEvaluator extends AbstractAgentModule.Sub<EvaluatorInput
}
@Override
public List<EvaluatedSlice> execute(EvaluatorInput evaluatorInput) {
public List<ActivatedMemorySlice> execute(EvaluatorInput evaluatorInput) {
log.debug("[SliceSelectEvaluator] 切片评估模块开始...");
List<MemoryResult> memoryResultList = evaluatorInput.getMemoryResults();
List<ActivatedMemorySlice> memorySlices = evaluatorInput.getMemorySlices();
List<Callable<Void>> tasks = new ArrayList<>();
Queue<EvaluatedSlice> queue = new ConcurrentLinkedDeque<>();
Queue<ActivatedMemorySlice> queue = new ConcurrentLinkedDeque<>();
AtomicInteger count = new AtomicInteger(0);
for (MemoryResult memoryResult : memoryResultList) {
if (memoryResult.getMemorySliceResult().isEmpty() && memoryResult.getRelatedMemorySliceResult().isEmpty()) {
continue;
}
tasks.add(() -> {
int thisCount = count.incrementAndGet();
log.debug("[SliceSelectEvaluator] 评估[{}]开始", thisCount);
List<SliceSummary> sliceSummaryList = new ArrayList<>();
//映射查找键值
Map<Long, SliceSummary> map = new HashMap<>();
try {
setSliceSummaryList(memoryResult, sliceSummaryList, map);
EvaluatorBatchInput batchInput = EvaluatorBatchInput.builder()
.text(evaluatorInput.getInput())
.memory_slices(sliceSummaryList)
.history(evaluatorInput.getMessages())
.build();
log.debug("[SliceSelectEvaluator] 评估[{}]输入: {}", thisCount, JSONObject.toJSONString(batchInput));
EvaluatorResult evaluatorResult = formattedChat(
List.of(new Message(Message.Character.USER, JSONUtil.toJsonStr(batchInput))),
EvaluatorResult.class
);
log.debug("[SliceSelectEvaluator] 评估[{}]结果: {}", thisCount, JSONObject.toJSONString(evaluatorResult));
for (Long result : evaluatorResult.getResults()) {
SliceSummary sliceSummary = map.get(result);
EvaluatedSlice evaluatedSlice = EvaluatedSlice.builder()
.summary(sliceSummary.getSummary())
.date(sliceSummary.getDate())
.build();
// setEvaluatedSliceMessages(evaluatedSlice, memoryResult, sliceSummary.getId());
queue.offer(evaluatedSlice);
}
} catch (Exception e) {
log.error("[SliceSelectEvaluator] 评估[{}]出现错误: {}", thisCount, e.getLocalizedMessage());
}
return null;
});
if (memorySlices == null || memorySlices.isEmpty()) {
return List.of();
}
tasks.add(() -> {
int thisCount = count.incrementAndGet();
log.debug("[SliceSelectEvaluator] 评估[{}]开始", thisCount);
List<SliceSummary> sliceSummaryList = new ArrayList<>();
Map<Long, ActivatedMemorySlice> map = new HashMap<>();
try {
setSliceSummaryList(memorySlices, sliceSummaryList, map);
EvaluatorBatchInput batchInput = EvaluatorBatchInput.builder()
.text(evaluatorInput.getInput())
.memory_slices(sliceSummaryList)
.history(evaluatorInput.getMessages())
.build();
log.debug("[SliceSelectEvaluator] 评估[{}]输入: {}", thisCount, JSONObject.toJSONString(batchInput));
EvaluatorResult evaluatorResult = formattedChat(
List.of(new Message(Message.Character.USER, JSONUtil.toJsonStr(batchInput))),
EvaluatorResult.class
);
log.debug("[SliceSelectEvaluator] 评估[{}]结果: {}", thisCount, JSONObject.toJSONString(evaluatorResult));
for (Long result : evaluatorResult.getResults()) {
ActivatedMemorySlice slice = map.get(result);
if (slice != null) {
queue.offer(slice);
}
}
} catch (Exception e) {
log.error("[SliceSelectEvaluator] 评估[{}]出现错误: {}", thisCount, e.getLocalizedMessage());
}
return null;
});
executor.invokeAll(tasks, 30, TimeUnit.SECONDS);
log.debug("[SliceSelectEvaluator] 评估模块结束, 输出队列: {}", queue);
List<EvaluatedSlice> temp = queue.stream().toList();
return new ArrayList<>(temp);
return new ArrayList<>(queue);
}
private void setSliceSummaryList(MemoryResult memoryResult, List<SliceSummary> sliceSummaryList, Map<Long, SliceSummary> map) {
for (MemorySliceResult memorySliceResult : memoryResult.getMemorySliceResult()) {
SliceSummary sliceSummary = new SliceSummary();
sliceSummary.setId(memorySliceResult.getMemorySlice().getTimestamp());
StringBuilder stringBuilder = new StringBuilder();
if (memorySliceResult.getSliceBefore() != null) {
stringBuilder.append(memorySliceResult.getSliceBefore().getSummary())
.append("\r\n");
}
stringBuilder.append(memorySliceResult.getMemorySlice().getSummary());
if (memorySliceResult.getSliceAfter() != null) {
stringBuilder.append("\r\n")
.append(memorySliceResult.getSliceAfter().getSummary())
.append("\r\n");
}
sliceSummary.setSummary(stringBuilder.toString());
Long timestamp = memorySliceResult.getMemorySlice().getTimestamp();
sliceSummary.setDate(DateUtil.date(timestamp).toLocalDateTime().toLocalDate());
sliceSummaryList.add(sliceSummary);
map.put(timestamp, sliceSummary);
}
for (MemorySlice memorySlice : memoryResult.getRelatedMemorySliceResult()) {
private void setSliceSummaryList(List<ActivatedMemorySlice> memorySlices, List<SliceSummary> sliceSummaryList,
Map<Long, ActivatedMemorySlice> map) {
for (ActivatedMemorySlice memorySlice : memorySlices) {
SliceSummary sliceSummary = new SliceSummary();
sliceSummary.setId(memorySlice.getTimestamp());
sliceSummary.setSummary(memorySlice.getSummary());
sliceSummary.setDate(memorySlice.getDate());
sliceSummaryList.add(sliceSummary);
map.put(memorySlice.getTimestamp(), sliceSummary);
map.put(memorySlice.getTimestamp(), memorySlice);
}
}

View File

@@ -3,7 +3,7 @@ package work.slhaf.partner.module.modules.memory.selector.evaluator.entity;
import lombok.Builder;
import lombok.Data;
import work.slhaf.partner.api.chat.pojo.Message;
import work.slhaf.partner.core.memory.pojo.MemoryResult;
import work.slhaf.partner.core.memory.pojo.ActivatedMemorySlice;
import java.util.List;
@@ -12,5 +12,5 @@ import java.util.List;
public class EvaluatorInput {
private String input;
private List<Message> messages;
private List<MemoryResult> memoryResults;
private List<ActivatedMemorySlice> memorySlices;
}

View File

@@ -6,17 +6,17 @@ import lombok.EqualsAndHashCode;
import work.slhaf.partner.api.agent.factory.capability.annotation.InjectCapability;
import work.slhaf.partner.api.agent.factory.component.abstracts.AbstractAgentModule;
import work.slhaf.partner.api.agent.factory.component.abstracts.ActivateModel;
import work.slhaf.partner.api.agent.factory.component.annotation.InjectModule;
import work.slhaf.partner.api.chat.pojo.Message;
import work.slhaf.partner.api.chat.pojo.MetaMessage;
import work.slhaf.partner.core.cognation.CognationCapability;
import work.slhaf.partner.core.memory.MemoryCapability;
import work.slhaf.partner.core.memory.pojo.EvaluatedSlice;
import work.slhaf.partner.core.memory.pojo.ActivatedMemorySlice;
import work.slhaf.partner.module.modules.memory.runtime.MemoryRuntime;
import work.slhaf.partner.module.modules.memory.selector.extractor.entity.ExtractorInput;
import work.slhaf.partner.module.modules.memory.selector.extractor.entity.ExtractorMatchData;
import work.slhaf.partner.module.modules.memory.selector.extractor.entity.ExtractorResult;
import work.slhaf.partner.runtime.interaction.data.context.PartnerRunningFlowContext;
import java.util.ArrayList;
import java.util.List;
import static work.slhaf.partner.common.util.ExtractUtil.fixTopicPath;
@@ -29,25 +29,21 @@ public class MemorySelectExtractor extends AbstractAgentModule.Sub<PartnerRunnin
private MemoryCapability memoryCapability;
@InjectCapability
private CognationCapability cognationCapability;
@InjectModule
private MemoryRuntime memoryRuntime;
@Override
public ExtractorResult execute(PartnerRunningFlowContext context) {
log.debug("[MemorySelectExtractor] 主题提取模块开始...");
// 结构化为指定格式
List<Message> chatMessages = new ArrayList<>();
List<MetaMessage> metaMessages = cognationCapability.snapshotSingleMetaMessages(context.getSource());
for (MetaMessage metaMessage : metaMessages) {
chatMessages.add(metaMessage.getUserMessage());
chatMessages.add(metaMessage.getAssistantMessage());
}
List<Message> chatMessages = cognationCapability.snapshotChatMessages();
ExtractorResult extractorResult;
try {
List<EvaluatedSlice> activatedMemorySlices = memoryCapability.getActivatedSlices(context.getSource());
List<ActivatedMemorySlice> activatedMemorySlices = memoryCapability.getActivatedSlices();
ExtractorInput extractorInput = ExtractorInput.builder()
.text(context.getInput())
.date(context.getInfo().getDateTime().toLocalDate())
.history(chatMessages)
.topic_tree(memoryCapability.getTopicTree())
.topic_tree(memoryRuntime.getTopicTree())
.activatedMemorySlices(activatedMemorySlices)
.build();
log.debug("[MemorySelectExtractor] 主题提取输入: {}", JSONUtil.toJsonStr(extractorInput));

View File

@@ -3,7 +3,7 @@ package work.slhaf.partner.module.modules.memory.selector.extractor.entity;
import lombok.Builder;
import lombok.Data;
import work.slhaf.partner.api.chat.pojo.Message;
import work.slhaf.partner.core.memory.pojo.EvaluatedSlice;
import work.slhaf.partner.core.memory.pojo.ActivatedMemorySlice;
import java.time.LocalDate;
import java.util.List;
@@ -15,5 +15,5 @@ public class ExtractorInput {
private String topic_tree;
private LocalDate date;
private List<Message> history;
private List<EvaluatedSlice> activatedMemorySlices;
private List<ActivatedMemorySlice> activatedMemorySlices;
}

View File

@@ -8,31 +8,29 @@ import work.slhaf.partner.api.agent.factory.capability.annotation.InjectCapabili
import work.slhaf.partner.api.agent.factory.component.annotation.Init;
import work.slhaf.partner.api.agent.factory.component.annotation.InjectModule;
import work.slhaf.partner.api.chat.pojo.Message;
import work.slhaf.partner.api.chat.pojo.MetaMessage;
import work.slhaf.partner.common.thread.InteractionThreadPoolExecutor;
import work.slhaf.partner.core.action.entity.Schedulable;
import work.slhaf.partner.core.action.entity.StateAction;
import work.slhaf.partner.core.cognation.CognationCapability;
import work.slhaf.partner.core.memory.MemoryCapability;
import work.slhaf.partner.core.memory.pojo.MemorySlice;
import work.slhaf.partner.core.perceive.PerceiveCapability;
import work.slhaf.partner.core.memory.pojo.MemoryUnit;
import work.slhaf.partner.core.memory.pojo.SliceRef;
import work.slhaf.partner.module.common.module.PostRunningAgentModule;
import work.slhaf.partner.module.modules.action.scheduler.ActionScheduler;
import work.slhaf.partner.module.modules.memory.runtime.MemoryRuntime;
import work.slhaf.partner.module.modules.memory.updater.summarizer.MultiSummarizer;
import work.slhaf.partner.module.modules.memory.updater.summarizer.SingleSummarizer;
import work.slhaf.partner.module.modules.memory.updater.summarizer.TotalSummarizer;
import work.slhaf.partner.module.modules.memory.updater.summarizer.entity.SummarizeInput;
import work.slhaf.partner.module.modules.memory.updater.summarizer.entity.SummarizeResult;
import work.slhaf.partner.runtime.interaction.data.context.PartnerRunningFlowContext;
import java.time.LocalDateTime;
import java.util.*;
import java.util.concurrent.Callable;
import java.util.concurrent.ConcurrentHashMap;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.UUID;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import static work.slhaf.partner.common.util.ExtractUtil.extractUserId;
@EqualsAndHashCode(callSuper = true)
@Data
@@ -46,21 +44,20 @@ public class MemoryUpdater extends PostRunningAgentModule {
private CognationCapability cognationCapability;
@InjectCapability
private MemoryCapability memoryCapability;
@InjectCapability
private PerceiveCapability perceiveCapability;
@InjectModule
private MemoryRuntime memoryRuntime;
@InjectModule
private MultiSummarizer multiSummarizer;
@InjectModule
private SingleSummarizer singleSummarizer;
@InjectModule
private TotalSummarizer totalSummarizer;
private final AtomicBoolean updating = new AtomicBoolean(false);
private InteractionThreadPoolExecutor executor;
@InjectModule
private ActionScheduler actionScheduler;
private final AtomicBoolean updating = new AtomicBoolean(false);
private InteractionThreadPoolExecutor executor;
private volatile long lastUpdatedTime;
@Init
public void init() {
executor = InteractionThreadPoolExecutor.getInstance();
@@ -86,15 +83,13 @@ public class MemoryUpdater extends PostRunningAgentModule {
@Override
public void doExecute(PartnerRunningFlowContext context) {
executor.execute(() -> {
// 如果token 大于阈值,则更新记忆
JSONObject moduleContext = context.getModuleContext().getExtraContext();
boolean recall = moduleContext.getBoolean("recall");
if (recall) {
log.debug("[MemoryUpdater] 存在回忆");
int recallCount = moduleContext.getIntValue("recall_count");
log.debug("[MemoryUpdater] 记忆切片数量 [{}]", recallCount);
log.debug("[MemoryUpdater] 当前激活记忆数量 [{}]", recallCount);
}
boolean trigger = context.getModuleContext().getExtraContext().getBoolean("post_process_trigger");
boolean trigger = moduleContext.getBoolean("post_process_trigger");
if (!trigger) {
return;
}
@@ -110,7 +105,6 @@ public class MemoryUpdater extends PostRunningAgentModule {
private void tryAutoUpdate() {
long currentTime = System.currentTimeMillis();
long lastUpdatedTime = cognationCapability.getLastUpdatedTime();
int chatCount = cognationCapability.snapshotChatMessages().size();
if (lastUpdatedTime != 0 && currentTime - lastUpdatedTime > UPDATE_TRIGGER_INTERVAL && chatCount > 1) {
triggerMemoryUpdate(true);
@@ -131,7 +125,7 @@ public class MemoryUpdater extends PostRunningAgentModule {
updateMemory(chatSnapshot);
cognationCapability.rollChatMessagesWithSnapshot(chatSnapshot.size(), CONTEXT_RETAIN_DIVISOR);
if (refreshMemoryId) {
cognationCapability.refreshMemoryId();
memoryCapability.refreshMemoryId();
}
} catch (Exception e) {
log.error("[MemoryUpdater] 记忆更新线程出错: ", e);
@@ -142,75 +136,35 @@ public class MemoryUpdater extends PostRunningAgentModule {
private void updateMemory(List<Message> chatSnapshot) {
log.debug("[MemoryUpdater] 记忆更新流程开始...");
Map<String, String> singleMemorySummary = new ConcurrentHashMap<>();
Map<String, List<Message>> singleChatMessages = drainSingleChatMessages();
// 更新单聊记忆同时从chatMessages中去掉单聊记忆
updateSingleChatSlices(singleMemorySummary, singleChatMessages);
// 更新多人场景下的记忆及相关的确定性记忆
List<Message> multiChatMessages = excludeSingleChatMessages(chatSnapshot, singleChatMessages);
updateMultiChatSlices(singleMemorySummary, multiChatMessages);
cognationCapability.resetLastUpdatedTime();
List<Message> chatMessages = getCleanedMessages(chatSnapshot);
if (chatMessages.isEmpty()) {
return;
}
SummarizeInput summarizeInput = new SummarizeInput(chatMessages, memoryRuntime.getTopicTree());
log.debug("[MemoryUpdater] 记忆更新-总结流程-输入: {}", JSONObject.toJSONString(summarizeInput));
SummarizeResult summarizeResult = summarize(summarizeInput);
log.debug("[MemoryUpdater] 记忆更新-总结流程-输出: {}", JSONObject.toJSONString(summarizeResult));
MemoryUnit memoryUnit = buildMemoryUnit(chatMessages, summarizeResult);
memoryCapability.saveMemoryUnit(memoryUnit);
MemorySlice memorySlice = memoryUnit.getSlices().getFirst();
SliceRef sliceRef = new SliceRef(memoryUnit.getId(), memorySlice.getId());
bindTopics(memoryUnit, summarizeResult, sliceRef);
memoryRuntime.updateDialogMap(LocalDateTime.now(), summarizeResult.getSummary());
lastUpdatedTime = System.currentTimeMillis();
log.debug("[MemoryUpdater] 记忆更新流程结束...");
}
private Map<String, List<Message>> drainSingleChatMessages() {
Map<String, List<Message>> drainedMessages = new HashMap<>();
Map<String, List<MetaMessage>> drainedMetaMessages = cognationCapability.drainSingleMetaMessages();
for (Map.Entry<String, List<MetaMessage>> entry : drainedMetaMessages.entrySet()) {
List<Message> messages = new ArrayList<>();
for (MetaMessage metaMessage : entry.getValue()) {
messages.add(metaMessage.getUserMessage());
messages.add(metaMessage.getAssistantMessage());
}
drainedMessages.put(entry.getKey(), messages);
private void bindTopics(MemoryUnit memoryUnit, SummarizeResult summarizeResult, SliceRef sliceRef) {
memoryRuntime.indexMemoryUnit(memoryUnit);
memoryRuntime.bindTopic(summarizeResult.getTopicPath(), sliceRef);
if (summarizeResult.getRelatedTopicPath() == null) {
return;
}
return drainedMessages;
}
private List<Message> excludeSingleChatMessages(List<Message> chatSnapshot, Map<String, List<Message>> singleChatMessages) {
Set<Message> singleMessages = new HashSet<>();
for (List<Message> messages : singleChatMessages.values()) {
singleMessages.addAll(messages);
for (String relatedTopicPath : summarizeResult.getRelatedTopicPath()) {
memoryRuntime.bindTopic(relatedTopicPath, sliceRef);
}
return chatSnapshot.stream()
.filter(message -> !singleMessages.contains(message))
.toList();
}
private void updateMultiChatSlices(Map<String, String> singleMemorySummary, List<Message> multiChatMessages) {
// 此时chatMessages中不再包含单聊记录直接执行摘要以及切片插入
// 对剩下的多人聊天记录进行进行摘要
Callable<Void> task = () -> {
log.debug("[MemoryUpdater] 多人聊天记忆更新流程开始...");
List<Message> chatMessages = getCleanedMessages(multiChatMessages);
if (!chatMessages.isEmpty()) {
log.debug("[MemoryUpdater] 存在多人聊天记录, 流程正常进行...");
// 以第一条user对应的id为发起用户
String userId = extractUserId(chatMessages.getFirst().getContent());
if (userId == null) {
throw new RuntimeException("未匹配到 userId!");
}
SummarizeInput summarizeInput = new SummarizeInput(chatMessages, memoryCapability.getTopicTree());
log.debug("[MemoryUpdater] 多人聊天记忆更新-总结流程-输入: {}", summarizeInput);
SummarizeResult summarizeResult = summarize(summarizeInput);
log.debug("[MemoryUpdater] 多人聊天记忆更新-总结流程-输出: {}", summarizeResult);
MemorySlice memorySlice = getMemorySlice(userId, summarizeResult, chatMessages);
// 设置involvedUserId
setInvolvedUserId(userId, memorySlice, chatMessages);
memoryCapability.insertSlice(memorySlice, summarizeResult.getTopicPath());
memoryCapability.updateDialogMap(LocalDateTime.now(), summarizeResult.getSummary());
} else {
log.debug("[MemoryUpdater] 不存在多人聊天记录, 将以单聊总结为对话缓存的主要输入: {}", singleMemorySummary);
memoryCapability.updateDialogMap(LocalDateTime.now(), totalSummarizer.execute(new HashMap<>(singleMemorySummary)));
}
log.debug("[MemoryUpdater] 对话缓存更新完毕");
log.debug("[MemoryUpdater] 多人聊天记忆更新流程结束...");
return null;
};
executor.invokeAll(List.of(task));
}
// TODO need to move time information into perceive core
private List<Message> getCleanedMessages(List<Message> chatMessages) {
return chatMessages.stream()
.map(message -> {
@@ -226,84 +180,27 @@ public class MemoryUpdater extends PostRunningAgentModule {
}).toList();
}
private void setInvolvedUserId(String startUserId, MemorySlice memorySlice, List<Message> chatMessages) {
for (Message chatMessage : chatMessages) {
if (chatMessage.getRole() == Message.Character.ASSISTANT) {
continue;
}
// 匹配userId
String userId = extractUserId(chatMessage.getContent());
if (userId == null) {
continue;
}
if (userId.equals(startUserId)) {
continue;
}
memorySlice.setInvolvedUserIds(new ArrayList<>());
memorySlice.getInvolvedUserIds().add(userId);
}
}
private void updateSingleChatSlices(Map<String, String> singleMemorySummary, Map<String, List<Message>> singleChatMessages) {
log.debug("[MemoryUpdater] 单聊记忆更新流程开始...");
List<Callable<Void>> tasks = new ArrayList<>();
AtomicInteger count = new AtomicInteger(0);
for (Map.Entry<String, List<Message>> entry : singleChatMessages.entrySet()) {
String id = entry.getKey();
List<Message> messages = entry.getValue();
if (messages.isEmpty()) {
continue;
}
tasks.add(() -> {
int thisCount = count.incrementAndGet();
log.debug("[MemoryUpdater] 单聊记忆[{}]更新: {}", thisCount, id);
try {
// 单聊记忆更新
SummarizeInput summarizeInput = new SummarizeInput(messages, memoryCapability.getTopicTree());
log.debug("[MemoryUpdater] 单聊记忆[{}]更新-总结流程-输入: {}", thisCount, JSONObject.toJSONString(summarizeInput));
SummarizeResult summarizeResult = summarize(summarizeInput);
log.debug("[MemoryUpdater] 单聊记忆[{}]更新-总结流程-输出: {}", thisCount, JSONObject.toJSONString(summarizeResult));
MemorySlice memorySlice = getMemorySlice(id, summarizeResult, messages);
// 插入时userDialogMap已经进行更新
memoryCapability.insertSlice(memorySlice, summarizeResult.getTopicPath());
// 从chatMessages中移除单聊记录
cognationCapability.cleanMessage(messages);
// 添加至singleMemorySummary
String key = perceiveCapability.getUser(id).getNickName() + "[" + id + "]";
singleMemorySummary.put(key, summarizeResult.getSummary());
log.debug("[MemoryUpdater] 单聊记忆[{}]更新成功: ", thisCount);
} catch (Exception e) {
log.error("[MemoryUpdater] 单聊记忆[{}]更新出错: ", thisCount, e);
}
return null;
});
}
executor.invokeAll(tasks);
log.debug("[MemoryUpdater] 单聊记忆更新结束...");
}
private SummarizeResult summarize(SummarizeInput summarizeInput) {
singleSummarizer.execute(summarizeInput.getChatMessages());
return multiSummarizer.execute(summarizeInput);
}
private MemorySlice getMemorySlice(String userId, SummarizeResult summarizeResult, List<Message> chatMessages) {
private MemoryUnit buildMemoryUnit(List<Message> chatMessages, SummarizeResult summarizeResult) {
long now = System.currentTimeMillis();
MemorySlice memorySlice = new MemorySlice();
// 设置 memoryId,timestamp
memorySlice.setMemoryId(cognationCapability.getCurrentMemoryId());
memorySlice.setTimestamp(System.currentTimeMillis());
// 补充信息
memorySlice.setPrivate(summarizeResult.isPrivate());
memorySlice.setId(UUID.randomUUID().toString());
memorySlice.setStartIndex(0);
memorySlice.setEndIndex(Math.max(chatMessages.size() - 1, 0));
memorySlice.setSummary(summarizeResult.getSummary());
memorySlice.setChatMessages(chatMessages);
memorySlice.setStartUserId(userId);
List<List<String>> relatedTopicPathList = new ArrayList<>();
for (String string : summarizeResult.getRelatedTopicPath()) {
List<String> list = Arrays.stream(string.split("->")).toList();
relatedTopicPathList.add(list);
}
memorySlice.setRelatedTopics(relatedTopicPathList);
return memorySlice;
memorySlice.setTimestamp(now);
MemoryUnit memoryUnit = new MemoryUnit();
String memoryId = memoryCapability.getCurrentMemoryId();
memoryUnit.setId(memoryId == null || memoryId.isBlank() ? UUID.randomUUID().toString() : memoryId);
memoryUnit.setTimestamp(now);
memoryUnit.setConversationMessages(new ArrayList<>(chatMessages));
memoryUnit.setSlices(new ArrayList<>(List.of(memorySlice)));
return memoryUnit;
}
@Override

View File

@@ -9,5 +9,4 @@ public class SummarizeResult {
private String summary;
private String topicPath;
private List<String> relatedTopicPath;
private boolean isPrivate;
}

View File

@@ -5,6 +5,7 @@ import lombok.EqualsAndHashCode;
import work.slhaf.partner.api.agent.factory.capability.annotation.InjectCapability;
import work.slhaf.partner.api.agent.factory.component.abstracts.AbstractAgentModule;
import work.slhaf.partner.core.cognation.CognationCapability;
import work.slhaf.partner.core.memory.MemoryCapability;
import work.slhaf.partner.core.perceive.PerceiveCapability;
import work.slhaf.partner.core.perceive.pojo.User;
import work.slhaf.partner.runtime.interaction.data.context.PartnerRunningFlowContext;
@@ -18,6 +19,8 @@ public class PreprocessExecutor extends AbstractAgentModule.Running<PartnerRunni
@InjectCapability
private CognationCapability cognationCapability;
@InjectCapability
private MemoryCapability memoryCapability;
@InjectCapability
private PerceiveCapability perceiveCapability;
@Override
@@ -27,9 +30,9 @@ public class PreprocessExecutor extends AbstractAgentModule.Running<PartnerRunni
}
private void checkAndSetMemoryId() {
String currentMemoryId = cognationCapability.getCurrentMemoryId();
String currentMemoryId = memoryCapability.getCurrentMemoryId();
if (currentMemoryId == null || cognationCapability.getChatMessages().isEmpty()) {
cognationCapability.refreshMemoryId();
memoryCapability.refreshMemoryId();
}
}

View File

@@ -2,7 +2,6 @@ package experimental;
import org.junit.jupiter.api.Test;
import work.slhaf.partner.core.memory.MemoryCapability;
import work.slhaf.partner.core.memory.pojo.MemoryResult;
import java.lang.reflect.Proxy;
@@ -15,12 +14,12 @@ public class ReflectionTest {
@Test
public void proxyTest() {
MemoryCapability memory = (MemoryCapability) Proxy.newProxyInstance(this.getClass().getClassLoader(), new Class[]{MemoryCapability.class}, (proxy, method, args) -> {
if ("selectMemory".equals(method.getName())) {
if ("getCurrentMemoryId".equals(method.getName())) {
System.out.println(111);
return new MemoryResult();
return "memory-id";
}
return null;
});
memory.selectMemory("111");
memory.getCurrentMemoryId();
}
}

View File

@@ -74,7 +74,7 @@ class ActionExecutorTest {
@BeforeEach
void setUp() {
lenient().when(cognationCapability.getChatMessages()).thenReturn(Collections.emptyList());
lenient().when(memoryCapability.getActivatedSlices(anyString())).thenReturn(Collections.emptyList());
lenient().when(memoryCapability.getActivatedSlices()).thenReturn(Collections.emptyList());
lenient().when(actionCapability.putPhaserRecord(any(Phaser.class), any(ExecutableAction.class)))
.thenAnswer(inv -> new PhaserRecord(inv.getArgument(0), inv.getArgument(1)));
lenient().when(actionCapability.loadMetaActionInfo(anyString())).thenAnswer(inv -> {

View File

@@ -1,17 +1,15 @@
package work.slhaf.partner.module.modules.core;
import org.jetbrains.annotations.NotNull;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.jetbrains.annotations.NotNull;
import org.mockito.ArgumentCaptor;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
import org.w3c.dom.Document;
import org.w3c.dom.Element;
import work.slhaf.partner.api.agent.runtime.interaction.flow.ContextBlock;
import work.slhaf.partner.api.chat.pojo.Message;
import work.slhaf.partner.api.chat.pojo.MetaMessage;
import work.slhaf.partner.core.cognation.CognationCapability;
import work.slhaf.partner.runtime.interaction.data.context.PartnerRunningFlowContext;
@@ -19,13 +17,8 @@ import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.locks.ReentrantLock;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.ArgumentMatchers.anyString;
import static org.junit.jupiter.api.Assertions.*;
import static org.mockito.Mockito.lenient;
import static org.mockito.Mockito.verify;
@ExtendWith(MockitoExtension.class)
class CommunicationProducerTest {
@@ -100,12 +93,6 @@ class CommunicationProducerTest {
Message lastAssistantMessage = producer.getChatMessages().get(3);
assertEquals("收到", lastAssistantMessage.getContent());
ArgumentCaptor<MetaMessage> metaMessageCaptor = ArgumentCaptor.forClass(MetaMessage.class);
verify(cognationCapability).addMetaMessage(anyString(), metaMessageCaptor.capture());
MetaMessage metaMessage = metaMessageCaptor.getValue();
assertNotNull(metaMessage);
assertTrue(metaMessage.getUserMessage().getContent().startsWith("[USER]: u-1: 你好,介绍一下你现在看到的上下文"));
assertEquals("收到", metaMessage.getAssistantMessage().getContent());
assertEquals("收到", context.getCoreResponse().getString("text"));
}