mirror of
https://github.com/slhaf/Partner.git
synced 2026-05-12 08:43:02 +08:00
进行: 调整认知补充机制; 应当可以进入第二阶段测试
- 调整模块追加认知内容补充机制, 舍弃了原有的CoreContext, 分为多条prompt message注入上文 - 完全去除记忆模块的静态记忆内容 - 在MemoryManager中添加了必要的获取相关字符串的方法,而不是依赖原始的JSON格式
This commit is contained in:
@@ -27,7 +27,7 @@ public class Config {
|
||||
private static Config config;
|
||||
|
||||
private String agentId;
|
||||
private String basicCharacter;
|
||||
// private String basicCharacter;
|
||||
|
||||
private WebSocketConfig webSocketConfig;
|
||||
|
||||
@@ -48,8 +48,8 @@ public class Config {
|
||||
System.out.print("输入智能体名称: ");
|
||||
config.setAgentId(scanner.nextLine());
|
||||
|
||||
System.out.print("输入智能体基础角色设定: ");
|
||||
config.setBasicCharacter(scanner.nextLine());
|
||||
// System.out.print("输入智能体基础角色设定: ");
|
||||
// config.setBasicCharacter(scanner.nextLine());
|
||||
|
||||
System.out.println("(注意! 设定角色之后修改主配置文件将不会影响现有记忆,除非同时更换agentId)");
|
||||
|
||||
|
||||
@@ -81,7 +81,7 @@ public class MemoryGraph extends PersistableObject {
|
||||
* 存储确定性记忆, 如'用户爱好'等确定性信息
|
||||
* 该部分作为'主LLM'system prompt常驻
|
||||
*/
|
||||
private HashMap<String /*userId*/, ConcurrentHashMap<String /*memoryKey*/, String /*memoryValue*/>> staticMemory;
|
||||
// private HashMap<String /*userId*/, ConcurrentHashMap<String /*memoryKey*/, String /*memoryValue*/>> staticMemory;
|
||||
|
||||
/**
|
||||
* memorySliceCache计数器,每日清空
|
||||
@@ -121,12 +121,12 @@ public class MemoryGraph extends PersistableObject {
|
||||
*/
|
||||
private Set<Long> selectedSlices;
|
||||
|
||||
public MemoryGraph(String id, String basicCharacter) {
|
||||
public MemoryGraph(String id) {
|
||||
this.id = id;
|
||||
this.topicNodes = new HashMap<>();
|
||||
this.existedTopics = new HashMap<>();
|
||||
this.currentDateDialogSlices = new HashMap<>();
|
||||
this.staticMemory = new HashMap<>();
|
||||
// this.staticMemory = new HashMap<>();
|
||||
this.memoryNodeCacheCounter = new ConcurrentHashMap<>();
|
||||
this.memorySliceCache = new ConcurrentHashMap<>();
|
||||
this.modelPrompt = new HashMap<>();
|
||||
@@ -139,7 +139,7 @@ public class MemoryGraph extends PersistableObject {
|
||||
this.dateIndex = new HashMap<>();
|
||||
}
|
||||
|
||||
public static MemoryGraph getInstance(String id, String basicCharacter) throws IOException, ClassNotFoundException {
|
||||
public static MemoryGraph getInstance(String id) throws IOException, ClassNotFoundException {
|
||||
if (memoryGraph == null) {
|
||||
synchronized (MemoryGraph.class) {
|
||||
// 检查存储目录是否存在,不存在则创建
|
||||
@@ -150,7 +150,7 @@ public class MemoryGraph extends PersistableObject {
|
||||
memoryGraph = deserialize(id);
|
||||
} else {
|
||||
FileUtils.createParentDirectories(filePath.toFile().getParentFile());
|
||||
memoryGraph = new MemoryGraph(id, basicCharacter);
|
||||
memoryGraph = new MemoryGraph(id);
|
||||
memoryGraph.serialize();
|
||||
}
|
||||
log.info("MemoryGraph注册完毕...");
|
||||
|
||||
@@ -45,7 +45,7 @@ public class MemoryManager extends PersistableObject {
|
||||
if (memoryManager == null) {
|
||||
Config config = Config.getConfig();
|
||||
memoryManager = new MemoryManager();
|
||||
memoryManager.setMemoryGraph(MemoryGraph.getInstance(config.getAgentId(), config.getBasicCharacter()));
|
||||
memoryManager.setMemoryGraph(MemoryGraph.getInstance(config.getAgentId()));
|
||||
memoryManager.setActivatedSlices(new HashMap<>());
|
||||
memoryManager.setShutdownHook();
|
||||
log.info("[MemoryManager] MemoryManager注册完毕...");
|
||||
@@ -115,9 +115,9 @@ public class MemoryManager extends PersistableObject {
|
||||
return memoryGraph.getTopicTree();
|
||||
}
|
||||
|
||||
public ConcurrentHashMap<String, String> getStaticMemory(String userId) {
|
||||
/* public ConcurrentHashMap<String, String> getStaticMemory(String userId) {
|
||||
return memoryGraph.getStaticMemory().get(userId);
|
||||
}
|
||||
}*/
|
||||
|
||||
public HashMap<LocalDateTime, String> getDialogMap() {
|
||||
return memoryGraph.getDialogMap();
|
||||
@@ -145,12 +145,12 @@ public class MemoryManager extends PersistableObject {
|
||||
messageCleanLock.unlock();
|
||||
}
|
||||
|
||||
public void insertStaticMemory(String userId, Map<String, String> newStaticMemory) {
|
||||
/* public void insertStaticMemory(String userId, Map<String, String> newStaticMemory) {
|
||||
if (!memoryGraph.getStaticMemory().containsKey(userId)) {
|
||||
memoryGraph.getStaticMemory().put(userId, new ConcurrentHashMap<>());
|
||||
}
|
||||
memoryGraph.getStaticMemory().get(userId).putAll(newStaticMemory);
|
||||
}
|
||||
}*/
|
||||
|
||||
public void updateDialogMap(LocalDateTime dateTime, String newDialogCache) {
|
||||
memoryGraph.updateDialogMap(dateTime, newDialogCache);
|
||||
@@ -173,4 +173,37 @@ public class MemoryManager extends PersistableObject {
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
public String getActivatedSlicesStr(String userId) {
|
||||
StringBuilder str = new StringBuilder();
|
||||
if (memoryManager.getActivatedSlices().containsKey(userId)) {
|
||||
memoryManager.getActivatedSlices().get(userId).forEach(slice -> {
|
||||
str.append("\n\n").append("[").append(slice.getDate()).append("]\n")
|
||||
.append(slice.getSummary());
|
||||
});
|
||||
}
|
||||
return str.toString();
|
||||
}
|
||||
|
||||
public String getDialogMapStr() {
|
||||
StringBuilder str = new StringBuilder();
|
||||
memoryGraph.getDialogMap().forEach((dateTime, dialog) -> {
|
||||
str.append("\n\n").append("[").append(dateTime).append("]\n")
|
||||
.append(dialog);
|
||||
});
|
||||
return str.toString();
|
||||
}
|
||||
|
||||
public String getUserDialogMapStr(String userId) {
|
||||
StringBuilder str = new StringBuilder();
|
||||
Collection<String> dialogMapValues = memoryGraph.getDialogMap().values();
|
||||
memoryGraph.getUserDialogMap().get(userId).forEach((dateTime, dialog) -> {
|
||||
if (dialogMapValues.contains(dialog)) {
|
||||
return;
|
||||
}
|
||||
str.append("\n\n").append("[").append(dateTime).append("]\n")
|
||||
.append(dialog);
|
||||
});
|
||||
return str.toString();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -9,7 +9,7 @@ public class ModelConstant {
|
||||
}
|
||||
|
||||
public static class CharacterPrefix {
|
||||
public static final String SYSTEM = "[system] ";
|
||||
public static final String SYSTEM = "[SYSTEM] ";
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -153,18 +153,45 @@ public class CoreModel extends Model implements InteractionModule {
|
||||
private void setAppendedPromptMessage(List<AppendPromptData> appendPrompt) {
|
||||
Message appendDeclareMessage = Message.builder()
|
||||
.role(ChatConstant.Character.USER)
|
||||
.content(ModelConstant.CharacterPrefix.SYSTEM + "以下为追加字段声明,可能包含用户的输入字段和你需要在回应中添加的输出字段.")
|
||||
// .content(ModelConstant.CharacterPrefix.SYSTEM + "以下为追加字段声明,可能包含用户的输入字段和你需要在回应中添加的输出字段.")
|
||||
.content(ModelConstant.CharacterPrefix.SYSTEM + "以下为你的相关认知内容,可在对话中参考")
|
||||
.build();
|
||||
this.appendedMessages.add(appendDeclareMessage);
|
||||
for (AppendPromptData data : appendPrompt) {
|
||||
StringBuilder str = new StringBuilder(data.getComment()).append("\r\n");
|
||||
data.getAppendedPrompt().forEach((k, v) -> str.append(k).append(": ").append(v).append("\r\n"));
|
||||
appendedMessages.add(new Message(ChatConstant.Character.USER, str.toString()));
|
||||
setStartMessage(data);
|
||||
setContentMessage(data);
|
||||
setEndMessage(data);
|
||||
}
|
||||
Message appendEndMessage = Message.builder()
|
||||
.role(ChatConstant.Character.USER)
|
||||
.content(ModelConstant.CharacterPrefix.SYSTEM + "追加字段声明结束,接下来为用户的真实输入。")
|
||||
.content(ModelConstant.CharacterPrefix.SYSTEM + "相关认知内容结束,接下来是‘你’——‘Partner’与用户的真正交互")
|
||||
.build();
|
||||
this.appendedMessages.add(appendEndMessage);
|
||||
}
|
||||
|
||||
private void setEndMessage(AppendPromptData data) {
|
||||
Message endMessage = Message.builder()
|
||||
.role(ChatConstant.Character.USER)
|
||||
.content(ModelConstant.CharacterPrefix.SYSTEM + data.getComment() + "认知补充结束.")
|
||||
.build();
|
||||
appendedMessages.add(endMessage);
|
||||
}
|
||||
|
||||
private void setContentMessage(AppendPromptData data) {
|
||||
data.getAppendedPrompt().forEach((k, v) -> {
|
||||
Message contentMessage = Message.builder()
|
||||
.role(ChatConstant.Character.USER)
|
||||
.content(ModelConstant.CharacterPrefix.SYSTEM + k + v)
|
||||
.build();
|
||||
appendedMessages.add(contentMessage);
|
||||
});
|
||||
}
|
||||
|
||||
private void setStartMessage(AppendPromptData data) {
|
||||
Message startMessage = Message.builder()
|
||||
.role(ChatConstant.Character.USER)
|
||||
.content(ModelConstant.CharacterPrefix.SYSTEM + data.getComment() + "以下为" + data.getComment() + "相关认知.")
|
||||
.build();
|
||||
appendedMessages.add(startMessage);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -23,6 +23,7 @@ import work.slhaf.agent.shared.memory.EvaluatedSlice;
|
||||
import java.io.IOException;
|
||||
import java.time.LocalDate;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collection;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
|
||||
@@ -89,21 +90,22 @@ public class MemorySelector implements InteractionModule, AppendPrompt {
|
||||
//获取主题路径
|
||||
ExtractorResult extractorResult = memorySelectExtractor.execute(interactionContext);
|
||||
if (extractorResult.isRecall() || !extractorResult.getMatches().isEmpty()) {
|
||||
selectAndEvaluateMemory(interactionContext, extractorResult, userId);
|
||||
}
|
||||
if (extractorResult.isRecall()) {
|
||||
memoryManager.getActivatedSlices().clear();
|
||||
memoryManager.getActivatedSlices().get(userId).clear();
|
||||
List<EvaluatedSlice> evaluatedSlices = selectAndEvaluateMemory(interactionContext, extractorResult);
|
||||
memoryManager.updateActivatedSlices(userId, evaluatedSlices);
|
||||
}
|
||||
//设置上下文
|
||||
setModuleContext(interactionContext, userId);
|
||||
// setCoreContext(interactionContext);
|
||||
//设置追加提示词
|
||||
setAppendedPrompt(interactionContext);
|
||||
setModuleContextRecall(interactionContext);
|
||||
log.debug("[MemorySelector] 记忆回溯结果: {}", interactionContext);
|
||||
}
|
||||
|
||||
private void selectAndEvaluateMemory(InteractionContext interactionContext, ExtractorResult extractorResult, String userId) throws IOException, ClassNotFoundException, InterruptedException {
|
||||
private List<EvaluatedSlice> selectAndEvaluateMemory(InteractionContext interactionContext, ExtractorResult extractorResult) throws IOException, ClassNotFoundException, InterruptedException {
|
||||
log.debug("[MemorySelector] 触发记忆回溯...");
|
||||
//查找切片
|
||||
String userId = interactionContext.getUserId();
|
||||
List<MemoryResult> memoryResultList = new ArrayList<>();
|
||||
setMemoryResultList(memoryResultList, extractorResult.getMatches(), userId);
|
||||
//评估切片
|
||||
@@ -115,18 +117,19 @@ public class MemorySelector implements InteractionModule, AppendPrompt {
|
||||
log.debug("[MemorySelector] 切片评估输入: {}", evaluatorInput);
|
||||
List<EvaluatedSlice> memorySlices = sliceSelectEvaluator.execute(evaluatorInput);
|
||||
log.debug("[MemorySelector] 切片评估结果: {}", memorySlices);
|
||||
memoryManager.updateActivatedSlices(userId, memorySlices);
|
||||
return memorySlices;
|
||||
}
|
||||
|
||||
private void setModuleContext(InteractionContext interactionContext, String userId) {
|
||||
/*private void setCoreContext(InteractionContext interactionContext) {
|
||||
String userId = interactionContext.getUserId();
|
||||
interactionContext.getCoreContext().put("memory_slices", memoryManager.getActivatedSlices().get(userId));
|
||||
interactionContext.getCoreContext().put("static_memory", memoryManager.getStaticMemory(userId));
|
||||
// interactionContext.getCoreContext().put("static_memory", memoryManager.getStaticMemory(userId));
|
||||
interactionContext.getCoreContext().put("dialog_map", memoryManager.getDialogMap());
|
||||
interactionContext.getCoreContext().put("user_dialog_map", memoryManager.getUserDialogMap(userId));
|
||||
setModuleContextRecall(interactionContext, userId);
|
||||
}
|
||||
}*/
|
||||
|
||||
private void setModuleContextRecall(InteractionContext interactionContext, String userId) {
|
||||
private void setModuleContextRecall(InteractionContext interactionContext) {
|
||||
String userId = interactionContext.getUserId();
|
||||
boolean recall;
|
||||
if (memoryManager.getActivatedSlices().get(userId) == null) {
|
||||
recall = false;
|
||||
@@ -150,6 +153,7 @@ public class MemorySelector implements InteractionModule, AppendPrompt {
|
||||
default -> null;
|
||||
};
|
||||
if (memoryResult == null) continue;
|
||||
removeDuplicateSlice(memoryResult);
|
||||
memoryResultList.add(memoryResult);
|
||||
} catch (UnExistedDateIndexException | UnExistedTopicException e) {
|
||||
log.error("[MemorySelector] 不存在的记忆索引! 请尝试更换更合适的主题提取LLM!", e);
|
||||
@@ -168,6 +172,12 @@ public class MemorySelector implements InteractionModule, AppendPrompt {
|
||||
}
|
||||
}
|
||||
|
||||
private void removeDuplicateSlice(MemoryResult memoryResult) {
|
||||
Collection<String> values = memoryManager.getDialogMap().values();
|
||||
memoryResult.getRelatedMemorySliceResult().removeIf(m -> values.contains(m.getSummary()));
|
||||
memoryResult.getMemorySliceResult().removeIf(m -> values.contains(m.getMemorySlice().getSummary()));
|
||||
}
|
||||
|
||||
private boolean removeOrNot(MemorySlice memorySlice, String userId) {
|
||||
if (memorySlice.isPrivate()) {
|
||||
return memorySlice.getStartUserId().equals(userId);
|
||||
@@ -177,14 +187,30 @@ public class MemorySelector implements InteractionModule, AppendPrompt {
|
||||
|
||||
@Override
|
||||
public void setAppendedPrompt(InteractionContext context) {
|
||||
HashMap<String, String> map = new HashMap<>();
|
||||
map.put("memory_slices", "本次对话可参考的记忆切片");
|
||||
map.put("static_memory", "关于本次对话对象的稳定记忆");
|
||||
map.put("dialog_map", "近两日的与所有用户的对话缓存");
|
||||
map.put("user_dialog_map", "与当前用户的近两日对话缓存");
|
||||
String userId = context.getUserId();
|
||||
HashMap<String, String> map = getPromptDataMap(userId);
|
||||
AppendPromptData data = new AppendPromptData();
|
||||
data.setComment("[system] 追加字段: 记忆模块");
|
||||
data.setComment("[记忆模块]");
|
||||
data.setAppendedPrompt(map);
|
||||
context.setAppendedPrompt(data);
|
||||
}
|
||||
|
||||
private HashMap<String, String> getPromptDataMap(String userId) {
|
||||
HashMap<String, String> map = new HashMap<>();
|
||||
String dialogMapStr = memoryManager.getDialogMapStr();
|
||||
if (!dialogMapStr.isEmpty()) {
|
||||
map.put("[记忆缓存] <你最近两日和所有聊天者的对话记忆印象>", dialogMapStr);
|
||||
}
|
||||
|
||||
String userDialogMapStr = memoryManager.getUserDialogMapStr(userId);
|
||||
if (!userDialogMapStr.isEmpty()) {
|
||||
map.put("[用户记忆缓存] <与最新一条消息的发送者的近两天对话记忆印象, 可能与[记忆缓存]稍有重复>", "与当前用户的近两日对话缓存");
|
||||
}
|
||||
|
||||
String sliceStr = memoryManager.getActivatedSlicesStr(userId);
|
||||
if (!sliceStr.isEmpty()){
|
||||
map.put("[记忆切片] <你与最新一条消息的发送者的相关回忆, 不会与[记忆缓存]重复, 如果有重复你也可以指出来()>", sliceStr);
|
||||
}
|
||||
return map;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -12,7 +12,7 @@ public class MemoryTest {
|
||||
//@Test
|
||||
public void test1() {
|
||||
String basicCharacter = "";
|
||||
MemoryGraph graph = new MemoryGraph("test", basicCharacter);
|
||||
MemoryGraph graph = new MemoryGraph("test");
|
||||
HashMap<String, TopicNode> topicMap = new HashMap<>();
|
||||
|
||||
TopicNode root1 = new TopicNode();
|
||||
|
||||
Reference in New Issue
Block a user