新增配置加载功能并优化模型设置

- 新增 ConfigLoaderFactory 和 ModelConfigFactory 以及对应的默认实现用于加载模型配置和提示词列表
- 重构 ActivateModel 接口,支持基本提示和特定提示的加载,具体逻辑待实现,可通过ModelConfigFactory加载
- 优化模块注册和能力注入相关逻辑
- 添加了必要注释
This commit is contained in:
2025-07-31 22:13:10 +08:00
parent ade922cbc2
commit 64a7ed261e
35 changed files with 298 additions and 66 deletions

View File

@@ -4,14 +4,33 @@ import work.slhaf.partner.api.entity.AgentContext;
import work.slhaf.partner.api.factory.AgentRegisterFactory;
import work.slhaf.partner.api.flow.AgentInteraction;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
/**
* Agent启动类
*/
public class Agent {
private static final List<Runnable> runners = new ArrayList<>();
public static void run(Class<?> clazz) {
AgentContext context = AgentRegisterFactory.launch(clazz.getPackage().getName());
AgentInteraction.launch(context);
launchRunners();
}
private static void launchRunners() {
ExecutorService executorService = Executors.newVirtualThreadPerTaskExecutor();
for (Runnable runner : runners) {
executorService.execute(runner);
}
executorService.close();
}
public static void addRunner(Runnable runnable) {
runners.add(runnable);
}
}

View File

@@ -1,7 +1,6 @@
package work.slhaf.partner.api.flow.abstracts;
package work.slhaf.partner.api.common.chat;
import lombok.Data;
import work.slhaf.partner.api.common.chat.ChatClient;
import work.slhaf.partner.api.common.chat.pojo.Message;
import java.util.List;

View File

@@ -2,6 +2,7 @@ package work.slhaf.partner.api.factory;
import cn.hutool.core.bean.BeanUtil;
import work.slhaf.partner.api.entity.AgentContext;
import work.slhaf.partner.api.factory.config.ConfigLoaderFactory;
import work.slhaf.partner.api.factory.entity.AgentRegisterContext;
import work.slhaf.partner.api.factory.capability.CapabilityCheckFactory;
import work.slhaf.partner.api.factory.capability.CapabilityRegisterFactory;
@@ -16,6 +17,8 @@ public class AgentRegisterFactory {
public static AgentContext launch(String path) {
AgentRegisterContext registerContext = new AgentRegisterContext(path);
//流程
//0. 加载配置
new ConfigLoaderFactory().execute(registerContext);
//1. 执行register和check逻辑
new CapabilityRegisterFactory().execute(registerContext);
new CapabilityCheckFactory().execute(registerContext);

View File

@@ -16,6 +16,9 @@ import java.util.stream.Collectors;
import static work.slhaf.partner.api.common.util.AgentUtil.methodSignature;
/**
* 执行<code>Capability</code>相关检查
*/
public class CapabilityCheckFactory extends AgentBaseFactory {
private Reflections reflections;
@@ -38,6 +41,9 @@ public class CapabilityCheckFactory extends AgentBaseFactory {
checkInjectCapability();
}
/**
* 检查<code>@InjectCapability</code>注解是否只用在<code>@CapabilityHolder</code>所标识类的字段上
*/
private void checkInjectCapability() {
reflections.getFieldsAnnotatedWith(InjectCapability.class).forEach(field -> {
if (!field.getDeclaringClass().isAssignableFrom(CapabilityHolder.class)) {
@@ -46,6 +52,9 @@ public class CapabilityCheckFactory extends AgentBaseFactory {
});
}
/**
* 检查是否包含协调方法,如果存在,则进一步检查是否存在<code>@CoordinateManager</code>提供对应的实现
*/
private void checkCoordinatedMethods() {
//检查各个capability中是否含有ToCoordinated注解
//如果含有则需要查找AbstractCognationManager的子类,看这里是否有对应的Coordinated注解所在方法
@@ -93,6 +102,9 @@ public class CapabilityCheckFactory extends AgentBaseFactory {
return methodsCoordinated;
}
/**
* 查看在<code>Capability</code>在对应的<core>CapabilityCore</core>中存在尚未实现的方法
*/
private void checkCapabilityMethods() {
HashMap<String, List<Method>> capabilitiesMethods = getCapabilityMethods(capabilities);
StringBuilder sb = new StringBuilder();
@@ -151,6 +163,9 @@ public class CapabilityCheckFactory extends AgentBaseFactory {
return capabilityMethods;
}
/**
* 检查<code>Capability</code>和<code>CapabilityCore</code>的数量和标识是否匹配
*/
private void checkCountAndCapabilities() {
if (cores.size() != capabilities.size()) {
throw new UnMatchedCapabilityException("Capability 注册异常: 已存在的CapabilityCore与Capability数量不匹配!");

View File

@@ -6,7 +6,7 @@ import work.slhaf.partner.api.factory.entity.AgentRegisterContext;
import work.slhaf.partner.api.factory.capability.annotation.Capability;
import work.slhaf.partner.api.factory.capability.annotation.InjectCapability;
import work.slhaf.partner.api.factory.capability.annotation.ToCoordinated;
import work.slhaf.partner.api.factory.capability.exception.ProxySetFailedException;
import work.slhaf.partner.api.factory.capability.exception.ProxySetFailedExceptionCapability;
import java.lang.reflect.Field;
import java.lang.reflect.Proxy;
@@ -16,6 +16,9 @@ import java.util.function.Function;
import static work.slhaf.partner.api.common.util.AgentUtil.methodSignature;
/**
* 负责执行<code>Capability</code>的注入逻辑
*/
public class CapabilityInjectFactory extends AgentBaseFactory {
private Reflections reflections;
@@ -37,7 +40,6 @@ public class CapabilityInjectFactory extends AgentBaseFactory {
Set<Field> fields = reflections.getFieldsAnnotatedWith(InjectCapability.class);
//在动态代理内部,通过函数路由表调用对应的方法
createProxy(fields);
}
private void createProxy(Set<Field> fields) {
@@ -60,7 +62,7 @@ public class CapabilityInjectFactory extends AgentBaseFactory {
field.set(capabilityHolderInstances.get(field.getDeclaringClass()), instance);
}
} catch (Exception e) {
throw new ProxySetFailedException("代理设置失败", e);
throw new ProxySetFailedExceptionCapability("代理设置失败", e);
}
}

View File

@@ -4,9 +4,9 @@ import org.reflections.Reflections;
import work.slhaf.partner.api.factory.entity.AgentBaseFactory;
import work.slhaf.partner.api.factory.entity.AgentRegisterContext;
import work.slhaf.partner.api.factory.capability.annotation.*;
import work.slhaf.partner.api.factory.capability.exception.CoreInstancesCreateFailedException;
import work.slhaf.partner.api.factory.capability.exception.CoreInstancesCreateFailedExceptionCapability;
import work.slhaf.partner.api.factory.capability.exception.DuplicateMethodException;
import work.slhaf.partner.api.factory.capability.exception.FactoryExecuteFailedException;
import work.slhaf.partner.api.factory.capability.exception.CapabilityFactoryExecuteFailedException;
import java.lang.reflect.Constructor;
import java.lang.reflect.InvocationTargetException;
@@ -16,9 +16,13 @@ import java.util.HashMap;
import java.util.Set;
import java.util.function.Function;
import static cn.hutool.core.util.ClassUtil.isNormalClass;
import static work.slhaf.partner.api.common.util.AgentUtil.methodSignature;
/**
* 负责获取<code>@Capability</code>和<code>@CapabilityCore</code>标识的类,并生成函数路由表、设置<code>Core</code>实例用于后续注入
*/
public final class CapabilityRegisterFactory extends AgentBaseFactory {
private Reflections reflections;
@@ -35,19 +39,22 @@ public final class CapabilityRegisterFactory extends AgentBaseFactory {
methodsRouterTable = context.getMethodsRouterTable();
coordinatedMethodsRouterTable = context.getCoordinatedMethodsRouterTable();
capabilityCoreInstances = context.getCapabilityCoreInstances();
capabilityHolderInstances = context.getCapabilityHolderInstances();
cores = context.getCores();
capabilities = context.getCapabilities();
capabilityHolderInstances = context.getCapabilityHolderInstances();
}
@Override
protected void run() throws NoSuchMethodException, InvocationTargetException, InstantiationException, IllegalAccessException {
protected void run() throws InvocationTargetException, NoSuchMethodException, InstantiationException, IllegalAccessException {
setCapabilityCoreInstances();
setAnnotatedClasses();
generateRouterTable();
}
private void setAnnotatedClasses() throws NoSuchMethodException, InvocationTargetException, InstantiationException, IllegalAccessException {
/**
* 设置<code>CapabilityCore</code>、<code>Capability</code>注解标识类
*/
private void setAnnotatedClasses() throws InvocationTargetException, NoSuchMethodException, InstantiationException, IllegalAccessException {
cores.addAll(reflections.getTypesAnnotatedWith(CapabilityCore.class));
capabilities.addAll(reflections.getTypesAnnotatedWith(Capability.class));
setCapabilityHolderInstances();
@@ -55,16 +62,25 @@ public final class CapabilityRegisterFactory extends AgentBaseFactory {
private void setCapabilityHolderInstances() throws NoSuchMethodException, InvocationTargetException, InstantiationException, IllegalAccessException {
for (Class<?> clazz : reflections.getTypesAnnotatedWith(CapabilityHolder.class)) {
if (!isNormalClass(clazz)){
continue;
}
Object o = clazz.getDeclaredConstructor().newInstance();
capabilityHolderInstances.put(clazz, o);
}
}
/**
* 生成函数路由表
*/
private void generateRouterTable() {
generateMethodsRouterTable();
generateCoordinatedMethodsRouterTable();
}
/**
* 生成协调函数对应的函数路由表
*/
private void generateCoordinatedMethodsRouterTable() {
Set<Method> methodsAnnotatedWith = reflections.getMethodsAnnotatedWith(Coordinated.class);
if (methodsAnnotatedWith.isEmpty()) {
@@ -85,11 +101,14 @@ public final class CapabilityRegisterFactory extends AgentBaseFactory {
coordinatedMethodsRouterTable.put(key, function);
});
} catch (Exception e) {
throw new FactoryExecuteFailedException("创建协调方法路由表出错", e);
throw new CapabilityFactoryExecuteFailedException("创建协调方法路由表出错", e);
}
}
/**
* 获取<code>CoordinateManager</code>子类实例
*/
private HashMap<String, Object> getCoordinateManagerInstances() throws InvocationTargetException, InstantiationException, IllegalAccessException, NoSuchMethodException {
HashMap<String, Object> map = new HashMap<>();
for (Class<?> c : reflections.getTypesAnnotatedWith(CoordinateManager.class)) {
@@ -106,6 +125,9 @@ public final class CapabilityRegisterFactory extends AgentBaseFactory {
return map;
}
/**
* 生成普通方法对应的函数路由表
*/
private void generateMethodsRouterTable() {
//扫描`@Capability`与`@CapabilityMethod`注解的类与方法
//将`capabilityValue.methodSignature`作为key,函数对象为通过反射拿到的core实例对应的方法
@@ -127,6 +149,9 @@ public final class CapabilityRegisterFactory extends AgentBaseFactory {
}));
}
/**
* 反射获取<code>CapabilityCore</code>实例
*/
private void setCapabilityCoreInstances() {
try {
for (Class<?> core : cores) {
@@ -136,7 +161,7 @@ public final class CapabilityRegisterFactory extends AgentBaseFactory {
}
} catch (InvocationTargetException | NoSuchMethodException | InstantiationException |
IllegalAccessException e) {
throw new CoreInstancesCreateFailedException("core实例创建失败");
throw new CoreInstancesCreateFailedExceptionCapability("core实例创建失败");
}
}
}

View File

@@ -0,0 +1,11 @@
package work.slhaf.partner.api.factory.capability.exception;
public class CapabilityFactoryExecuteFailedException extends RuntimeException {
public CapabilityFactoryExecuteFailedException(String message) {
super("CapabilityRegisterFactory 执行失败: " + message);
}
public CapabilityFactoryExecuteFailedException(String message, Throwable cause) {
super("CapabilityRegisterFactory 执行失败: " + message, cause);
}
}

View File

@@ -1,11 +0,0 @@
package work.slhaf.partner.api.factory.capability.exception;
public class CoreInstancesCreateFailedException extends FactoryExecuteFailedException {
public CoreInstancesCreateFailedException(String message) {
super(message);
}
public CoreInstancesCreateFailedException(String message, Throwable cause) {
super(message, cause);
}
}

View File

@@ -0,0 +1,11 @@
package work.slhaf.partner.api.factory.capability.exception;
public class CoreInstancesCreateFailedExceptionCapability extends CapabilityFactoryExecuteFailedException {
public CoreInstancesCreateFailedExceptionCapability(String message) {
super(message);
}
public CoreInstancesCreateFailedExceptionCapability(String message, Throwable cause) {
super(message, cause);
}
}

View File

@@ -1,11 +0,0 @@
package work.slhaf.partner.api.factory.capability.exception;
public class FactoryExecuteFailedException extends RuntimeException {
public FactoryExecuteFailedException(String message) {
super("CapabilityRegisterFactory 执行失败: " + message);
}
public FactoryExecuteFailedException(String message, Throwable cause) {
super("CapabilityRegisterFactory 执行失败: " + message, cause);
}
}

View File

@@ -1,11 +0,0 @@
package work.slhaf.partner.api.factory.capability.exception;
public class ProxySetFailedException extends FactoryExecuteFailedException{
public ProxySetFailedException(String message) {
super(message);
}
public ProxySetFailedException(String message, Throwable cause) {
super(message, cause);
}
}

View File

@@ -0,0 +1,11 @@
package work.slhaf.partner.api.factory.capability.exception;
public class ProxySetFailedExceptionCapability extends CapabilityFactoryExecuteFailedException {
public ProxySetFailedExceptionCapability(String message) {
super(message);
}
public ProxySetFailedExceptionCapability(String message, Throwable cause) {
super(message, cause);
}
}

View File

@@ -0,0 +1,22 @@
package work.slhaf.partner.api.factory.config;
import work.slhaf.partner.api.factory.entity.AgentBaseFactory;
import work.slhaf.partner.api.factory.entity.AgentRegisterContext;
public class ConfigLoaderFactory extends AgentBaseFactory {
@Override
protected void setVariables(AgentRegisterContext context) {
}
@Override
protected void run() {
//反射获取是否存在其他的ModelConfigFactory子类如果存在则不使用默认配置工厂
ModelConfigFactory factory;
if (){
}
factory.load();
}
}

View File

@@ -0,0 +1,29 @@
package work.slhaf.partner.api.factory.config;
import work.slhaf.partner.api.common.chat.pojo.Message;
import work.slhaf.partner.api.factory.config.pojo.ModelConfig;
import java.util.HashMap;
import java.util.List;
/**
* 默认配置工厂
* 将从当前运行目录的config文件夹下创建并读取配置
*/
public class DefaultModelConfigFactory extends ModelConfigFactory {
private static final String MODEL_CONFIG_DIR = "./config/model/";
private static final String PROMPT_CONFIG_DIR = "./config/prompt/";
@Override
protected HashMap<String, List<Message>> loadPrompt() {
return null;
}
@Override
protected HashMap<String, ModelConfig> loadConfig() {
return null;
}
}

View File

@@ -0,0 +1,53 @@
package work.slhaf.partner.api.factory.config;
import work.slhaf.partner.api.common.chat.pojo.Message;
import work.slhaf.partner.api.factory.config.exception.UnExistModelConfigException;
import work.slhaf.partner.api.factory.config.exception.UnExistModelPromptException;
import work.slhaf.partner.api.factory.config.pojo.ModelConfig;
import java.util.HashMap;
import java.util.List;
public abstract class ModelConfigFactory {
public static ModelConfigFactory factory;
protected HashMap<String, ModelConfig> modelConfigMap;
protected HashMap<String, List<Message>> modelPromptMap;
public ModelConfigFactory() {
factory = this;
}
public void load() {
modelConfigMap = loadConfig();
modelPromptMap = loadPrompt();
}
protected abstract HashMap<String, List<Message>> loadPrompt();
protected abstract HashMap<String, ModelConfig> loadConfig();
public List<Message> loadModelPrompt(String modelKey){
if (!modelPromptMap.containsKey(modelKey)){
throw new UnExistModelPromptException("不存在的modelPrompt: "+modelKey);
}
return modelPromptMap.get(modelKey);
}
public ModelConfig loadModelConfig(String modelKey) {
if (!modelConfigMap.containsKey(modelKey)) {
throw new UnExistModelConfigException("不存在的modelKey: " + modelKey);
}
return modelConfigMap.get(modelKey);
}
public void updateModelConfig(String modelKey, ModelConfig config) {
if (!modelConfigMap.containsKey(modelKey)) {
throw new UnExistModelConfigException("不存在的modelKey: " + modelKey);
}
modelConfigMap.get(modelKey).setModel(config.getModel());
modelConfigMap.get(modelKey).setBaseUrl(config.getBaseUrl());
modelConfigMap.get(modelKey).setApikey(config.getApikey());
}
}

View File

@@ -0,0 +1,11 @@
package work.slhaf.partner.api.factory.config.exception;
public class ModelConfigFactoryFailedException extends RuntimeException {
public ModelConfigFactoryFailedException(String message, Throwable cause) {
super("ModelConfigFactory 执行失败: " + message, cause);
}
public ModelConfigFactoryFailedException(String message) {
super("ModelConfigFactory 执行失败: " + message);
}
}

View File

@@ -0,0 +1,11 @@
package work.slhaf.partner.api.factory.config.exception;
public class UnExistModelConfigException extends ModelConfigFactoryFailedException {
public UnExistModelConfigException(String message, Throwable e) {
super(message, e);
}
public UnExistModelConfigException(String message) {
super(message);
}
}

View File

@@ -0,0 +1,11 @@
package work.slhaf.partner.api.factory.config.exception;
public class UnExistModelPromptException extends ModelConfigFactoryFailedException{
public UnExistModelPromptException(String message) {
super(message);
}
public UnExistModelPromptException(String message, Throwable cause) {
super(message, cause);
}
}

View File

@@ -0,0 +1,10 @@
package work.slhaf.partner.api.factory.config.pojo;
import lombok.Data;
@Data
public class ModelConfig {
private String baseUrl;
private String apikey;
private String model;
}

View File

@@ -1,6 +1,6 @@
package work.slhaf.partner.api.factory.entity;
import work.slhaf.partner.api.factory.capability.exception.FactoryExecuteFailedException;
import work.slhaf.partner.api.factory.capability.exception.CapabilityFactoryExecuteFailedException;
import java.lang.reflect.InvocationTargetException;
@@ -10,7 +10,7 @@ public abstract class AgentBaseFactory {
setVariables(context);
run();
} catch (Exception e) {
throw new FactoryExecuteFailedException(e.getMessage(), e);
throw new CapabilityFactoryExecuteFailedException(e.getMessage(), e);
}
}

View File

@@ -4,6 +4,9 @@ import org.reflections.Reflections;
import work.slhaf.partner.api.factory.entity.AgentBaseFactory;
import work.slhaf.partner.api.factory.entity.AgentRegisterContext;
/**
* 负责扫描<code>@Module</code>注解获取模块实例
*/
public class ModuleRegisterFactory extends AgentBaseFactory {
private Reflections reflections;

View File

@@ -1,9 +1,7 @@
package work.slhaf.partner.api.factory.module.annotation;
import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;
import java.lang.annotation.*;
/**
* 用于注解执行模块

View File

@@ -0,0 +1,5 @@
package work.slhaf.partner.api.factory.module.annotation;
@AgentModule(name = "core",order = 5)
public @interface CoreModule {
}

View File

@@ -1,9 +1,12 @@
package work.slhaf.partner.api.flow.abstracts;
import work.slhaf.partner.api.common.chat.ChatClient;
import work.slhaf.partner.api.common.chat.Model;
import work.slhaf.partner.api.common.chat.constant.ChatConstant;
import work.slhaf.partner.api.common.chat.pojo.ChatResponse;
import work.slhaf.partner.api.common.chat.pojo.Message;
import work.slhaf.partner.api.factory.config.ModelConfigFactory;
import work.slhaf.partner.api.factory.config.pojo.ModelConfig;
import work.slhaf.partner.api.factory.module.annotation.Before;
import java.util.ArrayList;
@@ -11,13 +14,22 @@ import java.util.List;
public interface ActivateModel {
@Before
default void modelSettings() {
// Model model = new Model();
// ModelConfig modelConfig = ModelConfig.load(modelKey());
// model.setBaseMessages(withAwareness() ? ResourcesUtil.Prompt.loadPromptWithSelfAwareness(modelKey(), promptModule()) : ResourcesUtil.Prompt.loadPrompt(modelKey(), promptModule()));
// model.setChatClient(new ChatClient(modelConfig.getBaseUrl(), modelConfig.getApikey(), modelConfig.getModel()));
Model model = new Model();
ModelConfig modelConfig = ModelConfigFactory.factory.loadModelConfig(modelKey());
model.setBaseMessages(withBasicPrompt() ? loadSpecificPromptAndBasicPrompt(modelKey(), promptModule()) : loadSpecificPrompt(modelKey(), promptModule()));
model.setChatClient(new ChatClient(modelConfig.getBaseUrl(), modelConfig.getApikey(), modelConfig.getModel()));
}
private List<Message> loadSpecificPrompt(String modelKey, String specificModule) {
return null;
}
private List<Message> loadSpecificPromptAndBasicPrompt(String modelKey, String specificModule) {
return null;
}
default ChatResponse chat() {
@@ -68,7 +80,7 @@ public interface ActivateModel {
String modelKey();
boolean withAwareness();
boolean withBasicPrompt();
String promptModule();
}

View File

@@ -2,8 +2,12 @@ package work.slhaf.partner.api.flow.abstracts;
import lombok.Getter;
import lombok.Setter;
import work.slhaf.partner.api.common.chat.Model;
import work.slhaf.partner.api.factory.capability.annotation.CapabilityHolder;
/**
* 模块基类
*/
@CapabilityHolder
public abstract class Module {

View File

@@ -14,7 +14,7 @@ public class ResourcesUtil {
private static final ClassLoader classloader = Agent.class.getClassLoader();
public static class Prompt {
private static final String SELF_AWARENESS_PATH = "prompt/self_awareness.json";
private static final String SELF_AWARENESS_PATH = "prompt/basic_prompt.json";
private static final String MODULE_PROMPT_PREFIX_PATH = "prompt/module/";
public static List<Message> loadPromptWithSelfAwareness(String modelKey, String promptType) {

View File

@@ -69,7 +69,7 @@ public class CoreModel extends CoreModule implements ActivateModel {
}
@Override
public boolean withAwareness() {
public boolean withBasicPrompt() {
return true;
}

View File

@@ -141,7 +141,7 @@ public class SliceSelectEvaluator extends AgentInteractionSubModule<EvaluatorInp
}
@Override
public boolean withAwareness() {
public boolean withBasicPrompt() {
return false;
}

View File

@@ -114,7 +114,7 @@ public class MemorySelectExtractor extends AgentInteractionSubModule<Interaction
}
@Override
public boolean withAwareness() {
public boolean withBasicPrompt() {
return false;
}

View File

@@ -71,7 +71,7 @@ public class MultiSummarizer extends AgentInteractionSubModule<SummarizeInput, S
}
@Override
public boolean withAwareness() {
public boolean withBasicPrompt() {
return true;
}

View File

@@ -83,7 +83,7 @@ public class SingleSummarizer extends AgentInteractionSubModule<List<Message>,Vo
}
@Override
public boolean withAwareness() {
public boolean withBasicPrompt() {
return false;
}

View File

@@ -49,7 +49,7 @@ public class TotalSummarizer extends AgentInteractionSubModule<HashMap<String, S
}
@Override
public boolean withAwareness() {
public boolean withBasicPrompt() {
return true;
}

View File

@@ -91,7 +91,7 @@ public class RelationExtractor extends AgentInteractionSubModule<InteractionCont
}
@Override
public boolean withAwareness() {
public boolean withBasicPrompt() {
return true;
}

View File

@@ -64,7 +64,7 @@ public class StaticMemoryExtractor extends AgentInteractionSubModule<Interaction
}
@Override
public boolean withAwareness() {
public boolean withBasicPrompt() {
return true;
}