新增配置加载功能并优化模型设置

- 新增 ConfigLoaderFactory 和 ModelConfigFactory 以及对应的默认实现用于加载模型配置和提示词列表
- 重构 ActivateModel 接口,支持基本提示和特定提示的加载,具体逻辑待实现,可通过ModelConfigFactory加载
- 优化模块注册和能力注入相关逻辑
- 添加了必要注释
This commit is contained in:
2025-07-31 22:13:10 +08:00
parent ade922cbc2
commit 64a7ed261e
35 changed files with 298 additions and 66 deletions

View File

@@ -4,14 +4,33 @@ import work.slhaf.partner.api.entity.AgentContext;
import work.slhaf.partner.api.factory.AgentRegisterFactory; import work.slhaf.partner.api.factory.AgentRegisterFactory;
import work.slhaf.partner.api.flow.AgentInteraction; import work.slhaf.partner.api.flow.AgentInteraction;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
/** /**
* Agent启动类 * Agent启动类
*/ */
public class Agent { public class Agent {
private static final List<Runnable> runners = new ArrayList<>();
public static void run(Class<?> clazz) { public static void run(Class<?> clazz) {
AgentContext context = AgentRegisterFactory.launch(clazz.getPackage().getName()); AgentContext context = AgentRegisterFactory.launch(clazz.getPackage().getName());
AgentInteraction.launch(context); AgentInteraction.launch(context);
launchRunners();
} }
private static void launchRunners() {
ExecutorService executorService = Executors.newVirtualThreadPerTaskExecutor();
for (Runnable runner : runners) {
executorService.execute(runner);
}
executorService.close();
}
public static void addRunner(Runnable runnable) {
runners.add(runnable);
}
} }

View File

@@ -1,7 +1,6 @@
package work.slhaf.partner.api.flow.abstracts; package work.slhaf.partner.api.common.chat;
import lombok.Data; import lombok.Data;
import work.slhaf.partner.api.common.chat.ChatClient;
import work.slhaf.partner.api.common.chat.pojo.Message; import work.slhaf.partner.api.common.chat.pojo.Message;
import java.util.List; import java.util.List;

View File

@@ -2,6 +2,7 @@ package work.slhaf.partner.api.factory;
import cn.hutool.core.bean.BeanUtil; import cn.hutool.core.bean.BeanUtil;
import work.slhaf.partner.api.entity.AgentContext; import work.slhaf.partner.api.entity.AgentContext;
import work.slhaf.partner.api.factory.config.ConfigLoaderFactory;
import work.slhaf.partner.api.factory.entity.AgentRegisterContext; import work.slhaf.partner.api.factory.entity.AgentRegisterContext;
import work.slhaf.partner.api.factory.capability.CapabilityCheckFactory; import work.slhaf.partner.api.factory.capability.CapabilityCheckFactory;
import work.slhaf.partner.api.factory.capability.CapabilityRegisterFactory; import work.slhaf.partner.api.factory.capability.CapabilityRegisterFactory;
@@ -16,6 +17,8 @@ public class AgentRegisterFactory {
public static AgentContext launch(String path) { public static AgentContext launch(String path) {
AgentRegisterContext registerContext = new AgentRegisterContext(path); AgentRegisterContext registerContext = new AgentRegisterContext(path);
//流程 //流程
//0. 加载配置
new ConfigLoaderFactory().execute(registerContext);
//1. 执行register和check逻辑 //1. 执行register和check逻辑
new CapabilityRegisterFactory().execute(registerContext); new CapabilityRegisterFactory().execute(registerContext);
new CapabilityCheckFactory().execute(registerContext); new CapabilityCheckFactory().execute(registerContext);

View File

@@ -16,6 +16,9 @@ import java.util.stream.Collectors;
import static work.slhaf.partner.api.common.util.AgentUtil.methodSignature; import static work.slhaf.partner.api.common.util.AgentUtil.methodSignature;
/**
* 执行<code>Capability</code>相关检查
*/
public class CapabilityCheckFactory extends AgentBaseFactory { public class CapabilityCheckFactory extends AgentBaseFactory {
private Reflections reflections; private Reflections reflections;
@@ -38,6 +41,9 @@ public class CapabilityCheckFactory extends AgentBaseFactory {
checkInjectCapability(); checkInjectCapability();
} }
/**
* 检查<code>@InjectCapability</code>注解是否只用在<code>@CapabilityHolder</code>所标识类的字段上
*/
private void checkInjectCapability() { private void checkInjectCapability() {
reflections.getFieldsAnnotatedWith(InjectCapability.class).forEach(field -> { reflections.getFieldsAnnotatedWith(InjectCapability.class).forEach(field -> {
if (!field.getDeclaringClass().isAssignableFrom(CapabilityHolder.class)) { if (!field.getDeclaringClass().isAssignableFrom(CapabilityHolder.class)) {
@@ -46,6 +52,9 @@ public class CapabilityCheckFactory extends AgentBaseFactory {
}); });
} }
/**
* 检查是否包含协调方法,如果存在,则进一步检查是否存在<code>@CoordinateManager</code>提供对应的实现
*/
private void checkCoordinatedMethods() { private void checkCoordinatedMethods() {
//检查各个capability中是否含有ToCoordinated注解 //检查各个capability中是否含有ToCoordinated注解
//如果含有则需要查找AbstractCognationManager的子类,看这里是否有对应的Coordinated注解所在方法 //如果含有则需要查找AbstractCognationManager的子类,看这里是否有对应的Coordinated注解所在方法
@@ -93,6 +102,9 @@ public class CapabilityCheckFactory extends AgentBaseFactory {
return methodsCoordinated; return methodsCoordinated;
} }
/**
* 查看在<code>Capability</code>在对应的<core>CapabilityCore</core>中存在尚未实现的方法
*/
private void checkCapabilityMethods() { private void checkCapabilityMethods() {
HashMap<String, List<Method>> capabilitiesMethods = getCapabilityMethods(capabilities); HashMap<String, List<Method>> capabilitiesMethods = getCapabilityMethods(capabilities);
StringBuilder sb = new StringBuilder(); StringBuilder sb = new StringBuilder();
@@ -151,6 +163,9 @@ public class CapabilityCheckFactory extends AgentBaseFactory {
return capabilityMethods; return capabilityMethods;
} }
/**
* 检查<code>Capability</code>和<code>CapabilityCore</code>的数量和标识是否匹配
*/
private void checkCountAndCapabilities() { private void checkCountAndCapabilities() {
if (cores.size() != capabilities.size()) { if (cores.size() != capabilities.size()) {
throw new UnMatchedCapabilityException("Capability 注册异常: 已存在的CapabilityCore与Capability数量不匹配!"); throw new UnMatchedCapabilityException("Capability 注册异常: 已存在的CapabilityCore与Capability数量不匹配!");

View File

@@ -6,7 +6,7 @@ import work.slhaf.partner.api.factory.entity.AgentRegisterContext;
import work.slhaf.partner.api.factory.capability.annotation.Capability; import work.slhaf.partner.api.factory.capability.annotation.Capability;
import work.slhaf.partner.api.factory.capability.annotation.InjectCapability; import work.slhaf.partner.api.factory.capability.annotation.InjectCapability;
import work.slhaf.partner.api.factory.capability.annotation.ToCoordinated; import work.slhaf.partner.api.factory.capability.annotation.ToCoordinated;
import work.slhaf.partner.api.factory.capability.exception.ProxySetFailedException; import work.slhaf.partner.api.factory.capability.exception.ProxySetFailedExceptionCapability;
import java.lang.reflect.Field; import java.lang.reflect.Field;
import java.lang.reflect.Proxy; import java.lang.reflect.Proxy;
@@ -16,6 +16,9 @@ import java.util.function.Function;
import static work.slhaf.partner.api.common.util.AgentUtil.methodSignature; import static work.slhaf.partner.api.common.util.AgentUtil.methodSignature;
/**
* 负责执行<code>Capability</code>的注入逻辑
*/
public class CapabilityInjectFactory extends AgentBaseFactory { public class CapabilityInjectFactory extends AgentBaseFactory {
private Reflections reflections; private Reflections reflections;
@@ -37,7 +40,6 @@ public class CapabilityInjectFactory extends AgentBaseFactory {
Set<Field> fields = reflections.getFieldsAnnotatedWith(InjectCapability.class); Set<Field> fields = reflections.getFieldsAnnotatedWith(InjectCapability.class);
//在动态代理内部,通过函数路由表调用对应的方法 //在动态代理内部,通过函数路由表调用对应的方法
createProxy(fields); createProxy(fields);
} }
private void createProxy(Set<Field> fields) { private void createProxy(Set<Field> fields) {
@@ -60,7 +62,7 @@ public class CapabilityInjectFactory extends AgentBaseFactory {
field.set(capabilityHolderInstances.get(field.getDeclaringClass()), instance); field.set(capabilityHolderInstances.get(field.getDeclaringClass()), instance);
} }
} catch (Exception e) { } catch (Exception e) {
throw new ProxySetFailedException("代理设置失败", e); throw new ProxySetFailedExceptionCapability("代理设置失败", e);
} }
} }

View File

@@ -4,9 +4,9 @@ import org.reflections.Reflections;
import work.slhaf.partner.api.factory.entity.AgentBaseFactory; import work.slhaf.partner.api.factory.entity.AgentBaseFactory;
import work.slhaf.partner.api.factory.entity.AgentRegisterContext; import work.slhaf.partner.api.factory.entity.AgentRegisterContext;
import work.slhaf.partner.api.factory.capability.annotation.*; import work.slhaf.partner.api.factory.capability.annotation.*;
import work.slhaf.partner.api.factory.capability.exception.CoreInstancesCreateFailedException; import work.slhaf.partner.api.factory.capability.exception.CoreInstancesCreateFailedExceptionCapability;
import work.slhaf.partner.api.factory.capability.exception.DuplicateMethodException; import work.slhaf.partner.api.factory.capability.exception.DuplicateMethodException;
import work.slhaf.partner.api.factory.capability.exception.FactoryExecuteFailedException; import work.slhaf.partner.api.factory.capability.exception.CapabilityFactoryExecuteFailedException;
import java.lang.reflect.Constructor; import java.lang.reflect.Constructor;
import java.lang.reflect.InvocationTargetException; import java.lang.reflect.InvocationTargetException;
@@ -16,9 +16,13 @@ import java.util.HashMap;
import java.util.Set; import java.util.Set;
import java.util.function.Function; import java.util.function.Function;
import static cn.hutool.core.util.ClassUtil.isNormalClass;
import static work.slhaf.partner.api.common.util.AgentUtil.methodSignature; import static work.slhaf.partner.api.common.util.AgentUtil.methodSignature;
/**
* 负责获取<code>@Capability</code>和<code>@CapabilityCore</code>标识的类,并生成函数路由表、设置<code>Core</code>实例用于后续注入
*/
public final class CapabilityRegisterFactory extends AgentBaseFactory { public final class CapabilityRegisterFactory extends AgentBaseFactory {
private Reflections reflections; private Reflections reflections;
@@ -35,19 +39,22 @@ public final class CapabilityRegisterFactory extends AgentBaseFactory {
methodsRouterTable = context.getMethodsRouterTable(); methodsRouterTable = context.getMethodsRouterTable();
coordinatedMethodsRouterTable = context.getCoordinatedMethodsRouterTable(); coordinatedMethodsRouterTable = context.getCoordinatedMethodsRouterTable();
capabilityCoreInstances = context.getCapabilityCoreInstances(); capabilityCoreInstances = context.getCapabilityCoreInstances();
capabilityHolderInstances = context.getCapabilityHolderInstances();
cores = context.getCores(); cores = context.getCores();
capabilities = context.getCapabilities(); capabilities = context.getCapabilities();
capabilityHolderInstances = context.getCapabilityHolderInstances();
} }
@Override @Override
protected void run() throws NoSuchMethodException, InvocationTargetException, InstantiationException, IllegalAccessException { protected void run() throws InvocationTargetException, NoSuchMethodException, InstantiationException, IllegalAccessException {
setCapabilityCoreInstances(); setCapabilityCoreInstances();
setAnnotatedClasses(); setAnnotatedClasses();
generateRouterTable(); generateRouterTable();
} }
private void setAnnotatedClasses() throws NoSuchMethodException, InvocationTargetException, InstantiationException, IllegalAccessException { /**
* 设置<code>CapabilityCore</code>、<code>Capability</code>注解标识类
*/
private void setAnnotatedClasses() throws InvocationTargetException, NoSuchMethodException, InstantiationException, IllegalAccessException {
cores.addAll(reflections.getTypesAnnotatedWith(CapabilityCore.class)); cores.addAll(reflections.getTypesAnnotatedWith(CapabilityCore.class));
capabilities.addAll(reflections.getTypesAnnotatedWith(Capability.class)); capabilities.addAll(reflections.getTypesAnnotatedWith(Capability.class));
setCapabilityHolderInstances(); setCapabilityHolderInstances();
@@ -55,16 +62,25 @@ public final class CapabilityRegisterFactory extends AgentBaseFactory {
private void setCapabilityHolderInstances() throws NoSuchMethodException, InvocationTargetException, InstantiationException, IllegalAccessException { private void setCapabilityHolderInstances() throws NoSuchMethodException, InvocationTargetException, InstantiationException, IllegalAccessException {
for (Class<?> clazz : reflections.getTypesAnnotatedWith(CapabilityHolder.class)) { for (Class<?> clazz : reflections.getTypesAnnotatedWith(CapabilityHolder.class)) {
if (!isNormalClass(clazz)){
continue;
}
Object o = clazz.getDeclaredConstructor().newInstance(); Object o = clazz.getDeclaredConstructor().newInstance();
capabilityHolderInstances.put(clazz, o); capabilityHolderInstances.put(clazz, o);
} }
} }
/**
* 生成函数路由表
*/
private void generateRouterTable() { private void generateRouterTable() {
generateMethodsRouterTable(); generateMethodsRouterTable();
generateCoordinatedMethodsRouterTable(); generateCoordinatedMethodsRouterTable();
} }
/**
* 生成协调函数对应的函数路由表
*/
private void generateCoordinatedMethodsRouterTable() { private void generateCoordinatedMethodsRouterTable() {
Set<Method> methodsAnnotatedWith = reflections.getMethodsAnnotatedWith(Coordinated.class); Set<Method> methodsAnnotatedWith = reflections.getMethodsAnnotatedWith(Coordinated.class);
if (methodsAnnotatedWith.isEmpty()) { if (methodsAnnotatedWith.isEmpty()) {
@@ -85,11 +101,14 @@ public final class CapabilityRegisterFactory extends AgentBaseFactory {
coordinatedMethodsRouterTable.put(key, function); coordinatedMethodsRouterTable.put(key, function);
}); });
} catch (Exception e) { } catch (Exception e) {
throw new FactoryExecuteFailedException("创建协调方法路由表出错", e); throw new CapabilityFactoryExecuteFailedException("创建协调方法路由表出错", e);
} }
} }
/**
* 获取<code>CoordinateManager</code>子类实例
*/
private HashMap<String, Object> getCoordinateManagerInstances() throws InvocationTargetException, InstantiationException, IllegalAccessException, NoSuchMethodException { private HashMap<String, Object> getCoordinateManagerInstances() throws InvocationTargetException, InstantiationException, IllegalAccessException, NoSuchMethodException {
HashMap<String, Object> map = new HashMap<>(); HashMap<String, Object> map = new HashMap<>();
for (Class<?> c : reflections.getTypesAnnotatedWith(CoordinateManager.class)) { for (Class<?> c : reflections.getTypesAnnotatedWith(CoordinateManager.class)) {
@@ -106,6 +125,9 @@ public final class CapabilityRegisterFactory extends AgentBaseFactory {
return map; return map;
} }
/**
* 生成普通方法对应的函数路由表
*/
private void generateMethodsRouterTable() { private void generateMethodsRouterTable() {
//扫描`@Capability`与`@CapabilityMethod`注解的类与方法 //扫描`@Capability`与`@CapabilityMethod`注解的类与方法
//将`capabilityValue.methodSignature`作为key,函数对象为通过反射拿到的core实例对应的方法 //将`capabilityValue.methodSignature`作为key,函数对象为通过反射拿到的core实例对应的方法
@@ -127,6 +149,9 @@ public final class CapabilityRegisterFactory extends AgentBaseFactory {
})); }));
} }
/**
* 反射获取<code>CapabilityCore</code>实例
*/
private void setCapabilityCoreInstances() { private void setCapabilityCoreInstances() {
try { try {
for (Class<?> core : cores) { for (Class<?> core : cores) {
@@ -136,7 +161,7 @@ public final class CapabilityRegisterFactory extends AgentBaseFactory {
} }
} catch (InvocationTargetException | NoSuchMethodException | InstantiationException | } catch (InvocationTargetException | NoSuchMethodException | InstantiationException |
IllegalAccessException e) { IllegalAccessException e) {
throw new CoreInstancesCreateFailedException("core实例创建失败"); throw new CoreInstancesCreateFailedExceptionCapability("core实例创建失败");
} }
} }
} }

View File

@@ -0,0 +1,11 @@
package work.slhaf.partner.api.factory.capability.exception;
public class CapabilityFactoryExecuteFailedException extends RuntimeException {
public CapabilityFactoryExecuteFailedException(String message) {
super("CapabilityRegisterFactory 执行失败: " + message);
}
public CapabilityFactoryExecuteFailedException(String message, Throwable cause) {
super("CapabilityRegisterFactory 执行失败: " + message, cause);
}
}

View File

@@ -1,11 +0,0 @@
package work.slhaf.partner.api.factory.capability.exception;
public class CoreInstancesCreateFailedException extends FactoryExecuteFailedException {
public CoreInstancesCreateFailedException(String message) {
super(message);
}
public CoreInstancesCreateFailedException(String message, Throwable cause) {
super(message, cause);
}
}

View File

@@ -0,0 +1,11 @@
package work.slhaf.partner.api.factory.capability.exception;
public class CoreInstancesCreateFailedExceptionCapability extends CapabilityFactoryExecuteFailedException {
public CoreInstancesCreateFailedExceptionCapability(String message) {
super(message);
}
public CoreInstancesCreateFailedExceptionCapability(String message, Throwable cause) {
super(message, cause);
}
}

View File

@@ -1,11 +0,0 @@
package work.slhaf.partner.api.factory.capability.exception;
public class FactoryExecuteFailedException extends RuntimeException {
public FactoryExecuteFailedException(String message) {
super("CapabilityRegisterFactory 执行失败: " + message);
}
public FactoryExecuteFailedException(String message, Throwable cause) {
super("CapabilityRegisterFactory 执行失败: " + message, cause);
}
}

View File

@@ -1,11 +0,0 @@
package work.slhaf.partner.api.factory.capability.exception;
public class ProxySetFailedException extends FactoryExecuteFailedException{
public ProxySetFailedException(String message) {
super(message);
}
public ProxySetFailedException(String message, Throwable cause) {
super(message, cause);
}
}

View File

@@ -0,0 +1,11 @@
package work.slhaf.partner.api.factory.capability.exception;
public class ProxySetFailedExceptionCapability extends CapabilityFactoryExecuteFailedException {
public ProxySetFailedExceptionCapability(String message) {
super(message);
}
public ProxySetFailedExceptionCapability(String message, Throwable cause) {
super(message, cause);
}
}

View File

@@ -0,0 +1,22 @@
package work.slhaf.partner.api.factory.config;
import work.slhaf.partner.api.factory.entity.AgentBaseFactory;
import work.slhaf.partner.api.factory.entity.AgentRegisterContext;
public class ConfigLoaderFactory extends AgentBaseFactory {
@Override
protected void setVariables(AgentRegisterContext context) {
}
@Override
protected void run() {
//反射获取是否存在其他的ModelConfigFactory子类如果存在则不使用默认配置工厂
ModelConfigFactory factory;
if (){
}
factory.load();
}
}

View File

@@ -0,0 +1,29 @@
package work.slhaf.partner.api.factory.config;
import work.slhaf.partner.api.common.chat.pojo.Message;
import work.slhaf.partner.api.factory.config.pojo.ModelConfig;
import java.util.HashMap;
import java.util.List;
/**
* 默认配置工厂
* 将从当前运行目录的config文件夹下创建并读取配置
*/
public class DefaultModelConfigFactory extends ModelConfigFactory {
private static final String MODEL_CONFIG_DIR = "./config/model/";
private static final String PROMPT_CONFIG_DIR = "./config/prompt/";
@Override
protected HashMap<String, List<Message>> loadPrompt() {
return null;
}
@Override
protected HashMap<String, ModelConfig> loadConfig() {
return null;
}
}

View File

@@ -0,0 +1,53 @@
package work.slhaf.partner.api.factory.config;
import work.slhaf.partner.api.common.chat.pojo.Message;
import work.slhaf.partner.api.factory.config.exception.UnExistModelConfigException;
import work.slhaf.partner.api.factory.config.exception.UnExistModelPromptException;
import work.slhaf.partner.api.factory.config.pojo.ModelConfig;
import java.util.HashMap;
import java.util.List;
public abstract class ModelConfigFactory {
public static ModelConfigFactory factory;
protected HashMap<String, ModelConfig> modelConfigMap;
protected HashMap<String, List<Message>> modelPromptMap;
public ModelConfigFactory() {
factory = this;
}
public void load() {
modelConfigMap = loadConfig();
modelPromptMap = loadPrompt();
}
protected abstract HashMap<String, List<Message>> loadPrompt();
protected abstract HashMap<String, ModelConfig> loadConfig();
public List<Message> loadModelPrompt(String modelKey){
if (!modelPromptMap.containsKey(modelKey)){
throw new UnExistModelPromptException("不存在的modelPrompt: "+modelKey);
}
return modelPromptMap.get(modelKey);
}
public ModelConfig loadModelConfig(String modelKey) {
if (!modelConfigMap.containsKey(modelKey)) {
throw new UnExistModelConfigException("不存在的modelKey: " + modelKey);
}
return modelConfigMap.get(modelKey);
}
public void updateModelConfig(String modelKey, ModelConfig config) {
if (!modelConfigMap.containsKey(modelKey)) {
throw new UnExistModelConfigException("不存在的modelKey: " + modelKey);
}
modelConfigMap.get(modelKey).setModel(config.getModel());
modelConfigMap.get(modelKey).setBaseUrl(config.getBaseUrl());
modelConfigMap.get(modelKey).setApikey(config.getApikey());
}
}

View File

@@ -0,0 +1,11 @@
package work.slhaf.partner.api.factory.config.exception;
public class ModelConfigFactoryFailedException extends RuntimeException {
public ModelConfigFactoryFailedException(String message, Throwable cause) {
super("ModelConfigFactory 执行失败: " + message, cause);
}
public ModelConfigFactoryFailedException(String message) {
super("ModelConfigFactory 执行失败: " + message);
}
}

View File

@@ -0,0 +1,11 @@
package work.slhaf.partner.api.factory.config.exception;
public class UnExistModelConfigException extends ModelConfigFactoryFailedException {
public UnExistModelConfigException(String message, Throwable e) {
super(message, e);
}
public UnExistModelConfigException(String message) {
super(message);
}
}

View File

@@ -0,0 +1,11 @@
package work.slhaf.partner.api.factory.config.exception;
public class UnExistModelPromptException extends ModelConfigFactoryFailedException{
public UnExistModelPromptException(String message) {
super(message);
}
public UnExistModelPromptException(String message, Throwable cause) {
super(message, cause);
}
}

View File

@@ -0,0 +1,10 @@
package work.slhaf.partner.api.factory.config.pojo;
import lombok.Data;
@Data
public class ModelConfig {
private String baseUrl;
private String apikey;
private String model;
}

View File

@@ -1,6 +1,6 @@
package work.slhaf.partner.api.factory.entity; package work.slhaf.partner.api.factory.entity;
import work.slhaf.partner.api.factory.capability.exception.FactoryExecuteFailedException; import work.slhaf.partner.api.factory.capability.exception.CapabilityFactoryExecuteFailedException;
import java.lang.reflect.InvocationTargetException; import java.lang.reflect.InvocationTargetException;
@@ -10,7 +10,7 @@ public abstract class AgentBaseFactory {
setVariables(context); setVariables(context);
run(); run();
} catch (Exception e) { } catch (Exception e) {
throw new FactoryExecuteFailedException(e.getMessage(), e); throw new CapabilityFactoryExecuteFailedException(e.getMessage(), e);
} }
} }

View File

@@ -4,6 +4,9 @@ import org.reflections.Reflections;
import work.slhaf.partner.api.factory.entity.AgentBaseFactory; import work.slhaf.partner.api.factory.entity.AgentBaseFactory;
import work.slhaf.partner.api.factory.entity.AgentRegisterContext; import work.slhaf.partner.api.factory.entity.AgentRegisterContext;
/**
* 负责扫描<code>@Module</code>注解获取模块实例
*/
public class ModuleRegisterFactory extends AgentBaseFactory { public class ModuleRegisterFactory extends AgentBaseFactory {
private Reflections reflections; private Reflections reflections;

View File

@@ -1,9 +1,7 @@
package work.slhaf.partner.api.factory.module.annotation; package work.slhaf.partner.api.factory.module.annotation;
import java.lang.annotation.ElementType;
import java.lang.annotation.Retention; import java.lang.annotation.*;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;
/** /**
* 用于注解执行模块 * 用于注解执行模块

View File

@@ -0,0 +1,5 @@
package work.slhaf.partner.api.factory.module.annotation;
@AgentModule(name = "core",order = 5)
public @interface CoreModule {
}

View File

@@ -1,9 +1,12 @@
package work.slhaf.partner.api.flow.abstracts; package work.slhaf.partner.api.flow.abstracts;
import work.slhaf.partner.api.common.chat.ChatClient; import work.slhaf.partner.api.common.chat.ChatClient;
import work.slhaf.partner.api.common.chat.Model;
import work.slhaf.partner.api.common.chat.constant.ChatConstant; import work.slhaf.partner.api.common.chat.constant.ChatConstant;
import work.slhaf.partner.api.common.chat.pojo.ChatResponse; import work.slhaf.partner.api.common.chat.pojo.ChatResponse;
import work.slhaf.partner.api.common.chat.pojo.Message; import work.slhaf.partner.api.common.chat.pojo.Message;
import work.slhaf.partner.api.factory.config.ModelConfigFactory;
import work.slhaf.partner.api.factory.config.pojo.ModelConfig;
import work.slhaf.partner.api.factory.module.annotation.Before; import work.slhaf.partner.api.factory.module.annotation.Before;
import java.util.ArrayList; import java.util.ArrayList;
@@ -11,13 +14,22 @@ import java.util.List;
public interface ActivateModel { public interface ActivateModel {
@Before @Before
default void modelSettings() { default void modelSettings() {
// Model model = new Model(); Model model = new Model();
// ModelConfig modelConfig = ModelConfig.load(modelKey()); ModelConfig modelConfig = ModelConfigFactory.factory.loadModelConfig(modelKey());
// model.setBaseMessages(withAwareness() ? ResourcesUtil.Prompt.loadPromptWithSelfAwareness(modelKey(), promptModule()) : ResourcesUtil.Prompt.loadPrompt(modelKey(), promptModule())); model.setBaseMessages(withBasicPrompt() ? loadSpecificPromptAndBasicPrompt(modelKey(), promptModule()) : loadSpecificPrompt(modelKey(), promptModule()));
// model.setChatClient(new ChatClient(modelConfig.getBaseUrl(), modelConfig.getApikey(), modelConfig.getModel())); model.setChatClient(new ChatClient(modelConfig.getBaseUrl(), modelConfig.getApikey(), modelConfig.getModel()));
}
private List<Message> loadSpecificPrompt(String modelKey, String specificModule) {
return null;
}
private List<Message> loadSpecificPromptAndBasicPrompt(String modelKey, String specificModule) {
return null;
} }
default ChatResponse chat() { default ChatResponse chat() {
@@ -68,7 +80,7 @@ public interface ActivateModel {
String modelKey(); String modelKey();
boolean withAwareness(); boolean withBasicPrompt();
String promptModule(); String promptModule();
} }

View File

@@ -2,8 +2,12 @@ package work.slhaf.partner.api.flow.abstracts;
import lombok.Getter; import lombok.Getter;
import lombok.Setter; import lombok.Setter;
import work.slhaf.partner.api.common.chat.Model;
import work.slhaf.partner.api.factory.capability.annotation.CapabilityHolder; import work.slhaf.partner.api.factory.capability.annotation.CapabilityHolder;
/**
* 模块基类
*/
@CapabilityHolder @CapabilityHolder
public abstract class Module { public abstract class Module {

View File

@@ -14,7 +14,7 @@ public class ResourcesUtil {
private static final ClassLoader classloader = Agent.class.getClassLoader(); private static final ClassLoader classloader = Agent.class.getClassLoader();
public static class Prompt { public static class Prompt {
private static final String SELF_AWARENESS_PATH = "prompt/self_awareness.json"; private static final String SELF_AWARENESS_PATH = "prompt/basic_prompt.json";
private static final String MODULE_PROMPT_PREFIX_PATH = "prompt/module/"; private static final String MODULE_PROMPT_PREFIX_PATH = "prompt/module/";
public static List<Message> loadPromptWithSelfAwareness(String modelKey, String promptType) { public static List<Message> loadPromptWithSelfAwareness(String modelKey, String promptType) {

View File

@@ -69,7 +69,7 @@ public class CoreModel extends CoreModule implements ActivateModel {
} }
@Override @Override
public boolean withAwareness() { public boolean withBasicPrompt() {
return true; return true;
} }

View File

@@ -141,7 +141,7 @@ public class SliceSelectEvaluator extends AgentInteractionSubModule<EvaluatorInp
} }
@Override @Override
public boolean withAwareness() { public boolean withBasicPrompt() {
return false; return false;
} }

View File

@@ -114,7 +114,7 @@ public class MemorySelectExtractor extends AgentInteractionSubModule<Interaction
} }
@Override @Override
public boolean withAwareness() { public boolean withBasicPrompt() {
return false; return false;
} }

View File

@@ -71,7 +71,7 @@ public class MultiSummarizer extends AgentInteractionSubModule<SummarizeInput, S
} }
@Override @Override
public boolean withAwareness() { public boolean withBasicPrompt() {
return true; return true;
} }

View File

@@ -83,7 +83,7 @@ public class SingleSummarizer extends AgentInteractionSubModule<List<Message>,Vo
} }
@Override @Override
public boolean withAwareness() { public boolean withBasicPrompt() {
return false; return false;
} }

View File

@@ -49,7 +49,7 @@ public class TotalSummarizer extends AgentInteractionSubModule<HashMap<String, S
} }
@Override @Override
public boolean withAwareness() { public boolean withBasicPrompt() {
return true; return true;
} }

View File

@@ -91,7 +91,7 @@ public class RelationExtractor extends AgentInteractionSubModule<InteractionCont
} }
@Override @Override
public boolean withAwareness() { public boolean withBasicPrompt() {
return true; return true;
} }

View File

@@ -64,7 +64,7 @@ public class StaticMemoryExtractor extends AgentInteractionSubModule<Interaction
} }
@Override @Override
public boolean withAwareness() { public boolean withBasicPrompt() {
return true; return true;
} }