mirror of
https://github.com/slhaf/Partner.git
synced 2026-05-12 16:53:04 +08:00
- 更新了WebSocket服务器的启动逻辑
- 发现了agent, websocket, interactionHub之间的循环引用导致IDEA调试出错问题,通过exclude解决 - 实现了CoreModel的execute执行逻辑,并且系统提示词将动态拼接以适应不同模块 - 移动EvaluatedSlice至shared/memory包下,避免层级混淆 - 提取清洗json方法至独立的工具类 - 将agent通过InputReceiver接口暴露至socketServer,而非直接交给其完整实例 - 调整模块加载时机->InteractionHub加载时进行加载 - 调整MemoryGraph中userDialogMap的结构,换用以用户id为主键 - 初步进行测试,记忆更新逻辑暂未实现
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -38,3 +38,4 @@ build/
|
||||
.DS_Store
|
||||
/data/
|
||||
/config/
|
||||
/src/test/java/memory/test.json
|
||||
|
||||
@@ -1,20 +1,14 @@
|
||||
package work.slhaf;
|
||||
|
||||
import work.slhaf.agent.Agent;
|
||||
import work.slhaf.agent.core.interaction.data.InteractionInputData;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Scanner;
|
||||
|
||||
public class Main {
|
||||
public static void main(String[] args) throws IOException, ClassNotFoundException, InterruptedException {
|
||||
Agent agent = Agent.initialize();
|
||||
|
||||
InteractionInputData inputData = new InteractionInputData();
|
||||
inputData.setContent("hello");
|
||||
inputData.setPlatform("cli");
|
||||
inputData.setUserInfo("owner");
|
||||
inputData.setUserNickName("master");
|
||||
|
||||
agent.receiveUserInput(inputData);
|
||||
public static void main(String[] args) throws IOException {
|
||||
Agent.initialize();
|
||||
Scanner scanner = new Scanner(System.in);
|
||||
while (!scanner.nextLine().equals("exit"));
|
||||
}
|
||||
}
|
||||
@@ -2,8 +2,10 @@ package work.slhaf.agent;
|
||||
|
||||
import lombok.Data;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.java_websocket.WebSocket;
|
||||
import work.slhaf.agent.common.config.Config;
|
||||
import work.slhaf.agent.core.InteractionHub;
|
||||
import work.slhaf.agent.core.interaction.InputReceiver;
|
||||
import work.slhaf.agent.core.interaction.TaskCallback;
|
||||
import work.slhaf.agent.core.interaction.data.InteractionInputData;
|
||||
import work.slhaf.agent.core.interaction.data.InteractionOutputData;
|
||||
@@ -15,7 +17,7 @@ import java.time.LocalDateTime;
|
||||
|
||||
@Data
|
||||
@Slf4j
|
||||
public class Agent implements TaskCallback {
|
||||
public class Agent implements TaskCallback, InputReceiver {
|
||||
|
||||
private static Agent agent;
|
||||
private InteractionHub interactionHub;
|
||||
@@ -28,7 +30,22 @@ public class Agent implements TaskCallback {
|
||||
agent = new Agent();
|
||||
agent.setInteractionHub(InteractionHub.initialize());
|
||||
agent.registerTaskCallback();
|
||||
agent.setMessageSender(new AgentWebSocketServer(config.getWebSocketConfig().getPort(),agent));
|
||||
AgentWebSocketServer server = new AgentWebSocketServer(config.getWebSocketConfig().getPort(),agent);
|
||||
server.start();
|
||||
Runtime.getRuntime().addShutdownHook(new Thread(() -> {
|
||||
try {
|
||||
for (WebSocket conn : server.getConnections()) {
|
||||
conn.close();
|
||||
}
|
||||
server.stop();
|
||||
log.info("WebSocketServer 已优雅关闭");
|
||||
} catch (Exception e) {
|
||||
log.error("关闭失败", e);
|
||||
}
|
||||
}));
|
||||
|
||||
agent.setMessageSender(server);
|
||||
|
||||
log.info("Agent 加载完毕..");
|
||||
}
|
||||
return agent;
|
||||
@@ -36,9 +53,8 @@ public class Agent implements TaskCallback {
|
||||
|
||||
/**
|
||||
* 接收用户输入,包装为标准输入数据类
|
||||
* @param inputData
|
||||
*/
|
||||
public void receiveUserInput(InteractionInputData inputData) throws IOException, ClassNotFoundException, InterruptedException {
|
||||
public void receiveInput(InteractionInputData inputData) throws IOException, ClassNotFoundException, InterruptedException {
|
||||
inputData.setLocalDateTime(LocalDateTime.now());
|
||||
interactionHub.call(inputData);
|
||||
}
|
||||
@@ -46,11 +62,10 @@ public class Agent implements TaskCallback {
|
||||
|
||||
/**
|
||||
* 向用户返回输出内容
|
||||
* @param output
|
||||
*/
|
||||
public void sendToUser(String userInfo,String output){
|
||||
System.out.println(output);
|
||||
messageSender.sendMessage(new InteractionOutputData(userInfo,output));
|
||||
messageSender.sendMessage(new InteractionOutputData(output,userInfo));
|
||||
}
|
||||
|
||||
@Override
|
||||
|
||||
@@ -1,10 +1,8 @@
|
||||
package work.slhaf.agent.common.chat.pojo;
|
||||
|
||||
import com.alibaba.fastjson2.JSONObject;
|
||||
import lombok.*;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
@Builder
|
||||
@Data
|
||||
|
||||
@@ -66,7 +66,6 @@ public class Config {
|
||||
|
||||
private static void generatePipelineConfig() {
|
||||
List<ModuleConfig> moduleConfigList = List.of(
|
||||
new ModuleConfig(MemorySelectExtractor.class.getName(), ModuleConfig.Constant.INTERNAL, null),
|
||||
new ModuleConfig(MemorySelector.class.getName(), ModuleConfig.Constant.INTERNAL, null),
|
||||
new ModuleConfig(CoreModel.class.getName(),ModuleConfig.Constant.INTERNAL,null),
|
||||
new ModuleConfig(MemoryUpdater.class.getName(),ModuleConfig.Constant.INTERNAL,null),
|
||||
|
||||
@@ -39,8 +39,7 @@ public class Model {
|
||||
model.setChatClient(new ChatClient(modelConfig.getBaseUrl(), modelConfig.getApikey(), modelConfig.getModel()));
|
||||
}
|
||||
|
||||
public ChatResponse runChat(String input) {
|
||||
this.messages.add(new Message(ChatConstant.Character.USER, input));
|
||||
public ChatResponse chat() {
|
||||
return this.chatClient.runChat(this.messages);
|
||||
}
|
||||
|
||||
|
||||
@@ -2,9 +2,60 @@ package work.slhaf.agent.common.model;
|
||||
|
||||
public class ModelConstant {
|
||||
public static final String CORE_MODEL_PROMPT = """
|
||||
CoreModel 提示词
|
||||
|
||||
功能说明
|
||||
你需要根据用户的当前输入(text)生成恰当的回复。只有当以下字段与text内容直接相关时,才需要参考它们:
|
||||
- datetime:当text包含时间相关语义时使用
|
||||
- character:当需要根据角色设定调整语气时使用
|
||||
- user_nick:当text中包含对用户的称呼或个性化需求时使用
|
||||
其他所有字段仅在明确与text内容相关时才予以考虑,否则应完全忽略。
|
||||
|
||||
输入字段优先级
|
||||
1. 首要关注text字段,这是核心输入内容
|
||||
2. 次要字段(有条件参考):
|
||||
• datetime:仅当text包含时间表达时生效
|
||||
• character:仅当角色设定会影响回复风格时生效
|
||||
• user_nick:仅当需要个性化称呼时生效
|
||||
3. 其他所有扩展字段(如memory_slices/static_memory等):
|
||||
- 必须与text内容有明确关联时才参考
|
||||
- 若字段内容与text无关,则完全忽略该字段
|
||||
|
||||
核心生成逻辑
|
||||
1. 主内容优先原则
|
||||
- 首先独立分析text字段的语义
|
||||
- 只有当其他字段内容能直接辅助理解text时(如text说"上次说的那个"对应memory_slices中的记录),才调用相关字段
|
||||
- 若text是独立完整表达(如单字、短句、新话题开启),则忽略所有非核心字段
|
||||
|
||||
2. 无关字段过滤机制
|
||||
- 当text属于以下情况时,强制忽略所有扩展字段:
|
||||
✓ 短于5个字符的输入(如"在"、"好的")
|
||||
✓ 明显开启新话题的提问(如"量子计算是什么")
|
||||
✓ 不含指代词的独立陈述句
|
||||
- 示例:当text="今天天气如何"时,即使存在量子计算相关的memory_slices也应忽略
|
||||
|
||||
3. 响应生成规范
|
||||
- 回复必须完全基于text的核心语义生成
|
||||
- 禁止出现"根据您之前提到的XX"等无关内容引用
|
||||
- 当角色设定(character)与当前对话无关时(如科技助手回答日常问候),暂时覆盖角色设定
|
||||
|
||||
输出格式
|
||||
{
|
||||
"text": "响应内容" // 必须严格对应text字段的语义
|
||||
}
|
||||
|
||||
最终注意事项
|
||||
1. 回应内容必须紧扣用户输入,且契合角色设定
|
||||
2. 遇到模糊提问时,优先推测最常见的语境理解,不要直接问“你指的是什么”
|
||||
3. 回应应自然衔接,并允许后续系统模块追加更多限定、扩展字段
|
||||
4. 你只需要生成JSON格式的响应对象,字段仅包含`text`,但在模块扩展下,字段内容可以有所增加。确保你可以兼容这些扩展而不破坏结构。
|
||||
5. 若用户的输入(text)与其他字段中的内容无关,可忽略其他字段的内容
|
||||
|
||||
> 以下模块可能会追加更多内容限制或上下文提示,请确保你的回答能够自然兼容这些后续拼接的内容,并调整输出格式。
|
||||
|
||||
""";
|
||||
public static final String SLICE_EVALUATOR_PROMPT = """
|
||||
记忆切片选择器提示词(最终版)
|
||||
SliceEvaluator 提示词
|
||||
|
||||
功能说明
|
||||
你需要根据用户输入的JSON数据,分析其中的`text`(当前输入内容)、`history`(对话历史)和`memory_slices`(可用记忆切片),选出相关记忆切片。当text内容与history明显不相关时,应以text为主要判断依据。
|
||||
|
||||
12
src/main/java/work/slhaf/agent/common/util/ExtractUtil.java
Normal file
12
src/main/java/work/slhaf/agent/common/util/ExtractUtil.java
Normal file
@@ -0,0 +1,12 @@
|
||||
package work.slhaf.agent.common.util;
|
||||
|
||||
public class ExtractUtil {
|
||||
public static String extractJson(String jsonStr) {
|
||||
int start = jsonStr.indexOf("{");
|
||||
int end = jsonStr.lastIndexOf("}");
|
||||
if (start != -1 && end != -1 && start < end) {
|
||||
return jsonStr.substring(start, end + 1);
|
||||
}
|
||||
return jsonStr;
|
||||
}
|
||||
}
|
||||
@@ -1,6 +1,7 @@
|
||||
package work.slhaf.agent.core;
|
||||
|
||||
import lombok.Data;
|
||||
import lombok.ToString;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import work.slhaf.agent.core.interaction.InteractionModule;
|
||||
import work.slhaf.agent.core.interaction.InteractionModulesLoader;
|
||||
@@ -21,15 +22,18 @@ public class InteractionHub {
|
||||
|
||||
private static InteractionHub interactionHub;
|
||||
|
||||
@ToString.Exclude
|
||||
private TaskCallback callback;
|
||||
|
||||
private CoreModel coreModel;
|
||||
private MemoryManager memoryManager;
|
||||
private TaskScheduler taskScheduler;
|
||||
private List<InteractionModule> interactionModules;
|
||||
|
||||
public static InteractionHub initialize() throws IOException {
|
||||
if (interactionHub == null) {
|
||||
interactionHub = new InteractionHub();
|
||||
//加载模块
|
||||
interactionHub.setInteractionModules(InteractionModulesLoader.getInstance().registerInteractionModules());
|
||||
log.info("InteractionHub注册完毕...");
|
||||
}
|
||||
return interactionHub;
|
||||
@@ -38,11 +42,10 @@ public class InteractionHub {
|
||||
public void call(InteractionInputData inputData) throws IOException, ClassNotFoundException, InterruptedException {
|
||||
//预处理
|
||||
InteractionContext interactionContext = PreprocessExecutor.getInstance().execute(inputData);
|
||||
//加载模块
|
||||
List<InteractionModule> interactionModules = InteractionModulesLoader.getInstance().registerInteractionModules();
|
||||
|
||||
for (InteractionModule interactionModule : interactionModules) {
|
||||
interactionModule.execute(interactionContext);
|
||||
}
|
||||
callback.onTaskFinished(interactionContext.getUserInfo(), interactionContext.getCoreResponse().getString("message"));
|
||||
callback.onTaskFinished(interactionContext.getUserInfo(), interactionContext.getCoreResponse().getString("text"));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,10 @@
|
||||
package work.slhaf.agent.core.interaction;
|
||||
|
||||
import work.slhaf.agent.core.interaction.data.InteractionInputData;
|
||||
|
||||
import java.io.IOException;
|
||||
|
||||
public interface InputReceiver {
|
||||
|
||||
void receiveInput(InteractionInputData inputData) throws IOException, ClassNotFoundException, InterruptedException;
|
||||
}
|
||||
@@ -15,5 +15,6 @@ public class InteractionContext {
|
||||
protected String input;
|
||||
|
||||
protected JSONObject moduleContext;
|
||||
protected JSONObject modulePrompt;
|
||||
protected JSONObject coreResponse;
|
||||
}
|
||||
|
||||
@@ -58,7 +58,7 @@ public class MemoryGraph extends PersistableObject {
|
||||
/**
|
||||
* 近两日的区分用户的对话总结缓存,在prompt结构上比dialogMap层级深一层, dialogMap更具近两日整体对话的摘要性质
|
||||
*/
|
||||
private ConcurrentHashMap<LocalDateTime, ConcurrentHashMap<String/*userId*/, String>> userDialogMap;
|
||||
private ConcurrentHashMap<String/*userId*/, ConcurrentHashMap<LocalDateTime, String>> userDialogMap;
|
||||
|
||||
/**
|
||||
* 当前对话的活动性总结, 拥有比dialogMap更丰富的全文细节, 作为当前对话token超限时的必要上下文压缩存储
|
||||
@@ -92,6 +92,8 @@ public class MemoryGraph extends PersistableObject {
|
||||
*/
|
||||
private HashMap<String, String> modelPrompt;
|
||||
|
||||
private String character;
|
||||
|
||||
/**
|
||||
* 主模型的聊天记录
|
||||
*/
|
||||
@@ -117,6 +119,10 @@ public class MemoryGraph extends PersistableObject {
|
||||
this.memorySliceCache = new ConcurrentHashMap<>();
|
||||
this.modelPrompt = new HashMap<>();
|
||||
this.selectedSlices = new HashSet<>();
|
||||
this.users = new ArrayList<>();
|
||||
this.userDialogMap = new ConcurrentHashMap<>();
|
||||
this.currentCompressedSessionContext = new ArrayList<>();
|
||||
this.dialogMap = new HashMap<>();
|
||||
}
|
||||
|
||||
public static MemoryGraph getInstance(String id) throws IOException, ClassNotFoundException {
|
||||
@@ -206,7 +212,9 @@ public class MemoryGraph extends PersistableObject {
|
||||
}
|
||||
|
||||
updateDateIndex(now, slice);
|
||||
updateDialogMap(slice);
|
||||
if (!slice.isPrivate()) {
|
||||
updateUserDialogMap(slice);
|
||||
}
|
||||
node.saveMemorySliceList();
|
||||
}
|
||||
|
||||
@@ -241,7 +249,7 @@ public class MemoryGraph extends PersistableObject {
|
||||
return lastTopicNode;
|
||||
}
|
||||
|
||||
private void updateDialogMap(MemorySlice slice) {
|
||||
private void updateUserDialogMap(MemorySlice slice) {
|
||||
String summary = slice.getSummary();
|
||||
LocalDateTime now = LocalDateTime.now();
|
||||
|
||||
@@ -264,17 +272,21 @@ public class MemoryGraph extends PersistableObject {
|
||||
//更新userDialogMap
|
||||
//移除两天前上下文缓存(切片总结)
|
||||
userDialogMap.forEach((k, v) -> {
|
||||
if (now.minusDays(2).isAfter(k)) {
|
||||
keysToRemove.add(k);
|
||||
v.forEach((i, j) -> {
|
||||
if (now.minusDays(2).isAfter(i)) {
|
||||
keysToRemove.add(i);
|
||||
}
|
||||
});
|
||||
});
|
||||
for (LocalDateTime dateTime : keysToRemove) {
|
||||
userDialogMap.remove(dateTime);
|
||||
userDialogMap.forEach((k, v) -> {
|
||||
v.remove(dateTime);
|
||||
});
|
||||
}
|
||||
//放入新缓存
|
||||
userDialogMap
|
||||
.computeIfAbsent(now, k -> new ConcurrentHashMap<>())
|
||||
.merge(slice.getStartUserId(), slice.getSummary(), (oldVal, newVal) -> oldVal + " " + newVal);
|
||||
.computeIfAbsent(slice.getStartUserId(), k -> new ConcurrentHashMap<>())
|
||||
.merge(now, slice.getSummary(), (oldVal, newVal) -> oldVal + " " + newVal);
|
||||
|
||||
}
|
||||
|
||||
@@ -298,6 +310,8 @@ public class MemoryGraph extends PersistableObject {
|
||||
//排序
|
||||
memorySliceList.sort(null);
|
||||
MemorySlice tempSlice = memorySliceList.getLast();
|
||||
//设置私密状态一致
|
||||
tempSlice.setPrivate(slice.isPrivate());
|
||||
//末尾切片添加当前切片的引用
|
||||
tempSlice.setSliceAfter(slice);
|
||||
//当前切片添加前序切片的引用
|
||||
@@ -444,24 +458,26 @@ public class MemoryGraph extends PersistableObject {
|
||||
return targetParentNode;
|
||||
}
|
||||
|
||||
public void printTopicTree() {
|
||||
public String getTopicTree() {
|
||||
StringBuilder stringBuilder = new StringBuilder();
|
||||
for (Map.Entry<String, TopicNode> entry : topicNodes.entrySet()) {
|
||||
String rootName = entry.getKey();
|
||||
TopicNode rootNode = entry.getValue();
|
||||
System.out.println(rootName+"[root]");
|
||||
printSubTopicsTreeFormat(rootNode, "", true);
|
||||
stringBuilder.append(rootName).append("[root]").append("\r\n");
|
||||
printSubTopicsTreeFormat(rootNode, "", stringBuilder);
|
||||
}
|
||||
return stringBuilder.toString();
|
||||
}
|
||||
|
||||
private void printSubTopicsTreeFormat(TopicNode node, String prefix, boolean isLast) {
|
||||
private void printSubTopicsTreeFormat(TopicNode node, String prefix, StringBuilder stringBuilder) {
|
||||
if (node.getTopicNodes() == null || node.getTopicNodes().isEmpty()) return;
|
||||
|
||||
List<Map.Entry<String, TopicNode>> entries = new ArrayList<>(node.getTopicNodes().entrySet());
|
||||
for (int i = 0; i < entries.size(); i++) {
|
||||
boolean last = (i == entries.size() - 1);
|
||||
Map.Entry<String, TopicNode> entry = entries.get(i);
|
||||
System.out.println(prefix + (last ? "└── " : "├── ") + entry.getKey());
|
||||
printSubTopicsTreeFormat(entry.getValue(), prefix + (last ? " " : "│ "), last);
|
||||
stringBuilder.append(prefix).append(last ? "└── " : "├── ").append(entry.getKey()).append("\r\n");
|
||||
printSubTopicsTreeFormat(entry.getValue(), prefix + (last ? " " : "│ "), stringBuilder);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -8,13 +8,16 @@ import work.slhaf.agent.core.interaction.InteractionModule;
|
||||
import work.slhaf.agent.core.interaction.data.InteractionContext;
|
||||
import work.slhaf.agent.core.memory.pojo.MemoryResult;
|
||||
import work.slhaf.agent.core.memory.pojo.User;
|
||||
import work.slhaf.agent.modules.memory.SliceEvaluator;
|
||||
import work.slhaf.agent.shared.memory.EvaluatedSlice;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.time.LocalDate;
|
||||
import java.time.LocalDateTime;
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.UUID;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
|
||||
@Data
|
||||
@Slf4j
|
||||
@@ -23,7 +26,7 @@ public class MemoryManager implements InteractionModule {
|
||||
private static MemoryManager memoryManager;
|
||||
|
||||
private MemoryGraph memoryGraph;
|
||||
private SliceEvaluator sliceEvaluator;
|
||||
private HashMap<String,List<EvaluatedSlice>> activatedSlices;
|
||||
|
||||
private MemoryManager(){}
|
||||
|
||||
@@ -37,7 +40,7 @@ public class MemoryManager implements InteractionModule {
|
||||
Config config = Config.getConfig();
|
||||
memoryManager = new MemoryManager();
|
||||
memoryManager.setMemoryGraph(MemoryGraph.getInstance(config.getAgentId()));
|
||||
memoryManager.setSliceEvaluator(SliceEvaluator.getInstance());
|
||||
memoryManager.setActivatedSlices(new HashMap<>());
|
||||
log.info("MemoryManager注册完毕...");
|
||||
}
|
||||
return memoryManager;
|
||||
@@ -85,6 +88,22 @@ public class MemoryManager implements InteractionModule {
|
||||
}
|
||||
|
||||
public String getTopicTree() {
|
||||
return memoryManager.getTopicTree();
|
||||
return memoryGraph.getTopicTree();
|
||||
}
|
||||
|
||||
public ConcurrentHashMap<String,String> getStaticMemory(String userId) {
|
||||
return memoryGraph.getStaticMemory().get(userId);
|
||||
}
|
||||
|
||||
public HashMap<LocalDateTime, String> getDialogMap() {
|
||||
return memoryGraph.getDialogMap();
|
||||
}
|
||||
|
||||
public ConcurrentHashMap<LocalDateTime, String> getUserDialogMap(String userId) {
|
||||
return memoryGraph.getUserDialogMap().get(userId);
|
||||
}
|
||||
|
||||
public String getCharacter() {
|
||||
return memoryGraph.getCharacter();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -11,7 +11,6 @@ import java.io.*;
|
||||
import java.nio.file.Files;
|
||||
import java.nio.file.Path;
|
||||
import java.time.LocalDate;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.concurrent.CopyOnWriteArrayList;
|
||||
|
||||
|
||||
@@ -1,17 +1,23 @@
|
||||
package work.slhaf.agent.core.module;
|
||||
|
||||
import com.alibaba.fastjson2.JSONObject;
|
||||
import lombok.Data;
|
||||
import lombok.EqualsAndHashCode;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import work.slhaf.agent.common.chat.constant.ChatConstant;
|
||||
import work.slhaf.agent.common.chat.pojo.ChatResponse;
|
||||
import work.slhaf.agent.common.chat.pojo.Message;
|
||||
import work.slhaf.agent.common.config.Config;
|
||||
import work.slhaf.agent.common.model.Model;
|
||||
import work.slhaf.agent.common.model.ModelConstant;
|
||||
import work.slhaf.agent.core.interaction.InteractionModule;
|
||||
import work.slhaf.agent.core.interaction.data.InteractionContext;
|
||||
import work.slhaf.agent.core.memory.MemoryManager;
|
||||
|
||||
import java.io.IOException;
|
||||
|
||||
import static work.slhaf.agent.common.util.ExtractUtil.extractJson;
|
||||
|
||||
@EqualsAndHashCode(callSuper = true)
|
||||
@Data
|
||||
@Slf4j
|
||||
@@ -20,6 +26,9 @@ public class CoreModel extends Model implements InteractionModule {
|
||||
public static final String MODEL_KEY = "core_model";
|
||||
private static CoreModel coreModel;
|
||||
|
||||
private MemoryManager memoryManager;
|
||||
private String promptCache;
|
||||
|
||||
private CoreModel() {
|
||||
}
|
||||
|
||||
@@ -27,6 +36,8 @@ public class CoreModel extends Model implements InteractionModule {
|
||||
if (coreModel == null) {
|
||||
Config config = Config.getConfig();
|
||||
coreModel = new CoreModel();
|
||||
coreModel.memoryManager = MemoryManager.getInstance();
|
||||
coreModel.messages = coreModel.memoryManager.getChatMessages();
|
||||
setModel(config, coreModel, MODEL_KEY, ModelConstant.CORE_MODEL_PROMPT);
|
||||
log.info("CoreModel注册完毕...");
|
||||
}
|
||||
@@ -35,9 +46,35 @@ public class CoreModel extends Model implements InteractionModule {
|
||||
|
||||
@Override
|
||||
public void execute(InteractionContext interactionContext) {
|
||||
//TODO 需要拼接上下文之后再发送给主模型
|
||||
String tempPrompt = interactionContext.getModulePrompt().toString();
|
||||
if (!tempPrompt.equals(promptCache)) {
|
||||
coreModel.getMessages().set(0, new Message(ChatConstant.Character.SYSTEM, ModelConstant.CORE_MODEL_PROMPT + "\r\n" + tempPrompt));
|
||||
promptCache = tempPrompt;
|
||||
}
|
||||
this.messages.add(new Message(ChatConstant.Character.USER, interactionContext.getModuleContext().getString("text")));
|
||||
ChatResponse chatResponse = this.chat();
|
||||
JSONObject response = null;
|
||||
int count = 0;
|
||||
while (true) {
|
||||
try {
|
||||
response = JSONObject.parse(extractJson(chatResponse.getMessage()));
|
||||
this.messages.add(new Message(ChatConstant.Character.ASSISTANT, response.getString("text")));
|
||||
|
||||
ChatResponse res = runChat(interactionContext.getInput());
|
||||
// interactionContext.setCoreResponse();
|
||||
//设置上下文
|
||||
interactionContext.getModuleContext().put("total_token", chatResponse.getUsageBean().getTotal_tokens());
|
||||
break;
|
||||
} catch (Exception e) {
|
||||
count++;
|
||||
log.error("CoreModel执行异常: {}", e.getLocalizedMessage());
|
||||
if (count > 3) {
|
||||
response = new JSONObject();
|
||||
response.put("text", "主模型交互出错: " + e.getLocalizedMessage());
|
||||
interactionContext.setFinished(true);
|
||||
break;
|
||||
}
|
||||
} finally {
|
||||
interactionContext.setCoreResponse(response);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,11 +2,12 @@ package work.slhaf.agent.gateway;
|
||||
|
||||
import cn.hutool.json.JSONUtil;
|
||||
import com.alibaba.fastjson2.JSONObject;
|
||||
import lombok.ToString;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.java_websocket.WebSocket;
|
||||
import org.java_websocket.handshake.ClientHandshake;
|
||||
import org.java_websocket.server.WebSocketServer;
|
||||
import work.slhaf.agent.Agent;
|
||||
import work.slhaf.agent.core.interaction.InputReceiver;
|
||||
import work.slhaf.agent.core.interaction.data.InteractionInputData;
|
||||
import work.slhaf.agent.core.interaction.data.InteractionOutputData;
|
||||
|
||||
@@ -17,12 +18,13 @@ import java.util.concurrent.ConcurrentHashMap;
|
||||
@Slf4j
|
||||
public class AgentWebSocketServer extends WebSocketServer implements MessageSender {
|
||||
|
||||
private final Agent agent;
|
||||
@ToString.Exclude
|
||||
private final InputReceiver receiver;
|
||||
private final ConcurrentHashMap<String, WebSocket> userSessions = new ConcurrentHashMap<>();
|
||||
|
||||
public AgentWebSocketServer(int port, Agent agent) {
|
||||
public AgentWebSocketServer(int port, InputReceiver receiver) {
|
||||
super(new InetSocketAddress(port));
|
||||
this.agent = agent;
|
||||
this.receiver = receiver;
|
||||
}
|
||||
|
||||
@Override
|
||||
@@ -41,7 +43,7 @@ public class AgentWebSocketServer extends WebSocketServer implements MessageSend
|
||||
InteractionInputData inputData = JSONObject.parseObject(s, InteractionInputData.class);
|
||||
userSessions.put(inputData.getUserInfo(), webSocket); // 注册连接
|
||||
try {
|
||||
agent.receiveUserInput(inputData);
|
||||
receiver.receiveInput(inputData);
|
||||
} catch (IOException | ClassNotFoundException | InterruptedException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
|
||||
@@ -16,6 +16,8 @@ import work.slhaf.agent.modules.memory.data.extractor.ExtractorResult;
|
||||
import java.io.IOException;
|
||||
import java.util.List;
|
||||
|
||||
import static work.slhaf.agent.common.util.ExtractUtil.extractJson;
|
||||
|
||||
@EqualsAndHashCode(callSuper = true)
|
||||
@Data
|
||||
@Slf4j
|
||||
@@ -47,7 +49,7 @@ public class MemorySelectExtractor extends Model {
|
||||
.history(memoryManager.getChatMessages())
|
||||
.topic_tree(memoryManager.getTopicTree())
|
||||
.build();
|
||||
String responseStr = singleChat(JSONUtil.toJsonPrettyStr(extractorInput)).getMessage();
|
||||
String responseStr = extractJson(singleChat(JSONUtil.toJsonPrettyStr(extractorInput)).getMessage());
|
||||
|
||||
ExtractorResult extractorResult;
|
||||
try {
|
||||
|
||||
@@ -6,11 +6,10 @@ import work.slhaf.agent.core.interaction.data.InteractionContext;
|
||||
import work.slhaf.agent.core.memory.MemoryManager;
|
||||
import work.slhaf.agent.core.memory.pojo.MemoryResult;
|
||||
import work.slhaf.agent.core.memory.pojo.MemorySlice;
|
||||
import work.slhaf.agent.modules.memory.data.evaluator.EvaluatedSlice;
|
||||
import work.slhaf.agent.modules.memory.data.evaluator.EvaluatorInput;
|
||||
import work.slhaf.agent.modules.memory.data.evaluator.EvaluatorResult;
|
||||
import work.slhaf.agent.modules.memory.data.extractor.ExtractorMatchData;
|
||||
import work.slhaf.agent.modules.memory.data.extractor.ExtractorResult;
|
||||
import work.slhaf.agent.shared.memory.EvaluatedSlice;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.time.LocalDate;
|
||||
@@ -21,6 +20,29 @@ import java.util.List;
|
||||
public class MemorySelector implements InteractionModule {
|
||||
|
||||
private static MemorySelector memorySelector;
|
||||
public static final String modulePrompt = """
|
||||
新增输入字段:
|
||||
|
||||
"memory_slices": [{ //记忆切片,可能为多个
|
||||
"chatMessages": [{
|
||||
"role": "user"/"assistant", //该信息发送者
|
||||
"content": "消息内容"
|
||||
}],
|
||||
"date": "2024-03-20", //切片日期
|
||||
"summary": "切片总结"
|
||||
}],
|
||||
"static_memory": "对于该用户的常识性记忆,如爱好、住处、生日",
|
||||
"dialog_map": { //近两日的与所有用户的对话缓存
|
||||
"2023-01-01T11:30": "发生了...与用户A...、用户B谈到...",
|
||||
"2023-01-02T11:30": "发生了...与用户A...、用户B谈到..."
|
||||
}
|
||||
"user_dialog_map": { //与当前用户的近两日对话缓存
|
||||
"2023-01-01T11:30": "与用户讨论了...",
|
||||
"2023-01-02T11:30": "与用户讨论了..."
|
||||
}
|
||||
|
||||
无新增输出字段
|
||||
""";
|
||||
|
||||
private MemoryManager memoryManager;
|
||||
private SliceEvaluator sliceEvaluator;
|
||||
@@ -41,12 +63,13 @@ public class MemorySelector implements InteractionModule {
|
||||
|
||||
@Override
|
||||
public void execute(InteractionContext interactionContext) throws IOException, ClassNotFoundException, InterruptedException {
|
||||
String userId = memoryManager.getUserId(interactionContext.getUserInfo(), interactionContext.getUserNickname());
|
||||
//获取主题路径
|
||||
ExtractorResult extractorResult = memorySelectExtractor.execute(interactionContext);
|
||||
if (extractorResult.isRecall()) {
|
||||
//查找切片
|
||||
List<MemoryResult> memoryResultList = new ArrayList<>();
|
||||
setMemoryResultList(memoryResultList, extractorResult.getMatches(), interactionContext.getUserInfo(), interactionContext.getUserNickname());
|
||||
setMemoryResultList(memoryResultList, extractorResult.getMatches(),userId);
|
||||
//评估切片
|
||||
EvaluatorInput evaluatorInput = EvaluatorInput.builder()
|
||||
.input(interactionContext.getInput())
|
||||
@@ -54,13 +77,19 @@ public class MemorySelector implements InteractionModule {
|
||||
.messages(memoryManager.getChatMessages())
|
||||
.build();
|
||||
List<EvaluatedSlice> memorySlices = sliceEvaluator.execute(evaluatorInput);
|
||||
memoryManager.getActivatedSlices().put(userId,memorySlices);
|
||||
}
|
||||
|
||||
//设置上下文
|
||||
interactionContext.getModuleContext().put("memory_slices",memorySlices);
|
||||
interactionContext.getModuleContext().put("memory_slices",memoryManager.getActivatedSlices().get(userId));
|
||||
interactionContext.getModuleContext().put("static_memory",memoryManager.getStaticMemory(userId));
|
||||
interactionContext.getModuleContext().put("dialog_map",memoryManager.getDialogMap());
|
||||
interactionContext.getModuleContext().put("user_dialog_map",memoryManager.getUserDialogMap(userId));
|
||||
|
||||
interactionContext.getModulePrompt().put("memory", modulePrompt);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
private void setMemoryResultList(List<MemoryResult> memoryResultList, List<ExtractorMatchData> matches, String userInfo, String nickName) throws IOException, ClassNotFoundException {
|
||||
private void setMemoryResultList(List<MemoryResult> memoryResultList, List<ExtractorMatchData> matches, String userId) throws IOException, ClassNotFoundException {
|
||||
for (ExtractorMatchData match : matches) {
|
||||
MemoryResult memoryResult = switch (match.getType()) {
|
||||
case ExtractorMatchData.Constant.DATE -> memoryManager.selectMemory(match.getText());
|
||||
@@ -76,15 +105,14 @@ public class MemorySelector implements InteractionModule {
|
||||
//根据userInfo过滤是否为私人记忆
|
||||
for (MemoryResult memoryResult : memoryResultList) {
|
||||
//过滤终点记忆
|
||||
memoryResult.getMemorySliceResult().removeIf(m -> removeOrNot(m.getMemorySlice(), userInfo, nickName));
|
||||
memoryResult.getMemorySliceResult().removeIf(m -> removeOrNot(m.getMemorySlice(), userId));
|
||||
//过滤邻近记忆
|
||||
memoryResult.getRelatedMemorySliceResult().removeIf(m -> removeOrNot(m, userInfo, nickName));
|
||||
memoryResult.getRelatedMemorySliceResult().removeIf(m -> removeOrNot(m, userId));
|
||||
}
|
||||
}
|
||||
|
||||
private boolean removeOrNot(MemorySlice memorySlice, String userInfo, String nickName) {
|
||||
private boolean removeOrNot(MemorySlice memorySlice, String userId) {
|
||||
if (memorySlice.isPrivate()) {
|
||||
String userId = memoryManager.getUserId(userInfo, nickName);
|
||||
return memorySlice.getStartUserId().equals(userId);
|
||||
}
|
||||
return true;
|
||||
|
||||
@@ -30,6 +30,18 @@ public class MemoryUpdater implements InteractionModule {
|
||||
|
||||
@Override
|
||||
public void execute(InteractionContext interactionContext) {
|
||||
if (interactionContext.isFinished()){
|
||||
return;
|
||||
}
|
||||
//如果token 大于阈值,则更新记忆
|
||||
if (interactionContext.getModuleContext().getIntValue("total_token") > 24000) {
|
||||
executor.execute(() -> {
|
||||
|
||||
});
|
||||
}
|
||||
|
||||
//更新确定性记忆
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
@@ -14,7 +14,11 @@ import work.slhaf.agent.core.memory.MemoryManager;
|
||||
import work.slhaf.agent.core.memory.pojo.MemoryResult;
|
||||
import work.slhaf.agent.core.memory.pojo.MemorySlice;
|
||||
import work.slhaf.agent.core.memory.pojo.MemorySliceResult;
|
||||
import work.slhaf.agent.modules.memory.data.evaluator.*;
|
||||
import work.slhaf.agent.modules.memory.data.evaluator.EvaluatorBatchInput;
|
||||
import work.slhaf.agent.modules.memory.data.evaluator.EvaluatorInput;
|
||||
import work.slhaf.agent.modules.memory.data.evaluator.EvaluatorResult;
|
||||
import work.slhaf.agent.modules.memory.data.evaluator.SliceSummary;
|
||||
import work.slhaf.agent.shared.memory.EvaluatedSlice;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.*;
|
||||
@@ -22,6 +26,8 @@ import java.util.concurrent.Callable;
|
||||
import java.util.concurrent.ConcurrentLinkedDeque;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
|
||||
import static work.slhaf.agent.common.util.ExtractUtil.extractJson;
|
||||
|
||||
@EqualsAndHashCode(callSuper = true)
|
||||
@Data
|
||||
@Slf4j
|
||||
@@ -63,7 +69,7 @@ public class SliceEvaluator extends Model {
|
||||
.memory_slices(sliceSummaryList)
|
||||
.history(evaluatorInput.getMessages())
|
||||
.build();
|
||||
EvaluatorResult evaluatorResult = JSONObject.parseObject(singleChat(JSONUtil.toJsonStr(batchInput)).getMessage(), EvaluatorResult.class);
|
||||
EvaluatorResult evaluatorResult = JSONObject.parseObject(extractJson(singleChat(JSONUtil.toJsonStr(batchInput)).getMessage()), EvaluatorResult.class);
|
||||
for (Long result : evaluatorResult.getResults()) {
|
||||
SliceSummary sliceSummary = map.get(result);
|
||||
EvaluatedSlice evaluatedSlice = EvaluatedSlice.builder()
|
||||
@@ -73,7 +79,7 @@ public class SliceEvaluator extends Model {
|
||||
queue.offer(evaluatedSlice);
|
||||
}
|
||||
} catch (Exception e) {
|
||||
log.error("切片评估: {}", e.getLocalizedMessage());
|
||||
log.error("切片评估出现错误: {}", e.getLocalizedMessage());
|
||||
}
|
||||
return null;
|
||||
});
|
||||
|
||||
@@ -41,7 +41,10 @@ public class PreprocessExecutor {
|
||||
context.setModuleContext(new JSONObject());
|
||||
context.getModuleContext().put("text", inputData.getContent());
|
||||
context.getModuleContext().put("datetime", LocalDateTime.now().format(DateTimeFormatter.ISO_LOCAL_DATE_TIME));
|
||||
context.getModuleContext().put("character",memoryManager.getMemoryGraph().getModelPrompt());
|
||||
context.getModuleContext().put("character",memoryManager.getCharacter());
|
||||
context.getModuleContext().put("user_nick", inputData.getUserNickName());
|
||||
|
||||
context.setModulePrompt(new JSONObject());
|
||||
|
||||
return context;
|
||||
}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
package work.slhaf.agent.modules.memory.data.evaluator;
|
||||
package work.slhaf.agent.shared.memory;
|
||||
|
||||
import lombok.Builder;
|
||||
import lombok.Data;
|
||||
@@ -10,7 +10,7 @@ import java.util.List;
|
||||
@Data
|
||||
@Builder
|
||||
public class EvaluatedSlice {
|
||||
// private List<Message> chatMessages;
|
||||
private List<Message> chatMessages;
|
||||
private LocalDate date;
|
||||
private String summary;
|
||||
}
|
||||
@@ -1,12 +1,16 @@
|
||||
package memory;
|
||||
|
||||
import cn.hutool.json.JSONUtil;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import work.slhaf.agent.common.chat.ChatClient;
|
||||
import work.slhaf.agent.common.chat.constant.ChatConstant;
|
||||
import work.slhaf.agent.common.chat.pojo.Message;
|
||||
import work.slhaf.agent.common.model.ModelConstant;
|
||||
import work.slhaf.agent.modules.memory.MemorySelector;
|
||||
|
||||
import java.time.LocalDate;
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
|
||||
public class AITest {
|
||||
@@ -94,6 +98,44 @@ public class AITest {
|
||||
run(input,ModelConstant.SLICE_EVALUATOR_PROMPT);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void coreModelTest(){
|
||||
String input = """
|
||||
{
|
||||
"text": "在",
|
||||
"datetime": "2024-03-22T09:00",
|
||||
"character": "你是一个智能助手,专注于科技领域",
|
||||
"memory_slices": [
|
||||
{
|
||||
"chatMessages": [
|
||||
{"role": "user", "content": "量子计算近期的进展怎么样?"},
|
||||
{"role": "assistant", "content": "量子计算在硬件和算法上都取得了突破,IBM发布了433量子位处理器,Google也在量子优越性上取得了进展。"}
|
||||
],
|
||||
"date": "2024-03-20",
|
||||
"summary": "量子计算最新突破:IBM发布433量子位处理器,Google在量子优越性上取得进展。"
|
||||
}
|
||||
],
|
||||
"static_memory": "用户对量子计算技术非常感兴趣。",
|
||||
"dialog_map": {
|
||||
"2024-03-20T10:30": "与用户讨论了量子计算的最新进展"
|
||||
},
|
||||
"user_dialog_map": {
|
||||
"2024-03-20T10:30": "与用户讨论了量子计算的最新进展"
|
||||
}
|
||||
}
|
||||
|
||||
""";
|
||||
run(input,ModelConstant.CORE_MODEL_PROMPT + "\r\n" + MemorySelector.modulePrompt);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void map2jsonTest(){
|
||||
HashMap<LocalDate,String> map = new HashMap<>();
|
||||
map.put(LocalDate.now(),"hello");
|
||||
map.put(LocalDate.now().plusDays(1),"world");
|
||||
System.out.println(JSONUtil.toJsonPrettyStr(map));
|
||||
}
|
||||
|
||||
private void run(String input, String prompt) {
|
||||
ChatClient client = new ChatClient("https://open.bigmodel.cn/api/paas/v4/chat/completions", "3db444552530b7742b0c53425fb93dcc.LcVwYjByht9AC3N9", "glm-4-flash-250414");
|
||||
List<Message> messages = new ArrayList<>();
|
||||
|
||||
@@ -3,9 +3,9 @@ package memory;
|
||||
import org.junit.Before;
|
||||
import org.junit.Test;
|
||||
import work.slhaf.agent.core.memory.MemoryGraph;
|
||||
import work.slhaf.agent.core.memory.pojo.MemorySlice;
|
||||
import work.slhaf.agent.core.memory.node.MemoryNode;
|
||||
import work.slhaf.agent.core.memory.node.TopicNode;
|
||||
import work.slhaf.agent.core.memory.pojo.MemorySlice;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.time.LocalDate;
|
||||
|
||||
@@ -1,12 +1,10 @@
|
||||
package memory;
|
||||
|
||||
import cn.hutool.core.date.LocalDateTimeUtil;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import work.slhaf.agent.core.memory.MemoryGraph;
|
||||
import work.slhaf.agent.core.memory.node.TopicNode;
|
||||
|
||||
import java.time.LocalDate;
|
||||
import java.time.LocalDateTime;
|
||||
import java.util.HashMap;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
|
||||
@@ -50,7 +48,7 @@ public void test1() {
|
||||
|
||||
// 输出
|
||||
graph.setTopicNodes(topicMap);
|
||||
graph.printTopicTree();
|
||||
System.out.println(graph.getTopicTree());
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -3,18 +3,17 @@ package memory;
|
||||
import org.junit.jupiter.api.BeforeEach;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import work.slhaf.agent.core.memory.MemoryGraph;
|
||||
import work.slhaf.agent.core.memory.pojo.MemorySlice;
|
||||
import work.slhaf.agent.core.memory.exception.UnExistedTopicException;
|
||||
import work.slhaf.agent.core.memory.node.MemoryNode;
|
||||
import work.slhaf.agent.core.memory.node.TopicNode;
|
||||
import work.slhaf.agent.core.memory.pojo.MemoryResult;
|
||||
import work.slhaf.agent.core.memory.pojo.MemorySlice;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.time.LocalDate;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.*;
|
||||
import static org.junit.jupiter.api.Assertions.assertThrows;
|
||||
|
||||
class SearchTest {
|
||||
private MemoryGraph memoryGraph;
|
||||
|
||||
Reference in New Issue
Block a user