推进核心服务注册机制,并调整了Partner的模块结构

- 为了方便调试,将项目分为两个子模块,demo模块中进行新机制的开发工作,core模块为原来的Partner项目;
- 新增了多个注解,用于适配新的核心服务注册机制;
- 在`CapabilityRegisterFactory`中,将首先启动`statusCheck`,检查各个注解是否正常工作,包括以下内容:
   - `CapabilityCore`核心服务与`Capability`接口是否匹配
   - 核心服务中的`CapabilityMethod`是否与`Capability`接口中的方法匹配
   - 是否存在待协调方法`ToCoordinatedMethod`以及对应的存在于`BaseCognationManager`子类实现中
This commit is contained in:
2025-07-15 16:48:27 +08:00
parent 98d830d08b
commit dd10b00fb6
148 changed files with 1082 additions and 500 deletions

20
Partner-Core/pom.xml Normal file
View File

@@ -0,0 +1,20 @@
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<parent>
<groupId>work.slhaf</groupId>
<artifactId>Partner</artifactId>
<version>0.5.0</version>
</parent>
<artifactId>Partner-Core</artifactId>
<properties>
<maven.compiler.source>21</maven.compiler.source>
<maven.compiler.target>21</maven.compiler.target>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
</properties>
</project>

View File

@@ -0,0 +1,14 @@
package work.slhaf;
import work.slhaf.partner.Agent;
import java.io.IOException;
import java.util.Scanner;
public class Main {
public static void main(String[] args) throws IOException {
Agent.initialize();
Scanner scanner = new Scanner(System.in);
while (!scanner.nextLine().equals("exit")) ;
}
}

View File

@@ -0,0 +1,75 @@
package work.slhaf.partner;
import lombok.Data;
import lombok.extern.slf4j.Slf4j;
import work.slhaf.partner.common.config.Config;
import work.slhaf.partner.common.monitor.DebugMonitor;
import work.slhaf.partner.core.InteractionHub;
import work.slhaf.partner.core.interaction.agent_interface.InputReceiver;
import work.slhaf.partner.core.interaction.agent_interface.TaskCallback;
import work.slhaf.partner.core.interaction.data.InteractionInputData;
import work.slhaf.partner.core.interaction.data.InteractionOutputData;
import work.slhaf.partner.gateway.AgentWebSocketServer;
import work.slhaf.partner.gateway.MessageSender;
import java.io.IOException;
import java.time.LocalDateTime;
@Data
@Slf4j
public class Agent implements TaskCallback, InputReceiver {
private static volatile Agent agent;
private InteractionHub interactionHub;
private MessageSender messageSender;
public static void initialize() throws IOException {
if (agent == null) {
synchronized (Agent.class) {
if (agent == null) {
//加载配置
Config config = Config.getConfig();
agent = new Agent();
agent.setInteractionHub(InteractionHub.initialize());
agent.registerTaskCallback();
AgentWebSocketServer server = new AgentWebSocketServer(config.getWebSocketConfig().getPort(), agent);
server.launch();
agent.setMessageSender(server);
log.info("Agent 加载完毕..");
//启动监测线程
DebugMonitor.initialize();
}
}
}
}
public static Agent getInstance() throws IOException {
initialize();
return agent;
}
/**
* 接收用户输入,包装为标准输入数据类
*/
public void receiveInput(InteractionInputData inputData) throws IOException, ClassNotFoundException {
inputData.setLocalDateTime(LocalDateTime.now());
interactionHub.call(inputData);
}
/**
* 向用户返回输出内容
*/
public void sendToUser(String userInfo, String output) {
messageSender.sendMessage(new InteractionOutputData(output, userInfo));
}
@Override
public void onTaskFinished(String userInfo, String output) {
sendToUser(userInfo, output);
}
private void registerTaskCallback() {
interactionHub.setCallback(this);
}
}

View File

@@ -0,0 +1,70 @@
package work.slhaf.partner.common.chat;
import cn.hutool.http.HttpRequest;
import cn.hutool.http.HttpResponse;
import cn.hutool.json.JSONUtil;
import lombok.Data;
import lombok.NoArgsConstructor;
import work.slhaf.partner.common.chat.constant.ChatConstant;
import work.slhaf.partner.common.chat.pojo.ChatBody;
import work.slhaf.partner.common.chat.pojo.ChatResponse;
import work.slhaf.partner.common.chat.pojo.Message;
import work.slhaf.partner.common.chat.pojo.PrimaryChatResponse;
import java.util.List;
@Data
@NoArgsConstructor
public class ChatClient {
private String clientId;
private String url;
private String apikey;
private String model;
private double top_p;
private double temperature;
private int max_tokens;
public ChatClient(String url, String apikey, String model) {
this.url = url;
this.apikey = apikey;
this.model = model;
}
public ChatResponse runChat(List<Message> messages) {
HttpRequest request = HttpRequest.post(url);
request.header("Content-Type", "application/json");
request.header("Authorization", "Bearer " + apikey);
ChatBody body;
if (top_p > 0) {
body = ChatBody.builder()
.model(model)
.messages(messages)
.top_p(top_p)
.temperature(temperature)
.max_tokens(max_tokens)
.build();
} else {
body = ChatBody.builder()
.model(model)
.messages(messages)
.build();
}
HttpResponse response = request.body(JSONUtil.toJsonStr(body)).execute();
ChatResponse finalResponse;
PrimaryChatResponse primaryChatResponse = JSONUtil.toBean(response.body(), PrimaryChatResponse.class);
finalResponse = ChatResponse.builder()
.type(ChatConstant.Response.SUCCESS)
.message(primaryChatResponse.getChoices().get(0).getMessage().getContent())
.usageBean(primaryChatResponse.getUsage())
.build();
response.close();
return finalResponse;
}
}

View File

@@ -0,0 +1,22 @@
package work.slhaf.partner.common.chat.constant;
public class ChatConstant {
public static class Character {
public static final String USER = "user";
public static final String SYSTEM = "system";
public static final String ASSISTANT = "assistant";
}
public static class Model {
public static final String DEEP_SEEK_CHAT = "deepseek-chat";
public static final String GLM_4_FLASH = "glm-4_flash";
public static final String GLM_4_PLUS = "glm-4_plus";
public static final String GLM_4_0520 = "glm-4_0520";
}
public static class Response {
public static final String SUCCESS = "success";
public static final String ERROR = "error";
}
}

View File

@@ -0,0 +1,25 @@
package work.slhaf.partner.common.chat.pojo;
import lombok.*;
import java.util.List;
@Builder
@Data
@AllArgsConstructor
@NoArgsConstructor
public class ChatBody {
@NonNull
private String model;
@NonNull
private List<Message> messages;
@Builder.Default
private double temperature = 1;
@Builder.Default
private double top_p = 1;
private boolean stream;
@Builder.Default
private int max_tokens = 1024;
private int presence_penalty;
private int frequency_penalty;
}

View File

@@ -0,0 +1,16 @@
package work.slhaf.partner.common.chat.pojo;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;
@Data
@Builder
@AllArgsConstructor
@NoArgsConstructor
public class ChatResponse {
private String type;
private String message;
private PrimaryChatResponse.UsageBean usageBean;
}

View File

@@ -0,0 +1,22 @@
package work.slhaf.partner.common.chat.pojo;
import lombok.*;
import work.slhaf.partner.common.serialize.PersistableObject;
import java.io.Serial;
@EqualsAndHashCode(callSuper = true)
@Builder
@Data
@AllArgsConstructor
@NoArgsConstructor
public class Message extends PersistableObject {
@Serial
private static final long serialVersionUID = 1L;
@NonNull
private String role;
@NonNull
private String content;
}

View File

@@ -0,0 +1,20 @@
package work.slhaf.partner.common.chat.pojo;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.EqualsAndHashCode;
import work.slhaf.partner.common.serialize.PersistableObject;
import java.io.Serial;
@EqualsAndHashCode(callSuper = true)
@Data
@AllArgsConstructor
public class MetaMessage extends PersistableObject {
@Serial
private static final long serialVersionUID = 1L;
private Message userMessage;
private Message assistantMessage;
}

View File

@@ -0,0 +1,111 @@
package work.slhaf.partner.common.chat.pojo;
import lombok.Getter;
import lombok.Setter;
import java.util.List;
@Getter
@Setter
public class PrimaryChatResponse {
/**
* id
*/
private String id;
/**
* object
*/
private String object;
/**
* created
*/
private int created;
/**
* model
*/
private String model;
/**
* choices
*/
private List<ChoicesBean> choices;
/**
* usage
*/
private UsageBean usage;
/**
* system_fingerprint
*/
private String system_fingerprint;
@Setter
@Getter
public static class UsageBean {
/**
* prompt_tokens
*/
private int prompt_tokens;
/**
* completion_tokens
*/
private int completion_tokens;
/**
* total_tokens
*/
private int total_tokens;
/**
* prompt_cache_hit_tokens
*/
private int prompt_cache_hit_tokens;
/**
* prompt_cache_miss_tokens
*/
private int prompt_cache_miss_tokens;
@Override
public String toString() {
return "UsageBean{" +
"prompt_tokens=" + prompt_tokens +
", completion_tokens=" + completion_tokens +
", total_tokens=" + total_tokens +
", prompt_cache_hit_tokens=" + prompt_cache_hit_tokens +
", prompt_cache_miss_tokens=" + prompt_cache_miss_tokens +
'}';
}
}
@Setter
@Getter
public static class ChoicesBean {
/**
* index
*/
private int index;
/**
* message
*/
private MessageBean message;
/**
* logprobs
*/
private Object logprobs;
/**
* finish_reason
*/
private String finish_reason;
@Setter
@Getter
public static class MessageBean {
/**
* role
*/
private String role;
/**
* content
*/
private String content;
}
}
}

View File

@@ -0,0 +1,139 @@
package work.slhaf.partner.common.config;
import cn.hutool.json.JSONUtil;
import com.alibaba.fastjson2.JSONArray;
import lombok.Data;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.io.FileUtils;
import work.slhaf.partner.module.modules.core.CoreModel;
import work.slhaf.partner.module.modules.memory.selector.MemorySelector;
import work.slhaf.partner.module.modules.memory.updater.MemoryUpdater;
import work.slhaf.partner.module.modules.process.PostprocessExecutor;
import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.nio.file.Paths;
import java.util.List;
import java.util.Scanner;
@Data
@Slf4j
public class Config {
private static final String CONFIG_FILE_PATH = "./config/config.json";
private static final String LOG_FILE_PATH = "./data/log";
private static Config config;
private String agentId;
// private String basicCharacter;
private WebSocketConfig webSocketConfig;
private List<ModuleConfig> moduleConfigList;
private Config() {
}
public static Config getConfig() throws IOException {
if (config == null) {
File file = new File(CONFIG_FILE_PATH);
if (file.exists()) {
config = JSONUtil.readJSONObject(file, StandardCharsets.UTF_8).toBean(Config.class);
} else {
config = new Config();
Scanner scanner = new Scanner(System.in);
System.out.print("输入智能体名称: ");
config.setAgentId(scanner.nextLine());
System.out.println("(注意! 设定角色之后修改主配置文件将不会影响现有记忆除非同时更换agentId)");
System.out.println("\r\n--------模型配置--------\r\n");
generateModelConfig(scanner);
System.out.println("\r\n--------服务配置--------\r\n");
generateWsSocketConfig(scanner);
System.out.println("\r\n--------模块链配置--------\r\n");
generatePipelineConfig();
boolean launchOrNot = getLaunchOrNot(scanner);
//保存配置文件
String str = JSONUtil.toJsonPrettyStr(config);
FileUtils.writeStringToFile(file, str, StandardCharsets.UTF_8);
log.info("配置已保存");
if (!launchOrNot) {
System.exit(0);
}
}
config.generateCommonDirs();
}
return config;
}
private void generateCommonDirs() throws IOException {
Files.createDirectories(Paths.get(LOG_FILE_PATH));
}
private static boolean getLaunchOrNot(Scanner scanner) {
System.out.print("是否直接启动Partner?(y/n): ");
String input;
while (true) {
input = scanner.nextLine();
if (input.equals("y")) {
return true;
} else if (input.equals("n")) {
return false;
} else {
System.out.println("请输入y或n");
}
}
}
private static void generatePipelineConfig() {
List<ModuleConfig> moduleConfigList = List.of(
new ModuleConfig(MemorySelector.class.getName(), ModuleConfig.Constant.INTERNAL, null),
new ModuleConfig(CoreModel.class.getName(), ModuleConfig.Constant.INTERNAL, null),
new ModuleConfig(PostprocessExecutor.class.getName(),ModuleConfig.Constant.INTERNAL,null),
new ModuleConfig(MemoryUpdater.class.getName(), ModuleConfig.Constant.INTERNAL, null)
// new ModuleConfig(TaskScheduler.class.getName(), ModuleConfig.Constant.INTERNAL, null)
);
config.setModuleConfigList(moduleConfigList);
}
private static void generateWsSocketConfig(Scanner scanner) {
System.out.print("WebSocket port: ");
WebSocketConfig wsConfig = new WebSocketConfig();
wsConfig.setPort(scanner.nextInt());
config.setWebSocketConfig(wsConfig);
}
private static void generateModelConfig(Scanner scanner) throws IOException {
System.out.println("配置LLM APi:");
System.out.println("经测试, 目前只建议选择Qwen3: qwen-plus-latest或qwen-max-latest");
System.out.print("base_url: ");
String baseUrl = scanner.nextLine();
System.out.print("apikey: ");
String apikey = scanner.nextLine();
System.out.print("model: ");
String model = scanner.nextLine();
ModelConfig modelConfig = new ModelConfig();
modelConfig.setBaseUrl(baseUrl);
modelConfig.setApikey(apikey);
modelConfig.setModel(model);
InputStream stream = Config.class.getClassLoader().getResourceAsStream("modules/default_activated_model.json");
String content = new String(stream.readAllBytes(), StandardCharsets.UTF_8);
stream.close();
for (String s : JSONArray.parseArray(content, String.class)) {
modelConfig.generateConfig(s);
}
}
}

View File

@@ -0,0 +1,40 @@
package work.slhaf.partner.common.config;
import cn.hutool.json.JSONUtil;
import lombok.Data;
import org.apache.commons.io.FileUtils;
import java.io.File;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.util.HashMap;
@Data
public class ModelConfig {
private static final String MODEL_CONFIG_DIR_PATH = "./config/model/";
private static final HashMap<String, ModelConfig> modelConfigMap = new HashMap<>();
private String apikey;
private String baseUrl;
private String model;
public void generateConfig(String filename) throws IOException {
String str = JSONUtil.toJsonPrettyStr(this);
File file = new File(MODEL_CONFIG_DIR_PATH + filename + ".json");
FileUtils.writeStringToFile(file, str, StandardCharsets.UTF_8);
}
public static ModelConfig load(String modelKey) {
if (!modelConfigMap.containsKey(modelKey)) {
modelConfigMap.put(modelKey,loadConfig(modelKey));
}
return modelConfigMap.get(modelKey);
}
private static ModelConfig loadConfig(String modelKey) {
File file = new File(MODEL_CONFIG_DIR_PATH+modelKey+".json");
return JSONUtil.readJSONObject(file,StandardCharsets.UTF_8).toBean(ModelConfig.class);
}
}

View File

@@ -0,0 +1,17 @@
package work.slhaf.partner.common.config;
import lombok.AllArgsConstructor;
import lombok.Data;
@Data
@AllArgsConstructor
public class ModuleConfig {
private String className;
private String type;
private String path;
public static class Constant {
public static final String INTERNAL = "internal";
public static final String EXTERNAL = "external";
}
}

View File

@@ -0,0 +1,8 @@
package work.slhaf.partner.common.config;
import lombok.Data;
@Data
public class WebSocketConfig {
private Integer port;
}

View File

@@ -0,0 +1,43 @@
package work.slhaf.partner.common.exception_handler;
import lombok.extern.slf4j.Slf4j;
import work.slhaf.partner.common.exception_handler.pojo.GlobalException;
import work.slhaf.partner.common.exception_handler.pojo.GlobalExceptionData;
import java.io.*;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
@Slf4j
public class GlobalExceptionHandler {
private static final String EXCEPTION_STATIC_PATH = "./data/exception_snapshot/";
public static void writeExceptionState(GlobalException exception) {
GlobalExceptionData exceptionData = exception.getData();
Path filePath = Paths.get(EXCEPTION_STATIC_PATH, exceptionData.getExceptionTime() + ".dat");
try {
Files.createDirectories(Path.of(EXCEPTION_STATIC_PATH));
ObjectOutputStream oos = new ObjectOutputStream(new FileOutputStream(filePath.toFile()));
oos.writeObject(exceptionData);
oos.close();
log.warn("[GlobalExceptionHandler] 捕获异常, 已保存到: {}", filePath);
} catch (IOException e) {
log.error("[GlobalExceptionHandler] 捕获异常, 保存失败: ", e);
}
}
public static GlobalExceptionData readExceptionState(String filePath) {
try {
ObjectInputStream ois = new ObjectInputStream(new FileInputStream(filePath));
GlobalExceptionData exceptionData = (GlobalExceptionData) ois.readObject();
ois.close();
log.info("[GlobalExceptionHandler] 已从: {} 读取异常快照", filePath);
return exceptionData;
} catch (IOException | ClassNotFoundException e) {
log.error("[GlobalExceptionHandler] 读取异常, 读取失败: ", e);
return null;
}
}
}

View File

@@ -0,0 +1,30 @@
package work.slhaf.partner.common.exception_handler.pojo;
import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.extern.slf4j.Slf4j;
import work.slhaf.partner.core.cognation.CognationManager;
import work.slhaf.partner.core.interaction.data.context.InteractionContext;
import work.slhaf.partner.core.session.SessionManager;
@EqualsAndHashCode(callSuper = true)
@Slf4j
@Data
public class GlobalException extends RuntimeException {
private GlobalExceptionData data;
public GlobalException(String message) {
super(message);
try {
this.data = new GlobalExceptionData();
this.data.setExceptionTime(System.currentTimeMillis());
this.data.setSessionManager(SessionManager.getInstance());
this.data.setCognationManager(CognationManager.getInstance());
this.data.setContext(InteractionContext.getInstance());
} catch (Exception e) {
log.error("[GlobalException] 捕获异常, 获取数据失败");
}
}
}

View File

@@ -0,0 +1,26 @@
package work.slhaf.partner.common.exception_handler.pojo;
import lombok.Data;
import lombok.EqualsAndHashCode;
import work.slhaf.partner.common.serialize.PersistableObject;
import work.slhaf.partner.core.cognation.CognationManager;
import work.slhaf.partner.core.interaction.data.context.InteractionContext;
import work.slhaf.partner.core.session.SessionManager;
import java.io.Serial;
import java.util.HashMap;
@EqualsAndHashCode(callSuper = true)
@Data
public class GlobalExceptionData extends PersistableObject {
@Serial
private static final long serialVersionUID = 1L;
private String exceptionMessage;
protected HashMap<String, InteractionContext> context;
protected SessionManager sessionManager;
protected CognationManager cognationManager;
protected Long exceptionTime;
}

View File

@@ -0,0 +1,36 @@
package work.slhaf.partner.common.monitor;
import lombok.extern.slf4j.Slf4j;
import work.slhaf.partner.common.thread.InteractionThreadPoolExecutor;
@Slf4j
public class DebugMonitor {
private InteractionThreadPoolExecutor executor;
private static DebugMonitor debugMonitor;
public static void initialize() {
debugMonitor = new DebugMonitor();
debugMonitor.executor = InteractionThreadPoolExecutor.getInstance();
debugMonitor.runMonitor();
}
private void runMonitor() {
executor.execute(() -> {
while (true) {
try {
Thread.sleep(3000);
} catch (InterruptedException e) {
log.error("监测线程报错?");
}
}
});
}
public static DebugMonitor getInstance(){
if (debugMonitor == null) {
initialize();
}
return debugMonitor;
}
}

View File

@@ -0,0 +1,6 @@
package work.slhaf.partner.common.serialize;
import java.io.Serializable;
public abstract class PersistableObject implements Serializable {
}

View File

@@ -0,0 +1,45 @@
package work.slhaf.partner.common.thread;
import lombok.Getter;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
@Getter
public class InteractionThreadPoolExecutor {
private static InteractionThreadPoolExecutor interactionThreadPoolExecutor;
private final ExecutorService executorService = Executors.newVirtualThreadPerTaskExecutor();
public static InteractionThreadPoolExecutor getInstance() {
if (interactionThreadPoolExecutor == null) {
interactionThreadPoolExecutor = new InteractionThreadPoolExecutor();
}
return interactionThreadPoolExecutor;
}
public <T> void invokeAll(List<Callable<T>> tasks, int time, TimeUnit timeUnit) {
try {
executorService.invokeAll(tasks, time, timeUnit);
} catch (InterruptedException e) {
throw new RuntimeException(e);
}
}
public <T> void invokeAll(List<Callable<T>> tasks) {
try {
executorService.invokeAll(tasks);
} catch (InterruptedException e) {
throw new RuntimeException(e);
}
}
public void execute(Runnable runnable) {
executorService.execute(runnable);
}
}

View File

@@ -0,0 +1,42 @@
package work.slhaf.partner.common.util;
import java.util.ArrayList;
import java.util.List;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
public class ExtractUtil {
public static String extractJson(String jsonStr) {
jsonStr = jsonStr.replace("", "\"").replace("", "\"");
int start = jsonStr.indexOf("{");
int end = jsonStr.lastIndexOf("}");
if (start != -1 && end != -1 && start < end) {
return jsonStr.substring(start, end + 1);
}
return jsonStr;
}
public static String extractUserId(String messageContent) {
Pattern pattern = Pattern.compile("\\[.*\\(([^)]+)\\)\\]");
Matcher matcher = pattern.matcher(messageContent);
if (!matcher.find()) {
return null;
}
return matcher.group(1);
}
public static String fixTopicPath(String topicPath) {
String[] parts = topicPath.split("->");
List<String> cleanedParts = new ArrayList<>();
for (String part : parts) {
// 修正正则表达式,正确移除 [xxx] 部分
String cleaned = part.replaceAll("\\[[^\\]]*\\]", "").trim();
if (!cleaned.isEmpty()) { // 忽略空字符串
cleanedParts.add(cleaned);
}
}
return String.join("->", cleanedParts);
}
}

View File

@@ -0,0 +1,50 @@
package work.slhaf.partner.common.util;
import com.alibaba.fastjson2.JSONArray;
import work.slhaf.partner.Agent;
import work.slhaf.partner.common.chat.pojo.Message;
import java.io.InputStream;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.List;
public class ResourcesUtil {
private static final ClassLoader classloader = Agent.class.getClassLoader();
public static class Prompt {
private static final String SELF_AWARENESS_PATH = "prompt/self_awareness.json";
private static final String MODULE_PROMPT_PREFIX_PATH = "prompt/module/";
public static List<Message> loadPromptWithSelfAwareness(String modelKey, String promptType) {
//加载人格引导
List<Message> messages = new ArrayList<>(loadSelfAwareness());
//加载常规提示
String path = MODULE_PROMPT_PREFIX_PATH + promptType + "/" + modelKey + ".json";
messages.addAll(readPromptFromResources(path));
return messages;
}
public static List<Message> loadSelfAwareness() {
return readPromptFromResources(SELF_AWARENESS_PATH);
}
public static List<Message> loadPrompt(String modelKey,String promptType){
return new ArrayList<>(readPromptFromResources(MODULE_PROMPT_PREFIX_PATH+promptType+"/"+modelKey+".json"));
}
private static List<Message> readPromptFromResources(String filePath) {
try {
InputStream inputStream = classloader.getResourceAsStream(filePath);
String content = new String(inputStream.readAllBytes(), StandardCharsets.UTF_8);
JSONArray array = JSONArray.parse(content);
inputStream.close();
return array.toJavaList(Message.class);
} catch (Exception e) {
throw new RuntimeException("读取Resource失败: " + filePath, e);
}
}
}
}

View File

@@ -0,0 +1,56 @@
package work.slhaf.partner.core;
import lombok.Data;
import lombok.ToString;
import lombok.extern.slf4j.Slf4j;
import work.slhaf.partner.common.exception_handler.GlobalExceptionHandler;
import work.slhaf.partner.common.exception_handler.pojo.GlobalException;
import work.slhaf.partner.core.interaction.agent_interface.TaskCallback;
import work.slhaf.partner.core.interaction.data.InteractionInputData;
import work.slhaf.partner.core.interaction.data.context.InteractionContext;
import work.slhaf.partner.core.interaction.module.InteractionModule;
import work.slhaf.partner.core.interaction.module.InteractionModulesLoader;
import work.slhaf.partner.module.modules.process.PreprocessExecutor;
import java.io.IOException;
import java.util.List;
@Data
@Slf4j
public class InteractionHub {
private static volatile InteractionHub interactionHub;
@ToString.Exclude
private TaskCallback callback;
private List<InteractionModule> interactionModules;
public static InteractionHub initialize() throws IOException {
if (interactionHub == null) {
synchronized (InteractionHub.class) {
if (interactionHub == null) {
interactionHub = new InteractionHub();
//加载模块
interactionHub.setInteractionModules(InteractionModulesLoader.getInstance().registerInteractionModules());
log.info("InteractionHub注册完毕...");
}
}
}
return interactionHub;
}
public void call(InteractionInputData inputData) throws IOException, ClassNotFoundException {
InteractionContext interactionContext = PreprocessExecutor.getInstance().execute(inputData);
try {
for (InteractionModule interactionModule : interactionModules) {
interactionModule.execute(interactionContext);
}
} catch (GlobalException e) {
GlobalExceptionHandler.writeExceptionState(e);
interactionContext.getCoreResponse().put("text", "[ERROR] " + e.getMessage());
} finally {
callback.onTaskFinished(interactionContext.getUserInfo(), interactionContext.getCoreResponse().getString("text"));
interactionContext.clearUp();
}
}
}

View File

@@ -0,0 +1,323 @@
package work.slhaf.partner.core.cognation;
import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.extern.slf4j.Slf4j;
import work.slhaf.partner.common.chat.constant.ChatConstant;
import work.slhaf.partner.common.chat.pojo.Message;
import work.slhaf.partner.common.config.Config;
import work.slhaf.partner.common.exception_handler.GlobalExceptionHandler;
import work.slhaf.partner.common.exception_handler.pojo.GlobalException;
import work.slhaf.partner.common.serialize.PersistableObject;
import work.slhaf.partner.core.cognation.capability.ability.CacheCapability;
import work.slhaf.partner.core.cognation.capability.ability.CognationCapability;
import work.slhaf.partner.core.cognation.capability.ability.MemoryCapability;
import work.slhaf.partner.core.cognation.capability.ability.PerceiveCapability;
import work.slhaf.partner.core.cognation.cognation.CognationCore;
import work.slhaf.partner.core.cognation.cognation.exception.UserNotExistsException;
import work.slhaf.partner.core.cognation.cognation.pojo.ActiveData;
import work.slhaf.partner.core.cognation.common.pojo.MemoryResult;
import work.slhaf.partner.core.cognation.common.pojo.MemorySliceResult;
import work.slhaf.partner.core.cognation.submodule.cache.CacheCore;
import work.slhaf.partner.core.cognation.submodule.dispatch.DispatchCore;
import work.slhaf.partner.core.cognation.submodule.memory.MemoryCore;
import work.slhaf.partner.core.cognation.submodule.memory.pojo.EvaluatedSlice;
import work.slhaf.partner.core.cognation.submodule.memory.pojo.MemorySlice;
import work.slhaf.partner.core.cognation.submodule.perceive.PerceiveCore;
import work.slhaf.partner.core.cognation.submodule.perceive.pojo.User;
import java.io.IOException;
import java.io.Serial;
import java.time.LocalDate;
import java.time.LocalDateTime;
import java.util.*;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReentrantLock;
import static work.slhaf.partner.common.util.ExtractUtil.extractUserId;
@EqualsAndHashCode(callSuper = true)
@Data
@Slf4j
public class CognationManager extends PersistableObject implements CacheCapability, MemoryCapability, PerceiveCapability, CognationCapability {
@Serial
private static final long serialVersionUID = 1L;
private static volatile CognationManager cognationManager;
private final Lock sliceInsertLock = new ReentrantLock();
@Getter
public final Lock messageLock = new ReentrantLock();
private CognationCore cognationCore;
private CacheCore cacheCore;
private MemoryCore memoryCore;
private PerceiveCore perceiveCore;
private DispatchCore dispatchCore;
private ActiveData activeData;
private CognationManager() {
}
public static CognationManager getInstance() throws IOException, ClassNotFoundException {
if (cognationManager == null) {
synchronized (CognationManager.class) {
if (cognationManager == null) {
Config config = Config.getConfig();
cognationManager = new CognationManager();
cognationManager.setCognationCore(CognationCore.getInstance(config.getAgentId()));
cognationManager.setCores();
cognationManager.setActiveData(new ActiveData());
cognationManager.setShutdownHook();
log.info("[CognationManager] MemoryManager注册完毕...");
}
}
}
return cognationManager;
}
private void setCores() {
this.setCacheCore(this.getCognationCore().getCacheCore());
this.setMemoryCore(this.getCognationCore().getMemoryCore());
this.setPerceiveCore(this.getCognationCore().getPerceiveCore());
}
private void setShutdownHook() {
Runtime.getRuntime().addShutdownHook(new Thread(() -> {
try {
cognationManager.save();
log.info("[CognationManager] MemoryGraph已保存");
} catch (IOException e) {
log.error("[CognationManager] 保存MemoryGraph失败: ", e);
}
}));
}
@Override
public MemoryResult selectMemory(String topicPathStr) {
MemoryResult memoryResult;
List<String> topicPath = List.of(topicPathStr.split("->"));
try {
List<String> path = new ArrayList<>(topicPath);
//每日刷新缓存
cacheCore.checkCacheDate();
//检测缓存并更新计数, 查看是否需要放入缓存
cacheCore.updateCacheCounter(path);
//查看是否存在缓存,如果存在,则直接返回
if ((memoryResult = cacheCore.selectCache(path)) != null) {
return memoryResult;
}
memoryResult = memoryCore.selectMemory(path);
//尝试更新缓存
cacheCore.updateCache(topicPath, memoryResult);
} catch (Exception e) {
log.error("[CognationManager] selectMemory error: ", e);
log.error("[CognationManager] 路径: {}", topicPathStr);
log.error("[CognationManager] 主题树: {}", getTopicTree());
memoryResult = new MemoryResult();
memoryResult.setRelatedMemorySliceResult(new ArrayList<>());
memoryResult.setMemorySliceResult(new CopyOnWriteArrayList<>());
GlobalExceptionHandler.writeExceptionState(new GlobalException(e.getLocalizedMessage()));
}
return cacheFilter(memoryResult);
}
@Override
public MemoryResult selectMemory(LocalDate date) throws IOException, ClassNotFoundException {
return cacheFilter(memoryCore.selectMemory(date));
}
private MemoryResult cacheFilter(MemoryResult memoryResult) {
//过滤掉与缓存重复的切片
CopyOnWriteArrayList<MemorySliceResult> memorySliceResult = memoryResult.getMemorySliceResult();
List<MemorySlice> relatedMemorySliceResult = memoryResult.getRelatedMemorySliceResult();
getDialogMap().forEach((k, v) -> {
memorySliceResult.removeIf(m -> m.getMemorySlice().getSummary().equals(v));
relatedMemorySliceResult.removeIf(m -> m.getSummary().equals(v));
});
return memoryResult;
}
@Override
public void cleanSelectedSliceFilter() {
memoryCore.getSelectedSlices().clear();
}
@Override
public User getUser(String userInfo, String client) {
return perceiveCore.selectUser(userInfo, client);
}
@Override
public List<Message> getChatMessages() {
return cognationCore.getChatMessages();
}
@Override
public void setChatMessages(List<Message> chatMessages) {
cognationCore.setChatMessages(chatMessages);
}
@Override
public String getTopicTree() {
return memoryCore.getTopicTree();
}
@Override
public HashMap<LocalDateTime, String> getDialogMap() {
return cacheCore.getDialogMap();
}
@Override
public ConcurrentHashMap<LocalDateTime, String> getUserDialogMap(String userId) {
return cacheCore.getUserDialogMap().get(userId);
}
@Override
public void insertSlice(MemorySlice memorySlice, String topicPath) {
sliceInsertLock.lock();
List<String> topicPathList = Arrays.stream(topicPath.split("->")).toList();
try {
//检查是否存在当天对应的memorySlice并确定是否插入
//每日刷新缓存
cacheCore.checkCacheDate();
//如果topicPath在memorySliceCache中存在对应缓存由于进行的插入操作则需要移除该缓存但不清除相关计数
cacheCore.clearCacheByTopicPath(topicPathList);
memoryCore.insertMemory(topicPathList, memorySlice);
if (!memorySlice.isPrivate()) {
cacheCore.updateUserDialogMap(memorySlice);
}
} catch (Exception e) {
log.error("[CognationManager] 插入记忆时出错: ", e);
GlobalExceptionHandler.writeExceptionState(new GlobalException("插入记忆时出错: " + e.getLocalizedMessage()));
}
log.debug("[CognationManager] 插入切片: {}, 路径: {}", memorySlice, topicPath);
sliceInsertLock.unlock();
}
@Override
public void cleanMessage(List<Message> messages) {
messageLock.lock();
cognationCore.getChatMessages().removeAll(messages);
messageLock.unlock();
}
@Override
public void updateDialogMap(LocalDateTime dateTime, String newDialogCache) {
cacheCore.updateDialogMap(dateTime, newDialogCache);
}
private void save() throws IOException {
cognationCore.serialize();
}
@Override
public void updateActivatedSlices(String userId, List<EvaluatedSlice> memorySlices) {
activeData.updateActivatedSlices(userId, memorySlices);
log.debug("[CognationManager] 已更新激活切片, userId: {}", userId);
}
@Override
public User getUser(String id) {
User user = perceiveCore.selectUser(id);
if (user == null) {
throw new UserNotExistsException("[CognationManager] 用户不存在: " + id);
}
return user;
}
@Override
public String getActivatedSlicesStr(String userId) {
return activeData.getActivatedSlicesStr(userId);
}
@Override
public String getDialogMapStr() {
StringBuilder str = new StringBuilder();
cacheCore.getDialogMap().forEach((dateTime, dialog) -> str.append("\n\n").append("[").append(dateTime).append("]\n")
.append(dialog));
return str.toString();
}
@Override
public String getUserDialogMapStr(String userId) {
if (cacheCore.getUserDialogMap().containsKey(userId)) {
StringBuilder str = new StringBuilder();
Collection<String> dialogMapValues = cacheCore.getDialogMap().values();
cacheCore.getUserDialogMap().get(userId).forEach((dateTime, dialog) -> {
if (dialogMapValues.contains(dialog)) {
return;
}
str.append("\n\n").append("[").append(dateTime).append("]\n")
.append(dialog);
});
return str.toString();
} else {
return null;
}
}
private boolean isCacheSingleUser() {
return cacheCore.getUserDialogMap().size() <= 1;
}
@Override
public boolean isSingleUser() {
return isCacheSingleUser() && isChatMessagesSingleUser();
}
private boolean isChatMessagesSingleUser() {
Set<String> userIdSet = new HashSet<>();
cognationManager.getChatMessages().forEach(m -> {
if (m.getRole().equals(ChatConstant.Character.ASSISTANT)) {
return;
}
String userId = extractUserId(m.getContent());
if (userId == null || userId.isEmpty()) {
return;
}
userIdSet.add(userId);
});
return userIdSet.size() <= 1;
}
@Override
public User addUser(String userInfo, String platform, String userNickName) {
return perceiveCore.addUser(userInfo, platform, userNickName);
}
@Override
public void updateUser(User tempUser) {
perceiveCore.updateUser(tempUser);
}
@Override
public HashMap<String, List<EvaluatedSlice>> getActivatedSlices() {
return activeData.getActivatedSlices();
}
@Override
public void clearActivatedSlices(String userId) {
activeData.clearActivatedSlices(userId);
}
@Override
public boolean hasActivatedSlices(String userId) {
return activeData.hasActivatedSlices(userId);
}
@Override
public int getActivatedSlicesSize(String userId) {
return activeData.getActivatedSlices().get(userId).size();
}
@Override
public List<EvaluatedSlice> getActivatedSlices(String userId) {
return activeData.getActivatedSlices().get(userId);
}
}

View File

@@ -0,0 +1,14 @@
package work.slhaf.partner.core.cognation.capability;
import org.reflections.Reflections;
import work.slhaf.partner.core.cognation.capability.exception.CapabilityRegisterFailedException;
import work.slhaf.partner.core.cognation.capability.interfaces.Capability;
import work.slhaf.partner.core.cognation.capability.interfaces.CapabilityCore;
import java.util.ArrayList;
import java.util.List;
import java.util.Set;
public class CapabilityRegisterFactory {
}

View File

@@ -0,0 +1,13 @@
package work.slhaf.partner.core.cognation.capability.ability;
import java.time.LocalDateTime;
import java.util.HashMap;
import java.util.concurrent.ConcurrentHashMap;
public interface CacheCapability {
HashMap<LocalDateTime, String> getDialogMap();
ConcurrentHashMap<LocalDateTime, String> getUserDialogMap(String userId);
void updateDialogMap(LocalDateTime dateTime, String newDialogCache);
String getDialogMapStr();
String getUserDialogMapStr(String userId);
}

View File

@@ -0,0 +1,23 @@
package work.slhaf.partner.core.cognation.capability.ability;
import work.slhaf.partner.common.chat.pojo.Message;
import work.slhaf.partner.core.cognation.submodule.memory.pojo.EvaluatedSlice;
import java.util.HashMap;
import java.util.List;
import java.util.concurrent.locks.Lock;
public interface CognationCapability {
List<Message> getChatMessages();
void setChatMessages(List<Message> chatMessages);
void cleanMessage(List<Message> messages);
void updateActivatedSlices(String userId, List<EvaluatedSlice> memorySlices);
String getActivatedSlicesStr(String userId);
HashMap<String, List<EvaluatedSlice>> getActivatedSlices();
void clearActivatedSlices(String userId);
boolean hasActivatedSlices(String userId);
int getActivatedSlicesSize(String userId);
List<EvaluatedSlice> getActivatedSlices(String userId);
boolean isSingleUser();
Lock getMessageLock();
}

View File

@@ -0,0 +1,4 @@
package work.slhaf.partner.core.cognation.capability.ability;
public interface DispatchCapability {
}

View File

@@ -0,0 +1,15 @@
package work.slhaf.partner.core.cognation.capability.ability;
import work.slhaf.partner.core.cognation.common.pojo.MemoryResult;
import work.slhaf.partner.core.cognation.submodule.memory.pojo.MemorySlice;
import java.io.IOException;
import java.time.LocalDate;
public interface MemoryCapability {
MemoryResult selectMemory(String topicPathStr);
MemoryResult selectMemory(LocalDate date) throws IOException, ClassNotFoundException;
void insertSlice(MemorySlice memorySlice, String topicPath);
void cleanSelectedSliceFilter();
String getTopicTree();
}

View File

@@ -0,0 +1,10 @@
package work.slhaf.partner.core.cognation.capability.ability;
import work.slhaf.partner.core.cognation.submodule.perceive.pojo.User;
public interface PerceiveCapability {
User getUser(String userInfo, String client);
User getUser(String id);
User addUser(String userInfo, String platform, String userNickName);
void updateUser(User user);
}

View File

@@ -0,0 +1,7 @@
package work.slhaf.partner.core.cognation.capability.exception;
public class CapabilityRegisterFailedException extends RuntimeException {
public CapabilityRegisterFailedException(String message) {
super(message);
}
}

View File

@@ -0,0 +1,15 @@
package work.slhaf.partner.core.cognation.capability.interfaces;
import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;
/**
* 用于注解能力接口,需要与`@CapabilityCore`对应的`value`一致
*/
@Retention(RetentionPolicy.RUNTIME)
@Target(ElementType.TYPE)
public @interface Capability {
String value();
}

View File

@@ -0,0 +1,15 @@
package work.slhaf.partner.core.cognation.capability.interfaces;
import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;
/**
* 用于注解Core服务需标识一个value致用于核心服务发现
*/
@Retention(RetentionPolicy.RUNTIME)
@Target(ElementType.TYPE)
public @interface CapabilityCore {
String value();
}

View File

@@ -0,0 +1,15 @@
package work.slhaf.partner.core.cognation.capability.interfaces;
import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;
/**
* 用于标注协调方法,`value`值需与对应的`@ToCoordinated`保持一致
*/
@Retention(RetentionPolicy.RUNTIME)
@Target(ElementType.METHOD)
public @interface Coordinated {
String value();
}

View File

@@ -0,0 +1,14 @@
package work.slhaf.partner.core.cognation.capability.interfaces;
import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;
/**
* 用于注入`Capability`
*/
@Target(ElementType.FIELD)
@Retention(RetentionPolicy.RUNTIME)
public @interface InjectCapability {
}

View File

@@ -0,0 +1,16 @@
package work.slhaf.partner.core.cognation.capability.interfaces;
import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;
/**
* 当`@Capability`所注接口中如果存在方法需要协调多个Core服务的调用可以通过该注解进行排除
* value值为方法对应标识需与协调实现处的方法标识保持一致
*/
@Target(ElementType.METHOD)
@Retention(RetentionPolicy.RUNTIME)
public @interface ToCoordinated {
String value();
}

View File

@@ -0,0 +1,107 @@
package work.slhaf.partner.core.cognation.cognation;
import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.io.FileUtils;
import work.slhaf.partner.common.chat.pojo.Message;
import work.slhaf.partner.common.serialize.PersistableObject;
import work.slhaf.partner.core.cognation.submodule.cache.CacheCore;
import work.slhaf.partner.core.cognation.submodule.memory.MemoryCore;
import work.slhaf.partner.core.cognation.submodule.perceive.PerceiveCore;
import java.io.*;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.nio.file.StandardCopyOption;
import java.util.ArrayList;
import java.util.List;
@EqualsAndHashCode(callSuper = true)
@Data
@Slf4j
public class CognationCore extends PersistableObject {
@Serial
private static final long serialVersionUID = 1L;
private static final String STORAGE_DIR = "./data/memory/";
private static volatile CognationCore cognationCore;
private String id;
private MemoryCore memoryCore = new MemoryCore();
private CacheCore cacheCore = new CacheCore();
private PerceiveCore perceiveCore = new PerceiveCore();
/**
* 主模型的聊天记录
*/
private List<Message> chatMessages = new ArrayList<>();
public CognationCore(String id) {
this.id = id;
}
public static CognationCore getInstance(String id) throws IOException, ClassNotFoundException {
if (cognationCore == null) {
synchronized (CognationCore.class) {
// 检查存储目录是否存在,不存在则创建
if (cognationCore == null) {
createStorageDirectory();
Path filePath = getFilePath(id);
if (Files.exists(filePath)) {
cognationCore = deserialize(id);
} else {
FileUtils.createParentDirectories(filePath.toFile().getParentFile());
cognationCore = new CognationCore(id);
cognationCore.serialize();
}
log.info("CognationCore注册完毕...");
}
}
}
return cognationCore;
}
public void serialize() throws IOException {
//先写入到临时文件,如果正常写入则覆盖原文件
Path filePath = getFilePath(this.id + "-temp");
Files.createDirectories(Path.of(STORAGE_DIR));
try {
ObjectOutputStream oos = new ObjectOutputStream(new FileOutputStream(filePath.toFile()));
oos.writeObject(this);
oos.close();
Path path = getFilePath(this.id);
Files.move(filePath, path, StandardCopyOption.REPLACE_EXISTING);
log.info("CognationCore 已保存到: {}", path);
} catch (IOException e) {
Files.delete(filePath);
log.error("序列化保存失败: {}", e.getMessage());
}
}
private static CognationCore deserialize(String id) throws IOException, ClassNotFoundException {
Path filePath = getFilePath(id);
try (ObjectInputStream ois = new ObjectInputStream(
new FileInputStream(filePath.toFile()))) {
CognationCore graph = (CognationCore) ois.readObject();
log.info("CognationCore 已从文件加载: {}", filePath);
return graph;
}
}
private static Path getFilePath(String id) {
return Paths.get(STORAGE_DIR, id + ".memory");
}
private static void createStorageDirectory() {
try {
Files.createDirectories(Paths.get(STORAGE_DIR));
} catch (IOException e) {
System.err.println("创建存储目录失败: " + e.getMessage());
}
}
}

View File

@@ -0,0 +1,7 @@
package work.slhaf.partner.core.cognation.cognation.exception;
public class UserNotExistsException extends RuntimeException {
public UserNotExistsException(String message) {
super(message);
}
}

View File

@@ -0,0 +1,38 @@
package work.slhaf.partner.core.cognation.cognation.pojo;
import lombok.Data;
import work.slhaf.partner.core.cognation.submodule.memory.pojo.EvaluatedSlice;
import java.util.HashMap;
import java.util.List;
@Data
public class ActiveData {
private HashMap<String, List<EvaluatedSlice>> activatedSlices;
public void updateActivatedSlices(String userId, List<EvaluatedSlice> memorySlices) {
activatedSlices.put(userId, memorySlices);
}
public String getActivatedSlicesStr(String userId) {
if (activatedSlices.containsKey(userId)) {
StringBuilder str = new StringBuilder();
activatedSlices.get(userId).forEach(slice -> str.append("\n\n").append("[").append(slice.getDate()).append("]\n")
.append(slice.getSummary()));
return str.toString();
} else {
return null;
}
}
public void clearActivatedSlices(String userId) {
activatedSlices.remove(userId);
}
public boolean hasActivatedSlices(String userId) {
if (!activatedSlices.containsKey(userId)){
return false;
}
return !activatedSlices.get(userId).isEmpty();
}
}

View File

@@ -0,0 +1,27 @@
package work.slhaf.partner.core.cognation.common.pojo;
import lombok.Data;
import lombok.EqualsAndHashCode;
import work.slhaf.partner.common.serialize.PersistableObject;
import work.slhaf.partner.core.cognation.submodule.memory.pojo.MemorySlice;
import java.io.Serial;
import java.util.List;
import java.util.concurrent.CopyOnWriteArrayList;
@EqualsAndHashCode(callSuper = true)
@Data
public class MemoryResult extends PersistableObject {
@Serial
private static final long serialVersionUID = 1L;
private CopyOnWriteArrayList<MemorySliceResult> memorySliceResult;
private List<MemorySlice> relatedMemorySliceResult;
public boolean isEmpty(){
boolean a = memorySliceResult == null || memorySliceResult.isEmpty();
boolean b = relatedMemorySliceResult == null || relatedMemorySliceResult.isEmpty();
return a && b;
}
}

View File

@@ -0,0 +1,25 @@
package work.slhaf.partner.core.cognation.common.pojo;
import com.alibaba.fastjson2.annotation.JSONField;
import lombok.Data;
import lombok.EqualsAndHashCode;
import work.slhaf.partner.common.serialize.PersistableObject;
import work.slhaf.partner.core.cognation.submodule.memory.pojo.MemorySlice;
import java.io.Serial;
@EqualsAndHashCode(callSuper = true)
@Data
public class MemorySliceResult extends PersistableObject {
@Serial
private static final long serialVersionUID = 1L;
@JSONField(serialize = false)
private MemorySlice sliceBefore;
private MemorySlice memorySlice;
@JSONField(serialize = false)
private MemorySlice sliceAfter;
}

View File

@@ -0,0 +1,132 @@
package work.slhaf.partner.core.cognation.submodule.cache;
import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.extern.slf4j.Slf4j;
import work.slhaf.partner.common.serialize.PersistableObject;
import work.slhaf.partner.core.cognation.common.pojo.MemoryResult;
import work.slhaf.partner.core.cognation.submodule.memory.pojo.MemorySlice;
import java.io.Serial;
import java.time.LocalDate;
import java.time.LocalDateTime;
import java.util.*;
import java.util.concurrent.ConcurrentHashMap;
@EqualsAndHashCode(callSuper = true)
@Data
@Slf4j
public class CacheCore extends PersistableObject {
@Serial
private static final long serialVersionUID = 1L;
/**
* 近两日的对话总结缓存, 用于为大模型提供必要的记忆补充, hashmap以切片的存储时间为键总结为值
* 该部分作为'主LLM'system prompt常驻
* 该部分作为近两日的整体对话缓存, 不区分用户
*/
private HashMap<LocalDateTime, String> dialogMap = new HashMap<>();
/**
* 近两日的区分用户的对话总结缓存在prompt结构上比dialogMap层级深一层, dialogMap更具近两日整体对话的摘要性质
*/
private ConcurrentHashMap<String/*userId*/, ConcurrentHashMap<LocalDateTime, String>> userDialogMap = new ConcurrentHashMap<>();
/**
* memorySliceCache计数器每日清空
*/
private ConcurrentHashMap<List<String> /*触发查询的主题列表*/, Integer> memoryNodeCacheCounter = new ConcurrentHashMap<>();
/**
* 记忆切片缓存,每日清空
* 用于记录作为终点节点调用次数最多的记忆节点的切片数据
*/
private ConcurrentHashMap<List<String> /*主题路径*/, MemoryResult /*切片列表*/> memorySliceCache = new ConcurrentHashMap<>();
/**
* 缓存日期
*/
private LocalDate cacheDate;
/**
* 已被选中的切片时间戳集合,需要及时清理
*/
private Set<Long> selectedSlices = new HashSet<>();
public void updateCacheCounter(List<String> topicPath) {
if (memoryNodeCacheCounter.containsKey(topicPath)) {
Integer tempCount = memoryNodeCacheCounter.get(topicPath);
memoryNodeCacheCounter.put(topicPath, ++tempCount);
} else {
memoryNodeCacheCounter.put(topicPath, 1);
}
}
public void checkCacheDate() {
if (cacheDate == null || cacheDate.isBefore(LocalDate.now())) {
memorySliceCache.clear();
memoryNodeCacheCounter.clear();
cacheDate = LocalDate.now();
}
}
public void updateDialogMap(LocalDateTime dateTime, String newDialogCache) {
List<LocalDateTime> keysToRemove = new ArrayList<>();
dialogMap.forEach((k, v) -> {
if (dateTime.minusDays(2).isAfter(k)) {
keysToRemove.add(k);
}
});
for (LocalDateTime temp : keysToRemove) {
dialogMap.remove(temp);
}
keysToRemove.clear();
//放入新缓存
dialogMap.put(dateTime, newDialogCache);
}
public void updateCache(List<String> topicPath, MemoryResult memoryResult) {
Integer tempCount = memoryNodeCacheCounter.get(topicPath);
if (tempCount == null) {
log.warn("[CacheCore] tempCount为null? memoryNodeCacheCounter: {}; topicPath: {}", memoryNodeCacheCounter, topicPath);
return;
}
if (tempCount >= 5) {
memorySliceCache.put(topicPath, memoryResult);
}
}
public void updateUserDialogMap(MemorySlice slice) {
String summary = slice.getSummary();
LocalDateTime now = LocalDateTime.now();
//更新userDialogMap
//移除两天前上下文缓存(切片总结)
List<LocalDateTime> keysToRemove = new ArrayList<>();
userDialogMap.forEach((k, v) -> v.forEach((i, j) -> {
if (now.minusDays(2).isAfter(i)) {
keysToRemove.add(i);
}
}));
for (LocalDateTime dateTime : keysToRemove) {
userDialogMap.forEach((k, v) -> v.remove(dateTime));
}
//放入新缓存
userDialogMap
.computeIfAbsent(slice.getStartUserId(), k -> new ConcurrentHashMap<>())
.merge(now, summary, (oldVal, newVal) -> oldVal + " " + newVal);
}
public void clearCacheByTopicPath(List<String> topicPath) {
memorySliceCache.remove(topicPath);
}
public MemoryResult selectCache(List<String> path) {
if (memorySliceCache.containsKey(path)) {
return memorySliceCache.get(path);
}
return null;
}
}

View File

@@ -0,0 +1,32 @@
package work.slhaf.partner.core.cognation.submodule.dispatch;
import work.slhaf.partner.common.serialize.PersistableObject;
import work.slhaf.partner.core.cognation.submodule.dispatch.pojo.DispatchData;
import java.io.Serial;
public class DispatchCore extends PersistableObject {
@Serial
private static final long serialVersionUID = 1L;
public static volatile DispatchCore dispatchCore;
public static DispatchCore getInstance() {
if (dispatchCore == null) {
synchronized (DispatchCore.class) {
if (dispatchCore == null) {
dispatchCore = new DispatchCore();
}
}
}
return dispatchCore;
}
public void dispatch(DispatchData dispatchData){
}
public void listDispatchData(){
}
}

View File

@@ -0,0 +1,15 @@
package work.slhaf.partner.core.cognation.submodule.dispatch.pojo;
import lombok.Data;
import java.time.LocalDateTime;
@Data
public class DispatchData {
private LocalDateTime dateTime;
private String userId;
private String comment;
//TODO 替换为<执行器>或者<插件>
private String executor;
}

View File

@@ -0,0 +1,317 @@
package work.slhaf.partner.core.cognation.submodule.memory;
import lombok.Data;
import lombok.EqualsAndHashCode;
import work.slhaf.partner.common.serialize.PersistableObject;
import work.slhaf.partner.core.cognation.common.pojo.MemoryResult;
import work.slhaf.partner.core.cognation.common.pojo.MemorySliceResult;
import work.slhaf.partner.core.cognation.submodule.memory.exception.UnExistedDateIndexException;
import work.slhaf.partner.core.cognation.submodule.memory.exception.UnExistedTopicException;
import work.slhaf.partner.core.cognation.submodule.memory.pojo.MemorySlice;
import work.slhaf.partner.core.cognation.submodule.memory.pojo.node.MemoryNode;
import work.slhaf.partner.core.cognation.submodule.memory.pojo.node.TopicNode;
import java.io.IOException;
import java.io.Serial;
import java.time.LocalDate;
import java.util.*;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.CopyOnWriteArrayList;
@EqualsAndHashCode(callSuper = true)
@Data
public class MemoryCore extends PersistableObject {
@Serial
private static final long serialVersionUID = 1L;
/**
* key: 根主题名称 value: 根主题节点
*/
private HashMap<String, TopicNode> topicNodes = new HashMap<>();
/**
* 用于存储已存在的主题列表,便于记忆查找, 使用根主题名称作为键, 子主题名称集合为值
* 该部分在'主题提取LLM'的system prompt中常驻
*/
private HashMap<String /*根主题名*/, LinkedHashSet<String> /*子主题列表*/> existedTopics = new HashMap<>();
/**
* 临时的同一对话切片容器, 用于为同一对话内的不同切片提供更新上下文的场所
*/
private HashMap<String /*对话id, 即slice中的字段'memoryId'*/, List<MemorySlice>> currentDateDialogSlices = new HashMap<>();
/**
* 记忆节点的日期索引, 同一日期内按照对话id区分
*/
private HashMap<LocalDate, Set<String>> dateIndex = new HashMap<>();
/**
* 已被选中的切片时间戳集合,需要及时清理
*/
private Set<Long> selectedSlices = new HashSet<>();
private HashMap<String,List<String>> userIndex = new HashMap<>();
public MemoryResult selectMemory(LocalDate date) throws IOException, ClassNotFoundException {
MemoryResult memoryResult = new MemoryResult();
CopyOnWriteArrayList<MemorySliceResult> targetSliceList = new CopyOnWriteArrayList<>();
//加载节点并获取记忆切片列表
List<List<MemorySlice>> currentDateDialogSlices = loadSlicesByDate(date);
for (List<MemorySlice> value : currentDateDialogSlices) {
for (MemorySlice memorySlice : value) {
if (selectedSlices.contains(memorySlice.getTimestamp())) {
continue;
}
MemorySliceResult memorySliceResult = new MemorySliceResult();
memorySliceResult.setMemorySlice(memorySlice);
targetSliceList.add(memorySliceResult);
selectedSlices.add(memorySlice.getTimestamp());
}
}
memoryResult.setMemorySliceResult(targetSliceList);
return memoryResult;
}
private List<List<MemorySlice>> loadSlicesByDate(LocalDate date) throws IOException, ClassNotFoundException {
if (!dateIndex.containsKey(date)) {
throw new UnExistedDateIndexException("不存在的日期索引: " + date);
}
List<List<MemorySlice>> list = new ArrayList<>();
for (String memoryNodeId : dateIndex.get(date)) {
MemoryNode memoryNode = new MemoryNode();
memoryNode.setMemoryNodeId(memoryNodeId);
list.add(memoryNode.loadMemorySliceList());
}
return list;
}
public String getTopicTree() {
StringBuilder stringBuilder = new StringBuilder();
for (Map.Entry<String, TopicNode> entry : topicNodes.entrySet()) {
String rootName = entry.getKey();
TopicNode rootNode = entry.getValue();
stringBuilder.append(rootName).append("[root]").append("\r\n");
printSubTopicsTreeFormat(rootNode, "", stringBuilder);
}
return stringBuilder.toString();
}
private void printSubTopicsTreeFormat(TopicNode node, String prefix, StringBuilder stringBuilder) {
if (node.getTopicNodes() == null || node.getTopicNodes().isEmpty()) return;
List<Map.Entry<String, TopicNode>> entries = new ArrayList<>(node.getTopicNodes().entrySet());
for (int i = 0; i < entries.size(); i++) {
boolean last = (i == entries.size() - 1);
Map.Entry<String, TopicNode> entry = entries.get(i);
stringBuilder.append(prefix).append(last ? "└── " : "├── ").append(entry.getKey()).append("[").append(entry.getValue().getMemoryNodes().size()).append("]").append("\r\n");
printSubTopicsTreeFormat(entry.getValue(), prefix + (last ? " " : ""), stringBuilder);
}
}
public void insertMemory(List<String> topicPath, MemorySlice slice) throws IOException, ClassNotFoundException {
LocalDate now = LocalDate.now();
boolean hasSlice = false;
MemoryNode node = null;
TopicNode lastTopicNode = generateTopicPath(topicPath);
for (MemoryNode memoryNode : lastTopicNode.getMemoryNodes()) {
if (now.equals(memoryNode.getLocalDate())) {
hasSlice = true;
node = memoryNode;
break;
}
}
if (!hasSlice) {
node = new MemoryNode();
node.setLocalDate(now);
node.setMemoryNodeId(UUID.randomUUID().toString());
node.setMemorySliceList(new CopyOnWriteArrayList<>());
lastTopicNode.getMemoryNodes().add(node);
lastTopicNode.getMemoryNodes().sort(null);
}
node.loadMemorySliceList().add(slice);
//生成relatedTopicPath
for (List<String> relatedTopic : slice.getRelatedTopics()) {
generateTopicPath(relatedTopic);
}
updateSlicePrecedent(slice);
updateDateIndex(slice);
updateUserIndex(slice);
node.saveMemorySliceList();
}
private void updateUserIndex(MemorySlice slice) {
String memoryId = slice.getMemoryId();
String userId = slice.getStartUserId();
if (!userIndex.containsKey(userId)) {
List<String> memoryIdSet = new ArrayList<>();
memoryIdSet.add(memoryId);
userIndex.put(userId, memoryIdSet);
} else {
userIndex.get(userId).add(memoryId);
}
}
private TopicNode generateTopicPath(List<String> topicPath) {
topicPath = new ArrayList<>(topicPath);
//查看是否存在根主题节点
String rootTopic = topicPath.getFirst();
topicPath.removeFirst();
if (!topicNodes.containsKey(rootTopic)) {
synchronized (this) {
if (!topicNodes.containsKey(rootTopic)) {
TopicNode rootNode = new TopicNode();
topicNodes.put(rootTopic, rootNode);
existedTopics.put(rootTopic, new LinkedHashSet<>());
}
}
}
TopicNode current = topicNodes.get(rootTopic);
Set<String> existedTopicNodes = existedTopics.get(rootTopic);
for (String topic : topicPath) {
if (existedTopicNodes.contains(topic) && current.getTopicNodes().containsKey(topic)) {
current = current.getTopicNodes().get(topic);
} else {
TopicNode newNode = new TopicNode();
current.getTopicNodes().put(topic, newNode);
current = newNode;
current.setMemoryNodes(new CopyOnWriteArrayList<>());
current.setTopicNodes(new ConcurrentHashMap<>());
existedTopicNodes.add(topic);
}
}
return current;
}
private void updateSlicePrecedent(MemorySlice slice) {
String memoryId = slice.getMemoryId();
//查看是否切换了memoryId
if (!currentDateDialogSlices.containsKey(memoryId)) {
List<MemorySlice> memorySliceList = new ArrayList<>();
currentDateDialogSlices.clear();
currentDateDialogSlices.put(memoryId, memorySliceList);
}
//处理上下文关系
List<MemorySlice> memorySliceList = currentDateDialogSlices.get(memoryId);
if (memorySliceList.isEmpty()) {
memorySliceList.add(slice);
} else {
//排序
memorySliceList.sort(null);
MemorySlice tempSlice = memorySliceList.getLast();
//设置私密状态一致
tempSlice.setPrivate(slice.isPrivate());
//末尾切片添加当前切片的引用
tempSlice.setSliceAfter(slice);
//当前切片添加前序切片的引用
slice.setSliceBefore(tempSlice);
}
}
private void updateDateIndex(MemorySlice slice) {
String memoryId = slice.getMemoryId();
LocalDate date = LocalDate.now();
if (!dateIndex.containsKey(date)) {
HashSet<String> memoryIdSet = new HashSet<>();
memoryIdSet.add(memoryId);
dateIndex.put(date, memoryIdSet);
} else {
dateIndex.get(date).add(memoryId);
}
}
public MemoryResult selectMemory(List<String> path) throws IOException, ClassNotFoundException {
MemoryResult memoryResult = new MemoryResult();
CopyOnWriteArrayList<MemorySliceResult> targetSliceList = new CopyOnWriteArrayList<>();
String targetTopic = path.getLast();
TopicNode targetParentNode = getTargetParentNode(path, targetTopic);
List<List<String>> relatedTopics = new ArrayList<>();
//终点记忆节点
MemorySliceResult sliceResult = new MemorySliceResult();
for (MemoryNode memoryNode : targetParentNode.getTopicNodes().get(targetTopic).getMemoryNodes()) {
List<MemorySlice> endpointMemorySliceList = memoryNode.loadMemorySliceList();
for (MemorySlice memorySlice : endpointMemorySliceList) {
if (selectedSlices.contains(memorySlice.getTimestamp())) {
continue;
}
sliceResult.setSliceBefore(memorySlice.getSliceBefore());
sliceResult.setMemorySlice(memorySlice);
sliceResult.setSliceAfter(memorySlice.getSliceAfter());
targetSliceList.add(sliceResult);
selectedSlices.add(memorySlice.getTimestamp());
}
for (MemorySlice memorySlice : endpointMemorySliceList) {
if (memorySlice.getRelatedTopics() != null) {
relatedTopics.addAll(memorySlice.getRelatedTopics());
}
}
}
memoryResult.setMemorySliceResult(targetSliceList);
//邻近节点
List<MemorySlice> relatedMemorySlice = new ArrayList<>();
//邻近记忆节点 联系
for (List<String> relatedTopic : relatedTopics) {
List<String> tempTopicPath = new ArrayList<>(relatedTopic);
String tempTargetTopic = tempTopicPath.getLast();
TopicNode tempTargetParentNode = getTargetParentNode(tempTopicPath, tempTargetTopic);
//获取终点节点及其最新记忆节点
TopicNode tempTargetNode = tempTargetParentNode.getTopicNodes().get(tempTopicPath.getLast());
setRelatedMemorySlices(tempTargetNode, relatedMemorySlice);
}
//邻近记忆节点 父级
setRelatedMemorySlices(targetParentNode, relatedMemorySlice);
//将上述结果包装为MemoryResult
memoryResult.setRelatedMemorySliceResult(relatedMemorySlice);
return memoryResult;
}
private void setRelatedMemorySlices(TopicNode targetParentNode, List<MemorySlice> relatedMemorySlice) throws IOException, ClassNotFoundException {
List<MemoryNode> targetParentMemoryNodes = targetParentNode.getMemoryNodes();
if (!targetParentMemoryNodes.isEmpty()) {
for (MemorySlice memorySlice : targetParentMemoryNodes.getFirst().loadMemorySliceList()) {
if (selectedSlices.contains(memorySlice.getTimestamp())) {
continue;
}
relatedMemorySlice.add(memorySlice);
selectedSlices.add(memorySlice.getTimestamp());
}
}
}
private TopicNode getTargetParentNode(List<String> topicPath, String targetTopic) {
String topTopic = topicPath.getFirst();
if (!existedTopics.containsKey(topTopic)) {
throw new UnExistedTopicException("不存在的主题: " + topTopic);
}
TopicNode targetParentNode = topicNodes.get(topTopic);
topicPath.removeFirst();
for (String topic : topicPath) {
if (!existedTopics.get(topTopic).contains(topic)) {
throw new UnExistedTopicException("不存在的主题: " + topTopic);
}
}
//逐层查找目标主题
while (!targetParentNode.getTopicNodes().containsKey(targetTopic)) {
targetParentNode = targetParentNode.getTopicNodes().get(topicPath.getFirst());
topicPath.removeFirst();
}
return targetParentNode;
}
}

View File

@@ -0,0 +1,7 @@
package work.slhaf.partner.core.cognation.submodule.memory.exception;
public class NullSliceListException extends RuntimeException {
public NullSliceListException(String message) {
super(message);
}
}

View File

@@ -0,0 +1,7 @@
package work.slhaf.partner.core.cognation.submodule.memory.exception;
public class UnExistedDateIndexException extends RuntimeException {
public UnExistedDateIndexException(String message) {
super(message);
}
}

View File

@@ -0,0 +1,7 @@
package work.slhaf.partner.core.cognation.submodule.memory.exception;
public class UnExistedTopicException extends RuntimeException {
public UnExistedTopicException(String message) {
super(message);
}
}

View File

@@ -0,0 +1,22 @@
package work.slhaf.partner.core.cognation.submodule.memory.pojo;
import lombok.Builder;
import lombok.Data;
import lombok.EqualsAndHashCode;
import work.slhaf.partner.common.serialize.PersistableObject;
import java.io.Serial;
import java.time.LocalDate;
@EqualsAndHashCode(callSuper = true)
@Data
@Builder
public class EvaluatedSlice extends PersistableObject {
@Serial
private static final long serialVersionUID = 1L;
// private List<Message> chatMessages;
private LocalDate date;
private String summary;
}

View File

@@ -0,0 +1,83 @@
package work.slhaf.partner.core.cognation.submodule.memory.pojo;
import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.ToString;
import work.slhaf.partner.common.chat.pojo.Message;
import work.slhaf.partner.common.serialize.PersistableObject;
import java.io.Serial;
import java.util.List;
@EqualsAndHashCode(callSuper = true)
@Data
public class MemorySlice extends PersistableObject implements Comparable<MemorySlice> {
@Serial
private static final long serialVersionUID = 1L;
/**
* 关联的完整对话的id
*/
private String memoryId;
/**
* 该切片在关联的完整对话中的顺序, 由时间戳确定
*/
private Long timestamp;
/**
* 格式为"<日期>.slice", 如2025-04-11.slice
*/
private String summary;
private List<Message> chatMessages;
/**
* 关联的其他主题, 即"邻近节点(联系)"
*/
private List<List<String>> relatedTopics;
/**
* 关联完整对话中的前序切片, 排序为键,完整路径为值
*/
@ToString.Exclude
private MemorySlice sliceBefore, sliceAfter;
/**
* 多用户设定
* 发起该切片对话的用户
*/
private String startUserId;
/**
* 该切片涉及到的用户uuid
*/
private List<String> involvedUserIds;
/**
* 是否仅供发起用户作为记忆参考
*/
private boolean isPrivate;
/**
* 摘要向量化结果
*/
private float[] summaryEmbedding;
/**
* 是否向量化
*/
private boolean embedded;
@Override
public int compareTo(MemorySlice memorySlice) {
if (memorySlice.getTimestamp() > this.getTimestamp()) {
return -1;
} else if (memorySlice.getTimestamp() < this.timestamp) {
return 1;
}
return 0;
}
}

View File

@@ -0,0 +1,82 @@
package work.slhaf.partner.core.cognation.submodule.memory.pojo.node;
import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.extern.slf4j.Slf4j;
import work.slhaf.partner.common.serialize.PersistableObject;
import work.slhaf.partner.core.cognation.submodule.memory.exception.NullSliceListException;
import work.slhaf.partner.core.cognation.submodule.memory.pojo.MemorySlice;
import java.io.*;
import java.nio.file.Files;
import java.nio.file.Path;
import java.time.LocalDate;
import java.util.List;
import java.util.concurrent.CopyOnWriteArrayList;
@EqualsAndHashCode(callSuper = true)
@Data
@Slf4j
public class MemoryNode extends PersistableObject implements Comparable<MemoryNode> {
@Serial
private static final long serialVersionUID = 1L;
private static String SLICE_DATA_DIR = "./data/memory/slice/";
/**
* 记忆节点唯一标识, 用于作为实际文件名, 如(xxxx-xxxxx-xxxxx.slice)
*/
private String memoryNodeId;
/**
* 记忆节点所属日期
*/
private LocalDate localDate;
/**
* 该日期对应的全部记忆切片
*/
private CopyOnWriteArrayList<MemorySlice> memorySliceList;
@Override
public int compareTo(MemoryNode memoryNode) {
if (memoryNode.getLocalDate().isAfter(this.localDate)) {
return -1;
} else if (memoryNode.getLocalDate().isBefore(this.localDate)) {
return 1;
}
return 0;
}
public List<MemorySlice> loadMemorySliceList() throws IOException, ClassNotFoundException {
//检查是否存在对应文件
File file = new File(SLICE_DATA_DIR+this.getMemoryNodeId()+".slice");
if (file.exists()){
this.memorySliceList = deserialize(file);
}else {
//逻辑正常的话这部分应该不会出现除非在insertMemory中进行save操作之前出现异常中断了方法但程序却没有结束
this.memorySliceList = new CopyOnWriteArrayList<>();
}
return this.memorySliceList;
}
public void saveMemorySliceList() throws IOException {
if (memorySliceList == null){
throw new NullSliceListException("memorySliceList为NULL! 检查实现逻辑!");
}
File file = new File(SLICE_DATA_DIR+this.getMemoryNodeId()+".slice");
Files.createDirectories(Path.of(SLICE_DATA_DIR));
try(ObjectOutputStream oos = new ObjectOutputStream(new FileOutputStream(file))){
oos.writeObject(this.memorySliceList);
}
//取消切片挂载, 释放内存
this.memorySliceList = null;
}
private CopyOnWriteArrayList<MemorySlice> deserialize(File file) throws IOException, ClassNotFoundException {
try(ObjectInputStream ois = new ObjectInputStream(new FileInputStream(file))) {
return (CopyOnWriteArrayList<MemorySlice>) ois.readObject();
}
}
}

View File

@@ -0,0 +1,20 @@
package work.slhaf.partner.core.cognation.submodule.memory.pojo.node;
import lombok.Data;
import lombok.EqualsAndHashCode;
import work.slhaf.partner.common.serialize.PersistableObject;
import java.io.Serial;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.CopyOnWriteArrayList;
@EqualsAndHashCode(callSuper = true)
@Data
public class TopicNode extends PersistableObject {
@Serial
private static final long serialVersionUID = 1L;
private ConcurrentHashMap<String,TopicNode> topicNodes = new ConcurrentHashMap<>();
private CopyOnWriteArrayList<MemoryNode> memoryNodes = new CopyOnWriteArrayList<>();
}

View File

@@ -0,0 +1,88 @@
package work.slhaf.partner.core.cognation.submodule.perceive;
import lombok.Data;
import lombok.EqualsAndHashCode;
import work.slhaf.partner.common.serialize.PersistableObject;
import work.slhaf.partner.core.cognation.submodule.perceive.pojo.User;
import java.io.Serial;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.UUID;
import java.util.concurrent.locks.ReentrantLock;
@EqualsAndHashCode(callSuper = true)
@Data
public class PerceiveCore extends PersistableObject {
@Serial
private static final long serialVersionUID = 1L;
private static volatile PerceiveCore perceiveCore = new PerceiveCore();
private static final ReentrantLock usersLock = new ReentrantLock();
/**
* 用户列表
*/
private List<User> users = new ArrayList<>();
public static PerceiveCore getInstance() {
if (perceiveCore == null) {
synchronized (PerceiveCore.class) {
if (perceiveCore == null) {
perceiveCore = new PerceiveCore();
}
}
}
return perceiveCore;
}
public User selectUser(String userInfo, String platform) {
User resultUser = null;
usersLock.lock();
for (User user : users) {
HashMap<String, String> info = user.getInfo();
if (info.containsKey(platform)) {
if (info.get(platform).equals(userInfo)) {
resultUser = user;
}
}
}
usersLock.unlock();
return resultUser;
}
public User addUser(String userInfo, String platform, String userNickName) {
User user = new User();
user.addInfo(platform, userInfo);
user.setNickName(userNickName);
user.setUuid(UUID.randomUUID().toString());
usersLock.lock();
users.add(user);
usersLock.unlock();
return user;
}
public User selectUser(String id) {
usersLock.lock();
for (User user : users) {
if (user.getUuid().equals(id)) {
return user;
}
}
usersLock.unlock();
return null;
}
public void updateUser(User temp) {
usersLock.lock();
User user = selectUser(temp.getUuid());
user.setRelation(temp.getRelation());
user.setImpressions(temp.getImpressions());
user.setAttitude(temp.getAttitude());
user.setStaticMemory(temp.getStaticMemory());
user.updateRelationChange(user.getRelationChange());
usersLock.unlock();
}
}

View File

@@ -0,0 +1,51 @@
package work.slhaf.partner.core.cognation.submodule.perceive.pojo;
import lombok.Data;
import lombok.EqualsAndHashCode;
import work.slhaf.partner.common.serialize.PersistableObject;
import java.io.Serial;
import java.time.LocalDate;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
@EqualsAndHashCode(callSuper = true)
@Data
public class User extends PersistableObject {
@Serial
private static final long serialVersionUID = 1L;
private String uuid;
private String nickName;
private HashMap<String/*platform*/, String> info = new HashMap<>();
private String relation = Constant.Relation.STRANGER;
// private HashMap<LocalDate, String> events = new HashMap<>();
private List<String> impressions = new ArrayList<>();
private List<String> attitude = new ArrayList<>();
private LinkedHashMap<LocalDate,String> relationChange = new LinkedHashMap<>();
private HashMap<String,String> staticMemory = new HashMap<>();
public void addInfo(String platform, String userInfo) {
this.info.put(platform, userInfo);
}
public void updateRelationChange(String changeReason){
relationChange.put(LocalDate.now(),changeReason);
}
public void updateRelationChange(LocalDate date, String changeReason){
relationChange.put(date,changeReason);
}
public void updateRelationChange(LinkedHashMap<LocalDate,String> tempRelationChange){
relationChange.putAll(tempRelationChange);
}
public static class Constant {
public static class Relation {
public static final String STRANGER = "陌生";
}
}
}

View File

@@ -0,0 +1,11 @@
package work.slhaf.partner.core.interaction.agent_interface;
import work.slhaf.partner.core.interaction.data.InteractionInputData;
import java.io.IOException;
public interface InputReceiver {
void receiveInput(InteractionInputData inputData) throws IOException, ClassNotFoundException;
}

View File

@@ -0,0 +1,5 @@
package work.slhaf.partner.core.interaction.agent_interface;
public interface TaskCallback {
void onTaskFinished(String userInfo,String output);
}

View File

@@ -0,0 +1,15 @@
package work.slhaf.partner.core.interaction.data;
import lombok.Data;
import java.time.LocalDateTime;
@Data
public class InteractionInputData {
private String userInfo;
private String userNickName;
private String content;
private LocalDateTime localDateTime;
private String platform;
private boolean single;
}

View File

@@ -0,0 +1,11 @@
package work.slhaf.partner.core.interaction.data;
import lombok.AllArgsConstructor;
import lombok.Data;
@Data
@AllArgsConstructor
public class InteractionOutputData {
private String content;
private String userInfo;
}

View File

@@ -0,0 +1,62 @@
package work.slhaf.partner.core.interaction.data.context;
import com.alibaba.fastjson2.JSONObject;
import lombok.Data;
import lombok.EqualsAndHashCode;
import work.slhaf.partner.common.serialize.PersistableObject;
import work.slhaf.partner.core.interaction.data.context.subcontext.CoreContext;
import work.slhaf.partner.core.interaction.data.context.subcontext.ModuleContext;
import work.slhaf.partner.module.common.AppendPromptData;
import java.io.Serial;
import java.time.LocalDateTime;
import java.util.HashMap;
import java.util.List;
@EqualsAndHashCode(callSuper = true)
@Data
public class InteractionContext extends PersistableObject {
@Serial
private static final long serialVersionUID = 1L;
private static HashMap<String, InteractionContext> activeContext = new HashMap<>();
protected String userId;
protected String userNickname;
protected String userInfo;
protected LocalDateTime dateTime;
protected boolean single;
protected String input;
protected CoreContext coreContext = new CoreContext();
protected ModuleContext moduleContext = new ModuleContext();
protected JSONObject coreResponse = new JSONObject();
public InteractionContext() {
activeContext.put(userId, this);
}
public void setFinished(boolean finished) {
moduleContext.setFinished(finished);
}
public boolean isFinished() {
return moduleContext.isFinished();
}
public void setAppendedPrompt(AppendPromptData appendedPrompt) {
List<AppendPromptData> appendPromptList = moduleContext.getAppendedPrompt();
appendPromptList.addFirst(appendedPrompt);
}
public static HashMap<String, InteractionContext> getInstance() {
return activeContext;
}
public void clearUp() {
activeContext.remove(userId);
}
}

View File

@@ -0,0 +1,36 @@
package work.slhaf.partner.core.interaction.data.context.subcontext;
import com.alibaba.fastjson2.JSONObject;
import lombok.Data;
import lombok.EqualsAndHashCode;
import work.slhaf.partner.common.serialize.PersistableObject;
import java.io.Serial;
import java.util.HashMap;
@EqualsAndHashCode(callSuper = true)
@Data
public class CoreContext extends PersistableObject {
@Serial
private static final long serialVersionUID = 1L;
private String text;
private String dateTime;
private String userNick;
private String userId;
private HashMap<String, Boolean> activeModules = new HashMap<>();
@Override
public String toString() {
return JSONObject.toJSONString(this);
}
public void addActiveModule(String moduleName) {
activeModules.put(moduleName, false);
}
public void activateModule(String moduleName){
activeModules.put(moduleName, true);
}
}

View File

@@ -0,0 +1,23 @@
package work.slhaf.partner.core.interaction.data.context.subcontext;
import com.alibaba.fastjson2.JSONObject;
import lombok.Data;
import lombok.EqualsAndHashCode;
import work.slhaf.partner.common.serialize.PersistableObject;
import work.slhaf.partner.module.common.AppendPromptData;
import java.io.Serial;
import java.util.ArrayList;
import java.util.List;
@EqualsAndHashCode(callSuper = true)
@Data
public class ModuleContext extends PersistableObject {
@Serial
private static final long serialVersionUID = 1L;
private List<AppendPromptData> appendedPrompt = new ArrayList<>();
private JSONObject extraContext = new JSONObject();
private boolean finished = false;
}

View File

@@ -0,0 +1,9 @@
package work.slhaf.partner.core.interaction.module;
import work.slhaf.partner.core.interaction.data.context.InteractionContext;
import java.io.IOException;
public interface InteractionModule {
void execute(InteractionContext context) throws IOException, ClassNotFoundException;
}

View File

@@ -0,0 +1,60 @@
package work.slhaf.partner.core.interaction.module;
import work.slhaf.partner.common.config.Config;
import work.slhaf.partner.common.config.ModuleConfig;
import java.io.File;
import java.io.IOException;
import java.lang.reflect.InvocationTargetException;
import java.net.URL;
import java.net.URLClassLoader;
import java.util.ArrayList;
import java.util.List;
public class InteractionModulesLoader {
private static InteractionModulesLoader interactionModulesLoader;
public static InteractionModulesLoader getInstance(){
if (interactionModulesLoader == null) {
interactionModulesLoader = new InteractionModulesLoader();
}
return interactionModulesLoader;
}
public List<InteractionModule> registerInteractionModules() throws IOException {
List<InteractionModule> moduleList = new ArrayList<>();
List<ModuleConfig> moduleConfigList = Config.getConfig().getModuleConfigList();
for (ModuleConfig moduleConfig : moduleConfigList) {
if (ModuleConfig.Constant.INTERNAL.equals(moduleConfig.getType())) {
moduleList.add(loadInternalModule(moduleConfig.getClassName()));
} else if (ModuleConfig.Constant.EXTERNAL.equals(moduleConfig.getType())) {
moduleList.add(loadExternalModule(moduleConfig.getClassName(),moduleConfig.getPath()));
}
}
return moduleList;
}
private InteractionModule loadExternalModule(String className, String path) {
try {
URL jarUrl = new File(path).toURI().toURL();
URLClassLoader loader = new URLClassLoader(new URL[]{jarUrl}, this.getClass().getClassLoader());
Class<?> clazz = loader.loadClass(className);
loader.close();
return (InteractionModule) clazz.getMethod("getInstance").invoke(null);
} catch (ClassNotFoundException | InvocationTargetException | IllegalAccessException |
NoSuchMethodException | IOException e) {
throw new RuntimeException("Fail to load internal module: " + className ,e);
}
}
private static InteractionModule loadInternalModule(String className) {
try {
Class<?> clazz = Class.forName(className);
return (InteractionModule) clazz.getMethod("getInstance").invoke(null);
} catch (ClassNotFoundException | InvocationTargetException | IllegalAccessException | NoSuchMethodException e) {
throw new RuntimeException("Fail to load internal module: " + className,e);
}
}
}

View File

@@ -0,0 +1,130 @@
package work.slhaf.partner.core.session;
import com.alibaba.fastjson2.JSONObject;
import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.extern.slf4j.Slf4j;
import work.slhaf.partner.common.chat.pojo.Message;
import work.slhaf.partner.common.chat.pojo.MetaMessage;
import work.slhaf.partner.common.config.Config;
import work.slhaf.partner.common.serialize.PersistableObject;
import java.io.*;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.nio.file.StandardCopyOption;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.UUID;
@EqualsAndHashCode(callSuper = true)
@Data
@Slf4j
public class SessionManager extends PersistableObject {
@Serial
private static final long serialVersionUID = 1L;
private static final String STORAGE_DIR = "./data/session/";
private static volatile SessionManager sessionManager;
private String id;
private HashMap<String /*startUserId*/, List<MetaMessage>> singleMetaMessageMap;
private String currentMemoryId;
private long lastUpdatedTime;
public static SessionManager getInstance() throws IOException, ClassNotFoundException {
if (sessionManager == null) {
synchronized (SessionManager.class) {
if (sessionManager == null) {
String id = Config.getConfig().getAgentId();
Path filePath = Paths.get(STORAGE_DIR, id + ".session");
if (Files.exists(filePath)) {
sessionManager = deserialize(id);
} else {
sessionManager = new SessionManager();
sessionManager.setSingleMetaMessageMap(new HashMap<>());
sessionManager.id = id;
sessionManager.setShutdownHook();
sessionManager.lastUpdatedTime = 0;
}
}
}
}
return sessionManager;
}
private void setShutdownHook() {
Runtime.getRuntime().addShutdownHook(new Thread(() -> {
try {
sessionManager.serialize();
log.info("[SessionManager] SessionManager 已保存");
} catch (IOException e) {
log.error("[SessionManager] 保存 SessionManager 失败: ", e);
}
}));
}
public void addMetaMessage(String userId, MetaMessage metaMessage) {
log.debug("[SessionManager] 当前会话历史: {}", JSONObject.toJSONString(singleMetaMessageMap));
if (singleMetaMessageMap.containsKey(userId)) {
singleMetaMessageMap.get(userId).add(metaMessage);
} else {
singleMetaMessageMap.put(userId, new java.util.ArrayList<>());
singleMetaMessageMap.get(userId).add(metaMessage);
}
log.debug("[SessionManager] 会话历史更新: {}", JSONObject.toJSONString(singleMetaMessageMap));
}
public List<Message> unpackAndClear(String userId) {
List<Message> messages = new ArrayList<>();
for (MetaMessage metaMessage : singleMetaMessageMap.get(userId)) {
messages.add(metaMessage.getUserMessage());
messages.add(metaMessage.getAssistantMessage());
}
singleMetaMessageMap.remove(userId);
return messages;
}
public void refreshMemoryId() {
currentMemoryId = UUID.randomUUID().toString();
}
public void serialize() throws IOException {
//先写入到临时文件,如果正常写入,则覆盖正式文件;否则删除临时文件
Path filePath = getFilePath(this.id + "-temp");
Files.createDirectories(Path.of(STORAGE_DIR));
try {
ObjectOutputStream oos = new ObjectOutputStream(new FileOutputStream(filePath.toFile()));
oos.writeObject(this);
oos.close();
Path path = getFilePath(this.id);
Files.move(filePath, path, StandardCopyOption.REPLACE_EXISTING);
log.info("[SessionManager] SessionManager 已保存到: {}", path);
} catch (IOException e) {
Files.delete(filePath);
log.error("[SessionManager] 序列化保存失败: {}", e.getMessage());
}
}
private static SessionManager deserialize(String id) throws IOException, ClassNotFoundException {
Path filePath = getFilePath(id);
try (ObjectInputStream ois = new ObjectInputStream(new FileInputStream(filePath.toFile()))) {
SessionManager sessionManager = (SessionManager) ois.readObject();
log.info("[SessionManager] SessionManager 已从文件加载: {}", filePath);
return sessionManager;
}
}
public void resetLastUpdatedTime() {
lastUpdatedTime = System.currentTimeMillis();
}
private static Path getFilePath(String id) {
return Paths.get(STORAGE_DIR, id + ".session");
}
}

View File

@@ -0,0 +1,136 @@
package work.slhaf.partner.gateway;
import cn.hutool.json.JSONUtil;
import com.alibaba.fastjson2.JSONObject;
import lombok.ToString;
import lombok.extern.slf4j.Slf4j;
import org.java_websocket.WebSocket;
import org.java_websocket.framing.Framedata;
import org.java_websocket.handshake.ClientHandshake;
import org.java_websocket.server.WebSocketServer;
import work.slhaf.partner.common.thread.InteractionThreadPoolExecutor;
import work.slhaf.partner.core.interaction.agent_interface.InputReceiver;
import work.slhaf.partner.core.interaction.data.InteractionInputData;
import work.slhaf.partner.core.interaction.data.InteractionOutputData;
import java.io.IOException;
import java.net.InetSocketAddress;
import java.util.concurrent.ConcurrentHashMap;
@Slf4j
public class AgentWebSocketServer extends WebSocketServer implements MessageSender {
private static final long HEARTBEAT_INTERVAL = 10_000;
@ToString.Exclude
private final InputReceiver receiver;
private final ConcurrentHashMap<String, WebSocket> userSessions = new ConcurrentHashMap<>();
private final InteractionThreadPoolExecutor executor;
// 记录最后一次收到Pong的时间
private final ConcurrentHashMap<WebSocket, Long> lastPongTimes = new ConcurrentHashMap<>();
public AgentWebSocketServer(int port, InputReceiver receiver) {
super(new InetSocketAddress(port));
this.receiver = receiver;
this.executor = InteractionThreadPoolExecutor.getInstance();
}
public void launch() {
this.start();
setShutDownHook();
startHeartbeatThread();
}
private void startHeartbeatThread() {
executor.execute(() -> {
while (!Thread.interrupted()){
try{
Thread.sleep(HEARTBEAT_INTERVAL);
checkConnections();
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
}
}
});
}
private void checkConnections() {
long now = System.currentTimeMillis();
for (WebSocket conn : getConnections()) {
if (conn.isOpen()) {
// 发送Ping
conn.sendPing();
log.debug("Sent Ping to {}", conn.getRemoteSocketAddress());
// 检查上次Pong响应是否超时2倍心跳间隔
Long lastPong = lastPongTimes.get(conn);
if (lastPong != null && now - lastPong > HEARTBEAT_INTERVAL * 2) {
log.warn("Connection {} timed out, closing...", conn.getRemoteSocketAddress());
conn.close(1001, "No Pong response");
}
}
}
}
@Override
public void onWebsocketPong(WebSocket conn, Framedata f) {
lastPongTimes.put(conn, System.currentTimeMillis());
log.debug("Received Pong from {}", conn.getRemoteSocketAddress());
}
private void setShutDownHook() {
Runtime.getRuntime().addShutdownHook(new Thread(() -> {
try {
//关闭WebSocketServer
this.stop();
log.info("WebSocketServer 已关闭");
} catch (Exception e) {
log.error("WebSocketServer关闭失败: ", e);
}
}));
}
@Override
public void onOpen(WebSocket webSocket, ClientHandshake clientHandshake) {
log.info("新连接: {}", webSocket.getRemoteSocketAddress());
}
@Override
public void onClose(WebSocket webSocket, int i, String s, boolean b) {
log.info("连接关闭: {}", webSocket.getRemoteSocketAddress());
lastPongTimes.remove(webSocket);
userSessions.values().removeIf(session -> session.equals(webSocket));
}
@Override
public void onMessage(WebSocket webSocket, String s) {
InteractionInputData inputData = JSONObject.parseObject(s, InteractionInputData.class);
userSessions.put(inputData.getUserInfo(), webSocket); // 注册连接
try {
receiver.receiveInput(inputData);
} catch (IOException | ClassNotFoundException e) {
throw new RuntimeException(e);
}
}
@Override
public void onError(WebSocket webSocket, Exception e) {
log.error(e.getLocalizedMessage());
}
@Override
public void onStart() {
log.info("WebSocketServer 已启动...");
}
@Override
public void sendMessage(InteractionOutputData outputData) {
WebSocket webSocket = userSessions.get(outputData.getUserInfo());
if (webSocket != null && webSocket.isOpen()) {
webSocket.send(JSONUtil.toJsonStr(outputData));
} else {
log.warn("用户不在线: {}", outputData.getUserInfo());
}
}
}

View File

@@ -0,0 +1,7 @@
package work.slhaf.partner.gateway;
import work.slhaf.partner.core.interaction.data.InteractionOutputData;
public interface MessageSender {
void sendMessage(InteractionOutputData outputData);
}

View File

@@ -0,0 +1,19 @@
package work.slhaf.partner.module.common;
import lombok.Data;
import lombok.EqualsAndHashCode;
import work.slhaf.partner.common.serialize.PersistableObject;
import java.io.Serial;
import java.util.HashMap;
@EqualsAndHashCode(callSuper = true)
@Data
public class AppendPromptData extends PersistableObject {
@Serial
private static final long serialVersionUID = 1L;
private String moduleName;
private HashMap<String,String> appendedPrompt;
}

View File

@@ -0,0 +1,47 @@
package work.slhaf.partner.module.common;
import lombok.Data;
import work.slhaf.partner.common.chat.ChatClient;
import work.slhaf.partner.common.chat.constant.ChatConstant;
import work.slhaf.partner.common.chat.pojo.ChatResponse;
import work.slhaf.partner.common.chat.pojo.Message;
import work.slhaf.partner.common.config.ModelConfig;
import work.slhaf.partner.common.util.ResourcesUtil;
import java.util.ArrayList;
import java.util.List;
@Data
public abstract class Model {
protected ChatClient chatClient;
protected List<Message> chatMessages;
protected List<Message> baseMessages;
protected static void setModel(Model model, String promptModule, boolean withAwareness) {
String model_key = model.modelKey();
ModelConfig modelConfig = ModelConfig.load(model_key);
model.setBaseMessages(withAwareness ? ResourcesUtil.Prompt.loadPromptWithSelfAwareness(model_key, promptModule) : ResourcesUtil.Prompt.loadPrompt(model_key, promptModule));
model.setChatClient(new ChatClient(modelConfig.getBaseUrl(), modelConfig.getApikey(), modelConfig.getModel()));
}
protected ChatResponse chat() {
List<Message> temp = new ArrayList<>();
temp.addAll(this.baseMessages);
temp.addAll(this.chatMessages);
return this.chatClient.runChat(temp);
}
protected ChatResponse singleChat(String input) {
List<Message> temp = new ArrayList<>(baseMessages);
temp.add( new Message(ChatConstant.Character.USER, input));
return this.chatClient.runChat(temp);
}
protected void updateChatClientSettings() {
this.chatClient.setTemperature(0.4);
this.chatClient.setTop_p(0.8);
}
protected abstract String modelKey();
}

View File

@@ -0,0 +1,16 @@
package work.slhaf.partner.module.common;
public class ModelConstant {
public static class Prompt {
public static final String MEMORY = "memory";
public static final String SCHEDULE = "schedule";
public static final String CORE = "core";
public static final String PERCEIVE = "perceive";
}
public static class CharacterPrefix {
public static final String SYSTEM = "[SYSTEM] ";
}
}

View File

@@ -0,0 +1,27 @@
package work.slhaf.partner.module.common;
import work.slhaf.partner.core.interaction.data.context.InteractionContext;
import work.slhaf.partner.core.interaction.module.InteractionModule;
import java.util.HashMap;
/**
* 前置模块抽象类
*/
public abstract class PreModule implements InteractionModule {
protected void setAppendedPrompt(InteractionContext context) {
AppendPromptData data = new AppendPromptData();
data.setModuleName(moduleName());
HashMap<String, String> map = getPromptDataMap(context.getUserId());
data.setAppendedPrompt(map);
context.getModuleContext().getAppendedPrompt().add(data);
}
protected void setActiveModule(InteractionContext context) {
context.getCoreContext().addActiveModule(moduleName());
}
protected abstract HashMap<String, String> getPromptDataMap(String userId);
protected abstract String moduleName();
}

View File

@@ -0,0 +1,245 @@
package work.slhaf.partner.module.modules.core;
import com.alibaba.fastjson2.JSONObject;
import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.extern.slf4j.Slf4j;
import work.slhaf.partner.common.chat.constant.ChatConstant;
import work.slhaf.partner.common.chat.pojo.ChatResponse;
import work.slhaf.partner.common.chat.pojo.Message;
import work.slhaf.partner.common.chat.pojo.MetaMessage;
import work.slhaf.partner.core.cognation.CognationManager;
import work.slhaf.partner.core.cognation.capability.ability.CognationCapability;
import work.slhaf.partner.core.interaction.data.context.InteractionContext;
import work.slhaf.partner.core.interaction.module.InteractionModule;
import work.slhaf.partner.core.session.SessionManager;
import work.slhaf.partner.module.common.AppendPromptData;
import work.slhaf.partner.module.common.Model;
import work.slhaf.partner.module.common.ModelConstant;
import java.io.IOException;
import java.time.LocalDateTime;
import java.time.format.DateTimeFormatter;
import java.util.ArrayList;
import java.util.List;
import static work.slhaf.partner.common.util.ExtractUtil.extractJson;
@EqualsAndHashCode(callSuper = true)
@Data
@Slf4j
public class CoreModel extends Model implements InteractionModule {
private static volatile CoreModel coreModel;
private CognationCapability cognationCapability;
private SessionManager sessionManager;
private List<Message> appendedMessages;
private CoreModel() {
}
public static CoreModel getInstance() throws IOException, ClassNotFoundException {
if (coreModel == null) {
synchronized (CoreModel.class) {
if (coreModel == null) {
coreModel = new CoreModel();
coreModel.cognationCapability = CognationManager.getInstance();
coreModel.chatMessages = coreModel.cognationCapability.getChatMessages();
coreModel.appendedMessages = new ArrayList<>();
coreModel.sessionManager = SessionManager.getInstance();
setModel(coreModel, ModelConstant.Prompt.CORE, true);
coreModel.updateChatClientSettings();
log.info("[CoreModel] CoreModel注册完毕...");
}
}
}
return coreModel;
}
@Override
protected void updateChatClientSettings() {
this.chatClient.setTemperature(0.3);
this.chatClient.setTop_p(0.7);
}
@Override
protected String modelKey() {
return "core_model";
}
@Override
public void execute(InteractionContext interactionContext) {
String userId = interactionContext.getUserId();
log.debug("[CoreModel] 主对话流程开始: {}", userId);
List<AppendPromptData> appendedPrompt = interactionContext.getModuleContext().getAppendedPrompt();
int appendedPromptSize = getAppendedPromptSize(appendedPrompt);
if (appendedPromptSize > 0) {
setAppendedPromptMessage(appendedPrompt);
}
activateModule(interactionContext);
setMessageCount(interactionContext);
log.debug("[CoreModel] 当前消息列表大小: {}", this.chatMessages.size());
log.debug("[CoreModel] 当前核心prompt内容: {}", interactionContext.getCoreContext().toString());
setMessage(interactionContext.getCoreContext().toString());
JSONObject response = new JSONObject();
int count = 0;
while (true) {
try {
ChatResponse chatResponse = this.chat();
try {
response.putAll(JSONObject.parse(extractJson(chatResponse.getMessage())));
} catch (Exception e) {
log.warn("主模型回复格式出错, 将直接作为消息返回, 建议尝试更换主模型...");
handleExceptionResponse(response, chatResponse.getMessage());
}
log.debug("[CoreModel] CoreModel 响应内容: {}", response);
updateModuleContextAndChatMessages(interactionContext, response.getString("text"), chatResponse);
break;
} catch (Exception e) {
count++;
log.error("[CoreModel] CoreModel执行异常: {}", e.getLocalizedMessage());
if (count > 3) {
handleExceptionResponse(response, "主模型交互出错: " + e.getLocalizedMessage());
this.chatMessages.removeLast();
break;
}
} finally {
updateCoreResponse(interactionContext, response);
resetAppendedMessages();
log.debug("[CoreModel] 消息列表更新大小: {}", this.chatMessages.size());
}
}
log.debug("[CoreModel] 主对话流程({})结束...", userId);
}
private int getAppendedPromptSize(List<AppendPromptData> appendedPrompt) {
int size = 0;
for (AppendPromptData data : appendedPrompt) {
size += data.getAppendedPrompt().size();
}
return size;
}
private void activateModule(InteractionContext context) {
for (AppendPromptData data : context.getModuleContext().getAppendedPrompt()) {
if (data.getAppendedPrompt().isEmpty()) continue;
context.getCoreContext().activateModule(data.getModuleName());
}
}
private void updateCoreResponse(InteractionContext interactionContext, JSONObject response) {
interactionContext.getCoreResponse().put("text", response.getString("text"));
}
private void resetAppendedMessages() {
this.appendedMessages.clear();
}
@Override
protected ChatResponse chat() {
List<Message> temp = new ArrayList<>(baseMessages.subList(0, baseMessages.size() - 2));
temp.addAll(appendedMessages);
temp.addAll(baseMessages.subList(baseMessages.size() - 2, baseMessages.size()));
temp.addAll(chatMessages);
return this.chatClient.runChat(temp);
}
private void updateModuleContextAndChatMessages(InteractionContext interactionContext, String response, ChatResponse chatResponse) {
cognationCapability.getMessageLock().lock();
this.chatMessages.removeIf(m -> {
if (m.getRole().equals(ChatConstant.Character.ASSISTANT)) {
return false;
}
try {
JSONObject.parseObject(extractJson(m.getContent()));
return true;
} catch (Exception e) {
return false;
}
});
//添加时间标志
String dateTime = LocalDateTime.now().format(DateTimeFormatter.ofPattern("\r\n**[yyyy-MM-dd HH:mm:ss]"));
Message primaryUserMessage = new Message(ChatConstant.Character.USER, interactionContext.getCoreContext().getText() + dateTime);
this.chatMessages.add(primaryUserMessage);
Message assistantMessage = new Message(ChatConstant.Character.ASSISTANT, response);
this.chatMessages.add(assistantMessage);
cognationCapability.getMessageLock().unlock();
//设置上下文
interactionContext.getModuleContext().getExtraContext().put("total_token", chatResponse.getUsageBean().getTotal_tokens());
//区分单人聊天场景
if (interactionContext.isSingle()) {
MetaMessage metaMessage = new MetaMessage(primaryUserMessage, assistantMessage);
sessionManager.addMetaMessage(interactionContext.getUserId(), metaMessage);
}
}
private void setMessage(String coreContextStr) {
Message userMessage = new Message(ChatConstant.Character.USER, coreContextStr);
this.chatMessages.add(userMessage);
}
private void handleExceptionResponse(JSONObject response, String chatResponse) {
response.put("text", chatResponse);
// interactionContext.setFinished(true);
}
private void setMessageCount(InteractionContext interactionContext) {
interactionContext.getModuleContext().getExtraContext().put("message_count", chatMessages.size());
}
private void setAppendedPromptMessage(List<AppendPromptData> appendPrompt) {
Message appendDeclareMessage = Message.builder()
.role(ChatConstant.Character.USER)
.content(ModelConstant.CharacterPrefix.SYSTEM + "认知补充开始")
.build();
this.appendedMessages.add(appendDeclareMessage);
for (AppendPromptData data : appendPrompt) {
setStartMessage(data);
setContentMessage(data);
setEndMessage(data);
setAssistantMessage();
}
Message appendEndMessage = Message.builder()
.role(ChatConstant.Character.USER)
.content(ModelConstant.CharacterPrefix.SYSTEM + "认知补充结束")
.build();
this.appendedMessages.add(appendEndMessage);
}
private void setAssistantMessage() {
appendedMessages.add(Message.builder()
.role(ChatConstant.Character.ASSISTANT)
.content("嗯,明白了")
.build());
}
private void setEndMessage(AppendPromptData data) {
Message endMessage = Message.builder()
.role(ChatConstant.Character.USER)
.content(ModelConstant.CharacterPrefix.SYSTEM + data.getModuleName() + "认知补充结束.")
.build();
appendedMessages.add(endMessage);
}
private void setContentMessage(AppendPromptData data) {
data.getAppendedPrompt().forEach((k, v) -> {
Message contentMessage = Message.builder()
.role(ChatConstant.Character.USER)
.content(ModelConstant.CharacterPrefix.SYSTEM + k + v + "\r\n")
.build();
appendedMessages.add(contentMessage);
});
}
private void setStartMessage(AppendPromptData data) {
Message startMessage = Message.builder()
.role(ChatConstant.Character.USER)
.content(ModelConstant.CharacterPrefix.SYSTEM + data.getModuleName() + "以下为" + data.getModuleName() + "相关认知.")
.build();
appendedMessages.add(startMessage);
}
}

View File

@@ -0,0 +1,176 @@
package work.slhaf.partner.module.modules.memory.selector;
import com.alibaba.fastjson2.JSONObject;
import lombok.Data;
import lombok.extern.slf4j.Slf4j;
import work.slhaf.partner.core.cognation.capability.ability.CognationCapability;
import work.slhaf.partner.core.cognation.CognationManager;
import work.slhaf.partner.core.cognation.submodule.memory.exception.UnExistedDateIndexException;
import work.slhaf.partner.core.cognation.submodule.memory.exception.UnExistedTopicException;
import work.slhaf.partner.core.cognation.common.pojo.MemoryResult;
import work.slhaf.partner.core.cognation.capability.ability.CacheCapability;
import work.slhaf.partner.core.cognation.capability.ability.MemoryCapability;
import work.slhaf.partner.core.cognation.submodule.memory.pojo.MemorySlice;
import work.slhaf.partner.core.interaction.data.context.InteractionContext;
import work.slhaf.partner.core.session.SessionManager;
import work.slhaf.partner.module.common.PreModule;
import work.slhaf.partner.module.modules.memory.selector.evaluator.SliceSelectEvaluator;
import work.slhaf.partner.module.modules.memory.selector.evaluator.data.EvaluatorInput;
import work.slhaf.partner.module.modules.memory.selector.extractor.MemorySelectExtractor;
import work.slhaf.partner.module.modules.memory.selector.extractor.data.ExtractorMatchData;
import work.slhaf.partner.module.modules.memory.selector.extractor.data.ExtractorResult;
import work.slhaf.partner.core.cognation.submodule.memory.pojo.EvaluatedSlice;
import java.io.IOException;
import java.time.LocalDate;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
@Data
@Slf4j
public class MemorySelector extends PreModule {
private static volatile MemorySelector memorySelector;
private CacheCapability cacheCapability;
private MemoryCapability memoryCapability;
private CognationCapability cognationCapability;
private SliceSelectEvaluator sliceSelectEvaluator;
private MemorySelectExtractor memorySelectExtractor;
private SessionManager sessionManager;
private MemorySelector() {
}
public static MemorySelector getInstance() throws IOException, ClassNotFoundException {
if (memorySelector == null) {
synchronized (MemorySelector.class) {
if (memorySelector == null) {
memorySelector = new MemorySelector();
memorySelector.setCacheCapability(CognationManager.getInstance());
memorySelector.setMemoryCapability(CognationManager.getInstance());
memorySelector.setCognationCapability(CognationManager.getInstance());
memorySelector.setSliceSelectEvaluator(SliceSelectEvaluator.getInstance());
memorySelector.setMemorySelectExtractor(MemorySelectExtractor.getInstance());
memorySelector.setSessionManager(SessionManager.getInstance());
}
}
}
return memorySelector;
}
@Override
public void execute(InteractionContext interactionContext) throws IOException, ClassNotFoundException {
log.debug("[MemorySelector] 记忆回溯流程开始...");
String userId = interactionContext.getUserId();
//获取主题路径
ExtractorResult extractorResult = memorySelectExtractor.execute(interactionContext);
if (extractorResult.isRecall() || !extractorResult.getMatches().isEmpty()) {
cognationCapability.clearActivatedSlices(userId);
List<EvaluatedSlice> evaluatedSlices = selectAndEvaluateMemory(interactionContext, extractorResult);
cognationCapability.updateActivatedSlices(userId, evaluatedSlices);
}
//设置追加提示词
setAppendedPrompt(interactionContext);
setModuleContextRecall(interactionContext);
setActiveModule(interactionContext);
log.debug("[MemorySelector] 记忆回溯完成...");
}
private List<EvaluatedSlice> selectAndEvaluateMemory(InteractionContext interactionContext, ExtractorResult extractorResult) throws IOException, ClassNotFoundException {
log.debug("[MemorySelector] 触发记忆回溯...");
//查找切片
String userId = interactionContext.getUserId();
List<MemoryResult> memoryResultList = new ArrayList<>();
setMemoryResultList(memoryResultList, extractorResult.getMatches(), userId);
//评估切片
EvaluatorInput evaluatorInput = EvaluatorInput.builder()
.input(interactionContext.getInput())
.memoryResults(memoryResultList)
.messages(cognationCapability.getChatMessages())
.build();
log.debug("[MemorySelector] 切片评估输入: {}", JSONObject.toJSONString(evaluatorInput));
List<EvaluatedSlice> memorySlices = sliceSelectEvaluator.execute(evaluatorInput);
log.debug("[MemorySelector] 切片评估结果: {}", JSONObject.toJSONString(memorySlices));
return memorySlices;
}
private void setModuleContextRecall(InteractionContext interactionContext) {
String userId = interactionContext.getUserId();
boolean recall = cognationCapability.hasActivatedSlices(userId);
interactionContext.getModuleContext().getExtraContext().put("recall", recall);
if (recall) {
interactionContext.getModuleContext().getExtraContext().put("recall_count", cognationCapability.getActivatedSlicesSize(userId));
}
}
private void setMemoryResultList(List<MemoryResult> memoryResultList, List<ExtractorMatchData> matches, String userId) throws IOException, ClassNotFoundException {
for (ExtractorMatchData match : matches) {
try {
MemoryResult memoryResult = switch (match.getType()) {
case ExtractorMatchData.Constant.TOPIC -> memoryCapability.selectMemory(match.getText());
case ExtractorMatchData.Constant.DATE ->
memoryCapability.selectMemory(LocalDate.parse(match.getText()));
default -> null;
};
if (memoryResult == null || memoryResult.isEmpty()) continue;
removeDuplicateSlice(memoryResult);
memoryResultList.add(memoryResult);
} catch (UnExistedDateIndexException | UnExistedTopicException e) {
log.error("[MemorySelector] 不存在的记忆索引! 请尝试更换更合适的主题提取LLM!", e);
log.error("[MemorySelector] 错误索引: {}", match.getText());
}
}
//清理切片记录
memoryCapability.cleanSelectedSliceFilter();
//根据userInfo过滤是否为私人记忆
for (MemoryResult memoryResult : memoryResultList) {
//过滤终点记忆
memoryResult.getMemorySliceResult().removeIf(m -> removeOrNot(m.getMemorySlice(), userId));
//过滤邻近记忆
memoryResult.getRelatedMemorySliceResult().removeIf(m -> removeOrNot(m, userId));
}
}
private void removeDuplicateSlice(MemoryResult memoryResult) {
Collection<String> values = cacheCapability.getDialogMap().values();
memoryResult.getRelatedMemorySliceResult().removeIf(m -> values.contains(m.getSummary()));
memoryResult.getMemorySliceResult().removeIf(m -> values.contains(m.getMemorySlice().getSummary()));
}
private boolean removeOrNot(MemorySlice memorySlice, String userId) {
if (memorySlice.isPrivate()) {
return memorySlice.getStartUserId().equals(userId);
}
return false;
}
@Override
public String moduleName() {
return "[记忆模块]";
}
protected HashMap<String, String> getPromptDataMap(String userId) {
HashMap<String, String> map = new HashMap<>();
String dialogMapStr = cacheCapability.getDialogMapStr();
if (!dialogMapStr.isEmpty()) {
map.put("[记忆缓存] <你最近两日和所有聊天者的对话记忆印象>", dialogMapStr);
}
String userDialogMapStr = cacheCapability.getUserDialogMapStr(userId);
if (userDialogMapStr != null && !userDialogMapStr.isEmpty() && !cognationCapability.isSingleUser()) {
map.put("[用户记忆缓存] <与最新一条消息的发送者的近两天对话记忆印象, 可能与[记忆缓存]稍有重复>", userDialogMapStr);
}
String sliceStr = cognationCapability.getActivatedSlicesStr(userId);
if (sliceStr != null && !sliceStr.isEmpty()) {
map.put("[记忆切片] <你与最新一条消息的发送者的相关回忆, 不会与[记忆缓存]重复, 如果有重复你也可以指出来()>", sliceStr);
}
return map;
}
}

View File

@@ -0,0 +1,141 @@
package work.slhaf.partner.module.modules.memory.selector.evaluator;
import cn.hutool.core.date.DateUtil;
import cn.hutool.json.JSONUtil;
import com.alibaba.fastjson2.JSONObject;
import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.extern.slf4j.Slf4j;
import work.slhaf.partner.common.thread.InteractionThreadPoolExecutor;
import work.slhaf.partner.core.cognation.common.pojo.MemoryResult;
import work.slhaf.partner.core.cognation.common.pojo.MemorySliceResult;
import work.slhaf.partner.core.cognation.submodule.memory.pojo.MemorySlice;
import work.slhaf.partner.module.common.Model;
import work.slhaf.partner.module.common.ModelConstant;
import work.slhaf.partner.module.modules.memory.selector.evaluator.data.EvaluatorBatchInput;
import work.slhaf.partner.module.modules.memory.selector.evaluator.data.EvaluatorInput;
import work.slhaf.partner.module.modules.memory.selector.evaluator.data.EvaluatorResult;
import work.slhaf.partner.module.modules.memory.selector.evaluator.data.SliceSummary;
import work.slhaf.partner.core.cognation.submodule.memory.pojo.EvaluatedSlice;
import java.io.IOException;
import java.util.*;
import java.util.concurrent.Callable;
import java.util.concurrent.ConcurrentLinkedDeque;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import static work.slhaf.partner.common.util.ExtractUtil.extractJson;
@EqualsAndHashCode(callSuper = true)
@Data
@Slf4j
public class SliceSelectEvaluator extends Model {
private static volatile SliceSelectEvaluator sliceSelectEvaluator;
private InteractionThreadPoolExecutor executor;
private SliceSelectEvaluator() {
}
public static SliceSelectEvaluator getInstance() throws IOException, ClassNotFoundException {
if (sliceSelectEvaluator == null) {
synchronized (SliceSelectEvaluator.class) {
if (sliceSelectEvaluator == null) {
sliceSelectEvaluator = new SliceSelectEvaluator();
sliceSelectEvaluator.setExecutor(InteractionThreadPoolExecutor.getInstance());
setModel(sliceSelectEvaluator, ModelConstant.Prompt.MEMORY, false);
log.info("SliceEvaluator注册完毕...");
}
}
}
return sliceSelectEvaluator;
}
public List<EvaluatedSlice> execute(EvaluatorInput evaluatorInput) {
log.debug("[SliceSelectEvaluator] 切片评估模块开始...");
List<MemoryResult> memoryResultList = evaluatorInput.getMemoryResults();
List<Callable<Void>> tasks = new ArrayList<>();
Queue<EvaluatedSlice> queue = new ConcurrentLinkedDeque<>();
AtomicInteger count = new AtomicInteger(0);
for (MemoryResult memoryResult : memoryResultList) {
if (memoryResult.getMemorySliceResult().isEmpty() && memoryResult.getRelatedMemorySliceResult().isEmpty()) {
continue;
}
tasks.add(() -> {
int thisCount = count.incrementAndGet();
log.debug("[SliceSelectEvaluator] 评估[{}]开始", thisCount);
List<SliceSummary> sliceSummaryList = new ArrayList<>();
//映射查找键值
Map<Long, SliceSummary> map = new HashMap<>();
try {
setSliceSummaryList(memoryResult, sliceSummaryList, map);
EvaluatorBatchInput batchInput = EvaluatorBatchInput.builder()
.text(evaluatorInput.getInput())
.memory_slices(sliceSummaryList)
.history(evaluatorInput.getMessages())
.build();
log.debug("[SliceSelectEvaluator] 评估[{}]输入: {}", thisCount, JSONObject.toJSONString(batchInput));
EvaluatorResult evaluatorResult = JSONObject.parseObject(extractJson(singleChat(JSONUtil.toJsonStr(batchInput)).getMessage()), EvaluatorResult.class);
log.debug("[SliceSelectEvaluator] 评估[{}]结果: {}", thisCount, JSONObject.toJSONString(evaluatorResult));
for (Long result : evaluatorResult.getResults()) {
SliceSummary sliceSummary = map.get(result);
EvaluatedSlice evaluatedSlice = EvaluatedSlice.builder()
.summary(sliceSummary.getSummary())
.date(sliceSummary.getDate())
.build();
// setEvaluatedSliceMessages(evaluatedSlice, memoryResult, sliceSummary.getId());
queue.offer(evaluatedSlice);
}
} catch (Exception e) {
log.error("[SliceSelectEvaluator] 评估[{}]出现错误: {}", thisCount, e.getLocalizedMessage());
}
return null;
});
}
executor.invokeAll(tasks, 30, TimeUnit.SECONDS);
log.debug("[SliceSelectEvaluator] 评估模块结束, 输出队列: {}", queue);
List<EvaluatedSlice> temp = queue.stream().toList();
return new ArrayList<>(temp);
}
private void setSliceSummaryList(MemoryResult memoryResult, List<SliceSummary> sliceSummaryList, Map<Long, SliceSummary> map) {
for (MemorySliceResult memorySliceResult : memoryResult.getMemorySliceResult()) {
SliceSummary sliceSummary = new SliceSummary();
sliceSummary.setId(memorySliceResult.getMemorySlice().getTimestamp());
StringBuilder stringBuilder = new StringBuilder();
if (memorySliceResult.getSliceBefore() != null) {
stringBuilder.append(memorySliceResult.getSliceBefore().getSummary())
.append("\r\n");
}
stringBuilder.append(memorySliceResult.getMemorySlice().getSummary());
if (memorySliceResult.getSliceAfter() != null) {
stringBuilder.append("\r\n")
.append(memorySliceResult.getSliceAfter().getSummary())
.append("\r\n");
}
sliceSummary.setSummary(stringBuilder.toString());
Long timestamp = memorySliceResult.getMemorySlice().getTimestamp();
sliceSummary.setDate(DateUtil.date(timestamp).toLocalDateTime().toLocalDate());
sliceSummaryList.add(sliceSummary);
map.put(timestamp, sliceSummary);
}
for (MemorySlice memorySlice : memoryResult.getRelatedMemorySliceResult()) {
SliceSummary sliceSummary = new SliceSummary();
sliceSummary.setId(memorySlice.getTimestamp());
sliceSummary.setSummary(memorySlice.getSummary());
sliceSummaryList.add(sliceSummary);
map.put(memorySlice.getTimestamp(), sliceSummary);
}
}
@Override
protected String modelKey() {
return "slice_evaluator";
}
}

View File

@@ -0,0 +1,15 @@
package work.slhaf.partner.module.modules.memory.selector.evaluator.data;
import lombok.Builder;
import lombok.Data;
import work.slhaf.partner.common.chat.pojo.Message;
import java.util.List;
@Data
@Builder
public class EvaluatorBatchInput {
private String text;
private List<Message> history;
private List<SliceSummary> memory_slices;
}

View File

@@ -0,0 +1,16 @@
package work.slhaf.partner.module.modules.memory.selector.evaluator.data;
import lombok.Builder;
import lombok.Data;
import work.slhaf.partner.common.chat.pojo.Message;
import work.slhaf.partner.core.cognation.common.pojo.MemoryResult;
import java.util.List;
@Data
@Builder
public class EvaluatorInput {
private String input;
private List<Message> messages;
private List<MemoryResult> memoryResults;
}

View File

@@ -0,0 +1,10 @@
package work.slhaf.partner.module.modules.memory.selector.evaluator.data;
import lombok.Data;
import java.util.List;
@Data
public class EvaluatorResult {
private List<Long> results;
}

View File

@@ -0,0 +1,12 @@
package work.slhaf.partner.module.modules.memory.selector.evaluator.data;
import lombok.Data;
import java.time.LocalDate;
@Data
public class SliceSummary {
private String summary;
private Long id;
private LocalDate date;
}

View File

@@ -0,0 +1,112 @@
package work.slhaf.partner.module.modules.memory.selector.extractor;
import cn.hutool.json.JSONUtil;
import com.alibaba.fastjson2.JSONObject;
import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.extern.slf4j.Slf4j;
import work.slhaf.partner.common.chat.pojo.Message;
import work.slhaf.partner.common.chat.pojo.MetaMessage;
import work.slhaf.partner.common.exception_handler.GlobalExceptionHandler;
import work.slhaf.partner.common.exception_handler.pojo.GlobalException;
import work.slhaf.partner.core.cognation.capability.ability.CognationCapability;
import work.slhaf.partner.core.cognation.CognationManager;
import work.slhaf.partner.core.cognation.capability.ability.MemoryCapability;
import work.slhaf.partner.core.interaction.data.context.InteractionContext;
import work.slhaf.partner.core.session.SessionManager;
import work.slhaf.partner.module.common.Model;
import work.slhaf.partner.module.common.ModelConstant;
import work.slhaf.partner.module.modules.memory.selector.extractor.data.ExtractorInput;
import work.slhaf.partner.module.modules.memory.selector.extractor.data.ExtractorMatchData;
import work.slhaf.partner.module.modules.memory.selector.extractor.data.ExtractorResult;
import work.slhaf.partner.core.cognation.submodule.memory.pojo.EvaluatedSlice;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import static work.slhaf.partner.common.util.ExtractUtil.extractJson;
import static work.slhaf.partner.common.util.ExtractUtil.fixTopicPath;
@EqualsAndHashCode(callSuper = true)
@Data
@Slf4j
public class MemorySelectExtractor extends Model {
private static volatile MemorySelectExtractor memorySelectExtractor;
private MemoryCapability memoryCapability;
private CognationCapability cognationCapability;
private SessionManager sessionManager;
private MemorySelectExtractor() {
}
public static MemorySelectExtractor getInstance() throws IOException, ClassNotFoundException {
if (memorySelectExtractor == null) {
synchronized (MemorySelectExtractor.class) {
if (memorySelectExtractor == null) {
memorySelectExtractor = new MemorySelectExtractor();
memorySelectExtractor.setMemoryCapability(CognationManager.getInstance());
memorySelectExtractor.setCognationCapability(CognationManager.getInstance());
memorySelectExtractor.setSessionManager(SessionManager.getInstance());
setModel(memorySelectExtractor, ModelConstant.Prompt.MEMORY, false);
}
}
}
return memorySelectExtractor;
}
public ExtractorResult execute(InteractionContext context) {
log.debug("[MemorySelectExtractor] 主题提取模块开始...");
//结构化为指定格式
List<Message> chatMessages = new ArrayList<>();
List<MetaMessage> metaMessages = sessionManager.getSingleMetaMessageMap().get(context.getUserId());
if (metaMessages == null) {
sessionManager.getSingleMetaMessageMap().put(context.getUserId(), new ArrayList<>());
} else {
for (MetaMessage metaMessage : metaMessages) {
chatMessages.add(metaMessage.getUserMessage());
chatMessages.add(metaMessage.getAssistantMessage());
}
}
ExtractorResult extractorResult;
try {
List<EvaluatedSlice> activatedMemorySlices = cognationCapability.getActivatedSlices(context.getUserId());
ExtractorInput extractorInput = ExtractorInput.builder()
.text(context.getInput())
.date(context.getDateTime().toLocalDate())
.history(chatMessages)
.topic_tree(memoryCapability.getTopicTree())
.activatedMemorySlices(activatedMemorySlices)
.build();
log.debug("[MemorySelectExtractor] 主题提取输入: {}", JSONObject.toJSONString(extractorInput));
String responseStr = extractJson(singleChat(JSONUtil.toJsonPrettyStr(extractorInput)).getMessage());
extractorResult = JSONObject.parseObject(responseStr, ExtractorResult.class);
log.debug("[MemorySelectExtractor] 主题提取结果: {}", extractorResult);
} catch (Exception e) {
log.error("[MemorySelectExtractor] 主题提取出错: ", e);
GlobalExceptionHandler.writeExceptionState(new GlobalException(e.getLocalizedMessage()));
extractorResult = new ExtractorResult();
extractorResult.setRecall(false);
extractorResult.setMatches(List.of());
}
return fix(extractorResult);
}
private ExtractorResult fix(ExtractorResult extractorResult) {
extractorResult.getMatches().forEach(m -> {
if (m.getType().equals(ExtractorMatchData.Constant.DATE)) {
return;
}
m.setText(fixTopicPath(m.getText()));
});
extractorResult.getMatches().removeIf(m -> m.getText().split("->")[0].isEmpty());
return extractorResult;
}
@Override
protected String modelKey() {
return "topic_extractor";
}
}

View File

@@ -0,0 +1,19 @@
package work.slhaf.partner.module.modules.memory.selector.extractor.data;
import lombok.Builder;
import lombok.Data;
import work.slhaf.partner.common.chat.pojo.Message;
import work.slhaf.partner.core.cognation.submodule.memory.pojo.EvaluatedSlice;
import java.time.LocalDate;
import java.util.List;
@Data
@Builder
public class ExtractorInput {
private String text;
private String topic_tree;
private LocalDate date;
private List<Message> history;
private List<EvaluatedSlice> activatedMemorySlices;
}

View File

@@ -0,0 +1,14 @@
package work.slhaf.partner.module.modules.memory.selector.extractor.data;
import lombok.Data;
@Data
public class ExtractorMatchData {
private String type;
private String text;
public static class Constant {
public static final String DATE = "date";
public static final String TOPIC = "topic";
}
}

View File

@@ -0,0 +1,11 @@
package work.slhaf.partner.module.modules.memory.selector.extractor.data;
import lombok.Data;
import java.util.List;
@Data
public class ExtractorResult {
private boolean recall;
private List<ExtractorMatchData> matches;
}

View File

@@ -0,0 +1,278 @@
package work.slhaf.partner.module.modules.memory.updater;
import com.alibaba.fastjson2.JSONObject;
import lombok.Data;
import lombok.extern.slf4j.Slf4j;
import work.slhaf.partner.common.chat.constant.ChatConstant;
import work.slhaf.partner.common.chat.pojo.Message;
import work.slhaf.partner.common.thread.InteractionThreadPoolExecutor;
import work.slhaf.partner.core.cognation.capability.ability.CognationCapability;
import work.slhaf.partner.core.cognation.CognationManager;
import work.slhaf.partner.core.cognation.capability.ability.CacheCapability;
import work.slhaf.partner.core.cognation.capability.ability.MemoryCapability;
import work.slhaf.partner.core.cognation.submodule.memory.pojo.MemorySlice;
import work.slhaf.partner.core.cognation.capability.ability.PerceiveCapability;
import work.slhaf.partner.core.interaction.data.context.InteractionContext;
import work.slhaf.partner.core.interaction.module.InteractionModule;
import work.slhaf.partner.core.session.SessionManager;
import work.slhaf.partner.module.modules.memory.selector.extractor.MemorySelectExtractor;
import work.slhaf.partner.module.modules.memory.updater.summarizer.MemorySummarizer;
import work.slhaf.partner.module.modules.memory.updater.summarizer.data.SummarizeInput;
import work.slhaf.partner.module.modules.memory.updater.summarizer.data.SummarizeResult;
import java.io.IOException;
import java.time.LocalDateTime;
import java.util.*;
import java.util.concurrent.Callable;
import java.util.concurrent.atomic.AtomicInteger;
import static work.slhaf.partner.common.util.ExtractUtil.extractUserId;
@Data
@Slf4j
public class MemoryUpdater implements InteractionModule {
private static volatile MemoryUpdater memoryUpdater;
private static final long SCHEDULED_UPDATE_INTERVAL = 10 * 1000;
private static final long UPDATE_TRIGGER_INTERVAL = 60 * 60 * 1000;
private CognationCapability cognationCapability;
private MemoryCapability memoryCapability;
private CacheCapability cacheCapability;
private PerceiveCapability perceiveCapability;
private InteractionThreadPoolExecutor executor;
private MemorySelectExtractor memorySelectExtractor;
private MemorySummarizer memorySummarizer;
private SessionManager sessionManager;
/**
* 用于临时存储完整对话记录在MemoryManager的分离后
*/
private List<Message> tempMessage;
private MemoryUpdater() {
}
public static MemoryUpdater getInstance() throws IOException, ClassNotFoundException {
if (memoryUpdater == null) {
synchronized (MemoryUpdater.class) {
if (memoryUpdater == null) {
memoryUpdater = new MemoryUpdater();
memoryUpdater.setCognationCapability(CognationManager.getInstance());
memoryUpdater.setMemoryCapability(CognationManager.getInstance());
memoryUpdater.setCacheCapability(CognationManager.getInstance());
memoryUpdater.setPerceiveCapability(CognationManager.getInstance());
memoryUpdater.setMemorySelectExtractor(MemorySelectExtractor.getInstance());
memoryUpdater.setMemorySummarizer(MemorySummarizer.getInstance());
memoryUpdater.setSessionManager(SessionManager.getInstance());
memoryUpdater.setExecutor(InteractionThreadPoolExecutor.getInstance());
memoryUpdater.setScheduledUpdater();
}
}
}
return memoryUpdater;
}
private void setScheduledUpdater() {
executor.execute(() -> {
log.info("[MemoryUpdater] 记忆自动更新线程启动");
while (!Thread.interrupted()) {
try {
long currentTime = System.currentTimeMillis();
long lastUpdatedTime = sessionManager.getLastUpdatedTime();
int chatCount = cognationCapability.getChatMessages().size();
if (lastUpdatedTime != 0 && currentTime - lastUpdatedTime > UPDATE_TRIGGER_INTERVAL && chatCount > 1) {
updateMemory();
cognationCapability.getChatMessages().clear();
//重置MemoryId
sessionManager.refreshMemoryId();
log.info("[MemoryUpdater] 记忆更新: 自动触发");
}
Thread.sleep(SCHEDULED_UPDATE_INTERVAL);
} catch (Exception e) {
log.error("[MemoryUpdater] 记忆自动更新线程出错: ", e);
}
}
log.info("[MemoryUpdater] 记忆自动更新线程结束");
});
}
@Override
public void execute(InteractionContext context) {
if (context.isFinished()) {
log.warn("[MemoryUpdater] 流程强制结束, 不触发记忆被动更新机制");
return;
}
executor.execute(() -> {
//如果token 大于阈值,则更新记忆
JSONObject moduleContext = context.getModuleContext().getExtraContext();
boolean recall = moduleContext.getBoolean("recall");
if (recall) {
log.debug("[MemoryUpdater] 存在回忆");
int recallCount = moduleContext.getIntValue("recall_count");
log.debug("[MemoryUpdater] 记忆切片数量 [{}]", recallCount);
}
boolean trigger = context.getModuleContext().getExtraContext().getBoolean("post_process_trigger");
if (!trigger) {
return;
}
try {
log.debug("[MemoryUpdater] 记忆更新触发");
updateMemory();
//清空chatMessages
clearChatMessages();
} catch (Exception e) {
log.error("[MemoryUpdater] 记忆更新线程出错: ", e);
}
});
}
private void updateMemory() {
log.debug("[MemoryUpdater] 记忆更新流程开始...");
tempMessage = new ArrayList<>(cognationCapability.getChatMessages());
HashMap<String, String> singleMemorySummary = new HashMap<>();
//更新单聊记忆同时从chatMessages中去掉单聊记忆
updateSingleChatSlices(singleMemorySummary);
//更新多人场景下的记忆及相关的确定性记忆
updateMultiChatSlices(singleMemorySummary);
sessionManager.resetLastUpdatedTime();
log.debug("[MemoryUpdater] 记忆更新流程结束...");
}
private void updateMultiChatSlices(HashMap<String, String> singleMemorySummary) {
//此时chatMessages中不再包含单聊记录直接执行摘要以及切片插入
//对剩下的多人聊天记录进行进行摘要
Callable<Void> task = () -> {
log.debug("[MemoryUpdater] 多人聊天记忆更新流程开始...");
List<Message> chatMessages;
cognationCapability.getMessageLock().lock();
chatMessages = new ArrayList<>(cognationCapability.getChatMessages());
cognationCapability.getMessageLock().unlock();
cleanMessage(chatMessages);
if (!chatMessages.isEmpty()) {
log.debug("[MemoryUpdater] 存在多人聊天记录, 流程正常进行...");
//以第一条user对应的id为发起用户
String userId = extractUserId(chatMessages.getFirst().getContent());
if (userId == null) {
throw new RuntimeException("未匹配到 userId!");
}
SummarizeInput summarizeInput = new SummarizeInput(chatMessages, memoryCapability.getTopicTree());
log.debug("[MemoryUpdater] 多人聊天记忆更新-总结流程-输入: {}", summarizeInput);
SummarizeResult summarizeResult = memorySummarizer.execute(summarizeInput);
log.debug("[MemoryUpdater] 多人聊天记忆更新-总结流程-输出: {}", summarizeResult);
MemorySlice memorySlice = getMemorySlice(userId, summarizeResult, chatMessages);
//设置involvedUserId
setInvolvedUserId(userId, memorySlice, chatMessages);
memoryCapability.insertSlice(memorySlice, summarizeResult.getTopicPath());
cacheCapability.updateDialogMap(LocalDateTime.now(), summarizeResult.getSummary());
} else {
log.debug("[MemoryUpdater] 不存在多人聊天记录, 将以单聊总结为对话缓存的主要输入: {}", singleMemorySummary);
cacheCapability.updateDialogMap(LocalDateTime.now(), memorySummarizer.executeTotalSummary(singleMemorySummary));
}
log.debug("[MemoryUpdater] 对话缓存更新完毕");
log.debug("[MemoryUpdater] 多人聊天记忆更新流程结束...");
return null;
};
executor.invokeAll(List.of(task));
}
private void cleanMessage(List<Message> chatMessages) {
//清理时间标识
for (Message message : chatMessages) {
if (message.getRole().equals(ChatConstant.Character.ASSISTANT)) {
continue;
}
String time = Arrays.stream(message.getContent().split("\\*\\*")).toList().getLast();
message.setContent(message.getContent().replace("\r\n**" + time, ""));
}
}
private void clearChatMessages() {
//不全部清空,保留一部分输入防止上下文割裂
cognationCapability.getMessageLock().lock();
List<Message> temp = new ArrayList<>(tempMessage.subList(tempMessage.size() - tempMessage.size() / 6, tempMessage.size()));
cognationCapability.getChatMessages().removeAll(tempMessage);
cognationCapability.getChatMessages().addAll(0, temp);
cognationCapability.getMessageLock().unlock();
}
private void setInvolvedUserId(String startUserId, MemorySlice memorySlice, List<Message> chatMessages) {
for (Message chatMessage : chatMessages) {
if (chatMessage.getRole().equals(ChatConstant.Character.ASSISTANT)) {
continue;
}
//匹配userId
String userId = extractUserId(chatMessage.getContent());
if (userId == null) {
continue;
}
if (userId.equals(startUserId)) {
continue;
}
memorySlice.setInvolvedUserIds(new ArrayList<>());
memorySlice.getInvolvedUserIds().add(userId);
}
}
private void updateSingleChatSlices(HashMap<String, String> singleMemorySummary) {
log.debug("[MemoryUpdater] 单聊记忆更新流程开始...");
//更新单聊记忆同时从chatMessages中去掉单聊记忆
Set<String> userIdSet = new HashSet<>(sessionManager.getSingleMetaMessageMap().keySet());
List<Callable<Void>> tasks = new ArrayList<>();
//多人聊天?
AtomicInteger count = new AtomicInteger(0);
for (String id : userIdSet) {
List<Message> messages = sessionManager.unpackAndClear(id);
tasks.add(() -> {
int thisCount = count.incrementAndGet();
log.debug("[MemoryUpdater] 单聊记忆[{}]更新: {}", thisCount, id);
try {
//单聊记忆更新
SummarizeInput summarizeInput = new SummarizeInput(messages, memoryCapability.getTopicTree());
log.debug("[MemoryUpdater] 单聊记忆[{}]更新-总结流程-输入: {}", thisCount, JSONObject.toJSONString(summarizeInput));
SummarizeResult summarizeResult = memorySummarizer.execute(summarizeInput);
log.debug("[MemoryUpdater] 单聊记忆[{}]更新-总结流程-输出: {}", thisCount, JSONObject.toJSONString(summarizeResult));
MemorySlice memorySlice = getMemorySlice(id, summarizeResult, messages);
//插入时userDialogMap已经进行更新
memoryCapability.insertSlice(memorySlice, summarizeResult.getTopicPath());
//从chatMessages中移除单聊记录
cognationCapability.cleanMessage(messages);
//添加至singleMemorySummary
String key = perceiveCapability.getUser(id).getNickName() + "[" + id + "]";
singleMemorySummary.put(key, summarizeResult.getSummary());
log.debug("[MemoryUpdater] 单聊记忆[{}]更新成功: ", thisCount);
} catch (Exception e) {
log.error("[MemoryUpdater] 单聊记忆[{}]更新出错: ", thisCount, e);
}
return null;
});
}
executor.invokeAll(tasks);
log.debug("[MemoryUpdater] 单聊记忆更新结束...");
}
private MemorySlice getMemorySlice(String userId, SummarizeResult summarizeResult, List<Message> chatMessages) {
MemorySlice memorySlice = new MemorySlice();
//设置 memoryId,timestamp
memorySlice.setMemoryId(sessionManager.getCurrentMemoryId());
memorySlice.setTimestamp(System.currentTimeMillis());
//补充信息
memorySlice.setPrivate(summarizeResult.isPrivate());
memorySlice.setSummary(summarizeResult.getSummary());
memorySlice.setChatMessages(chatMessages);
memorySlice.setStartUserId(userId);
List<List<String>> relatedTopicPathList = new ArrayList<>();
for (String string : summarizeResult.getRelatedTopicPath()) {
List<String> list = Arrays.stream(string.split("->")).toList();
relatedTopicPathList.add(list);
}
memorySlice.setRelatedTopics(relatedTopicPathList);
return memorySlice;
}
}

View File

@@ -0,0 +1,7 @@
package work.slhaf.partner.module.modules.memory.updater.exception;
public class UnExpectedMessageCountException extends RuntimeException {
public UnExpectedMessageCountException(String message) {
super(message);
}
}

View File

@@ -0,0 +1,48 @@
package work.slhaf.partner.module.modules.memory.updater.summarizer;
import lombok.Data;
import lombok.extern.slf4j.Slf4j;
import work.slhaf.partner.common.thread.InteractionThreadPoolExecutor;
import work.slhaf.partner.module.modules.memory.updater.summarizer.data.SummarizeInput;
import work.slhaf.partner.module.modules.memory.updater.summarizer.data.SummarizeResult;
import java.util.HashMap;
@Data
@Slf4j
public class MemorySummarizer {
private static volatile MemorySummarizer memorySummarizer;
public static final String MODEL_KEY = "memory_summarizer";
private InteractionThreadPoolExecutor executor;
private SingleSummarizer singleSummarizer;
private MultiSummarizer multiSummarizer;
private TotalSummarizer totalSummarizer;
public static MemorySummarizer getInstance() {
if (memorySummarizer == null) {
synchronized (MemorySummarizer.class) {
if (memorySummarizer == null) {
memorySummarizer = new MemorySummarizer();
memorySummarizer.setExecutor(InteractionThreadPoolExecutor.getInstance());
memorySummarizer.setSingleSummarizer(SingleSummarizer.getInstance());
memorySummarizer.setMultiSummarizer(MultiSummarizer.getInstance());
memorySummarizer.setTotalSummarizer(TotalSummarizer.getInstance());
}
}
}
return memorySummarizer;
}
public SummarizeResult execute(SummarizeInput input) {
//进行长文本批量摘要
singleSummarizer.execute(input.getChatMessages());
//进行整体摘要并返回结果
return multiSummarizer.execute(input);
}
public String executeTotalSummary(HashMap<String, String> singleMemorySummary) {
return totalSummarizer.execute(singleMemorySummary);
}
}

View File

@@ -0,0 +1,67 @@
package work.slhaf.partner.module.modules.memory.updater.summarizer;
import cn.hutool.json.JSONUtil;
import com.alibaba.fastjson2.JSONObject;
import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.extern.slf4j.Slf4j;
import work.slhaf.partner.common.chat.pojo.ChatResponse;
import work.slhaf.partner.module.common.Model;
import work.slhaf.partner.module.common.ModelConstant;
import work.slhaf.partner.module.modules.memory.updater.summarizer.data.SummarizeInput;
import work.slhaf.partner.module.modules.memory.updater.summarizer.data.SummarizeResult;
import java.util.ArrayList;
import java.util.List;
import static work.slhaf.partner.common.util.ExtractUtil.extractJson;
import static work.slhaf.partner.common.util.ExtractUtil.fixTopicPath;
@EqualsAndHashCode(callSuper = true)
@Data
@Slf4j
public class MultiSummarizer extends Model {
private static volatile MultiSummarizer multiSummarizer;
public static MultiSummarizer getInstance() {
if (multiSummarizer == null) {
synchronized (MultiSummarizer.class) {
if (multiSummarizer == null) {
multiSummarizer = new MultiSummarizer();
setModel(multiSummarizer, ModelConstant.Prompt.MEMORY, true);
multiSummarizer.updateChatClientSettings();
}
}
}
return multiSummarizer;
}
public SummarizeResult execute(SummarizeInput input) {
log.debug("[MemorySummarizer] 整体摘要开始...");
ChatResponse response = this.singleChat(JSONUtil.toJsonPrettyStr(input));
log.debug("[MemorySummarizer] 整体摘要结果: {}", JSONObject.toJSONString(response));
SummarizeResult result = JSONObject.parseObject(extractJson(response.getMessage()), SummarizeResult.class);
return fix(result);
}
private SummarizeResult fix(SummarizeResult result) {
if (result == null || result.getTopicPath() == null || result.getTopicPath().isEmpty()) {
return result;
}
String topicPath = fixTopicPath(result.getTopicPath());
List<String> relatedTopicPath = new ArrayList<>();
for (String s : result.getRelatedTopicPath()) {
relatedTopicPath.add(fixTopicPath(s));
}
result.setTopicPath(topicPath);
result.setRelatedTopicPath(relatedTopicPath);
return result;
}
@Override
protected String modelKey() {
return "multi_summarizer";
}
}

View File

@@ -0,0 +1,78 @@
package work.slhaf.partner.module.modules.memory.updater.summarizer;
import com.alibaba.fastjson2.JSONObject;
import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.extern.slf4j.Slf4j;
import work.slhaf.partner.common.chat.constant.ChatConstant;
import work.slhaf.partner.common.chat.pojo.ChatResponse;
import work.slhaf.partner.common.chat.pojo.Message;
import work.slhaf.partner.common.thread.InteractionThreadPoolExecutor;
import work.slhaf.partner.module.common.Model;
import work.slhaf.partner.module.common.ModelConstant;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
@EqualsAndHashCode(callSuper = true)
@Slf4j
@Data
public class SingleSummarizer extends Model {
private static volatile SingleSummarizer singleSummarizer;
private InteractionThreadPoolExecutor executor;
public static SingleSummarizer getInstance() {
if (singleSummarizer == null) {
synchronized (SingleSummarizer.class) {
if (singleSummarizer == null) {
singleSummarizer = new SingleSummarizer();
singleSummarizer.setExecutor(InteractionThreadPoolExecutor.getInstance());
setModel(singleSummarizer, ModelConstant.Prompt.MEMORY, false);
}
}
}
return singleSummarizer;
}
public void execute(List<Message> chatMessages) {
log.debug("[MemorySummarizer] 长文本摘要开始...");
List<Callable<Void>> tasks = new ArrayList<>();
AtomicInteger counter = new AtomicInteger();
for (Message chatMessage : chatMessages) {
if (chatMessage.getRole().equals(ChatConstant.Character.ASSISTANT)) {
String content = chatMessage.getContent();
if (chatMessage.getContent().length() > 500) {
tasks.add(() -> {
int thisCount = counter.incrementAndGet();
log.debug("[MemorySummarizer] 长文本摘要[{}]启动", thisCount);
chatMessage.setContent(singleExecute(JSONObject.of("content", content).toString()));
log.debug("[MemorySummarizer] 长文本摘要[{}]完成", thisCount);
return null;
});
}
}
}
executor.invokeAll(tasks, 30, TimeUnit.SECONDS);
log.debug("[MemorySummarizer] 长文本摘要结束");
}
private String singleExecute(String primaryContent) {
try {
ChatResponse response = this.singleChat(primaryContent);
return response.getMessage();
} catch (Exception e) {
log.error("[SingleSummarizer] 单消息总结出错: ", e);
return primaryContent;
}
}
@Override
protected String modelKey() {
return "single_summarizer";
}
}

View File

@@ -0,0 +1,45 @@
package work.slhaf.partner.module.modules.memory.updater.summarizer;
import cn.hutool.json.JSONUtil;
import com.alibaba.fastjson2.JSONObject;
import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.extern.slf4j.Slf4j;
import work.slhaf.partner.common.chat.pojo.ChatResponse;
import work.slhaf.partner.module.common.Model;
import work.slhaf.partner.module.common.ModelConstant;
import java.util.HashMap;
import static work.slhaf.partner.common.util.ExtractUtil.extractJson;
@EqualsAndHashCode(callSuper = true)
@Data
@Slf4j
public class TotalSummarizer extends Model {
private static volatile TotalSummarizer totalSummarizer;
public static TotalSummarizer getInstance() {
if (totalSummarizer == null) {
synchronized (TotalSummarizer.class) {
if (totalSummarizer == null) {
totalSummarizer = new TotalSummarizer();
setModel(totalSummarizer, ModelConstant.Prompt.MEMORY, true);
totalSummarizer.updateChatClientSettings();
}
}
}
return totalSummarizer;
}
public String execute(HashMap<String, String> singleMemorySummary){
ChatResponse response = this.singleChat(JSONUtil.toJsonPrettyStr(singleMemorySummary));
return JSONObject.parseObject(extractJson(response.getMessage())).getString("content");
}
@Override
protected String modelKey() {
return "total_summarizer";
}
}

View File

@@ -0,0 +1,14 @@
package work.slhaf.partner.module.modules.memory.updater.summarizer.data;
import lombok.AllArgsConstructor;
import lombok.Data;
import work.slhaf.partner.common.chat.pojo.Message;
import java.util.List;
@AllArgsConstructor
@Data
public class SummarizeInput {
private List<Message> chatMessages;
private String topicTree;
}

View File

@@ -0,0 +1,13 @@
package work.slhaf.partner.module.modules.memory.updater.summarizer.data;
import lombok.Data;
import java.util.List;
@Data
public class SummarizeResult {
private String summary;
private String topicPath;
private List<String> relatedTopicPath;
private boolean isPrivate;
}

View File

@@ -0,0 +1,58 @@
package work.slhaf.partner.module.modules.perceive.selector;
import lombok.Setter;
import lombok.extern.slf4j.Slf4j;
import work.slhaf.partner.core.cognation.CognationManager;
import work.slhaf.partner.core.cognation.capability.ability.PerceiveCapability;
import work.slhaf.partner.core.cognation.submodule.perceive.pojo.User;
import work.slhaf.partner.core.interaction.data.context.InteractionContext;
import work.slhaf.partner.module.common.PreModule;
import java.io.IOException;
import java.util.HashMap;
@Slf4j
@Setter
public class PerceiveSelector extends PreModule {
private static volatile PerceiveSelector perceiveSelector;
private PerceiveCapability perceiveCapability;
public static PerceiveSelector getInstance() throws IOException, ClassNotFoundException {
if (perceiveSelector == null) {
synchronized (PerceiveSelector.class) {
if (perceiveSelector == null) {
perceiveSelector = new PerceiveSelector();
perceiveSelector.setPerceiveCapability(CognationManager.getInstance());
}
}
}
return perceiveSelector;
}
@Override
public void execute(InteractionContext context) throws IOException, ClassNotFoundException {
log.debug("[PerceiveSelector] 感知模块处理流程开始...");
//处理思路: 根据用户id,查询用户相关身份感知数据直接添加到appendPrompt中这直接执行appendPrompt方法应该可以
setAppendedPrompt(context);
setActiveModule(context);
log.debug("[PerceiveSelector] 感知模块处理流程结束...");
}
@Override
protected HashMap<String, String> getPromptDataMap(String userId) {
HashMap<String, String> map = new HashMap<>();
User user = perceiveCapability.getUser(userId);
map.put("[关系] <你与最新聊天用户的关系>", user.getRelation());
map.put("[态度] <你对于最新聊天用户的态度>", user.getAttitude().toString());
map.put("[印象] <你对于最新聊天用户的印象>", user.getImpressions().toString());
map.put("[静态记忆] <你关于最新聊天用户的静态记忆>", user.getStaticMemory().toString());
return map;
}
@Override
public String moduleName() {
return "[感知模块]";
}
}

View File

@@ -0,0 +1,103 @@
package work.slhaf.partner.module.modules.perceive.updater;
import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.extern.slf4j.Slf4j;
import work.slhaf.partner.common.thread.InteractionThreadPoolExecutor;
import work.slhaf.partner.core.cognation.capability.ability.CognationCapability;
import work.slhaf.partner.core.cognation.CognationManager;
import work.slhaf.partner.core.cognation.capability.ability.PerceiveCapability;
import work.slhaf.partner.core.cognation.submodule.perceive.pojo.User;
import work.slhaf.partner.core.interaction.data.context.InteractionContext;
import work.slhaf.partner.core.interaction.module.InteractionModule;
import work.slhaf.partner.module.common.Model;
import work.slhaf.partner.module.common.ModelConstant;
import work.slhaf.partner.module.modules.perceive.updater.relation_extractor.pojo.RelationExtractResult;
import work.slhaf.partner.module.modules.perceive.updater.relation_extractor.RelationExtractor;
import work.slhaf.partner.module.modules.perceive.updater.static_extractor.StaticMemoryExtractor;
import java.io.IOException;
import java.util.*;
import java.util.concurrent.Callable;
import java.util.concurrent.locks.ReentrantLock;
/**
* 感知更新,异步
*/
@EqualsAndHashCode(callSuper = true)
@Slf4j
@Data
public class PerceiveUpdater extends Model implements InteractionModule {
private static volatile PerceiveUpdater perceiveUpdater;
private PerceiveCapability perceiveCapability;
private CognationCapability cognationCapability;
private InteractionThreadPoolExecutor executor;
private RelationExtractor relationExtractor;
private StaticMemoryExtractor staticMemoryExtractor;
public static PerceiveUpdater getInstance() throws IOException, ClassNotFoundException {
if (perceiveUpdater == null) {
synchronized (PerceiveUpdater.class) {
if (perceiveUpdater == null) {
perceiveUpdater = new PerceiveUpdater();
perceiveUpdater.setPerceiveCapability(CognationManager.getInstance());
perceiveUpdater.setCognationCapability(CognationManager.getInstance());
perceiveUpdater.setExecutor(InteractionThreadPoolExecutor.getInstance());
perceiveUpdater.setRelationExtractor(RelationExtractor.getInstance());
perceiveUpdater.setStaticMemoryExtractor(StaticMemoryExtractor.getInstance());
setModel(perceiveUpdater, ModelConstant.Prompt.PERCEIVE, true);
}
}
}
return perceiveUpdater;
}
@Override
public void execute(InteractionContext context) throws IOException, ClassNotFoundException {
executor.execute(() -> {
boolean trigger = context.getModuleContext().getExtraContext().getBoolean("perceive_updater");
if (!trigger){
return;
}
ReentrantLock userLock = new ReentrantLock();
User user = new User();
user.setUuid(context.getUserId());
List<Callable<Void>> tasks = new ArrayList<>();
tasks.add(() -> {
runStaticExtractorAction(context, userLock, user);
return null;
});
tasks.add(() -> {
runRelationExtractorAction(context, userLock, user);
return null;
});
executor.invokeAll(tasks);
perceiveCapability.updateUser(user);
});
}
private void runRelationExtractorAction(InteractionContext context, ReentrantLock userLock, User user) {
RelationExtractResult relationExtractResult = relationExtractor.execute(context);
userLock.lock();
user.setRelation(relationExtractResult.getRelation());
user.setImpressions(relationExtractResult.getImpressions());
user.setAttitude(relationExtractResult.getAttitude());
user.updateRelationChange(relationExtractResult.getRelationChangeHistory());
userLock.unlock();
}
private void runStaticExtractorAction(InteractionContext context, ReentrantLock userLock, User user) {
HashMap<String, String> newStaticMemory = staticMemoryExtractor.execute(context);
userLock.lock();
user.setStaticMemory(newStaticMemory);
userLock.unlock();
}
@Override
protected String modelKey() {
return "perceive_updater";
}
}

View File

@@ -0,0 +1,90 @@
package work.slhaf.partner.module.modules.perceive.updater.relation_extractor;
import com.alibaba.fastjson2.JSONObject;
import lombok.Data;
import work.slhaf.partner.common.chat.pojo.ChatResponse;
import work.slhaf.partner.common.chat.pojo.Message;
import work.slhaf.partner.core.cognation.capability.ability.CognationCapability;
import work.slhaf.partner.core.cognation.CognationManager;
import work.slhaf.partner.core.cognation.capability.ability.PerceiveCapability;
import work.slhaf.partner.core.cognation.submodule.perceive.pojo.User;
import work.slhaf.partner.core.interaction.data.context.InteractionContext;
import work.slhaf.partner.module.common.Model;
import work.slhaf.partner.module.common.ModelConstant;
import work.slhaf.partner.module.modules.perceive.updater.relation_extractor.pojo.RelationExtractInput;
import work.slhaf.partner.module.modules.perceive.updater.relation_extractor.pojo.RelationExtractResult;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
@Data
public class RelationExtractor extends Model {
private static volatile RelationExtractor relationExtractor;
private CognationCapability cognationCapability;
private PerceiveCapability perceiveCapability;
private List<Message> tempMessages;
public static RelationExtractor getInstance() throws IOException, ClassNotFoundException {
if (relationExtractor == null) {
synchronized (RelationExtractor.class) {
if (relationExtractor == null) {
relationExtractor = new RelationExtractor();
relationExtractor.setCognationCapability(CognationManager.getInstance());
relationExtractor.setPerceiveCapability(CognationManager.getInstance());
setModel(relationExtractor, ModelConstant.Prompt.PERCEIVE, true);
}
}
}
return relationExtractor;
}
public RelationExtractResult execute(InteractionContext context){
tempMessages = new ArrayList<>(cognationCapability.getChatMessages());
String userId = context.getUserId();
RelationExtractInput input = getRelationInput(userId);
RelationExtractResult relationExtractResult = getRelationResult(input);
User user = getTempUser(context, relationExtractResult);
perceiveCapability.updateUser(user);
return relationExtractResult;
}
private User getTempUser(InteractionContext context, RelationExtractResult relationExtractResult) {
User user = new User();
user.setUuid(context.getUserId());
user.setRelation(relationExtractResult.getRelation());
user.setImpressions(relationExtractResult.getImpressions());
user.setAttitude(relationExtractResult.getAttitude());
return user;
}
private RelationExtractResult getRelationResult(RelationExtractInput input) {
ChatResponse response = singleChat(JSONObject.toJSONString(input));
return JSONObject.parseObject(response.getMessage(), RelationExtractResult.class);
}
private RelationExtractInput getRelationInput(String userId) {
HashMap<String,String> map = new HashMap<>();
User user = perceiveCapability.getUser(userId);
map.put("[用户昵称] <用户的昵称信息>",user.getNickName());
map.put("[关系] <你与用户的关系>", user.getRelation());
map.put("[态度] <你对于用户的态度>", user.getAttitude().toString());
map.put("[印象] <你对于用户的印象>", user.getImpressions().toString());
map.put("[静态记忆] <你对该用户的事实性记忆>", user.getStaticMemory().toString());
RelationExtractInput input = new RelationExtractInput();
input.setPrimaryUserPerceive(map);
input.setChatMessages(tempMessages);
return input;
}
@Override
protected String modelKey() {
return "relation_extractor";
}
}

View File

@@ -0,0 +1,13 @@
package work.slhaf.partner.module.modules.perceive.updater.relation_extractor.pojo;
import lombok.Data;
import work.slhaf.partner.common.chat.pojo.Message;
import java.util.HashMap;
import java.util.List;
@Data
public class RelationExtractInput {
private HashMap<String,String> primaryUserPerceive;
private List<Message> chatMessages;
}

View File

@@ -0,0 +1,13 @@
package work.slhaf.partner.module.modules.perceive.updater.relation_extractor.pojo;
import lombok.Data;
import java.util.List;
@Data
public class RelationExtractResult {
private String relation;
private List<String> impressions;
private List<String> attitude;
private String relationChangeHistory;
}

View File

@@ -0,0 +1,60 @@
package work.slhaf.partner.module.modules.perceive.updater.static_extractor;
import cn.hutool.json.JSONUtil;
import com.alibaba.fastjson2.JSONObject;
import lombok.Data;
import lombok.EqualsAndHashCode;
import work.slhaf.partner.common.chat.pojo.ChatResponse;
import work.slhaf.partner.core.cognation.capability.ability.CognationCapability;
import work.slhaf.partner.core.cognation.CognationManager;
import work.slhaf.partner.core.cognation.capability.ability.PerceiveCapability;
import work.slhaf.partner.core.interaction.data.context.InteractionContext;
import work.slhaf.partner.module.common.Model;
import work.slhaf.partner.module.common.ModelConstant;
import work.slhaf.partner.module.modules.perceive.updater.static_extractor.data.StaticMemoryExtractInput;
import java.io.IOException;
import java.util.HashMap;
@EqualsAndHashCode(callSuper = true)
@Data
public class StaticMemoryExtractor extends Model {
private static volatile StaticMemoryExtractor staticMemoryExtractor;
private CognationCapability cognationCapability;
private PerceiveCapability perceiveCapability;
public static StaticMemoryExtractor getInstance() throws IOException, ClassNotFoundException {
if (staticMemoryExtractor == null) {
synchronized (StaticMemoryExtractor.class) {
if (staticMemoryExtractor == null) {
staticMemoryExtractor = new StaticMemoryExtractor();
staticMemoryExtractor.setCognationCapability(CognationManager.getInstance());
staticMemoryExtractor.setPerceiveCapability(CognationManager.getInstance());
setModel(staticMemoryExtractor, ModelConstant.Prompt.MEMORY, true);
}
}
}
return staticMemoryExtractor;
}
public HashMap<String, String> execute(InteractionContext context) {
StaticMemoryExtractInput input = StaticMemoryExtractInput.builder()
.userId(context.getUserId())
.messages(cognationCapability.getChatMessages())
.existedStaticMap(perceiveCapability.getUser(context.getUserId()).getStaticMemory())
.build();
ChatResponse response = singleChat(JSONUtil.toJsonPrettyStr(input));
JSONObject jsonObject = JSONObject.parseObject(response.getMessage());
HashMap<String, String> result = new HashMap<>();
jsonObject.forEach((k, v) -> result.put(k, (String) v));
return result;
}
@Override
protected String modelKey() {
return "static_extractor";
}
}

View File

@@ -0,0 +1,16 @@
package work.slhaf.partner.module.modules.perceive.updater.static_extractor.data;
import lombok.Builder;
import lombok.Data;
import work.slhaf.partner.common.chat.pojo.Message;
import java.util.List;
import java.util.Map;
@Data
@Builder
public class StaticMemoryExtractInput {
private String userId;
private List<Message> messages;
private Map<String,String> existedStaticMap;
}

View File

@@ -0,0 +1,39 @@
package work.slhaf.partner.module.modules.process;
import lombok.Data;
import lombok.extern.slf4j.Slf4j;
import work.slhaf.partner.core.cognation.capability.ability.CognationCapability;
import work.slhaf.partner.core.cognation.CognationManager;
import work.slhaf.partner.core.interaction.data.context.InteractionContext;
import work.slhaf.partner.core.interaction.module.InteractionModule;
import java.io.IOException;
@Slf4j
@Data
public class PostprocessExecutor implements InteractionModule {
private static volatile PostprocessExecutor postprocessExecutor;
private static final int POST_PROCESS_TRIGGER_ROLL_LIMIT = 36;
private CognationCapability cognationCapability;
public static PostprocessExecutor getInstance() throws IOException, ClassNotFoundException {
if (postprocessExecutor == null) {
synchronized (PostprocessExecutor.class) {
if (postprocessExecutor == null) {
postprocessExecutor = new PostprocessExecutor();
postprocessExecutor.setCognationCapability(CognationManager.getInstance());
}
}
}
return postprocessExecutor;
}
@Override
public void execute(InteractionContext context) throws IOException, ClassNotFoundException {
boolean trigger = cognationCapability.getChatMessages().size() >= POST_PROCESS_TRIGGER_ROLL_LIMIT;
context.getModuleContext().getExtraContext().put("post_process_trigger", trigger);
log.debug("[PostprocessExecutor] 是否执行后处理: {}", trigger);
}
}

View File

@@ -0,0 +1,104 @@
package work.slhaf.partner.module.modules.process;
import lombok.Data;
import lombok.extern.slf4j.Slf4j;
import work.slhaf.partner.core.cognation.capability.ability.CognationCapability;
import work.slhaf.partner.core.cognation.CognationManager;
import work.slhaf.partner.core.cognation.capability.ability.PerceiveCapability;
import work.slhaf.partner.core.cognation.submodule.perceive.pojo.User;
import work.slhaf.partner.core.interaction.data.InteractionInputData;
import work.slhaf.partner.core.interaction.data.context.InteractionContext;
import work.slhaf.partner.core.session.SessionManager;
import work.slhaf.partner.module.common.AppendPromptData;
import java.io.IOException;
import java.time.LocalDateTime;
import java.time.format.DateTimeFormatter;
import java.util.HashMap;
@Data
@Slf4j
public class PreprocessExecutor {
private static volatile PreprocessExecutor preprocessExecutor;
private CognationCapability cognationCapability;
private PerceiveCapability perceiveCapability;
private SessionManager sessionManager;
private PreprocessExecutor() {
}
public static PreprocessExecutor getInstance() throws IOException, ClassNotFoundException {
if (preprocessExecutor == null) {
synchronized (PreprocessExecutor.class) {
if (preprocessExecutor == null) {
preprocessExecutor = new PreprocessExecutor();
preprocessExecutor.setCognationCapability(CognationManager.getInstance());
preprocessExecutor.setPerceiveCapability(CognationManager.getInstance());
preprocessExecutor.setSessionManager(SessionManager.getInstance());
}
}
}
return preprocessExecutor;
}
public InteractionContext execute(InteractionInputData inputData) {
checkAndSetMemoryId();
return getInteractionContext(inputData);
}
private void checkAndSetMemoryId() {
String currentMemoryId = sessionManager.getCurrentMemoryId();
if (currentMemoryId == null || cognationCapability.getChatMessages().isEmpty()) {
sessionManager.refreshMemoryId();
}
}
private InteractionContext getInteractionContext(InteractionInputData inputData) {
log.debug("[PreprocessExecutor] 预处理原始输入: {}", inputData);
InteractionContext context = new InteractionContext();
User user = perceiveCapability.getUser(inputData.getUserInfo(), inputData.getPlatform());
if (user == null) {
user = perceiveCapability.addUser(inputData.getUserInfo(), inputData.getPlatform(), inputData.getUserNickName());
}
String userId = user.getUuid();
context.setUserId(userId);
context.setUserNickname(inputData.getUserNickName());
context.setUserInfo(inputData.getUserInfo());
context.setDateTime(inputData.getLocalDateTime());
context.setSingle(inputData.isSingle());
String userStr = "[" + inputData.getUserNickName() + "(" + userId + ")]";
String input = userStr + " " + inputData.getContent();
context.setInput(input);
setAppendedPrompt(context);
setCoreContext(inputData, context, input, userId);
log.debug("[PreprocessExecutor] 预处理结果: {}", context);
return context;
}
private void setAppendedPrompt(InteractionContext context) {
HashMap<String, String> map = new HashMap<>();
map.put("text", "这部分才是真正的用户输入内容, 就像你之前收到过的输入一样。但...不会是'同一个人'。");
map.put("datetime", "本次用户输入对应的当前时间");
map.put("user_nick", "用户昵称");
map.put("user_id", "用户id, 与user_nick区分, 这是用户的唯一标识");
map.put("active_modules", "已激活的模块, 为false时为激活但未活跃; 为true时为激活且活跃");
map.put("其他", "历史对话中将在用户消息的最后一行标注时间");
AppendPromptData data = new AppendPromptData();
data.setModuleName("[基础模块]");
data.setAppendedPrompt(map);
context.setAppendedPrompt(data);
}
private void setCoreContext(InteractionInputData inputData, InteractionContext context, String input, String userId) {
context.getCoreContext().setText(input);
context.getCoreContext().setDateTime(LocalDateTime.now().format(DateTimeFormatter.ISO_LOCAL_DATE_TIME));
context.getCoreContext().setUserNick(inputData.getUserNickName());
context.getCoreContext().setUserId(userId);
}
}

View File

@@ -0,0 +1,27 @@
package work.slhaf.partner.module.modules.task;
import lombok.Data;
import lombok.EqualsAndHashCode;
import work.slhaf.partner.module.common.Model;
import work.slhaf.partner.module.common.ModelConstant;
@EqualsAndHashCode(callSuper = true)
@Data
public class TaskEvaluator extends Model {
private static TaskEvaluator taskEvaluator;
private TaskEvaluator (){}
public static TaskEvaluator getInstance() {
if (taskEvaluator == null) {
taskEvaluator = new TaskEvaluator();
setModel(taskEvaluator, ModelConstant.Prompt.SCHEDULE,true);
}
return taskEvaluator;
}
@Override
protected String modelKey() {
return "task_evaluator";
}
}

View File

@@ -0,0 +1,20 @@
package work.slhaf.partner.module.modules.task;
import lombok.Data;
import work.slhaf.partner.common.thread.InteractionThreadPoolExecutor;
@Data
public class TaskExecutor {
private static TaskExecutor taskExecutor;
private InteractionThreadPoolExecutor executor;
private TaskExecutor(){}
public static TaskExecutor getInstance(){
if (taskExecutor == null){
taskExecutor = new TaskExecutor();
taskExecutor.setExecutor(InteractionThreadPoolExecutor.getInstance());
}
return taskExecutor;
}
}

Some files were not shown because too many files have changed in this diff Show More