diff --git a/src/main/java/work/slhaf/agent/common/config/Config.java b/src/main/java/work/slhaf/agent/common/config/Config.java
index b1438d00..dacbb6b2 100644
--- a/src/main/java/work/slhaf/agent/common/config/Config.java
+++ b/src/main/java/work/slhaf/agent/common/config/Config.java
@@ -27,7 +27,7 @@ public class Config {
private static Config config;
private String agentId;
- private String basicCharacter;
+// private String basicCharacter;
private WebSocketConfig webSocketConfig;
@@ -48,8 +48,8 @@ public class Config {
System.out.print("输入智能体名称: ");
config.setAgentId(scanner.nextLine());
- System.out.print("输入智能体基础角色设定: ");
- config.setBasicCharacter(scanner.nextLine());
+// System.out.print("输入智能体基础角色设定: ");
+// config.setBasicCharacter(scanner.nextLine());
System.out.println("(注意! 设定角色之后修改主配置文件将不会影响现有记忆,除非同时更换agentId)");
diff --git a/src/main/java/work/slhaf/agent/core/memory/MemoryGraph.java b/src/main/java/work/slhaf/agent/core/memory/MemoryGraph.java
index 2a6f6111..02283dbd 100644
--- a/src/main/java/work/slhaf/agent/core/memory/MemoryGraph.java
+++ b/src/main/java/work/slhaf/agent/core/memory/MemoryGraph.java
@@ -81,7 +81,7 @@ public class MemoryGraph extends PersistableObject {
* 存储确定性记忆, 如'用户爱好'等确定性信息
* 该部分作为'主LLM'system prompt常驻
*/
- private HashMap> staticMemory;
+// private HashMap> staticMemory;
/**
* memorySliceCache计数器,每日清空
@@ -121,12 +121,12 @@ public class MemoryGraph extends PersistableObject {
*/
private Set selectedSlices;
- public MemoryGraph(String id, String basicCharacter) {
+ public MemoryGraph(String id) {
this.id = id;
this.topicNodes = new HashMap<>();
this.existedTopics = new HashMap<>();
this.currentDateDialogSlices = new HashMap<>();
- this.staticMemory = new HashMap<>();
+// this.staticMemory = new HashMap<>();
this.memoryNodeCacheCounter = new ConcurrentHashMap<>();
this.memorySliceCache = new ConcurrentHashMap<>();
this.modelPrompt = new HashMap<>();
@@ -139,7 +139,7 @@ public class MemoryGraph extends PersistableObject {
this.dateIndex = new HashMap<>();
}
- public static MemoryGraph getInstance(String id, String basicCharacter) throws IOException, ClassNotFoundException {
+ public static MemoryGraph getInstance(String id) throws IOException, ClassNotFoundException {
if (memoryGraph == null) {
synchronized (MemoryGraph.class) {
// 检查存储目录是否存在,不存在则创建
@@ -150,7 +150,7 @@ public class MemoryGraph extends PersistableObject {
memoryGraph = deserialize(id);
} else {
FileUtils.createParentDirectories(filePath.toFile().getParentFile());
- memoryGraph = new MemoryGraph(id, basicCharacter);
+ memoryGraph = new MemoryGraph(id);
memoryGraph.serialize();
}
log.info("MemoryGraph注册完毕...");
diff --git a/src/main/java/work/slhaf/agent/core/memory/MemoryManager.java b/src/main/java/work/slhaf/agent/core/memory/MemoryManager.java
index e700ed64..2aef5ce7 100644
--- a/src/main/java/work/slhaf/agent/core/memory/MemoryManager.java
+++ b/src/main/java/work/slhaf/agent/core/memory/MemoryManager.java
@@ -45,7 +45,7 @@ public class MemoryManager extends PersistableObject {
if (memoryManager == null) {
Config config = Config.getConfig();
memoryManager = new MemoryManager();
- memoryManager.setMemoryGraph(MemoryGraph.getInstance(config.getAgentId(), config.getBasicCharacter()));
+ memoryManager.setMemoryGraph(MemoryGraph.getInstance(config.getAgentId()));
memoryManager.setActivatedSlices(new HashMap<>());
memoryManager.setShutdownHook();
log.info("[MemoryManager] MemoryManager注册完毕...");
@@ -115,9 +115,9 @@ public class MemoryManager extends PersistableObject {
return memoryGraph.getTopicTree();
}
- public ConcurrentHashMap getStaticMemory(String userId) {
+/* public ConcurrentHashMap getStaticMemory(String userId) {
return memoryGraph.getStaticMemory().get(userId);
- }
+ }*/
public HashMap getDialogMap() {
return memoryGraph.getDialogMap();
@@ -145,12 +145,12 @@ public class MemoryManager extends PersistableObject {
messageCleanLock.unlock();
}
- public void insertStaticMemory(String userId, Map newStaticMemory) {
+/* public void insertStaticMemory(String userId, Map newStaticMemory) {
if (!memoryGraph.getStaticMemory().containsKey(userId)) {
memoryGraph.getStaticMemory().put(userId, new ConcurrentHashMap<>());
}
memoryGraph.getStaticMemory().get(userId).putAll(newStaticMemory);
- }
+ }*/
public void updateDialogMap(LocalDateTime dateTime, String newDialogCache) {
memoryGraph.updateDialogMap(dateTime, newDialogCache);
@@ -173,4 +173,37 @@ public class MemoryManager extends PersistableObject {
}
return null;
}
+
+ public String getActivatedSlicesStr(String userId) {
+ StringBuilder str = new StringBuilder();
+ if (memoryManager.getActivatedSlices().containsKey(userId)) {
+ memoryManager.getActivatedSlices().get(userId).forEach(slice -> {
+ str.append("\n\n").append("[").append(slice.getDate()).append("]\n")
+ .append(slice.getSummary());
+ });
+ }
+ return str.toString();
+ }
+
+ public String getDialogMapStr() {
+ StringBuilder str = new StringBuilder();
+ memoryGraph.getDialogMap().forEach((dateTime, dialog) -> {
+ str.append("\n\n").append("[").append(dateTime).append("]\n")
+ .append(dialog);
+ });
+ return str.toString();
+ }
+
+ public String getUserDialogMapStr(String userId) {
+ StringBuilder str = new StringBuilder();
+ Collection dialogMapValues = memoryGraph.getDialogMap().values();
+ memoryGraph.getUserDialogMap().get(userId).forEach((dateTime, dialog) -> {
+ if (dialogMapValues.contains(dialog)) {
+ return;
+ }
+ str.append("\n\n").append("[").append(dateTime).append("]\n")
+ .append(dialog);
+ });
+ return str.toString();
+ }
}
diff --git a/src/main/java/work/slhaf/agent/module/common/ModelConstant.java b/src/main/java/work/slhaf/agent/module/common/ModelConstant.java
index 2b4342da..e8048bd5 100644
--- a/src/main/java/work/slhaf/agent/module/common/ModelConstant.java
+++ b/src/main/java/work/slhaf/agent/module/common/ModelConstant.java
@@ -9,7 +9,7 @@ public class ModelConstant {
}
public static class CharacterPrefix {
- public static final String SYSTEM = "[system] ";
+ public static final String SYSTEM = "[SYSTEM] ";
}
}
diff --git a/src/main/java/work/slhaf/agent/module/modules/core/CoreModel.java b/src/main/java/work/slhaf/agent/module/modules/core/CoreModel.java
index b75eac86..5a385857 100644
--- a/src/main/java/work/slhaf/agent/module/modules/core/CoreModel.java
+++ b/src/main/java/work/slhaf/agent/module/modules/core/CoreModel.java
@@ -153,18 +153,45 @@ public class CoreModel extends Model implements InteractionModule {
private void setAppendedPromptMessage(List appendPrompt) {
Message appendDeclareMessage = Message.builder()
.role(ChatConstant.Character.USER)
- .content(ModelConstant.CharacterPrefix.SYSTEM + "以下为追加字段声明,可能包含用户的输入字段和你需要在回应中添加的输出字段.")
+// .content(ModelConstant.CharacterPrefix.SYSTEM + "以下为追加字段声明,可能包含用户的输入字段和你需要在回应中添加的输出字段.")
+ .content(ModelConstant.CharacterPrefix.SYSTEM + "以下为你的相关认知内容,可在对话中参考")
.build();
this.appendedMessages.add(appendDeclareMessage);
for (AppendPromptData data : appendPrompt) {
- StringBuilder str = new StringBuilder(data.getComment()).append("\r\n");
- data.getAppendedPrompt().forEach((k, v) -> str.append(k).append(": ").append(v).append("\r\n"));
- appendedMessages.add(new Message(ChatConstant.Character.USER, str.toString()));
+ setStartMessage(data);
+ setContentMessage(data);
+ setEndMessage(data);
}
Message appendEndMessage = Message.builder()
.role(ChatConstant.Character.USER)
- .content(ModelConstant.CharacterPrefix.SYSTEM + "追加字段声明结束,接下来为用户的真实输入。")
+ .content(ModelConstant.CharacterPrefix.SYSTEM + "相关认知内容结束,接下来是‘你’——‘Partner’与用户的真正交互")
.build();
this.appendedMessages.add(appendEndMessage);
}
+
+ private void setEndMessage(AppendPromptData data) {
+ Message endMessage = Message.builder()
+ .role(ChatConstant.Character.USER)
+ .content(ModelConstant.CharacterPrefix.SYSTEM + data.getComment() + "认知补充结束.")
+ .build();
+ appendedMessages.add(endMessage);
+ }
+
+ private void setContentMessage(AppendPromptData data) {
+ data.getAppendedPrompt().forEach((k, v) -> {
+ Message contentMessage = Message.builder()
+ .role(ChatConstant.Character.USER)
+ .content(ModelConstant.CharacterPrefix.SYSTEM + k + v)
+ .build();
+ appendedMessages.add(contentMessage);
+ });
+ }
+
+ private void setStartMessage(AppendPromptData data) {
+ Message startMessage = Message.builder()
+ .role(ChatConstant.Character.USER)
+ .content(ModelConstant.CharacterPrefix.SYSTEM + data.getComment() + "以下为" + data.getComment() + "相关认知.")
+ .build();
+ appendedMessages.add(startMessage);
+ }
}
diff --git a/src/main/java/work/slhaf/agent/module/modules/memory/selector/MemorySelector.java b/src/main/java/work/slhaf/agent/module/modules/memory/selector/MemorySelector.java
index c3766396..5c22bbb6 100644
--- a/src/main/java/work/slhaf/agent/module/modules/memory/selector/MemorySelector.java
+++ b/src/main/java/work/slhaf/agent/module/modules/memory/selector/MemorySelector.java
@@ -23,6 +23,7 @@ import work.slhaf.agent.shared.memory.EvaluatedSlice;
import java.io.IOException;
import java.time.LocalDate;
import java.util.ArrayList;
+import java.util.Collection;
import java.util.HashMap;
import java.util.List;
@@ -89,21 +90,22 @@ public class MemorySelector implements InteractionModule, AppendPrompt {
//获取主题路径
ExtractorResult extractorResult = memorySelectExtractor.execute(interactionContext);
if (extractorResult.isRecall() || !extractorResult.getMatches().isEmpty()) {
- selectAndEvaluateMemory(interactionContext, extractorResult, userId);
- }
- if (extractorResult.isRecall()) {
- memoryManager.getActivatedSlices().clear();
+ memoryManager.getActivatedSlices().get(userId).clear();
+ List evaluatedSlices = selectAndEvaluateMemory(interactionContext, extractorResult);
+ memoryManager.updateActivatedSlices(userId, evaluatedSlices);
}
//设置上下文
- setModuleContext(interactionContext, userId);
+// setCoreContext(interactionContext);
//设置追加提示词
setAppendedPrompt(interactionContext);
+ setModuleContextRecall(interactionContext);
log.debug("[MemorySelector] 记忆回溯结果: {}", interactionContext);
}
- private void selectAndEvaluateMemory(InteractionContext interactionContext, ExtractorResult extractorResult, String userId) throws IOException, ClassNotFoundException, InterruptedException {
+ private List selectAndEvaluateMemory(InteractionContext interactionContext, ExtractorResult extractorResult) throws IOException, ClassNotFoundException, InterruptedException {
log.debug("[MemorySelector] 触发记忆回溯...");
//查找切片
+ String userId = interactionContext.getUserId();
List memoryResultList = new ArrayList<>();
setMemoryResultList(memoryResultList, extractorResult.getMatches(), userId);
//评估切片
@@ -115,18 +117,19 @@ public class MemorySelector implements InteractionModule, AppendPrompt {
log.debug("[MemorySelector] 切片评估输入: {}", evaluatorInput);
List memorySlices = sliceSelectEvaluator.execute(evaluatorInput);
log.debug("[MemorySelector] 切片评估结果: {}", memorySlices);
- memoryManager.updateActivatedSlices(userId, memorySlices);
+ return memorySlices;
}
- private void setModuleContext(InteractionContext interactionContext, String userId) {
+ /*private void setCoreContext(InteractionContext interactionContext) {
+ String userId = interactionContext.getUserId();
interactionContext.getCoreContext().put("memory_slices", memoryManager.getActivatedSlices().get(userId));
- interactionContext.getCoreContext().put("static_memory", memoryManager.getStaticMemory(userId));
+// interactionContext.getCoreContext().put("static_memory", memoryManager.getStaticMemory(userId));
interactionContext.getCoreContext().put("dialog_map", memoryManager.getDialogMap());
interactionContext.getCoreContext().put("user_dialog_map", memoryManager.getUserDialogMap(userId));
- setModuleContextRecall(interactionContext, userId);
- }
+ }*/
- private void setModuleContextRecall(InteractionContext interactionContext, String userId) {
+ private void setModuleContextRecall(InteractionContext interactionContext) {
+ String userId = interactionContext.getUserId();
boolean recall;
if (memoryManager.getActivatedSlices().get(userId) == null) {
recall = false;
@@ -150,6 +153,7 @@ public class MemorySelector implements InteractionModule, AppendPrompt {
default -> null;
};
if (memoryResult == null) continue;
+ removeDuplicateSlice(memoryResult);
memoryResultList.add(memoryResult);
} catch (UnExistedDateIndexException | UnExistedTopicException e) {
log.error("[MemorySelector] 不存在的记忆索引! 请尝试更换更合适的主题提取LLM!", e);
@@ -168,6 +172,12 @@ public class MemorySelector implements InteractionModule, AppendPrompt {
}
}
+ private void removeDuplicateSlice(MemoryResult memoryResult) {
+ Collection values = memoryManager.getDialogMap().values();
+ memoryResult.getRelatedMemorySliceResult().removeIf(m -> values.contains(m.getSummary()));
+ memoryResult.getMemorySliceResult().removeIf(m -> values.contains(m.getMemorySlice().getSummary()));
+ }
+
private boolean removeOrNot(MemorySlice memorySlice, String userId) {
if (memorySlice.isPrivate()) {
return memorySlice.getStartUserId().equals(userId);
@@ -177,14 +187,30 @@ public class MemorySelector implements InteractionModule, AppendPrompt {
@Override
public void setAppendedPrompt(InteractionContext context) {
- HashMap map = new HashMap<>();
- map.put("memory_slices", "本次对话可参考的记忆切片");
- map.put("static_memory", "关于本次对话对象的稳定记忆");
- map.put("dialog_map", "近两日的与所有用户的对话缓存");
- map.put("user_dialog_map", "与当前用户的近两日对话缓存");
+ String userId = context.getUserId();
+ HashMap map = getPromptDataMap(userId);
AppendPromptData data = new AppendPromptData();
- data.setComment("[system] 追加字段: 记忆模块");
+ data.setComment("[记忆模块]");
data.setAppendedPrompt(map);
context.setAppendedPrompt(data);
}
+
+ private HashMap getPromptDataMap(String userId) {
+ HashMap map = new HashMap<>();
+ String dialogMapStr = memoryManager.getDialogMapStr();
+ if (!dialogMapStr.isEmpty()) {
+ map.put("[记忆缓存] <你最近两日和所有聊天者的对话记忆印象>", dialogMapStr);
+ }
+
+ String userDialogMapStr = memoryManager.getUserDialogMapStr(userId);
+ if (!userDialogMapStr.isEmpty()) {
+ map.put("[用户记忆缓存] <与最新一条消息的发送者的近两天对话记忆印象, 可能与[记忆缓存]稍有重复>", "与当前用户的近两日对话缓存");
+ }
+
+ String sliceStr = memoryManager.getActivatedSlicesStr(userId);
+ if (!sliceStr.isEmpty()){
+ map.put("[记忆切片] <你与最新一条消息的发送者的相关回忆, 不会与[记忆缓存]重复, 如果有重复你也可以指出来()>", sliceStr);
+ }
+ return map;
+ }
}
diff --git a/src/test/java/memory/MemoryTest.java b/src/test/java/memory/MemoryTest.java
index 4fafe7ef..814e2398 100644
--- a/src/test/java/memory/MemoryTest.java
+++ b/src/test/java/memory/MemoryTest.java
@@ -12,7 +12,7 @@ public class MemoryTest {
//@Test
public void test1() {
String basicCharacter = "";
- MemoryGraph graph = new MemoryGraph("test", basicCharacter);
+ MemoryGraph graph = new MemoryGraph("test");
HashMap topicMap = new HashMap<>();
TopicNode root1 = new TopicNode();