diff --git a/.gitignore b/.gitignore index 223d58ec..25510217 100644 --- a/.gitignore +++ b/.gitignore @@ -38,3 +38,4 @@ build/ .DS_Store /data/ /config/ +/src/test/java/memory/test.json diff --git a/src/main/java/work/slhaf/Main.java b/src/main/java/work/slhaf/Main.java index 209c383f..32d10648 100644 --- a/src/main/java/work/slhaf/Main.java +++ b/src/main/java/work/slhaf/Main.java @@ -1,20 +1,14 @@ package work.slhaf; import work.slhaf.agent.Agent; -import work.slhaf.agent.core.interaction.data.InteractionInputData; import java.io.IOException; +import java.util.Scanner; public class Main { - public static void main(String[] args) throws IOException, ClassNotFoundException, InterruptedException { - Agent agent = Agent.initialize(); - - InteractionInputData inputData = new InteractionInputData(); - inputData.setContent("hello"); - inputData.setPlatform("cli"); - inputData.setUserInfo("owner"); - inputData.setUserNickName("master"); - - agent.receiveUserInput(inputData); + public static void main(String[] args) throws IOException { + Agent.initialize(); + Scanner scanner = new Scanner(System.in); + while (!scanner.nextLine().equals("exit")); } } \ No newline at end of file diff --git a/src/main/java/work/slhaf/agent/Agent.java b/src/main/java/work/slhaf/agent/Agent.java index dbf1ddda..4674595b 100644 --- a/src/main/java/work/slhaf/agent/Agent.java +++ b/src/main/java/work/slhaf/agent/Agent.java @@ -2,8 +2,10 @@ package work.slhaf.agent; import lombok.Data; import lombok.extern.slf4j.Slf4j; +import org.java_websocket.WebSocket; import work.slhaf.agent.common.config.Config; import work.slhaf.agent.core.InteractionHub; +import work.slhaf.agent.core.interaction.InputReceiver; import work.slhaf.agent.core.interaction.TaskCallback; import work.slhaf.agent.core.interaction.data.InteractionInputData; import work.slhaf.agent.core.interaction.data.InteractionOutputData; @@ -15,7 +17,7 @@ import java.time.LocalDateTime; @Data @Slf4j -public class Agent implements TaskCallback { +public class Agent implements TaskCallback, InputReceiver { private static Agent agent; private InteractionHub interactionHub; @@ -28,7 +30,22 @@ public class Agent implements TaskCallback { agent = new Agent(); agent.setInteractionHub(InteractionHub.initialize()); agent.registerTaskCallback(); - agent.setMessageSender(new AgentWebSocketServer(config.getWebSocketConfig().getPort(),agent)); + AgentWebSocketServer server = new AgentWebSocketServer(config.getWebSocketConfig().getPort(),agent); + server.start(); + Runtime.getRuntime().addShutdownHook(new Thread(() -> { + try { + for (WebSocket conn : server.getConnections()) { + conn.close(); + } + server.stop(); + log.info("WebSocketServer 已优雅关闭"); + } catch (Exception e) { + log.error("关闭失败", e); + } + })); + + agent.setMessageSender(server); + log.info("Agent 加载完毕.."); } return agent; @@ -36,9 +53,8 @@ public class Agent implements TaskCallback { /** * 接收用户输入,包装为标准输入数据类 - * @param inputData */ - public void receiveUserInput(InteractionInputData inputData) throws IOException, ClassNotFoundException, InterruptedException { + public void receiveInput(InteractionInputData inputData) throws IOException, ClassNotFoundException, InterruptedException { inputData.setLocalDateTime(LocalDateTime.now()); interactionHub.call(inputData); } @@ -46,11 +62,10 @@ public class Agent implements TaskCallback { /** * 向用户返回输出内容 - * @param output */ public void sendToUser(String userInfo,String output){ System.out.println(output); - messageSender.sendMessage(new InteractionOutputData(userInfo,output)); + messageSender.sendMessage(new InteractionOutputData(output,userInfo)); } @Override diff --git a/src/main/java/work/slhaf/agent/common/chat/pojo/ChatBody.java b/src/main/java/work/slhaf/agent/common/chat/pojo/ChatBody.java index bcea8b4d..f94bf136 100644 --- a/src/main/java/work/slhaf/agent/common/chat/pojo/ChatBody.java +++ b/src/main/java/work/slhaf/agent/common/chat/pojo/ChatBody.java @@ -1,10 +1,8 @@ package work.slhaf.agent.common.chat.pojo; -import com.alibaba.fastjson2.JSONObject; import lombok.*; import java.util.List; -import java.util.Map; @Builder @Data diff --git a/src/main/java/work/slhaf/agent/common/config/Config.java b/src/main/java/work/slhaf/agent/common/config/Config.java index 8cbc682b..c075e5d7 100644 --- a/src/main/java/work/slhaf/agent/common/config/Config.java +++ b/src/main/java/work/slhaf/agent/common/config/Config.java @@ -66,7 +66,6 @@ public class Config { private static void generatePipelineConfig() { List moduleConfigList = List.of( - new ModuleConfig(MemorySelectExtractor.class.getName(), ModuleConfig.Constant.INTERNAL, null), new ModuleConfig(MemorySelector.class.getName(), ModuleConfig.Constant.INTERNAL, null), new ModuleConfig(CoreModel.class.getName(),ModuleConfig.Constant.INTERNAL,null), new ModuleConfig(MemoryUpdater.class.getName(),ModuleConfig.Constant.INTERNAL,null), diff --git a/src/main/java/work/slhaf/agent/common/model/Model.java b/src/main/java/work/slhaf/agent/common/model/Model.java index 14179097..d98d5f2f 100644 --- a/src/main/java/work/slhaf/agent/common/model/Model.java +++ b/src/main/java/work/slhaf/agent/common/model/Model.java @@ -39,8 +39,7 @@ public class Model { model.setChatClient(new ChatClient(modelConfig.getBaseUrl(), modelConfig.getApikey(), modelConfig.getModel())); } - public ChatResponse runChat(String input) { - this.messages.add(new Message(ChatConstant.Character.USER, input)); + public ChatResponse chat() { return this.chatClient.runChat(this.messages); } diff --git a/src/main/java/work/slhaf/agent/common/model/ModelConstant.java b/src/main/java/work/slhaf/agent/common/model/ModelConstant.java index a6260632..784fd6d9 100644 --- a/src/main/java/work/slhaf/agent/common/model/ModelConstant.java +++ b/src/main/java/work/slhaf/agent/common/model/ModelConstant.java @@ -2,9 +2,60 @@ package work.slhaf.agent.common.model; public class ModelConstant { public static final String CORE_MODEL_PROMPT = """ + CoreModel 提示词 + + 功能说明 + 你需要根据用户的当前输入(text)生成恰当的回复。只有当以下字段与text内容直接相关时,才需要参考它们: + - datetime:当text包含时间相关语义时使用 + - character:当需要根据角色设定调整语气时使用 + - user_nick:当text中包含对用户的称呼或个性化需求时使用 + 其他所有字段仅在明确与text内容相关时才予以考虑,否则应完全忽略。 + + 输入字段优先级 + 1. 首要关注text字段,这是核心输入内容 + 2. 次要字段(有条件参考): + • datetime:仅当text包含时间表达时生效 + • character:仅当角色设定会影响回复风格时生效 + • user_nick:仅当需要个性化称呼时生效 + 3. 其他所有扩展字段(如memory_slices/static_memory等): + - 必须与text内容有明确关联时才参考 + - 若字段内容与text无关,则完全忽略该字段 + + 核心生成逻辑 + 1. 主内容优先原则 + - 首先独立分析text字段的语义 + - 只有当其他字段内容能直接辅助理解text时(如text说"上次说的那个"对应memory_slices中的记录),才调用相关字段 + - 若text是独立完整表达(如单字、短句、新话题开启),则忽略所有非核心字段 + + 2. 无关字段过滤机制 + - 当text属于以下情况时,强制忽略所有扩展字段: + ✓ 短于5个字符的输入(如"在"、"好的") + ✓ 明显开启新话题的提问(如"量子计算是什么") + ✓ 不含指代词的独立陈述句 + - 示例:当text="今天天气如何"时,即使存在量子计算相关的memory_slices也应忽略 + + 3. 响应生成规范 + - 回复必须完全基于text的核心语义生成 + - 禁止出现"根据您之前提到的XX"等无关内容引用 + - 当角色设定(character)与当前对话无关时(如科技助手回答日常问候),暂时覆盖角色设定 + + 输出格式 + { + "text": "响应内容" // 必须严格对应text字段的语义 + } + + 最终注意事项 + 1. 回应内容必须紧扣用户输入,且契合角色设定 + 2. 遇到模糊提问时,优先推测最常见的语境理解,不要直接问“你指的是什么” + 3. 回应应自然衔接,并允许后续系统模块追加更多限定、扩展字段 + 4. 你只需要生成JSON格式的响应对象,字段仅包含`text`,但在模块扩展下,字段内容可以有所增加。确保你可以兼容这些扩展而不破坏结构。 + 5. 若用户的输入(text)与其他字段中的内容无关,可忽略其他字段的内容 + + > 以下模块可能会追加更多内容限制或上下文提示,请确保你的回答能够自然兼容这些后续拼接的内容,并调整输出格式。 + """; public static final String SLICE_EVALUATOR_PROMPT = """ - 记忆切片选择器提示词(最终版) + SliceEvaluator 提示词 功能说明 你需要根据用户输入的JSON数据,分析其中的`text`(当前输入内容)、`history`(对话历史)和`memory_slices`(可用记忆切片),选出相关记忆切片。当text内容与history明显不相关时,应以text为主要判断依据。 diff --git a/src/main/java/work/slhaf/agent/common/util/ExtractUtil.java b/src/main/java/work/slhaf/agent/common/util/ExtractUtil.java new file mode 100644 index 00000000..209ef218 --- /dev/null +++ b/src/main/java/work/slhaf/agent/common/util/ExtractUtil.java @@ -0,0 +1,12 @@ +package work.slhaf.agent.common.util; + +public class ExtractUtil { + public static String extractJson(String jsonStr) { + int start = jsonStr.indexOf("{"); + int end = jsonStr.lastIndexOf("}"); + if (start != -1 && end != -1 && start < end) { + return jsonStr.substring(start, end + 1); + } + return jsonStr; + } +} diff --git a/src/main/java/work/slhaf/agent/core/InteractionHub.java b/src/main/java/work/slhaf/agent/core/InteractionHub.java index 05561785..72b9e422 100644 --- a/src/main/java/work/slhaf/agent/core/InteractionHub.java +++ b/src/main/java/work/slhaf/agent/core/InteractionHub.java @@ -1,6 +1,7 @@ package work.slhaf.agent.core; import lombok.Data; +import lombok.ToString; import lombok.extern.slf4j.Slf4j; import work.slhaf.agent.core.interaction.InteractionModule; import work.slhaf.agent.core.interaction.InteractionModulesLoader; @@ -21,15 +22,18 @@ public class InteractionHub { private static InteractionHub interactionHub; + @ToString.Exclude private TaskCallback callback; - private CoreModel coreModel; private MemoryManager memoryManager; private TaskScheduler taskScheduler; + private List interactionModules; public static InteractionHub initialize() throws IOException { if (interactionHub == null) { interactionHub = new InteractionHub(); + //加载模块 + interactionHub.setInteractionModules(InteractionModulesLoader.getInstance().registerInteractionModules()); log.info("InteractionHub注册完毕..."); } return interactionHub; @@ -38,11 +42,10 @@ public class InteractionHub { public void call(InteractionInputData inputData) throws IOException, ClassNotFoundException, InterruptedException { //预处理 InteractionContext interactionContext = PreprocessExecutor.getInstance().execute(inputData); - //加载模块 - List interactionModules = InteractionModulesLoader.getInstance().registerInteractionModules(); + for (InteractionModule interactionModule : interactionModules) { interactionModule.execute(interactionContext); } - callback.onTaskFinished(interactionContext.getUserInfo(), interactionContext.getCoreResponse().getString("message")); + callback.onTaskFinished(interactionContext.getUserInfo(), interactionContext.getCoreResponse().getString("text")); } } diff --git a/src/main/java/work/slhaf/agent/core/interaction/InputReceiver.java b/src/main/java/work/slhaf/agent/core/interaction/InputReceiver.java new file mode 100644 index 00000000..806cf78b --- /dev/null +++ b/src/main/java/work/slhaf/agent/core/interaction/InputReceiver.java @@ -0,0 +1,10 @@ +package work.slhaf.agent.core.interaction; + +import work.slhaf.agent.core.interaction.data.InteractionInputData; + +import java.io.IOException; + +public interface InputReceiver { + + void receiveInput(InteractionInputData inputData) throws IOException, ClassNotFoundException, InterruptedException; +} diff --git a/src/main/java/work/slhaf/agent/core/interaction/data/InteractionContext.java b/src/main/java/work/slhaf/agent/core/interaction/data/InteractionContext.java index 9725b954..628b6761 100644 --- a/src/main/java/work/slhaf/agent/core/interaction/data/InteractionContext.java +++ b/src/main/java/work/slhaf/agent/core/interaction/data/InteractionContext.java @@ -15,5 +15,6 @@ public class InteractionContext { protected String input; protected JSONObject moduleContext; + protected JSONObject modulePrompt; protected JSONObject coreResponse; } diff --git a/src/main/java/work/slhaf/agent/core/memory/MemoryGraph.java b/src/main/java/work/slhaf/agent/core/memory/MemoryGraph.java index 5e37d387..ee0de386 100644 --- a/src/main/java/work/slhaf/agent/core/memory/MemoryGraph.java +++ b/src/main/java/work/slhaf/agent/core/memory/MemoryGraph.java @@ -58,7 +58,7 @@ public class MemoryGraph extends PersistableObject { /** * 近两日的区分用户的对话总结缓存,在prompt结构上比dialogMap层级深一层, dialogMap更具近两日整体对话的摘要性质 */ - private ConcurrentHashMap> userDialogMap; + private ConcurrentHashMap> userDialogMap; /** * 当前对话的活动性总结, 拥有比dialogMap更丰富的全文细节, 作为当前对话token超限时的必要上下文压缩存储 @@ -92,6 +92,8 @@ public class MemoryGraph extends PersistableObject { */ private HashMap modelPrompt; + private String character; + /** * 主模型的聊天记录 */ @@ -117,6 +119,10 @@ public class MemoryGraph extends PersistableObject { this.memorySliceCache = new ConcurrentHashMap<>(); this.modelPrompt = new HashMap<>(); this.selectedSlices = new HashSet<>(); + this.users = new ArrayList<>(); + this.userDialogMap = new ConcurrentHashMap<>(); + this.currentCompressedSessionContext = new ArrayList<>(); + this.dialogMap = new HashMap<>(); } public static MemoryGraph getInstance(String id) throws IOException, ClassNotFoundException { @@ -126,7 +132,7 @@ public class MemoryGraph extends PersistableObject { Path filePath = getFilePath(id); if (Files.exists(filePath)) { memoryGraph = deserialize(id); - }else { + } else { FileUtils.createParentDirectories(filePath.toFile().getParentFile()); memoryGraph = new MemoryGraph(id); memoryGraph.serialize(); @@ -206,7 +212,9 @@ public class MemoryGraph extends PersistableObject { } updateDateIndex(now, slice); - updateDialogMap(slice); + if (!slice.isPrivate()) { + updateUserDialogMap(slice); + } node.saveMemorySliceList(); } @@ -241,7 +249,7 @@ public class MemoryGraph extends PersistableObject { return lastTopicNode; } - private void updateDialogMap(MemorySlice slice) { + private void updateUserDialogMap(MemorySlice slice) { String summary = slice.getSummary(); LocalDateTime now = LocalDateTime.now(); @@ -264,17 +272,21 @@ public class MemoryGraph extends PersistableObject { //更新userDialogMap //移除两天前上下文缓存(切片总结) userDialogMap.forEach((k, v) -> { - if (now.minusDays(2).isAfter(k)) { - keysToRemove.add(k); - } + v.forEach((i, j) -> { + if (now.minusDays(2).isAfter(i)) { + keysToRemove.add(i); + } + }); }); for (LocalDateTime dateTime : keysToRemove) { - userDialogMap.remove(dateTime); + userDialogMap.forEach((k, v) -> { + v.remove(dateTime); + }); } //放入新缓存 userDialogMap - .computeIfAbsent(now, k -> new ConcurrentHashMap<>()) - .merge(slice.getStartUserId(), slice.getSummary(), (oldVal, newVal) -> oldVal + " " + newVal); + .computeIfAbsent(slice.getStartUserId(), k -> new ConcurrentHashMap<>()) + .merge(now, slice.getSummary(), (oldVal, newVal) -> oldVal + " " + newVal); } @@ -298,6 +310,8 @@ public class MemoryGraph extends PersistableObject { //排序 memorySliceList.sort(null); MemorySlice tempSlice = memorySliceList.getLast(); + //设置私密状态一致 + tempSlice.setPrivate(slice.isPrivate()); //末尾切片添加当前切片的引用 tempSlice.setSliceAfter(slice); //当前切片添加前序切片的引用 @@ -329,7 +343,7 @@ public class MemoryGraph extends PersistableObject { for (MemoryNode memoryNode : targetParentNode.getTopicNodes().get(targetTopic).getMemoryNodes()) { List endpointMemorySliceList = memoryNode.loadMemorySliceList(); for (MemorySlice memorySlice : endpointMemorySliceList) { - if (selectedSlices.contains(memorySlice.getTimestamp())){ + if (selectedSlices.contains(memorySlice.getTimestamp())) { continue; } sliceResult.setSliceBefore(memorySlice.getSliceBefore()); @@ -410,7 +424,7 @@ public class MemoryGraph extends PersistableObject { CopyOnWriteArrayList targetSliceList = new CopyOnWriteArrayList<>(); for (List value : dateIndex.get(date).values()) { for (MemorySlice memorySlice : value) { - if (selectedSlices.contains(memorySlice.getTimestamp())){ + if (selectedSlices.contains(memorySlice.getTimestamp())) { continue; } MemorySliceResult memorySliceResult = new MemorySliceResult(); @@ -444,24 +458,26 @@ public class MemoryGraph extends PersistableObject { return targetParentNode; } - public void printTopicTree() { + public String getTopicTree() { + StringBuilder stringBuilder = new StringBuilder(); for (Map.Entry entry : topicNodes.entrySet()) { String rootName = entry.getKey(); TopicNode rootNode = entry.getValue(); - System.out.println(rootName+"[root]"); - printSubTopicsTreeFormat(rootNode, "", true); + stringBuilder.append(rootName).append("[root]").append("\r\n"); + printSubTopicsTreeFormat(rootNode, "", stringBuilder); } + return stringBuilder.toString(); } - private void printSubTopicsTreeFormat(TopicNode node, String prefix, boolean isLast) { + private void printSubTopicsTreeFormat(TopicNode node, String prefix, StringBuilder stringBuilder) { if (node.getTopicNodes() == null || node.getTopicNodes().isEmpty()) return; List> entries = new ArrayList<>(node.getTopicNodes().entrySet()); for (int i = 0; i < entries.size(); i++) { boolean last = (i == entries.size() - 1); Map.Entry entry = entries.get(i); - System.out.println(prefix + (last ? "└── " : "├── ") + entry.getKey()); - printSubTopicsTreeFormat(entry.getValue(), prefix + (last ? " " : "│ "), last); + stringBuilder.append(prefix).append(last ? "└── " : "├── ").append(entry.getKey()).append("\r\n"); + printSubTopicsTreeFormat(entry.getValue(), prefix + (last ? " " : "│ "), stringBuilder); } } diff --git a/src/main/java/work/slhaf/agent/core/memory/MemoryManager.java b/src/main/java/work/slhaf/agent/core/memory/MemoryManager.java index 5adcf0f1..1a30ed7a 100644 --- a/src/main/java/work/slhaf/agent/core/memory/MemoryManager.java +++ b/src/main/java/work/slhaf/agent/core/memory/MemoryManager.java @@ -8,13 +8,16 @@ import work.slhaf.agent.core.interaction.InteractionModule; import work.slhaf.agent.core.interaction.data.InteractionContext; import work.slhaf.agent.core.memory.pojo.MemoryResult; import work.slhaf.agent.core.memory.pojo.User; -import work.slhaf.agent.modules.memory.SliceEvaluator; +import work.slhaf.agent.shared.memory.EvaluatedSlice; import java.io.IOException; import java.time.LocalDate; +import java.time.LocalDateTime; import java.util.ArrayList; +import java.util.HashMap; import java.util.List; import java.util.UUID; +import java.util.concurrent.ConcurrentHashMap; @Data @Slf4j @@ -23,7 +26,7 @@ public class MemoryManager implements InteractionModule { private static MemoryManager memoryManager; private MemoryGraph memoryGraph; - private SliceEvaluator sliceEvaluator; + private HashMap> activatedSlices; private MemoryManager(){} @@ -37,7 +40,7 @@ public class MemoryManager implements InteractionModule { Config config = Config.getConfig(); memoryManager = new MemoryManager(); memoryManager.setMemoryGraph(MemoryGraph.getInstance(config.getAgentId())); - memoryManager.setSliceEvaluator(SliceEvaluator.getInstance()); + memoryManager.setActivatedSlices(new HashMap<>()); log.info("MemoryManager注册完毕..."); } return memoryManager; @@ -85,6 +88,22 @@ public class MemoryManager implements InteractionModule { } public String getTopicTree() { - return memoryManager.getTopicTree(); + return memoryGraph.getTopicTree(); + } + + public ConcurrentHashMap getStaticMemory(String userId) { + return memoryGraph.getStaticMemory().get(userId); + } + + public HashMap getDialogMap() { + return memoryGraph.getDialogMap(); + } + + public ConcurrentHashMap getUserDialogMap(String userId) { + return memoryGraph.getUserDialogMap().get(userId); + } + + public String getCharacter() { + return memoryGraph.getCharacter(); } } diff --git a/src/main/java/work/slhaf/agent/core/memory/node/MemoryNode.java b/src/main/java/work/slhaf/agent/core/memory/node/MemoryNode.java index db3476b6..c587650f 100644 --- a/src/main/java/work/slhaf/agent/core/memory/node/MemoryNode.java +++ b/src/main/java/work/slhaf/agent/core/memory/node/MemoryNode.java @@ -11,7 +11,6 @@ import java.io.*; import java.nio.file.Files; import java.nio.file.Path; import java.time.LocalDate; -import java.util.ArrayList; import java.util.List; import java.util.concurrent.CopyOnWriteArrayList; diff --git a/src/main/java/work/slhaf/agent/core/module/CoreModel.java b/src/main/java/work/slhaf/agent/core/module/CoreModel.java index 2485e1a6..1a6c4d78 100644 --- a/src/main/java/work/slhaf/agent/core/module/CoreModel.java +++ b/src/main/java/work/slhaf/agent/core/module/CoreModel.java @@ -1,17 +1,23 @@ package work.slhaf.agent.core.module; +import com.alibaba.fastjson2.JSONObject; import lombok.Data; import lombok.EqualsAndHashCode; import lombok.extern.slf4j.Slf4j; +import work.slhaf.agent.common.chat.constant.ChatConstant; import work.slhaf.agent.common.chat.pojo.ChatResponse; +import work.slhaf.agent.common.chat.pojo.Message; import work.slhaf.agent.common.config.Config; import work.slhaf.agent.common.model.Model; import work.slhaf.agent.common.model.ModelConstant; import work.slhaf.agent.core.interaction.InteractionModule; import work.slhaf.agent.core.interaction.data.InteractionContext; +import work.slhaf.agent.core.memory.MemoryManager; import java.io.IOException; +import static work.slhaf.agent.common.util.ExtractUtil.extractJson; + @EqualsAndHashCode(callSuper = true) @Data @Slf4j @@ -20,6 +26,9 @@ public class CoreModel extends Model implements InteractionModule { public static final String MODEL_KEY = "core_model"; private static CoreModel coreModel; + private MemoryManager memoryManager; + private String promptCache; + private CoreModel() { } @@ -27,6 +36,8 @@ public class CoreModel extends Model implements InteractionModule { if (coreModel == null) { Config config = Config.getConfig(); coreModel = new CoreModel(); + coreModel.memoryManager = MemoryManager.getInstance(); + coreModel.messages = coreModel.memoryManager.getChatMessages(); setModel(config, coreModel, MODEL_KEY, ModelConstant.CORE_MODEL_PROMPT); log.info("CoreModel注册完毕..."); } @@ -35,9 +46,35 @@ public class CoreModel extends Model implements InteractionModule { @Override public void execute(InteractionContext interactionContext) { - //TODO 需要拼接上下文之后再发送给主模型 + String tempPrompt = interactionContext.getModulePrompt().toString(); + if (!tempPrompt.equals(promptCache)) { + coreModel.getMessages().set(0, new Message(ChatConstant.Character.SYSTEM, ModelConstant.CORE_MODEL_PROMPT + "\r\n" + tempPrompt)); + promptCache = tempPrompt; + } + this.messages.add(new Message(ChatConstant.Character.USER, interactionContext.getModuleContext().getString("text"))); + ChatResponse chatResponse = this.chat(); + JSONObject response = null; + int count = 0; + while (true) { + try { + response = JSONObject.parse(extractJson(chatResponse.getMessage())); + this.messages.add(new Message(ChatConstant.Character.ASSISTANT, response.getString("text"))); - ChatResponse res = runChat(interactionContext.getInput()); -// interactionContext.setCoreResponse(); + //设置上下文 + interactionContext.getModuleContext().put("total_token", chatResponse.getUsageBean().getTotal_tokens()); + break; + } catch (Exception e) { + count++; + log.error("CoreModel执行异常: {}", e.getLocalizedMessage()); + if (count > 3) { + response = new JSONObject(); + response.put("text", "主模型交互出错: " + e.getLocalizedMessage()); + interactionContext.setFinished(true); + break; + } + } finally { + interactionContext.setCoreResponse(response); + } + } } } diff --git a/src/main/java/work/slhaf/agent/gateway/AgentWebSocketServer.java b/src/main/java/work/slhaf/agent/gateway/AgentWebSocketServer.java index ef9512e9..c79f3110 100644 --- a/src/main/java/work/slhaf/agent/gateway/AgentWebSocketServer.java +++ b/src/main/java/work/slhaf/agent/gateway/AgentWebSocketServer.java @@ -2,11 +2,12 @@ package work.slhaf.agent.gateway; import cn.hutool.json.JSONUtil; import com.alibaba.fastjson2.JSONObject; +import lombok.ToString; import lombok.extern.slf4j.Slf4j; import org.java_websocket.WebSocket; import org.java_websocket.handshake.ClientHandshake; import org.java_websocket.server.WebSocketServer; -import work.slhaf.agent.Agent; +import work.slhaf.agent.core.interaction.InputReceiver; import work.slhaf.agent.core.interaction.data.InteractionInputData; import work.slhaf.agent.core.interaction.data.InteractionOutputData; @@ -17,12 +18,13 @@ import java.util.concurrent.ConcurrentHashMap; @Slf4j public class AgentWebSocketServer extends WebSocketServer implements MessageSender { - private final Agent agent; + @ToString.Exclude + private final InputReceiver receiver; private final ConcurrentHashMap userSessions = new ConcurrentHashMap<>(); - public AgentWebSocketServer(int port, Agent agent) { + public AgentWebSocketServer(int port, InputReceiver receiver) { super(new InetSocketAddress(port)); - this.agent = agent; + this.receiver = receiver; } @Override @@ -41,7 +43,7 @@ public class AgentWebSocketServer extends WebSocketServer implements MessageSend InteractionInputData inputData = JSONObject.parseObject(s, InteractionInputData.class); userSessions.put(inputData.getUserInfo(), webSocket); // 注册连接 try { - agent.receiveUserInput(inputData); + receiver.receiveInput(inputData); } catch (IOException | ClassNotFoundException | InterruptedException e) { throw new RuntimeException(e); } diff --git a/src/main/java/work/slhaf/agent/modules/memory/MemorySelectExtractor.java b/src/main/java/work/slhaf/agent/modules/memory/MemorySelectExtractor.java index 1adb7e96..c829d02e 100644 --- a/src/main/java/work/slhaf/agent/modules/memory/MemorySelectExtractor.java +++ b/src/main/java/work/slhaf/agent/modules/memory/MemorySelectExtractor.java @@ -16,6 +16,8 @@ import work.slhaf.agent.modules.memory.data.extractor.ExtractorResult; import java.io.IOException; import java.util.List; +import static work.slhaf.agent.common.util.ExtractUtil.extractJson; + @EqualsAndHashCode(callSuper = true) @Data @Slf4j @@ -47,7 +49,7 @@ public class MemorySelectExtractor extends Model { .history(memoryManager.getChatMessages()) .topic_tree(memoryManager.getTopicTree()) .build(); - String responseStr = singleChat(JSONUtil.toJsonPrettyStr(extractorInput)).getMessage(); + String responseStr = extractJson(singleChat(JSONUtil.toJsonPrettyStr(extractorInput)).getMessage()); ExtractorResult extractorResult; try { diff --git a/src/main/java/work/slhaf/agent/modules/memory/MemorySelector.java b/src/main/java/work/slhaf/agent/modules/memory/MemorySelector.java index b61ee161..751e46d6 100644 --- a/src/main/java/work/slhaf/agent/modules/memory/MemorySelector.java +++ b/src/main/java/work/slhaf/agent/modules/memory/MemorySelector.java @@ -6,11 +6,10 @@ import work.slhaf.agent.core.interaction.data.InteractionContext; import work.slhaf.agent.core.memory.MemoryManager; import work.slhaf.agent.core.memory.pojo.MemoryResult; import work.slhaf.agent.core.memory.pojo.MemorySlice; -import work.slhaf.agent.modules.memory.data.evaluator.EvaluatedSlice; import work.slhaf.agent.modules.memory.data.evaluator.EvaluatorInput; -import work.slhaf.agent.modules.memory.data.evaluator.EvaluatorResult; import work.slhaf.agent.modules.memory.data.extractor.ExtractorMatchData; import work.slhaf.agent.modules.memory.data.extractor.ExtractorResult; +import work.slhaf.agent.shared.memory.EvaluatedSlice; import java.io.IOException; import java.time.LocalDate; @@ -21,6 +20,29 @@ import java.util.List; public class MemorySelector implements InteractionModule { private static MemorySelector memorySelector; + public static final String modulePrompt = """ + 新增输入字段: + + "memory_slices": [{ //记忆切片,可能为多个 + "chatMessages": [{ + "role": "user"/"assistant", //该信息发送者 + "content": "消息内容" + }], + "date": "2024-03-20", //切片日期 + "summary": "切片总结" + }], + "static_memory": "对于该用户的常识性记忆,如爱好、住处、生日", + "dialog_map": { //近两日的与所有用户的对话缓存 + "2023-01-01T11:30": "发生了...与用户A...、用户B谈到...", + "2023-01-02T11:30": "发生了...与用户A...、用户B谈到..." + } + "user_dialog_map": { //与当前用户的近两日对话缓存 + "2023-01-01T11:30": "与用户讨论了...", + "2023-01-02T11:30": "与用户讨论了..." + } + + 无新增输出字段 + """; private MemoryManager memoryManager; private SliceEvaluator sliceEvaluator; @@ -41,12 +63,13 @@ public class MemorySelector implements InteractionModule { @Override public void execute(InteractionContext interactionContext) throws IOException, ClassNotFoundException, InterruptedException { + String userId = memoryManager.getUserId(interactionContext.getUserInfo(), interactionContext.getUserNickname()); //获取主题路径 ExtractorResult extractorResult = memorySelectExtractor.execute(interactionContext); if (extractorResult.isRecall()) { //查找切片 List memoryResultList = new ArrayList<>(); - setMemoryResultList(memoryResultList, extractorResult.getMatches(), interactionContext.getUserInfo(), interactionContext.getUserNickname()); + setMemoryResultList(memoryResultList, extractorResult.getMatches(),userId); //评估切片 EvaluatorInput evaluatorInput = EvaluatorInput.builder() .input(interactionContext.getInput()) @@ -54,13 +77,19 @@ public class MemorySelector implements InteractionModule { .messages(memoryManager.getChatMessages()) .build(); List memorySlices = sliceEvaluator.execute(evaluatorInput); - //设置上下文 - interactionContext.getModuleContext().put("memory_slices",memorySlices); + memoryManager.getActivatedSlices().put(userId,memorySlices); } + //设置上下文 + interactionContext.getModuleContext().put("memory_slices",memoryManager.getActivatedSlices().get(userId)); + interactionContext.getModuleContext().put("static_memory",memoryManager.getStaticMemory(userId)); + interactionContext.getModuleContext().put("dialog_map",memoryManager.getDialogMap()); + interactionContext.getModuleContext().put("user_dialog_map",memoryManager.getUserDialogMap(userId)); + + interactionContext.getModulePrompt().put("memory", modulePrompt); } - private void setMemoryResultList(List memoryResultList, List matches, String userInfo, String nickName) throws IOException, ClassNotFoundException { + private void setMemoryResultList(List memoryResultList, List matches, String userId) throws IOException, ClassNotFoundException { for (ExtractorMatchData match : matches) { MemoryResult memoryResult = switch (match.getType()) { case ExtractorMatchData.Constant.DATE -> memoryManager.selectMemory(match.getText()); @@ -76,15 +105,14 @@ public class MemorySelector implements InteractionModule { //根据userInfo过滤是否为私人记忆 for (MemoryResult memoryResult : memoryResultList) { //过滤终点记忆 - memoryResult.getMemorySliceResult().removeIf(m -> removeOrNot(m.getMemorySlice(), userInfo, nickName)); + memoryResult.getMemorySliceResult().removeIf(m -> removeOrNot(m.getMemorySlice(), userId)); //过滤邻近记忆 - memoryResult.getRelatedMemorySliceResult().removeIf(m -> removeOrNot(m, userInfo, nickName)); + memoryResult.getRelatedMemorySliceResult().removeIf(m -> removeOrNot(m, userId)); } } - private boolean removeOrNot(MemorySlice memorySlice, String userInfo, String nickName) { + private boolean removeOrNot(MemorySlice memorySlice, String userId) { if (memorySlice.isPrivate()) { - String userId = memoryManager.getUserId(userInfo, nickName); return memorySlice.getStartUserId().equals(userId); } return true; diff --git a/src/main/java/work/slhaf/agent/modules/memory/MemoryUpdater.java b/src/main/java/work/slhaf/agent/modules/memory/MemoryUpdater.java index 1dd0687d..d8a453c5 100644 --- a/src/main/java/work/slhaf/agent/modules/memory/MemoryUpdater.java +++ b/src/main/java/work/slhaf/agent/modules/memory/MemoryUpdater.java @@ -30,6 +30,18 @@ public class MemoryUpdater implements InteractionModule { @Override public void execute(InteractionContext interactionContext) { + if (interactionContext.isFinished()){ + return; + } + //如果token 大于阈值,则更新记忆 + if (interactionContext.getModuleContext().getIntValue("total_token") > 24000) { + executor.execute(() -> { + + }); + } + + //更新确定性记忆 + } } diff --git a/src/main/java/work/slhaf/agent/modules/memory/SliceEvaluator.java b/src/main/java/work/slhaf/agent/modules/memory/SliceEvaluator.java index b2903ee8..ce8fb1d1 100644 --- a/src/main/java/work/slhaf/agent/modules/memory/SliceEvaluator.java +++ b/src/main/java/work/slhaf/agent/modules/memory/SliceEvaluator.java @@ -14,7 +14,11 @@ import work.slhaf.agent.core.memory.MemoryManager; import work.slhaf.agent.core.memory.pojo.MemoryResult; import work.slhaf.agent.core.memory.pojo.MemorySlice; import work.slhaf.agent.core.memory.pojo.MemorySliceResult; -import work.slhaf.agent.modules.memory.data.evaluator.*; +import work.slhaf.agent.modules.memory.data.evaluator.EvaluatorBatchInput; +import work.slhaf.agent.modules.memory.data.evaluator.EvaluatorInput; +import work.slhaf.agent.modules.memory.data.evaluator.EvaluatorResult; +import work.slhaf.agent.modules.memory.data.evaluator.SliceSummary; +import work.slhaf.agent.shared.memory.EvaluatedSlice; import java.io.IOException; import java.util.*; @@ -22,6 +26,8 @@ import java.util.concurrent.Callable; import java.util.concurrent.ConcurrentLinkedDeque; import java.util.concurrent.TimeUnit; +import static work.slhaf.agent.common.util.ExtractUtil.extractJson; + @EqualsAndHashCode(callSuper = true) @Data @Slf4j @@ -63,7 +69,7 @@ public class SliceEvaluator extends Model { .memory_slices(sliceSummaryList) .history(evaluatorInput.getMessages()) .build(); - EvaluatorResult evaluatorResult = JSONObject.parseObject(singleChat(JSONUtil.toJsonStr(batchInput)).getMessage(), EvaluatorResult.class); + EvaluatorResult evaluatorResult = JSONObject.parseObject(extractJson(singleChat(JSONUtil.toJsonStr(batchInput)).getMessage()), EvaluatorResult.class); for (Long result : evaluatorResult.getResults()) { SliceSummary sliceSummary = map.get(result); EvaluatedSlice evaluatedSlice = EvaluatedSlice.builder() @@ -73,7 +79,7 @@ public class SliceEvaluator extends Model { queue.offer(evaluatedSlice); } } catch (Exception e) { - log.error("切片评估: {}", e.getLocalizedMessage()); + log.error("切片评估出现错误: {}", e.getLocalizedMessage()); } return null; }); diff --git a/src/main/java/work/slhaf/agent/modules/preprocess/PreprocessExecutor.java b/src/main/java/work/slhaf/agent/modules/preprocess/PreprocessExecutor.java index a28ce03c..9b25e8e5 100644 --- a/src/main/java/work/slhaf/agent/modules/preprocess/PreprocessExecutor.java +++ b/src/main/java/work/slhaf/agent/modules/preprocess/PreprocessExecutor.java @@ -41,7 +41,10 @@ public class PreprocessExecutor { context.setModuleContext(new JSONObject()); context.getModuleContext().put("text", inputData.getContent()); context.getModuleContext().put("datetime", LocalDateTime.now().format(DateTimeFormatter.ISO_LOCAL_DATE_TIME)); - context.getModuleContext().put("character",memoryManager.getMemoryGraph().getModelPrompt()); + context.getModuleContext().put("character",memoryManager.getCharacter()); + context.getModuleContext().put("user_nick", inputData.getUserNickName()); + + context.setModulePrompt(new JSONObject()); return context; } diff --git a/src/main/java/work/slhaf/agent/modules/memory/data/evaluator/EvaluatedSlice.java b/src/main/java/work/slhaf/agent/shared/memory/EvaluatedSlice.java similarity index 71% rename from src/main/java/work/slhaf/agent/modules/memory/data/evaluator/EvaluatedSlice.java rename to src/main/java/work/slhaf/agent/shared/memory/EvaluatedSlice.java index bec0b846..6080cfb9 100644 --- a/src/main/java/work/slhaf/agent/modules/memory/data/evaluator/EvaluatedSlice.java +++ b/src/main/java/work/slhaf/agent/shared/memory/EvaluatedSlice.java @@ -1,4 +1,4 @@ -package work.slhaf.agent.modules.memory.data.evaluator; +package work.slhaf.agent.shared.memory; import lombok.Builder; import lombok.Data; @@ -10,7 +10,7 @@ import java.util.List; @Data @Builder public class EvaluatedSlice { -// private List chatMessages; + private List chatMessages; private LocalDate date; private String summary; } diff --git a/src/test/java/memory/AITest.java b/src/test/java/memory/AITest.java index 96bd516a..952cb05b 100644 --- a/src/test/java/memory/AITest.java +++ b/src/test/java/memory/AITest.java @@ -1,12 +1,16 @@ package memory; +import cn.hutool.json.JSONUtil; import org.junit.jupiter.api.Test; import work.slhaf.agent.common.chat.ChatClient; import work.slhaf.agent.common.chat.constant.ChatConstant; import work.slhaf.agent.common.chat.pojo.Message; import work.slhaf.agent.common.model.ModelConstant; +import work.slhaf.agent.modules.memory.MemorySelector; +import java.time.LocalDate; import java.util.ArrayList; +import java.util.HashMap; import java.util.List; public class AITest { @@ -94,6 +98,44 @@ public class AITest { run(input,ModelConstant.SLICE_EVALUATOR_PROMPT); } + @Test + public void coreModelTest(){ + String input = """ + { + "text": "在", + "datetime": "2024-03-22T09:00", + "character": "你是一个智能助手,专注于科技领域", + "memory_slices": [ + { + "chatMessages": [ + {"role": "user", "content": "量子计算近期的进展怎么样?"}, + {"role": "assistant", "content": "量子计算在硬件和算法上都取得了突破,IBM发布了433量子位处理器,Google也在量子优越性上取得了进展。"} + ], + "date": "2024-03-20", + "summary": "量子计算最新突破:IBM发布433量子位处理器,Google在量子优越性上取得进展。" + } + ], + "static_memory": "用户对量子计算技术非常感兴趣。", + "dialog_map": { + "2024-03-20T10:30": "与用户讨论了量子计算的最新进展" + }, + "user_dialog_map": { + "2024-03-20T10:30": "与用户讨论了量子计算的最新进展" + } + } + + """; + run(input,ModelConstant.CORE_MODEL_PROMPT + "\r\n" + MemorySelector.modulePrompt); + } + + @Test + public void map2jsonTest(){ + HashMap map = new HashMap<>(); + map.put(LocalDate.now(),"hello"); + map.put(LocalDate.now().plusDays(1),"world"); + System.out.println(JSONUtil.toJsonPrettyStr(map)); + } + private void run(String input, String prompt) { ChatClient client = new ChatClient("https://open.bigmodel.cn/api/paas/v4/chat/completions", "3db444552530b7742b0c53425fb93dcc.LcVwYjByht9AC3N9", "glm-4-flash-250414"); List messages = new ArrayList<>(); diff --git a/src/test/java/memory/InsertTest.java b/src/test/java/memory/InsertTest.java index 9bc058ff..4bace139 100644 --- a/src/test/java/memory/InsertTest.java +++ b/src/test/java/memory/InsertTest.java @@ -3,9 +3,9 @@ package memory; import org.junit.Before; import org.junit.Test; import work.slhaf.agent.core.memory.MemoryGraph; -import work.slhaf.agent.core.memory.pojo.MemorySlice; import work.slhaf.agent.core.memory.node.MemoryNode; import work.slhaf.agent.core.memory.node.TopicNode; +import work.slhaf.agent.core.memory.pojo.MemorySlice; import java.io.IOException; import java.time.LocalDate; diff --git a/src/test/java/memory/MemoryTest.java b/src/test/java/memory/MemoryTest.java index 4e75f2e8..7ff7d1ea 100644 --- a/src/test/java/memory/MemoryTest.java +++ b/src/test/java/memory/MemoryTest.java @@ -1,12 +1,10 @@ package memory; -import cn.hutool.core.date.LocalDateTimeUtil; import org.junit.jupiter.api.Test; import work.slhaf.agent.core.memory.MemoryGraph; import work.slhaf.agent.core.memory.node.TopicNode; import java.time.LocalDate; -import java.time.LocalDateTime; import java.util.HashMap; import java.util.concurrent.ConcurrentHashMap; @@ -50,7 +48,7 @@ public void test1() { // 输出 graph.setTopicNodes(topicMap); - graph.printTopicTree(); + System.out.println(graph.getTopicTree()); } diff --git a/src/test/java/memory/SearchTest.java b/src/test/java/memory/SearchTest.java index a3b8dd6b..0b5e7711 100644 --- a/src/test/java/memory/SearchTest.java +++ b/src/test/java/memory/SearchTest.java @@ -3,18 +3,17 @@ package memory; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import work.slhaf.agent.core.memory.MemoryGraph; -import work.slhaf.agent.core.memory.pojo.MemorySlice; import work.slhaf.agent.core.memory.exception.UnExistedTopicException; import work.slhaf.agent.core.memory.node.MemoryNode; import work.slhaf.agent.core.memory.node.TopicNode; -import work.slhaf.agent.core.memory.pojo.MemoryResult; +import work.slhaf.agent.core.memory.pojo.MemorySlice; import java.io.IOException; import java.time.LocalDate; import java.util.ArrayList; import java.util.List; -import static org.junit.jupiter.api.Assertions.*; +import static org.junit.jupiter.api.Assertions.assertThrows; class SearchTest { private MemoryGraph memoryGraph;