mirror of
https://github.com/slhaf/Partner.git
synced 2026-05-12 08:43:02 +08:00
refactor(memory): decouple memory storage and runtime structures
This commit is contained in:
@@ -2,11 +2,8 @@ package work.slhaf.partner.core.cognation;
|
||||
|
||||
import work.slhaf.partner.api.agent.factory.capability.annotation.Capability;
|
||||
import work.slhaf.partner.api.chat.pojo.Message;
|
||||
import work.slhaf.partner.api.chat.pojo.MetaMessage;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.concurrent.locks.Lock;
|
||||
|
||||
@Capability("cognation")
|
||||
@@ -20,26 +17,6 @@ public interface CognationCapability {
|
||||
|
||||
void rollChatMessagesWithSnapshot(int snapshotSize, int retainDivisor);
|
||||
|
||||
void cleanMessage(List<Message> messages);
|
||||
|
||||
Lock getMessageLock();
|
||||
|
||||
void addMetaMessage(String userId, MetaMessage metaMessage);
|
||||
|
||||
List<Message> unpackAndClear(String userId);
|
||||
|
||||
void refreshMemoryId();
|
||||
|
||||
void resetLastUpdatedTime();
|
||||
|
||||
long getLastUpdatedTime();
|
||||
|
||||
HashMap<String, List<MetaMessage>> getSingleMetaMessageMap();
|
||||
|
||||
Map<String, List<MetaMessage>> drainSingleMetaMessages();
|
||||
|
||||
List<MetaMessage> snapshotSingleMetaMessages(String userId);
|
||||
|
||||
String getCurrentMemoryId();
|
||||
|
||||
}
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
package work.slhaf.partner.core.cognation;
|
||||
|
||||
import com.alibaba.fastjson2.JSONObject;
|
||||
import lombok.EqualsAndHashCode;
|
||||
import lombok.Getter;
|
||||
import lombok.Setter;
|
||||
@@ -9,13 +8,13 @@ import work.slhaf.partner.api.agent.factory.capability.annotation.CapabilityCore
|
||||
import work.slhaf.partner.api.agent.factory.capability.annotation.CapabilityMethod;
|
||||
import work.slhaf.partner.api.agent.runtime.interaction.AgentRuntime;
|
||||
import work.slhaf.partner.api.chat.pojo.Message;
|
||||
import work.slhaf.partner.api.chat.pojo.MetaMessage;
|
||||
import work.slhaf.partner.core.PartnerCore;
|
||||
import work.slhaf.partner.runtime.interaction.data.context.PartnerRunningFlowContext;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.io.Serial;
|
||||
import java.util.*;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.concurrent.locks.Lock;
|
||||
import java.util.concurrent.locks.ReentrantLock;
|
||||
|
||||
@@ -35,10 +34,6 @@ public class CognationCore extends PartnerCore<CognationCore> {
|
||||
* 主模型的聊天记录
|
||||
*/
|
||||
private List<Message> chatMessages = new ArrayList<>();
|
||||
private HashMap<String /*startUserId*/, List<MetaMessage>> singleMetaMessageMap = new HashMap<>();
|
||||
private String currentMemoryId;
|
||||
private long lastUpdatedTime;
|
||||
|
||||
public CognationCore() throws IOException, ClassNotFoundException {
|
||||
}
|
||||
|
||||
@@ -86,112 +81,11 @@ public class CognationCore extends PartnerCore<CognationCore> {
|
||||
}
|
||||
}
|
||||
|
||||
@CapabilityMethod
|
||||
public long getLastUpdatedTime() {
|
||||
return lastUpdatedTime;
|
||||
}
|
||||
|
||||
@CapabilityMethod
|
||||
public HashMap<String, List<MetaMessage>> getSingleMetaMessageMap() {
|
||||
return singleMetaMessageMap;
|
||||
}
|
||||
|
||||
@CapabilityMethod
|
||||
public String getCurrentMemoryId() {
|
||||
return currentMemoryId;
|
||||
}
|
||||
|
||||
@CapabilityMethod
|
||||
public void cleanMessage(List<Message> messages) {
|
||||
messageLock.lock();
|
||||
try {
|
||||
this.getChatMessages().removeAll(messages);
|
||||
} finally {
|
||||
messageLock.unlock();
|
||||
}
|
||||
}
|
||||
|
||||
@CapabilityMethod
|
||||
public Lock getMessageLock() {
|
||||
return messageLock;
|
||||
}
|
||||
|
||||
@CapabilityMethod
|
||||
public void addMetaMessage(String userId, MetaMessage metaMessage) {
|
||||
log.debug("[{}] 当前会话历史: {}", getCoreKey(), JSONObject.toJSONString(singleMetaMessageMap));
|
||||
messageLock.lock();
|
||||
try {
|
||||
if (singleMetaMessageMap.containsKey(userId)) {
|
||||
singleMetaMessageMap.get(userId).add(metaMessage);
|
||||
} else {
|
||||
singleMetaMessageMap.put(userId, new java.util.ArrayList<>());
|
||||
singleMetaMessageMap.get(userId).add(metaMessage);
|
||||
}
|
||||
} finally {
|
||||
messageLock.unlock();
|
||||
}
|
||||
log.debug("[{}] 会话历史更新: {}", getCoreKey(), JSONObject.toJSONString(singleMetaMessageMap));
|
||||
}
|
||||
|
||||
@CapabilityMethod
|
||||
public List<Message> unpackAndClear(String userId) {
|
||||
messageLock.lock();
|
||||
try {
|
||||
List<Message> messages = new ArrayList<>();
|
||||
List<MetaMessage> metaMessages = singleMetaMessageMap.get(userId);
|
||||
if (metaMessages == null) {
|
||||
return messages;
|
||||
}
|
||||
for (MetaMessage metaMessage : metaMessages) {
|
||||
messages.add(metaMessage.getUserMessage());
|
||||
messages.add(metaMessage.getAssistantMessage());
|
||||
}
|
||||
singleMetaMessageMap.remove(userId);
|
||||
return messages;
|
||||
} finally {
|
||||
messageLock.unlock();
|
||||
}
|
||||
}
|
||||
|
||||
@CapabilityMethod
|
||||
public Map<String, List<MetaMessage>> drainSingleMetaMessages() {
|
||||
messageLock.lock();
|
||||
try {
|
||||
Map<String, List<MetaMessage>> drained = new HashMap<>();
|
||||
for (Map.Entry<String, List<MetaMessage>> entry : singleMetaMessageMap.entrySet()) {
|
||||
drained.put(entry.getKey(), new ArrayList<>(entry.getValue()));
|
||||
}
|
||||
singleMetaMessageMap.clear();
|
||||
return drained;
|
||||
} finally {
|
||||
messageLock.unlock();
|
||||
}
|
||||
}
|
||||
|
||||
@CapabilityMethod
|
||||
public List<MetaMessage> snapshotSingleMetaMessages(String userId) {
|
||||
messageLock.lock();
|
||||
try {
|
||||
List<MetaMessage> metaMessages = singleMetaMessageMap.get(userId);
|
||||
if (metaMessages == null) {
|
||||
return List.of();
|
||||
}
|
||||
return List.copyOf(metaMessages);
|
||||
} finally {
|
||||
messageLock.unlock();
|
||||
}
|
||||
}
|
||||
|
||||
@CapabilityMethod
|
||||
public void refreshMemoryId() {
|
||||
currentMemoryId = UUID.randomUUID().toString();
|
||||
}
|
||||
|
||||
@CapabilityMethod
|
||||
public void resetLastUpdatedTime() {
|
||||
lastUpdatedTime = System.currentTimeMillis();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected String getCoreKey() {
|
||||
return "cognation-core";
|
||||
|
||||
@@ -1,51 +1,36 @@
|
||||
package work.slhaf.partner.core.memory;
|
||||
|
||||
import work.slhaf.partner.api.agent.factory.capability.annotation.Capability;
|
||||
import work.slhaf.partner.core.memory.pojo.EvaluatedSlice;
|
||||
import work.slhaf.partner.core.memory.pojo.MemoryResult;
|
||||
import work.slhaf.partner.core.memory.pojo.ActivatedMemorySlice;
|
||||
import work.slhaf.partner.core.memory.pojo.MemorySlice;
|
||||
import work.slhaf.partner.core.memory.pojo.MemoryUnit;
|
||||
|
||||
import java.time.LocalDate;
|
||||
import java.time.LocalDateTime;
|
||||
import java.util.HashMap;
|
||||
import java.util.Collection;
|
||||
import java.util.List;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
|
||||
@Capability(value = "memory")
|
||||
public interface MemoryCapability {
|
||||
|
||||
void cleanSelectedSliceFilter();
|
||||
void clearActivatedSlices();
|
||||
|
||||
String getTopicTree();
|
||||
void updateActivatedSlices(List<ActivatedMemorySlice> memorySlices);
|
||||
|
||||
HashMap<LocalDateTime, String> getDialogMap();
|
||||
boolean hasActivatedSlices();
|
||||
|
||||
ConcurrentHashMap<LocalDateTime, String> getUserDialogMap(String userId);
|
||||
int getActivatedSlicesSize();
|
||||
|
||||
void updateDialogMap(LocalDateTime dateTime, String newDialogCache);
|
||||
List<ActivatedMemorySlice> getActivatedSlices();
|
||||
|
||||
String getDialogMapStr();
|
||||
void saveMemoryUnit(MemoryUnit memoryUnit);
|
||||
|
||||
String getUserDialogMapStr(String userId);
|
||||
MemoryUnit getMemoryUnit(String unitId);
|
||||
|
||||
void updateActivatedSlices(String userId, List<EvaluatedSlice> memorySlices);
|
||||
MemorySlice getMemorySlice(String unitId, String sliceId);
|
||||
|
||||
String getActivatedSlicesStr(String userId);
|
||||
Collection<MemoryUnit> listMemoryUnits();
|
||||
|
||||
HashMap<String, List<EvaluatedSlice>> getActivatedSlices();
|
||||
void refreshMemoryId();
|
||||
|
||||
void clearActivatedSlices(String userId);
|
||||
|
||||
boolean hasActivatedSlices(String userId);
|
||||
|
||||
int getActivatedSlicesSize(String userId);
|
||||
|
||||
List<EvaluatedSlice> getActivatedSlices(String userId);
|
||||
|
||||
MemoryResult selectMemory(String topicPathStr);
|
||||
|
||||
MemoryResult selectMemory(LocalDate date);
|
||||
|
||||
void insertSlice(MemorySlice memorySlice, String topicPath);
|
||||
String getCurrentMemoryId();
|
||||
|
||||
}
|
||||
|
||||
@@ -7,19 +7,12 @@ import lombok.extern.slf4j.Slf4j;
|
||||
import work.slhaf.partner.api.agent.factory.capability.annotation.CapabilityCore;
|
||||
import work.slhaf.partner.api.agent.factory.capability.annotation.CapabilityMethod;
|
||||
import work.slhaf.partner.core.PartnerCore;
|
||||
import work.slhaf.partner.core.memory.exception.UnExistedDateIndexException;
|
||||
import work.slhaf.partner.core.memory.exception.UnExistedTopicException;
|
||||
import work.slhaf.partner.core.memory.pojo.EvaluatedSlice;
|
||||
import work.slhaf.partner.core.memory.pojo.MemoryResult;
|
||||
import work.slhaf.partner.core.memory.pojo.ActivatedMemorySlice;
|
||||
import work.slhaf.partner.core.memory.pojo.MemorySlice;
|
||||
import work.slhaf.partner.core.memory.pojo.MemorySliceResult;
|
||||
import work.slhaf.partner.core.memory.pojo.node.MemoryNode;
|
||||
import work.slhaf.partner.core.memory.pojo.node.TopicNode;
|
||||
import work.slhaf.partner.core.memory.pojo.MemoryUnit;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.io.Serial;
|
||||
import java.time.LocalDate;
|
||||
import java.time.LocalDateTime;
|
||||
import java.util.*;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
import java.util.concurrent.CopyOnWriteArrayList;
|
||||
@@ -35,571 +28,121 @@ public class MemoryCore extends PartnerCore<MemoryCore> {
|
||||
|
||||
@Serial
|
||||
private static final long serialVersionUID = 1L;
|
||||
private final Lock sliceInsertLock = new ReentrantLock();
|
||||
/**
|
||||
* key: 根主题名称 value: 根主题节点
|
||||
*/
|
||||
private HashMap<String, TopicNode> topicNodes = new HashMap<>();
|
||||
/**
|
||||
* 用于存储已存在的主题列表,便于记忆查找, 使用根主题名称作为键, 子主题名称集合为值
|
||||
* 该部分在'主题提取LLM'的system prompt中常驻
|
||||
*/
|
||||
private HashMap<String /*根主题名*/, LinkedHashSet<String> /*子主题列表*/> existedTopics = new HashMap<>();
|
||||
/**
|
||||
* 临时的同一对话切片容器, 用于为同一对话内的不同切片提供更新上下文的场所
|
||||
*/
|
||||
private HashMap<String /*对话id, 即slice中的字段'memoryId'*/, List<MemorySlice>> currentDateDialogSlices = new HashMap<>();
|
||||
/**
|
||||
* 记忆节点的日期索引, 同一日期内按照对话id区分
|
||||
*/
|
||||
private HashMap<LocalDate, Set<String>> dateIndex = new HashMap<>();
|
||||
/**
|
||||
* 已被选中的切片时间戳集合,需要及时清理
|
||||
*/
|
||||
private Set<Long> selectedSlices = new HashSet<>();
|
||||
private HashMap<String, List<String>> userIndex = new HashMap<>();
|
||||
private MemoryCache cache = new MemoryCache();
|
||||
|
||||
private final Lock memoryLock = new ReentrantLock();
|
||||
private ConcurrentHashMap<String, MemoryUnit> memoryUnits = new ConcurrentHashMap<>();
|
||||
private List<ActivatedMemorySlice> activatedSlices = new CopyOnWriteArrayList<>();
|
||||
private String currentMemoryId;
|
||||
|
||||
public MemoryCore() throws IOException, ClassNotFoundException {
|
||||
}
|
||||
|
||||
|
||||
@CapabilityMethod
|
||||
public MemoryResult selectMemory(LocalDate date) throws IOException, ClassNotFoundException {
|
||||
MemoryResult memoryResult = new MemoryResult();
|
||||
CopyOnWriteArrayList<MemorySliceResult> targetSliceList = new CopyOnWriteArrayList<>();
|
||||
//加载节点并获取记忆切片列表
|
||||
List<List<MemorySlice>> currentDateDialogSlices = loadSlicesByDate(date);
|
||||
for (List<MemorySlice> value : currentDateDialogSlices) {
|
||||
for (MemorySlice memorySlice : value) {
|
||||
if (selectedSlices.contains(memorySlice.getTimestamp())) {
|
||||
continue;
|
||||
}
|
||||
MemorySliceResult memorySliceResult = new MemorySliceResult();
|
||||
memorySliceResult.setMemorySlice(memorySlice);
|
||||
targetSliceList.add(memorySliceResult);
|
||||
selectedSlices.add(memorySlice.getTimestamp());
|
||||
}
|
||||
}
|
||||
memoryResult.setMemorySliceResult(targetSliceList);
|
||||
return cacheFilter(memoryResult);
|
||||
public void clearActivatedSlices() {
|
||||
activatedSlices.clear();
|
||||
}
|
||||
|
||||
@CapabilityMethod
|
||||
public void insertSlice(MemorySlice memorySlice, String topicPath) {
|
||||
sliceInsertLock.lock();
|
||||
List<String> topicPathList = Arrays.stream(topicPath.split("->")).toList();
|
||||
public void updateActivatedSlices(List<ActivatedMemorySlice> memorySlices) {
|
||||
activatedSlices = new CopyOnWriteArrayList<>(memorySlices);
|
||||
}
|
||||
|
||||
@CapabilityMethod
|
||||
public boolean hasActivatedSlices() {
|
||||
return !activatedSlices.isEmpty();
|
||||
}
|
||||
|
||||
@CapabilityMethod
|
||||
public int getActivatedSlicesSize() {
|
||||
return activatedSlices.size();
|
||||
}
|
||||
|
||||
@CapabilityMethod
|
||||
public List<ActivatedMemorySlice> getActivatedSlices() {
|
||||
return new ArrayList<>(activatedSlices);
|
||||
}
|
||||
|
||||
@CapabilityMethod
|
||||
public void saveMemoryUnit(MemoryUnit memoryUnit) {
|
||||
memoryLock.lock();
|
||||
try {
|
||||
//检查是否存在当天对应的memorySlice并确定是否插入
|
||||
//每日刷新缓存
|
||||
checkCacheDate();
|
||||
//如果topicPath在memorySliceCache中存在对应缓存,由于进行的插入操作,则需要移除该缓存,但不清除相关计数
|
||||
clearCacheByTopicPath(topicPathList);
|
||||
insertMemory(topicPathList, memorySlice);
|
||||
if (!memorySlice.isPrivate()) {
|
||||
updateUserDialogMap(memorySlice);
|
||||
}
|
||||
} catch (Exception e) {
|
||||
log.error("插入记忆时出错: ", e);
|
||||
normalizeMemoryUnit(memoryUnit);
|
||||
memoryUnits.put(memoryUnit.getId(), memoryUnit);
|
||||
} finally {
|
||||
memoryLock.unlock();
|
||||
}
|
||||
log.debug("插入切片: {}, 路径: {}", memorySlice, topicPath);
|
||||
sliceInsertLock.unlock();
|
||||
}
|
||||
|
||||
@CapabilityMethod
|
||||
public String getTopicTree() {
|
||||
StringBuilder stringBuilder = new StringBuilder();
|
||||
for (Map.Entry<String, TopicNode> entry : topicNodes.entrySet()) {
|
||||
String rootName = entry.getKey();
|
||||
TopicNode rootNode = entry.getValue();
|
||||
stringBuilder.append(rootName).append("[root]").append("\r\n");
|
||||
printSubTopicsTreeFormat(rootNode, "", stringBuilder);
|
||||
}
|
||||
return stringBuilder.toString();
|
||||
public MemoryUnit getMemoryUnit(String unitId) {
|
||||
return memoryUnits.get(unitId);
|
||||
}
|
||||
|
||||
@CapabilityMethod
|
||||
public void updateDialogMap(LocalDateTime dateTime, String newDialogCache) {
|
||||
List<LocalDateTime> keysToRemove = new ArrayList<>();
|
||||
HashMap<LocalDateTime, String> dialogMap = cache.dialogMap;
|
||||
dialogMap.forEach((k, v) -> {
|
||||
if (dateTime.minusDays(2).isAfter(k)) {
|
||||
keysToRemove.add(k);
|
||||
}
|
||||
});
|
||||
for (LocalDateTime temp : keysToRemove) {
|
||||
dialogMap.remove(temp);
|
||||
}
|
||||
keysToRemove.clear();
|
||||
//放入新缓存
|
||||
dialogMap.put(dateTime, newDialogCache);
|
||||
}
|
||||
|
||||
@CapabilityMethod
|
||||
public HashMap<LocalDateTime, String> getDialogMap() {
|
||||
return cache.dialogMap;
|
||||
}
|
||||
|
||||
@CapabilityMethod
|
||||
public ConcurrentHashMap<LocalDateTime, String> getUserDialogMap(String userId) {
|
||||
return cache.userDialogMap.get(userId);
|
||||
}
|
||||
|
||||
@CapabilityMethod
|
||||
public String getDialogMapStr() {
|
||||
StringBuilder str = new StringBuilder();
|
||||
this.getDialogMap().forEach((dateTime, dialog) -> str.append("\n\n").append("[").append(dateTime).append("]\n")
|
||||
.append(dialog));
|
||||
return str.toString();
|
||||
}
|
||||
|
||||
@CapabilityMethod
|
||||
public String getUserDialogMapStr(String userId) {
|
||||
ConcurrentHashMap<String, ConcurrentHashMap<LocalDateTime, String>> userDialogMap = cache.userDialogMap;
|
||||
if (userDialogMap.containsKey(userId)) {
|
||||
StringBuilder str = new StringBuilder();
|
||||
Collection<String> dialogMapValues = this.getDialogMap().values();
|
||||
userDialogMap.get(userId).forEach((dateTime, dialog) -> {
|
||||
if (dialogMapValues.contains(dialog)) {
|
||||
return;
|
||||
}
|
||||
str.append("\n\n").append("[").append(dateTime).append("]\n")
|
||||
.append(dialog);
|
||||
});
|
||||
return str.toString();
|
||||
} else {
|
||||
public MemorySlice getMemorySlice(String unitId, String sliceId) {
|
||||
MemoryUnit memoryUnit = memoryUnits.get(unitId);
|
||||
if (memoryUnit == null || memoryUnit.getSlices() == null) {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
@CapabilityMethod
|
||||
public MemoryResult selectMemory(String topicPathStr) {
|
||||
MemoryResult memoryResult;
|
||||
List<String> topicPath = List.of(topicPathStr.split("->"));
|
||||
try {
|
||||
List<String> path = new ArrayList<>(topicPath);
|
||||
//每日刷新缓存
|
||||
checkCacheDate();
|
||||
//检测缓存并更新计数, 查看是否需要放入缓存
|
||||
updateCacheCounter(path);
|
||||
//查看是否存在缓存,如果存在,则直接返回
|
||||
if ((memoryResult = selectCache(path)) != null) {
|
||||
return memoryResult;
|
||||
for (MemorySlice slice : memoryUnit.getSlices()) {
|
||||
if (sliceId.equals(slice.getId())) {
|
||||
return slice;
|
||||
}
|
||||
memoryResult = selectMemory(path);
|
||||
//尝试更新缓存
|
||||
updateCache(topicPath, memoryResult);
|
||||
} catch (Exception e) {
|
||||
log.error("[{}] selectMemory error: ", getCoreKey(), e);
|
||||
log.error("[{}] 路径: {}", getCoreKey(), topicPathStr);
|
||||
log.error("[{}] 主题树: {}", getCoreKey(), getTopicTree());
|
||||
memoryResult = new MemoryResult();
|
||||
memoryResult.setRelatedMemorySliceResult(new ArrayList<>());
|
||||
memoryResult.setMemorySliceResult(new CopyOnWriteArrayList<>());
|
||||
}
|
||||
return cacheFilter(memoryResult);
|
||||
}
|
||||
|
||||
@CapabilityMethod
|
||||
public void updateActivatedSlices(String userId, List<EvaluatedSlice> memorySlices) {
|
||||
cache.activatedSlices.put(userId, memorySlices);
|
||||
log.debug("[{}] 已更新激活切片, userId: {}", getCoreKey(), userId);
|
||||
}
|
||||
|
||||
@CapabilityMethod
|
||||
public String getActivatedSlicesStr(String userId) {
|
||||
HashMap<String, List<EvaluatedSlice>> activatedSlices = cache.activatedSlices;
|
||||
if (activatedSlices.containsKey(userId)) {
|
||||
StringBuilder str = new StringBuilder();
|
||||
activatedSlices.get(userId).forEach(slice -> str.append("\n\n").append("[").append(slice.getDate()).append("]\n")
|
||||
.append(slice.getSummary()));
|
||||
return str.toString();
|
||||
} else {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
@CapabilityMethod
|
||||
public HashMap<String, List<EvaluatedSlice>> getActivatedSlices() {
|
||||
return cache.activatedSlices;
|
||||
}
|
||||
|
||||
@CapabilityMethod
|
||||
public void clearActivatedSlices(String userId) {
|
||||
cache.activatedSlices.remove(userId);
|
||||
}
|
||||
|
||||
@CapabilityMethod
|
||||
public boolean hasActivatedSlices(String userId) {
|
||||
HashMap<String, List<EvaluatedSlice>> activatedSlices = cache.activatedSlices;
|
||||
if (!activatedSlices.containsKey(userId)) {
|
||||
return false;
|
||||
}
|
||||
return !activatedSlices.get(userId).isEmpty();
|
||||
}
|
||||
|
||||
@CapabilityMethod
|
||||
public int getActivatedSlicesSize(String userId) {
|
||||
return cache.activatedSlices.get(userId).size();
|
||||
}
|
||||
|
||||
@CapabilityMethod
|
||||
public List<EvaluatedSlice> getActivatedSlices(String userId) {
|
||||
return cache.activatedSlices.get(userId);
|
||||
}
|
||||
|
||||
@CapabilityMethod
|
||||
public void cleanSelectedSliceFilter() {
|
||||
this.selectedSlices.clear();
|
||||
}
|
||||
|
||||
private List<List<MemorySlice>> loadSlicesByDate(LocalDate date) throws IOException, ClassNotFoundException {
|
||||
if (!dateIndex.containsKey(date)) {
|
||||
throw new UnExistedDateIndexException("不存在的日期索引: " + date);
|
||||
}
|
||||
List<List<MemorySlice>> list = new ArrayList<>();
|
||||
for (String memoryNodeId : dateIndex.get(date)) {
|
||||
MemoryNode memoryNode = new MemoryNode();
|
||||
memoryNode.setMemoryNodeId(memoryNodeId);
|
||||
list.add(memoryNode.loadMemorySliceList());
|
||||
}
|
||||
return list;
|
||||
}
|
||||
|
||||
private void printSubTopicsTreeFormat(TopicNode node, String prefix, StringBuilder stringBuilder) {
|
||||
if (node.getTopicNodes() == null || node.getTopicNodes().isEmpty()) return;
|
||||
|
||||
List<Map.Entry<String, TopicNode>> entries = new ArrayList<>(node.getTopicNodes().entrySet());
|
||||
for (int i = 0; i < entries.size(); i++) {
|
||||
boolean last = (i == entries.size() - 1);
|
||||
Map.Entry<String, TopicNode> entry = entries.get(i);
|
||||
stringBuilder.append(prefix).append(last ? "└── " : "├── ").append(entry.getKey()).append("[").append(entry.getValue().getMemoryNodes().size()).append("]").append("\r\n");
|
||||
printSubTopicsTreeFormat(entry.getValue(), prefix + (last ? " " : "│ "), stringBuilder);
|
||||
}
|
||||
}
|
||||
|
||||
private void insertMemory(List<String> topicPath, MemorySlice slice) throws IOException, ClassNotFoundException {
|
||||
LocalDate now = LocalDate.now();
|
||||
boolean hasSlice = false;
|
||||
MemoryNode node = null;
|
||||
TopicNode lastTopicNode = generateTopicPath(topicPath);
|
||||
for (MemoryNode memoryNode : lastTopicNode.getMemoryNodes()) {
|
||||
if (now.equals(memoryNode.getLocalDate())) {
|
||||
hasSlice = true;
|
||||
node = memoryNode;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (!hasSlice) {
|
||||
node = new MemoryNode();
|
||||
node.setLocalDate(now);
|
||||
node.setMemoryNodeId(UUID.randomUUID().toString());
|
||||
node.setMemorySliceList(new CopyOnWriteArrayList<>());
|
||||
lastTopicNode.getMemoryNodes().add(node);
|
||||
lastTopicNode.getMemoryNodes().sort(null);
|
||||
}
|
||||
node.loadMemorySliceList().add(slice);
|
||||
|
||||
//生成relatedTopicPath
|
||||
for (List<String> relatedTopic : slice.getRelatedTopics()) {
|
||||
generateTopicPath(relatedTopic);
|
||||
}
|
||||
|
||||
updateSlicePrecedent(slice);
|
||||
updateDateIndex(slice);
|
||||
updateUserIndex(slice);
|
||||
|
||||
node.saveMemorySliceList();
|
||||
|
||||
}
|
||||
|
||||
private void updateUserIndex(MemorySlice slice) {
|
||||
String memoryId = slice.getMemoryId();
|
||||
String userId = slice.getStartUserId();
|
||||
if (!userIndex.containsKey(userId)) {
|
||||
List<String> memoryIdSet = new ArrayList<>();
|
||||
memoryIdSet.add(memoryId);
|
||||
userIndex.put(userId, memoryIdSet);
|
||||
} else {
|
||||
userIndex.get(userId).add(memoryId);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
private TopicNode generateTopicPath(List<String> topicPath) {
|
||||
topicPath = new ArrayList<>(topicPath);
|
||||
//查看是否存在根主题节点
|
||||
String rootTopic = topicPath.getFirst();
|
||||
topicPath.removeFirst();
|
||||
if (!topicNodes.containsKey(rootTopic)) {
|
||||
synchronized (this) {
|
||||
if (!topicNodes.containsKey(rootTopic)) {
|
||||
TopicNode rootNode = new TopicNode();
|
||||
topicNodes.put(rootTopic, rootNode);
|
||||
existedTopics.put(rootTopic, new LinkedHashSet<>());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TopicNode current = topicNodes.get(rootTopic);
|
||||
Set<String> existedTopicNodes = existedTopics.get(rootTopic);
|
||||
for (String topic : topicPath) {
|
||||
if (existedTopicNodes.contains(topic) && current.getTopicNodes().containsKey(topic)) {
|
||||
current = current.getTopicNodes().get(topic);
|
||||
} else {
|
||||
TopicNode newNode = new TopicNode();
|
||||
current.getTopicNodes().put(topic, newNode);
|
||||
current = newNode;
|
||||
|
||||
current.setMemoryNodes(new CopyOnWriteArrayList<>());
|
||||
current.setTopicNodes(new ConcurrentHashMap<>());
|
||||
existedTopicNodes.add(topic);
|
||||
}
|
||||
}
|
||||
return current;
|
||||
}
|
||||
|
||||
private void updateSlicePrecedent(MemorySlice slice) {
|
||||
String memoryId = slice.getMemoryId();
|
||||
//查看是否切换了memoryId
|
||||
if (!currentDateDialogSlices.containsKey(memoryId)) {
|
||||
List<MemorySlice> memorySliceList = new ArrayList<>();
|
||||
currentDateDialogSlices.clear();
|
||||
currentDateDialogSlices.put(memoryId, memorySliceList);
|
||||
}
|
||||
//处理上下文关系
|
||||
List<MemorySlice> memorySliceList = currentDateDialogSlices.get(memoryId);
|
||||
if (memorySliceList.isEmpty()) {
|
||||
memorySliceList.add(slice);
|
||||
} else {
|
||||
//排序
|
||||
memorySliceList.sort(null);
|
||||
MemorySlice tempSlice = memorySliceList.getLast();
|
||||
//设置私密状态一致
|
||||
tempSlice.setPrivate(slice.isPrivate());
|
||||
//末尾切片添加当前切片的引用
|
||||
tempSlice.setSliceAfter(slice);
|
||||
//当前切片添加前序切片的引用
|
||||
slice.setSliceBefore(tempSlice);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
private void updateDateIndex(MemorySlice slice) {
|
||||
String memoryId = slice.getMemoryId();
|
||||
LocalDate date = LocalDate.now();
|
||||
if (!dateIndex.containsKey(date)) {
|
||||
HashSet<String> memoryIdSet = new HashSet<>();
|
||||
memoryIdSet.add(memoryId);
|
||||
dateIndex.put(date, memoryIdSet);
|
||||
} else {
|
||||
dateIndex.get(date).add(memoryId);
|
||||
}
|
||||
}
|
||||
|
||||
public MemoryResult selectMemory(List<String> path) throws IOException, ClassNotFoundException {
|
||||
MemoryResult memoryResult = new MemoryResult();
|
||||
CopyOnWriteArrayList<MemorySliceResult> targetSliceList = new CopyOnWriteArrayList<>();
|
||||
String targetTopic = path.getLast();
|
||||
TopicNode targetParentNode = getTargetParentNode(path, targetTopic);
|
||||
List<List<String>> relatedTopics = new ArrayList<>();
|
||||
|
||||
//终点记忆节点
|
||||
MemorySliceResult sliceResult = new MemorySliceResult();
|
||||
for (MemoryNode memoryNode : targetParentNode.getTopicNodes().get(targetTopic).getMemoryNodes()) {
|
||||
List<MemorySlice> endpointMemorySliceList = memoryNode.loadMemorySliceList();
|
||||
for (MemorySlice memorySlice : endpointMemorySliceList) {
|
||||
if (selectedSlices.contains(memorySlice.getTimestamp())) {
|
||||
continue;
|
||||
}
|
||||
sliceResult.setSliceBefore(memorySlice.getSliceBefore());
|
||||
sliceResult.setMemorySlice(memorySlice);
|
||||
sliceResult.setSliceAfter(memorySlice.getSliceAfter());
|
||||
targetSliceList.add(sliceResult);
|
||||
selectedSlices.add(memorySlice.getTimestamp());
|
||||
}
|
||||
for (MemorySlice memorySlice : endpointMemorySliceList) {
|
||||
if (memorySlice.getRelatedTopics() != null) {
|
||||
relatedTopics.addAll(memorySlice.getRelatedTopics());
|
||||
}
|
||||
}
|
||||
}
|
||||
memoryResult.setMemorySliceResult(targetSliceList);
|
||||
|
||||
//邻近节点
|
||||
List<MemorySlice> relatedMemorySlice = new ArrayList<>();
|
||||
//邻近记忆节点 联系
|
||||
for (List<String> relatedTopic : relatedTopics) {
|
||||
List<String> tempTopicPath = new ArrayList<>(relatedTopic);
|
||||
String tempTargetTopic = tempTopicPath.getLast();
|
||||
TopicNode tempTargetParentNode = getTargetParentNode(tempTopicPath, tempTargetTopic);
|
||||
//获取终点节点及其最新记忆节点
|
||||
TopicNode tempTargetNode = tempTargetParentNode.getTopicNodes().get(tempTopicPath.getLast());
|
||||
setRelatedMemorySlices(tempTargetNode, relatedMemorySlice);
|
||||
}
|
||||
|
||||
//邻近记忆节点 父级
|
||||
setRelatedMemorySlices(targetParentNode, relatedMemorySlice);
|
||||
|
||||
//将上述结果包装为MemoryResult
|
||||
memoryResult.setRelatedMemorySliceResult(relatedMemorySlice);
|
||||
return memoryResult;
|
||||
}
|
||||
|
||||
private void setRelatedMemorySlices(TopicNode targetParentNode, List<MemorySlice> relatedMemorySlice) throws IOException, ClassNotFoundException {
|
||||
List<MemoryNode> targetParentMemoryNodes = targetParentNode.getMemoryNodes();
|
||||
if (!targetParentMemoryNodes.isEmpty()) {
|
||||
for (MemorySlice memorySlice : targetParentMemoryNodes.getFirst().loadMemorySliceList()) {
|
||||
if (selectedSlices.contains(memorySlice.getTimestamp())) {
|
||||
continue;
|
||||
}
|
||||
relatedMemorySlice.add(memorySlice);
|
||||
selectedSlices.add(memorySlice.getTimestamp());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private TopicNode getTargetParentNode(List<String> topicPath, String targetTopic) {
|
||||
String topTopic = topicPath.getFirst();
|
||||
if (!existedTopics.containsKey(topTopic)) {
|
||||
throw new UnExistedTopicException("不存在的主题: " + topTopic);
|
||||
}
|
||||
TopicNode targetParentNode = topicNodes.get(topTopic);
|
||||
topicPath.removeFirst();
|
||||
for (String topic : topicPath) {
|
||||
if (!existedTopics.get(topTopic).contains(topic)) {
|
||||
throw new UnExistedTopicException("不存在的主题: " + topTopic);
|
||||
}
|
||||
}
|
||||
|
||||
//逐层查找目标主题
|
||||
while (!targetParentNode.getTopicNodes().containsKey(targetTopic)) {
|
||||
targetParentNode = targetParentNode.getTopicNodes().get(topicPath.getFirst());
|
||||
topicPath.removeFirst();
|
||||
}
|
||||
return targetParentNode;
|
||||
}
|
||||
|
||||
private void updateCacheCounter(List<String> topicPath) {
|
||||
ConcurrentHashMap<List<String>, Integer> memoryNodeCacheCounter = cache.memoryNodeCacheCounter;
|
||||
if (memoryNodeCacheCounter.containsKey(topicPath)) {
|
||||
Integer tempCount = memoryNodeCacheCounter.get(topicPath);
|
||||
memoryNodeCacheCounter.put(topicPath, ++tempCount);
|
||||
} else {
|
||||
memoryNodeCacheCounter.put(topicPath, 1);
|
||||
}
|
||||
}
|
||||
|
||||
private void checkCacheDate() {
|
||||
if (cache.cacheDate == null || cache.cacheDate.isBefore(LocalDate.now())) {
|
||||
cache.memorySliceCache.clear();
|
||||
cache.memoryNodeCacheCounter.clear();
|
||||
cache.cacheDate = LocalDate.now();
|
||||
}
|
||||
}
|
||||
|
||||
private void updateCache(List<String> topicPath, MemoryResult memoryResult) {
|
||||
ConcurrentHashMap<List<String>, Integer> memoryNodeCacheCounter = cache.memoryNodeCacheCounter;
|
||||
Integer tempCount = memoryNodeCacheCounter.get(topicPath);
|
||||
if (tempCount == null) {
|
||||
log.warn("[CacheCore] tempCount为null? memoryNodeCacheCounter: {}; topicPath: {}", memoryNodeCacheCounter, topicPath);
|
||||
return;
|
||||
}
|
||||
if (tempCount >= 5) {
|
||||
cache.memorySliceCache.put(topicPath, memoryResult);
|
||||
}
|
||||
}
|
||||
|
||||
private void updateUserDialogMap(MemorySlice slice) {
|
||||
String summary = slice.getSummary();
|
||||
LocalDateTime now = LocalDateTime.now();
|
||||
ConcurrentHashMap<String, ConcurrentHashMap<LocalDateTime, String>> userDialogMap = cache.userDialogMap;
|
||||
|
||||
//更新userDialogMap
|
||||
//移除两天前上下文缓存(切片总结)
|
||||
List<LocalDateTime> keysToRemove = new ArrayList<>();
|
||||
userDialogMap.forEach((k, v) -> v.forEach((i, j) -> {
|
||||
if (now.minusDays(2).isAfter(i)) {
|
||||
keysToRemove.add(i);
|
||||
}
|
||||
}));
|
||||
for (LocalDateTime dateTime : keysToRemove) {
|
||||
userDialogMap.forEach((k, v) -> v.remove(dateTime));
|
||||
}
|
||||
//放入新缓存
|
||||
userDialogMap
|
||||
.computeIfAbsent(slice.getStartUserId(), k -> new ConcurrentHashMap<>())
|
||||
.merge(now, summary, (oldVal, newVal) -> oldVal + " " + newVal);
|
||||
|
||||
}
|
||||
|
||||
private void clearCacheByTopicPath(List<String> topicPath) {
|
||||
cache.memorySliceCache.remove(topicPath);
|
||||
}
|
||||
|
||||
private MemoryResult selectCache(List<String> path) {
|
||||
ConcurrentHashMap<List<String>, MemoryResult> memorySliceCache = cache.memorySliceCache;
|
||||
if (memorySliceCache.containsKey(path)) {
|
||||
return memorySliceCache.get(path);
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
@CapabilityMethod
|
||||
public Collection<MemoryUnit> listMemoryUnits() {
|
||||
return new ArrayList<>(memoryUnits.values());
|
||||
}
|
||||
|
||||
@CapabilityMethod
|
||||
public void refreshMemoryId() {
|
||||
currentMemoryId = UUID.randomUUID().toString();
|
||||
}
|
||||
|
||||
@CapabilityMethod
|
||||
public String getCurrentMemoryId() {
|
||||
return currentMemoryId;
|
||||
}
|
||||
|
||||
private void normalizeMemoryUnit(MemoryUnit memoryUnit) {
|
||||
if (memoryUnit.getId() == null || memoryUnit.getId().isBlank()) {
|
||||
memoryUnit.setId(UUID.randomUUID().toString());
|
||||
}
|
||||
if (memoryUnit.getTimestamp() == null || memoryUnit.getTimestamp() <= 0) {
|
||||
memoryUnit.setTimestamp(System.currentTimeMillis());
|
||||
}
|
||||
if (memoryUnit.getConversationMessages() == null) {
|
||||
memoryUnit.setConversationMessages(new ArrayList<>());
|
||||
}
|
||||
if (memoryUnit.getSlices() == null) {
|
||||
memoryUnit.setSlices(new ArrayList<>());
|
||||
}
|
||||
int maxIndex = Math.max(memoryUnit.getConversationMessages().size() - 1, 0);
|
||||
for (MemorySlice slice : memoryUnit.getSlices()) {
|
||||
if (slice.getId() == null || slice.getId().isBlank()) {
|
||||
slice.setId(UUID.randomUUID().toString());
|
||||
}
|
||||
if (slice.getTimestamp() == null || slice.getTimestamp() <= 0) {
|
||||
slice.setTimestamp(memoryUnit.getTimestamp());
|
||||
}
|
||||
if (slice.getStartIndex() == null || slice.getStartIndex() < 0) {
|
||||
slice.setStartIndex(0);
|
||||
}
|
||||
if (slice.getEndIndex() == null || slice.getEndIndex() < slice.getStartIndex()) {
|
||||
slice.setEndIndex(maxIndex);
|
||||
}
|
||||
if (slice.getEndIndex() > maxIndex) {
|
||||
slice.setEndIndex(maxIndex);
|
||||
}
|
||||
}
|
||||
memoryUnit.getSlices().sort(Comparator.naturalOrder());
|
||||
}
|
||||
|
||||
@Override
|
||||
protected String getCoreKey() {
|
||||
return "memory-core";
|
||||
}
|
||||
|
||||
public ConcurrentHashMap<String, ConcurrentHashMap<LocalDateTime, String>> getUserDialogMap() {
|
||||
return cache.userDialogMap;
|
||||
}
|
||||
|
||||
|
||||
private MemoryResult cacheFilter(MemoryResult memoryResult) {
|
||||
//过滤掉与缓存重复的切片
|
||||
CopyOnWriteArrayList<MemorySliceResult> memorySliceResult = memoryResult.getMemorySliceResult();
|
||||
List<MemorySlice> relatedMemorySliceResult = memoryResult.getRelatedMemorySliceResult();
|
||||
cache.dialogMap.forEach((k, v) -> {
|
||||
memorySliceResult.removeIf(m -> m.getMemorySlice().getSummary().equals(v));
|
||||
relatedMemorySliceResult.removeIf(m -> m.getSummary().equals(v));
|
||||
});
|
||||
return memoryResult;
|
||||
}
|
||||
|
||||
@SuppressWarnings("FieldMayBeFinal")
|
||||
private static class MemoryCache {
|
||||
|
||||
/**
|
||||
* 近两日的对话总结缓存, 用于为大模型提供必要的记忆补充, hashmap以切片的存储时间为键,总结为值
|
||||
* 该部分作为'主LLM'system prompt常驻
|
||||
* 该部分作为近两日的整体对话缓存, 不区分用户
|
||||
*/
|
||||
private HashMap<LocalDateTime, String> dialogMap = new HashMap<>();
|
||||
|
||||
/**
|
||||
* 近两日的区分用户的对话总结缓存,在prompt结构上比dialogMap层级深一层, dialogMap更具近两日整体对话的摘要性质
|
||||
*/
|
||||
private ConcurrentHashMap<String/*userId*/, ConcurrentHashMap<LocalDateTime, String>> userDialogMap = new ConcurrentHashMap<>();
|
||||
|
||||
/**
|
||||
* memorySliceCache计数器,每日清空
|
||||
*/
|
||||
private ConcurrentHashMap<List<String> /*触发查询的主题列表*/, Integer> memoryNodeCacheCounter = new ConcurrentHashMap<>();
|
||||
|
||||
/**
|
||||
* 记忆切片缓存,每日清空
|
||||
* 用于记录作为终点节点调用次数最多的记忆节点的切片数据
|
||||
*/
|
||||
private ConcurrentHashMap<List<String> /*主题路径*/, MemoryResult /*切片列表*/> memorySliceCache = new ConcurrentHashMap<>();
|
||||
|
||||
/**
|
||||
* 缓存日期
|
||||
*/
|
||||
private LocalDate cacheDate;
|
||||
|
||||
private HashMap<String, List<EvaluatedSlice>> activatedSlices = new HashMap<>();
|
||||
|
||||
private MemoryCache() {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,20 +3,25 @@ package work.slhaf.partner.core.memory.pojo;
|
||||
import lombok.Builder;
|
||||
import lombok.Data;
|
||||
import lombok.EqualsAndHashCode;
|
||||
import work.slhaf.partner.api.chat.pojo.Message;
|
||||
import work.slhaf.partner.api.common.entity.PersistableObject;
|
||||
|
||||
import java.io.Serial;
|
||||
import java.time.LocalDate;
|
||||
import java.util.List;
|
||||
|
||||
@EqualsAndHashCode(callSuper = true)
|
||||
@Data
|
||||
@Builder
|
||||
public class EvaluatedSlice extends PersistableObject {
|
||||
public class ActivatedMemorySlice extends PersistableObject {
|
||||
|
||||
@Serial
|
||||
private static final long serialVersionUID = 1L;
|
||||
|
||||
// private List<Message> chatMessages;
|
||||
private String unitId;
|
||||
private String sliceId;
|
||||
private LocalDate date;
|
||||
private Long timestamp;
|
||||
private String summary;
|
||||
private List<Message> messages;
|
||||
}
|
||||
@@ -1,26 +0,0 @@
|
||||
package work.slhaf.partner.core.memory.pojo;
|
||||
|
||||
import lombok.Data;
|
||||
import lombok.EqualsAndHashCode;
|
||||
import work.slhaf.partner.api.common.entity.PersistableObject;
|
||||
|
||||
import java.io.Serial;
|
||||
import java.util.List;
|
||||
import java.util.concurrent.CopyOnWriteArrayList;
|
||||
|
||||
@EqualsAndHashCode(callSuper = true)
|
||||
@Data
|
||||
public class MemoryResult extends PersistableObject {
|
||||
|
||||
@Serial
|
||||
private static final long serialVersionUID = 1L;
|
||||
|
||||
private CopyOnWriteArrayList<MemorySliceResult> memorySliceResult;
|
||||
private List<MemorySlice> relatedMemorySliceResult;
|
||||
|
||||
public boolean isEmpty() {
|
||||
boolean a = memorySliceResult == null || memorySliceResult.isEmpty();
|
||||
boolean b = relatedMemorySliceResult == null || relatedMemorySliceResult.isEmpty();
|
||||
return a && b;
|
||||
}
|
||||
}
|
||||
@@ -2,12 +2,9 @@ package work.slhaf.partner.core.memory.pojo;
|
||||
|
||||
import lombok.Data;
|
||||
import lombok.EqualsAndHashCode;
|
||||
import lombok.ToString;
|
||||
import work.slhaf.partner.api.chat.pojo.Message;
|
||||
import work.slhaf.partner.api.common.entity.PersistableObject;
|
||||
|
||||
import java.io.Serial;
|
||||
import java.util.List;
|
||||
|
||||
@EqualsAndHashCode(callSuper = true)
|
||||
@Data
|
||||
@@ -16,59 +13,11 @@ public class MemorySlice extends PersistableObject implements Comparable<MemoryS
|
||||
@Serial
|
||||
private static final long serialVersionUID = 1L;
|
||||
|
||||
/**
|
||||
* 关联的完整对话的id
|
||||
*/
|
||||
private String memoryId;
|
||||
|
||||
/**
|
||||
* 该切片在关联的完整对话中的顺序, 由时间戳确定
|
||||
*/
|
||||
private Long timestamp;
|
||||
|
||||
/**
|
||||
* 格式为"<日期>.slice", 如2025-04-11.slice
|
||||
*/
|
||||
private String id;
|
||||
private Integer startIndex;
|
||||
private Integer endIndex;
|
||||
private String summary;
|
||||
|
||||
private List<Message> chatMessages;
|
||||
|
||||
/**
|
||||
* 关联的其他主题, 即"邻近节点(联系)"
|
||||
*/
|
||||
private List<List<String>> relatedTopics;
|
||||
|
||||
/**
|
||||
* 关联完整对话中的前序切片, 排序为键,完整路径为值
|
||||
*/
|
||||
@ToString.Exclude
|
||||
private MemorySlice sliceBefore, sliceAfter;
|
||||
|
||||
/**
|
||||
* 多用户设定
|
||||
* 发起该切片对话的用户
|
||||
*/
|
||||
private String startUserId;
|
||||
|
||||
/**
|
||||
* 该切片涉及到的用户uuid
|
||||
*/
|
||||
private List<String> involvedUserIds;
|
||||
|
||||
/**
|
||||
* 是否仅供发起用户作为记忆参考
|
||||
*/
|
||||
private boolean isPrivate;
|
||||
|
||||
/**
|
||||
* 摘要向量化结果
|
||||
*/
|
||||
private float[] summaryEmbedding;
|
||||
|
||||
/**
|
||||
* 是否向量化
|
||||
*/
|
||||
private boolean embedded;
|
||||
private Long timestamp;
|
||||
|
||||
@Override
|
||||
public int compareTo(MemorySlice memorySlice) {
|
||||
@@ -79,5 +28,4 @@ public class MemorySlice extends PersistableObject implements Comparable<MemoryS
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -0,0 +1,23 @@
|
||||
package work.slhaf.partner.core.memory.pojo;
|
||||
|
||||
import lombok.Data;
|
||||
import lombok.EqualsAndHashCode;
|
||||
import work.slhaf.partner.api.chat.pojo.Message;
|
||||
import work.slhaf.partner.api.common.entity.PersistableObject;
|
||||
|
||||
import java.io.Serial;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
@EqualsAndHashCode(callSuper = true)
|
||||
@Data
|
||||
public class MemoryUnit extends PersistableObject {
|
||||
|
||||
@Serial
|
||||
private static final long serialVersionUID = 1L;
|
||||
|
||||
private String id;
|
||||
private List<Message> conversationMessages = new ArrayList<>();
|
||||
private Long timestamp;
|
||||
private List<MemorySlice> slices = new ArrayList<>();
|
||||
}
|
||||
@@ -1,24 +1,22 @@
|
||||
package work.slhaf.partner.core.memory.pojo;
|
||||
|
||||
import com.alibaba.fastjson2.annotation.JSONField;
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Data;
|
||||
import lombok.EqualsAndHashCode;
|
||||
import lombok.NoArgsConstructor;
|
||||
import work.slhaf.partner.api.common.entity.PersistableObject;
|
||||
|
||||
import java.io.Serial;
|
||||
|
||||
@EqualsAndHashCode(callSuper = true)
|
||||
@Data
|
||||
public class MemorySliceResult extends PersistableObject {
|
||||
@NoArgsConstructor
|
||||
@AllArgsConstructor
|
||||
public class SliceRef extends PersistableObject {
|
||||
|
||||
@Serial
|
||||
private static final long serialVersionUID = 1L;
|
||||
|
||||
@JSONField(serialize = false)
|
||||
private MemorySlice sliceBefore;
|
||||
|
||||
private MemorySlice memorySlice;
|
||||
|
||||
@JSONField(serialize = false)
|
||||
private MemorySlice sliceAfter;
|
||||
private String unitId;
|
||||
private String sliceId;
|
||||
}
|
||||
@@ -367,7 +367,7 @@ public class ActionExecutor extends AbstractAgentModule.Standalone {
|
||||
private ExtractorInput buildExtractorInput(MetaAction action, String source, List<HistoryAction> historyActionResults,
|
||||
List<String> additionalContext) {
|
||||
ExtractorInput input = new ExtractorInput();
|
||||
input.setEvaluatedSlices(memoryCapability.getActivatedSlices(source));
|
||||
input.setActivatedMemorySlices(memoryCapability.getActivatedSlices());
|
||||
input.setRecentMessages(cognationCapability.getChatMessages());
|
||||
input.setMetaActionInfo(actionCapability.loadMetaActionInfo(action.getKey()));
|
||||
input.setHistoryActionResults(historyActionResults);
|
||||
@@ -384,7 +384,7 @@ public class ActionExecutor extends AbstractAgentModule.Standalone {
|
||||
.history(executableAction.getHistory().get(executableAction.getExecutingStage()))
|
||||
.status(executableAction.getStatus())
|
||||
.recentMessages(cognationCapability.getChatMessages())
|
||||
.activatedSlices(memoryCapability.getActivatedSlices(source))
|
||||
.activatedSlices(memoryCapability.getActivatedSlices())
|
||||
.build();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,7 +4,7 @@ import lombok.Builder;
|
||||
import lombok.Data;
|
||||
import work.slhaf.partner.api.chat.pojo.Message;
|
||||
import work.slhaf.partner.core.action.entity.ExecutableAction;
|
||||
import work.slhaf.partner.core.memory.pojo.EvaluatedSlice;
|
||||
import work.slhaf.partner.core.memory.pojo.ActivatedMemorySlice;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
@@ -20,5 +20,5 @@ public class CorrectorInput {
|
||||
private ExecutableAction.Status status;
|
||||
|
||||
private List<Message> recentMessages;
|
||||
private List<EvaluatedSlice> activatedSlices;
|
||||
private List<ActivatedMemorySlice> activatedSlices;
|
||||
}
|
||||
|
||||
@@ -3,7 +3,7 @@ package work.slhaf.partner.module.modules.action.executor.entity;
|
||||
import lombok.Data;
|
||||
import work.slhaf.partner.api.chat.pojo.Message;
|
||||
import work.slhaf.partner.core.action.entity.MetaActionInfo;
|
||||
import work.slhaf.partner.core.memory.pojo.EvaluatedSlice;
|
||||
import work.slhaf.partner.core.memory.pojo.ActivatedMemorySlice;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
@@ -16,7 +16,7 @@ public class ExtractorInput {
|
||||
/**
|
||||
* 可参考的记忆切片
|
||||
*/
|
||||
private List<EvaluatedSlice> evaluatedSlices;
|
||||
private List<ActivatedMemorySlice> activatedMemorySlices;
|
||||
/**
|
||||
* 历史行动执行结果
|
||||
*/
|
||||
|
||||
@@ -22,6 +22,7 @@ import work.slhaf.partner.module.modules.action.interventor.evaluator.entity.Eva
|
||||
import work.slhaf.partner.module.modules.action.interventor.recognizer.InterventionRecognizer;
|
||||
import work.slhaf.partner.module.modules.action.interventor.recognizer.entity.RecognizerInput;
|
||||
import work.slhaf.partner.module.modules.action.interventor.recognizer.entity.RecognizerResult;
|
||||
import work.slhaf.partner.module.modules.memory.runtime.MemoryRuntime;
|
||||
import work.slhaf.partner.runtime.interaction.data.context.PartnerRunningFlowContext;
|
||||
|
||||
import java.util.ArrayList;
|
||||
@@ -51,6 +52,8 @@ public class ActionInterventor extends AbstractAgentModule.Running<PartnerRunnin
|
||||
private CognationCapability cognationCapability;
|
||||
@InjectCapability
|
||||
private MemoryCapability memoryCapability;
|
||||
@InjectModule
|
||||
private MemoryRuntime memoryRuntime;
|
||||
|
||||
@Override
|
||||
public void execute(PartnerRunningFlowContext context) {
|
||||
@@ -146,7 +149,7 @@ public class ActionInterventor extends AbstractAgentModule.Running<PartnerRunnin
|
||||
private RecognizerInput buildRecognizerInput(String userId, String input) {
|
||||
RecognizerInput recognizerInput = new RecognizerInput();
|
||||
recognizerInput.setInput(input);
|
||||
recognizerInput.setUserDialogMapStr(memoryCapability.getUserDialogMapStr(userId));
|
||||
recognizerInput.setUserDialogMapStr(memoryRuntime.getDialogMapStr());
|
||||
// 参考的对话列表大小或需调整
|
||||
recognizerInput.setRecentMessages(cognationCapability.getChatMessages());
|
||||
recognizerInput.setExecutingActions(actionCapability.listPhaserRecords().stream().map(PhaserRecord::executableAction).toList());
|
||||
@@ -159,7 +162,7 @@ public class ActionInterventor extends AbstractAgentModule.Running<PartnerRunnin
|
||||
input.setExecutingInterventions(recognizerResult.getExecutingInterventions());
|
||||
input.setPreparedInterventions(recognizerResult.getPreparedInterventions());
|
||||
input.setRecentMessages(cognationCapability.getChatMessages());
|
||||
input.setActivatedSlices(memoryCapability.getActivatedSlices(userId));
|
||||
input.setActivatedSlices(memoryCapability.getActivatedSlices());
|
||||
return input;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -8,7 +8,7 @@ import work.slhaf.partner.api.chat.pojo.Message;
|
||||
import work.slhaf.partner.core.action.ActionCapability;
|
||||
import work.slhaf.partner.core.action.ActionCore.ExecutorType;
|
||||
import work.slhaf.partner.core.action.entity.ExecutableAction;
|
||||
import work.slhaf.partner.core.memory.pojo.EvaluatedSlice;
|
||||
import work.slhaf.partner.core.memory.pojo.ActivatedMemorySlice;
|
||||
import work.slhaf.partner.module.modules.action.interventor.evaluator.entity.EvaluatorInput;
|
||||
import work.slhaf.partner.module.modules.action.interventor.evaluator.entity.EvaluatorResult;
|
||||
import work.slhaf.partner.module.modules.action.interventor.evaluator.entity.EvaluatorResult.EvaluatedInterventionData;
|
||||
@@ -66,7 +66,7 @@ public class InterventionEvaluator extends AbstractAgentModule.Sub<EvaluatorInpu
|
||||
}));
|
||||
}
|
||||
|
||||
private String buildPrompt(List<Message> recentMessages, List<EvaluatedSlice> activatedSlices,
|
||||
private String buildPrompt(List<Message> recentMessages, List<ActivatedMemorySlice> activatedSlices,
|
||||
ExecutableAction executableAction, String tendency) {
|
||||
JSONObject json = new JSONObject();
|
||||
json.put("干预倾向", tendency);
|
||||
|
||||
@@ -3,7 +3,7 @@ package work.slhaf.partner.module.modules.action.interventor.evaluator.entity;
|
||||
import lombok.Data;
|
||||
import work.slhaf.partner.api.chat.pojo.Message;
|
||||
import work.slhaf.partner.core.action.entity.ExecutableAction;
|
||||
import work.slhaf.partner.core.memory.pojo.EvaluatedSlice;
|
||||
import work.slhaf.partner.core.memory.pojo.ActivatedMemorySlice;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
@@ -12,6 +12,6 @@ import java.util.Map;
|
||||
public class EvaluatorInput {
|
||||
private Map<String, ExecutableAction> executingInterventions;
|
||||
private Map<String, ExecutableAction> preparedInterventions;
|
||||
private List<EvaluatedSlice> activatedSlices;
|
||||
private List<ActivatedMemorySlice> activatedSlices;
|
||||
private List<Message> recentMessages;
|
||||
}
|
||||
|
||||
@@ -328,7 +328,7 @@ public class ActionPlanner extends AbstractAgentModule.Running<PartnerRunningFlo
|
||||
input.setTendencies(extractorResult.getTendencies());
|
||||
input.setUser(perceiveCapability.getUser(userId));
|
||||
input.setRecentMessages(cognationCapability.snapshotChatMessages());
|
||||
input.setActivatedSlices(memoryCapability.getActivatedSlices(userId));
|
||||
input.setActivatedSlices(memoryCapability.getActivatedSlices());
|
||||
return input;
|
||||
}
|
||||
|
||||
|
||||
@@ -10,7 +10,7 @@ import work.slhaf.partner.api.agent.factory.component.annotation.Init;
|
||||
import work.slhaf.partner.api.chat.pojo.Message;
|
||||
import work.slhaf.partner.common.thread.InteractionThreadPoolExecutor;
|
||||
import work.slhaf.partner.core.action.ActionCapability;
|
||||
import work.slhaf.partner.core.memory.pojo.EvaluatedSlice;
|
||||
import work.slhaf.partner.core.memory.pojo.ActivatedMemorySlice;
|
||||
import work.slhaf.partner.module.modules.action.planner.evaluator.entity.EvaluatorBatchInput;
|
||||
import work.slhaf.partner.module.modules.action.planner.evaluator.entity.EvaluatorInput;
|
||||
import work.slhaf.partner.module.modules.action.planner.evaluator.entity.EvaluatorResult;
|
||||
@@ -77,7 +77,7 @@ public class ActionEvaluator extends AbstractAgentModule.Sub<EvaluatorInput, Lis
|
||||
JSONObject prompt = new JSONObject();
|
||||
prompt.put("[行动倾向]", batchInput.getTendency());
|
||||
JSONArray memoryData = prompt.putArray("[相关记忆切片]");
|
||||
for (EvaluatedSlice evaluatedSlice : batchInput.getActivatedSlices()) {
|
||||
for (ActivatedMemorySlice evaluatedSlice : batchInput.getActivatedSlices()) {
|
||||
JSONObject memory = memoryData.addObject();
|
||||
memory.put("[日期]", evaluatedSlice.getDate());
|
||||
memory.put("[摘要]", evaluatedSlice.getSummary());
|
||||
|
||||
@@ -2,7 +2,7 @@ package work.slhaf.partner.module.modules.action.planner.evaluator.entity;
|
||||
|
||||
import lombok.Data;
|
||||
import work.slhaf.partner.api.chat.pojo.Message;
|
||||
import work.slhaf.partner.core.memory.pojo.EvaluatedSlice;
|
||||
import work.slhaf.partner.core.memory.pojo.ActivatedMemorySlice;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
@@ -10,7 +10,7 @@ import java.util.Map;
|
||||
@Data
|
||||
public class EvaluatorBatchInput {
|
||||
private List<Message> recentMessages;
|
||||
private List<EvaluatedSlice> activatedSlices;
|
||||
private List<ActivatedMemorySlice> activatedSlices;
|
||||
private Map<String, String> availableActions;
|
||||
private String tendency;
|
||||
}
|
||||
|
||||
@@ -2,7 +2,7 @@ package work.slhaf.partner.module.modules.action.planner.evaluator.entity;
|
||||
|
||||
import lombok.Data;
|
||||
import work.slhaf.partner.api.chat.pojo.Message;
|
||||
import work.slhaf.partner.core.memory.pojo.EvaluatedSlice;
|
||||
import work.slhaf.partner.core.memory.pojo.ActivatedMemorySlice;
|
||||
import work.slhaf.partner.core.perceive.pojo.User;
|
||||
|
||||
import java.util.List;
|
||||
@@ -11,6 +11,6 @@ import java.util.List;
|
||||
public class EvaluatorInput {
|
||||
private List<Message> recentMessages;
|
||||
private User user;
|
||||
private List<EvaluatedSlice> activatedSlices;
|
||||
private List<ActivatedMemorySlice> activatedSlices;
|
||||
private List<String> tendencies;
|
||||
}
|
||||
|
||||
@@ -12,7 +12,6 @@ import work.slhaf.partner.api.agent.factory.component.abstracts.ActivateModel;
|
||||
import work.slhaf.partner.api.agent.factory.component.annotation.Init;
|
||||
import work.slhaf.partner.api.agent.runtime.interaction.flow.ContextBlock;
|
||||
import work.slhaf.partner.api.chat.pojo.Message;
|
||||
import work.slhaf.partner.api.chat.pojo.MetaMessage;
|
||||
import work.slhaf.partner.core.cognation.CognationCapability;
|
||||
import work.slhaf.partner.runtime.interaction.data.context.PartnerRunningFlowContext;
|
||||
|
||||
@@ -126,8 +125,6 @@ public class CommunicationProducer extends AbstractAgentModule.Running<PartnerRu
|
||||
chatMessages.add(primaryUserMessage);
|
||||
Message assistantMessage = new Message(Message.Character.ASSISTANT, response);
|
||||
chatMessages.add(assistantMessage);
|
||||
MetaMessage metaMessage = new MetaMessage(primaryUserMessage, assistantMessage);
|
||||
cognationCapability.addMetaMessage(runningFlowContext.getSource(), metaMessage);
|
||||
} finally {
|
||||
cognationCapability.getMessageLock().unlock();
|
||||
}
|
||||
|
||||
@@ -0,0 +1,247 @@
|
||||
package work.slhaf.partner.module.modules.memory.runtime;
|
||||
|
||||
import lombok.Data;
|
||||
import lombok.EqualsAndHashCode;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.io.FileUtils;
|
||||
import work.slhaf.partner.api.agent.factory.component.abstracts.AbstractAgentModule;
|
||||
import work.slhaf.partner.api.agent.factory.component.annotation.Init;
|
||||
import work.slhaf.partner.api.agent.runtime.config.AgentConfigLoader;
|
||||
import work.slhaf.partner.api.common.entity.PersistableObject;
|
||||
import work.slhaf.partner.common.config.PartnerAgentConfigLoader;
|
||||
import work.slhaf.partner.core.memory.exception.UnExistedDateIndexException;
|
||||
import work.slhaf.partner.core.memory.exception.UnExistedTopicException;
|
||||
import work.slhaf.partner.core.memory.pojo.MemorySlice;
|
||||
import work.slhaf.partner.core.memory.pojo.MemoryUnit;
|
||||
import work.slhaf.partner.core.memory.pojo.SliceRef;
|
||||
|
||||
import java.io.*;
|
||||
import java.nio.file.Files;
|
||||
import java.nio.file.Path;
|
||||
import java.nio.file.Paths;
|
||||
import java.time.Instant;
|
||||
import java.time.LocalDate;
|
||||
import java.time.LocalDateTime;
|
||||
import java.time.ZoneId;
|
||||
import java.util.*;
|
||||
import java.util.concurrent.CopyOnWriteArrayList;
|
||||
import java.util.concurrent.locks.ReentrantLock;
|
||||
|
||||
import static work.slhaf.partner.common.Constant.Path.MEMORY_DATA;
|
||||
|
||||
@EqualsAndHashCode(callSuper = true)
|
||||
@Data
|
||||
@Slf4j
|
||||
public class MemoryRuntime extends AbstractAgentModule.Standalone {
|
||||
|
||||
private static final String RUNTIME_KEY = "memory-runtime";
|
||||
|
||||
private final ReentrantLock runtimeLock = new ReentrantLock();
|
||||
private Map<String, CopyOnWriteArrayList<SliceRef>> topicSlices = new HashMap<>();
|
||||
private Map<LocalDate, CopyOnWriteArrayList<SliceRef>> dateIndex = new HashMap<>();
|
||||
private HashMap<LocalDateTime, String> dialogMap = new HashMap<>();
|
||||
|
||||
@Init
|
||||
public void init() {
|
||||
loadState();
|
||||
Runtime.getRuntime().addShutdownHook(new Thread(this::saveStateSafely));
|
||||
}
|
||||
|
||||
public void bindTopic(String topicPath, SliceRef sliceRef) {
|
||||
String normalizedPath = normalizeTopicPath(topicPath);
|
||||
runtimeLock.lock();
|
||||
try {
|
||||
CopyOnWriteArrayList<SliceRef> refs = topicSlices.computeIfAbsent(normalizedPath, key -> new CopyOnWriteArrayList<>());
|
||||
boolean exists = refs.stream().anyMatch(ref -> Objects.equals(ref.getUnitId(), sliceRef.getUnitId())
|
||||
&& Objects.equals(ref.getSliceId(), sliceRef.getSliceId()));
|
||||
if (!exists) {
|
||||
refs.add(sliceRef);
|
||||
}
|
||||
saveState();
|
||||
} finally {
|
||||
runtimeLock.unlock();
|
||||
}
|
||||
}
|
||||
|
||||
public void indexMemoryUnit(MemoryUnit memoryUnit) {
|
||||
runtimeLock.lock();
|
||||
try {
|
||||
for (CopyOnWriteArrayList<SliceRef> refs : dateIndex.values()) {
|
||||
refs.removeIf(ref -> memoryUnit.getId().equals(ref.getUnitId()));
|
||||
}
|
||||
if (memoryUnit.getSlices() != null) {
|
||||
for (MemorySlice slice : memoryUnit.getSlices()) {
|
||||
LocalDate date = Instant.ofEpochMilli(slice.getTimestamp())
|
||||
.atZone(ZoneId.systemDefault())
|
||||
.toLocalDate();
|
||||
dateIndex.computeIfAbsent(date, key -> new CopyOnWriteArrayList<>())
|
||||
.addIfAbsent(new SliceRef(memoryUnit.getId(), slice.getId()));
|
||||
}
|
||||
}
|
||||
saveState();
|
||||
} finally {
|
||||
runtimeLock.unlock();
|
||||
}
|
||||
}
|
||||
|
||||
public List<SliceRef> findByTopicPath(String topicPath) {
|
||||
String normalizedPath = normalizeTopicPath(topicPath);
|
||||
List<SliceRef> refs = topicSlices.get(normalizedPath);
|
||||
if (refs == null || refs.isEmpty()) {
|
||||
throw new UnExistedTopicException("不存在的主题: " + normalizedPath);
|
||||
}
|
||||
return new ArrayList<>(refs);
|
||||
}
|
||||
|
||||
public List<SliceRef> findByDate(LocalDate date) {
|
||||
List<SliceRef> refs = dateIndex.get(date);
|
||||
if (refs == null || refs.isEmpty()) {
|
||||
throw new UnExistedDateIndexException("不存在的日期索引: " + date);
|
||||
}
|
||||
return new ArrayList<>(refs);
|
||||
}
|
||||
|
||||
public void updateDialogMap(LocalDateTime dateTime, String newDialogCache) {
|
||||
runtimeLock.lock();
|
||||
try {
|
||||
List<LocalDateTime> keysToRemove = new ArrayList<>();
|
||||
dialogMap.forEach((k, v) -> {
|
||||
if (dateTime.minusDays(2).isAfter(k)) {
|
||||
keysToRemove.add(k);
|
||||
}
|
||||
});
|
||||
for (LocalDateTime temp : keysToRemove) {
|
||||
dialogMap.remove(temp);
|
||||
}
|
||||
dialogMap.put(dateTime, newDialogCache);
|
||||
saveState();
|
||||
} finally {
|
||||
runtimeLock.unlock();
|
||||
}
|
||||
}
|
||||
|
||||
public HashMap<LocalDateTime, String> getDialogMap() {
|
||||
return dialogMap;
|
||||
}
|
||||
|
||||
public String getDialogMapStr() {
|
||||
StringBuilder str = new StringBuilder();
|
||||
dialogMap.entrySet().stream()
|
||||
.sorted(Map.Entry.comparingByKey())
|
||||
.forEach(entry -> str.append("\n\n[")
|
||||
.append(entry.getKey())
|
||||
.append("]\n")
|
||||
.append(entry.getValue()));
|
||||
return str.toString();
|
||||
}
|
||||
|
||||
public String getTopicTree() {
|
||||
TopicTreeNode root = new TopicTreeNode();
|
||||
for (Map.Entry<String, CopyOnWriteArrayList<SliceRef>> entry : topicSlices.entrySet()) {
|
||||
String[] parts = entry.getKey().split("->");
|
||||
TopicTreeNode current = root;
|
||||
for (String part : parts) {
|
||||
current = current.children.computeIfAbsent(part, key -> new TopicTreeNode());
|
||||
}
|
||||
current.count += entry.getValue().size();
|
||||
}
|
||||
|
||||
StringBuilder stringBuilder = new StringBuilder();
|
||||
List<Map.Entry<String, TopicTreeNode>> roots = new ArrayList<>(root.children.entrySet());
|
||||
for (Map.Entry<String, TopicTreeNode> entry : roots) {
|
||||
stringBuilder.append(entry.getKey()).append("[root]").append("\r\n");
|
||||
printSubTopicsTreeFormat(entry.getValue(), "", stringBuilder);
|
||||
}
|
||||
return stringBuilder.toString();
|
||||
}
|
||||
|
||||
private void printSubTopicsTreeFormat(TopicTreeNode node, String prefix, StringBuilder stringBuilder) {
|
||||
List<Map.Entry<String, TopicTreeNode>> entries = new ArrayList<>(node.children.entrySet());
|
||||
for (int i = 0; i < entries.size(); i++) {
|
||||
boolean last = i == entries.size() - 1;
|
||||
Map.Entry<String, TopicTreeNode> entry = entries.get(i);
|
||||
stringBuilder.append(prefix)
|
||||
.append(last ? "└── " : "├── ")
|
||||
.append(entry.getKey())
|
||||
.append("[")
|
||||
.append(entry.getValue().count)
|
||||
.append("]")
|
||||
.append("\r\n");
|
||||
printSubTopicsTreeFormat(entry.getValue(), prefix + (last ? " " : "│ "), stringBuilder);
|
||||
}
|
||||
}
|
||||
|
||||
private String normalizeTopicPath(String topicPath) {
|
||||
return topicPath == null ? "" : topicPath.trim();
|
||||
}
|
||||
|
||||
private void loadState() {
|
||||
Path filePath = getFilePath();
|
||||
if (!Files.exists(filePath)) {
|
||||
return;
|
||||
}
|
||||
try (ObjectInputStream ois = new ObjectInputStream(new FileInputStream(filePath.toFile()))) {
|
||||
RuntimeState state = (RuntimeState) ois.readObject();
|
||||
topicSlices = state.topicSlices;
|
||||
dateIndex = state.dateIndex;
|
||||
dialogMap = state.dialogMap;
|
||||
} catch (Exception e) {
|
||||
log.error("[MemoryRuntime] 加载运行态失败", e);
|
||||
topicSlices = new HashMap<>();
|
||||
dateIndex = new HashMap<>();
|
||||
dialogMap = new HashMap<>();
|
||||
}
|
||||
}
|
||||
|
||||
private void saveStateSafely() {
|
||||
runtimeLock.lock();
|
||||
try {
|
||||
saveState();
|
||||
} finally {
|
||||
runtimeLock.unlock();
|
||||
}
|
||||
}
|
||||
|
||||
private void saveState() {
|
||||
Path filePath = getFilePath();
|
||||
Path tempPath = getTempFilePath();
|
||||
try {
|
||||
Files.createDirectories(Paths.get(MEMORY_DATA));
|
||||
FileUtils.createParentDirectories(filePath.toFile());
|
||||
try (ObjectOutputStream oos = new ObjectOutputStream(new FileOutputStream(tempPath.toFile()))) {
|
||||
RuntimeState state = new RuntimeState();
|
||||
state.topicSlices = new HashMap<>(topicSlices);
|
||||
state.dateIndex = new HashMap<>(dateIndex);
|
||||
state.dialogMap = new HashMap<>(dialogMap);
|
||||
oos.writeObject(state);
|
||||
}
|
||||
Files.move(tempPath, filePath, java.nio.file.StandardCopyOption.REPLACE_EXISTING);
|
||||
} catch (IOException e) {
|
||||
log.error("[MemoryRuntime] 保存运行态失败", e);
|
||||
}
|
||||
}
|
||||
|
||||
private Path getFilePath() {
|
||||
String id = ((PartnerAgentConfigLoader) AgentConfigLoader.INSTANCE).getConfig().getAgentId();
|
||||
return Paths.get(MEMORY_DATA, id + "-" + RUNTIME_KEY + ".memory");
|
||||
}
|
||||
|
||||
private Path getTempFilePath() {
|
||||
String id = ((PartnerAgentConfigLoader) AgentConfigLoader.INSTANCE).getConfig().getAgentId();
|
||||
return Paths.get(MEMORY_DATA, id + "-" + RUNTIME_KEY + "-temp.memory");
|
||||
}
|
||||
|
||||
private static final class TopicTreeNode {
|
||||
private final Map<String, TopicTreeNode> children = new LinkedHashMap<>();
|
||||
private int count;
|
||||
}
|
||||
|
||||
private static final class RuntimeState extends PersistableObject {
|
||||
@Serial
|
||||
private static final long serialVersionUID = 1L;
|
||||
|
||||
private Map<String, CopyOnWriteArrayList<SliceRef>> topicSlices = new HashMap<>();
|
||||
private Map<LocalDate, CopyOnWriteArrayList<SliceRef>> dateIndex = new HashMap<>();
|
||||
private HashMap<LocalDateTime, String> dialogMap = new HashMap<>();
|
||||
}
|
||||
}
|
||||
@@ -6,13 +6,16 @@ import lombok.EqualsAndHashCode;
|
||||
import work.slhaf.partner.api.agent.factory.capability.annotation.InjectCapability;
|
||||
import work.slhaf.partner.api.agent.factory.component.abstracts.AbstractAgentModule;
|
||||
import work.slhaf.partner.api.agent.factory.component.annotation.InjectModule;
|
||||
import work.slhaf.partner.api.chat.pojo.Message;
|
||||
import work.slhaf.partner.core.cognation.CognationCapability;
|
||||
import work.slhaf.partner.core.memory.MemoryCapability;
|
||||
import work.slhaf.partner.core.memory.exception.UnExistedDateIndexException;
|
||||
import work.slhaf.partner.core.memory.exception.UnExistedTopicException;
|
||||
import work.slhaf.partner.core.memory.pojo.EvaluatedSlice;
|
||||
import work.slhaf.partner.core.memory.pojo.MemoryResult;
|
||||
import work.slhaf.partner.core.memory.pojo.ActivatedMemorySlice;
|
||||
import work.slhaf.partner.core.memory.pojo.MemorySlice;
|
||||
import work.slhaf.partner.core.memory.pojo.MemoryUnit;
|
||||
import work.slhaf.partner.core.memory.pojo.SliceRef;
|
||||
import work.slhaf.partner.module.modules.memory.runtime.MemoryRuntime;
|
||||
import work.slhaf.partner.module.modules.memory.selector.evaluator.SliceSelectEvaluator;
|
||||
import work.slhaf.partner.module.modules.memory.selector.evaluator.entity.EvaluatorInput;
|
||||
import work.slhaf.partner.module.modules.memory.selector.extractor.MemorySelectExtractor;
|
||||
@@ -20,9 +23,12 @@ import work.slhaf.partner.module.modules.memory.selector.extractor.entity.Extrac
|
||||
import work.slhaf.partner.module.modules.memory.selector.extractor.entity.ExtractorResult;
|
||||
import work.slhaf.partner.runtime.interaction.data.context.PartnerRunningFlowContext;
|
||||
|
||||
import java.time.Instant;
|
||||
import java.time.LocalDate;
|
||||
import java.time.ZoneId;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collection;
|
||||
import java.util.LinkedHashMap;
|
||||
import java.util.List;
|
||||
|
||||
@EqualsAndHashCode(callSuper = true)
|
||||
@@ -33,89 +39,106 @@ public class MemorySelector extends AbstractAgentModule.Running<PartnerRunningFl
|
||||
@InjectCapability
|
||||
private CognationCapability cognationCapability;
|
||||
@InjectModule
|
||||
private MemoryRuntime memoryRuntime;
|
||||
@InjectModule
|
||||
private SliceSelectEvaluator sliceSelectEvaluator;
|
||||
@InjectModule
|
||||
private MemorySelectExtractor memorySelectExtractor;
|
||||
|
||||
@Override
|
||||
public void execute(PartnerRunningFlowContext runningFlowContext) {
|
||||
String userId = runningFlowContext.getSource();
|
||||
//获取主题路径
|
||||
ExtractorResult extractorResult = memorySelectExtractor.execute(runningFlowContext);
|
||||
if (extractorResult.isRecall() || !extractorResult.getMatches().isEmpty()) {
|
||||
memoryCapability.clearActivatedSlices(userId);
|
||||
List<EvaluatedSlice> evaluatedSlices = selectAndEvaluateMemory(runningFlowContext, extractorResult);
|
||||
memoryCapability.updateActivatedSlices(userId, evaluatedSlices);
|
||||
memoryCapability.clearActivatedSlices();
|
||||
List<ActivatedMemorySlice> activatedSlices = selectAndEvaluateMemory(runningFlowContext, extractorResult);
|
||||
memoryCapability.updateActivatedSlices(activatedSlices);
|
||||
}
|
||||
setModuleContextRecall(runningFlowContext);
|
||||
}
|
||||
|
||||
private List<EvaluatedSlice> selectAndEvaluateMemory(PartnerRunningFlowContext runningFlowContext, ExtractorResult extractorResult) {
|
||||
private List<ActivatedMemorySlice> selectAndEvaluateMemory(PartnerRunningFlowContext runningFlowContext,
|
||||
ExtractorResult extractorResult) {
|
||||
log.debug("[MemorySelector] 触发记忆回溯...");
|
||||
//查找切片
|
||||
String userId = runningFlowContext.getSource();
|
||||
List<MemoryResult> memoryResultList = new ArrayList<>();
|
||||
setMemoryResultList(memoryResultList, extractorResult.getMatches(), userId);
|
||||
//评估切片
|
||||
LinkedHashMap<String, ActivatedMemorySlice> candidates = new LinkedHashMap<>();
|
||||
setMemoryCandidates(candidates, extractorResult.getMatches());
|
||||
removeDuplicateSlice(candidates.values());
|
||||
EvaluatorInput evaluatorInput = EvaluatorInput.builder()
|
||||
.input(runningFlowContext.getInput())
|
||||
.memoryResults(memoryResultList)
|
||||
.memorySlices(new ArrayList<>(candidates.values()))
|
||||
.messages(cognationCapability.getChatMessages())
|
||||
.build();
|
||||
log.debug("[MemorySelector] 切片评估输入: {}", JSONObject.toJSONString(evaluatorInput));
|
||||
List<EvaluatedSlice> memorySlices = sliceSelectEvaluator.execute(evaluatorInput);
|
||||
List<ActivatedMemorySlice> memorySlices = sliceSelectEvaluator.execute(evaluatorInput);
|
||||
log.debug("[MemorySelector] 切片评估结果: {}", JSONObject.toJSONString(memorySlices));
|
||||
return memorySlices;
|
||||
}
|
||||
|
||||
private void setModuleContextRecall(PartnerRunningFlowContext runningFlowContext) {
|
||||
String userId = runningFlowContext.getSource();
|
||||
boolean recall = memoryCapability.hasActivatedSlices(userId);
|
||||
boolean recall = memoryCapability.hasActivatedSlices();
|
||||
runningFlowContext.getModuleContext().getExtraContext().put("recall", recall);
|
||||
if (recall) {
|
||||
runningFlowContext.getModuleContext().getExtraContext().put("recall_count", memoryCapability.getActivatedSlicesSize(userId));
|
||||
runningFlowContext.getModuleContext().getExtraContext().put("recall_count", memoryCapability.getActivatedSlicesSize());
|
||||
}
|
||||
}
|
||||
|
||||
private void setMemoryResultList(List<MemoryResult> memoryResultList, List<ExtractorMatchData> matches, String userId) {
|
||||
private void setMemoryCandidates(LinkedHashMap<String, ActivatedMemorySlice> candidates,
|
||||
List<ExtractorMatchData> matches) {
|
||||
for (ExtractorMatchData match : matches) {
|
||||
try {
|
||||
MemoryResult memoryResult = switch (match.getType()) {
|
||||
case ExtractorMatchData.Constant.TOPIC -> memoryCapability.selectMemory(match.getText());
|
||||
case ExtractorMatchData.Constant.DATE ->
|
||||
memoryCapability.selectMemory(LocalDate.parse(match.getText()));
|
||||
default -> null;
|
||||
List<SliceRef> refs = switch (match.getType()) {
|
||||
case ExtractorMatchData.Constant.TOPIC -> memoryRuntime.findByTopicPath(match.getText());
|
||||
case ExtractorMatchData.Constant.DATE -> memoryRuntime.findByDate(LocalDate.parse(match.getText()));
|
||||
default -> List.of();
|
||||
};
|
||||
if (memoryResult == null || memoryResult.isEmpty()) continue;
|
||||
removeDuplicateSlice(memoryResult);
|
||||
memoryResultList.add(memoryResult);
|
||||
for (SliceRef ref : refs) {
|
||||
ActivatedMemorySlice recalledSlice = buildActivatedMemorySlice(ref);
|
||||
if (recalledSlice != null) {
|
||||
candidates.putIfAbsent(ref.getUnitId() + ":" + ref.getSliceId(), recalledSlice);
|
||||
}
|
||||
}
|
||||
} catch (UnExistedDateIndexException | UnExistedTopicException e) {
|
||||
log.error("[MemorySelector] 不存在的记忆索引! 请尝试更换更合适的主题提取LLM!", e);
|
||||
log.error("[MemorySelector] 不存在的记忆索引", e);
|
||||
log.error("[MemorySelector] 错误索引: {}", match.getText());
|
||||
}
|
||||
}
|
||||
//清理切片记录
|
||||
memoryCapability.cleanSelectedSliceFilter();
|
||||
//根据userInfo过滤是否为私人记忆
|
||||
for (MemoryResult memoryResult : memoryResultList) {
|
||||
//过滤终点记忆
|
||||
memoryResult.getMemorySliceResult().removeIf(m -> removeOrNot(m.getMemorySlice(), userId));
|
||||
//过滤邻近记忆
|
||||
memoryResult.getRelatedMemorySliceResult().removeIf(m -> removeOrNot(m, userId));
|
||||
}
|
||||
}
|
||||
|
||||
private void removeDuplicateSlice(MemoryResult memoryResult) {
|
||||
Collection<String> values = memoryCapability.getDialogMap().values();
|
||||
memoryResult.getRelatedMemorySliceResult().removeIf(m -> values.contains(m.getSummary()));
|
||||
memoryResult.getMemorySliceResult().removeIf(m -> values.contains(m.getMemorySlice().getSummary()));
|
||||
private ActivatedMemorySlice buildActivatedMemorySlice(SliceRef ref) {
|
||||
MemoryUnit memoryUnit = memoryCapability.getMemoryUnit(ref.getUnitId());
|
||||
MemorySlice memorySlice = memoryCapability.getMemorySlice(ref.getUnitId(), ref.getSliceId());
|
||||
if (memoryUnit == null || memorySlice == null) {
|
||||
return null;
|
||||
}
|
||||
List<Message> messages = sliceMessages(memoryUnit, memorySlice);
|
||||
LocalDate date = Instant.ofEpochMilli(memorySlice.getTimestamp())
|
||||
.atZone(ZoneId.systemDefault())
|
||||
.toLocalDate();
|
||||
return ActivatedMemorySlice.builder()
|
||||
.unitId(ref.getUnitId())
|
||||
.sliceId(ref.getSliceId())
|
||||
.summary(memorySlice.getSummary())
|
||||
.timestamp(memorySlice.getTimestamp())
|
||||
.date(date)
|
||||
.messages(messages)
|
||||
.build();
|
||||
}
|
||||
|
||||
private boolean removeOrNot(MemorySlice memorySlice, String userId) {
|
||||
if (memorySlice.isPrivate()) {
|
||||
return memorySlice.getStartUserId().equals(userId);
|
||||
private List<Message> sliceMessages(MemoryUnit memoryUnit, MemorySlice memorySlice) {
|
||||
List<Message> conversationMessages = memoryUnit.getConversationMessages();
|
||||
if (conversationMessages == null || conversationMessages.isEmpty()) {
|
||||
return List.of();
|
||||
}
|
||||
return false;
|
||||
int start = Math.max(0, memorySlice.getStartIndex());
|
||||
int end = Math.min(conversationMessages.size() - 1, memorySlice.getEndIndex());
|
||||
if (start > end) {
|
||||
return List.of();
|
||||
}
|
||||
return new ArrayList<>(conversationMessages.subList(start, end + 1));
|
||||
}
|
||||
|
||||
private void removeDuplicateSlice(Collection<ActivatedMemorySlice> candidates) {
|
||||
Collection<String> values = memoryRuntime.getDialogMap().values();
|
||||
candidates.removeIf(m -> values.contains(m.getSummary()));
|
||||
}
|
||||
|
||||
@Override
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
package work.slhaf.partner.module.modules.memory.selector.evaluator;
|
||||
|
||||
import cn.hutool.core.date.DateUtil;
|
||||
import cn.hutool.json.JSONUtil;
|
||||
import com.alibaba.fastjson2.JSONObject;
|
||||
import lombok.Data;
|
||||
@@ -10,10 +9,7 @@ import work.slhaf.partner.api.agent.factory.component.abstracts.ActivateModel;
|
||||
import work.slhaf.partner.api.agent.factory.component.annotation.Init;
|
||||
import work.slhaf.partner.api.chat.pojo.Message;
|
||||
import work.slhaf.partner.common.thread.InteractionThreadPoolExecutor;
|
||||
import work.slhaf.partner.core.memory.pojo.EvaluatedSlice;
|
||||
import work.slhaf.partner.core.memory.pojo.MemoryResult;
|
||||
import work.slhaf.partner.core.memory.pojo.MemorySlice;
|
||||
import work.slhaf.partner.core.memory.pojo.MemorySliceResult;
|
||||
import work.slhaf.partner.core.memory.pojo.ActivatedMemorySlice;
|
||||
import work.slhaf.partner.module.modules.memory.selector.evaluator.entity.EvaluatorBatchInput;
|
||||
import work.slhaf.partner.module.modules.memory.selector.evaluator.entity.EvaluatorInput;
|
||||
import work.slhaf.partner.module.modules.memory.selector.evaluator.entity.EvaluatorResult;
|
||||
@@ -27,7 +23,7 @@ import java.util.concurrent.atomic.AtomicInteger;
|
||||
|
||||
@EqualsAndHashCode(callSuper = true)
|
||||
@Data
|
||||
public class SliceSelectEvaluator extends AbstractAgentModule.Sub<EvaluatorInput, List<EvaluatedSlice>> implements ActivateModel {
|
||||
public class SliceSelectEvaluator extends AbstractAgentModule.Sub<EvaluatorInput, List<ActivatedMemorySlice>> implements ActivateModel {
|
||||
private InteractionThreadPoolExecutor executor;
|
||||
|
||||
@Init
|
||||
@@ -36,83 +32,58 @@ public class SliceSelectEvaluator extends AbstractAgentModule.Sub<EvaluatorInput
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<EvaluatedSlice> execute(EvaluatorInput evaluatorInput) {
|
||||
public List<ActivatedMemorySlice> execute(EvaluatorInput evaluatorInput) {
|
||||
log.debug("[SliceSelectEvaluator] 切片评估模块开始...");
|
||||
List<MemoryResult> memoryResultList = evaluatorInput.getMemoryResults();
|
||||
List<ActivatedMemorySlice> memorySlices = evaluatorInput.getMemorySlices();
|
||||
List<Callable<Void>> tasks = new ArrayList<>();
|
||||
Queue<EvaluatedSlice> queue = new ConcurrentLinkedDeque<>();
|
||||
Queue<ActivatedMemorySlice> queue = new ConcurrentLinkedDeque<>();
|
||||
AtomicInteger count = new AtomicInteger(0);
|
||||
for (MemoryResult memoryResult : memoryResultList) {
|
||||
if (memoryResult.getMemorySliceResult().isEmpty() && memoryResult.getRelatedMemorySliceResult().isEmpty()) {
|
||||
continue;
|
||||
}
|
||||
tasks.add(() -> {
|
||||
int thisCount = count.incrementAndGet();
|
||||
log.debug("[SliceSelectEvaluator] 评估[{}]开始", thisCount);
|
||||
List<SliceSummary> sliceSummaryList = new ArrayList<>();
|
||||
//映射查找键值
|
||||
Map<Long, SliceSummary> map = new HashMap<>();
|
||||
try {
|
||||
setSliceSummaryList(memoryResult, sliceSummaryList, map);
|
||||
EvaluatorBatchInput batchInput = EvaluatorBatchInput.builder()
|
||||
.text(evaluatorInput.getInput())
|
||||
.memory_slices(sliceSummaryList)
|
||||
.history(evaluatorInput.getMessages())
|
||||
.build();
|
||||
log.debug("[SliceSelectEvaluator] 评估[{}]输入: {}", thisCount, JSONObject.toJSONString(batchInput));
|
||||
EvaluatorResult evaluatorResult = formattedChat(
|
||||
List.of(new Message(Message.Character.USER, JSONUtil.toJsonStr(batchInput))),
|
||||
EvaluatorResult.class
|
||||
);
|
||||
log.debug("[SliceSelectEvaluator] 评估[{}]结果: {}", thisCount, JSONObject.toJSONString(evaluatorResult));
|
||||
for (Long result : evaluatorResult.getResults()) {
|
||||
SliceSummary sliceSummary = map.get(result);
|
||||
EvaluatedSlice evaluatedSlice = EvaluatedSlice.builder()
|
||||
.summary(sliceSummary.getSummary())
|
||||
.date(sliceSummary.getDate())
|
||||
.build();
|
||||
// setEvaluatedSliceMessages(evaluatedSlice, memoryResult, sliceSummary.getId());
|
||||
queue.offer(evaluatedSlice);
|
||||
}
|
||||
} catch (Exception e) {
|
||||
log.error("[SliceSelectEvaluator] 评估[{}]出现错误: {}", thisCount, e.getLocalizedMessage());
|
||||
}
|
||||
return null;
|
||||
});
|
||||
if (memorySlices == null || memorySlices.isEmpty()) {
|
||||
return List.of();
|
||||
}
|
||||
tasks.add(() -> {
|
||||
int thisCount = count.incrementAndGet();
|
||||
log.debug("[SliceSelectEvaluator] 评估[{}]开始", thisCount);
|
||||
List<SliceSummary> sliceSummaryList = new ArrayList<>();
|
||||
Map<Long, ActivatedMemorySlice> map = new HashMap<>();
|
||||
try {
|
||||
setSliceSummaryList(memorySlices, sliceSummaryList, map);
|
||||
EvaluatorBatchInput batchInput = EvaluatorBatchInput.builder()
|
||||
.text(evaluatorInput.getInput())
|
||||
.memory_slices(sliceSummaryList)
|
||||
.history(evaluatorInput.getMessages())
|
||||
.build();
|
||||
log.debug("[SliceSelectEvaluator] 评估[{}]输入: {}", thisCount, JSONObject.toJSONString(batchInput));
|
||||
EvaluatorResult evaluatorResult = formattedChat(
|
||||
List.of(new Message(Message.Character.USER, JSONUtil.toJsonStr(batchInput))),
|
||||
EvaluatorResult.class
|
||||
);
|
||||
log.debug("[SliceSelectEvaluator] 评估[{}]结果: {}", thisCount, JSONObject.toJSONString(evaluatorResult));
|
||||
for (Long result : evaluatorResult.getResults()) {
|
||||
ActivatedMemorySlice slice = map.get(result);
|
||||
if (slice != null) {
|
||||
queue.offer(slice);
|
||||
}
|
||||
}
|
||||
} catch (Exception e) {
|
||||
log.error("[SliceSelectEvaluator] 评估[{}]出现错误: {}", thisCount, e.getLocalizedMessage());
|
||||
}
|
||||
return null;
|
||||
});
|
||||
executor.invokeAll(tasks, 30, TimeUnit.SECONDS);
|
||||
log.debug("[SliceSelectEvaluator] 评估模块结束, 输出队列: {}", queue);
|
||||
List<EvaluatedSlice> temp = queue.stream().toList();
|
||||
return new ArrayList<>(temp);
|
||||
return new ArrayList<>(queue);
|
||||
}
|
||||
|
||||
private void setSliceSummaryList(MemoryResult memoryResult, List<SliceSummary> sliceSummaryList, Map<Long, SliceSummary> map) {
|
||||
for (MemorySliceResult memorySliceResult : memoryResult.getMemorySliceResult()) {
|
||||
SliceSummary sliceSummary = new SliceSummary();
|
||||
sliceSummary.setId(memorySliceResult.getMemorySlice().getTimestamp());
|
||||
StringBuilder stringBuilder = new StringBuilder();
|
||||
if (memorySliceResult.getSliceBefore() != null) {
|
||||
stringBuilder.append(memorySliceResult.getSliceBefore().getSummary())
|
||||
.append("\r\n");
|
||||
}
|
||||
stringBuilder.append(memorySliceResult.getMemorySlice().getSummary());
|
||||
if (memorySliceResult.getSliceAfter() != null) {
|
||||
stringBuilder.append("\r\n")
|
||||
.append(memorySliceResult.getSliceAfter().getSummary())
|
||||
.append("\r\n");
|
||||
}
|
||||
sliceSummary.setSummary(stringBuilder.toString());
|
||||
Long timestamp = memorySliceResult.getMemorySlice().getTimestamp();
|
||||
sliceSummary.setDate(DateUtil.date(timestamp).toLocalDateTime().toLocalDate());
|
||||
sliceSummaryList.add(sliceSummary);
|
||||
map.put(timestamp, sliceSummary);
|
||||
}
|
||||
for (MemorySlice memorySlice : memoryResult.getRelatedMemorySliceResult()) {
|
||||
private void setSliceSummaryList(List<ActivatedMemorySlice> memorySlices, List<SliceSummary> sliceSummaryList,
|
||||
Map<Long, ActivatedMemorySlice> map) {
|
||||
for (ActivatedMemorySlice memorySlice : memorySlices) {
|
||||
SliceSummary sliceSummary = new SliceSummary();
|
||||
sliceSummary.setId(memorySlice.getTimestamp());
|
||||
sliceSummary.setSummary(memorySlice.getSummary());
|
||||
sliceSummary.setDate(memorySlice.getDate());
|
||||
sliceSummaryList.add(sliceSummary);
|
||||
map.put(memorySlice.getTimestamp(), sliceSummary);
|
||||
map.put(memorySlice.getTimestamp(), memorySlice);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ package work.slhaf.partner.module.modules.memory.selector.evaluator.entity;
|
||||
import lombok.Builder;
|
||||
import lombok.Data;
|
||||
import work.slhaf.partner.api.chat.pojo.Message;
|
||||
import work.slhaf.partner.core.memory.pojo.MemoryResult;
|
||||
import work.slhaf.partner.core.memory.pojo.ActivatedMemorySlice;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
@@ -12,5 +12,5 @@ import java.util.List;
|
||||
public class EvaluatorInput {
|
||||
private String input;
|
||||
private List<Message> messages;
|
||||
private List<MemoryResult> memoryResults;
|
||||
private List<ActivatedMemorySlice> memorySlices;
|
||||
}
|
||||
|
||||
@@ -6,17 +6,17 @@ import lombok.EqualsAndHashCode;
|
||||
import work.slhaf.partner.api.agent.factory.capability.annotation.InjectCapability;
|
||||
import work.slhaf.partner.api.agent.factory.component.abstracts.AbstractAgentModule;
|
||||
import work.slhaf.partner.api.agent.factory.component.abstracts.ActivateModel;
|
||||
import work.slhaf.partner.api.agent.factory.component.annotation.InjectModule;
|
||||
import work.slhaf.partner.api.chat.pojo.Message;
|
||||
import work.slhaf.partner.api.chat.pojo.MetaMessage;
|
||||
import work.slhaf.partner.core.cognation.CognationCapability;
|
||||
import work.slhaf.partner.core.memory.MemoryCapability;
|
||||
import work.slhaf.partner.core.memory.pojo.EvaluatedSlice;
|
||||
import work.slhaf.partner.core.memory.pojo.ActivatedMemorySlice;
|
||||
import work.slhaf.partner.module.modules.memory.runtime.MemoryRuntime;
|
||||
import work.slhaf.partner.module.modules.memory.selector.extractor.entity.ExtractorInput;
|
||||
import work.slhaf.partner.module.modules.memory.selector.extractor.entity.ExtractorMatchData;
|
||||
import work.slhaf.partner.module.modules.memory.selector.extractor.entity.ExtractorResult;
|
||||
import work.slhaf.partner.runtime.interaction.data.context.PartnerRunningFlowContext;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
import static work.slhaf.partner.common.util.ExtractUtil.fixTopicPath;
|
||||
@@ -29,25 +29,21 @@ public class MemorySelectExtractor extends AbstractAgentModule.Sub<PartnerRunnin
|
||||
private MemoryCapability memoryCapability;
|
||||
@InjectCapability
|
||||
private CognationCapability cognationCapability;
|
||||
@InjectModule
|
||||
private MemoryRuntime memoryRuntime;
|
||||
|
||||
@Override
|
||||
public ExtractorResult execute(PartnerRunningFlowContext context) {
|
||||
log.debug("[MemorySelectExtractor] 主题提取模块开始...");
|
||||
// 结构化为指定格式
|
||||
List<Message> chatMessages = new ArrayList<>();
|
||||
List<MetaMessage> metaMessages = cognationCapability.snapshotSingleMetaMessages(context.getSource());
|
||||
for (MetaMessage metaMessage : metaMessages) {
|
||||
chatMessages.add(metaMessage.getUserMessage());
|
||||
chatMessages.add(metaMessage.getAssistantMessage());
|
||||
}
|
||||
List<Message> chatMessages = cognationCapability.snapshotChatMessages();
|
||||
ExtractorResult extractorResult;
|
||||
try {
|
||||
List<EvaluatedSlice> activatedMemorySlices = memoryCapability.getActivatedSlices(context.getSource());
|
||||
List<ActivatedMemorySlice> activatedMemorySlices = memoryCapability.getActivatedSlices();
|
||||
ExtractorInput extractorInput = ExtractorInput.builder()
|
||||
.text(context.getInput())
|
||||
.date(context.getInfo().getDateTime().toLocalDate())
|
||||
.history(chatMessages)
|
||||
.topic_tree(memoryCapability.getTopicTree())
|
||||
.topic_tree(memoryRuntime.getTopicTree())
|
||||
.activatedMemorySlices(activatedMemorySlices)
|
||||
.build();
|
||||
log.debug("[MemorySelectExtractor] 主题提取输入: {}", JSONUtil.toJsonStr(extractorInput));
|
||||
|
||||
@@ -3,7 +3,7 @@ package work.slhaf.partner.module.modules.memory.selector.extractor.entity;
|
||||
import lombok.Builder;
|
||||
import lombok.Data;
|
||||
import work.slhaf.partner.api.chat.pojo.Message;
|
||||
import work.slhaf.partner.core.memory.pojo.EvaluatedSlice;
|
||||
import work.slhaf.partner.core.memory.pojo.ActivatedMemorySlice;
|
||||
|
||||
import java.time.LocalDate;
|
||||
import java.util.List;
|
||||
@@ -15,5 +15,5 @@ public class ExtractorInput {
|
||||
private String topic_tree;
|
||||
private LocalDate date;
|
||||
private List<Message> history;
|
||||
private List<EvaluatedSlice> activatedMemorySlices;
|
||||
private List<ActivatedMemorySlice> activatedMemorySlices;
|
||||
}
|
||||
|
||||
@@ -8,31 +8,29 @@ import work.slhaf.partner.api.agent.factory.capability.annotation.InjectCapabili
|
||||
import work.slhaf.partner.api.agent.factory.component.annotation.Init;
|
||||
import work.slhaf.partner.api.agent.factory.component.annotation.InjectModule;
|
||||
import work.slhaf.partner.api.chat.pojo.Message;
|
||||
import work.slhaf.partner.api.chat.pojo.MetaMessage;
|
||||
import work.slhaf.partner.common.thread.InteractionThreadPoolExecutor;
|
||||
import work.slhaf.partner.core.action.entity.Schedulable;
|
||||
import work.slhaf.partner.core.action.entity.StateAction;
|
||||
import work.slhaf.partner.core.cognation.CognationCapability;
|
||||
import work.slhaf.partner.core.memory.MemoryCapability;
|
||||
import work.slhaf.partner.core.memory.pojo.MemorySlice;
|
||||
import work.slhaf.partner.core.perceive.PerceiveCapability;
|
||||
import work.slhaf.partner.core.memory.pojo.MemoryUnit;
|
||||
import work.slhaf.partner.core.memory.pojo.SliceRef;
|
||||
import work.slhaf.partner.module.common.module.PostRunningAgentModule;
|
||||
import work.slhaf.partner.module.modules.action.scheduler.ActionScheduler;
|
||||
import work.slhaf.partner.module.modules.memory.runtime.MemoryRuntime;
|
||||
import work.slhaf.partner.module.modules.memory.updater.summarizer.MultiSummarizer;
|
||||
import work.slhaf.partner.module.modules.memory.updater.summarizer.SingleSummarizer;
|
||||
import work.slhaf.partner.module.modules.memory.updater.summarizer.TotalSummarizer;
|
||||
import work.slhaf.partner.module.modules.memory.updater.summarizer.entity.SummarizeInput;
|
||||
import work.slhaf.partner.module.modules.memory.updater.summarizer.entity.SummarizeResult;
|
||||
import work.slhaf.partner.runtime.interaction.data.context.PartnerRunningFlowContext;
|
||||
|
||||
import java.time.LocalDateTime;
|
||||
import java.util.*;
|
||||
import java.util.concurrent.Callable;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import java.util.UUID;
|
||||
import java.util.concurrent.atomic.AtomicBoolean;
|
||||
import java.util.concurrent.atomic.AtomicInteger;
|
||||
|
||||
import static work.slhaf.partner.common.util.ExtractUtil.extractUserId;
|
||||
|
||||
@EqualsAndHashCode(callSuper = true)
|
||||
@Data
|
||||
@@ -46,21 +44,20 @@ public class MemoryUpdater extends PostRunningAgentModule {
|
||||
private CognationCapability cognationCapability;
|
||||
@InjectCapability
|
||||
private MemoryCapability memoryCapability;
|
||||
@InjectCapability
|
||||
private PerceiveCapability perceiveCapability;
|
||||
|
||||
@InjectModule
|
||||
private MemoryRuntime memoryRuntime;
|
||||
@InjectModule
|
||||
private MultiSummarizer multiSummarizer;
|
||||
@InjectModule
|
||||
private SingleSummarizer singleSummarizer;
|
||||
@InjectModule
|
||||
private TotalSummarizer totalSummarizer;
|
||||
private final AtomicBoolean updating = new AtomicBoolean(false);
|
||||
|
||||
private InteractionThreadPoolExecutor executor;
|
||||
@InjectModule
|
||||
private ActionScheduler actionScheduler;
|
||||
|
||||
private final AtomicBoolean updating = new AtomicBoolean(false);
|
||||
private InteractionThreadPoolExecutor executor;
|
||||
private volatile long lastUpdatedTime;
|
||||
|
||||
@Init
|
||||
public void init() {
|
||||
executor = InteractionThreadPoolExecutor.getInstance();
|
||||
@@ -86,15 +83,13 @@ public class MemoryUpdater extends PostRunningAgentModule {
|
||||
@Override
|
||||
public void doExecute(PartnerRunningFlowContext context) {
|
||||
executor.execute(() -> {
|
||||
// 如果token 大于阈值,则更新记忆
|
||||
JSONObject moduleContext = context.getModuleContext().getExtraContext();
|
||||
boolean recall = moduleContext.getBoolean("recall");
|
||||
if (recall) {
|
||||
log.debug("[MemoryUpdater] 存在回忆");
|
||||
int recallCount = moduleContext.getIntValue("recall_count");
|
||||
log.debug("[MemoryUpdater] 记忆切片数量 [{}]", recallCount);
|
||||
log.debug("[MemoryUpdater] 当前激活记忆数量 [{}]", recallCount);
|
||||
}
|
||||
boolean trigger = context.getModuleContext().getExtraContext().getBoolean("post_process_trigger");
|
||||
boolean trigger = moduleContext.getBoolean("post_process_trigger");
|
||||
if (!trigger) {
|
||||
return;
|
||||
}
|
||||
@@ -110,7 +105,6 @@ public class MemoryUpdater extends PostRunningAgentModule {
|
||||
|
||||
private void tryAutoUpdate() {
|
||||
long currentTime = System.currentTimeMillis();
|
||||
long lastUpdatedTime = cognationCapability.getLastUpdatedTime();
|
||||
int chatCount = cognationCapability.snapshotChatMessages().size();
|
||||
if (lastUpdatedTime != 0 && currentTime - lastUpdatedTime > UPDATE_TRIGGER_INTERVAL && chatCount > 1) {
|
||||
triggerMemoryUpdate(true);
|
||||
@@ -131,7 +125,7 @@ public class MemoryUpdater extends PostRunningAgentModule {
|
||||
updateMemory(chatSnapshot);
|
||||
cognationCapability.rollChatMessagesWithSnapshot(chatSnapshot.size(), CONTEXT_RETAIN_DIVISOR);
|
||||
if (refreshMemoryId) {
|
||||
cognationCapability.refreshMemoryId();
|
||||
memoryCapability.refreshMemoryId();
|
||||
}
|
||||
} catch (Exception e) {
|
||||
log.error("[MemoryUpdater] 记忆更新线程出错: ", e);
|
||||
@@ -142,75 +136,35 @@ public class MemoryUpdater extends PostRunningAgentModule {
|
||||
|
||||
private void updateMemory(List<Message> chatSnapshot) {
|
||||
log.debug("[MemoryUpdater] 记忆更新流程开始...");
|
||||
Map<String, String> singleMemorySummary = new ConcurrentHashMap<>();
|
||||
Map<String, List<Message>> singleChatMessages = drainSingleChatMessages();
|
||||
// 更新单聊记忆,同时从chatMessages中去掉单聊记忆
|
||||
updateSingleChatSlices(singleMemorySummary, singleChatMessages);
|
||||
// 更新多人场景下的记忆及相关的确定性记忆
|
||||
List<Message> multiChatMessages = excludeSingleChatMessages(chatSnapshot, singleChatMessages);
|
||||
updateMultiChatSlices(singleMemorySummary, multiChatMessages);
|
||||
cognationCapability.resetLastUpdatedTime();
|
||||
List<Message> chatMessages = getCleanedMessages(chatSnapshot);
|
||||
if (chatMessages.isEmpty()) {
|
||||
return;
|
||||
}
|
||||
SummarizeInput summarizeInput = new SummarizeInput(chatMessages, memoryRuntime.getTopicTree());
|
||||
log.debug("[MemoryUpdater] 记忆更新-总结流程-输入: {}", JSONObject.toJSONString(summarizeInput));
|
||||
SummarizeResult summarizeResult = summarize(summarizeInput);
|
||||
log.debug("[MemoryUpdater] 记忆更新-总结流程-输出: {}", JSONObject.toJSONString(summarizeResult));
|
||||
MemoryUnit memoryUnit = buildMemoryUnit(chatMessages, summarizeResult);
|
||||
memoryCapability.saveMemoryUnit(memoryUnit);
|
||||
MemorySlice memorySlice = memoryUnit.getSlices().getFirst();
|
||||
SliceRef sliceRef = new SliceRef(memoryUnit.getId(), memorySlice.getId());
|
||||
bindTopics(memoryUnit, summarizeResult, sliceRef);
|
||||
memoryRuntime.updateDialogMap(LocalDateTime.now(), summarizeResult.getSummary());
|
||||
lastUpdatedTime = System.currentTimeMillis();
|
||||
log.debug("[MemoryUpdater] 记忆更新流程结束...");
|
||||
}
|
||||
|
||||
private Map<String, List<Message>> drainSingleChatMessages() {
|
||||
Map<String, List<Message>> drainedMessages = new HashMap<>();
|
||||
Map<String, List<MetaMessage>> drainedMetaMessages = cognationCapability.drainSingleMetaMessages();
|
||||
for (Map.Entry<String, List<MetaMessage>> entry : drainedMetaMessages.entrySet()) {
|
||||
List<Message> messages = new ArrayList<>();
|
||||
for (MetaMessage metaMessage : entry.getValue()) {
|
||||
messages.add(metaMessage.getUserMessage());
|
||||
messages.add(metaMessage.getAssistantMessage());
|
||||
}
|
||||
drainedMessages.put(entry.getKey(), messages);
|
||||
private void bindTopics(MemoryUnit memoryUnit, SummarizeResult summarizeResult, SliceRef sliceRef) {
|
||||
memoryRuntime.indexMemoryUnit(memoryUnit);
|
||||
memoryRuntime.bindTopic(summarizeResult.getTopicPath(), sliceRef);
|
||||
if (summarizeResult.getRelatedTopicPath() == null) {
|
||||
return;
|
||||
}
|
||||
return drainedMessages;
|
||||
}
|
||||
|
||||
private List<Message> excludeSingleChatMessages(List<Message> chatSnapshot, Map<String, List<Message>> singleChatMessages) {
|
||||
Set<Message> singleMessages = new HashSet<>();
|
||||
for (List<Message> messages : singleChatMessages.values()) {
|
||||
singleMessages.addAll(messages);
|
||||
for (String relatedTopicPath : summarizeResult.getRelatedTopicPath()) {
|
||||
memoryRuntime.bindTopic(relatedTopicPath, sliceRef);
|
||||
}
|
||||
return chatSnapshot.stream()
|
||||
.filter(message -> !singleMessages.contains(message))
|
||||
.toList();
|
||||
}
|
||||
|
||||
private void updateMultiChatSlices(Map<String, String> singleMemorySummary, List<Message> multiChatMessages) {
|
||||
// 此时chatMessages中不再包含单聊记录,直接执行摘要以及切片插入
|
||||
// 对剩下的多人聊天记录进行进行摘要
|
||||
Callable<Void> task = () -> {
|
||||
log.debug("[MemoryUpdater] 多人聊天记忆更新流程开始...");
|
||||
List<Message> chatMessages = getCleanedMessages(multiChatMessages);
|
||||
if (!chatMessages.isEmpty()) {
|
||||
log.debug("[MemoryUpdater] 存在多人聊天记录, 流程正常进行...");
|
||||
// 以第一条user对应的id为发起用户
|
||||
String userId = extractUserId(chatMessages.getFirst().getContent());
|
||||
if (userId == null) {
|
||||
throw new RuntimeException("未匹配到 userId!");
|
||||
}
|
||||
SummarizeInput summarizeInput = new SummarizeInput(chatMessages, memoryCapability.getTopicTree());
|
||||
log.debug("[MemoryUpdater] 多人聊天记忆更新-总结流程-输入: {}", summarizeInput);
|
||||
SummarizeResult summarizeResult = summarize(summarizeInput);
|
||||
log.debug("[MemoryUpdater] 多人聊天记忆更新-总结流程-输出: {}", summarizeResult);
|
||||
MemorySlice memorySlice = getMemorySlice(userId, summarizeResult, chatMessages);
|
||||
// 设置involvedUserId
|
||||
setInvolvedUserId(userId, memorySlice, chatMessages);
|
||||
memoryCapability.insertSlice(memorySlice, summarizeResult.getTopicPath());
|
||||
memoryCapability.updateDialogMap(LocalDateTime.now(), summarizeResult.getSummary());
|
||||
} else {
|
||||
log.debug("[MemoryUpdater] 不存在多人聊天记录, 将以单聊总结为对话缓存的主要输入: {}", singleMemorySummary);
|
||||
memoryCapability.updateDialogMap(LocalDateTime.now(), totalSummarizer.execute(new HashMap<>(singleMemorySummary)));
|
||||
}
|
||||
log.debug("[MemoryUpdater] 对话缓存更新完毕");
|
||||
log.debug("[MemoryUpdater] 多人聊天记忆更新流程结束...");
|
||||
return null;
|
||||
};
|
||||
executor.invokeAll(List.of(task));
|
||||
}
|
||||
|
||||
// TODO need to move time information into perceive core
|
||||
private List<Message> getCleanedMessages(List<Message> chatMessages) {
|
||||
return chatMessages.stream()
|
||||
.map(message -> {
|
||||
@@ -226,84 +180,27 @@ public class MemoryUpdater extends PostRunningAgentModule {
|
||||
}).toList();
|
||||
}
|
||||
|
||||
private void setInvolvedUserId(String startUserId, MemorySlice memorySlice, List<Message> chatMessages) {
|
||||
for (Message chatMessage : chatMessages) {
|
||||
if (chatMessage.getRole() == Message.Character.ASSISTANT) {
|
||||
continue;
|
||||
}
|
||||
// 匹配userId
|
||||
String userId = extractUserId(chatMessage.getContent());
|
||||
if (userId == null) {
|
||||
continue;
|
||||
}
|
||||
if (userId.equals(startUserId)) {
|
||||
continue;
|
||||
}
|
||||
memorySlice.setInvolvedUserIds(new ArrayList<>());
|
||||
memorySlice.getInvolvedUserIds().add(userId);
|
||||
}
|
||||
}
|
||||
|
||||
private void updateSingleChatSlices(Map<String, String> singleMemorySummary, Map<String, List<Message>> singleChatMessages) {
|
||||
log.debug("[MemoryUpdater] 单聊记忆更新流程开始...");
|
||||
List<Callable<Void>> tasks = new ArrayList<>();
|
||||
AtomicInteger count = new AtomicInteger(0);
|
||||
for (Map.Entry<String, List<Message>> entry : singleChatMessages.entrySet()) {
|
||||
String id = entry.getKey();
|
||||
List<Message> messages = entry.getValue();
|
||||
if (messages.isEmpty()) {
|
||||
continue;
|
||||
}
|
||||
tasks.add(() -> {
|
||||
int thisCount = count.incrementAndGet();
|
||||
log.debug("[MemoryUpdater] 单聊记忆[{}]更新: {}", thisCount, id);
|
||||
try {
|
||||
// 单聊记忆更新
|
||||
SummarizeInput summarizeInput = new SummarizeInput(messages, memoryCapability.getTopicTree());
|
||||
log.debug("[MemoryUpdater] 单聊记忆[{}]更新-总结流程-输入: {}", thisCount, JSONObject.toJSONString(summarizeInput));
|
||||
SummarizeResult summarizeResult = summarize(summarizeInput);
|
||||
log.debug("[MemoryUpdater] 单聊记忆[{}]更新-总结流程-输出: {}", thisCount, JSONObject.toJSONString(summarizeResult));
|
||||
MemorySlice memorySlice = getMemorySlice(id, summarizeResult, messages);
|
||||
// 插入时userDialogMap已经进行更新
|
||||
memoryCapability.insertSlice(memorySlice, summarizeResult.getTopicPath());
|
||||
// 从chatMessages中移除单聊记录
|
||||
cognationCapability.cleanMessage(messages);
|
||||
// 添加至singleMemorySummary
|
||||
String key = perceiveCapability.getUser(id).getNickName() + "[" + id + "]";
|
||||
singleMemorySummary.put(key, summarizeResult.getSummary());
|
||||
log.debug("[MemoryUpdater] 单聊记忆[{}]更新成功: ", thisCount);
|
||||
} catch (Exception e) {
|
||||
log.error("[MemoryUpdater] 单聊记忆[{}]更新出错: ", thisCount, e);
|
||||
}
|
||||
return null;
|
||||
});
|
||||
}
|
||||
executor.invokeAll(tasks);
|
||||
log.debug("[MemoryUpdater] 单聊记忆更新结束...");
|
||||
}
|
||||
|
||||
private SummarizeResult summarize(SummarizeInput summarizeInput) {
|
||||
singleSummarizer.execute(summarizeInput.getChatMessages());
|
||||
return multiSummarizer.execute(summarizeInput);
|
||||
}
|
||||
|
||||
private MemorySlice getMemorySlice(String userId, SummarizeResult summarizeResult, List<Message> chatMessages) {
|
||||
private MemoryUnit buildMemoryUnit(List<Message> chatMessages, SummarizeResult summarizeResult) {
|
||||
long now = System.currentTimeMillis();
|
||||
MemorySlice memorySlice = new MemorySlice();
|
||||
// 设置 memoryId,timestamp
|
||||
memorySlice.setMemoryId(cognationCapability.getCurrentMemoryId());
|
||||
memorySlice.setTimestamp(System.currentTimeMillis());
|
||||
// 补充信息
|
||||
memorySlice.setPrivate(summarizeResult.isPrivate());
|
||||
memorySlice.setId(UUID.randomUUID().toString());
|
||||
memorySlice.setStartIndex(0);
|
||||
memorySlice.setEndIndex(Math.max(chatMessages.size() - 1, 0));
|
||||
memorySlice.setSummary(summarizeResult.getSummary());
|
||||
memorySlice.setChatMessages(chatMessages);
|
||||
memorySlice.setStartUserId(userId);
|
||||
List<List<String>> relatedTopicPathList = new ArrayList<>();
|
||||
for (String string : summarizeResult.getRelatedTopicPath()) {
|
||||
List<String> list = Arrays.stream(string.split("->")).toList();
|
||||
relatedTopicPathList.add(list);
|
||||
}
|
||||
memorySlice.setRelatedTopics(relatedTopicPathList);
|
||||
return memorySlice;
|
||||
memorySlice.setTimestamp(now);
|
||||
|
||||
MemoryUnit memoryUnit = new MemoryUnit();
|
||||
String memoryId = memoryCapability.getCurrentMemoryId();
|
||||
memoryUnit.setId(memoryId == null || memoryId.isBlank() ? UUID.randomUUID().toString() : memoryId);
|
||||
memoryUnit.setTimestamp(now);
|
||||
memoryUnit.setConversationMessages(new ArrayList<>(chatMessages));
|
||||
memoryUnit.setSlices(new ArrayList<>(List.of(memorySlice)));
|
||||
return memoryUnit;
|
||||
}
|
||||
|
||||
@Override
|
||||
|
||||
@@ -9,5 +9,4 @@ public class SummarizeResult {
|
||||
private String summary;
|
||||
private String topicPath;
|
||||
private List<String> relatedTopicPath;
|
||||
private boolean isPrivate;
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@ import lombok.EqualsAndHashCode;
|
||||
import work.slhaf.partner.api.agent.factory.capability.annotation.InjectCapability;
|
||||
import work.slhaf.partner.api.agent.factory.component.abstracts.AbstractAgentModule;
|
||||
import work.slhaf.partner.core.cognation.CognationCapability;
|
||||
import work.slhaf.partner.core.memory.MemoryCapability;
|
||||
import work.slhaf.partner.core.perceive.PerceiveCapability;
|
||||
import work.slhaf.partner.core.perceive.pojo.User;
|
||||
import work.slhaf.partner.runtime.interaction.data.context.PartnerRunningFlowContext;
|
||||
@@ -18,6 +19,8 @@ public class PreprocessExecutor extends AbstractAgentModule.Running<PartnerRunni
|
||||
@InjectCapability
|
||||
private CognationCapability cognationCapability;
|
||||
@InjectCapability
|
||||
private MemoryCapability memoryCapability;
|
||||
@InjectCapability
|
||||
private PerceiveCapability perceiveCapability;
|
||||
|
||||
@Override
|
||||
@@ -27,9 +30,9 @@ public class PreprocessExecutor extends AbstractAgentModule.Running<PartnerRunni
|
||||
}
|
||||
|
||||
private void checkAndSetMemoryId() {
|
||||
String currentMemoryId = cognationCapability.getCurrentMemoryId();
|
||||
String currentMemoryId = memoryCapability.getCurrentMemoryId();
|
||||
if (currentMemoryId == null || cognationCapability.getChatMessages().isEmpty()) {
|
||||
cognationCapability.refreshMemoryId();
|
||||
memoryCapability.refreshMemoryId();
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -2,7 +2,6 @@ package experimental;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
import work.slhaf.partner.core.memory.MemoryCapability;
|
||||
import work.slhaf.partner.core.memory.pojo.MemoryResult;
|
||||
|
||||
import java.lang.reflect.Proxy;
|
||||
|
||||
@@ -15,12 +14,12 @@ public class ReflectionTest {
|
||||
@Test
|
||||
public void proxyTest() {
|
||||
MemoryCapability memory = (MemoryCapability) Proxy.newProxyInstance(this.getClass().getClassLoader(), new Class[]{MemoryCapability.class}, (proxy, method, args) -> {
|
||||
if ("selectMemory".equals(method.getName())) {
|
||||
if ("getCurrentMemoryId".equals(method.getName())) {
|
||||
System.out.println(111);
|
||||
return new MemoryResult();
|
||||
return "memory-id";
|
||||
}
|
||||
return null;
|
||||
});
|
||||
memory.selectMemory("111");
|
||||
memory.getCurrentMemoryId();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -74,7 +74,7 @@ class ActionExecutorTest {
|
||||
@BeforeEach
|
||||
void setUp() {
|
||||
lenient().when(cognationCapability.getChatMessages()).thenReturn(Collections.emptyList());
|
||||
lenient().when(memoryCapability.getActivatedSlices(anyString())).thenReturn(Collections.emptyList());
|
||||
lenient().when(memoryCapability.getActivatedSlices()).thenReturn(Collections.emptyList());
|
||||
lenient().when(actionCapability.putPhaserRecord(any(Phaser.class), any(ExecutableAction.class)))
|
||||
.thenAnswer(inv -> new PhaserRecord(inv.getArgument(0), inv.getArgument(1)));
|
||||
lenient().when(actionCapability.loadMetaActionInfo(anyString())).thenAnswer(inv -> {
|
||||
|
||||
@@ -1,17 +1,15 @@
|
||||
package work.slhaf.partner.module.modules.core;
|
||||
|
||||
import org.jetbrains.annotations.NotNull;
|
||||
import org.junit.jupiter.api.BeforeEach;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.extension.ExtendWith;
|
||||
import org.jetbrains.annotations.NotNull;
|
||||
import org.mockito.ArgumentCaptor;
|
||||
import org.mockito.Mock;
|
||||
import org.mockito.junit.jupiter.MockitoExtension;
|
||||
import org.w3c.dom.Document;
|
||||
import org.w3c.dom.Element;
|
||||
import work.slhaf.partner.api.agent.runtime.interaction.flow.ContextBlock;
|
||||
import work.slhaf.partner.api.chat.pojo.Message;
|
||||
import work.slhaf.partner.api.chat.pojo.MetaMessage;
|
||||
import work.slhaf.partner.core.cognation.CognationCapability;
|
||||
import work.slhaf.partner.runtime.interaction.data.context.PartnerRunningFlowContext;
|
||||
|
||||
@@ -19,13 +17,8 @@ import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.concurrent.locks.ReentrantLock;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import static org.junit.jupiter.api.Assertions.assertFalse;
|
||||
import static org.junit.jupiter.api.Assertions.assertNotNull;
|
||||
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||
import static org.mockito.ArgumentMatchers.anyString;
|
||||
import static org.junit.jupiter.api.Assertions.*;
|
||||
import static org.mockito.Mockito.lenient;
|
||||
import static org.mockito.Mockito.verify;
|
||||
|
||||
@ExtendWith(MockitoExtension.class)
|
||||
class CommunicationProducerTest {
|
||||
@@ -100,12 +93,6 @@ class CommunicationProducerTest {
|
||||
Message lastAssistantMessage = producer.getChatMessages().get(3);
|
||||
assertEquals("收到", lastAssistantMessage.getContent());
|
||||
|
||||
ArgumentCaptor<MetaMessage> metaMessageCaptor = ArgumentCaptor.forClass(MetaMessage.class);
|
||||
verify(cognationCapability).addMetaMessage(anyString(), metaMessageCaptor.capture());
|
||||
MetaMessage metaMessage = metaMessageCaptor.getValue();
|
||||
assertNotNull(metaMessage);
|
||||
assertTrue(metaMessage.getUserMessage().getContent().startsWith("[USER]: u-1: 你好,介绍一下你现在看到的上下文"));
|
||||
assertEquals("收到", metaMessage.getAssistantMessage().getContent());
|
||||
assertEquals("收到", context.getCoreResponse().getString("text"));
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user