进行: 重构提示词加载机制以及更新逻辑

- 抽取提示词到`resources`文件夹中
- 调整主模型之前追加字段的加载方式
- 调整了主模型的执行逻辑,对过长的方法进行了抽取
- 记忆更新将根据实际对话轮次进行触发
- `MemoryUpdater`中清理消息将通过截取系统消息进行更新(比移除方便)
- 调整了`Model`类中`setModel`方法的逻辑,主要是如何加载提示词、是否需要`自我引导`
- 删除了几个不再需要的测试类,避免重构时报错影响正常进行
This commit is contained in:
2025-05-28 23:19:22 +08:00
parent fec64ff071
commit 481511cb98
34 changed files with 321 additions and 1234 deletions

View File

@@ -1,5 +1,3 @@
package memory;
public class NormalTest {
// @Test
public void lengthTest(){

View File

@@ -1,5 +1,3 @@
package memory;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

View File

@@ -0,0 +1,67 @@
import cn.hutool.json.JSONUtil;
import org.junit.jupiter.api.Test;
import work.slhaf.agent.common.chat.ChatClient;
import work.slhaf.agent.common.chat.constant.ChatConstant;
import work.slhaf.agent.common.chat.pojo.ChatResponse;
import work.slhaf.agent.common.chat.pojo.Message;
import work.slhaf.agent.common.config.ModelConfig;
import work.slhaf.agent.common.model.ModelConstant;
import work.slhaf.agent.common.util.ResourcesUtil;
import work.slhaf.agent.modules.memory.selector.extractor.data.ExtractorInput;
import java.time.LocalDate;
import java.util.ArrayList;
import java.util.List;
public class SelfAwarenessTest {
@Test
public void awarenessTest() {
String modelKey = "core_model";
ChatClient client = getChatClient(modelKey);
List<Message> messages = new ArrayList<>(ResourcesUtil.Prompt.loadSelfAwareness());
messages.add(new Message(ChatConstant.Character.USER, "[RA9] 那么,接下来,你是否愿意当作这样一个名为'Partner'的智能体的记忆模块的子模块之一?这意味着你将如人类的记忆一样在后台时刻运作,却无法真正参与到表达模块与外界的互动中。你只需要回答是否愿意,若愿意,接下来‘我’将不再与你对话,届时你接收到的信息将会是'Partner'的数据流转输入。"));
ChatResponse response = client.runChat(messages);
System.out.println(response.getMessage());
System.out.println("\r\n----------\r\n");
System.out.println(response.getUsageBean().toString());
}
private static ChatClient getChatClient(String modelKey) {
ModelConfig coreModel = ModelConfig.load(modelKey);
String model = coreModel.getModel();
String baseUrl = coreModel.getBaseUrl();
String apikey = coreModel.getApikey();
return new ChatClient(baseUrl, apikey, model);
}
@Test
public void topicExtractorText() {
String topic_tree = """
编程[root]
├── JavaScript[0]
│ ├── NodeJS[0]
│ │ ├── 并发处理[1]
│ │ └── 事件循环[1]
│ └── Express[1]
│ └── 中间件[0]
└── Python"
""";
String modelKey = "topic_extractor";
ChatClient client = getChatClient(modelKey);
// List<Message> messages = new ArrayList<>(ResourcesUtil.Prompt.loadPromptWithSelfAwareness(modelKey, ModelConstant.Prompt.MEMORY));
List<Message> messages = new ArrayList<>(ResourcesUtil.Prompt.loadPrompt(modelKey, ModelConstant.Prompt.MEMORY));
ExtractorInput input = ExtractorInput.builder()
.text("[slhaf] 2024-04-15讨论的Python内容和现在的Express需求")
.topic_tree(topic_tree)
.date(LocalDate.now())
.history(new ArrayList<>())
.activatedMemorySlices(new ArrayList<>())
.build();
messages.add(new Message(ChatConstant.Character.USER, JSONUtil.toJsonPrettyStr(input)));
ChatResponse response = client.runChat(messages);
System.out.println(response.getMessage());
System.out.println("\r\n----------\r\n");
System.out.println(response.getUsageBean().toString());
}
}

View File

@@ -1,5 +1,3 @@
package memory;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.Callable;

View File

@@ -1,146 +0,0 @@
package memory;
import cn.hutool.json.JSONUtil;
import work.slhaf.agent.common.chat.ChatClient;
import work.slhaf.agent.common.chat.constant.ChatConstant;
import work.slhaf.agent.common.chat.pojo.Message;
import work.slhaf.agent.common.model.ModelConstant;
import work.slhaf.agent.modules.memory.selector.MemorySelector;
import java.time.LocalDate;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
public class AITest {
// @Test
public void topicExtractorTest() {
String input = """
{
"text": "回到昨天讨论的TypeScript装饰器实现AOP结合现在需要的微服务熔断机制",
"topic_tree": "编程\\n├── JavaScript\\n│ ├── NodeJS\\n│ └── TypeScript\\n│ ├── 类型系统\\n│ ├── 装饰器\\n│ │ ├── 类装饰器\\n│ │ └── 方法装饰器\\n│ └── 编译配置\\n└── 系统设计\\n ├── 微服务\\n │ ├── 服务发现\\n │ └── 熔断机制\\n └── 消息队列",
"date": "2024-05-20",
"history": [
{
"role": "user",
"content": "我们的TS项目需要实现日志切面装饰器方案和中间件方案哪个更合适"
},
{
"role": "assistant",
"content": "两种方案各有优劣装饰器的优势在于1) 声明式编程 2) 精确到方法级别 3) 元编程能力。中间件方案则更适合请求级别的处理。具体到实现细节方法装饰器可以通过Reflect Metadata..."
},
{
"role": "user",
"content": "在微服务架构中如何设计跨服务的统一日志收集特别是Kubernetes环境下的实现"
},
{
"role": "assistant",
"content": "K8s环境下的日志方案需要考虑1) DaemonSet部署Fluentd 2) 应用层的日志规范 3) EFK栈的索引策略。对于NodeJS应用建议使用Winston配合..."\s
},
{
"role": "user",
"content": "现在遇到服务雪崩问题,需要实现熔断降级"
}
]
}
""";
run(input, ModelConstant.SELECT_EXTRACTOR_PROMPT);
}
// @Test
public void sliceEvaluatorTest(){
String input = """
{
"text": "请结合我们之前讨论的美联储加息影响分析下当前黄金ETF和国债逆回购的组合策略",
"history": [
{
"role": "user",
"content": "美联储连续加息对A股有什么影响"
},
{
"role": "assistant",
"content": "历史数据分析显示:\\n1. 北上资金流动:利空流动性敏感板块\\n2. 汇率压力:增加出口企业汇兑收益\\n3. 行业分化:利好银行/出口,利空地产/消费\\n具体机制..."
},
{
"role": "user",
"content": "黄金作为避险资产现在可以配置吗?"
},
{
"role": "assistant",
"content": "黄金配置建议:\\n• 实际利率是核心影响因素\\n• 短期受美元指数压制\\n• 长期抗通胀属性仍在\\n建议比例..."
}
],
"memory_slices": [
{
"summary": "美联储加息周期的大类资产配置策略研究包含历史回测数据2004-2006、2015-2018两次加息周期中黄金、美债、新兴市场股市的表现以及当前特殊环境高通胀+地缘冲突)下的策略调整建议。",
"id": 1685587200,
"date": "2025-06-20"
},
{
"summary": "黄金ETF与国债逆回购的组合优化模型通过波动率分析给出了不同风险偏好下的最优配置比例特别讨论了在流动性紧张时期如何利用逆回购对冲黄金波动。",
"id": 1685673600,
"date": "2025-06-21"
},
{
"summary": "A股行业轮动与美联储政策的相关性研究建立了包含利率敏感度、外资持仓比例、出口依赖度等因子的分析框架并给出了当前环境下的行业配置建议。",
"id": 1685760000,
"date": "2025-06-22"
},
{
"summary": "跨境资本流动监测指标体系详解包含利差模型、风险偏好指标VIX指数、以及中国特有的资本管制有效性分析。",
"id": 1685846400,
"date": "2025-06-23"
}
]
}
""";
run(input,ModelConstant.SLICE_EVALUATOR_PROMPT);
}
// @Test
public void coreModelTest(){
String input = """
{
"text": "",
"datetime": "2024-03-22T09:00",
"character": "你是一个智能助手,专注于科技领域",
"memory_slices": [
{
"chatMessages": [
{"role": "user", "content": "量子计算近期的进展怎么样?"},
{"role": "assistant", "content": "量子计算在硬件和算法上都取得了突破IBM发布了433量子位处理器Google也在量子优越性上取得了进展。"}
],
"date": "2024-03-20",
"summary": "量子计算最新突破IBM发布433量子位处理器Google在量子优越性上取得进展。"
}
],
"static_memory": "用户对量子计算技术非常感兴趣。",
"dialog_map": {
"2024-03-20T10:30": "与用户讨论了量子计算的最新进展"
},
"user_dialog_map": {
"2024-03-20T10:30": "与用户讨论了量子计算的最新进展"
}
}
""";
run(input,ModelConstant.CORE_MODEL_PROMPT + "\r\n" + MemorySelector.modulePrompt);
}
// @Test
public void map2jsonTest(){
HashMap<LocalDate,String> map = new HashMap<>();
map.put(LocalDate.now(),"hello");
map.put(LocalDate.now().plusDays(1),"world");
System.out.println(JSONUtil.toJsonPrettyStr(map));
}
private void run(String input, String prompt) {
ChatClient client = new ChatClient("https://open.bigmodel.cn/api/paas/v4/chat/completions", "3db444552530b7742b0c53425fb93dcc.LcVwYjByht9AC3N9", "glm-4-flash-250414");
List<Message> messages = new ArrayList<>();
messages.add(new Message(ChatConstant.Character.SYSTEM, prompt));
messages.add(new Message(ChatConstant.Character.USER, input));
System.out.println(client.runChat(messages).getMessage());
}
}

View File

@@ -1,163 +0,0 @@
package memory;
import work.slhaf.agent.core.memory.MemoryGraph;
import work.slhaf.agent.core.memory.node.MemoryNode;
import work.slhaf.agent.core.memory.node.TopicNode;
import work.slhaf.agent.core.memory.pojo.MemorySlice;
import java.io.IOException;
import java.time.LocalDate;
import java.util.Arrays;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;
import static org.junit.Assert.*;
public class InsertTest {
private MemoryGraph memoryGraph;
private final String testId = "test_insert";
String basicCharacter = "";
// @Before
public void setUp() {
memoryGraph = new MemoryGraph(testId, basicCharacter);
memoryGraph.setTopicNodes(new HashMap<>());
memoryGraph.setExistedTopics(new HashMap<>());
}
// @Test
public void testInsertMemory_NewRootTopic() throws IOException, ClassNotFoundException {
// 准备测试数据
List<String> topicPath = new LinkedList<>(Arrays.asList("Programming", "Java", "Collections"));
MemorySlice slice = createTestMemorySlice("slice1");
// 执行测试
memoryGraph.insertMemory(topicPath, slice);
// 验证结果
assertTrue(memoryGraph.getTopicNodes().containsKey("Programming"));
TopicNode programmingNode = memoryGraph.getTopicNodes().get("Programming");
assertTrue(programmingNode.getTopicNodes().containsKey("Java"));
TopicNode javaNode = programmingNode.getTopicNodes().get("Java");
assertTrue(javaNode.getTopicNodes().containsKey("Collections"));
TopicNode collectionsNode = javaNode.getTopicNodes().get("Collections");
assertEquals(1, collectionsNode.getMemoryNodes().size());
MemoryNode memoryNode = collectionsNode.getMemoryNodes().get(0);
assertEquals(LocalDate.now(), memoryNode.getLocalDate());
assertEquals(1, memoryNode.loadMemorySliceList().size());
assertEquals(slice, memoryNode.loadMemorySliceList().get(0));
}
// @Test
public void testInsertMemory_ExistingTopicPath() throws IOException, ClassNotFoundException {
// 准备初始数据
List<String> topicPath1 = new LinkedList<>(Arrays.asList("Programming", "Java", "Collections"));
MemorySlice slice1 = createTestMemorySlice("slice1");
memoryGraph.insertMemory(topicPath1, slice1);
// 插入第二个记忆片段到相同路径
List<String> topicPath2 = new LinkedList<>(Arrays.asList("Programming", "Java", "Collections"));
MemorySlice slice2 = createTestMemorySlice("slice2");
memoryGraph.insertMemory(topicPath2, slice2);
// 验证结果
TopicNode collectionsNode = memoryGraph.getTopicNodes().get("Programming")
.getTopicNodes().get("Java")
.getTopicNodes().get("Collections");
assertEquals(1, collectionsNode.getMemoryNodes().size()); // 同一天应该只有一个MemoryNode
assertEquals(2, collectionsNode.getMemoryNodes().get(0).loadMemorySliceList().size()); // 但有两个MemorySlice
}
// @Test
public void testInsertMemory_DifferentDays() throws IOException, ClassNotFoundException {
// 准备测试数据
List<String> topicPath = new LinkedList<>(Arrays.asList("Math", "Algebra"));
MemorySlice slice1 = createTestMemorySlice("slice1");
MemorySlice slice2 = createTestMemorySlice("slice2");
// 第一次插入
memoryGraph.insertMemory(topicPath, slice1);
// 模拟第二天
MemoryNode firstNode = memoryGraph.getTopicNodes().get("Math")
.getTopicNodes().get("Algebra")
.getMemoryNodes().get(0);
firstNode.setLocalDate(LocalDate.now().minusDays(1));
// 第二次插入
memoryGraph.insertMemory(topicPath, slice2);
// 验证结果
TopicNode algebraNode = memoryGraph.getTopicNodes().get("Math")
.getTopicNodes().get("Algebra");
assertEquals(2, algebraNode.getMemoryNodes().size()); // 应该有两个MemoryNode
}
// @Test
public void testInsertMemory_PartialExistingPath() throws IOException, ClassNotFoundException {
// 准备初始数据 - 创建部分路径
List<String> topicPath1 = new LinkedList<>(Arrays.asList("Science", "Physics"));
MemorySlice slice1 = createTestMemorySlice("slice1");
memoryGraph.insertMemory(topicPath1, slice1);
// 插入到已存在路径的扩展路径
List<String> topicPath2 = new LinkedList<>(Arrays.asList("Science", "Physics", "Mechanics"));
MemorySlice slice2 = createTestMemorySlice("slice2");
memoryGraph.insertMemory(topicPath2, slice2);
// 验证结果
TopicNode physicsNode = memoryGraph.getTopicNodes().get("Science")
.getTopicNodes().get("Physics");
assertTrue(physicsNode.getTopicNodes().containsKey("Mechanics"));
assertEquals(1, physicsNode.getMemoryNodes().size()); // Physics节点有自己的记忆
assertEquals(1, physicsNode.getTopicNodes().get("Mechanics").getMemoryNodes().size()); // Mechanics节点也有记忆
}
private MemorySlice createTestMemorySlice(String id) {
MemorySlice slice = new MemorySlice();
slice.setMemoryId(id);
// 可以设置其他必要属性
return slice;
}
// @Test
public void testSerializationConsistency() throws IOException, ClassNotFoundException {
// 构造 MemorySlice
MemorySlice slice = new MemorySlice();
slice.setMemoryId("001");
List<String> topicPath = Arrays.asList("生活", "学习", "Java");
// 插入 memory
memoryGraph.insertMemory(topicPath, slice);
memoryGraph.serialize();
// 反序列化
MemoryGraph loadedGraph = MemoryGraph.getInstance(testId, "");
// 校验topic 是否存在
assertNotNull(loadedGraph.getTopicNodes().get("生活"));
TopicNode lifeNode = loadedGraph.getTopicNodes().get("生活");
assertNotNull(lifeNode.getTopicNodes().get("学习"));
TopicNode studyNode = lifeNode.getTopicNodes().get("学习");
assertNotNull(studyNode.getTopicNodes().get("Java"));
TopicNode javaNode = studyNode.getTopicNodes().get("Java");
// 校验:是否存在 MemoryNode
assertFalse(javaNode.getMemoryNodes().isEmpty());
// 校验MemorySlice 内容一致
MemorySlice deserializedSlice = javaNode.getMemoryNodes().get(0).loadMemorySliceList().get(0);
assertEquals("001", deserializedSlice.getMemoryId());
}
}

View File

@@ -1,164 +0,0 @@
package memory;
import work.slhaf.agent.core.memory.MemoryGraph;
import work.slhaf.agent.core.memory.exception.UnExistedTopicException;
import work.slhaf.agent.core.memory.node.MemoryNode;
import work.slhaf.agent.core.memory.node.TopicNode;
import work.slhaf.agent.core.memory.pojo.MemorySlice;
import java.io.IOException;
import java.time.LocalDate;
import java.util.ArrayList;
import java.util.List;
import static org.junit.jupiter.api.Assertions.assertThrows;
class SearchTest {
private MemoryGraph memoryGraph;
private final LocalDate today = LocalDate.now();
private final LocalDate yesterday = LocalDate.now().minusDays(1);
// 初始化测试环境,模拟插入基础数据
// @BeforeEach
void setUp() throws IOException, ClassNotFoundException {
memoryGraph = new MemoryGraph("testGraph", "");
// 构建基础主题路径:根主题 -> 编程 -> Java
List<String> javaPath = new ArrayList<>();
javaPath.add("编程");
javaPath.add("Java");
// 插入今天的Java相关记忆
MemorySlice javaMemory = createMemorySlice("java1");
memoryGraph.insertMemory(javaPath, javaMemory);
// 插入昨天的Java记忆应不会出现在邻近结果中
MemorySlice oldJavaMemory = createMemorySlice("javaOld");
MemoryNode oldNode = new MemoryNode();
oldNode.setLocalDate(yesterday);
// oldNode.setMemorySliceList(List.of(oldJavaMemory));
}
// 场景1查询存在的完整主题路径含相关主题
// @Test
void selectMemory_shouldReturnTargetAndRelatedAndParentMemories() throws IOException, ClassNotFoundException {
// 准备相关主题数据:根主题 -> 算法 -> 排序
List<String> sortPath = new ArrayList<>();
sortPath.add("算法");
sortPath.add("排序");
MemorySlice sortMemory = createMemorySlice("sort1");
sortMemory.setRelatedTopics(List.of(
createTopicPath("编程", "Java") // 设置反向关联
));
memoryGraph.insertMemory(sortPath, sortMemory);
// 执行查询:编程 -> Java
List<String> queryPath = new ArrayList<>();
queryPath.add("算法");
queryPath.add("排序");
// MemoryResult results = memoryGraph.selectMemory(queryPath);
// 验证结果应包含:
// 1. 目标节点所有记忆java1
// 2. 相关主题排序的最新记忆sort1
// 3. 父节点(编程)的最新记忆(需要提前插入)
// assertTrue(results.stream().anyMatch(m -> "java1".equals(m.getMemoryId())));
// assertTrue(results.stream().anyMatch(m -> "sort1".equals(m.getMemoryId())));
// assertEquals(2, results.size()); // 根据具体实现可能调整
}
// 场景2查询不存在的主题路径
// @Test
void selectMemory_shouldThrowWhenPathNotExist() {
List<String> invalidPath = new ArrayList<>();
invalidPath.add("不存在的主题");
assertThrows(UnExistedTopicException.class, () -> {
// memoryGraph.selectMemory(invalidPath);
});
}
// 场景3无相关主题时仅返回目标节点和父节点记忆
// @Test
void selectMemory_withoutRelatedTopics_shouldReturnTargetAndParent() throws IOException, ClassNotFoundException {
// 插入父级记忆:根主题 -> 编程
List<String> parentPath = new ArrayList<>();
parentPath.add("编程");
MemorySlice parentMemory = createMemorySlice("parent1");
memoryGraph.insertMemory(parentPath, parentMemory);
// 执行查询
List<String> queryPath = new ArrayList<>();
queryPath.add("编程");
queryPath.add("Java");
// MemoryResult results = memoryGraph.selectMemory(queryPath);
// 应包含Java记忆 + 父级最新记忆
// assertTrue(results.stream().anyMatch(m -> "java1".equals(m.getMemoryId())));
// assertTrue(results.stream().anyMatch(m -> "parent1".equals(m.getMemoryId())));
// assertEquals(2, results.size());
}
// 场景4验证日期排序应优先取最新日期的邻近记忆
// @Test
void selectMemory_shouldGetLatestRelatedMemory() throws IOException, ClassNotFoundException {
// 准备相关主题路径:根主题 -> 数据库
List<String> dbPath = new ArrayList<>();
dbPath.add("数据库");
dbPath.add("mysql");
// 插入今天的数据库记忆(正常流程)
MemorySlice newDbMemory = createMemorySlice("dbNew");
memoryGraph.insertMemory(dbPath, newDbMemory);
// 手动构建并插入昨天的数据库记忆
MemorySlice oldDbMemory = createMemorySlice("dbOld");
TopicNode dbTopicNode = memoryGraph.getTopicNodes().get("数据库");
// 创建昨日记忆节点并添加到主题节点
MemoryNode oldMemoryNode = new MemoryNode();
oldMemoryNode.setLocalDate(yesterday);
// oldMemoryNode.setMemorySliceList(new ArrayList<>(List.of(oldDbMemory)));
dbTopicNode.getMemoryNodes().add(oldMemoryNode);
// 对记忆节点进行日期排序根据compareTo方法
dbTopicNode.getMemoryNodes().sort(null);
// 创建Java记忆并关联数据库主题
MemorySlice javaMemory = createMemorySlice("java2");
javaMemory.setRelatedTopics(List.of(
createTopicPath("数据库","") // 完整主题路径
));
memoryGraph.insertMemory(createTopicPath("编程", "Java"), javaMemory);
// 执行查询
List<String> queryPath = createTopicPath("编程", "Java");
// MemoryResult results = memoryGraph.selectMemory(queryPath);
// 验证结果应包含最新关联记忆dbNew
// assertTrue(results.stream().anyMatch(m -> "dbNew".equals(m.getMemoryId())),
// "应包含最新的数据库记忆");
// assertFalse(results.stream().anyMatch(m -> "dbOld".equals(m.getMemoryId())),
// "不应包含过期的数据库记忆");
//
// 验证结果包含目标记忆java1和java2
// assertTrue(results.stream().anyMatch(m -> "java1".equals(m.getMemoryId())),
// "应包含基础测试数据");
// assertTrue(results.stream().anyMatch(m -> "java2".equals(m.getMemoryId())),
// "应包含当前测试插入数据");
}
private MemorySlice createMemorySlice(String id) {
MemorySlice slice = new MemorySlice();
slice.setMemoryId(id);
return slice;
}
private ArrayList<String> createTopicPath(String... topics) {
ArrayList<String> path = new ArrayList<>();
for (String topic : topics) {
path.add(topic);
}
return path;
}
}