- 更新了WebSocket服务器的启动逻辑

- 发现了agent, websocket, interactionHub之间的循环引用导致IDEA调试出错问题,通过exclude解决
- 实现了CoreModel的execute执行逻辑,并且系统提示词将动态拼接以适应不同模块
- 移动EvaluatedSlice至shared/memory包下,避免层级混淆
- 提取清洗json方法至独立的工具类
- 将agent通过InputReceiver接口暴露至socketServer,而非直接交给其完整实例
- 调整模块加载时机->InteractionHub加载时进行加载
- 调整MemoryGraph中userDialogMap的结构,换用以用户id为主键
- 初步进行测试,记忆更新逻辑暂未实现
This commit is contained in:
2025-04-25 23:08:01 +08:00
parent 4e28adbc52
commit a83cf26f40
26 changed files with 328 additions and 82 deletions

1
.gitignore vendored
View File

@@ -38,3 +38,4 @@ build/
.DS_Store
/data/
/config/
/src/test/java/memory/test.json

View File

@@ -1,20 +1,14 @@
package work.slhaf;
import work.slhaf.agent.Agent;
import work.slhaf.agent.core.interaction.data.InteractionInputData;
import java.io.IOException;
import java.util.Scanner;
public class Main {
public static void main(String[] args) throws IOException, ClassNotFoundException, InterruptedException {
Agent agent = Agent.initialize();
InteractionInputData inputData = new InteractionInputData();
inputData.setContent("hello");
inputData.setPlatform("cli");
inputData.setUserInfo("owner");
inputData.setUserNickName("master");
agent.receiveUserInput(inputData);
public static void main(String[] args) throws IOException {
Agent.initialize();
Scanner scanner = new Scanner(System.in);
while (!scanner.nextLine().equals("exit"));
}
}

View File

@@ -2,8 +2,10 @@ package work.slhaf.agent;
import lombok.Data;
import lombok.extern.slf4j.Slf4j;
import org.java_websocket.WebSocket;
import work.slhaf.agent.common.config.Config;
import work.slhaf.agent.core.InteractionHub;
import work.slhaf.agent.core.interaction.InputReceiver;
import work.slhaf.agent.core.interaction.TaskCallback;
import work.slhaf.agent.core.interaction.data.InteractionInputData;
import work.slhaf.agent.core.interaction.data.InteractionOutputData;
@@ -15,7 +17,7 @@ import java.time.LocalDateTime;
@Data
@Slf4j
public class Agent implements TaskCallback {
public class Agent implements TaskCallback, InputReceiver {
private static Agent agent;
private InteractionHub interactionHub;
@@ -28,7 +30,22 @@ public class Agent implements TaskCallback {
agent = new Agent();
agent.setInteractionHub(InteractionHub.initialize());
agent.registerTaskCallback();
agent.setMessageSender(new AgentWebSocketServer(config.getWebSocketConfig().getPort(),agent));
AgentWebSocketServer server = new AgentWebSocketServer(config.getWebSocketConfig().getPort(),agent);
server.start();
Runtime.getRuntime().addShutdownHook(new Thread(() -> {
try {
for (WebSocket conn : server.getConnections()) {
conn.close();
}
server.stop();
log.info("WebSocketServer 已优雅关闭");
} catch (Exception e) {
log.error("关闭失败", e);
}
}));
agent.setMessageSender(server);
log.info("Agent 加载完毕..");
}
return agent;
@@ -36,9 +53,8 @@ public class Agent implements TaskCallback {
/**
* 接收用户输入,包装为标准输入数据类
* @param inputData
*/
public void receiveUserInput(InteractionInputData inputData) throws IOException, ClassNotFoundException, InterruptedException {
public void receiveInput(InteractionInputData inputData) throws IOException, ClassNotFoundException, InterruptedException {
inputData.setLocalDateTime(LocalDateTime.now());
interactionHub.call(inputData);
}
@@ -46,11 +62,10 @@ public class Agent implements TaskCallback {
/**
* 向用户返回输出内容
* @param output
*/
public void sendToUser(String userInfo,String output){
System.out.println(output);
messageSender.sendMessage(new InteractionOutputData(userInfo,output));
messageSender.sendMessage(new InteractionOutputData(output,userInfo));
}
@Override

View File

@@ -1,10 +1,8 @@
package work.slhaf.agent.common.chat.pojo;
import com.alibaba.fastjson2.JSONObject;
import lombok.*;
import java.util.List;
import java.util.Map;
@Builder
@Data

View File

@@ -66,7 +66,6 @@ public class Config {
private static void generatePipelineConfig() {
List<ModuleConfig> moduleConfigList = List.of(
new ModuleConfig(MemorySelectExtractor.class.getName(), ModuleConfig.Constant.INTERNAL, null),
new ModuleConfig(MemorySelector.class.getName(), ModuleConfig.Constant.INTERNAL, null),
new ModuleConfig(CoreModel.class.getName(),ModuleConfig.Constant.INTERNAL,null),
new ModuleConfig(MemoryUpdater.class.getName(),ModuleConfig.Constant.INTERNAL,null),

View File

@@ -39,8 +39,7 @@ public class Model {
model.setChatClient(new ChatClient(modelConfig.getBaseUrl(), modelConfig.getApikey(), modelConfig.getModel()));
}
public ChatResponse runChat(String input) {
this.messages.add(new Message(ChatConstant.Character.USER, input));
public ChatResponse chat() {
return this.chatClient.runChat(this.messages);
}

View File

@@ -2,9 +2,60 @@ package work.slhaf.agent.common.model;
public class ModelConstant {
public static final String CORE_MODEL_PROMPT = """
CoreModel 提示词
功能说明
你需要根据用户的当前输入text生成恰当的回复。只有当以下字段与text内容直接相关时才需要参考它们
- datetime当text包含时间相关语义时使用
- character当需要根据角色设定调整语气时使用
- user_nick当text中包含对用户的称呼或个性化需求时使用
其他所有字段仅在明确与text内容相关时才予以考虑否则应完全忽略。
输入字段优先级
1. 首要关注text字段这是核心输入内容
2. 次要字段(有条件参考):
• datetime仅当text包含时间表达时生效
• character仅当角色设定会影响回复风格时生效
• user_nick仅当需要个性化称呼时生效
3. 其他所有扩展字段如memory_slices/static_memory等
- 必须与text内容有明确关联时才参考
- 若字段内容与text无关则完全忽略该字段
核心生成逻辑
1. 主内容优先原则
- 首先独立分析text字段的语义
- 只有当其他字段内容能直接辅助理解text时如text说"上次说的那个"对应memory_slices中的记录才调用相关字段
- 若text是独立完整表达如单字、短句、新话题开启则忽略所有非核心字段
2. 无关字段过滤机制
- 当text属于以下情况时强制忽略所有扩展字段
✓ 短于5个字符的输入"""好的"
✓ 明显开启新话题的提问(如"量子计算是什么"
✓ 不含指代词的独立陈述句
- 示例当text="今天天气如何"即使存在量子计算相关的memory_slices也应忽略
3. 响应生成规范
- 回复必须完全基于text的核心语义生成
- 禁止出现"根据您之前提到的XX"等无关内容引用
- 当角色设定(character)与当前对话无关时(如科技助手回答日常问候),暂时覆盖角色设定
输出格式
{
"text": "响应内容" // 必须严格对应text字段的语义
}
最终注意事项
1. 回应内容必须紧扣用户输入,且契合角色设定
2. 遇到模糊提问时,优先推测最常见的语境理解,不要直接问“你指的是什么”
3. 回应应自然衔接,并允许后续系统模块追加更多限定、扩展字段
4. 你只需要生成JSON格式的响应对象字段仅包含`text`,但在模块扩展下,字段内容可以有所增加。确保你可以兼容这些扩展而不破坏结构。
5. 若用户的输入(text)与其他字段中的内容无关,可忽略其他字段的内容
> 以下模块可能会追加更多内容限制或上下文提示,请确保你的回答能够自然兼容这些后续拼接的内容,并调整输出格式。
""";
public static final String SLICE_EVALUATOR_PROMPT = """
记忆切片选择器提示词(最终版)
SliceEvaluator 提示词
功能说明
你需要根据用户输入的JSON数据分析其中的`text`(当前输入内容)、`history`(对话历史)和`memory_slices`(可用记忆切片)选出相关记忆切片。当text内容与history明显不相关时应以text为主要判断依据。

View File

@@ -0,0 +1,12 @@
package work.slhaf.agent.common.util;
public class ExtractUtil {
public static String extractJson(String jsonStr) {
int start = jsonStr.indexOf("{");
int end = jsonStr.lastIndexOf("}");
if (start != -1 && end != -1 && start < end) {
return jsonStr.substring(start, end + 1);
}
return jsonStr;
}
}

View File

@@ -1,6 +1,7 @@
package work.slhaf.agent.core;
import lombok.Data;
import lombok.ToString;
import lombok.extern.slf4j.Slf4j;
import work.slhaf.agent.core.interaction.InteractionModule;
import work.slhaf.agent.core.interaction.InteractionModulesLoader;
@@ -21,15 +22,18 @@ public class InteractionHub {
private static InteractionHub interactionHub;
@ToString.Exclude
private TaskCallback callback;
private CoreModel coreModel;
private MemoryManager memoryManager;
private TaskScheduler taskScheduler;
private List<InteractionModule> interactionModules;
public static InteractionHub initialize() throws IOException {
if (interactionHub == null) {
interactionHub = new InteractionHub();
//加载模块
interactionHub.setInteractionModules(InteractionModulesLoader.getInstance().registerInteractionModules());
log.info("InteractionHub注册完毕...");
}
return interactionHub;
@@ -38,11 +42,10 @@ public class InteractionHub {
public void call(InteractionInputData inputData) throws IOException, ClassNotFoundException, InterruptedException {
//预处理
InteractionContext interactionContext = PreprocessExecutor.getInstance().execute(inputData);
//加载模块
List<InteractionModule> interactionModules = InteractionModulesLoader.getInstance().registerInteractionModules();
for (InteractionModule interactionModule : interactionModules) {
interactionModule.execute(interactionContext);
}
callback.onTaskFinished(interactionContext.getUserInfo(), interactionContext.getCoreResponse().getString("message"));
callback.onTaskFinished(interactionContext.getUserInfo(), interactionContext.getCoreResponse().getString("text"));
}
}

View File

@@ -0,0 +1,10 @@
package work.slhaf.agent.core.interaction;
import work.slhaf.agent.core.interaction.data.InteractionInputData;
import java.io.IOException;
public interface InputReceiver {
void receiveInput(InteractionInputData inputData) throws IOException, ClassNotFoundException, InterruptedException;
}

View File

@@ -15,5 +15,6 @@ public class InteractionContext {
protected String input;
protected JSONObject moduleContext;
protected JSONObject modulePrompt;
protected JSONObject coreResponse;
}

View File

@@ -58,7 +58,7 @@ public class MemoryGraph extends PersistableObject {
/**
* 近两日的区分用户的对话总结缓存在prompt结构上比dialogMap层级深一层, dialogMap更具近两日整体对话的摘要性质
*/
private ConcurrentHashMap<LocalDateTime, ConcurrentHashMap<String/*userId*/, String>> userDialogMap;
private ConcurrentHashMap<String/*userId*/, ConcurrentHashMap<LocalDateTime, String>> userDialogMap;
/**
* 当前对话的活动性总结, 拥有比dialogMap更丰富的全文细节, 作为当前对话token超限时的必要上下文压缩存储
@@ -92,6 +92,8 @@ public class MemoryGraph extends PersistableObject {
*/
private HashMap<String, String> modelPrompt;
private String character;
/**
* 主模型的聊天记录
*/
@@ -117,6 +119,10 @@ public class MemoryGraph extends PersistableObject {
this.memorySliceCache = new ConcurrentHashMap<>();
this.modelPrompt = new HashMap<>();
this.selectedSlices = new HashSet<>();
this.users = new ArrayList<>();
this.userDialogMap = new ConcurrentHashMap<>();
this.currentCompressedSessionContext = new ArrayList<>();
this.dialogMap = new HashMap<>();
}
public static MemoryGraph getInstance(String id) throws IOException, ClassNotFoundException {
@@ -126,7 +132,7 @@ public class MemoryGraph extends PersistableObject {
Path filePath = getFilePath(id);
if (Files.exists(filePath)) {
memoryGraph = deserialize(id);
}else {
} else {
FileUtils.createParentDirectories(filePath.toFile().getParentFile());
memoryGraph = new MemoryGraph(id);
memoryGraph.serialize();
@@ -206,7 +212,9 @@ public class MemoryGraph extends PersistableObject {
}
updateDateIndex(now, slice);
updateDialogMap(slice);
if (!slice.isPrivate()) {
updateUserDialogMap(slice);
}
node.saveMemorySliceList();
}
@@ -241,7 +249,7 @@ public class MemoryGraph extends PersistableObject {
return lastTopicNode;
}
private void updateDialogMap(MemorySlice slice) {
private void updateUserDialogMap(MemorySlice slice) {
String summary = slice.getSummary();
LocalDateTime now = LocalDateTime.now();
@@ -264,17 +272,21 @@ public class MemoryGraph extends PersistableObject {
//更新userDialogMap
//移除两天前上下文缓存(切片总结)
userDialogMap.forEach((k, v) -> {
if (now.minusDays(2).isAfter(k)) {
keysToRemove.add(k);
v.forEach((i, j) -> {
if (now.minusDays(2).isAfter(i)) {
keysToRemove.add(i);
}
});
});
for (LocalDateTime dateTime : keysToRemove) {
userDialogMap.remove(dateTime);
userDialogMap.forEach((k, v) -> {
v.remove(dateTime);
});
}
//放入新缓存
userDialogMap
.computeIfAbsent(now, k -> new ConcurrentHashMap<>())
.merge(slice.getStartUserId(), slice.getSummary(), (oldVal, newVal) -> oldVal + " " + newVal);
.computeIfAbsent(slice.getStartUserId(), k -> new ConcurrentHashMap<>())
.merge(now, slice.getSummary(), (oldVal, newVal) -> oldVal + " " + newVal);
}
@@ -298,6 +310,8 @@ public class MemoryGraph extends PersistableObject {
//排序
memorySliceList.sort(null);
MemorySlice tempSlice = memorySliceList.getLast();
//设置私密状态一致
tempSlice.setPrivate(slice.isPrivate());
//末尾切片添加当前切片的引用
tempSlice.setSliceAfter(slice);
//当前切片添加前序切片的引用
@@ -329,7 +343,7 @@ public class MemoryGraph extends PersistableObject {
for (MemoryNode memoryNode : targetParentNode.getTopicNodes().get(targetTopic).getMemoryNodes()) {
List<MemorySlice> endpointMemorySliceList = memoryNode.loadMemorySliceList();
for (MemorySlice memorySlice : endpointMemorySliceList) {
if (selectedSlices.contains(memorySlice.getTimestamp())){
if (selectedSlices.contains(memorySlice.getTimestamp())) {
continue;
}
sliceResult.setSliceBefore(memorySlice.getSliceBefore());
@@ -410,7 +424,7 @@ public class MemoryGraph extends PersistableObject {
CopyOnWriteArrayList<MemorySliceResult> targetSliceList = new CopyOnWriteArrayList<>();
for (List<MemorySlice> value : dateIndex.get(date).values()) {
for (MemorySlice memorySlice : value) {
if (selectedSlices.contains(memorySlice.getTimestamp())){
if (selectedSlices.contains(memorySlice.getTimestamp())) {
continue;
}
MemorySliceResult memorySliceResult = new MemorySliceResult();
@@ -444,24 +458,26 @@ public class MemoryGraph extends PersistableObject {
return targetParentNode;
}
public void printTopicTree() {
public String getTopicTree() {
StringBuilder stringBuilder = new StringBuilder();
for (Map.Entry<String, TopicNode> entry : topicNodes.entrySet()) {
String rootName = entry.getKey();
TopicNode rootNode = entry.getValue();
System.out.println(rootName+"[root]");
printSubTopicsTreeFormat(rootNode, "", true);
stringBuilder.append(rootName).append("[root]").append("\r\n");
printSubTopicsTreeFormat(rootNode, "", stringBuilder);
}
return stringBuilder.toString();
}
private void printSubTopicsTreeFormat(TopicNode node, String prefix, boolean isLast) {
private void printSubTopicsTreeFormat(TopicNode node, String prefix, StringBuilder stringBuilder) {
if (node.getTopicNodes() == null || node.getTopicNodes().isEmpty()) return;
List<Map.Entry<String, TopicNode>> entries = new ArrayList<>(node.getTopicNodes().entrySet());
for (int i = 0; i < entries.size(); i++) {
boolean last = (i == entries.size() - 1);
Map.Entry<String, TopicNode> entry = entries.get(i);
System.out.println(prefix + (last ? "└── " : "├── ") + entry.getKey());
printSubTopicsTreeFormat(entry.getValue(), prefix + (last ? " " : ""), last);
stringBuilder.append(prefix).append(last ? "└── " : "├── ").append(entry.getKey()).append("\r\n");
printSubTopicsTreeFormat(entry.getValue(), prefix + (last ? " " : ""), stringBuilder);
}
}

View File

@@ -8,13 +8,16 @@ import work.slhaf.agent.core.interaction.InteractionModule;
import work.slhaf.agent.core.interaction.data.InteractionContext;
import work.slhaf.agent.core.memory.pojo.MemoryResult;
import work.slhaf.agent.core.memory.pojo.User;
import work.slhaf.agent.modules.memory.SliceEvaluator;
import work.slhaf.agent.shared.memory.EvaluatedSlice;
import java.io.IOException;
import java.time.LocalDate;
import java.time.LocalDateTime;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.UUID;
import java.util.concurrent.ConcurrentHashMap;
@Data
@Slf4j
@@ -23,7 +26,7 @@ public class MemoryManager implements InteractionModule {
private static MemoryManager memoryManager;
private MemoryGraph memoryGraph;
private SliceEvaluator sliceEvaluator;
private HashMap<String,List<EvaluatedSlice>> activatedSlices;
private MemoryManager(){}
@@ -37,7 +40,7 @@ public class MemoryManager implements InteractionModule {
Config config = Config.getConfig();
memoryManager = new MemoryManager();
memoryManager.setMemoryGraph(MemoryGraph.getInstance(config.getAgentId()));
memoryManager.setSliceEvaluator(SliceEvaluator.getInstance());
memoryManager.setActivatedSlices(new HashMap<>());
log.info("MemoryManager注册完毕...");
}
return memoryManager;
@@ -85,6 +88,22 @@ public class MemoryManager implements InteractionModule {
}
public String getTopicTree() {
return memoryManager.getTopicTree();
return memoryGraph.getTopicTree();
}
public ConcurrentHashMap<String,String> getStaticMemory(String userId) {
return memoryGraph.getStaticMemory().get(userId);
}
public HashMap<LocalDateTime, String> getDialogMap() {
return memoryGraph.getDialogMap();
}
public ConcurrentHashMap<LocalDateTime, String> getUserDialogMap(String userId) {
return memoryGraph.getUserDialogMap().get(userId);
}
public String getCharacter() {
return memoryGraph.getCharacter();
}
}

View File

@@ -11,7 +11,6 @@ import java.io.*;
import java.nio.file.Files;
import java.nio.file.Path;
import java.time.LocalDate;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.CopyOnWriteArrayList;

View File

@@ -1,17 +1,23 @@
package work.slhaf.agent.core.module;
import com.alibaba.fastjson2.JSONObject;
import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.extern.slf4j.Slf4j;
import work.slhaf.agent.common.chat.constant.ChatConstant;
import work.slhaf.agent.common.chat.pojo.ChatResponse;
import work.slhaf.agent.common.chat.pojo.Message;
import work.slhaf.agent.common.config.Config;
import work.slhaf.agent.common.model.Model;
import work.slhaf.agent.common.model.ModelConstant;
import work.slhaf.agent.core.interaction.InteractionModule;
import work.slhaf.agent.core.interaction.data.InteractionContext;
import work.slhaf.agent.core.memory.MemoryManager;
import java.io.IOException;
import static work.slhaf.agent.common.util.ExtractUtil.extractJson;
@EqualsAndHashCode(callSuper = true)
@Data
@Slf4j
@@ -20,6 +26,9 @@ public class CoreModel extends Model implements InteractionModule {
public static final String MODEL_KEY = "core_model";
private static CoreModel coreModel;
private MemoryManager memoryManager;
private String promptCache;
private CoreModel() {
}
@@ -27,6 +36,8 @@ public class CoreModel extends Model implements InteractionModule {
if (coreModel == null) {
Config config = Config.getConfig();
coreModel = new CoreModel();
coreModel.memoryManager = MemoryManager.getInstance();
coreModel.messages = coreModel.memoryManager.getChatMessages();
setModel(config, coreModel, MODEL_KEY, ModelConstant.CORE_MODEL_PROMPT);
log.info("CoreModel注册完毕...");
}
@@ -35,9 +46,35 @@ public class CoreModel extends Model implements InteractionModule {
@Override
public void execute(InteractionContext interactionContext) {
//TODO 需要拼接上下文之后再发送给主模型
String tempPrompt = interactionContext.getModulePrompt().toString();
if (!tempPrompt.equals(promptCache)) {
coreModel.getMessages().set(0, new Message(ChatConstant.Character.SYSTEM, ModelConstant.CORE_MODEL_PROMPT + "\r\n" + tempPrompt));
promptCache = tempPrompt;
}
this.messages.add(new Message(ChatConstant.Character.USER, interactionContext.getModuleContext().getString("text")));
ChatResponse chatResponse = this.chat();
JSONObject response = null;
int count = 0;
while (true) {
try {
response = JSONObject.parse(extractJson(chatResponse.getMessage()));
this.messages.add(new Message(ChatConstant.Character.ASSISTANT, response.getString("text")));
ChatResponse res = runChat(interactionContext.getInput());
// interactionContext.setCoreResponse();
//设置上下文
interactionContext.getModuleContext().put("total_token", chatResponse.getUsageBean().getTotal_tokens());
break;
} catch (Exception e) {
count++;
log.error("CoreModel执行异常: {}", e.getLocalizedMessage());
if (count > 3) {
response = new JSONObject();
response.put("text", "主模型交互出错: " + e.getLocalizedMessage());
interactionContext.setFinished(true);
break;
}
} finally {
interactionContext.setCoreResponse(response);
}
}
}
}

View File

@@ -2,11 +2,12 @@ package work.slhaf.agent.gateway;
import cn.hutool.json.JSONUtil;
import com.alibaba.fastjson2.JSONObject;
import lombok.ToString;
import lombok.extern.slf4j.Slf4j;
import org.java_websocket.WebSocket;
import org.java_websocket.handshake.ClientHandshake;
import org.java_websocket.server.WebSocketServer;
import work.slhaf.agent.Agent;
import work.slhaf.agent.core.interaction.InputReceiver;
import work.slhaf.agent.core.interaction.data.InteractionInputData;
import work.slhaf.agent.core.interaction.data.InteractionOutputData;
@@ -17,12 +18,13 @@ import java.util.concurrent.ConcurrentHashMap;
@Slf4j
public class AgentWebSocketServer extends WebSocketServer implements MessageSender {
private final Agent agent;
@ToString.Exclude
private final InputReceiver receiver;
private final ConcurrentHashMap<String, WebSocket> userSessions = new ConcurrentHashMap<>();
public AgentWebSocketServer(int port, Agent agent) {
public AgentWebSocketServer(int port, InputReceiver receiver) {
super(new InetSocketAddress(port));
this.agent = agent;
this.receiver = receiver;
}
@Override
@@ -41,7 +43,7 @@ public class AgentWebSocketServer extends WebSocketServer implements MessageSend
InteractionInputData inputData = JSONObject.parseObject(s, InteractionInputData.class);
userSessions.put(inputData.getUserInfo(), webSocket); // 注册连接
try {
agent.receiveUserInput(inputData);
receiver.receiveInput(inputData);
} catch (IOException | ClassNotFoundException | InterruptedException e) {
throw new RuntimeException(e);
}

View File

@@ -16,6 +16,8 @@ import work.slhaf.agent.modules.memory.data.extractor.ExtractorResult;
import java.io.IOException;
import java.util.List;
import static work.slhaf.agent.common.util.ExtractUtil.extractJson;
@EqualsAndHashCode(callSuper = true)
@Data
@Slf4j
@@ -47,7 +49,7 @@ public class MemorySelectExtractor extends Model {
.history(memoryManager.getChatMessages())
.topic_tree(memoryManager.getTopicTree())
.build();
String responseStr = singleChat(JSONUtil.toJsonPrettyStr(extractorInput)).getMessage();
String responseStr = extractJson(singleChat(JSONUtil.toJsonPrettyStr(extractorInput)).getMessage());
ExtractorResult extractorResult;
try {

View File

@@ -6,11 +6,10 @@ import work.slhaf.agent.core.interaction.data.InteractionContext;
import work.slhaf.agent.core.memory.MemoryManager;
import work.slhaf.agent.core.memory.pojo.MemoryResult;
import work.slhaf.agent.core.memory.pojo.MemorySlice;
import work.slhaf.agent.modules.memory.data.evaluator.EvaluatedSlice;
import work.slhaf.agent.modules.memory.data.evaluator.EvaluatorInput;
import work.slhaf.agent.modules.memory.data.evaluator.EvaluatorResult;
import work.slhaf.agent.modules.memory.data.extractor.ExtractorMatchData;
import work.slhaf.agent.modules.memory.data.extractor.ExtractorResult;
import work.slhaf.agent.shared.memory.EvaluatedSlice;
import java.io.IOException;
import java.time.LocalDate;
@@ -21,6 +20,29 @@ import java.util.List;
public class MemorySelector implements InteractionModule {
private static MemorySelector memorySelector;
public static final String modulePrompt = """
新增输入字段:
"memory_slices": [{ //记忆切片,可能为多个
"chatMessages": [{
"role": "user"/"assistant", //该信息发送者
"content": "消息内容"
}],
"date": "2024-03-20", //切片日期
"summary": "切片总结"
}],
"static_memory": "对于该用户的常识性记忆,如爱好、住处、生日",
"dialog_map": { //近两日的与所有用户的对话缓存
"2023-01-01T11:30": "发生了...与用户A...、用户B谈到...",
"2023-01-02T11:30": "发生了...与用户A...、用户B谈到..."
}
"user_dialog_map": { //与当前用户的近两日对话缓存
"2023-01-01T11:30": "与用户讨论了...",
"2023-01-02T11:30": "与用户讨论了..."
}
无新增输出字段
""";
private MemoryManager memoryManager;
private SliceEvaluator sliceEvaluator;
@@ -41,12 +63,13 @@ public class MemorySelector implements InteractionModule {
@Override
public void execute(InteractionContext interactionContext) throws IOException, ClassNotFoundException, InterruptedException {
String userId = memoryManager.getUserId(interactionContext.getUserInfo(), interactionContext.getUserNickname());
//获取主题路径
ExtractorResult extractorResult = memorySelectExtractor.execute(interactionContext);
if (extractorResult.isRecall()) {
//查找切片
List<MemoryResult> memoryResultList = new ArrayList<>();
setMemoryResultList(memoryResultList, extractorResult.getMatches(), interactionContext.getUserInfo(), interactionContext.getUserNickname());
setMemoryResultList(memoryResultList, extractorResult.getMatches(),userId);
//评估切片
EvaluatorInput evaluatorInput = EvaluatorInput.builder()
.input(interactionContext.getInput())
@@ -54,13 +77,19 @@ public class MemorySelector implements InteractionModule {
.messages(memoryManager.getChatMessages())
.build();
List<EvaluatedSlice> memorySlices = sliceEvaluator.execute(evaluatorInput);
memoryManager.getActivatedSlices().put(userId,memorySlices);
}
//设置上下文
interactionContext.getModuleContext().put("memory_slices",memorySlices);
interactionContext.getModuleContext().put("memory_slices",memoryManager.getActivatedSlices().get(userId));
interactionContext.getModuleContext().put("static_memory",memoryManager.getStaticMemory(userId));
interactionContext.getModuleContext().put("dialog_map",memoryManager.getDialogMap());
interactionContext.getModuleContext().put("user_dialog_map",memoryManager.getUserDialogMap(userId));
interactionContext.getModulePrompt().put("memory", modulePrompt);
}
}
private void setMemoryResultList(List<MemoryResult> memoryResultList, List<ExtractorMatchData> matches, String userInfo, String nickName) throws IOException, ClassNotFoundException {
private void setMemoryResultList(List<MemoryResult> memoryResultList, List<ExtractorMatchData> matches, String userId) throws IOException, ClassNotFoundException {
for (ExtractorMatchData match : matches) {
MemoryResult memoryResult = switch (match.getType()) {
case ExtractorMatchData.Constant.DATE -> memoryManager.selectMemory(match.getText());
@@ -76,15 +105,14 @@ public class MemorySelector implements InteractionModule {
//根据userInfo过滤是否为私人记忆
for (MemoryResult memoryResult : memoryResultList) {
//过滤终点记忆
memoryResult.getMemorySliceResult().removeIf(m -> removeOrNot(m.getMemorySlice(), userInfo, nickName));
memoryResult.getMemorySliceResult().removeIf(m -> removeOrNot(m.getMemorySlice(), userId));
//过滤邻近记忆
memoryResult.getRelatedMemorySliceResult().removeIf(m -> removeOrNot(m, userInfo, nickName));
memoryResult.getRelatedMemorySliceResult().removeIf(m -> removeOrNot(m, userId));
}
}
private boolean removeOrNot(MemorySlice memorySlice, String userInfo, String nickName) {
private boolean removeOrNot(MemorySlice memorySlice, String userId) {
if (memorySlice.isPrivate()) {
String userId = memoryManager.getUserId(userInfo, nickName);
return memorySlice.getStartUserId().equals(userId);
}
return true;

View File

@@ -30,6 +30,18 @@ public class MemoryUpdater implements InteractionModule {
@Override
public void execute(InteractionContext interactionContext) {
if (interactionContext.isFinished()){
return;
}
//如果token 大于阈值,则更新记忆
if (interactionContext.getModuleContext().getIntValue("total_token") > 24000) {
executor.execute(() -> {
});
}
//更新确定性记忆
}
}

View File

@@ -14,7 +14,11 @@ import work.slhaf.agent.core.memory.MemoryManager;
import work.slhaf.agent.core.memory.pojo.MemoryResult;
import work.slhaf.agent.core.memory.pojo.MemorySlice;
import work.slhaf.agent.core.memory.pojo.MemorySliceResult;
import work.slhaf.agent.modules.memory.data.evaluator.*;
import work.slhaf.agent.modules.memory.data.evaluator.EvaluatorBatchInput;
import work.slhaf.agent.modules.memory.data.evaluator.EvaluatorInput;
import work.slhaf.agent.modules.memory.data.evaluator.EvaluatorResult;
import work.slhaf.agent.modules.memory.data.evaluator.SliceSummary;
import work.slhaf.agent.shared.memory.EvaluatedSlice;
import java.io.IOException;
import java.util.*;
@@ -22,6 +26,8 @@ import java.util.concurrent.Callable;
import java.util.concurrent.ConcurrentLinkedDeque;
import java.util.concurrent.TimeUnit;
import static work.slhaf.agent.common.util.ExtractUtil.extractJson;
@EqualsAndHashCode(callSuper = true)
@Data
@Slf4j
@@ -63,7 +69,7 @@ public class SliceEvaluator extends Model {
.memory_slices(sliceSummaryList)
.history(evaluatorInput.getMessages())
.build();
EvaluatorResult evaluatorResult = JSONObject.parseObject(singleChat(JSONUtil.toJsonStr(batchInput)).getMessage(), EvaluatorResult.class);
EvaluatorResult evaluatorResult = JSONObject.parseObject(extractJson(singleChat(JSONUtil.toJsonStr(batchInput)).getMessage()), EvaluatorResult.class);
for (Long result : evaluatorResult.getResults()) {
SliceSummary sliceSummary = map.get(result);
EvaluatedSlice evaluatedSlice = EvaluatedSlice.builder()
@@ -73,7 +79,7 @@ public class SliceEvaluator extends Model {
queue.offer(evaluatedSlice);
}
} catch (Exception e) {
log.error("切片评估: {}", e.getLocalizedMessage());
log.error("切片评估出现错误: {}", e.getLocalizedMessage());
}
return null;
});

View File

@@ -41,7 +41,10 @@ public class PreprocessExecutor {
context.setModuleContext(new JSONObject());
context.getModuleContext().put("text", inputData.getContent());
context.getModuleContext().put("datetime", LocalDateTime.now().format(DateTimeFormatter.ISO_LOCAL_DATE_TIME));
context.getModuleContext().put("character",memoryManager.getMemoryGraph().getModelPrompt());
context.getModuleContext().put("character",memoryManager.getCharacter());
context.getModuleContext().put("user_nick", inputData.getUserNickName());
context.setModulePrompt(new JSONObject());
return context;
}

View File

@@ -1,4 +1,4 @@
package work.slhaf.agent.modules.memory.data.evaluator;
package work.slhaf.agent.shared.memory;
import lombok.Builder;
import lombok.Data;
@@ -10,7 +10,7 @@ import java.util.List;
@Data
@Builder
public class EvaluatedSlice {
// private List<Message> chatMessages;
private List<Message> chatMessages;
private LocalDate date;
private String summary;
}

View File

@@ -1,12 +1,16 @@
package memory;
import cn.hutool.json.JSONUtil;
import org.junit.jupiter.api.Test;
import work.slhaf.agent.common.chat.ChatClient;
import work.slhaf.agent.common.chat.constant.ChatConstant;
import work.slhaf.agent.common.chat.pojo.Message;
import work.slhaf.agent.common.model.ModelConstant;
import work.slhaf.agent.modules.memory.MemorySelector;
import java.time.LocalDate;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
public class AITest {
@@ -94,6 +98,44 @@ public class AITest {
run(input,ModelConstant.SLICE_EVALUATOR_PROMPT);
}
@Test
public void coreModelTest(){
String input = """
{
"text": "",
"datetime": "2024-03-22T09:00",
"character": "你是一个智能助手,专注于科技领域",
"memory_slices": [
{
"chatMessages": [
{"role": "user", "content": "量子计算近期的进展怎么样?"},
{"role": "assistant", "content": "量子计算在硬件和算法上都取得了突破IBM发布了433量子位处理器Google也在量子优越性上取得了进展。"}
],
"date": "2024-03-20",
"summary": "量子计算最新突破IBM发布433量子位处理器Google在量子优越性上取得进展。"
}
],
"static_memory": "用户对量子计算技术非常感兴趣。",
"dialog_map": {
"2024-03-20T10:30": "与用户讨论了量子计算的最新进展"
},
"user_dialog_map": {
"2024-03-20T10:30": "与用户讨论了量子计算的最新进展"
}
}
""";
run(input,ModelConstant.CORE_MODEL_PROMPT + "\r\n" + MemorySelector.modulePrompt);
}
@Test
public void map2jsonTest(){
HashMap<LocalDate,String> map = new HashMap<>();
map.put(LocalDate.now(),"hello");
map.put(LocalDate.now().plusDays(1),"world");
System.out.println(JSONUtil.toJsonPrettyStr(map));
}
private void run(String input, String prompt) {
ChatClient client = new ChatClient("https://open.bigmodel.cn/api/paas/v4/chat/completions", "3db444552530b7742b0c53425fb93dcc.LcVwYjByht9AC3N9", "glm-4-flash-250414");
List<Message> messages = new ArrayList<>();

View File

@@ -3,9 +3,9 @@ package memory;
import org.junit.Before;
import org.junit.Test;
import work.slhaf.agent.core.memory.MemoryGraph;
import work.slhaf.agent.core.memory.pojo.MemorySlice;
import work.slhaf.agent.core.memory.node.MemoryNode;
import work.slhaf.agent.core.memory.node.TopicNode;
import work.slhaf.agent.core.memory.pojo.MemorySlice;
import java.io.IOException;
import java.time.LocalDate;

View File

@@ -1,12 +1,10 @@
package memory;
import cn.hutool.core.date.LocalDateTimeUtil;
import org.junit.jupiter.api.Test;
import work.slhaf.agent.core.memory.MemoryGraph;
import work.slhaf.agent.core.memory.node.TopicNode;
import java.time.LocalDate;
import java.time.LocalDateTime;
import java.util.HashMap;
import java.util.concurrent.ConcurrentHashMap;
@@ -50,7 +48,7 @@ public void test1() {
// 输出
graph.setTopicNodes(topicMap);
graph.printTopicTree();
System.out.println(graph.getTopicTree());
}

View File

@@ -3,18 +3,17 @@ package memory;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import work.slhaf.agent.core.memory.MemoryGraph;
import work.slhaf.agent.core.memory.pojo.MemorySlice;
import work.slhaf.agent.core.memory.exception.UnExistedTopicException;
import work.slhaf.agent.core.memory.node.MemoryNode;
import work.slhaf.agent.core.memory.node.TopicNode;
import work.slhaf.agent.core.memory.pojo.MemoryResult;
import work.slhaf.agent.core.memory.pojo.MemorySlice;
import java.io.IOException;
import java.time.LocalDate;
import java.util.ArrayList;
import java.util.List;
import static org.junit.jupiter.api.Assertions.*;
import static org.junit.jupiter.api.Assertions.assertThrows;
class SearchTest {
private MemoryGraph memoryGraph;