10 Commits

Author SHA1 Message Date
85818556f8 将记忆模块的缓存逻辑迁移至 MemoryCore; 移除了 CacheCore,并将 CoordinatedManager 中原记忆模块与缓存模块中的逻辑迁移至现记忆模块中,确保语义正确 2025-10-16 15:22:19 +08:00
cb1a25e9d5 移除 ActiveData ,其逻辑回归至 CacheCore,下一步将对 CacheCore 及 CoordinateManager 中的 cognation 相关内容进行拆分 2025-10-16 11:40:55 +08:00
a10a149edb 开始推进行动模块(ActionModule); 针对框架与本体分别进行了一系列架构优化。
框架:
- 调整模块注册以及AgentRunningFlow的相关逻辑,以支持同组模块并发执行,将以@AgentModule注解中的order属性区分组间顺序先后及是否同组
- 针对@CoordinateManager注解新增了Core的自动注入处理,以便更好的协调不同Core的逻辑

本体: - 开始推进行动模块。将采取类似记忆模块的分层思路,分为ActionPlanner与ActionDispatcher两个主要模块,再各自细分子模块划分
- 将CognationCore从核心统筹的身份下降至与其他核心平级,同时将其中的序列化逻辑抽取至统一的PartnerCore父类,所有核心都将继承该类以获得序列化能力,不同core的内容将序列化至各自的memory文件
- 将SessionManager移除,相关逻辑迁移至CognationCore,统一序列化逻辑的同时又保证语义正确
- 将CognationCore中的某些缓存性质逻辑移动至CacheCore,确保语义正确
- 调整了目录结构以适应优化过的架构
2025-10-12 16:23:11 +08:00
41bf19f43e 将 .java 重命名为 .kt 2025-10-12 16:23:11 +08:00
941943f696 Partner 主体与框架适配完成! 完整逻辑已达到适配框架之前的完成度。发现并修复了不少问题,以及更新了README
框架:
- 由于`Gateway`的启动属于`Agent`启动流程的子线程,而主线程可能由于逻辑执行结束时机早于`Gateway`创建完成时机而报错,故引入`CountDownLatch`进行阻塞
- 在`AgentRunningModule`与`AgentRunningSubModule`中添加日志hook,记录模块执行的起始与截止时机
- 修复了`AgentUtil`中收集继承链时遗忘起始类的错误
- 在`CapabilityCheckFactory`中针对`CoordinateManager`无参构造方法的实现检验
- 在`CapabilityRegisterFactory`中添加了收集模块之外的CapabilityHolder的逻辑,与`@InjectCapability`的校验与注入逻辑保持一致
- 修复了‘生成模块启用配置时,多余局部变量导致无法执行流正确读取启用情况’的错误
- 在GlobalExceptionHandler中添加了对于未知异常的处理逻辑,确保不会导致程序异常终止
- 发现`ModuleProxyFactory`中使用`record`类型会导致`ByteBuddy`无法正确创建代理类,已修复,替换成普通类

本体:
- `ActiveData`由于`CognationCore`的引用,也需要实现序列化,已修复
- 修复了`MemorySelectExtractor`中由于匹配到的主题列表为空导致的空指针异常
- 将后置模块的trigger判定抽取到新的父类中,统一判断
- 修复了`WebSocketServer`如果存在过ws连接,关闭后短时间再次启动内仍提示端口占用的情况,设置允许端口重用
- 在`WebSocketGateway`新增了断开ws客户端连接的逻辑
2025-09-30 15:46:05 +08:00
a7d54349e4 进行 框架-主题 的适配测试,发现了一些问题并进行了修复
框架:
- 去除了 ActivateModel 中 modelKey() 方法的默认实现,对于特殊的 AgentModule 继承者(CoreModule)而言,直接获取注解信息不可行,如果保持,则需要另加判断逻辑。这是没有必要的
- 发现 Agent 启动流程中,由于 Gateway 的启动可能依赖配置文件的加载,故将 AgentConfigManager 与 AgentGateway 的指定替换为类型指定,在合适的时机通过反射进行实例化
- 在 AgentUtil 中新增了链式判断指定类的注解链上是否存在指定注解的方法,目前用于 CapabilityHolder 的持有实例判定
- 发现 CapabilityFactoryContext 中 cores、capabilities 未赋值导致空指针异常,已修复
- 将 AgentConfigManager 中的检验逻辑进行抽离,放到了 ConfigLoaderFactory 中,避免职责混淆
- 发现 CoreModule 的注解使用错误,`@Retention(RetentionPolicy.RUNTIME)`元注解可以使得注解在代码运行时能够被反射扫描
- 在 ModuleCheckFactory 中添加了对于 Module 与 SubModule 的注解、继承使用是否匹配的检验
- 发现对于一个类来说,无法直接通过一层反射获取到‘注解的注解’,故在 ModuleRegisterFactory 中针对 CoreModule 的注册做了特殊处理

主体:
- 发现一些类缺少必要注解,已修复
- 发现存在有些必要的类未公开化无参构造函数,已修复,并在框架部分增加校验逻辑

其他:
- 由于项目的启动流程与完整的配置文件密不可分,所以开始尝试编写启动说明,目前只写了开头
2025-09-21 23:29:45 +08:00
3c2ac32708 完成了本体与框架的适配工作,并修复了某些问题。需要进一步进行测试
- 修复了 CognationCapability 相关的注解使用错误
- 将前置模块中的 setAppendedPrompt 与 setActiveModule 方法抽取到 execute 模板方法中
- 完善了已有模块的适配工作, 并去除了不必要的单例配置
2025-09-18 16:03:59 +08:00
7f9d007f07 适配框架时发现工厂注册链上存在一些执行顺序上的错误,于是尝试修复问题,为Agent启动链添加了完整的注释,并做出了必要的修复与调整 2025-09-13 23:37:35 +08:00
c1018d6b54 进行 Partner 框架层的部分调整
- 新增 AgentSubModule 注解,用于标识子模块
- 新增 MetaSubModule 类,用于存储子模块元信息
- 支持子模块初始化和注入逻辑,不再使用单例模式为执行模块提供子模块服务
- 重构模块初始化流程,支持模块和子模块的初始化
- 优化模块注册流程,分别处理模块和子模块
2025-09-11 13:07:48 +08:00
47684c78e0 进行 Partner 本体对于框架的适配,以及框架层的部分调整
框架:

- 调整 ActivateModel 中模型初始化设置的 initHook 权重为-1(最优先)
- 为 AgentGateway 中 receive 操作提供模板方法,子类需实现发送逻辑并提供适配器
- 取消了 AgentInteractionAdapter 的单例配置
- 调整 RunningFlow 的异常处理,并在RunningFlowContext中提供错误码进行判断
- 调整模块基类
-

本体:
- 新增配置加载异常,继承自Agent启动异常
- 修改 GlobalExceptionData 获取逻辑
- 移除 MessageSender 等交互接口,适配框架的交互逻辑
- 异常处理已适配
- 配置加载逻辑已适配
- Gateway 已适配
- CoreModel 已适配
2025-09-09 20:42:28 +08:00
169 changed files with 2896 additions and 2570 deletions

7
.gitignore vendored
View File

@@ -36,8 +36,8 @@ build/
### Mac OS ### ### Mac OS ###
.DS_Store .DS_Store
/data/ /backup/data/
/config/ /backup/config/
/Partner-Core/src/main/java/src/test/java/memory/test.json /Partner-Core/src/main/java/src/test/java/memory/test.json
/Partner-Core/src/main/java/src/test/java/memory/result/input1.json /Partner-Core/src/main/java/src/test/java/memory/result/input1.json
/Partner-Core/src/main/java/src/test/java/memory/result/input2.json /Partner-Core/src/main/java/src/test/java/memory/result/input2.json
@@ -51,3 +51,6 @@ build/
/backup/ /backup/
/Partner-Main/src/test/java/text/test.json /Partner-Main/src/test/java/text/test.json
/CLAUDE.md /CLAUDE.md
/config/
/data/
/generated-classes/

6
.idea/kotlinc.xml generated Normal file
View File

@@ -0,0 +1,6 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="KotlinJpsPluginSettings">
<option name="version" value="2.2.0" />
</component>
</project>

22
.idea/misc.xml generated
View File

@@ -1,13 +1,21 @@
<?xml version="1.0" encoding="UTF-8"?> <?xml version="1.0" encoding="UTF-8"?>
<project version="4"> <project version="4">
<component name="EntryPointsManager"> <component name="EntryPointsManager">
<list size="6"> <list size="14">
<item index="0" class="java.lang.String" itemvalue="work.slhaf.partner.api.agent.factory.capability.annotation.Capability" /> <item index="0" class="java.lang.String" itemvalue="lombok.Data" />
<item index="1" class="java.lang.String" itemvalue="work.slhaf.partner.api.agent.factory.capability.annotation.CapabilityCore" /> <item index="1" class="java.lang.String" itemvalue="net.bytebuddy.implementation.bind.annotation.RuntimeType" />
<item index="2" class="java.lang.String" itemvalue="work.slhaf.partner.api.agent.factory.capability.annotation.CapabilityMethod" /> <item index="2" class="java.lang.String" itemvalue="work.slhaf.partner.api.agent.factory.capability.annotation.Capability" />
<item index="3" class="java.lang.String" itemvalue="work.slhaf.partner.api.capability.annotation.CapabilityMethod" /> <item index="3" class="java.lang.String" itemvalue="work.slhaf.partner.api.agent.factory.capability.annotation.CapabilityCore" />
<item index="4" class="java.lang.String" itemvalue="work.slhaf.partner.api.capability.annotation.CoordinateManager" /> <item index="4" class="java.lang.String" itemvalue="work.slhaf.partner.api.agent.factory.capability.annotation.CapabilityMethod" />
<item index="5" class="java.lang.String" itemvalue="work.slhaf.partner.api.register.capability.annotation.Capability" /> <item index="5" class="java.lang.String" itemvalue="work.slhaf.partner.api.agent.factory.capability.annotation.CoordinateManager" />
<item index="6" class="java.lang.String" itemvalue="work.slhaf.partner.api.agent.factory.capability.annotation.Coordinated" />
<item index="7" class="java.lang.String" itemvalue="work.slhaf.partner.api.agent.factory.module.annotation.AfterExecute" />
<item index="8" class="java.lang.String" itemvalue="work.slhaf.partner.api.agent.factory.module.annotation.AgentModule" />
<item index="9" class="java.lang.String" itemvalue="work.slhaf.partner.api.agent.factory.module.annotation.BeforeExecute" />
<item index="10" class="java.lang.String" itemvalue="work.slhaf.partner.api.agent.factory.module.annotation.Init" />
<item index="11" class="java.lang.String" itemvalue="work.slhaf.partner.api.capability.annotation.CapabilityMethod" />
<item index="12" class="java.lang.String" itemvalue="work.slhaf.partner.api.capability.annotation.CoordinateManager" />
<item index="13" class="java.lang.String" itemvalue="work.slhaf.partner.api.register.capability.annotation.Capability" />
</list> </list>
</component> </component>
<component name="ExternalStorageConfigurationManager" enabled="true" /> <component name="ExternalStorageConfigurationManager" enabled="true" />

1
.idea/vcs.xml generated
View File

@@ -2,5 +2,6 @@
<project version="4"> <project version="4">
<component name="VcsDirectoryMappings"> <component name="VcsDirectoryMappings">
<mapping directory="$PROJECT_DIR$" vcs="Git" /> <mapping directory="$PROJECT_DIR$" vcs="Git" />
<mapping directory="$USER_HOME$/Projects/IdeaProjects/Projects/Partner" vcs="Git" />
</component> </component>
</project> </project>

View File

@@ -0,0 +1,29 @@
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/maven-v4_0_0.xsd">
<parent>
<artifactId>Partner</artifactId>
<groupId>work.slhaf</groupId>
<version>0.5.0</version>
</parent>
<modelVersion>4.0.0</modelVersion>
<artifactId>Partner-Api</artifactId>
<dependencies>
<dependency>
<groupId>junit</groupId>
<artifactId>junit</artifactId>
<version>4.13.2</version>
<scope>test</scope>
<exclusions>
<exclusion>
<artifactId>hamcrest-core</artifactId>
<groupId>org.hamcrest</groupId>
</exclusion>
</exclusions>
</dependency>
</dependencies>
<properties>
<maven.compiler.target>21</maven.compiler.target>
<maven.compiler.source>21</maven.compiler.source>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
</properties>
</project>

View File

@@ -1,5 +1,6 @@
package work.slhaf.partner.api.agent; package work.slhaf.partner.api.agent;
import lombok.extern.slf4j.Slf4j;
import work.slhaf.partner.api.agent.factory.AgentRegisterFactory; import work.slhaf.partner.api.agent.factory.AgentRegisterFactory;
import work.slhaf.partner.api.agent.runtime.config.AgentConfigManager; import work.slhaf.partner.api.agent.runtime.config.AgentConfigManager;
import work.slhaf.partner.api.agent.runtime.exception.AgentExceptionCallback; import work.slhaf.partner.api.agent.runtime.exception.AgentExceptionCallback;
@@ -9,20 +10,30 @@ import work.slhaf.partner.api.agent.runtime.interaction.AgentGateway;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.List; import java.util.List;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService; import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors; import java.util.concurrent.Executors;
/**
* <h2>Agent 启动入口</h2>
* 详细启动流程请参阅{@link AgentRegisterFactory}
*/
@Slf4j
public final class Agent { public final class Agent {
public static AgentGatewayStep newAgent(Class<?> clazz) { public static AgentConfigManagerStep newAgent(Class<?> clazz) {
if (clazz == null) { if (clazz == null) {
throw new AgentLaunchFailedException("Agent class 和 interaction flow context 不能为 null"); throw new AgentLaunchFailedException("Agent class 和 interaction flow context 不能为 null");
} }
return new AgentApp(clazz); return new AgentApp(clazz);
} }
public interface AgentConfigManagerStep {
AgentGatewayStep setAgentConfigManager(Class<? extends AgentConfigManager> agentConfigManager);
}
public interface AgentGatewayStep { public interface AgentGatewayStep {
AgentStep setGateway(AgentGateway gateway); AgentStep setGateway(Class<? extends AgentGateway> gateway);
} }
public interface AgentStep { public interface AgentStep {
@@ -30,9 +41,7 @@ public final class Agent {
AgentStep addAfterLaunchRunners(Runnable... runners); AgentStep addAfterLaunchRunners(Runnable... runners);
AgentStep setAgentConfigManager(AgentConfigManager agentConfigManager); AgentStep setAgentExceptionCallback(Class<? extends AgentExceptionCallback> agentExceptionCallback);
AgentStep setAgentExceptionCallback(AgentExceptionCallback agentExceptionCallback);
AgentStep addScanPackage(String packageName); AgentStep addScanPackage(String packageName);
@@ -42,21 +51,26 @@ public final class Agent {
} }
public static class AgentApp implements AgentStep, AgentGatewayStep { public static class AgentApp implements AgentStep, AgentGatewayStep, AgentConfigManagerStep {
private final ExecutorService executorService = Executors.newVirtualThreadPerTaskExecutor(); private final ExecutorService executorService = Executors.newVirtualThreadPerTaskExecutor();
private final List<Runnable> beforeLaunchRunners = new ArrayList<>(); private final List<Runnable> beforeLaunchRunners = new ArrayList<>();
private final List<Runnable> afterLaunchRunners = new ArrayList<>(); private final List<Runnable> afterLaunchRunners = new ArrayList<>();
private AgentGateway gateway; private AgentGateway gateway;
private final Class<?> applicationClass; private final Class<?> applicationClass;
private Class<? extends AgentConfigManager> agentConfigManagerClass;
private Class<? extends AgentGateway> gatewayClass;
private Class<? extends AgentExceptionCallback> agentExceptionCallbackClass;
private final CountDownLatch latch = new CountDownLatch(1);
private AgentApp(Class<?> clazz) { private AgentApp(Class<?> clazz) {
this.applicationClass = clazz; this.applicationClass = clazz;
} }
@Override @Override
public AgentStep setGateway(AgentGateway gateway) { public AgentStep setGateway(Class<? extends AgentGateway> gateway) {
this.gateway = gateway; this.gatewayClass = gateway;
return this; return this;
} }
@@ -73,14 +87,14 @@ public final class Agent {
} }
@Override @Override
public AgentStep setAgentConfigManager(AgentConfigManager agentConfigManager) { public AgentGatewayStep setAgentConfigManager(Class<? extends AgentConfigManager> agentConfigManager) {
AgentConfigManager.setINSTANCE(agentConfigManager); this.agentConfigManagerClass = agentConfigManager;
return this; return this;
} }
@Override @Override
public AgentStep setAgentExceptionCallback(AgentExceptionCallback agentExceptionCallback) { public AgentStep setAgentExceptionCallback(Class<? extends AgentExceptionCallback> agentExceptionCallback) {
GlobalExceptionHandler.setExceptionCallback(agentExceptionCallback); agentExceptionCallbackClass = agentExceptionCallback;
return this; return this;
} }
@@ -98,10 +112,38 @@ public final class Agent {
@Override @Override
public void launch() { public void launch() {
launchRunners(beforeLaunchRunners); beforeLaunch();
AgentRegisterFactory.launch(applicationClass.getPackageName()); AgentRegisterFactory.launch(applicationClass.getPackageName());
executorService.execute(() -> gateway.launch()); afterLaunch();
}
private void afterLaunch() {
try {
this.gateway = gatewayClass.getDeclaredConstructor().newInstance();
executorService.execute(() -> {
gateway.launch();
latch.countDown();
log.info("Gateway 启动完毕: {}", gatewayClass.getSimpleName());
});
latch.await();
launchRunners(afterLaunchRunners); launchRunners(afterLaunchRunners);
log.info("后置任务启动完毕");
} catch (Exception e) {
throw new AgentLaunchFailedException("Agent 后置任务启动失败", e);
}
}
private void beforeLaunch() {
try {
AgentConfigManager.setINSTANCE(agentConfigManagerClass.getDeclaredConstructor().newInstance());
log.info("配置管理器设置完毕: {}",agentConfigManagerClass.getSimpleName());
GlobalExceptionHandler.setExceptionCallback(agentExceptionCallbackClass.getDeclaredConstructor().newInstance());
log.info("异常处理回调设置完毕: {}",agentExceptionCallbackClass.getSimpleName());
launchRunners(beforeLaunchRunners);
log.info("前置任务启动完毕");
} catch (Exception e) {
throw new AgentLaunchFailedException("Agent 前置任务启动失败", e);
}
} }
private void launchRunners(List<Runnable> runners) { private void launchRunners(List<Runnable> runners) {

View File

@@ -14,14 +14,22 @@ import work.slhaf.partner.api.agent.factory.module.ModuleInitHookExecuteFactory;
import work.slhaf.partner.api.agent.factory.module.ModuleProxyFactory; import work.slhaf.partner.api.agent.factory.module.ModuleProxyFactory;
import work.slhaf.partner.api.agent.factory.module.ModuleRegisterFactory; import work.slhaf.partner.api.agent.factory.module.ModuleRegisterFactory;
import work.slhaf.partner.api.agent.factory.module.pojo.MetaModule; import work.slhaf.partner.api.agent.factory.module.pojo.MetaModule;
import work.slhaf.partner.api.agent.runtime.data.AgentContext;
import work.slhaf.partner.api.agent.runtime.config.AgentConfigManager; import work.slhaf.partner.api.agent.runtime.config.AgentConfigManager;
import work.slhaf.partner.api.agent.runtime.data.AgentContext;
import work.slhaf.partner.api.agent.runtime.interaction.flow.AgentRunningFlow;
import java.io.File; import java.io.File;
import java.net.URL; import java.net.URL;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.List; import java.util.List;
/**
* <h2>Agent 注册工厂</h2>
*
* <p>
* 具体流程依次按照 {@link AgentRegisterFactory#launch(String)} 方法顺序执行,最终将执行模块列表对应实例交给 {@link AgentConfigManager} ,传递给 {@link AgentRunningFlow} 针对交互做出调用
* <p/>
*/
public class AgentRegisterFactory { public class AgentRegisterFactory {
private static final List<URL> urls = new ArrayList<>(); private static final List<URL> urls = new ArrayList<>();
@@ -35,20 +43,20 @@ public class AgentRegisterFactory {
//流程 //流程
//0. 加载配置 //0. 加载配置
new ConfigLoaderFactory().execute(registerContext); new ConfigLoaderFactory().execute(registerContext);
//1. 注册并检查Capability //1. 注册并检查Module
new CapabilityRegisterFactory().execute(registerContext);
new CapabilityCheckFactory().execute(registerContext);
//2. 注册并检查Module
new ModuleCheckFactory().execute(registerContext); new ModuleCheckFactory().execute(registerContext);
new ModuleRegisterFactory().execute(registerContext); new ModuleRegisterFactory().execute(registerContext);
//3. 为module通过动态代理添加PostHook逻辑并进行实例化 //2. 为module通过动态代理添加PostHook逻辑并进行实例化
new ModuleProxyFactory().execute(registerContext); new ModuleProxyFactory().execute(registerContext);
//3. 加载检查Capability层内容后进行能力层的内容注册
new CapabilityCheckFactory().execute(registerContext);
new CapabilityRegisterFactory().execute(registerContext);
//. 先一步注入Capability,避免因前hook逻辑存在针对能力的引用而报错 //. 先一步注入Capability,避免因前hook逻辑存在针对能力的引用而报错
new CapabilityInjectFactory().execute(registerContext); new CapabilityInjectFactory().execute(registerContext);
//. 执行模块PreHook逻辑 //. 执行模块PreHook逻辑
new ModuleInitHookExecuteFactory().execute(registerContext); new ModuleInitHookExecuteFactory().execute(registerContext);
List<MetaModule> moduleList = registerContext.getModuleFactoryContext().getModuleList(); List<MetaModule> moduleList = registerContext.getModuleFactoryContext().getAgentModuleList();
AgentConfigManager.INSTANCE.moduleEnabledStatusFilterAndRecord(moduleList); AgentConfigManager.INSTANCE.moduleEnabledStatusFilterAndRecord(moduleList);
BeanUtil.copyProperties(registerContext, AgentContext.INSTANCE); BeanUtil.copyProperties(registerContext, AgentContext.INSTANCE);

View File

@@ -1,24 +1,52 @@
package work.slhaf.partner.api.agent.factory.capability; package work.slhaf.partner.api.agent.factory.capability;
import cn.hutool.core.util.ClassUtil;
import org.reflections.Reflections; import org.reflections.Reflections;
import work.slhaf.partner.api.agent.factory.AgentBaseFactory; import work.slhaf.partner.api.agent.factory.AgentBaseFactory;
import work.slhaf.partner.api.agent.factory.capability.annotation.*; import work.slhaf.partner.api.agent.factory.capability.annotation.*;
import work.slhaf.partner.api.agent.factory.capability.exception.DuplicateCapabilityException; import work.slhaf.partner.api.agent.factory.capability.exception.*;
import work.slhaf.partner.api.agent.factory.capability.exception.UnMatchedCapabilityException;
import work.slhaf.partner.api.agent.factory.capability.exception.UnMatchedCapabilityMethodException;
import work.slhaf.partner.api.agent.factory.capability.exception.UnMatchedCoordinatedMethodException;
import work.slhaf.partner.api.agent.factory.context.AgentRegisterContext; import work.slhaf.partner.api.agent.factory.context.AgentRegisterContext;
import work.slhaf.partner.api.agent.factory.context.CapabilityFactoryContext; import work.slhaf.partner.api.agent.factory.context.CapabilityFactoryContext;
import work.slhaf.partner.api.agent.factory.module.annotation.AgentModule;
import work.slhaf.partner.api.agent.factory.module.annotation.AgentSubModule;
import work.slhaf.partner.api.agent.util.AgentUtil; import work.slhaf.partner.api.agent.util.AgentUtil;
import java.lang.reflect.Method; import java.lang.reflect.Method;
import java.util.*; import java.util.*;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import static work.slhaf.partner.api.agent.util.AgentUtil.isAssignableFromAnnotation;
import static work.slhaf.partner.api.agent.util.AgentUtil.methodSignature; import static work.slhaf.partner.api.agent.util.AgentUtil.methodSignature;
/** /**
* 执行<code>Capability</code>相关检查 * <h2>Agent启动流程 4</h2>
*
* <p>负责通过反射收集 {@link Capability} 和 {@link CapabilityCore} 注解所在类,并判断是否存在被错误忽略的方法</p>
*
* <ol>
* <li>
* <p>{@link CapabilityCheckFactory#loadCoresAndCapabilities()}</p>
* 通过反射收集 {@link Capability} 和 {@link CapabilityCore} 注解所在类为对应集合
* </li>
* <li>
* <p>{@link CapabilityCheckFactory#checkCountAndCapabilities()}</p>
* 检测 {@link Capability} 与 {@link CapabilityCore} 的数量、对应的能力是否相等。每一个core都将对应一个capability并通过value属性进行匹配
* </li>
* <li>
* <p>{@link CapabilityCheckFactory#checkCapabilityMethods()}</p>
* 检测在 {@link Capability} 与 {@link CapabilityCore} 中是否存在对方尚未实现/注册的方法
* </li>
* <li>
* <p>{@link CapabilityCheckFactory#checkCoordinatedMethods()}</p>
* 检查是否包含协调方法({@link ToCoordinated}),如果存在,则进一步检查在 {@link CoordinateManager} 所注类中是否有提供对应的实现
* </li>
* <li>
* <p>{@link CapabilityCheckFactory#checkInjectCapability()}</p>
* 检查 {@link InjectCapability} 注解是否只用在 {@link CapabilityHolder} 所标识类的字段上。{@link AgentModule} 与 {@link AgentSubModule} 已经被 {@link CapabilityHolder} 标注
* </li>
* </ol>
*
* <p>下一步流程请参阅{@link CapabilityRegisterFactory}</p>
*/ */
public class CapabilityCheckFactory extends AgentBaseFactory { public class CapabilityCheckFactory extends AgentBaseFactory {
@@ -37,19 +65,42 @@ public class CapabilityCheckFactory extends AgentBaseFactory {
@Override @Override
protected void run() { protected void run() {
loadCoresAndCapabilities();
checkCountAndCapabilities(); checkCountAndCapabilities();
checkCapabilityMethods(); checkCapabilityMethods();
checkCoordinatedMethods(); checkCoordinatedMethods();
checkCoordinatedManager();
checkInjectCapability(); checkInjectCapability();
} }
private void checkCoordinatedManager() {
reflections.getTypesAnnotatedWith(CoordinateManager.class)
.stream()
.filter(ClassUtil::isNormalClass)
.forEach(managerClass -> {
try {
if (!managerClass.getDeclaredConstructor().canAccess(null)) {
throw new CapabilityCheckFailedException("CoordinateManager 所注类的无参构造方法未公开!");
}
} catch (NoSuchMethodException e) {
throw new CapabilityCheckFailedException("CoordinateManager 所注类缺少无参构造方法!");
}
});
}
private void loadCoresAndCapabilities() {
cores.addAll(reflections.getTypesAnnotatedWith(CapabilityCore.class));
capabilities.addAll(reflections.getTypesAnnotatedWith(Capability.class));
}
/** /**
* 检查<code>@InjectCapability</code>注解是否只用在<code>@CapabilityHolder</code>所标识类的字段上 * 检查<code>@InjectCapability</code>注解是否只用在<code>@CapabilityHolder</code>所标识类的字段上
*/ */
private void checkInjectCapability() { private void checkInjectCapability() {
reflections.getFieldsAnnotatedWith(InjectCapability.class).forEach(field -> { reflections.getFieldsAnnotatedWith(InjectCapability.class).forEach(field -> {
if (!field.getDeclaringClass().isAssignableFrom(CapabilityHolder.class)) { Class<?> declaringClass = field.getDeclaringClass();
throw new UnMatchedCapabilityException("InjectCapability 注解只能用于 CapabilityHolder 注解所在类"); if (!isAssignableFromAnnotation(declaringClass, CapabilityHolder.class)) {
throw new UnMatchedCapabilityException("InjectCapability 注解只能用于 CapabilityHolder 注解所在类,检查该类是否使用了@CapabilityHolder注解或者受其标注的注解或父类: " + declaringClass);
} }
}); });
} }

View File

@@ -8,6 +8,9 @@ import work.slhaf.partner.api.agent.factory.capability.annotation.ToCoordinated;
import work.slhaf.partner.api.agent.factory.capability.exception.ProxySetFailedExceptionCapability; import work.slhaf.partner.api.agent.factory.capability.exception.ProxySetFailedExceptionCapability;
import work.slhaf.partner.api.agent.factory.context.AgentRegisterContext; import work.slhaf.partner.api.agent.factory.context.AgentRegisterContext;
import work.slhaf.partner.api.agent.factory.context.CapabilityFactoryContext; import work.slhaf.partner.api.agent.factory.context.CapabilityFactoryContext;
import work.slhaf.partner.api.agent.factory.module.ModuleInitHookExecuteFactory;
import work.slhaf.partner.api.agent.factory.module.annotation.AgentModule;
import work.slhaf.partner.api.agent.factory.module.annotation.AgentSubModule;
import java.lang.reflect.Field; import java.lang.reflect.Field;
import java.lang.reflect.Proxy; import java.lang.reflect.Proxy;
@@ -18,9 +21,23 @@ import java.util.function.Function;
import static work.slhaf.partner.api.agent.util.AgentUtil.methodSignature; import static work.slhaf.partner.api.agent.util.AgentUtil.methodSignature;
/** /**
* 负责执行<code>Capability</code>的注入逻辑 * <h2>Agent启动流程 6</h2>
*/ *
public class CapabilityInjectFactory extends AgentBaseFactory { * <p>负责执行 {@link Capability} 的注入逻辑。</p>
*
* <p>实现方式:</p>
* <ol>
* <li>通过动态代理,为 {@link AgentModule} 与 {@link AgentSubModule} 中待注入的
* <b>能力接口</b> 类型(即 {@link Capability} 标注的接口类)生成代理对象。
* </li>
* <li>在代理对象内部,根据调用方法的签名确定路由,将调用转发至对应的具体函数。
* </li>
* <li>通过此机制,实现了 {@link Capability} 单一语义层面上普通方法与协调方法的统一入口。
* </li>
* </ol>
*
* <p>下一步流程请参阅 {@link ModuleInitHookExecuteFactory}</p>
*/public class CapabilityInjectFactory extends AgentBaseFactory {
private Reflections reflections; private Reflections reflections;
private HashMap<String, Function<Object[], Object>> coordinatedMethodsRouterTable; private HashMap<String, Function<Object[], Object>> coordinatedMethodsRouterTable;

View File

@@ -1,5 +1,6 @@
package work.slhaf.partner.api.agent.factory.capability; package work.slhaf.partner.api.agent.factory.capability;
import cn.hutool.core.util.ClassUtil;
import org.reflections.Reflections; import org.reflections.Reflections;
import work.slhaf.partner.api.agent.factory.AgentBaseFactory; import work.slhaf.partner.api.agent.factory.AgentBaseFactory;
import work.slhaf.partner.api.agent.factory.capability.annotation.*; import work.slhaf.partner.api.agent.factory.capability.annotation.*;
@@ -8,28 +9,61 @@ import work.slhaf.partner.api.agent.factory.capability.exception.CoreInstancesCr
import work.slhaf.partner.api.agent.factory.capability.exception.DuplicateMethodException; import work.slhaf.partner.api.agent.factory.capability.exception.DuplicateMethodException;
import work.slhaf.partner.api.agent.factory.context.AgentRegisterContext; import work.slhaf.partner.api.agent.factory.context.AgentRegisterContext;
import work.slhaf.partner.api.agent.factory.context.CapabilityFactoryContext; import work.slhaf.partner.api.agent.factory.context.CapabilityFactoryContext;
import work.slhaf.partner.api.agent.factory.module.annotation.AgentModule;
import work.slhaf.partner.api.agent.factory.module.annotation.AgentSubModule;
import java.lang.reflect.Constructor; import java.lang.reflect.Constructor;
import java.lang.reflect.Field;
import java.lang.reflect.InvocationTargetException; import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method; import java.lang.reflect.Method;
import java.util.Arrays; import java.util.Arrays;
import java.util.HashMap; import java.util.HashMap;
import java.util.Set; import java.util.Set;
import java.util.function.Function; import java.util.function.Function;
import java.util.stream.Collectors;
import static cn.hutool.core.util.ClassUtil.isNormalClass;
import static work.slhaf.partner.api.agent.util.AgentUtil.methodSignature; import static work.slhaf.partner.api.agent.util.AgentUtil.methodSignature;
/** /**
* 负责获取<code>@Capability</code>和<code>@CapabilityCore</code>标识的类,并生成函数路由表、设置<code>Core</code>实例用于后续注入 * <h2>Agent启动流程 5</h2>
*
* <p>
* 负责收集注解 {@link Capability} 和 {@link CapabilityCore} 标识的类并生成函数路由表、创建core、capability实例以及放入instanceMap供后续进行注入操作
* </p>
*
* <ol>
* <li>
* <p>{@link CapabilityRegisterFactory#setCoreInstances()}</p>
* 通过反射调用无参构造函数创建core实例并将实例放入instanceMap供后续使用
* </li>
* <li>
* <p>{@link CapabilityRegisterFactory#generateRouterTable()}</p>
* 生成函数路由表:
* <ul>
* <li>
* <p>{@link CapabilityRegisterFactory#generateMethodsRouterTable()}</p>
* 生成普通方法对应的函数路由表
* </li>
* <li>
* <p>{@link CapabilityRegisterFactory#generateCoordinatedMethodsRouterTable()}</p>
* 生成协调方法对应的函数路由表
* </li>
* </ul>
* </li>
* <li>
* 函数路由表生成完毕、core实例创建完毕之后将交由下一工厂完成能力(Capability)注入操作,注入到 {@link AgentModule} 与 {@link AgentSubModule} 对应的实例中
* </li>
* </ol>
*
* <p>下一步流程请参阅{@link CapabilityInjectFactory}</p>
*/ */
public final class CapabilityRegisterFactory extends AgentBaseFactory { public class CapabilityRegisterFactory extends AgentBaseFactory {
private Reflections reflections; private Reflections reflections;
private HashMap<String, Function<Object[], Object>> methodsRouterTable; private HashMap<String, Function<Object[], Object>> methodsRouterTable;
private HashMap<String, Function<Object[], Object>> coordinatedMethodsRouterTable; private HashMap<String, Function<Object[], Object>> coordinatedMethodsRouterTable;
private HashMap<Class<?>, Object> capabilityCoreInstances; private HashMap<Class<?>, Object> coreInstances;
private HashMap<Class<?>, Object> capabilityHolderInstances; private HashMap<Class<?>, Object> capabilityHolderInstances;
private Set<Class<?>> cores; private Set<Class<?>> cores;
private Set<Class<?>> capabilities; private Set<Class<?>> capabilities;
@@ -40,35 +74,35 @@ public final class CapabilityRegisterFactory extends AgentBaseFactory {
reflections = context.getReflections(); reflections = context.getReflections();
methodsRouterTable = factoryContext.getMethodsRouterTable(); methodsRouterTable = factoryContext.getMethodsRouterTable();
coordinatedMethodsRouterTable = factoryContext.getCoordinatedMethodsRouterTable(); coordinatedMethodsRouterTable = factoryContext.getCoordinatedMethodsRouterTable();
capabilityCoreInstances = factoryContext.getCapabilityCoreInstances(); coreInstances = factoryContext.getCapabilityCoreInstances();
cores = factoryContext.getCores(); cores = factoryContext.getCores();
capabilities = factoryContext.getCapabilities(); capabilities = factoryContext.getCapabilities();
capabilityHolderInstances = factoryContext.getCapabilityHolderInstances(); capabilityHolderInstances = factoryContext.getCapabilityHolderInstances();
} }
@Override @Override
protected void run() throws InvocationTargetException, NoSuchMethodException, InstantiationException, IllegalAccessException { protected void run() {
setCapabilityCoreInstances(); setCapabilityHolderInstances();
setAnnotatedClasses(); setCoreInstances();
generateRouterTable(); generateRouterTable();
} }
/** private void setCapabilityHolderInstances() {
* 设置<code>CapabilityCore</code>、<code>Capability</code>注解标识类 Set<Class<?>> collect = reflections.getTypesAnnotatedWith(CapabilityHolder.class).stream()
*/ .filter(ClassUtil::isNormalClass)
private void setAnnotatedClasses() throws InvocationTargetException, NoSuchMethodException, InstantiationException, IllegalAccessException { .filter(clazz -> !capabilityHolderInstances.containsKey(clazz))
cores.addAll(reflections.getTypesAnnotatedWith(CapabilityCore.class)); .collect(Collectors.toSet());
capabilities.addAll(reflections.getTypesAnnotatedWith(Capability.class)); for (Class<?> clazz : collect) {
setCapabilityHolderInstances(); try {
Constructor<?> constructor = clazz.getDeclaredConstructor();
if (constructor.canAccess(null)) {
throw new CapabilityFactoryExecuteFailedException("缺少无参构造方法的类: " + clazz);
} }
Object o = constructor.newInstance();
private void setCapabilityHolderInstances() throws NoSuchMethodException, InvocationTargetException, InstantiationException, IllegalAccessException {
for (Class<?> clazz : reflections.getTypesAnnotatedWith(CapabilityHolder.class)) {
if (!isNormalClass(clazz)){
continue;
}
Object o = clazz.getDeclaredConstructor().newInstance();
capabilityHolderInstances.put(clazz, o); capabilityHolderInstances.put(clazz, o);
} catch (Exception e) {
throw new CapabilityFactoryExecuteFailedException("创建代理对象失败: " + clazz, e);
}
} }
} }
@@ -116,7 +150,7 @@ public final class CapabilityRegisterFactory extends AgentBaseFactory {
for (Class<?> c : reflections.getTypesAnnotatedWith(CoordinateManager.class)) { for (Class<?> c : reflections.getTypesAnnotatedWith(CoordinateManager.class)) {
Constructor<?> constructor = c.getDeclaredConstructor(); Constructor<?> constructor = c.getDeclaredConstructor();
Object instance = constructor.newInstance(); Object instance = constructor.newInstance();
setCores(instance, c);
Arrays.stream(c.getMethods()) Arrays.stream(c.getMethods())
.filter(method -> method.isAnnotationPresent(Coordinated.class)) .filter(method -> method.isAnnotationPresent(Coordinated.class))
.forEach(method -> { .forEach(method -> {
@@ -127,18 +161,26 @@ public final class CapabilityRegisterFactory extends AgentBaseFactory {
return map; return map;
} }
private void setCores(Object cmInstance, Class<?> cmClazz) throws IllegalAccessException {
for (Field field : cmClazz.getFields()) {
if (field.getType().isAnnotationPresent(CapabilityCore.class)) {
field.setAccessible(true);
field.set(cmInstance, coreInstances.get(field.getType()));
}
}
}
/** /**
* 生成普通方法对应的函数路由表 * 扫描`@Capability`与`@CapabilityMethod`注解的类与方法
* 将`capabilityValue.methodSignature`作为key,函数对象为通过反射拿到的core实例对应的方法
*/ */
private void generateMethodsRouterTable() { private void generateMethodsRouterTable() {
//扫描`@Capability`与`@CapabilityMethod`注解的类与方法
//将`capabilityValue.methodSignature`作为key,函数对象为通过反射拿到的core实例对应的方法
cores.forEach(core -> Arrays.stream(core.getMethods()) cores.forEach(core -> Arrays.stream(core.getMethods())
.filter(method -> method.isAnnotationPresent(CapabilityMethod.class)) .filter(method -> method.isAnnotationPresent(CapabilityMethod.class))
.forEach(method -> { .forEach(method -> {
Function<Object[], Object> function = args -> { Function<Object[], Object> function = args -> {
try { try {
return method.invoke(capabilityCoreInstances.get(core), args); return method.invoke(coreInstances.get(core), args);
} catch (IllegalAccessException | InvocationTargetException e) { } catch (IllegalAccessException | InvocationTargetException e) {
throw new RuntimeException(e); throw new RuntimeException(e);
} }
@@ -154,12 +196,12 @@ public final class CapabilityRegisterFactory extends AgentBaseFactory {
/** /**
* 反射获取<code>CapabilityCore</code>实例 * 反射获取<code>CapabilityCore</code>实例
*/ */
private void setCapabilityCoreInstances() { private void setCoreInstances() {
try { try {
for (Class<?> core : cores) { for (Class<?> core : cores) {
Constructor<?> constructor = core.getDeclaredConstructor(); Constructor<?> constructor = core.getDeclaredConstructor();
constructor.setAccessible(true); constructor.setAccessible(true);
capabilityCoreInstances.put(core, constructor.newInstance()); coreInstances.put(core, constructor.newInstance());
} }
} catch (InvocationTargetException | NoSuchMethodException | InstantiationException | } catch (InvocationTargetException | NoSuchMethodException | InstantiationException |
IllegalAccessException e) { IllegalAccessException e) {

View File

@@ -5,6 +5,9 @@ import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy; import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target; import java.lang.annotation.Target;
/**
* Core的协调类该注解的实现类中如果存在任何{@link CapabilityCore}实例的引用,都将被自动注入
*/
@Target(ElementType.TYPE) @Target(ElementType.TYPE)
@Retention(RetentionPolicy.RUNTIME) @Retention(RetentionPolicy.RUNTIME)
public @interface CoordinateManager { public @interface CoordinateManager {

View File

@@ -1,16 +1,31 @@
package work.slhaf.partner.api.agent.factory.config; package work.slhaf.partner.api.agent.factory.config;
import lombok.extern.slf4j.Slf4j;
import work.slhaf.partner.api.agent.factory.AgentBaseFactory; import work.slhaf.partner.api.agent.factory.AgentBaseFactory;
import work.slhaf.partner.api.agent.factory.config.exception.ConfigNotExistException;
import work.slhaf.partner.api.agent.factory.config.exception.PromptNotExistException;
import work.slhaf.partner.api.agent.factory.config.pojo.ModelConfig; import work.slhaf.partner.api.agent.factory.config.pojo.ModelConfig;
import work.slhaf.partner.api.agent.factory.context.AgentRegisterContext; import work.slhaf.partner.api.agent.factory.context.AgentRegisterContext;
import work.slhaf.partner.api.agent.factory.context.ConfigFactoryContext; import work.slhaf.partner.api.agent.factory.context.ConfigFactoryContext;
import work.slhaf.partner.api.agent.factory.module.ModuleCheckFactory;
import work.slhaf.partner.api.agent.runtime.config.AgentConfigManager; import work.slhaf.partner.api.agent.runtime.config.AgentConfigManager;
import work.slhaf.partner.api.agent.runtime.config.DefaultAgentConfigManager; import work.slhaf.partner.api.agent.runtime.config.FileAgentConfigManager;
import work.slhaf.partner.api.chat.pojo.Message; import work.slhaf.partner.api.chat.pojo.Message;
import java.util.HashMap; import java.util.HashMap;
import java.util.HashSet;
import java.util.List; import java.util.List;
import java.util.Set;
/**
* <h2>Agent启动流程 0</h2>
* <p>
* 通过指定的 {@link AgentConfigManager} 或者默认的 {@link FileAgentConfigManager} 加载配置文件
* <p/>
*
* <p>下一步流程请参阅{@link ModuleCheckFactory}</p>
*/
@Slf4j
public class ConfigLoaderFactory extends AgentBaseFactory { public class ConfigLoaderFactory extends AgentBaseFactory {
private AgentConfigManager agentConfigManager; private AgentConfigManager agentConfigManager;
@@ -23,8 +38,8 @@ public class ConfigLoaderFactory extends AgentBaseFactory {
modelConfigMap = factoryContext.getModelConfigMap(); modelConfigMap = factoryContext.getModelConfigMap();
modelPromptMap = factoryContext.getModelPromptMap(); modelPromptMap = factoryContext.getModelPromptMap();
if (AgentConfigManager.INSTANCE == null){ if (AgentConfigManager.INSTANCE == null) {
AgentConfigManager.setINSTANCE(new DefaultAgentConfigManager()); AgentConfigManager.setINSTANCE(new FileAgentConfigManager());
} }
agentConfigManager = AgentConfigManager.INSTANCE; agentConfigManager = AgentConfigManager.INSTANCE;
@@ -33,9 +48,30 @@ public class ConfigLoaderFactory extends AgentBaseFactory {
@Override @Override
protected void run() { protected void run() {
agentConfigManager.load(); agentConfigManager.load();
agentConfigManager.check();
modelConfigMap.putAll(agentConfigManager.getModelConfigMap()); modelConfigMap.putAll(agentConfigManager.getModelConfigMap());
modelPromptMap.putAll(agentConfigManager.getModelPromptMap()); modelPromptMap.putAll(agentConfigManager.getModelPromptMap());
check();
} }
/**
* 对模型Config与Prompt分别进行检验,除了都必须包含default外还需要确保数量、key一致毕竟是模型配置与提示词
*/
private void check() {
log.info("执行config与prompt检测...");
if (!modelConfigMap.containsKey("default")) {
throw new ConfigNotExistException("缺少默认配置! 需确保存在一个模型配置的key为`default`");
}
if (!modelPromptMap.containsKey("basic")) {
throw new PromptNotExistException("缺少基础Prompt! 需要确保存在key为basic的Prompt文件它将与其他Prompt共同作用于模块节点。");
}
Set<String> configKeySet = new HashSet<>(modelConfigMap.keySet());
configKeySet.remove("default");
Set<String> promptKeySet = new HashSet<>(modelPromptMap.keySet());
promptKeySet.remove("basic");
if (!promptKeySet.containsAll(configKeySet)) {
log.warn("存在未被提示词包含的模型配置,该配置将无法生效!");
}
//检查提示词数量与`ActivateModel`的实现数量是否一致
log.info("检测完毕.");
}
} }

View File

@@ -3,6 +3,7 @@ package work.slhaf.partner.api.agent.factory.context;
import lombok.Data; import lombok.Data;
import java.util.HashMap; import java.util.HashMap;
import java.util.HashSet;
import java.util.Set; import java.util.Set;
import java.util.function.Function; import java.util.function.Function;
@@ -12,6 +13,6 @@ public class CapabilityFactoryContext {
private final HashMap<String, Function<Object[], Object>> coordinatedMethodsRouterTable = new HashMap<>(); private final HashMap<String, Function<Object[], Object>> coordinatedMethodsRouterTable = new HashMap<>();
private final HashMap<Class<?>, Object> capabilityCoreInstances = new HashMap<>(); private final HashMap<Class<?>, Object> capabilityCoreInstances = new HashMap<>();
private final HashMap<Class<?>, Object> capabilityHolderInstances = new HashMap<>(); private final HashMap<Class<?>, Object> capabilityHolderInstances = new HashMap<>();
private Set<Class<?>> cores; private Set<Class<?>> cores = new HashSet<>();
private Set<Class<?>> capabilities; private Set<Class<?>> capabilities = new HashSet<>();
} }

View File

@@ -2,11 +2,13 @@ package work.slhaf.partner.api.agent.factory.context;
import lombok.Data; import lombok.Data;
import work.slhaf.partner.api.agent.factory.module.pojo.MetaModule; import work.slhaf.partner.api.agent.factory.module.pojo.MetaModule;
import work.slhaf.partner.api.agent.factory.module.pojo.MetaSubModule;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.List; import java.util.List;
@Data @Data
public class ModuleFactoryContext { public class ModuleFactoryContext {
private List<MetaModule> moduleList = new ArrayList<>(); private List<MetaModule> agentModuleList = new ArrayList<>();
private List<MetaSubModule> agentSubModuleList = new ArrayList<>();
} }

View File

@@ -6,8 +6,8 @@ import work.slhaf.partner.api.agent.factory.AgentBaseFactory;
import work.slhaf.partner.api.agent.factory.context.AgentRegisterContext; import work.slhaf.partner.api.agent.factory.context.AgentRegisterContext;
import work.slhaf.partner.api.agent.factory.module.annotation.AfterExecute; import work.slhaf.partner.api.agent.factory.module.annotation.AfterExecute;
import work.slhaf.partner.api.agent.factory.module.annotation.AgentModule; import work.slhaf.partner.api.agent.factory.module.annotation.AgentModule;
import work.slhaf.partner.api.agent.factory.module.annotation.AgentSubModule;
import work.slhaf.partner.api.agent.factory.module.annotation.BeforeExecute; import work.slhaf.partner.api.agent.factory.module.annotation.BeforeExecute;
import work.slhaf.partner.api.agent.factory.module.annotation.Init;
import work.slhaf.partner.api.agent.factory.module.exception.ModuleCheckException; import work.slhaf.partner.api.agent.factory.module.exception.ModuleCheckException;
import work.slhaf.partner.api.agent.runtime.config.AgentConfigManager; import work.slhaf.partner.api.agent.runtime.config.AgentConfigManager;
import work.slhaf.partner.api.agent.runtime.interaction.flow.abstracts.ActivateModel; import work.slhaf.partner.api.agent.runtime.interaction.flow.abstracts.ActivateModel;
@@ -19,6 +19,32 @@ import java.util.HashSet;
import java.util.Set; import java.util.Set;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import static work.slhaf.partner.api.agent.util.AgentUtil.getMethodAnnotationTypeSet;
/**
* <h2>Agent启动流程 1</h2>
*
* <p>
* 检查模块部分抽象类与注解、接口的使用方式
* </p>
*
* <ol>
* <li>
* <p>{@link ModuleCheckFactory#annotationAbstractCheck(Set, Class)}</p>
* 所有添加了 {@link AgentModule} 注解的类都将作为Agent的执行模块为规范模块入口都必须实现抽象类: {@link AgentRunningModule}; {@link AgentSubModule} 注解所在类则必须实现 {@link AgentRunningSubModule}
* </li>
* <li>
* <p>{@link ModuleCheckFactory#moduleConstructorsCheck(Set)}</p>
* 所有 {@link AgentModule} 与 {@link AgentSubModule} 注解所在类都必须具备空参构造方法,初始化逻辑可放在 @Init 注解所处方法中,将在 Capability 与 subModules 注入后才会执行
* </li>
* <li>
* <p>{@link ModuleCheckFactory#activateModelImplCheck()}</p>
* 检查实现了 {@link ActivateModel} 的模块数量、名称与prompt是否一致
* </li>
* </ol>
*
* <p>下一步流程请参阅{@link ModuleRegisterFactory}</p>
*/
public class ModuleCheckFactory extends AgentBaseFactory { public class ModuleCheckFactory extends AgentBaseFactory {
private Reflections reflections; private Reflections reflections;
@@ -30,15 +56,43 @@ public class ModuleCheckFactory extends AgentBaseFactory {
@Override @Override
protected void run() { protected void run() {
Set<Class<?>> types = reflections.getTypesAnnotatedWith(AgentModule.class); AnnotatedModules annotatedModules = getAnnotatedModules();
//检查注解AgentModule所在类是否继承了AgentInteractionModule ExtendedModules extendedModules = getExtendedModules();
agentModuleAnnotationCheck(types); checkIfClassCorresponds(annotatedModules, extendedModules);
//检查注解AgentModule或AgentSubModule所在类是否继承了对应的抽象类
annotationAbstractCheck(annotatedModules.moduleTypes(), AgentRunningModule.class);
annotationAbstractCheck(annotatedModules.subModuleTypes(), AgentRunningSubModule.class);
//检查AgentModule是否具备无参构造方法 //检查AgentModule是否具备无参构造方法
moduleConstructorsCheck(types); moduleConstructorsCheck(annotatedModules.moduleTypes());
//检查hook注解所在方法是否位于AgentInteractionModule子类/AgentInteractionSubModule子类/ActivateModel子类 moduleConstructorsCheck(annotatedModules.subModuleTypes());
hookLocationCheck();
//检查实现了ActivateModel的模块数量、名称与prompt是否一致 //检查实现了ActivateModel的模块数量、名称与prompt是否一致
activateModelImplCheck(); activateModelImplCheck();
//检查hook注解所在位置是否正确
hookLocationCheck();
}
private ExtendedModules getExtendedModules() {
Set<Class<?>> moduleTypes = reflections.getSubTypesOf(AgentRunningModule.class)
.stream()
.filter(ClassUtil::isNormalClass)
.collect(Collectors.toSet());
Set<Class<?>> subModuleTypes = reflections.getSubTypesOf(AgentRunningSubModule.class)
.stream()
.filter(ClassUtil::isNormalClass)
.collect(Collectors.toSet());
return new ExtendedModules(moduleTypes, subModuleTypes);
}
private AnnotatedModules getAnnotatedModules() {
Set<Class<?>> moduleTypes = reflections.getTypesAnnotatedWith(AgentModule.class)
.stream()
.filter(ClassUtil::isNormalClass)
.collect(Collectors.toSet());
Set<Class<?>> subModuleTypes = reflections.getTypesAnnotatedWith(AgentSubModule.class)
.stream()
.filter(ClassUtil::isNormalClass)
.collect(Collectors.toSet());
return new AnnotatedModules(moduleTypes, subModuleTypes);
} }
private void moduleConstructorsCheck(Set<Class<?>> types) { private void moduleConstructorsCheck(Set<Class<?>> types) {
@@ -76,24 +130,10 @@ public class ModuleCheckFactory extends AgentBaseFactory {
preHookLocationCheck(); preHookLocationCheck();
//检查@Init注解 //检查@Init注解
initHookLocationCheck(); initHookLocationCheck();
//检查@AgentModule注解是否只位于普通类上
agentModuleLocationCheck();
}
private void agentModuleLocationCheck() {
Set<Class<?>> types = reflections.getTypesAnnotatedWith(AgentModule.class);
for (Class<?> type : types) {
if (!ClassUtil.isNormalClass(type)) {
throw new ModuleCheckException("AgentModule 注解仅能位于普通类上! 异常类信息: " + type.getSimpleName());
}
}
} }
private void initHookLocationCheck() { private void initHookLocationCheck() {
Set<Method> methods = reflections.getMethodsAnnotatedWith(Init.class); Set<Class<?>> types = getMethodAnnotationTypeSet(AgentModule.class, reflections);
Set<Class<?>> types = methods.stream()
.map(Method::getDeclaringClass)
.collect(Collectors.toSet());
checkLocation(types); checkLocation(types);
} }
@@ -129,15 +169,59 @@ public class ModuleCheckFactory extends AgentBaseFactory {
} }
} }
private void agentModuleAnnotationCheck(Set<Class<?>> types) { private void annotationAbstractCheck(Set<Class<?>> types, Class<?> clazz) {
for (Class<?> type : types) { for (Class<?> type : types) {
if (type.isAnnotation()) { if (type.isAnnotation()) {
continue; continue;
} }
if (AgentRunningModule.class.isAssignableFrom(type) && ClassUtil.isNormalClass(type)) { if (clazz.isAssignableFrom(type) && ClassUtil.isNormalClass(type)) {
continue; continue;
} }
throw new ModuleCheckException("存在未继承AgentInteractionModule.class的AgentModule实现: " + type.getSimpleName()); throw new ModuleCheckException("存在未继承AgentInteractionModule.class的AgentModule实现: " + type.getSimpleName());
} }
} }
private void checkIfClassCorresponds(AnnotatedModules annotatedModules, ExtendedModules extendedModules) {
// 检查是否有被@AgentModule注解但没有继承AgentRunningModule的类
checkSets(annotatedModules.moduleTypes(), extendedModules.moduleTypes(),
"存在被@AgentModule注解但未继承AgentRunningModule的类");
// 检查是否有继承AgentRunningModule但没有被@AgentModule注解的类
checkSets(extendedModules.moduleTypes(), annotatedModules.moduleTypes(),
"存在继承AgentRunningModule但未被@AgentModule注解的类");
// 检查是否有被@AgentSubModule注解但没有继承AgentRunningSubModule的类
checkSets(annotatedModules.subModuleTypes(), extendedModules.subModuleTypes(),
"存在被@AgentSubModule注解但未继承AgentRunningSubModule的类");
// 检查是否有继承AgentRunningSubModule但没有被@AgentSubModule注解的类
checkSets(extendedModules.subModuleTypes(), annotatedModules.subModuleTypes(),
"存在继承AgentRunningSubModule但未被@AgentSubModule注解的类");
}
/**
* 检查源集合中是否有不在目标集合中的元素
* @param source 源集合
* @param target 目标集合
* @param errorMessage 错误信息前缀
*/
private void checkSets(Set<Class<?>> source, Set<Class<?>> target, String errorMessage) {
// 只有在需要时才创建HashSet以节省内存
if (!target.containsAll(source)) {
// 使用流式处理找出差异部分,避免创建完整的中间集合
String classNames = source.stream()
.filter(clazz -> !target.contains(clazz))
.map(Class::getSimpleName)
.limit(10) // 限制显示数量,避免信息泄露
.collect(Collectors.joining(", ", "[", "]"));
throw new ModuleCheckException(errorMessage + ": " + classNames);
}
}
private record AnnotatedModules(Set<Class<?>> moduleTypes, Set<Class<?>> subModuleTypes) {
}
private record ExtendedModules(Set<Class<?>> moduleTypes, Set<Class<?>> subModuleTypes) {
}
} }

View File

@@ -1,13 +1,19 @@
package work.slhaf.partner.api.agent.factory.module; package work.slhaf.partner.api.agent.factory.module;
import work.slhaf.partner.api.agent.factory.AgentBaseFactory; import work.slhaf.partner.api.agent.factory.AgentBaseFactory;
import work.slhaf.partner.api.agent.factory.AgentRegisterFactory;
import work.slhaf.partner.api.agent.factory.context.AgentRegisterContext; import work.slhaf.partner.api.agent.factory.context.AgentRegisterContext;
import work.slhaf.partner.api.agent.factory.context.ModuleFactoryContext; import work.slhaf.partner.api.agent.factory.context.ModuleFactoryContext;
import work.slhaf.partner.api.agent.factory.module.annotation.Init; import work.slhaf.partner.api.agent.factory.module.annotation.Init;
import work.slhaf.partner.api.agent.factory.module.exception.ModuleInitHookExecuteFailedException; import work.slhaf.partner.api.agent.factory.module.exception.ModuleInitHookExecuteFailedException;
import work.slhaf.partner.api.agent.factory.module.pojo.BaseMetaModule;
import work.slhaf.partner.api.agent.factory.module.pojo.MetaMethod; import work.slhaf.partner.api.agent.factory.module.pojo.MetaMethod;
import work.slhaf.partner.api.agent.factory.module.pojo.MetaModule; import work.slhaf.partner.api.agent.factory.module.pojo.MetaModule;
import work.slhaf.partner.api.agent.factory.module.pojo.MetaSubModule;
import work.slhaf.partner.api.agent.runtime.interaction.flow.abstracts.AgentRunningModule; import work.slhaf.partner.api.agent.runtime.interaction.flow.abstracts.AgentRunningModule;
import work.slhaf.partner.api.agent.runtime.interaction.flow.abstracts.AgentRunningSubModule;
import work.slhaf.partner.api.agent.runtime.interaction.flow.abstracts.Module;
import work.slhaf.partner.api.agent.util.AgentUtil;
import java.lang.reflect.InvocationTargetException; import java.lang.reflect.InvocationTargetException;
import java.util.Arrays; import java.util.Arrays;
@@ -20,39 +26,61 @@ import static work.slhaf.partner.api.agent.util.AgentUtil.collectExtendedClasses
import static work.slhaf.partner.api.agent.util.AgentUtil.methodSignature; import static work.slhaf.partner.api.agent.util.AgentUtil.methodSignature;
/** /**
* 负责执行前hook逻辑 * <h2>Agent启动流程 7</h2>
*
* <p>负责执行初始化hook逻辑即 {@link Init} 注解所在方法</p>
*
* <ol>
* <li>
* <p>{@link ModuleInitHookExecuteFactory#collectInitHookMethods(Class, Class)}</p>
* 分别遍历前置模块拿到的模块列表({@link ModuleInitHookExecuteFactory#moduleList}, {@link ModuleInitHookExecuteFactory#subModuleList}),通过 {@link AgentUtil#collectExtendedClasses(Class, Class)} 收集到当前模块类的继承链上的所有类后,收集其所有带有 {@link Init} 注解的方法
* </li>
* <li>
* <p>{@link ModuleInitHookExecuteFactory#proceedInitMethods(BaseMetaModule, List)}</p>
* 收集好初始化方法后,将通过反射执行该方法,所用实例即为前置模块中收集到的执行模块与子模块的 {@link MetaModule} 与 {@link MetaSubModule} 内容
* </li>
* </ol>
*
* <p>Agent启动流程到此进行完毕。整个工厂执行链中均为针对 {@link AgentRegisterContext} 进行的操作,在 {@link AgentRegisterFactory} 中,将进行最终处理以及将必要内容进行传递。</p>
*/ */
public class ModuleInitHookExecuteFactory extends AgentBaseFactory { public class ModuleInitHookExecuteFactory extends AgentBaseFactory {
private List<MetaModule> moduleList; private List<MetaModule> moduleList;
private List<MetaSubModule> subModuleList;
@Override @Override
protected void setVariables(AgentRegisterContext context) { protected void setVariables(AgentRegisterContext context) {
ModuleFactoryContext factoryContext = context.getModuleFactoryContext(); ModuleFactoryContext factoryContext = context.getModuleFactoryContext();
moduleList = factoryContext.getModuleList(); moduleList = factoryContext.getAgentModuleList();
subModuleList = factoryContext.getAgentSubModuleList();
} }
@Override @Override
protected void run() throws NoSuchMethodException, InvocationTargetException, InstantiationException, IllegalAccessException { protected void run() {
//遍历模块列表,并向上查找@Init注解 //遍历模块列表,并向上查找@Init注解
for (MetaSubModule metaSubModule : subModuleList) {
List<MetaMethod> initHookMethods = collectInitHookMethods(metaSubModule.getClazz(),AgentRunningModule.class);
proceedInitMethods(metaSubModule, initHookMethods);
}
for (MetaModule metaModule : moduleList) { for (MetaModule metaModule : moduleList) {
List<MetaMethod> initHookMethods = collectInitHookMethods(metaModule.getClazz()); List<MetaMethod> initHookMethods = collectInitHookMethods(metaModule.getClazz(), AgentRunningSubModule.class);
proceedInitMethods(metaModule, initHookMethods); proceedInitMethods(metaModule, initHookMethods);
} }
} }
private void proceedInitMethods(MetaModule metaModule, List<MetaMethod> initHookMethods) { private void proceedInitMethods(BaseMetaModule metaModule, List<MetaMethod> initHookMethods) {
for (MetaMethod metaMethod : initHookMethods) { for (MetaMethod metaMethod : initHookMethods) {
try { try {
metaMethod.getMethod().invoke(metaModule.getInstance()); metaMethod.getMethod().invoke(metaModule.getInstance());
} catch (IllegalAccessException | InvocationTargetException e) { } catch (IllegalAccessException | InvocationTargetException e) {
throw new ModuleInitHookExecuteFailedException("模块的init hook方法执行失败! 模块: " + metaModule.getName() + " 方法签名: " + methodSignature(metaMethod.getMethod()), e); throw new ModuleInitHookExecuteFailedException("模块的init hook方法执行失败! 模块: " + metaModule.getClazz().getSimpleName() + " 方法签名: " + methodSignature(metaMethod.getMethod()), e);
} }
} }
} }
private List<MetaMethod> collectInitHookMethods(Class<?> clazz) { private List<MetaMethod> collectInitHookMethods(Class<?> clazz, Class<? extends Module> target) {
Set<Class<?>> classes = collectExtendedClasses(clazz, AgentRunningModule.class); Set<Class<?>> classes = collectExtendedClasses(clazz, target);
return classes.stream() return classes.stream()
.map(Class::getDeclaredMethods) .map(Class::getDeclaredMethods)
.flatMap(Arrays::stream) .flatMap(Arrays::stream)

View File

@@ -1,21 +1,29 @@
package work.slhaf.partner.api.agent.factory.module; package work.slhaf.partner.api.agent.factory.module;
import lombok.Getter;
import net.bytebuddy.ByteBuddy; import net.bytebuddy.ByteBuddy;
import net.bytebuddy.implementation.MethodDelegation; import net.bytebuddy.implementation.MethodDelegation;
import net.bytebuddy.implementation.bind.annotation.*; import net.bytebuddy.implementation.bind.annotation.*;
import net.bytebuddy.matcher.ElementMatchers; import net.bytebuddy.matcher.ElementMatchers;
import work.slhaf.partner.api.agent.factory.AgentBaseFactory; import work.slhaf.partner.api.agent.factory.AgentBaseFactory;
import work.slhaf.partner.api.agent.factory.capability.CapabilityCheckFactory;
import work.slhaf.partner.api.agent.factory.context.AgentRegisterContext; import work.slhaf.partner.api.agent.factory.context.AgentRegisterContext;
import work.slhaf.partner.api.agent.factory.context.CapabilityFactoryContext;
import work.slhaf.partner.api.agent.factory.context.ModuleFactoryContext; import work.slhaf.partner.api.agent.factory.context.ModuleFactoryContext;
import work.slhaf.partner.api.agent.factory.module.annotation.AfterExecute; import work.slhaf.partner.api.agent.factory.module.annotation.AfterExecute;
import work.slhaf.partner.api.agent.factory.module.annotation.BeforeExecute; import work.slhaf.partner.api.agent.factory.module.annotation.BeforeExecute;
import work.slhaf.partner.api.agent.factory.module.annotation.InjectModule;
import work.slhaf.partner.api.agent.factory.module.exception.ModuleInstanceGenerateFailedException; import work.slhaf.partner.api.agent.factory.module.exception.ModuleInstanceGenerateFailedException;
import work.slhaf.partner.api.agent.factory.module.exception.ModuleProxyGenerateFailedException; import work.slhaf.partner.api.agent.factory.module.exception.ModuleProxyGenerateFailedException;
import work.slhaf.partner.api.agent.factory.module.exception.ProxiedModuleRunningException;
import work.slhaf.partner.api.agent.factory.module.pojo.BaseMetaModule;
import work.slhaf.partner.api.agent.factory.module.pojo.MetaMethod; import work.slhaf.partner.api.agent.factory.module.pojo.MetaMethod;
import work.slhaf.partner.api.agent.factory.module.pojo.MetaModule; import work.slhaf.partner.api.agent.factory.module.pojo.MetaModule;
import work.slhaf.partner.api.agent.factory.module.pojo.MetaSubModule;
import work.slhaf.partner.api.agent.runtime.interaction.flow.abstracts.AgentRunningModule; import work.slhaf.partner.api.agent.runtime.interaction.flow.abstracts.AgentRunningModule;
import work.slhaf.partner.api.agent.runtime.interaction.flow.abstracts.AgentRunningSubModule;
import work.slhaf.partner.api.agent.runtime.interaction.flow.abstracts.Module;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method; import java.lang.reflect.Method;
import java.util.*; import java.util.*;
import java.util.concurrent.Callable; import java.util.concurrent.Callable;
@@ -24,50 +32,124 @@ import java.util.stream.Collectors;
import static work.slhaf.partner.api.agent.util.AgentUtil.collectExtendedClasses; import static work.slhaf.partner.api.agent.util.AgentUtil.collectExtendedClasses;
/** /**
* 通过扫描注解<code>@BeforeExecute</code>获取到各个模块的后hook逻辑并通过动态代理添加到执行逻辑之后 * <h2>Agent启动流程 3</h2>
*
* <p>
* 扫描前置模块各个hook注解生成代理对象放入对应的list中并按照类型为键放入 {@link ModuleProxyFactory#capabilityHolderInstances} 中供后续完成能力(capability)注入
* <p/>
*
* <ol>
*
* <li>
* <p>{@link ModuleProxyFactory#createProxiedInstances()}</p>
* 根据moduleList中的类型信息向上查找继承链获取所有hook方法收集为{@link MethodsListRecord}然后通过ByteBuddy根据收集到的preHook与postHook生成代理对象放入对应的 {@link MetaModule} 对象以及 instanceMap 中
* </li>
* <li>
* <p>{@link ModuleProxyFactory#injectSubModule()}</p>
* 通过反射将子模块实例注入到执行模块中带有注解 {@link InjectModule} 的字段
* </li>
* </ol>
*
* <p>下一步流程请参阅{@link CapabilityCheckFactory}</p>
*/ */
public class ModuleProxyFactory extends AgentBaseFactory { public class ModuleProxyFactory extends AgentBaseFactory {
private List<MetaModule> moduleList; private List<MetaModule> moduleList;
private List<MetaSubModule> subModuleList;
private HashMap<Class<?>, Object> capabilityHolderInstances;
private final HashMap<Class<?>, Object> subModuleInstances = new HashMap<>();
private final HashMap<Class<?>, Object> moduleInstances = new HashMap<>();
@Override @Override
protected void setVariables(AgentRegisterContext context) { protected void setVariables(AgentRegisterContext context) {
ModuleFactoryContext factoryContext = context.getModuleFactoryContext(); ModuleFactoryContext factoryContext = context.getModuleFactoryContext();
moduleList = factoryContext.getModuleList(); CapabilityFactoryContext capabilityFactoryContext = context.getCapabilityFactoryContext();
moduleList = factoryContext.getAgentModuleList();
subModuleList = factoryContext.getAgentSubModuleList();
capabilityHolderInstances = capabilityFactoryContext.getCapabilityHolderInstances();
} }
@Override @Override
protected void run() throws NoSuchMethodException, InvocationTargetException, InstantiationException, IllegalAccessException { protected void run() {
generateInstances(); createProxiedInstances();
setHookProxy(); injectSubModule();
} }
private void setHookProxy() { private void injectSubModule() {
for (MetaModule module : moduleList) { for (MetaModule module : moduleList) {
//因为实际上ByteBuddy生成的是module.getClazz()的子类所以应当使用getDeclaredFields()获取字段
Arrays.stream(module.getClazz().getDeclaredFields())
.filter(field -> field.isAnnotationPresent(InjectModule.class))
.forEach(field -> {
try {
field.setAccessible(true);
field.set(
moduleInstances.get(module.getClazz()),
subModuleInstances.get(field.getType())
);
} catch (IllegalAccessException e) {
throw new ModuleInstanceGenerateFailedException("模块实例注入失败", e);
}
});
}
}
private void createProxiedInstances() {
generateModuleProxy(moduleList, AgentRunningModule.class);
generateModuleProxy(subModuleList, AgentRunningSubModule.class);
updateInstanceMap(moduleInstances, moduleList);
updateInstanceMap(subModuleInstances, subModuleList);
updateCapabilityHolderInstances();
}
private void updateCapabilityHolderInstances() {
capabilityHolderInstances.putAll(moduleInstances);
capabilityHolderInstances.putAll(subModuleInstances);
}
private void updateInstanceMap(HashMap<Class<?>, Object> instanceMap, List<? extends BaseMetaModule> list) {
for (BaseMetaModule baseMetaModule : list) {
instanceMap.put(baseMetaModule.getClazz(), baseMetaModule.getInstance());
}
}
private void generateModuleProxy(List<? extends BaseMetaModule> list, Class<? extends Module> overrideSource) {
for (BaseMetaModule module : list) {
Class<?> clazz = module.getClazz(); Class<?> clazz = module.getClazz();
try { try {
MethodsListRecord record = collectHookMethods(clazz); MethodsListRecord record = collectHookMethods(clazz);
//生成实例 //生成实例
generateProxiedInstances(record, module); generateProxiedInstances(record, module, overrideSource);
} catch (Exception e) { } catch (Exception e) {
throw new ModuleProxyGenerateFailedException("创建代理对象失败: " + clazz.getSimpleName(), e); throw new ModuleProxyGenerateFailedException("创建代理对象失败: " + clazz.getSimpleName(), e);
} }
} }
} }
private void generateProxiedInstances(MethodsListRecord record, MetaModule metaModule) { private void generateProxiedInstances(MethodsListRecord record, BaseMetaModule module, Class<? extends Module> overrideSource) {
try { try {
Class<? extends AgentRunningModule> clazz = metaModule.getClazz(); Class<? extends Module> clazz = module.getClazz();
Class<? extends AgentRunningModule> proxyClass = new ByteBuddy() Class<? extends Module> proxyClass = new ByteBuddy()
.subclass(clazz) .subclass(clazz)
.method(ElementMatchers.isOverriddenFrom(AgentRunningModule.class)) .method(ElementMatchers.isOverriddenFrom(overrideSource))
.intercept(MethodDelegation.to(new ModuleProxyInterceptor(record.post, record.pre))) .intercept(MethodDelegation.to(new ModuleProxyInterceptor(record.post, record.pre)))
.make() .make()
.load(ModuleProxyFactory.class.getClassLoader()) .load(ModuleProxyFactory.class.getClassLoader())
.getLoaded(); .getLoaded();
metaModule.setInstance(proxyClass.getConstructor().newInstance());
// new ByteBuddy()
// .subclass(clazz)
// .method(ElementMatchers.isOverriddenFrom(overrideSource))
// .intercept(MethodDelegation.to(new ModuleProxyInterceptor(record.post, record.pre)))
//
// .make()
// .saveIn(new File("./generated-classes"));
module.setInstance(proxyClass.getConstructor().newInstance());
} catch (Exception e) { } catch (Exception e) {
throw new ModuleProxyGenerateFailedException("模块Hook代理生成失败! 代理失败的模块名: " + metaModule.getClazz().getSimpleName(), e); throw new ModuleProxyGenerateFailedException("模块Hook代理生成失败! 代理失败的模块名: " + module.getClazz().getSimpleName(), e);
} }
} }
@@ -76,8 +158,8 @@ public class ModuleProxyFactory extends AgentBaseFactory {
List<MetaMethod> pre = new ArrayList<>(); List<MetaMethod> pre = new ArrayList<>();
//获取该类本身的hook逻辑 //获取该类本身的hook逻辑
collectHookMethods(post, pre, clazz); collectHookMethods(post, pre, clazz);
//获取它所继承、实现的抽象类或接口, 以AgentInteractionModule、ActiveModel为终点 //获取它所继承、实现的抽象类或接口, 以Module为终点收集继承链上所有父类和接口
Set<Class<?>> classes = collectExtendedClasses(clazz, AgentRunningModule.class); Set<Class<?>> classes = collectExtendedClasses(clazz, Module.class);
//获取这些类中的hook逻辑 //获取这些类中的hook逻辑
collectHookMethods(post, pre, classes); collectHookMethods(post, pre, classes);
return new MethodsListRecord(post, pre); return new MethodsListRecord(post, pre);
@@ -118,7 +200,7 @@ public class ModuleProxyFactory extends AgentBaseFactory {
private void collectHookMethods(List<MetaMethod> post, List<MetaMethod> pre, Class<?> clazz) { private void collectHookMethods(List<MetaMethod> post, List<MetaMethod> pre, Class<?> clazz) {
Method[] methods = clazz.getMethods(); Method[] methods = clazz.getDeclaredMethods();
for (Method method : methods) { for (Method method : methods) {
if (method.isAnnotationPresent(BeforeExecute.class)) { if (method.isAnnotationPresent(BeforeExecute.class)) {
MetaMethod metaMethod = new MetaMethod(); MetaMethod metaMethod = new MetaMethod();
@@ -134,30 +216,38 @@ public class ModuleProxyFactory extends AgentBaseFactory {
} }
} }
private void generateInstances() { @Getter
for (MetaModule metaModule : moduleList) { @SuppressWarnings("ClassCanBeRecord")
public static class ModuleProxyInterceptor {
private final List<MetaMethod> postHookMethods;
private final List<MetaMethod> preHookMethods;
public ModuleProxyInterceptor(List<MetaMethod> postHookMethods, List<MetaMethod> preHookMethods) {
this.postHookMethods = postHookMethods;
this.preHookMethods = preHookMethods;
}
@RuntimeType
public Object intercept(@Origin Method method, @AllArguments Object[] allArguments, @SuperCall Callable<?> zuper, @This Object proxy) throws Exception {
executeHookMethods(preHookMethods, proxy);
Object res = zuper.call();
executeHookMethods(postHookMethods, proxy);
return res;
}
private void executeHookMethods(List<MetaMethod> hookMethods, Object proxy) {
for (MetaMethod metaMethod : hookMethods) {
Method m = metaMethod.getMethod();
try { try {
Class<? extends AgentRunningModule> clazz = metaModule.getClazz(); m.setAccessible(true);
AgentRunningModule instance = clazz.getConstructor().newInstance(); m.invoke(proxy);
metaModule.setInstance(instance);
} catch (Exception e) { } catch (Exception e) {
throw new ModuleInstanceGenerateFailedException("模块实例构造失败:" + e.getMessage()); throw new ProxiedModuleRunningException("hook方法执行异常: " + m.getDeclaringClass() + "#" + m.getName(), e);
} }
} }
} }
private record ModuleProxyInterceptor(List<MetaMethod> postHookMethods, List<MetaMethod> preHookMethods) {
@RuntimeType
public Object intercept(@Origin Method method, @AllArguments Object[] allArguments, @SuperCall Callable<?> zuper, @This Object proxy) throws Exception {
for (MetaMethod metaMethod : preHookMethods) {
metaMethod.getMethod().invoke(proxy);
}
Object res = zuper.call();
for (MetaMethod metaMethod : postHookMethods) {
metaMethod.getMethod().invoke(proxy);
}
return res;
}
} }
record MethodsListRecord(List<MetaMethod> post, List<MetaMethod> pre) { record MethodsListRecord(List<MetaMethod> post, List<MetaMethod> pre) {

View File

@@ -6,31 +6,71 @@ import work.slhaf.partner.api.agent.factory.AgentBaseFactory;
import work.slhaf.partner.api.agent.factory.context.AgentRegisterContext; import work.slhaf.partner.api.agent.factory.context.AgentRegisterContext;
import work.slhaf.partner.api.agent.factory.context.ModuleFactoryContext; import work.slhaf.partner.api.agent.factory.context.ModuleFactoryContext;
import work.slhaf.partner.api.agent.factory.module.annotation.AgentModule; import work.slhaf.partner.api.agent.factory.module.annotation.AgentModule;
import work.slhaf.partner.api.agent.factory.module.annotation.AgentSubModule;
import work.slhaf.partner.api.agent.factory.module.annotation.CoreModule;
import work.slhaf.partner.api.agent.factory.module.pojo.MetaModule; import work.slhaf.partner.api.agent.factory.module.pojo.MetaModule;
import work.slhaf.partner.api.agent.factory.module.pojo.MetaSubModule;
import work.slhaf.partner.api.agent.runtime.interaction.flow.abstracts.AgentRunningModule; import work.slhaf.partner.api.agent.runtime.interaction.flow.abstracts.AgentRunningModule;
import work.slhaf.partner.api.agent.runtime.interaction.flow.abstracts.AgentRunningSubModule;
import java.util.Comparator; import java.util.Comparator;
import java.util.List; import java.util.List;
import java.util.Set; import java.util.Set;
/** /**
* 负责扫描<code>@Module</code>注解获取模块实例 * <h2>Agent启动流程 2</h2>
*
* <p>
* 负责收集 {@link AgentModule} 与 {@link AgentSubModule} 注解所在类的信息,供后续工厂完成动态代理、模块与能力注入
* <p/>
*
* <ol>
* <li>
* <p>{@link ModuleRegisterFactory#setModuleList()}</p>
* 扫描 {@link AgentModule} 注解,获取执行模块信息: 类型、模块名称({@link AgentModule#name()}),执行顺序。并按照注解的 {@link AgentModule#order()} 字段进行排序
* </li>
* <li>
* <p>{@link ModuleRegisterFactory#setSubModuleList()}</p>
* 扫描 {@link AgentSubModule} 注册,获取子模块类型信息
* </li>
* <li>
* 两种模块都将存入各自的list中供后续模块完成注册与注入
* </li>
* </ol>
*
* <p>下一步流程请参阅{@link ModuleProxyFactory}</p>
*/ */
public class ModuleRegisterFactory extends AgentBaseFactory { public class ModuleRegisterFactory extends AgentBaseFactory {
private Reflections reflections; private Reflections reflections;
private List<MetaModule> moduleList; private List<MetaModule> moduleList;
private List<MetaSubModule> subModuleList;
@Override @Override
protected void setVariables(AgentRegisterContext context) { protected void setVariables(AgentRegisterContext context) {
ModuleFactoryContext factoryContext = context.getModuleFactoryContext(); ModuleFactoryContext factoryContext = context.getModuleFactoryContext();
reflections = context.getReflections(); reflections = context.getReflections();
moduleList = factoryContext.getModuleList(); moduleList = factoryContext.getAgentModuleList();
subModuleList = factoryContext.getAgentSubModuleList();
} }
@Override @Override
protected void run() { protected void run() {
setModuleList(); setModuleList();
setSubModuleList();
}
private void setSubModuleList() {
Set<Class<?>> subModules = reflections.getTypesAnnotatedWith(AgentSubModule.class);
for (Class<?> subModule : subModules) {
if (!ClassUtil.isNormalClass(subModule)) {
continue;
}
Class<? extends AgentRunningSubModule> clazz = subModule.asSubclass(AgentRunningSubModule.class);
MetaSubModule metaSubModule = new MetaSubModule();
metaSubModule.setClazz(clazz);
subModuleList.add(metaSubModule);
}
} }
private void setModuleList() { private void setModuleList() {
@@ -41,13 +81,24 @@ public class ModuleRegisterFactory extends AgentBaseFactory {
continue; continue;
} }
Class<? extends AgentRunningModule> clazz = module.asSubclass(AgentRunningModule.class); Class<? extends AgentRunningModule> clazz = module.asSubclass(AgentRunningModule.class);
AgentModule agentModule = clazz.getAnnotation(AgentModule.class); MetaModule metaModule = getMetaModule(clazz);
MetaModule metaModule = new MetaModule();
metaModule.setName(agentModule.name());
metaModule.setOrder(agentModule.order());
metaModule.setClazz(clazz);
moduleList.add(metaModule); moduleList.add(metaModule);
} }
moduleList.sort(Comparator.comparing(MetaModule::getOrder)); moduleList.sort(Comparator.comparing(MetaModule::getOrder));
} }
private static MetaModule getMetaModule(Class<? extends AgentRunningModule> clazz) {
MetaModule metaModule = new MetaModule();
AgentModule agentModule;
if (clazz.isAnnotationPresent(CoreModule.class)){
agentModule = CoreModule.class.getAnnotation(AgentModule.class);
}else{
agentModule = clazz.getAnnotation(AgentModule.class);
}
metaModule.setName(agentModule.name());
metaModule.setOrder(agentModule.order());
metaModule.setClazz(clazz);
return metaModule;
}
} }

View File

@@ -1,14 +1,17 @@
package work.slhaf.partner.api.agent.factory.module.annotation; package work.slhaf.partner.api.agent.factory.module.annotation;
import work.slhaf.partner.api.agent.factory.capability.annotation.CapabilityHolder;
import java.lang.annotation.*; import java.lang.annotation.*;
/** /**
* 用于注解执行模块 * 用于注解执行模块
*/ */
@Inherited
@Target(ElementType.TYPE) @Target(ElementType.TYPE)
@Retention(RetentionPolicy.RUNTIME) @Retention(RetentionPolicy.RUNTIME)
@CapabilityHolder
@Inherited
public @interface AgentModule { public @interface AgentModule {
/** /**

View File

@@ -0,0 +1,14 @@
package work.slhaf.partner.api.agent.factory.module.annotation;
import work.slhaf.partner.api.agent.factory.capability.annotation.CapabilityHolder;
import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;
@Target(ElementType.TYPE)
@Retention(RetentionPolicy.RUNTIME)
@CapabilityHolder
public @interface AgentSubModule {
}

View File

@@ -1,5 +1,9 @@
package work.slhaf.partner.api.agent.factory.module.annotation; package work.slhaf.partner.api.agent.factory.module.annotation;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
@Retention(RetentionPolicy.RUNTIME)
@AgentModule(name = "core",order = 5) @AgentModule(name = "core",order = 5)
public @interface CoreModule { public @interface CoreModule {
} }

View File

@@ -0,0 +1,11 @@
package work.slhaf.partner.api.agent.factory.module.annotation;
import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;
@Target(ElementType.FIELD)
@Retention(RetentionPolicy.RUNTIME)
public @interface InjectModule {
}

View File

@@ -0,0 +1,13 @@
package work.slhaf.partner.api.agent.factory.module.exception;
import work.slhaf.partner.api.agent.runtime.exception.AgentRuntimeException;
public class ProxiedModuleRunningException extends AgentRuntimeException {
public ProxiedModuleRunningException(String message) {
super(message);
}
public ProxiedModuleRunningException(String message, Throwable cause) {
super(message, cause);
}
}

View File

@@ -0,0 +1,10 @@
package work.slhaf.partner.api.agent.factory.module.pojo;
import lombok.Data;
import work.slhaf.partner.api.agent.runtime.interaction.flow.abstracts.Module;
@Data
public abstract class BaseMetaModule <C extends Module> {
private Class<? extends C> clazz;
private C instance;
}

View File

@@ -1,13 +1,13 @@
package work.slhaf.partner.api.agent.factory.module.pojo; package work.slhaf.partner.api.agent.factory.module.pojo;
import lombok.Data; import lombok.Data;
import lombok.EqualsAndHashCode;
import work.slhaf.partner.api.agent.runtime.interaction.flow.abstracts.AgentRunningModule; import work.slhaf.partner.api.agent.runtime.interaction.flow.abstracts.AgentRunningModule;
@EqualsAndHashCode(callSuper = true)
@Data @Data
public class MetaModule { public class MetaModule extends BaseMetaModule<AgentRunningModule>{
private String name; private String name;
private int order; private int order;
private Class<? extends AgentRunningModule> clazz;
private AgentRunningModule instance;
private boolean enabled = true; private boolean enabled = true;
} }

View File

@@ -0,0 +1,10 @@
package work.slhaf.partner.api.agent.factory.module.pojo;
import lombok.Data;
import lombok.EqualsAndHashCode;
import work.slhaf.partner.api.agent.runtime.interaction.flow.abstracts.AgentRunningSubModule;
@EqualsAndHashCode(callSuper = true)
@Data
public class MetaSubModule extends BaseMetaModule<AgentRunningSubModule>{
}

View File

@@ -3,30 +3,27 @@ package work.slhaf.partner.api.agent.runtime.config;
import lombok.Data; import lombok.Data;
import lombok.Setter; import lombok.Setter;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import work.slhaf.partner.api.agent.factory.config.exception.ConfigNotExistException;
import work.slhaf.partner.api.agent.factory.config.exception.ConfigUpdateFailedException; import work.slhaf.partner.api.agent.factory.config.exception.ConfigUpdateFailedException;
import work.slhaf.partner.api.agent.factory.config.exception.PromptNotExistException; import work.slhaf.partner.api.agent.factory.config.exception.PromptNotExistException;
import work.slhaf.partner.api.agent.factory.config.pojo.ModelConfig; import work.slhaf.partner.api.agent.factory.config.pojo.ModelConfig;
import work.slhaf.partner.api.agent.factory.module.pojo.MetaModule; import work.slhaf.partner.api.agent.factory.module.pojo.MetaModule;
import work.slhaf.partner.api.chat.pojo.Message; import work.slhaf.partner.api.chat.pojo.Message;
import java.util.HashMap; import java.util.*;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
@Slf4j @Slf4j
@Data @Data
public abstract class AgentConfigManager { public abstract class AgentConfigManager {
@Setter @Setter
public static AgentConfigManager INSTANCE; public static AgentConfigManager INSTANCE = new FileAgentConfigManager();
private static final String DEFAULT_KEY = "default"; private static final String DEFAULT_KEY = "default";
protected HashMap<String, ModelConfig> modelConfigMap; protected HashMap<String, ModelConfig> modelConfigMap;
protected HashMap<String, List<Message>> modelPromptMap; protected HashMap<String, List<Message>> modelPromptMap;
protected HashMap<String, Boolean> moduleEnabledStatus; protected HashMap<String, Boolean> moduleEnabledStatus;
protected List<MetaModule> moduleList; protected Map<Integer, List<MetaModule>> moduleOrderedMap = new LinkedHashMap<>();
protected Map<String, MetaModule> moduleMap = new HashMap<>();
public void load() { public void load() {
modelConfigMap = loadModelConfig(); modelConfigMap = loadModelConfig();
@@ -41,11 +38,24 @@ public abstract class AgentConfigManager {
protected abstract void dumpModuleEnabledStatus(); protected abstract void dumpModuleEnabledStatus();
protected abstract HashMap<String, Boolean> loadModuleEnabledStatusMap(); protected abstract HashMap<String, Boolean> loadModuleEnabledStatusMap(List<MetaModule> moduleList);
public void moduleEnabledStatusFilterAndRecord(List<MetaModule> moduleList) { public void moduleEnabledStatusFilterAndRecord(List<MetaModule> moduleList) {
this.moduleList = moduleList; updateModuleMap(moduleList);
this.moduleEnabledStatus = loadModuleEnabledStatusMap(); updateModuleEnabledStatus(moduleList);
}
private void updateModuleMap(List<MetaModule> moduleList) {
//在ModuleRegisterFactory已进行过排序操作
for (MetaModule module : moduleList) {
int k = module.getOrder();
moduleOrderedMap.computeIfAbsent(k, order -> new ArrayList<>()).add(module);
moduleMap.put(module.getName(), module);
}
}
private void updateModuleEnabledStatus(List<MetaModule> moduleList) {
this.moduleEnabledStatus = loadModuleEnabledStatusMap(moduleList);
boolean unmatch = false; boolean unmatch = false;
for (MetaModule metaModule : moduleList) { for (MetaModule metaModule : moduleList) {
@@ -62,27 +72,6 @@ public abstract class AgentConfigManager {
} }
} }
/**
* 对模型Config与Prompt分别进行检验,除了都必须包含default外还需要确保数量、key一致毕竟是模型配置与提示词
*/
public void check() {
log.info("[AgentConfigManager]: 执行config与prompt检测...");
if (!modelConfigMap.containsKey("default")) {
throw new ConfigNotExistException("缺少默认配置! 需确保存在一个模型配置的key为`default`");
}
if (!modelPromptMap.containsKey("basic")) {
throw new PromptNotExistException("缺少基础Prompt! 需要确保存在key为basic的Prompt文件它将与其他Prompt共同作用于模块节点。");
}
Set<String> configKeySet = new HashSet<>(modelConfigMap.keySet());
configKeySet.remove("default");
Set<String> promptKeySet = new HashSet<>(modelPromptMap.keySet());
promptKeySet.remove("basic");
if (!promptKeySet.containsAll(configKeySet)) {
log.warn("存在未被提示词包含的模型配置,该配置将无法生效!");
}
log.info("[AgentConfigManager]: 检测完毕.");
}
public List<Message> loadModelPrompt(String modelKey) { public List<Message> loadModelPrompt(String modelKey) {
if (!modelPromptMap.containsKey(modelKey)) { if (!modelPromptMap.containsKey(modelKey)) {
throw new PromptNotExistException("不存在的modelPrompt: " + modelKey); throw new PromptNotExistException("不存在的modelPrompt: " + modelKey);
@@ -108,12 +97,7 @@ public abstract class AgentConfigManager {
} }
moduleEnabledStatus.put(key, status); moduleEnabledStatus.put(key, status);
dumpModuleEnabledStatus(); dumpModuleEnabledStatus();
for (MetaModule metaModule : moduleList) { moduleMap.get(key).setEnabled(status);
if (metaModule.getName().equals(key)) {
metaModule.setEnabled(status);
break;
}
}
} }
} }

View File

@@ -22,12 +22,12 @@ import java.util.List;
* 将从当前运行目录的config文件夹下创建并读取配置 * 将从当前运行目录的config文件夹下创建并读取配置
*/ */
@Slf4j @Slf4j
public class DefaultAgentConfigManager extends AgentConfigManager { public class FileAgentConfigManager extends AgentConfigManager {
private static final String CONFIG_DIR = "./config/"; protected static final String CONFIG_DIR = "./config/";
private static final String MODEL_CONFIG_DIR = "./config/model/"; protected static final String MODEL_CONFIG_DIR = "./config/model/";
private static final String PROMPT_CONFIG_DIR = "./config/prompt/"; protected static final String PROMPT_CONFIG_DIR = "./config/prompt/";
private static final String MODULE_ENABLED_STATUS_CONFIG_FILE = CONFIG_DIR + "module_enabled_status.json"; protected static final String MODULE_ENABLED_STATUS_CONFIG_FILE = CONFIG_DIR + "module_enabled_status.json";
@Override @Override
@@ -74,10 +74,10 @@ public class DefaultAgentConfigManager extends AgentConfigManager {
} }
@Override @Override
protected HashMap<String, Boolean> loadModuleEnabledStatusMap() { protected HashMap<String, Boolean> loadModuleEnabledStatusMap(List<MetaModule> moduleList) {
File file = new File(MODULE_ENABLED_STATUS_CONFIG_FILE); File file = new File(MODULE_ENABLED_STATUS_CONFIG_FILE);
try { try {
HashMap<String, Boolean> moduleEnabledStatus = new HashMap<>(); moduleEnabledStatus = new HashMap<>();
if (!file.exists()) { if (!file.exists()) {
file.createNewFile(); file.createNewFile();
for (MetaModule module : moduleList) { for (MetaModule module : moduleList) {

View File

@@ -0,0 +1,11 @@
package work.slhaf.partner.api.agent.runtime.exception;
public class AgentRunningFailedException extends AgentRuntimeException{
public AgentRunningFailedException(String message) {
super(message);
}
public AgentRunningFailedException(String message, Throwable cause) {
super(message, cause);
}
}

View File

@@ -1,23 +1,36 @@
package work.slhaf.partner.api.agent.runtime.exception; package work.slhaf.partner.api.agent.runtime.exception;
import lombok.extern.slf4j.Slf4j;
@Slf4j
public class GlobalExceptionHandler { public class GlobalExceptionHandler {
public static GlobalExceptionHandler INSTANCE = new GlobalExceptionHandler(); public static GlobalExceptionHandler INSTANCE = new GlobalExceptionHandler();
private AgentExceptionCallback exceptionCallback = new DefaultAgentExceptionCallback(); private AgentExceptionCallback exceptionCallback = new LogAgentExceptionCallback();
public void handle(Throwable e) { public boolean handle(Throwable e) {
boolean exit;
switch (e.getClass().getSimpleName()) { Throwable cause = e.getCause();
case "AgentRuntimeException": switch (cause) {
exceptionCallback.onRuntimeException((AgentRuntimeException) e); case AgentRunningFailedException arfe -> {
break; exit = true;
case "AgentLaunchFailedException": exceptionCallback.onRuntimeException((AgentRuntimeException) cause);
exceptionCallback.onFailedException((AgentLaunchFailedException) e);
break;
default:
throw new RuntimeException("未经处理的异常!", e);
} }
case AgentRuntimeException are -> {
exit = false;
exceptionCallback.onRuntimeException((AgentRuntimeException) cause);
}
case AgentLaunchFailedException alfe -> {
exit = true;
exceptionCallback.onFailedException((AgentLaunchFailedException) cause);
}
default -> {
exit = true;
log.error("意外异常: ", cause);
}
}
return exit;
} }
public static void setExceptionCallback(AgentExceptionCallback callback) { public static void setExceptionCallback(AgentExceptionCallback callback) {

View File

@@ -3,7 +3,7 @@ package work.slhaf.partner.api.agent.runtime.exception;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
@Slf4j @Slf4j
public class DefaultAgentExceptionCallback implements AgentExceptionCallback { public class LogAgentExceptionCallback implements AgentExceptionCallback {
@Override @Override
public void onRuntimeException(AgentRuntimeException e) { public void onRuntimeException(AgentRuntimeException e) {

View File

@@ -4,9 +4,23 @@ import work.slhaf.partner.api.agent.runtime.interaction.data.AgentInputData;
import work.slhaf.partner.api.agent.runtime.interaction.data.AgentOutputData; import work.slhaf.partner.api.agent.runtime.interaction.data.AgentOutputData;
import work.slhaf.partner.api.agent.runtime.interaction.flow.entity.RunningFlowContext; import work.slhaf.partner.api.agent.runtime.interaction.flow.entity.RunningFlowContext;
public interface AgentGateway { public interface AgentGateway <I extends AgentInputData, O extends AgentOutputData, C extends RunningFlowContext>{
void launch(); void launch();
<I extends AgentInputData, O extends AgentOutputData, C extends RunningFlowContext> AgentInteractionAdapter<I, O, C> adapter(); default void receive(I inputData){
C finalInputData = adapter().parseInputData(inputData);
C outputContext = adapter().call(finalInputData);
O outputData = adapter().parseOutputData(outputContext);
send(outputData);
}
void send(O outputData);
/**
* 通过adapter提供的receive、send方法进行与客户端的交互行为
*
* @return adapter实例
*/
AgentInteractionAdapter<I, O, C> adapter();
} }

View File

@@ -8,34 +8,19 @@ import work.slhaf.partner.api.agent.runtime.interaction.flow.AgentRunningFlow;
import work.slhaf.partner.api.agent.runtime.interaction.flow.entity.RunningFlowContext; import work.slhaf.partner.api.agent.runtime.interaction.flow.entity.RunningFlowContext;
import java.util.List; import java.util.List;
import java.util.Map;
public abstract class AgentInteractionAdapter<I extends AgentInputData, O extends AgentOutputData, C extends RunningFlowContext> { public abstract class AgentInteractionAdapter<I extends AgentInputData, O extends AgentOutputData, C extends RunningFlowContext> {
private static AgentInteractionAdapter<?,?,?> INSTANCE;
protected AgentRunningFlow<C> agentRunningFlow = new AgentRunningFlow<>(); protected AgentRunningFlow<C> agentRunningFlow = new AgentRunningFlow<>();
protected List<MetaModule> moduleList = AgentConfigManager.INSTANCE.getModuleList(); protected Map<Integer, List<MetaModule>> moduleOrderedMap = AgentConfigManager.INSTANCE.getModuleOrderedMap();
public void receive(I inputData) { public C call(C finalInputData){
C finalInputData = parseInputData(inputData); return agentRunningFlow.launch(moduleOrderedMap, finalInputData);
C outputContext = agentRunningFlow.launch(moduleList, finalInputData);
O outputData = parseOutputData(outputContext);
send(outputData);
} }
protected abstract O parseOutputData(C outputContext); protected abstract O parseOutputData(C outputContext);
protected abstract C parseInputData(I inputData); protected abstract C parseInputData(I inputData);
public abstract void send(O outputData);
public static <I extends AgentInputData, O extends AgentOutputData, C extends RunningFlowContext> AgentInteractionAdapter<I, O, C> getInstance() {
@SuppressWarnings("unchecked")
AgentInteractionAdapter<I, O, C> instance = (AgentInteractionAdapter<I, O, C>) INSTANCE;
return instance;
}
public static <I extends AgentInputData, O extends AgentOutputData, C extends RunningFlowContext> void setInstance(AgentInteractionAdapter<I, O, C> instance) {
INSTANCE = instance;
}
} }

View File

@@ -7,7 +7,7 @@ import lombok.EqualsAndHashCode;
@Data @Data
public abstract class AgentOutputData extends InteractionData{ public abstract class AgentOutputData extends InteractionData{
private int code; protected int code;
public static class StatusCode { public static class StatusCode {
public static final int SUCCESS = 1; public static final int SUCCESS = 1;

View File

@@ -1,24 +1,52 @@
package work.slhaf.partner.api.agent.runtime.interaction.flow; package work.slhaf.partner.api.agent.runtime.interaction.flow;
import work.slhaf.partner.api.agent.factory.module.pojo.MetaModule; import work.slhaf.partner.api.agent.factory.module.pojo.MetaModule;
import work.slhaf.partner.api.agent.runtime.exception.AgentRuntimeException;
import work.slhaf.partner.api.agent.runtime.exception.GlobalExceptionHandler; import work.slhaf.partner.api.agent.runtime.exception.GlobalExceptionHandler;
import work.slhaf.partner.api.agent.runtime.interaction.flow.entity.RunningFlowContext; import work.slhaf.partner.api.agent.runtime.interaction.flow.entity.RunningFlowContext;
import java.util.ArrayList;
import java.util.List; import java.util.List;
import java.util.Map;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
/** /**
* Agent执行流程 * Agent执行流程
*/ */
public class AgentRunningFlow<C extends RunningFlowContext> { public class AgentRunningFlow<C extends RunningFlowContext> {
public C launch(List<MetaModule> moduleList, C interactionContext){ public C launch(Map<Integer, List<MetaModule>> modules, C interactionContext) {
try { try (ExecutorService executor = Executors.newVirtualThreadPerTaskExecutor()) {
//流程执行启动 //流程执行启动
for (MetaModule metaModule : moduleList) { for (Map.Entry<Integer, List<MetaModule>> entry : modules.entrySet()) {
metaModule.getInstance().execute(interactionContext); List<Future<?>> futures = new ArrayList<>();
List<MetaModule> moduleList = entry.getValue();
for (MetaModule module : moduleList) {
Future<?> future = executor.submit(() -> {
try {
module.getInstance().execute(interactionContext);
} catch (Exception e) {
throw new AgentRuntimeException("模块执行出错: " + module.getName(), e);
} }
}catch (Exception e){ });
GlobalExceptionHandler.INSTANCE.handle(e); futures.add(future);
}
for (Future<?> future : futures) {
try {
future.get();
} catch (Exception e) {
boolean exit = GlobalExceptionHandler.INSTANCE.handle(e);
if (exit) throw new AgentRuntimeException("Agent执行出错!", e);
interactionContext.getErrMsg().add(e.getLocalizedMessage());
}
}
}
interactionContext.setOk(1);
} catch (Exception e) {
interactionContext.setOk(0);
interactionContext.getErrMsg().add(e.getLocalizedMessage());
} }
return interactionContext; return interactionContext;
} }

View File

@@ -17,7 +17,7 @@ public interface ActivateModel {
AgentConfigManager AGENT_CONFIG_MANAGER = AgentConfigManager.INSTANCE; AgentConfigManager AGENT_CONFIG_MANAGER = AgentConfigManager.INSTANCE;
@Init @Init(order = -1)
default void modelSettings() { default void modelSettings() {
Model model = new Model(); Model model = new Model();
ModelConfig modelConfig = AgentConfigManager.INSTANCE.loadModelConfig(modelKey()); ModelConfig modelConfig = AgentConfigManager.INSTANCE.loadModelConfig(modelKey());
@@ -87,6 +87,9 @@ public interface ActivateModel {
((Module) this).setModel(model); ((Module) this).setModel(model);
} }
/**
* 对应调用的模型配置名称
*/
String modelKey(); String modelKey();
boolean withBasicPrompt(); boolean withBasicPrompt();

View File

@@ -1,10 +1,38 @@
package work.slhaf.partner.api.agent.runtime.interaction.flow.abstracts; package work.slhaf.partner.api.agent.runtime.interaction.flow.abstracts;
import lombok.extern.slf4j.Slf4j;
import work.slhaf.partner.api.agent.factory.module.annotation.AfterExecute;
import work.slhaf.partner.api.agent.factory.module.annotation.AgentModule;
import work.slhaf.partner.api.agent.factory.module.annotation.BeforeExecute;
import work.slhaf.partner.api.agent.factory.module.annotation.CoreModule;
import work.slhaf.partner.api.agent.runtime.interaction.flow.entity.RunningFlowContext; import work.slhaf.partner.api.agent.runtime.interaction.flow.entity.RunningFlowContext;
import java.io.IOException;
/** /**
* 流程执行模块基类 * 流程执行模块基类
*/ */
public abstract class AgentRunningModule extends Module { @Slf4j
public abstract void execute(RunningFlowContext context); public abstract class AgentRunningModule<C extends RunningFlowContext> extends Module {
public abstract void execute(C context) throws IOException, ClassNotFoundException;
@BeforeExecute
private void beforeLog() {
log.debug("[{}] 模块执行开始...", getModuleName());
}
@AfterExecute
private void afterLog() {
log.debug("[{}] 模块执行结束...", getModuleName());
}
private String getModuleName(){
if (this.getClass().isAnnotationPresent(AgentModule.class)) {
return this.getClass().getAnnotation(AgentModule.class).name();
} else if (this.getClass().isAnnotationPresent(CoreModule.class)) {
return CoreModule.class.getAnnotation(AgentModule.class).name();
}else {
return "Unknown Module";
}
}
} }

View File

@@ -1,13 +1,35 @@
package work.slhaf.partner.api.agent.runtime.interaction.flow.abstracts; package work.slhaf.partner.api.agent.runtime.interaction.flow.abstracts;
/** import lombok.extern.slf4j.Slf4j;
* 流程子模块基类 import work.slhaf.partner.api.agent.factory.module.annotation.AfterExecute;
* @param <I> 输入类型 import work.slhaf.partner.api.agent.factory.module.annotation.AgentModule;
* @param <O> 输出类型 import work.slhaf.partner.api.agent.factory.module.annotation.BeforeExecute;
*/ import work.slhaf.partner.api.agent.factory.module.annotation.CoreModule;
@Slf4j
public abstract class AgentRunningSubModule<I, O> extends Module { public abstract class AgentRunningSubModule<I, O> extends Module {
public abstract O execute(I data); public abstract O execute(I data);
@BeforeExecute
private void beforeLog() {
log.debug("[{}] 模块执行开始...", getModuleName());
}
@AfterExecute
private void afterLog() {
log.debug("[{}] 模块执行结束...", getModuleName());
}
private String getModuleName(){
if (this.getClass().isAnnotationPresent(AgentModule.class)) {
return this.getClass().getAnnotation(AgentModule.class).name();
} else if (this.getClass().isAnnotationPresent(CoreModule.class)) {
return CoreModule.class.getAnnotation(AgentModule.class).name();
}else {
return "Unknown Module";
}
}
} }

View File

@@ -2,13 +2,11 @@ package work.slhaf.partner.api.agent.runtime.interaction.flow.abstracts;
import lombok.Getter; import lombok.Getter;
import lombok.Setter; import lombok.Setter;
import work.slhaf.partner.api.agent.factory.capability.annotation.CapabilityHolder;
import work.slhaf.partner.api.agent.runtime.interaction.flow.entity.Model; import work.slhaf.partner.api.agent.runtime.interaction.flow.entity.Model;
/** /**
* 模块基类 * 模块基类
*/ */
@CapabilityHolder
public abstract class Module { public abstract class Module {
@Getter @Getter

View File

@@ -1,11 +1,18 @@
package work.slhaf.partner.api.agent.runtime.interaction.flow.entity; package work.slhaf.partner.api.agent.runtime.interaction.flow.entity;
import lombok.Data; import lombok.Data;
import lombok.EqualsAndHashCode;
import work.slhaf.partner.api.common.entity.PersistableObject;
import java.util.ArrayList;
import java.util.List;
/** /**
* 流程上下文 * 流程上下文
*/ */
@EqualsAndHashCode(callSuper = true)
@Data @Data
public abstract class RunningFlowContext { public abstract class RunningFlowContext extends PersistableObject {
protected int ok;
protected List<String> errMsg = new ArrayList<>();
} }

View File

@@ -1,11 +1,37 @@
package work.slhaf.partner.api.agent.util; package work.slhaf.partner.api.agent.util;
import org.reflections.Reflections;
import java.lang.annotation.Annotation;
import java.lang.reflect.Method; import java.lang.reflect.Method;
import java.util.HashSet; import java.util.HashSet;
import java.util.Set; import java.util.Set;
import java.util.stream.Collectors;
public final class AgentUtil { public final class AgentUtil {
public static boolean isAssignableFromAnnotation(Class<?> clazz,Class<? extends Annotation> targetAnnotation){
Set<Class<?>> visited = new HashSet<>();
return isAssignableFromAnnotation(clazz,targetAnnotation,visited);
}
private static boolean isAssignableFromAnnotation(Class<?> clazz,Class<? extends Annotation> targetAnnotation,Set<Class<?>> visited){
if (!visited.add(clazz)){
return false;
}
if (clazz.isAnnotationPresent(targetAnnotation)){
return true;
}
Annotation[] annotations = clazz.getAnnotations();
for (Annotation annotation : annotations) {
boolean ok = isAssignableFromAnnotation(annotation.annotationType(),targetAnnotation,visited);
if (ok){
return true;
}
}
return false;
}
public static String methodSignature(Method method) { public static String methodSignature(Method method) {
StringBuilder sb = new StringBuilder(); StringBuilder sb = new StringBuilder();
sb.append("("); sb.append("(");
@@ -23,6 +49,7 @@ public final class AgentUtil {
public static Set<Class<?>> collectExtendedClasses(Class<?> clazz, Class<?> targetClass) { public static Set<Class<?>> collectExtendedClasses(Class<?> clazz, Class<?> targetClass) {
Set<Class<?>> classes = new HashSet<>(); Set<Class<?>> classes = new HashSet<>();
collectExtendedClasses(classes, clazz, targetClass); collectExtendedClasses(classes, clazz, targetClass);
classes.add(clazz);
return classes; return classes;
} }
@@ -36,10 +63,18 @@ public final class AgentUtil {
collectInterfaces(clazz, classes); collectInterfaces(clazz, classes);
} }
public static Set<Class<?>> getMethodAnnotationTypeSet(Class<? extends Annotation> clazz, Reflections reflections){
Set<Method> methods = reflections.getMethodsAnnotatedWith(clazz);
return methods.stream()
.map(Method::getDeclaringClass)
.collect(Collectors.toSet());
}
private static void collectInterfaces(Class<?> clazz, Set<Class<?>> classes) { private static void collectInterfaces(Class<?> clazz, Set<Class<?>> classes) {
for (Class<?> type : clazz.getInterfaces()) { for (Class<?> type : clazz.getInterfaces()) {
if (classes.add(type)) { if (classes.add(type)) {
collectInterfaces(type, classes); collectInterfaces(type, classes);
} }
} }
}} }
}

View File

@@ -6,11 +6,12 @@ import net.bytebuddy.matcher.ElementMatchers;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import work.slhaf.partner.api.agent.runtime.interaction.flow.abstracts.AgentRunningModule; import work.slhaf.partner.api.agent.runtime.interaction.flow.abstracts.AgentRunningModule;
import java.io.IOException;
import java.lang.reflect.InvocationTargetException; import java.lang.reflect.InvocationTargetException;
public class ModuleProxyTest { public class ModuleProxyTest {
@Test @Test
public void test() throws NoSuchMethodException, InvocationTargetException, InstantiationException, IllegalAccessException { public void test() throws NoSuchMethodException, InvocationTargetException, InstantiationException, IllegalAccessException, IOException, ClassNotFoundException {
Class<? extends AgentRunningModule> clazz = new ByteBuddy().subclass(MyAgentRunningModule.class) Class<? extends AgentRunningModule> clazz = new ByteBuddy().subclass(MyAgentRunningModule.class)
.method(ElementMatchers.isOverriddenFrom(AgentRunningModule.class)) .method(ElementMatchers.isOverriddenFrom(AgentRunningModule.class))
.intercept(MethodDelegation.to( .intercept(MethodDelegation.to(

View File

@@ -0,0 +1,43 @@
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/maven-v4_0_0.xsd">
<parent>
<artifactId>Partner</artifactId>
<groupId>work.slhaf</groupId>
<version>0.5.0</version>
</parent>
<modelVersion>4.0.0</modelVersion>
<artifactId>Partner-Main</artifactId>
<build>
<plugins>
<plugin>
<artifactId>maven-shade-plugin</artifactId>
<executions>
<execution>
<phase>package</phase>
<goals>
<goal>shade</goal>
</goals>
<configuration>
<transformers>
<transformer>
<mainClass>work.slhaf.partner.Main</mainClass>
</transformer>
</transformers>
</configuration>
</execution>
</executions>
</plugin>
<plugin>
<artifactId>maven-surefire-plugin</artifactId>
<configuration>
<skipTests>true</skipTests>
</configuration>
</plugin>
</plugins>
</build>
<properties>
<maven.compiler.target>21</maven.compiler.target>
<maven.compiler.source>21</maven.compiler.source>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
</properties>
</project>

View File

@@ -22,6 +22,12 @@
<artifactId>Partner-Api</artifactId> <artifactId>Partner-Api</artifactId>
<version>0.5.0</version> <version>0.5.0</version>
</dependency> </dependency>
<dependency>
<groupId>org.jetbrains.kotlinx</groupId>
<artifactId>kotlinx-coroutines-core</artifactId>
<version>1.10.2</version>
<scope>test</scope>
</dependency>
</dependencies> </dependencies>
<properties> <properties>
@@ -30,4 +36,35 @@
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding> <project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
</properties> </properties>
<build>
<plugins>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-shade-plugin</artifactId>
<executions>
<execution>
<phase>package</phase>
<goals>
<goal>shade</goal>
</goals>
<configuration>
<transformers>
<transformer implementation="org.apache.maven.plugins.shade.resource.ManifestResourceTransformer">
<mainClass>work.slhaf.partner.Main</mainClass>
</transformer>
</transformers>
</configuration>
</execution>
</executions>
</plugin>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-surefire-plugin</artifactId>
<configuration>
<skipTests>true</skipTests>
</configuration>
</plugin>
</plugins>
</build>
</project> </project>

View File

@@ -1,14 +0,0 @@
package work.slhaf;
import work.slhaf.partner.Agent;
import java.io.IOException;
import java.util.Scanner;
public class Main {
public static void main(String[] args) throws IOException {
Agent.initialize();
Scanner scanner = new Scanner(System.in);
while (!scanner.nextLine().equals("exit")) ;
}
}

View File

@@ -1,75 +0,0 @@
package work.slhaf.partner;
import lombok.Data;
import lombok.extern.slf4j.Slf4j;
import work.slhaf.partner.common.config.Config;
import work.slhaf.partner.common.monitor.DebugMonitor;
import work.slhaf.partner.core.InteractionHub;
import work.slhaf.partner.core.interaction.agent_interface.InputReceiver;
import work.slhaf.partner.core.interaction.agent_interface.TaskCallback;
import work.slhaf.partner.core.interaction.data.InteractionInputData;
import work.slhaf.partner.core.interaction.data.InteractionOutputData;
import work.slhaf.partner.gateway.AgentWebSocketServer;
import work.slhaf.partner.gateway.MessageSender;
import java.io.IOException;
import java.time.LocalDateTime;
@Data
@Slf4j
public class Agent implements TaskCallback, InputReceiver {
private static volatile Agent agent;
private InteractionHub interactionHub;
private MessageSender messageSender;
public static void initialize() throws IOException {
if (agent == null) {
synchronized (Agent.class) {
if (agent == null) {
//加载配置
Config config = Config.getConfig();
agent = new Agent();
agent.setInteractionHub(InteractionHub.initialize());
agent.registerTaskCallback();
AgentWebSocketServer server = new AgentWebSocketServer(config.getWebSocketConfig().getPort(), agent);
server.launch();
agent.setMessageSender(server);
log.info("Agent 加载完毕..");
//启动监测线程
DebugMonitor.initialize();
}
}
}
}
public static Agent getInstance() throws IOException {
initialize();
return agent;
}
/**
* 接收用户输入,包装为标准输入数据类
*/
public void receiveInput(InteractionInputData inputData) throws IOException, ClassNotFoundException {
inputData.setLocalDateTime(LocalDateTime.now());
interactionHub.call(inputData);
}
/**
* 向用户返回输出内容
*/
public void sendToUser(String userInfo, String output) {
messageSender.sendMessage(new InteractionOutputData(output, userInfo));
}
@Override
public void onTaskFinished(String userInfo, String output) {
sendToUser(userInfo, output);
}
private void registerTaskCallback() {
interactionHub.setCallback(this);
}
}

View File

@@ -0,0 +1,16 @@
package work.slhaf.partner;
import work.slhaf.partner.api.agent.Agent;
import work.slhaf.partner.common.config.PartnerAgentConfigManager;
import work.slhaf.partner.runtime.exception.PartnerExceptionCallback;
import work.slhaf.partner.runtime.interaction.WebSocketGateway;
public class Main {
public static void main(String[] args) {
Agent.newAgent(Main.class)
.setAgentConfigManager(PartnerAgentConfigManager.class)
.setGateway(WebSocketGateway.class)
.setAgentExceptionCallback(PartnerExceptionCallback.class)
.launch();
}
}

View File

@@ -1,138 +1,9 @@
package work.slhaf.partner.common.config; package work.slhaf.partner.common.config;
import cn.hutool.json.JSONUtil;
import com.alibaba.fastjson2.JSONArray;
import lombok.Data; import lombok.Data;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.io.FileUtils;
import work.slhaf.partner.module.modules.core.CoreModel;
import work.slhaf.partner.module.modules.memory.selector.MemorySelector;
import work.slhaf.partner.module.modules.memory.updater.MemoryUpdater;
import work.slhaf.partner.module.modules.process.PostprocessExecutor;
import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.nio.file.Paths;
import java.util.List;
import java.util.Scanner;
@Data @Data
@Slf4j
public class Config { public class Config {
private int port;
private static final String CONFIG_FILE_PATH = "./config/config.json";
private static final String LOG_FILE_PATH = "./data/log";
private static Config config;
private String agentId; private String agentId;
// private String basicCharacter;
private WebSocketConfig webSocketConfig;
private List<ModuleConfig> moduleConfigList;
private Config() {
}
public static Config getConfig() throws IOException {
if (config == null) {
File file = new File(CONFIG_FILE_PATH);
if (file.exists()) {
config = JSONUtil.readJSONObject(file, StandardCharsets.UTF_8).toBean(Config.class);
} else {
config = new Config();
Scanner scanner = new Scanner(System.in);
System.out.print("输入智能体名称: ");
config.setAgentId(scanner.nextLine());
System.out.println("(注意! 设定角色之后修改主配置文件将不会影响现有记忆除非同时更换agentId)");
System.out.println("\r\n--------模型配置--------\r\n");
generateModelConfig(scanner);
System.out.println("\r\n--------服务配置--------\r\n");
generateWsSocketConfig(scanner);
System.out.println("\r\n--------模块链配置--------\r\n");
generatePipelineConfig();
boolean launchOrNot = getLaunchOrNot(scanner);
//保存配置文件
String str = JSONUtil.toJsonPrettyStr(config);
FileUtils.writeStringToFile(file, str, StandardCharsets.UTF_8);
log.info("配置已保存");
if (!launchOrNot) {
System.exit(0);
}
}
config.generateCommonDirs();
}
return config;
}
private void generateCommonDirs() throws IOException {
Files.createDirectories(Paths.get(LOG_FILE_PATH));
}
private static boolean getLaunchOrNot(Scanner scanner) {
System.out.print("是否直接启动Partner?(y/n): ");
String input;
while (true) {
input = scanner.nextLine();
if (input.equals("y")) {
return true;
} else if (input.equals("n")) {
return false;
} else {
System.out.println("请输入y或n");
}
}
}
private static void generatePipelineConfig() {
List<ModuleConfig> moduleConfigList = List.of(
new ModuleConfig(MemorySelector.class.getName(), ModuleConfig.Constant.INTERNAL, null),
new ModuleConfig(CoreModel.class.getName(), ModuleConfig.Constant.INTERNAL, null),
new ModuleConfig(PostprocessExecutor.class.getName(),ModuleConfig.Constant.INTERNAL,null),
new ModuleConfig(MemoryUpdater.class.getName(), ModuleConfig.Constant.INTERNAL, null)
);
config.setModuleConfigList(moduleConfigList);
}
private static void generateWsSocketConfig(Scanner scanner) {
System.out.print("WebSocket port: ");
WebSocketConfig wsConfig = new WebSocketConfig();
wsConfig.setPort(scanner.nextInt());
config.setWebSocketConfig(wsConfig);
}
private static void generateModelConfig(Scanner scanner) throws IOException {
System.out.println("配置LLM APi:");
System.out.println("经测试, 目前只建议选择Qwen3: qwen-plus-latest或qwen-max-latest");
System.out.print("base_url: ");
String baseUrl = scanner.nextLine();
System.out.print("apikey: ");
String apikey = scanner.nextLine();
System.out.print("model: ");
String model = scanner.nextLine();
ModelConfig modelConfig = new ModelConfig();
modelConfig.setBaseUrl(baseUrl);
modelConfig.setApikey(apikey);
modelConfig.setModel(model);
InputStream stream = Config.class.getClassLoader().getResourceAsStream("modules/default_activated_model.json");
String content = new String(stream.readAllBytes(), StandardCharsets.UTF_8);
stream.close();
for (String s : JSONArray.parseArray(content, String.class)) {
modelConfig.generateConfig(s);
}
}
} }

View File

@@ -1,40 +0,0 @@
package work.slhaf.partner.common.config;
import cn.hutool.json.JSONUtil;
import lombok.Data;
import org.apache.commons.io.FileUtils;
import java.io.File;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.util.HashMap;
@Data
public class ModelConfig {
private static final String MODEL_CONFIG_DIR_PATH = "./config/model/";
private static final HashMap<String, ModelConfig> modelConfigMap = new HashMap<>();
private String apikey;
private String baseUrl;
private String model;
public void generateConfig(String filename) throws IOException {
String str = JSONUtil.toJsonPrettyStr(this);
File file = new File(MODEL_CONFIG_DIR_PATH + filename + ".json");
FileUtils.writeStringToFile(file, str, StandardCharsets.UTF_8);
}
public static ModelConfig load(String modelKey) {
if (!modelConfigMap.containsKey(modelKey)) {
modelConfigMap.put(modelKey,loadConfig(modelKey));
}
return modelConfigMap.get(modelKey);
}
private static ModelConfig loadConfig(String modelKey) {
File file = new File(MODEL_CONFIG_DIR_PATH+modelKey+".json");
return JSONUtil.readJSONObject(file,StandardCharsets.UTF_8).toBean(ModelConfig.class);
}
}

View File

@@ -1,17 +0,0 @@
package work.slhaf.partner.common.config;
import lombok.AllArgsConstructor;
import lombok.Data;
@Data
@AllArgsConstructor
public class ModuleConfig {
private String className;
private String type;
private String path;
public static class Constant {
public static final String INTERNAL = "internal";
public static final String EXTERNAL = "external";
}
}

View File

@@ -0,0 +1,40 @@
package work.slhaf.partner.common.config;
import cn.hutool.json.JSONUtil;
import lombok.Data;
import lombok.EqualsAndHashCode;
import work.slhaf.partner.api.agent.factory.config.exception.ConfigNotExistException;
import work.slhaf.partner.api.agent.runtime.config.FileAgentConfigManager;
import work.slhaf.partner.common.exception.ConfigLoadFailedException;
import java.io.File;
import java.nio.charset.StandardCharsets;
@EqualsAndHashCode(callSuper = true)
@Data
public final class PartnerAgentConfigManager extends FileAgentConfigManager {
private static final String COMMON_CONFIG_FILE = CONFIG_DIR + "common_config.json";
private Config config;
@Override
public void load() {
loadWebSocketConfig();
super.load();
}
private void loadWebSocketConfig() {
File file = new File(COMMON_CONFIG_FILE);
if (!file.exists()) {
throw new ConfigNotExistException("Partner Config Not Exist: " + COMMON_CONFIG_FILE);
}
config = JSONUtil.readJSONObject(file, StandardCharsets.UTF_8).toBean(Config.class);
if (config == null || config.getAgentId() == null) {
throw new ConfigLoadFailedException("Partner Config Load Failed: " + COMMON_CONFIG_FILE);
}
if (config.getPort() <= 0 || config.getPort() > 65535) {
throw new ConfigLoadFailedException("Invalid Websocket port: " + config.getPort());
}
}
}

View File

@@ -1,8 +0,0 @@
package work.slhaf.partner.common.config;
import lombok.Data;
@Data
public class WebSocketConfig {
private Integer port;
}

View File

@@ -0,0 +1,13 @@
package work.slhaf.partner.common.exception;
import work.slhaf.partner.api.agent.factory.config.exception.ConfigFactoryInitFailedException;
public class ConfigLoadFailedException extends ConfigFactoryInitFailedException {
public ConfigLoadFailedException(String message, Throwable cause) {
super(message, cause);
}
public ConfigLoadFailedException(String message) {
super(message);
}
}

View File

@@ -0,0 +1,13 @@
package work.slhaf.partner.common.exception;
import work.slhaf.partner.api.agent.runtime.exception.AgentLaunchFailedException;
public class ServiceLoadFailedException extends AgentLaunchFailedException {
public ServiceLoadFailedException(String message, Throwable cause) {
super(message, cause);
}
public ServiceLoadFailedException(String message) {
super(message);
}
}

View File

@@ -1,43 +0,0 @@
package work.slhaf.partner.common.exception_handler;
import lombok.extern.slf4j.Slf4j;
import work.slhaf.partner.common.exception_handler.pojo.GlobalException;
import work.slhaf.partner.common.exception_handler.pojo.GlobalExceptionData;
import java.io.*;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
@Slf4j
public class GlobalExceptionHandler {
private static final String EXCEPTION_STATIC_PATH = "./data/exception_snapshot/";
public static void writeExceptionState(GlobalException exception) {
GlobalExceptionData exceptionData = exception.getData();
Path filePath = Paths.get(EXCEPTION_STATIC_PATH, exceptionData.getExceptionTime() + ".dat");
try {
Files.createDirectories(Path.of(EXCEPTION_STATIC_PATH));
ObjectOutputStream oos = new ObjectOutputStream(new FileOutputStream(filePath.toFile()));
oos.writeObject(exceptionData);
oos.close();
log.warn("[GlobalExceptionHandler] 捕获异常, 已保存到: {}", filePath);
} catch (IOException e) {
log.error("[GlobalExceptionHandler] 捕获异常, 保存失败: ", e);
}
}
public static GlobalExceptionData readExceptionState(String filePath) {
try {
ObjectInputStream ois = new ObjectInputStream(new FileInputStream(filePath));
GlobalExceptionData exceptionData = (GlobalExceptionData) ois.readObject();
ois.close();
log.info("[GlobalExceptionHandler] 已从: {} 读取异常快照", filePath);
return exceptionData;
} catch (IOException | ClassNotFoundException e) {
log.error("[GlobalExceptionHandler] 读取异常, 读取失败: ", e);
return null;
}
}
}

View File

@@ -1,30 +0,0 @@
package work.slhaf.partner.common.exception_handler.pojo;
import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.extern.slf4j.Slf4j;
import work.slhaf.partner.core.cognation.cognation.CognationCore;
import work.slhaf.partner.core.interaction.data.context.InteractionContext;
import work.slhaf.partner.core.session.SessionManager;
@EqualsAndHashCode(callSuper = true)
@Slf4j
@Data
public class GlobalException extends RuntimeException {
private GlobalExceptionData data;
public GlobalException(String message) {
super(message);
try {
this.data = new GlobalExceptionData();
this.data.setExceptionTime(System.currentTimeMillis());
this.data.setSessionManager(SessionManager.getInstance());
this.data.setCognationCore(CognationCore.getInstance());
this.data.setContext(InteractionContext.getInstance());
} catch (Exception e) {
log.error("[GlobalException] 捕获异常, 获取数据失败");
}
}
}

View File

@@ -1,26 +0,0 @@
package work.slhaf.partner.common.exception_handler.pojo;
import lombok.Data;
import lombok.EqualsAndHashCode;
import work.slhaf.partner.api.common.entity.PersistableObject;
import work.slhaf.partner.core.cognation.cognation.CognationCore;
import work.slhaf.partner.core.interaction.data.context.InteractionContext;
import work.slhaf.partner.core.session.SessionManager;
import java.io.Serial;
import java.util.HashMap;
@EqualsAndHashCode(callSuper = true)
@Data
public class GlobalExceptionData extends PersistableObject {
@Serial
private static final long serialVersionUID = 1L;
private String exceptionMessage;
protected HashMap<String, InteractionContext> context;
protected SessionManager sessionManager;
protected CognationCore cognationCore;
protected Long exceptionTime;
}

View File

@@ -1,7 +1,7 @@
package work.slhaf.partner.common.util; package work.slhaf.partner.common.util;
import com.alibaba.fastjson2.JSONArray; import com.alibaba.fastjson2.JSONArray;
import work.slhaf.partner.Agent; import work.slhaf.partner.api.agent.Agent;
import work.slhaf.partner.api.chat.pojo.Message; import work.slhaf.partner.api.chat.pojo.Message;
import java.io.InputStream; import java.io.InputStream;

View File

@@ -0,0 +1,56 @@
package work.slhaf.partner.core;
import lombok.Data;
import lombok.extern.slf4j.Slf4j;
import work.slhaf.partner.api.agent.factory.capability.annotation.CoordinateManager;
import work.slhaf.partner.api.agent.factory.capability.annotation.Coordinated;
import work.slhaf.partner.api.chat.constant.ChatConstant;
import work.slhaf.partner.core.cognation.CognationCore;
import work.slhaf.partner.core.memory.MemoryCore;
import java.io.Serial;
import java.io.Serializable;
import java.util.HashSet;
import java.util.Set;
import static work.slhaf.partner.common.util.ExtractUtil.extractUserId;
@Data
@Slf4j
@CoordinateManager
public class CoordinatedManager implements Serializable {
@Serial
private static final long serialVersionUID = 1L;
//在框架将自动注入core,详见CapabilityRegistryFactory
private CognationCore cognationCore;
private MemoryCore memoryCore;
private boolean isCacheSingleUser() {
return memoryCore.getUserDialogMap().size() <= 1;
}
@Coordinated(capability = "cognation")
public boolean isSingleUser() {
return isCacheSingleUser() && isChatMessagesSingleUser();
}
private boolean isChatMessagesSingleUser() {
Set<String> userIdSet = new HashSet<>();
cognationCore.getChatMessages().forEach(m -> {
if (m.getRole().equals(ChatConstant.Character.ASSISTANT)) {
return;
}
String userId = extractUserId(m.getContent());
if (userId == null || userId.isEmpty()) {
return;
}
userIdSet.add(userId);
});
return userIdSet.size() <= 1;
}
}

View File

@@ -1,56 +0,0 @@
package work.slhaf.partner.core;
import lombok.Data;
import lombok.ToString;
import lombok.extern.slf4j.Slf4j;
import work.slhaf.partner.common.exception_handler.GlobalExceptionHandler;
import work.slhaf.partner.common.exception_handler.pojo.GlobalException;
import work.slhaf.partner.core.interaction.agent_interface.TaskCallback;
import work.slhaf.partner.core.interaction.data.InteractionInputData;
import work.slhaf.partner.core.interaction.data.context.InteractionContext;
import work.slhaf.partner.core.interaction.module.InteractionFlow;
import work.slhaf.partner.core.interaction.module.InteractionModulesLoader;
import work.slhaf.partner.module.modules.process.PreprocessExecutor;
import java.io.IOException;
import java.util.List;
@Data
@Slf4j
public class InteractionHub {
private static volatile InteractionHub interactionHub;
@ToString.Exclude
private TaskCallback callback;
private List<InteractionFlow> interactionModules;
public static InteractionHub initialize() throws IOException {
if (interactionHub == null) {
synchronized (InteractionHub.class) {
if (interactionHub == null) {
interactionHub = new InteractionHub();
//加载模块
interactionHub.setInteractionModules(InteractionModulesLoader.getInstance().registerInteractionModules());
log.info("InteractionHub注册完毕...");
}
}
}
return interactionHub;
}
public void call(InteractionInputData inputData) throws IOException, ClassNotFoundException {
InteractionContext interactionContext = PreprocessExecutor.getInstance().execute(inputData);
try {
for (InteractionFlow interactionModule : interactionModules) {
interactionModule.execute(interactionContext);
}
} catch (GlobalException e) {
GlobalExceptionHandler.writeExceptionState(e);
interactionContext.getCoreResponse().put("text", "[ERROR] " + e.getMessage());
} finally {
callback.onTaskFinished(interactionContext.getUserInfo(), interactionContext.getCoreResponse().getString("text"));
interactionContext.clearUp();
}
}
}

View File

@@ -0,0 +1,93 @@
package work.slhaf.partner.core;
import cn.hutool.core.bean.BeanUtil;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.io.FileUtils;
import work.slhaf.partner.api.agent.runtime.config.AgentConfigManager;
import work.slhaf.partner.api.common.entity.PersistableObject;
import work.slhaf.partner.common.config.PartnerAgentConfigManager;
import java.io.*;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.nio.file.StandardCopyOption;
@Slf4j
public abstract class PartnerCore<T extends PartnerCore<T>> extends PersistableObject {
private static final String STORAGE_DIR = "./data/memory/";
private final String id = ((PartnerAgentConfigManager) AgentConfigManager.INSTANCE).getConfig().getAgentId();
public PartnerCore() throws IOException, ClassNotFoundException {
createStorageDirectory();
Path filePath = getFilePath(id);
if (Files.exists(filePath)) {
T deserialize = deserialize();
setupData(deserialize, (T) this);
} else {
FileUtils.createParentDirectories(filePath.toFile().getParentFile());
this.serialize();
}
setupHook(this);
log.info("[{}] 注册完毕", getCoreKey());
}
private void setupHook(PartnerCore<T> temp) {
Runtime.getRuntime().addShutdownHook(new Thread(() -> {
try {
temp.serialize();
log.info("[{}] 已保存", getCoreKey());
} catch (IOException e) {
log.error("[{}] 保存失败: ", getCoreKey(), e);
}
}));
}
private void setupData(T source, T current) {
BeanUtil.copyProperties(source, current);
}
public void serialize() throws IOException {
//先写入到临时文件,如果正常写入则覆盖原文件
Path filePath = getFilePath(id + "-temp");
Files.createDirectories(Path.of(STORAGE_DIR));
try {
ObjectOutputStream oos = new ObjectOutputStream(new FileOutputStream(filePath.toFile()));
oos.writeObject(this);
oos.close();
Path path = getFilePath(id);
Files.move(filePath, path, StandardCopyOption.REPLACE_EXISTING);
log.info("[{}] 已保存到: {}", getCoreKey(), path);
} catch (IOException e) {
Files.delete(filePath);
log.error("[{}] 序列化保存失败: {}", getCoreKey(), e.getMessage());
}
}
private T deserialize() throws IOException, ClassNotFoundException {
Path filePath = getFilePath(id);
try (ObjectInputStream ois = new ObjectInputStream(
new FileInputStream(filePath.toFile()))) {
T graph = (T) ois.readObject();
log.info("[{}] 已从文件加载: {}", getCoreKey(), filePath);
return graph;
}
}
private Path getFilePath(String s) {
return Paths.get(STORAGE_DIR, s + "-" + getCoreKey() + ".memory");
}
private void createStorageDirectory() {
try {
Files.createDirectories(Paths.get(STORAGE_DIR));
} catch (IOException e) {
log.error("[{}]创建存储目录失败: {}", getCoreKey(), e.getMessage());
}
}
protected abstract String getCoreKey();
}

View File

@@ -0,0 +1,11 @@
package work.slhaf.partner.core.action.entity;
import lombok.Data;
@Data
public class ActionData {
private String key;
private String[] array;
private String reason;
private String description;
}

View File

@@ -0,0 +1,5 @@
package work.slhaf.partner.core.action.entity;
public enum ActionStatus {
SUCCESS, FAILED, EXECUTING, WAITING
}

View File

@@ -0,0 +1,5 @@
package work.slhaf.partner.core.action.entity;
public enum ActionType {
IMMEDIATE, PLANNING
}

View File

@@ -0,0 +1,13 @@
package work.slhaf.partner.core.action.entity;
import lombok.Data;
import java.time.LocalDateTime;
@Data
public class MetaActionInfo {
private ActionData actionData;
private ActionStatus status;
private String Result;
private LocalDateTime dateTime;
}

View File

@@ -0,0 +1,29 @@
package work.slhaf.partner.core.cognation;
import work.slhaf.partner.api.agent.factory.capability.annotation.Capability;
import work.slhaf.partner.api.agent.factory.capability.annotation.ToCoordinated;
import work.slhaf.partner.api.chat.pojo.Message;
import work.slhaf.partner.api.chat.pojo.MetaMessage;
import java.util.HashMap;
import java.util.List;
import java.util.concurrent.locks.Lock;
@Capability("cognation")
public interface CognationCapability {
List<Message> getChatMessages();
void setChatMessages(List<Message> chatMessages);
void cleanMessage(List<Message> messages);
Lock getMessageLock();
void addMetaMessage(String userId, MetaMessage metaMessage);
List<Message> unpackAndClear(String userId);
void refreshMemoryId();
void resetLastUpdatedTime();
long getLastUpdatedTime();
HashMap<String,List<MetaMessage>> getSingleMetaMessageMap();
String getCurrentMemoryId();
@ToCoordinated
boolean isSingleUser();
}

View File

@@ -0,0 +1,121 @@
package work.slhaf.partner.core.cognation;
import com.alibaba.fastjson2.JSONObject;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.Setter;
import lombok.extern.slf4j.Slf4j;
import work.slhaf.partner.api.agent.factory.capability.annotation.CapabilityCore;
import work.slhaf.partner.api.agent.factory.capability.annotation.CapabilityMethod;
import work.slhaf.partner.api.chat.pojo.Message;
import work.slhaf.partner.api.chat.pojo.MetaMessage;
import work.slhaf.partner.core.PartnerCore;
import java.io.IOException;
import java.io.Serial;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.UUID;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReentrantLock;
@EqualsAndHashCode(callSuper = true)
@Slf4j
@CapabilityCore(value = "cognation")
@Getter
@Setter
public class CognationCore extends PartnerCore<CognationCore> {
@Serial
private static final long serialVersionUID = 1L;
private final ReentrantLock messageLock = new ReentrantLock();
/**
* 主模型的聊天记录
*/
private List<Message> chatMessages = new ArrayList<>();
private HashMap<String /*startUserId*/, List<MetaMessage>> singleMetaMessageMap = new HashMap<>();
private String currentMemoryId;
private long lastUpdatedTime;
public CognationCore() throws IOException, ClassNotFoundException {
}
@CapabilityMethod
public List<Message> getChatMessages() {
return chatMessages;
}
@CapabilityMethod
public long getLastUpdatedTime(){
return lastUpdatedTime;
}
@CapabilityMethod
public HashMap<String,List<MetaMessage>> getSingleMetaMessageMap(){
return singleMetaMessageMap;
}
@CapabilityMethod
public String getCurrentMemoryId(){
return currentMemoryId;
}
@CapabilityMethod
public void setChatMessages(List<Message> chatMessages) {
this.chatMessages = chatMessages;
}
@CapabilityMethod
public void cleanMessage(List<Message> messages) {
messageLock.lock();
this.getChatMessages().removeAll(messages);
messageLock.unlock();
}
@CapabilityMethod
public Lock getMessageLock() {
return messageLock;
}
@CapabilityMethod
public void addMetaMessage(String userId, MetaMessage metaMessage) {
log.debug("[{}] 当前会话历史: {}", getCoreKey(), JSONObject.toJSONString(singleMetaMessageMap));
if (singleMetaMessageMap.containsKey(userId)) {
singleMetaMessageMap.get(userId).add(metaMessage);
} else {
singleMetaMessageMap.put(userId, new java.util.ArrayList<>());
singleMetaMessageMap.get(userId).add(metaMessage);
}
log.debug("[{}] 会话历史更新: {}", getCoreKey(), JSONObject.toJSONString(singleMetaMessageMap));
}
@CapabilityMethod
public List<Message> unpackAndClear(String userId) {
List<Message> messages = new ArrayList<>();
for (MetaMessage metaMessage : singleMetaMessageMap.get(userId)) {
messages.add(metaMessage.getUserMessage());
messages.add(metaMessage.getAssistantMessage());
}
singleMetaMessageMap.remove(userId);
return messages;
}
@CapabilityMethod
public void refreshMemoryId() {
currentMemoryId = UUID.randomUUID().toString();
}
@CapabilityMethod
public void resetLastUpdatedTime() {
lastUpdatedTime = System.currentTimeMillis();
}
@Override
protected String getCoreKey() {
return "cognation-core";
}
}

View File

@@ -1,166 +0,0 @@
package work.slhaf.partner.core.cognation;
import lombok.Data;
import lombok.extern.slf4j.Slf4j;
import work.slhaf.partner.api.agent.factory.capability.annotation.CoordinateManager;
import work.slhaf.partner.api.agent.factory.capability.annotation.Coordinated;
import work.slhaf.partner.api.chat.constant.ChatConstant;
import work.slhaf.partner.common.exception_handler.GlobalExceptionHandler;
import work.slhaf.partner.common.exception_handler.pojo.GlobalException;
import work.slhaf.partner.core.cognation.cognation.CognationCore;
import work.slhaf.partner.core.cognation.common.pojo.MemoryResult;
import work.slhaf.partner.core.cognation.common.pojo.MemorySliceResult;
import work.slhaf.partner.core.cognation.submodule.cache.CacheCore;
import work.slhaf.partner.core.cognation.submodule.dispatch.DispatchCore;
import work.slhaf.partner.core.cognation.submodule.memory.MemoryCore;
import work.slhaf.partner.core.cognation.submodule.memory.pojo.MemorySlice;
import work.slhaf.partner.core.cognation.submodule.perceive.PerceiveCore;
import java.io.IOException;
import java.io.Serial;
import java.io.Serializable;
import java.time.LocalDate;
import java.util.*;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReentrantLock;
import static work.slhaf.partner.common.util.ExtractUtil.extractUserId;
@Data
@Slf4j
@CoordinateManager
public class CoordinatedManager implements Serializable {
@Serial
private static final long serialVersionUID = 1L;
private static volatile CoordinatedManager coordinatedManager;
private final Lock sliceInsertLock = new ReentrantLock();
private CognationCore cognationCore;
private CacheCore cacheCore;
private MemoryCore memoryCore;
private PerceiveCore perceiveCore;
private DispatchCore dispatchCore;
private CoordinatedManager() {
}
public static CoordinatedManager getInstance() throws IOException, ClassNotFoundException {
if (coordinatedManager == null) {
synchronized (CoordinatedManager.class) {
if (coordinatedManager == null) {
coordinatedManager = new CoordinatedManager();
coordinatedManager.setCognationCore(CognationCore.getInstance());
coordinatedManager.setCores();
log.info("[CoordinatedManager] MemoryManager注册完毕...");
}
}
}
return coordinatedManager;
}
private void setCores() {
this.setCacheCore(this.getCognationCore().getCacheCore());
this.setMemoryCore(this.getCognationCore().getMemoryCore());
this.setPerceiveCore(this.getCognationCore().getPerceiveCore());
}
@Coordinated(capability = "memory")
public MemoryResult selectMemory(String topicPathStr) {
MemoryResult memoryResult;
List<String> topicPath = List.of(topicPathStr.split("->"));
try {
List<String> path = new ArrayList<>(topicPath);
//每日刷新缓存
cacheCore.checkCacheDate();
//检测缓存并更新计数, 查看是否需要放入缓存
cacheCore.updateCacheCounter(path);
//查看是否存在缓存,如果存在,则直接返回
if ((memoryResult = cacheCore.selectCache(path)) != null) {
return memoryResult;
}
memoryResult = memoryCore.selectMemory(path);
//尝试更新缓存
cacheCore.updateCache(topicPath, memoryResult);
} catch (Exception e) {
log.error("[CoordinatedManager] selectMemory error: ", e);
log.error("[CoordinatedManager] 路径: {}", topicPathStr);
log.error("[CoordinatedManager] 主题树: {}", memoryCore.getTopicTree());
memoryResult = new MemoryResult();
memoryResult.setRelatedMemorySliceResult(new ArrayList<>());
memoryResult.setMemorySliceResult(new CopyOnWriteArrayList<>());
GlobalExceptionHandler.writeExceptionState(new GlobalException(e.getLocalizedMessage()));
}
return cacheFilter(memoryResult);
}
@Coordinated(capability = "memory")
public MemoryResult selectMemory(LocalDate date) throws IOException, ClassNotFoundException {
return cacheFilter(memoryCore.selectMemory(date));
}
private MemoryResult cacheFilter(MemoryResult memoryResult) {
//过滤掉与缓存重复的切片
CopyOnWriteArrayList<MemorySliceResult> memorySliceResult = memoryResult.getMemorySliceResult();
List<MemorySlice> relatedMemorySliceResult = memoryResult.getRelatedMemorySliceResult();
cacheCore.getDialogMap().forEach((k, v) -> {
memorySliceResult.removeIf(m -> m.getMemorySlice().getSummary().equals(v));
relatedMemorySliceResult.removeIf(m -> m.getSummary().equals(v));
});
return memoryResult;
}
@Coordinated(capability = "memory")
public void insertSlice(MemorySlice memorySlice, String topicPath) {
sliceInsertLock.lock();
List<String> topicPathList = Arrays.stream(topicPath.split("->")).toList();
try {
//检查是否存在当天对应的memorySlice并确定是否插入
//每日刷新缓存
cacheCore.checkCacheDate();
//如果topicPath在memorySliceCache中存在对应缓存由于进行的插入操作则需要移除该缓存但不清除相关计数
cacheCore.clearCacheByTopicPath(topicPathList);
memoryCore.insertMemory(topicPathList, memorySlice);
if (!memorySlice.isPrivate()) {
cacheCore.updateUserDialogMap(memorySlice);
}
} catch (Exception e) {
log.error("[CoordinatedManager] 插入记忆时出错: ", e);
GlobalExceptionHandler.writeExceptionState(new GlobalException("插入记忆时出错: " + e.getLocalizedMessage()));
}
log.debug("[CoordinatedManager] 插入切片: {}, 路径: {}", memorySlice, topicPath);
sliceInsertLock.unlock();
}
private boolean isCacheSingleUser() {
return cacheCore.getUserDialogMap().size() <= 1;
}
@Coordinated(capability = "cognation")
public boolean isSingleUser() {
return isCacheSingleUser() && isChatMessagesSingleUser();
}
private boolean isChatMessagesSingleUser() {
Set<String> userIdSet = new HashSet<>();
cognationCore.getChatMessages().forEach(m -> {
if (m.getRole().equals(ChatConstant.Character.ASSISTANT)) {
return;
}
String userId = extractUserId(m.getContent());
if (userId == null || userId.isEmpty()) {
return;
}
userIdSet.add(userId);
});
return userIdSet.size() <= 1;
}
}

View File

@@ -1,51 +0,0 @@
package work.slhaf.partner.core.cognation.cognation;
import work.slhaf.partner.api.agent.factory.capability.annotation.Capability;
import work.slhaf.partner.api.agent.factory.capability.annotation.CapabilityMethod;
import work.slhaf.partner.api.agent.factory.capability.annotation.ToCoordinated;
import work.slhaf.partner.api.chat.pojo.Message;
import work.slhaf.partner.core.cognation.submodule.memory.pojo.EvaluatedSlice;
import java.util.HashMap;
import java.util.List;
import java.util.concurrent.locks.Lock;
@Capability("cognation")
public interface CognationCapability {
@CapabilityMethod
List<Message> getChatMessages();
@CapabilityMethod
void setChatMessages(List<Message> chatMessages);
@CapabilityMethod
void cleanMessage(List<Message> messages);
@CapabilityMethod
void updateActivatedSlices(String userId, List<EvaluatedSlice> memorySlices);
@CapabilityMethod
String getActivatedSlicesStr(String userId);
@CapabilityMethod
HashMap<String, List<EvaluatedSlice>> getActivatedSlices();
@CapabilityMethod
void clearActivatedSlices(String userId);
@CapabilityMethod
boolean hasActivatedSlices(String userId);
@CapabilityMethod
int getActivatedSlicesSize(String userId);
@CapabilityMethod
List<EvaluatedSlice> getActivatedSlices(String userId);
@CapabilityMethod
Lock getMessageLock();
@ToCoordinated
boolean isSingleUser();
}

View File

@@ -1,170 +0,0 @@
package work.slhaf.partner.core.cognation.cognation;
import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.io.FileUtils;
import work.slhaf.partner.api.agent.factory.capability.annotation.CapabilityCore;
import work.slhaf.partner.api.chat.pojo.Message;
import work.slhaf.partner.api.common.entity.PersistableObject;
import work.slhaf.partner.core.cognation.cognation.pojo.ActiveData;
import work.slhaf.partner.core.cognation.submodule.cache.CacheCore;
import work.slhaf.partner.core.cognation.submodule.memory.MemoryCore;
import work.slhaf.partner.core.cognation.submodule.memory.pojo.EvaluatedSlice;
import work.slhaf.partner.core.cognation.submodule.perceive.PerceiveCore;
import java.io.*;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.nio.file.StandardCopyOption;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.concurrent.locks.ReentrantLock;
@EqualsAndHashCode(callSuper = true)
@Data
@Slf4j
@CapabilityCore(value = "cognation")
public class CognationCore extends PersistableObject {
@Serial
private static final long serialVersionUID = 1L;
private static final String STORAGE_DIR = "./data/memory/";
private static volatile CognationCore cognationCore;
private MemoryCore memoryCore = new MemoryCore();
private CacheCore cacheCore = new CacheCore();
private PerceiveCore perceiveCore = new PerceiveCore();
private ReentrantLock messageLock = new ReentrantLock();
private ActiveData activeData;
/**
* 主模型的聊天记录
*/
private List<Message> chatMessages = new ArrayList<>();
public CognationCore() throws IOException, ClassNotFoundException {
createStorageDirectory();
Path filePath = getFilePath("partner");
if (Files.exists(filePath)) {
setupData(this);
} else {
FileUtils.createParentDirectories(filePath.toFile().getParentFile());
connectCores(this);
this.serialize();
}
setupHook(this);
log.info("CognationCore注册完毕...");
}
private void connectCores(CognationCore temp) {
temp.setCacheCore(CacheCore.getInstance());
temp.setMemoryCore(MemoryCore.getInstance());
temp.setPerceiveCore(PerceiveCore.getInstance());
}
private void setupHook(CognationCore temp) {
Runtime.getRuntime().addShutdownHook(new Thread(() -> {
try {
temp.serialize();
log.info("[CognationCore] CognationCore已保存");
} catch (IOException e) {
log.error("[CognationCore] CognationCore保存失败: ", e);
}
}));
}
private void setupData(CognationCore temp) throws IOException, ClassNotFoundException {
CognationCore deserialize = deserialize();
temp.activeData = deserialize.activeData;
temp.memoryCore = deserialize.memoryCore;
temp.cacheCore = deserialize.cacheCore;
temp.perceiveCore = deserialize.perceiveCore;
temp.chatMessages = deserialize.chatMessages;
}
public static CognationCore getInstance() {
return cognationCore;
}
public void serialize() throws IOException {
//先写入到临时文件,如果正常写入则覆盖原文件
Path filePath = getFilePath("partner-temp");
Files.createDirectories(Path.of(STORAGE_DIR));
try {
ObjectOutputStream oos = new ObjectOutputStream(new FileOutputStream(filePath.toFile()));
oos.writeObject(this);
oos.close();
Path path = getFilePath("partner");
Files.move(filePath, path, StandardCopyOption.REPLACE_EXISTING);
log.info("CognationCore 已保存到: {}", path);
} catch (IOException e) {
Files.delete(filePath);
log.error("序列化保存失败: {}", e.getMessage());
}
}
private static CognationCore deserialize() throws IOException, ClassNotFoundException {
Path filePath = getFilePath("partner");
try (ObjectInputStream ois = new ObjectInputStream(
new FileInputStream(filePath.toFile()))) {
CognationCore graph = (CognationCore) ois.readObject();
log.info("CognationCore 已从文件加载: {}", filePath);
return graph;
}
}
private static Path getFilePath(String s) {
return Paths.get(STORAGE_DIR, s + ".memory");
}
private static void createStorageDirectory() {
try {
Files.createDirectories(Paths.get(STORAGE_DIR));
} catch (IOException e) {
System.err.println("创建存储目录失败: " + e.getMessage());
}
}
public void cleanMessage(List<Message> messages) {
messageLock.lock();
this.getChatMessages().removeAll(messages);
messageLock.unlock();
}
public void updateActivatedSlices(String userId, List<EvaluatedSlice> memorySlices) {
activeData.updateActivatedSlices(userId, memorySlices);
log.debug("[CoordinatedManager] 已更新激活切片, userId: {}", userId);
}
public String getActivatedSlicesStr(String userId) {
return activeData.getActivatedSlicesStr(userId);
}
public HashMap<String, List<EvaluatedSlice>> getActivatedSlices() {
return activeData.getActivatedSlices();
}
public void clearActivatedSlices(String userId) {
activeData.clearActivatedSlices(userId);
}
public boolean hasActivatedSlices(String userId) {
return activeData.hasActivatedSlices(userId);
}
public int getActivatedSlicesSize(String userId) {
return activeData.getActivatedSlices().get(userId).size();
}
public List<EvaluatedSlice> getActivatedSlices(String userId) {
return activeData.getActivatedSlices().get(userId);
}
}

View File

@@ -1,38 +0,0 @@
package work.slhaf.partner.core.cognation.cognation.pojo;
import lombok.Data;
import work.slhaf.partner.core.cognation.submodule.memory.pojo.EvaluatedSlice;
import java.util.HashMap;
import java.util.List;
@Data
public class ActiveData {
private HashMap<String, List<EvaluatedSlice>> activatedSlices;
public void updateActivatedSlices(String userId, List<EvaluatedSlice> memorySlices) {
activatedSlices.put(userId, memorySlices);
}
public String getActivatedSlicesStr(String userId) {
if (activatedSlices.containsKey(userId)) {
StringBuilder str = new StringBuilder();
activatedSlices.get(userId).forEach(slice -> str.append("\n\n").append("[").append(slice.getDate()).append("]\n")
.append(slice.getSummary()));
return str.toString();
} else {
return null;
}
}
public void clearActivatedSlices(String userId) {
activatedSlices.remove(userId);
}
public boolean hasActivatedSlices(String userId) {
if (!activatedSlices.containsKey(userId)){
return false;
}
return !activatedSlices.get(userId).isEmpty();
}
}

View File

@@ -1,4 +1,4 @@
package work.slhaf.partner.core.cognation.cognation.exception; package work.slhaf.partner.core.cognation.exception;
public class UserNotExistsException extends RuntimeException { public class UserNotExistsException extends RuntimeException {
public UserNotExistsException(String message) { public UserNotExistsException(String message) {

View File

@@ -1,16 +0,0 @@
package work.slhaf.partner.core.cognation.submodule.cache;
import work.slhaf.partner.api.agent.factory.capability.annotation.Capability;
import java.time.LocalDateTime;
import java.util.HashMap;
import java.util.concurrent.ConcurrentHashMap;
@Capability(value = "cache")
public interface CacheCapability {
HashMap<LocalDateTime, String> getDialogMap();
ConcurrentHashMap<LocalDateTime, String> getUserDialogMap(String userId);
void updateDialogMap(LocalDateTime dateTime, String newDialogCache);
String getDialogMapStr();
String getUserDialogMapStr(String userId);
}

View File

@@ -1,182 +0,0 @@
package work.slhaf.partner.core.cognation.submodule.cache;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.extern.slf4j.Slf4j;
import work.slhaf.partner.api.agent.factory.capability.annotation.CapabilityCore;
import work.slhaf.partner.api.agent.factory.capability.annotation.CapabilityMethod;
import work.slhaf.partner.api.common.entity.PersistableObject;
import work.slhaf.partner.core.cognation.common.pojo.MemoryResult;
import work.slhaf.partner.core.cognation.submodule.memory.pojo.MemorySlice;
import java.io.Serial;
import java.time.LocalDate;
import java.time.LocalDateTime;
import java.util.*;
import java.util.concurrent.ConcurrentHashMap;
@EqualsAndHashCode(callSuper = true)
@Slf4j
@CapabilityCore(value = "cache")
@Getter
public class CacheCore extends PersistableObject {
@Serial
private static final long serialVersionUID = 1L;
public static volatile CacheCore cacheCore;
/**
* 近两日的对话总结缓存, 用于为大模型提供必要的记忆补充, hashmap以切片的存储时间为键总结为值
* 该部分作为'主LLM'system prompt常驻
* 该部分作为近两日的整体对话缓存, 不区分用户
*/
private HashMap<LocalDateTime, String> dialogMap = new HashMap<>();
/**
* 近两日的区分用户的对话总结缓存在prompt结构上比dialogMap层级深一层, dialogMap更具近两日整体对话的摘要性质
*/
private ConcurrentHashMap<String/*userId*/, ConcurrentHashMap<LocalDateTime, String>> userDialogMap = new ConcurrentHashMap<>();
/**
* memorySliceCache计数器每日清空
*/
private ConcurrentHashMap<List<String> /*触发查询的主题列表*/, Integer> memoryNodeCacheCounter = new ConcurrentHashMap<>();
/**
* 记忆切片缓存,每日清空
* 用于记录作为终点节点调用次数最多的记忆节点的切片数据
*/
private ConcurrentHashMap<List<String> /*主题路径*/, MemoryResult /*切片列表*/> memorySliceCache = new ConcurrentHashMap<>();
/**
* 缓存日期
*/
private LocalDate cacheDate;
/**
* 已被选中的切片时间戳集合,需要及时清理
*/
private Set<Long> selectedSlices = new HashSet<>();
public CacheCore() {
cacheCore = this;
}
public static CacheCore getInstance(){
return cacheCore;
}
@CapabilityMethod
public void updateDialogMap(LocalDateTime dateTime, String newDialogCache) {
List<LocalDateTime> keysToRemove = new ArrayList<>();
dialogMap.forEach((k, v) -> {
if (dateTime.minusDays(2).isAfter(k)) {
keysToRemove.add(k);
}
});
for (LocalDateTime temp : keysToRemove) {
dialogMap.remove(temp);
}
keysToRemove.clear();
//放入新缓存
dialogMap.put(dateTime, newDialogCache);
}
@CapabilityMethod
public HashMap<LocalDateTime, String> getDialogMap(){
return dialogMap;
}
@CapabilityMethod
public ConcurrentHashMap<LocalDateTime, String> getUserDialogMap(String userId) {
return this.getUserDialogMap().get(userId);
}
@CapabilityMethod
public String getDialogMapStr() {
StringBuilder str = new StringBuilder();
this.getDialogMap().forEach((dateTime, dialog) -> str.append("\n\n").append("[").append(dateTime).append("]\n")
.append(dialog));
return str.toString();
}
@CapabilityMethod
public String getUserDialogMapStr(String userId) {
if (this.getUserDialogMap().containsKey(userId)) {
StringBuilder str = new StringBuilder();
Collection<String> dialogMapValues = this.getDialogMap().values();
this.getUserDialogMap().get(userId).forEach((dateTime, dialog) -> {
if (dialogMapValues.contains(dialog)) {
return;
}
str.append("\n\n").append("[").append(dateTime).append("]\n")
.append(dialog);
});
return str.toString();
} else {
return null;
}
}
public void updateCacheCounter(List<String> topicPath) {
if (memoryNodeCacheCounter.containsKey(topicPath)) {
Integer tempCount = memoryNodeCacheCounter.get(topicPath);
memoryNodeCacheCounter.put(topicPath, ++tempCount);
} else {
memoryNodeCacheCounter.put(topicPath, 1);
}
}
public void checkCacheDate() {
if (cacheDate == null || cacheDate.isBefore(LocalDate.now())) {
memorySliceCache.clear();
memoryNodeCacheCounter.clear();
cacheDate = LocalDate.now();
}
}
public void updateCache(List<String> topicPath, MemoryResult memoryResult) {
Integer tempCount = memoryNodeCacheCounter.get(topicPath);
if (tempCount == null) {
log.warn("[CacheCore] tempCount为null? memoryNodeCacheCounter: {}; topicPath: {}", memoryNodeCacheCounter, topicPath);
return;
}
if (tempCount >= 5) {
memorySliceCache.put(topicPath, memoryResult);
}
}
public void updateUserDialogMap(MemorySlice slice) {
String summary = slice.getSummary();
LocalDateTime now = LocalDateTime.now();
//更新userDialogMap
//移除两天前上下文缓存(切片总结)
List<LocalDateTime> keysToRemove = new ArrayList<>();
userDialogMap.forEach((k, v) -> v.forEach((i, j) -> {
if (now.minusDays(2).isAfter(i)) {
keysToRemove.add(i);
}
}));
for (LocalDateTime dateTime : keysToRemove) {
userDialogMap.forEach((k, v) -> v.remove(dateTime));
}
//放入新缓存
userDialogMap
.computeIfAbsent(slice.getStartUserId(), k -> new ConcurrentHashMap<>())
.merge(now, summary, (oldVal, newVal) -> oldVal + " " + newVal);
}
public void clearCacheByTopicPath(List<String> topicPath) {
memorySliceCache.remove(topicPath);
}
public MemoryResult selectCache(List<String> path) {
if (memorySliceCache.containsKey(path)) {
return memorySliceCache.get(path);
}
return null;
}
}

View File

@@ -1,4 +0,0 @@
package work.slhaf.partner.core.cognation.submodule.dispatch;
public interface DispatchCapability {
}

View File

@@ -1,32 +0,0 @@
package work.slhaf.partner.core.cognation.submodule.dispatch;
import work.slhaf.partner.api.common.entity.PersistableObject;
import work.slhaf.partner.core.cognation.submodule.dispatch.pojo.DispatchData;
import java.io.Serial;
public class DispatchCore extends PersistableObject {
@Serial
private static final long serialVersionUID = 1L;
public static volatile DispatchCore dispatchCore;
public static DispatchCore getInstance() {
if (dispatchCore == null) {
synchronized (DispatchCore.class) {
if (dispatchCore == null) {
dispatchCore = new DispatchCore();
}
}
}
return dispatchCore;
}
public void dispatch(DispatchData dispatchData){
}
public void listDispatchData(){
}
}

View File

@@ -1,15 +0,0 @@
package work.slhaf.partner.core.cognation.submodule.dispatch.pojo;
import lombok.Data;
import java.time.LocalDateTime;
@Data
public class DispatchData {
private LocalDateTime dateTime;
private String userId;
private String comment;
//TODO 替换为<执行器>或者<插件>
private String executor;
}

View File

@@ -1,25 +0,0 @@
package work.slhaf.partner.core.cognation.submodule.memory;
import work.slhaf.partner.api.agent.factory.capability.annotation.Capability;
import work.slhaf.partner.api.agent.factory.capability.annotation.ToCoordinated;
import work.slhaf.partner.core.cognation.common.pojo.MemoryResult;
import work.slhaf.partner.core.cognation.submodule.memory.pojo.MemorySlice;
import java.io.IOException;
import java.time.LocalDate;
@Capability(value = "memory")
public interface MemoryCapability {
void cleanSelectedSliceFilter();
String getTopicTree();
@ToCoordinated
MemoryResult selectMemory(String topicPathStr);
@ToCoordinated
MemoryResult selectMemory(LocalDate date) throws IOException, ClassNotFoundException;
@ToCoordinated
void insertSlice(MemorySlice memorySlice, String topicPath);
}

View File

@@ -1,11 +0,0 @@
package work.slhaf.partner.core.interaction.agent_interface;
import work.slhaf.partner.core.interaction.data.InteractionInputData;
import java.io.IOException;
public interface InputReceiver {
void receiveInput(InteractionInputData inputData) throws IOException, ClassNotFoundException;
}

View File

@@ -1,5 +0,0 @@
package work.slhaf.partner.core.interaction.agent_interface;
public interface TaskCallback {
void onTaskFinished(String userInfo,String output);
}

View File

@@ -1,15 +0,0 @@
package work.slhaf.partner.core.interaction.data;
import lombok.Data;
import java.time.LocalDateTime;
@Data
public class InteractionInputData {
private String userInfo;
private String userNickName;
private String content;
private LocalDateTime localDateTime;
private String platform;
private boolean single;
}

View File

@@ -1,11 +0,0 @@
package work.slhaf.partner.core.interaction.data;
import lombok.AllArgsConstructor;
import lombok.Data;
@Data
@AllArgsConstructor
public class InteractionOutputData {
private String content;
private String userInfo;
}

View File

@@ -1,9 +0,0 @@
package work.slhaf.partner.core.interaction.module;
import work.slhaf.partner.core.interaction.data.context.InteractionContext;
import java.io.IOException;
public interface InteractionFlow {
void execute(InteractionContext context) throws IOException, ClassNotFoundException;
}

View File

@@ -1,60 +0,0 @@
package work.slhaf.partner.core.interaction.module;
import work.slhaf.partner.common.config.Config;
import work.slhaf.partner.common.config.ModuleConfig;
import java.io.File;
import java.io.IOException;
import java.lang.reflect.InvocationTargetException;
import java.net.URL;
import java.net.URLClassLoader;
import java.util.ArrayList;
import java.util.List;
public class InteractionModulesLoader {
private static InteractionModulesLoader interactionModulesLoader;
public static InteractionModulesLoader getInstance(){
if (interactionModulesLoader == null) {
interactionModulesLoader = new InteractionModulesLoader();
}
return interactionModulesLoader;
}
public List<InteractionFlow> registerInteractionModules() throws IOException {
List<InteractionFlow> moduleList = new ArrayList<>();
List<ModuleConfig> moduleConfigList = Config.getConfig().getModuleConfigList();
for (ModuleConfig moduleConfig : moduleConfigList) {
if (ModuleConfig.Constant.INTERNAL.equals(moduleConfig.getType())) {
moduleList.add(loadInternalModule(moduleConfig.getClassName()));
} else if (ModuleConfig.Constant.EXTERNAL.equals(moduleConfig.getType())) {
moduleList.add(loadExternalModule(moduleConfig.getClassName(),moduleConfig.getPath()));
}
}
return moduleList;
}
private InteractionFlow loadExternalModule(String className, String path) {
try {
URL jarUrl = new File(path).toURI().toURL();
URLClassLoader loader = new URLClassLoader(new URL[]{jarUrl}, this.getClass().getClassLoader());
Class<?> clazz = loader.loadClass(className);
loader.close();
return (InteractionFlow) clazz.getMethod("getInstance").invoke(null);
} catch (ClassNotFoundException | InvocationTargetException | IllegalAccessException |
NoSuchMethodException | IOException e) {
throw new RuntimeException("Fail to load internal module: " + className ,e);
}
}
private static InteractionFlow loadInternalModule(String className) {
try {
Class<?> clazz = Class.forName(className);
return (InteractionFlow) clazz.getMethod("getInstance").invoke(null);
} catch (ClassNotFoundException | InvocationTargetException | IllegalAccessException | NoSuchMethodException e) {
throw new RuntimeException("Fail to load internal module: " + className,e);
}
}
}

View File

@@ -0,0 +1,52 @@
package work.slhaf.partner.core.memory;
import work.slhaf.partner.api.agent.factory.capability.annotation.Capability;
import work.slhaf.partner.core.memory.pojo.EvaluatedSlice;
import work.slhaf.partner.core.memory.pojo.MemoryResult;
import work.slhaf.partner.core.memory.pojo.MemorySlice;
import java.io.IOException;
import java.time.LocalDate;
import java.time.LocalDateTime;
import java.util.HashMap;
import java.util.List;
import java.util.concurrent.ConcurrentHashMap;
@Capability(value = "memory")
public interface MemoryCapability {
void cleanSelectedSliceFilter();
String getTopicTree();
HashMap<LocalDateTime, String> getDialogMap();
ConcurrentHashMap<LocalDateTime, String> getUserDialogMap(String userId);
void updateDialogMap(LocalDateTime dateTime, String newDialogCache);
String getDialogMapStr();
String getUserDialogMapStr(String userId);
void updateActivatedSlices(String userId, List<EvaluatedSlice> memorySlices);
String getActivatedSlicesStr(String userId);
HashMap<String, List<EvaluatedSlice>> getActivatedSlices();
void clearActivatedSlices(String userId);
boolean hasActivatedSlices(String userId);
int getActivatedSlicesSize(String userId);
List<EvaluatedSlice> getActivatedSlices(String userId);
MemoryResult selectMemory(String topicPathStr);
MemoryResult selectMemory(LocalDate date) throws IOException, ClassNotFoundException;
void insertSlice(MemorySlice memorySlice, String topicPath);
}

View File

@@ -1,33 +1,40 @@
package work.slhaf.partner.core.cognation.submodule.memory; package work.slhaf.partner.core.memory;
import lombok.Data;
import lombok.EqualsAndHashCode; import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.Setter;
import lombok.extern.slf4j.Slf4j;
import work.slhaf.partner.api.agent.factory.capability.annotation.CapabilityCore; import work.slhaf.partner.api.agent.factory.capability.annotation.CapabilityCore;
import work.slhaf.partner.api.agent.factory.capability.annotation.CapabilityMethod; import work.slhaf.partner.api.agent.factory.capability.annotation.CapabilityMethod;
import work.slhaf.partner.api.common.entity.PersistableObject; import work.slhaf.partner.core.PartnerCore;
import work.slhaf.partner.core.cognation.common.pojo.MemoryResult; import work.slhaf.partner.core.memory.exception.UnExistedDateIndexException;
import work.slhaf.partner.core.cognation.common.pojo.MemorySliceResult; import work.slhaf.partner.core.memory.exception.UnExistedTopicException;
import work.slhaf.partner.core.cognation.submodule.memory.exception.UnExistedDateIndexException; import work.slhaf.partner.core.memory.pojo.EvaluatedSlice;
import work.slhaf.partner.core.cognation.submodule.memory.exception.UnExistedTopicException; import work.slhaf.partner.core.memory.pojo.MemoryResult;
import work.slhaf.partner.core.cognation.submodule.memory.pojo.MemorySlice; import work.slhaf.partner.core.memory.pojo.MemorySlice;
import work.slhaf.partner.core.cognation.submodule.memory.pojo.node.MemoryNode; import work.slhaf.partner.core.memory.pojo.MemorySliceResult;
import work.slhaf.partner.core.cognation.submodule.memory.pojo.node.TopicNode; import work.slhaf.partner.core.memory.pojo.node.MemoryNode;
import work.slhaf.partner.core.memory.pojo.node.TopicNode;
import java.io.IOException; import java.io.IOException;
import java.io.Serial; import java.io.Serial;
import java.time.LocalDate; import java.time.LocalDate;
import java.time.LocalDateTime;
import java.util.*; import java.util.*;
import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.CopyOnWriteArrayList; import java.util.concurrent.CopyOnWriteArrayList;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReentrantLock;
@EqualsAndHashCode(callSuper = true) @EqualsAndHashCode(callSuper = true)
@Data
@CapabilityCore(value = "memory") @CapabilityCore(value = "memory")
public class MemoryCore extends PersistableObject { @Getter
@Setter
@Slf4j
public class MemoryCore extends PartnerCore<MemoryCore> {
@Serial @Serial
private static final long serialVersionUID = 1L; private static final long serialVersionUID = 1L;
public static MemoryCore memoryCore;
/** /**
* key: 根主题名称 value: 根主题节点 * key: 根主题名称 value: 根主题节点
@@ -55,16 +62,17 @@ public class MemoryCore extends PersistableObject {
*/ */
private Set<Long> selectedSlices = new HashSet<>(); private Set<Long> selectedSlices = new HashSet<>();
private HashMap<String,List<String>> userIndex = new HashMap<>(); private HashMap<String, List<String>> userIndex = new HashMap<>();
public MemoryCore(){ private MemoryCache cache = new MemoryCache();
memoryCore = this;
private final Lock sliceInsertLock = new ReentrantLock();
public MemoryCore() throws IOException, ClassNotFoundException {
} }
public static MemoryCore getInstance() {
return memoryCore;
}
@CapabilityMethod
public MemoryResult selectMemory(LocalDate date) throws IOException, ClassNotFoundException { public MemoryResult selectMemory(LocalDate date) throws IOException, ClassNotFoundException {
MemoryResult memoryResult = new MemoryResult(); MemoryResult memoryResult = new MemoryResult();
CopyOnWriteArrayList<MemorySliceResult> targetSliceList = new CopyOnWriteArrayList<>(); CopyOnWriteArrayList<MemorySliceResult> targetSliceList = new CopyOnWriteArrayList<>();
@@ -82,8 +90,176 @@ public class MemoryCore extends PersistableObject {
} }
} }
memoryResult.setMemorySliceResult(targetSliceList); memoryResult.setMemorySliceResult(targetSliceList);
return cacheFilter(memoryResult);
}
@CapabilityMethod
public void insertSlice(MemorySlice memorySlice, String topicPath) {
sliceInsertLock.lock();
List<String> topicPathList = Arrays.stream(topicPath.split("->")).toList();
try {
//检查是否存在当天对应的memorySlice并确定是否插入
//每日刷新缓存
checkCacheDate();
//如果topicPath在memorySliceCache中存在对应缓存由于进行的插入操作则需要移除该缓存但不清除相关计数
clearCacheByTopicPath(topicPathList);
insertMemory(topicPathList, memorySlice);
if (!memorySlice.isPrivate()) {
updateUserDialogMap(memorySlice);
}
} catch (Exception e) {
log.error("[CoordinatedManager] 插入记忆时出错: ", e);
}
log.debug("[CoordinatedManager] 插入切片: {}, 路径: {}", memorySlice, topicPath);
sliceInsertLock.unlock();
}
@CapabilityMethod
public String getTopicTree() {
StringBuilder stringBuilder = new StringBuilder();
for (Map.Entry<String, TopicNode> entry : topicNodes.entrySet()) {
String rootName = entry.getKey();
TopicNode rootNode = entry.getValue();
stringBuilder.append(rootName).append("[root]").append("\r\n");
printSubTopicsTreeFormat(rootNode, "", stringBuilder);
}
return stringBuilder.toString();
}
@CapabilityMethod
public void updateDialogMap(LocalDateTime dateTime, String newDialogCache) {
List<LocalDateTime> keysToRemove = new ArrayList<>();
HashMap<LocalDateTime, String> dialogMap = cache.dialogMap;
dialogMap.forEach((k, v) -> {
if (dateTime.minusDays(2).isAfter(k)) {
keysToRemove.add(k);
}
});
for (LocalDateTime temp : keysToRemove) {
dialogMap.remove(temp);
}
keysToRemove.clear();
//放入新缓存
dialogMap.put(dateTime, newDialogCache);
}
@CapabilityMethod
public HashMap<LocalDateTime, String> getDialogMap() {
return cache.dialogMap;
}
@CapabilityMethod
public ConcurrentHashMap<LocalDateTime, String> getUserDialogMap(String userId) {
return cache.userDialogMap.get(userId);
}
@CapabilityMethod
public String getDialogMapStr() {
StringBuilder str = new StringBuilder();
this.getDialogMap().forEach((dateTime, dialog) -> str.append("\n\n").append("[").append(dateTime).append("]\n")
.append(dialog));
return str.toString();
}
@CapabilityMethod
public String getUserDialogMapStr(String userId) {
ConcurrentHashMap<String, ConcurrentHashMap<LocalDateTime, String>> userDialogMap = cache.userDialogMap;
if (userDialogMap.containsKey(userId)) {
StringBuilder str = new StringBuilder();
Collection<String> dialogMapValues = this.getDialogMap().values();
userDialogMap.get(userId).forEach((dateTime, dialog) -> {
if (dialogMapValues.contains(dialog)) {
return;
}
str.append("\n\n").append("[").append(dateTime).append("]\n")
.append(dialog);
});
return str.toString();
} else {
return null;
}
}
@CapabilityMethod
public MemoryResult selectMemory(String topicPathStr) {
MemoryResult memoryResult;
List<String> topicPath = List.of(topicPathStr.split("->"));
try {
List<String> path = new ArrayList<>(topicPath);
//每日刷新缓存
checkCacheDate();
//检测缓存并更新计数, 查看是否需要放入缓存
updateCacheCounter(path);
//查看是否存在缓存如果存在则直接返回
if ((memoryResult = selectCache(path)) != null) {
return memoryResult; return memoryResult;
} }
memoryResult = selectMemory(path);
//尝试更新缓存
updateCache(topicPath, memoryResult);
} catch (Exception e) {
log.error("[CoordinatedManager] selectMemory error: ", e);
log.error("[CoordinatedManager] 路径: {}", topicPathStr);
log.error("[CoordinatedManager] 主题树: {}", getTopicTree());
memoryResult = new MemoryResult();
memoryResult.setRelatedMemorySliceResult(new ArrayList<>());
memoryResult.setMemorySliceResult(new CopyOnWriteArrayList<>());
}
return cacheFilter(memoryResult);
}
@CapabilityMethod
public void updateActivatedSlices(String userId, List<EvaluatedSlice> memorySlices) {
cache.activatedSlices.put(userId, memorySlices);
log.debug("[CoordinatedManager] 已更新激活切片, userId: {}", userId);
}
@CapabilityMethod
public String getActivatedSlicesStr(String userId) {
HashMap<String, List<EvaluatedSlice>> activatedSlices = cache.activatedSlices;
if (activatedSlices.containsKey(userId)) {
StringBuilder str = new StringBuilder();
activatedSlices.get(userId).forEach(slice -> str.append("\n\n").append("[").append(slice.getDate()).append("]\n")
.append(slice.getSummary()));
return str.toString();
} else {
return null;
}
}
@CapabilityMethod
public HashMap<String, List<EvaluatedSlice>> getActivatedSlices() {
return cache.activatedSlices;
}
@CapabilityMethod
public void clearActivatedSlices(String userId) {
cache.activatedSlices.remove(userId);
}
@CapabilityMethod
public boolean hasActivatedSlices(String userId) {
HashMap<String, List<EvaluatedSlice>> activatedSlices = cache.activatedSlices;
if (!activatedSlices.containsKey(userId)) {
return false;
}
return !activatedSlices.get(userId).isEmpty();
}
@CapabilityMethod
public int getActivatedSlicesSize(String userId) {
return cache.activatedSlices.get(userId).size();
}
@CapabilityMethod
public List<EvaluatedSlice> getActivatedSlices(String userId) {
return cache.activatedSlices.get(userId);
}
@CapabilityMethod
public void cleanSelectedSliceFilter() {
this.selectedSlices.clear();
}
private List<List<MemorySlice>> loadSlicesByDate(LocalDate date) throws IOException, ClassNotFoundException { private List<List<MemorySlice>> loadSlicesByDate(LocalDate date) throws IOException, ClassNotFoundException {
if (!dateIndex.containsKey(date)) { if (!dateIndex.containsKey(date)) {
@@ -98,18 +274,6 @@ public class MemoryCore extends PersistableObject {
return list; return list;
} }
@CapabilityMethod
public String getTopicTree() {
StringBuilder stringBuilder = new StringBuilder();
for (Map.Entry<String, TopicNode> entry : topicNodes.entrySet()) {
String rootName = entry.getKey();
TopicNode rootNode = entry.getValue();
stringBuilder.append(rootName).append("[root]").append("\r\n");
printSubTopicsTreeFormat(rootNode, "", stringBuilder);
}
return stringBuilder.toString();
}
private void printSubTopicsTreeFormat(TopicNode node, String prefix, StringBuilder stringBuilder) { private void printSubTopicsTreeFormat(TopicNode node, String prefix, StringBuilder stringBuilder) {
if (node.getTopicNodes() == null || node.getTopicNodes().isEmpty()) return; if (node.getTopicNodes() == null || node.getTopicNodes().isEmpty()) return;
@@ -122,7 +286,7 @@ public class MemoryCore extends PersistableObject {
} }
} }
public void insertMemory(List<String> topicPath, MemorySlice slice) throws IOException, ClassNotFoundException { private void insertMemory(List<String> topicPath, MemorySlice slice) throws IOException, ClassNotFoundException {
LocalDate now = LocalDate.now(); LocalDate now = LocalDate.now();
boolean hasSlice = false; boolean hasSlice = false;
MemoryNode node = null; MemoryNode node = null;
@@ -303,8 +467,6 @@ public class MemoryCore extends PersistableObject {
} }
} }
private TopicNode getTargetParentNode(List<String> topicPath, String targetTopic) { private TopicNode getTargetParentNode(List<String> topicPath, String targetTopic) {
String topTopic = topicPath.getFirst(); String topTopic = topicPath.getFirst();
if (!existedTopics.containsKey(topTopic)) { if (!existedTopics.containsKey(topTopic)) {
@@ -326,8 +488,126 @@ public class MemoryCore extends PersistableObject {
return targetParentNode; return targetParentNode;
} }
@CapabilityMethod public void updateCacheCounter(List<String> topicPath) {
public void cleanSelectedSliceFilter() { ConcurrentHashMap<List<String>, Integer> memoryNodeCacheCounter = cache.memoryNodeCacheCounter;
this.getSelectedSlices().clear(); if (memoryNodeCacheCounter.containsKey(topicPath)) {
Integer tempCount = memoryNodeCacheCounter.get(topicPath);
memoryNodeCacheCounter.put(topicPath, ++tempCount);
} else {
memoryNodeCacheCounter.put(topicPath, 1);
}
}
private void checkCacheDate() {
if (cache.cacheDate == null || cache.cacheDate.isBefore(LocalDate.now())) {
cache.memorySliceCache.clear();
cache.memoryNodeCacheCounter.clear();
cache.cacheDate = LocalDate.now();
}
}
private void updateCache(List<String> topicPath, MemoryResult memoryResult) {
ConcurrentHashMap<List<String>, Integer> memoryNodeCacheCounter = cache.memoryNodeCacheCounter;
Integer tempCount = memoryNodeCacheCounter.get(topicPath);
if (tempCount == null) {
log.warn("[CacheCore] tempCount为null? memoryNodeCacheCounter: {}; topicPath: {}", memoryNodeCacheCounter, topicPath);
return;
}
if (tempCount >= 5) {
cache.memorySliceCache.put(topicPath, memoryResult);
}
}
private void updateUserDialogMap(MemorySlice slice) {
String summary = slice.getSummary();
LocalDateTime now = LocalDateTime.now();
ConcurrentHashMap<String, ConcurrentHashMap<LocalDateTime, String>> userDialogMap = cache.userDialogMap;
//更新userDialogMap
//移除两天前上下文缓存(切片总结)
List<LocalDateTime> keysToRemove = new ArrayList<>();
userDialogMap.forEach((k, v) -> v.forEach((i, j) -> {
if (now.minusDays(2).isAfter(i)) {
keysToRemove.add(i);
}
}));
for (LocalDateTime dateTime : keysToRemove) {
userDialogMap.forEach((k, v) -> v.remove(dateTime));
}
//放入新缓存
userDialogMap
.computeIfAbsent(slice.getStartUserId(), k -> new ConcurrentHashMap<>())
.merge(now, summary, (oldVal, newVal) -> oldVal + " " + newVal);
}
private void clearCacheByTopicPath(List<String> topicPath) {
cache.memorySliceCache.remove(topicPath);
}
private MemoryResult selectCache(List<String> path) {
ConcurrentHashMap<List<String>, MemoryResult> memorySliceCache = cache.memorySliceCache;
if (memorySliceCache.containsKey(path)) {
return memorySliceCache.get(path);
}
return null;
}
@Override
protected String getCoreKey() {
return "memory-core";
}
public ConcurrentHashMap<String, ConcurrentHashMap<LocalDateTime, String>> getUserDialogMap() {
return cache.userDialogMap;
}
private MemoryResult cacheFilter(MemoryResult memoryResult) {
//过滤掉与缓存重复的切片
CopyOnWriteArrayList<MemorySliceResult> memorySliceResult = memoryResult.getMemorySliceResult();
List<MemorySlice> relatedMemorySliceResult = memoryResult.getRelatedMemorySliceResult();
cache.dialogMap.forEach((k, v) -> {
memorySliceResult.removeIf(m -> m.getMemorySlice().getSummary().equals(v));
relatedMemorySliceResult.removeIf(m -> m.getSummary().equals(v));
});
return memoryResult;
}
@SuppressWarnings("FieldMayBeFinal")
private static class MemoryCache {
/**
* 近两日的对话总结缓存, 用于为大模型提供必要的记忆补充, hashmap以切片的存储时间为键总结为值
* 该部分作为'主LLM'system prompt常驻
* 该部分作为近两日的整体对话缓存, 不区分用户
*/
private HashMap<LocalDateTime, String> dialogMap = new HashMap<>();
/**
* 近两日的区分用户的对话总结缓存在prompt结构上比dialogMap层级深一层, dialogMap更具近两日整体对话的摘要性质
*/
private ConcurrentHashMap<String/*userId*/, ConcurrentHashMap<LocalDateTime, String>> userDialogMap = new ConcurrentHashMap<>();
/**
* memorySliceCache计数器每日清空
*/
private ConcurrentHashMap<List<String> /*触发查询的主题列表*/, Integer> memoryNodeCacheCounter = new ConcurrentHashMap<>();
/**
* 记忆切片缓存每日清空
* 用于记录作为终点节点调用次数最多的记忆节点的切片数据
*/
private ConcurrentHashMap<List<String> /*主题路径*/, MemoryResult /*切片列表*/> memorySliceCache = new ConcurrentHashMap<>();
/**
* 缓存日期
*/
private LocalDate cacheDate;
private HashMap<String, List<EvaluatedSlice>> activatedSlices = new HashMap<>();
private MemoryCache() {
}
} }
} }

View File

@@ -1,4 +1,4 @@
package work.slhaf.partner.core.cognation.submodule.memory.exception; package work.slhaf.partner.core.memory.exception;
public class NullSliceListException extends RuntimeException { public class NullSliceListException extends RuntimeException {
public NullSliceListException(String message) { public NullSliceListException(String message) {

View File

@@ -1,4 +1,4 @@
package work.slhaf.partner.core.cognation.submodule.memory.exception; package work.slhaf.partner.core.memory.exception;
public class UnExistedDateIndexException extends RuntimeException { public class UnExistedDateIndexException extends RuntimeException {
public UnExistedDateIndexException(String message) { public UnExistedDateIndexException(String message) {

View File

@@ -1,4 +1,4 @@
package work.slhaf.partner.core.cognation.submodule.memory.exception; package work.slhaf.partner.core.memory.exception;
public class UnExistedTopicException extends RuntimeException { public class UnExistedTopicException extends RuntimeException {
public UnExistedTopicException(String message) { public UnExistedTopicException(String message) {

View File

@@ -1,4 +1,4 @@
package work.slhaf.partner.core.cognation.submodule.memory.pojo; package work.slhaf.partner.core.memory.pojo;
import lombok.Builder; import lombok.Builder;
import lombok.Data; import lombok.Data;

View File

@@ -1,9 +1,8 @@
package work.slhaf.partner.core.cognation.common.pojo; package work.slhaf.partner.core.memory.pojo;
import lombok.Data; import lombok.Data;
import lombok.EqualsAndHashCode; import lombok.EqualsAndHashCode;
import work.slhaf.partner.api.common.entity.PersistableObject; import work.slhaf.partner.api.common.entity.PersistableObject;
import work.slhaf.partner.core.cognation.submodule.memory.pojo.MemorySlice;
import java.io.Serial; import java.io.Serial;
import java.util.List; import java.util.List;

View File

@@ -1,4 +1,4 @@
package work.slhaf.partner.core.cognation.submodule.memory.pojo; package work.slhaf.partner.core.memory.pojo;
import lombok.Data; import lombok.Data;
import lombok.EqualsAndHashCode; import lombok.EqualsAndHashCode;

View File

@@ -1,10 +1,9 @@
package work.slhaf.partner.core.cognation.common.pojo; package work.slhaf.partner.core.memory.pojo;
import com.alibaba.fastjson2.annotation.JSONField; import com.alibaba.fastjson2.annotation.JSONField;
import lombok.Data; import lombok.Data;
import lombok.EqualsAndHashCode; import lombok.EqualsAndHashCode;
import work.slhaf.partner.api.common.entity.PersistableObject; import work.slhaf.partner.api.common.entity.PersistableObject;
import work.slhaf.partner.core.cognation.submodule.memory.pojo.MemorySlice;
import java.io.Serial; import java.io.Serial;

View File

@@ -1,11 +1,11 @@
package work.slhaf.partner.core.cognation.submodule.memory.pojo.node; package work.slhaf.partner.core.memory.pojo.node;
import lombok.Data; import lombok.Data;
import lombok.EqualsAndHashCode; import lombok.EqualsAndHashCode;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import work.slhaf.partner.api.common.entity.PersistableObject; import work.slhaf.partner.api.common.entity.PersistableObject;
import work.slhaf.partner.core.cognation.submodule.memory.exception.NullSliceListException; import work.slhaf.partner.core.memory.exception.NullSliceListException;
import work.slhaf.partner.core.cognation.submodule.memory.pojo.MemorySlice; import work.slhaf.partner.core.memory.pojo.MemorySlice;
import java.io.*; import java.io.*;
import java.nio.file.Files; import java.nio.file.Files;

View File

@@ -1,4 +1,4 @@
package work.slhaf.partner.core.cognation.submodule.memory.pojo.node; package work.slhaf.partner.core.memory.pojo.node;
import lombok.Data; import lombok.Data;
import lombok.EqualsAndHashCode; import lombok.EqualsAndHashCode;

View File

@@ -1,7 +1,7 @@
package work.slhaf.partner.core.cognation.submodule.perceive; package work.slhaf.partner.core.perceive;
import work.slhaf.partner.api.agent.factory.capability.annotation.Capability; import work.slhaf.partner.api.agent.factory.capability.annotation.Capability;
import work.slhaf.partner.core.cognation.submodule.perceive.pojo.User; import work.slhaf.partner.core.perceive.pojo.User;
@Capability(value = "perceive") @Capability(value = "perceive")
public interface PerceiveCapability { public interface PerceiveCapability {

View File

@@ -1,13 +1,15 @@
package work.slhaf.partner.core.cognation.submodule.perceive; package work.slhaf.partner.core.perceive;
import lombok.Data;
import lombok.EqualsAndHashCode; import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.Setter;
import work.slhaf.partner.api.agent.factory.capability.annotation.CapabilityCore; import work.slhaf.partner.api.agent.factory.capability.annotation.CapabilityCore;
import work.slhaf.partner.api.agent.factory.capability.annotation.CapabilityMethod; import work.slhaf.partner.api.agent.factory.capability.annotation.CapabilityMethod;
import work.slhaf.partner.api.common.entity.PersistableObject; import work.slhaf.partner.core.PartnerCore;
import work.slhaf.partner.core.cognation.cognation.exception.UserNotExistsException; import work.slhaf.partner.core.cognation.exception.UserNotExistsException;
import work.slhaf.partner.core.cognation.submodule.perceive.pojo.User; import work.slhaf.partner.core.perceive.pojo.User;
import java.io.IOException;
import java.io.Serial; import java.io.Serial;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.HashMap; import java.util.HashMap;
@@ -16,13 +18,13 @@ import java.util.UUID;
import java.util.concurrent.locks.ReentrantLock; import java.util.concurrent.locks.ReentrantLock;
@EqualsAndHashCode(callSuper = true) @EqualsAndHashCode(callSuper = true)
@Data
@CapabilityCore(value = "perceive") @CapabilityCore(value = "perceive")
public class PerceiveCore extends PersistableObject { @Getter
@Setter
public class PerceiveCore extends PartnerCore<PerceiveCore> {
@Serial @Serial
private static final long serialVersionUID = 1L; private static final long serialVersionUID = 1L;
private static volatile PerceiveCore perceiveCore;
private static final ReentrantLock usersLock = new ReentrantLock(); private static final ReentrantLock usersLock = new ReentrantLock();
/** /**
@@ -30,12 +32,7 @@ public class PerceiveCore extends PersistableObject {
*/ */
private List<User> users = new ArrayList<>(); private List<User> users = new ArrayList<>();
public PerceiveCore() { public PerceiveCore() throws IOException, ClassNotFoundException {
perceiveCore = this;
}
public static PerceiveCore getInstance() {
return perceiveCore;
} }
@CapabilityMethod @CapabilityMethod
@@ -94,4 +91,9 @@ public class PerceiveCore extends PersistableObject {
user.updateRelationChange(user.getRelationChange()); user.updateRelationChange(user.getRelationChange());
usersLock.unlock(); usersLock.unlock();
} }
@Override
protected String getCoreKey() {
return "perceive-core";
}
} }

View File

@@ -1,4 +1,4 @@
package work.slhaf.partner.core.cognation.submodule.perceive.pojo; package work.slhaf.partner.core.perceive.pojo;
import lombok.Data; import lombok.Data;
import lombok.EqualsAndHashCode; import lombok.EqualsAndHashCode;

View File

@@ -1,130 +0,0 @@
package work.slhaf.partner.core.session;
import com.alibaba.fastjson2.JSONObject;
import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.extern.slf4j.Slf4j;
import work.slhaf.partner.api.chat.pojo.Message;
import work.slhaf.partner.api.chat.pojo.MetaMessage;
import work.slhaf.partner.api.common.entity.PersistableObject;
import work.slhaf.partner.common.config.Config;
import java.io.*;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.nio.file.StandardCopyOption;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.UUID;
@EqualsAndHashCode(callSuper = true)
@Data
@Slf4j
public class SessionManager extends PersistableObject {
@Serial
private static final long serialVersionUID = 1L;
private static final String STORAGE_DIR = "./data/session/";
private static volatile SessionManager sessionManager;
private String id;
private HashMap<String /*startUserId*/, List<MetaMessage>> singleMetaMessageMap;
private String currentMemoryId;
private long lastUpdatedTime;
public static SessionManager getInstance() throws IOException, ClassNotFoundException {
if (sessionManager == null) {
synchronized (SessionManager.class) {
if (sessionManager == null) {
String id = Config.getConfig().getAgentId();
Path filePath = Paths.get(STORAGE_DIR, id + ".session");
if (Files.exists(filePath)) {
sessionManager = deserialize(id);
} else {
sessionManager = new SessionManager();
sessionManager.setSingleMetaMessageMap(new HashMap<>());
sessionManager.id = id;
sessionManager.setShutdownHook();
sessionManager.lastUpdatedTime = 0;
}
}
}
}
return sessionManager;
}
private void setShutdownHook() {
Runtime.getRuntime().addShutdownHook(new Thread(() -> {
try {
sessionManager.serialize();
log.info("[SessionManager] SessionManager 已保存");
} catch (IOException e) {
log.error("[SessionManager] 保存 SessionManager 失败: ", e);
}
}));
}
public void addMetaMessage(String userId, MetaMessage metaMessage) {
log.debug("[SessionManager] 当前会话历史: {}", JSONObject.toJSONString(singleMetaMessageMap));
if (singleMetaMessageMap.containsKey(userId)) {
singleMetaMessageMap.get(userId).add(metaMessage);
} else {
singleMetaMessageMap.put(userId, new java.util.ArrayList<>());
singleMetaMessageMap.get(userId).add(metaMessage);
}
log.debug("[SessionManager] 会话历史更新: {}", JSONObject.toJSONString(singleMetaMessageMap));
}
public List<Message> unpackAndClear(String userId) {
List<Message> messages = new ArrayList<>();
for (MetaMessage metaMessage : singleMetaMessageMap.get(userId)) {
messages.add(metaMessage.getUserMessage());
messages.add(metaMessage.getAssistantMessage());
}
singleMetaMessageMap.remove(userId);
return messages;
}
public void refreshMemoryId() {
currentMemoryId = UUID.randomUUID().toString();
}
public void serialize() throws IOException {
//先写入到临时文件,如果正常写入,则覆盖正式文件;否则删除临时文件
Path filePath = getFilePath(this.id + "-temp");
Files.createDirectories(Path.of(STORAGE_DIR));
try {
ObjectOutputStream oos = new ObjectOutputStream(new FileOutputStream(filePath.toFile()));
oos.writeObject(this);
oos.close();
Path path = getFilePath(this.id);
Files.move(filePath, path, StandardCopyOption.REPLACE_EXISTING);
log.info("[SessionManager] SessionManager 已保存到: {}", path);
} catch (IOException e) {
Files.delete(filePath);
log.error("[SessionManager] 序列化保存失败: {}", e.getMessage());
}
}
private static SessionManager deserialize(String id) throws IOException, ClassNotFoundException {
Path filePath = getFilePath(id);
try (ObjectInputStream ois = new ObjectInputStream(new FileInputStream(filePath.toFile()))) {
SessionManager sessionManager = (SessionManager) ois.readObject();
log.info("[SessionManager] SessionManager 已从文件加载: {}", filePath);
return sessionManager;
}
}
public void resetLastUpdatedTime() {
lastUpdatedTime = System.currentTimeMillis();
}
private static Path getFilePath(String id) {
return Paths.get(STORAGE_DIR, id + ".session");
}
}

View File

@@ -1,7 +0,0 @@
package work.slhaf.partner.gateway;
import work.slhaf.partner.core.interaction.data.InteractionOutputData;
public interface MessageSender {
void sendMessage(InteractionOutputData outputData);
}

Some files were not shown because too many files have changed in this diff Show More