mirror of
https://github.com/slhaf/Partner.git
synced 2026-05-14 09:43:03 +08:00
Compare commits
213 Commits
old
...
action-mod
| Author | SHA1 | Date | |
|---|---|---|---|
| f6afe21b43 | |||
| d381a97731 | |||
| 940beb2587 | |||
| 69d9f04f11 | |||
| e2bd9eb0af | |||
| 9ec03c4c95 | |||
| ecbbbc9954 | |||
| a5d26769e8 | |||
| 2db1bdf3e9 | |||
| 656d6b65e3 | |||
| 7c46f1d1ff | |||
| 406b4250aa | |||
| eab3d00fe8 | |||
| d47e9fbf95 | |||
| 4b77f26e7b | |||
| 650f9b27a1 | |||
| 9f479c5f6f | |||
| 227c735667 | |||
| b05b665960 | |||
| 882ec43f2b | |||
| 7cb565fd1b | |||
| 84b96b6645 | |||
| 2169376062 | |||
| 9bff74c8c7 | |||
| 76c9c27532 | |||
| 8524ca6f9f | |||
| 7dd2104689 | |||
| 6ba5784a7f | |||
| cdea8d6322 | |||
| 8ca2b9998d | |||
| d098b28f31 | |||
| 98e4d4cf1b | |||
| 70489e57f7 | |||
| a43c87006e | |||
| be43b7eec6 | |||
| 3bc2ce839a | |||
| fe5a366527 | |||
| 9f724cee5d | |||
| ad58b83020 | |||
| c9b64fec2a | |||
| 0eb4765235 | |||
| 050c39cbc7 | |||
| 08100aea8a | |||
| 2cd0774834 | |||
| 12df938d85 | |||
| 277c0d437f | |||
| 6b861f4b77 | |||
| d33b6617c1 | |||
| a1dcf4a6fa | |||
| 9c38719514 | |||
| 33df0fa017 | |||
| 08bda84471 | |||
| 76da3c29f8 | |||
| 558b589830 | |||
| 80d7c283c5 | |||
| b0bb40c5f0 | |||
| eec8f71096 | |||
| fbd30d1a96 | |||
| 346f925b66 | |||
| 04e8d9e531 | |||
| 63d1552de2 | |||
| 77eb9b92a4 | |||
| a1b4743eeb | |||
| 0768cddd2d | |||
| 75145cc547 | |||
| d1ca1cda7d | |||
| fac6609d6b | |||
| dce8825e58 | |||
| cd641ac8dd | |||
| 5ffdab9e4a | |||
| 830503eee4 | |||
| 96e74ec877 | |||
| 420d51af15 | |||
| 8ead306b7b | |||
| c793851107 | |||
| fb5cabc747 | |||
| c5f6c4e0ae | |||
| 200c0f3f13 | |||
| fdf398b86e | |||
| 774e2b6cd5 | |||
| 837a4c92d1 | |||
| ddd999d47b | |||
| 9694a022c7 | |||
| 31968c7076 | |||
| abec141e4e | |||
| cdb6ae9d01 | |||
| dd8d86d3c4 | |||
| 99b42620d0 | |||
| 70b8335d49 | |||
| 8ca475beeb | |||
| 4f36c0dd2d | |||
| 00993bd763 | |||
| a0bca668cb | |||
| c6118c41b0 | |||
| 872d21170a | |||
| 44ab6cfac8 | |||
| ec30ac1922 | |||
| 74b6d0c653 | |||
| de462866b2 | |||
| 4ea8926363 | |||
| 04c98c7856 | |||
| 0757856187 | |||
| 19ec93f248 | |||
| 5877b9e80d | |||
| 5db0b5fad1 | |||
| 623a86daab | |||
| 64f24d3fc3 | |||
| 3097efe453 | |||
| b58eeffd2f | |||
| 62cec79005 | |||
| 03a5935107 | |||
| 0ecaec0545 | |||
| 74f2c6c950 | |||
| f35a467ebc | |||
| 64b907707a | |||
| a6e33edc7a | |||
| 94ef79c67d | |||
| a222015abb | |||
| 1c562f0e7b | |||
| 89535a6b1c | |||
| 6e90bc8d67 | |||
| 0e741802d1 | |||
| db3435fccf | |||
| e3294ec302 | |||
| bf99e01b51 | |||
| 1bd23b20c4 | |||
| 442dd55686 | |||
| abe5dd5251 | |||
| 1f737c0e29 | |||
| d41074c814 | |||
| 621441601a | |||
| e00d77f076 | |||
| d614ac0b15 | |||
| 592e2604d9 | |||
| dcbd2c6569 | |||
| 476acb0641 | |||
| 88a14f36b2 | |||
| 05d1fff125 | |||
| 49a4c9eb01 | |||
| 9e76c3e7ad | |||
| 9762739138 | |||
| 1f5509c17d | |||
| ed042cfffa | |||
| 128592e23c | |||
| 5ba36ed3e8 | |||
| 4dea948f82 | |||
| dc4074715e | |||
| 225802c1a8 | |||
| e851e33b2e | |||
| cb28a5b068 | |||
| ad58567ada | |||
| 0eee12d685 | |||
| 1e6ff1b30c | |||
| 0413fc281d | |||
| 8a7681ae31 | |||
| 1947f25ed6 | |||
| 488246525f | |||
| 534dcd5ade | |||
| ad58c0cc7c | |||
| d546148d69 | |||
| bf2d5ac707 | |||
| 628234f6e2 | |||
| 4b852e0049 | |||
| 6e3deced77 | |||
| 6a351413a1 | |||
| ad973d4230 | |||
| 1d315a9b62 | |||
| 4e32129b31 | |||
| 3f59719e16 | |||
| c548cceec6 | |||
| b3098310b4 | |||
| f48d559a7b | |||
| 14a57f0be6 | |||
| dff7b69b51 | |||
| d77ffd1db6 | |||
| 264cdb09e5 | |||
| fea7f9c81f | |||
| a1520f117b | |||
| ae5caf8475 | |||
| 980d9384d1 | |||
| 9ba0d1363a | |||
| f6d5cad5cd | |||
| c3ca4145b8 | |||
| 5419722c40 | |||
| 31ebee3ded | |||
| 746fda1a5e | |||
| ec4fbb7f19 | |||
| f9c3cacfea | |||
| e35e18f3b7 | |||
| 83832d2060 | |||
| 4757425a15 | |||
| 21b3a0e846 | |||
| 6bfa941c35 | |||
| 456a7e04e8 | |||
| 5864760f35 | |||
| aee6d879e9 | |||
| d1ea8dde79 | |||
| 7094a8a68b | |||
| e78048f66d | |||
| 2f09c0cd71 | |||
| 8c43d6594f | |||
| 2d052442b1 | |||
| 84f7befb75 | |||
| 85818556f8 | |||
| cb1a25e9d5 | |||
| a10a149edb | |||
| 41bf19f43e | |||
| 941943f696 | |||
| a7d54349e4 | |||
| 3c2ac32708 | |||
| 7f9d007f07 | |||
| c1018d6b54 | |||
| 47684c78e0 |
11
.gitignore
vendored
11
.gitignore
vendored
@@ -36,8 +36,8 @@ build/
|
||||
|
||||
### Mac OS ###
|
||||
.DS_Store
|
||||
/data/
|
||||
/config/
|
||||
/backup/data/
|
||||
/backup/config/
|
||||
/Partner-Core/src/main/java/src/test/java/memory/test.json
|
||||
/Partner-Core/src/main/java/src/test/java/memory/result/input1.json
|
||||
/Partner-Core/src/main/java/src/test/java/memory/result/input2.json
|
||||
@@ -51,3 +51,10 @@ build/
|
||||
/backup/
|
||||
/Partner-Main/src/test/java/text/test.json
|
||||
/CLAUDE.md
|
||||
/config/
|
||||
/data/
|
||||
/generated-classes/
|
||||
/.idea/copilot.data.migration.ask2agent.xml
|
||||
/Partner-Main/data/
|
||||
/AGENTS.md
|
||||
/.serena/
|
||||
|
||||
6
.idea/copilot.data.migration.agent.xml
generated
Normal file
6
.idea/copilot.data.migration.agent.xml
generated
Normal file
@@ -0,0 +1,6 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="AgentMigrationStateService">
|
||||
<option name="migrationStatus" value="COMPLETED" />
|
||||
</component>
|
||||
</project>
|
||||
6
.idea/copilot.data.migration.ask.xml
generated
Normal file
6
.idea/copilot.data.migration.ask.xml
generated
Normal file
@@ -0,0 +1,6 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="AskMigrationStateService">
|
||||
<option name="migrationStatus" value="COMPLETED" />
|
||||
</component>
|
||||
</project>
|
||||
6
.idea/copilot.data.migration.edit.xml
generated
Normal file
6
.idea/copilot.data.migration.edit.xml
generated
Normal file
@@ -0,0 +1,6 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="EditMigrationStateService">
|
||||
<option name="migrationStatus" value="COMPLETED" />
|
||||
</component>
|
||||
</project>
|
||||
3
.idea/encodings.xml
generated
3
.idea/encodings.xml
generated
@@ -3,10 +3,13 @@
|
||||
<component name="Encoding">
|
||||
<file url="file://$PROJECT_DIR$/Partner-Api/src/main/java" charset="UTF-8" />
|
||||
<file url="file://$PROJECT_DIR$/Partner-Api/src/main/resources" charset="UTF-8" />
|
||||
<file url="file://$PROJECT_DIR$/Partner-Common/src/main/java" charset="UTF-8" />
|
||||
<file url="file://$PROJECT_DIR$/Partner-Common/src/main/resources" charset="UTF-8" />
|
||||
<file url="file://$PROJECT_DIR$/Partner-Main/src/main/java" charset="UTF-8" />
|
||||
<file url="file://$PROJECT_DIR$/Partner-Main/src/main/java/src/main/java" charset="UTF-8" />
|
||||
<file url="file://$PROJECT_DIR$/Partner-Main/src/main/java/src/main/resources" charset="UTF-8" />
|
||||
<file url="file://$PROJECT_DIR$/Partner-Main/src/main/resources" charset="UTF-8" />
|
||||
<file url="file://$PROJECT_DIR$/Partner-SandboxRunner/src/main/java" charset="UTF-8" />
|
||||
<file url="file://$PROJECT_DIR$/Partner-Test-Demo/src/main/java" charset="UTF-8" />
|
||||
<file url="file://$PROJECT_DIR$/Partner-Test-Demo/src/main/resources" charset="UTF-8" />
|
||||
<file url="file://$PROJECT_DIR$/src/main/java" charset="UTF-8" />
|
||||
|
||||
6
.idea/kotlinc.xml
generated
Normal file
6
.idea/kotlinc.xml
generated
Normal file
@@ -0,0 +1,6 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="KotlinJpsPluginSettings">
|
||||
<option name="version" value="2.2.0" />
|
||||
</component>
|
||||
</project>
|
||||
29
.idea/misc.xml
generated
29
.idea/misc.xml
generated
@@ -1,20 +1,33 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="EntryPointsManager">
|
||||
<list size="6">
|
||||
<item index="0" class="java.lang.String" itemvalue="work.slhaf.partner.api.agent.factory.capability.annotation.Capability" />
|
||||
<item index="1" class="java.lang.String" itemvalue="work.slhaf.partner.api.agent.factory.capability.annotation.CapabilityCore" />
|
||||
<item index="2" class="java.lang.String" itemvalue="work.slhaf.partner.api.agent.factory.capability.annotation.CapabilityMethod" />
|
||||
<item index="3" class="java.lang.String" itemvalue="work.slhaf.partner.api.capability.annotation.CapabilityMethod" />
|
||||
<item index="4" class="java.lang.String" itemvalue="work.slhaf.partner.api.capability.annotation.CoordinateManager" />
|
||||
<item index="5" class="java.lang.String" itemvalue="work.slhaf.partner.api.register.capability.annotation.Capability" />
|
||||
<list size="15">
|
||||
<item index="0" class="java.lang.String" itemvalue="lombok.Data" />
|
||||
<item index="1" class="java.lang.String" itemvalue="net.bytebuddy.implementation.bind.annotation.RuntimeType" />
|
||||
<item index="2" class="java.lang.String" itemvalue="work.slhaf.partner.api.agent.factory.capability.annotation.Capability" />
|
||||
<item index="3" class="java.lang.String" itemvalue="work.slhaf.partner.api.agent.factory.capability.annotation.CapabilityCore" />
|
||||
<item index="4" class="java.lang.String" itemvalue="work.slhaf.partner.api.agent.factory.capability.annotation.CapabilityMethod" />
|
||||
<item index="5" class="java.lang.String" itemvalue="work.slhaf.partner.api.agent.factory.capability.annotation.CoordinateManager" />
|
||||
<item index="6" class="java.lang.String" itemvalue="work.slhaf.partner.api.agent.factory.capability.annotation.Coordinated" />
|
||||
<item index="7" class="java.lang.String" itemvalue="work.slhaf.partner.api.agent.factory.module.annotation.AfterExecute" />
|
||||
<item index="8" class="java.lang.String" itemvalue="work.slhaf.partner.api.agent.factory.module.annotation.AgentModule" />
|
||||
<item index="9" class="java.lang.String" itemvalue="work.slhaf.partner.api.agent.factory.module.annotation.AgentSubModule" />
|
||||
<item index="10" class="java.lang.String" itemvalue="work.slhaf.partner.api.agent.factory.module.annotation.BeforeExecute" />
|
||||
<item index="11" class="java.lang.String" itemvalue="work.slhaf.partner.api.agent.factory.module.annotation.Init" />
|
||||
<item index="12" class="java.lang.String" itemvalue="work.slhaf.partner.api.capability.annotation.CapabilityMethod" />
|
||||
<item index="13" class="java.lang.String" itemvalue="work.slhaf.partner.api.capability.annotation.CoordinateManager" />
|
||||
<item index="14" class="java.lang.String" itemvalue="work.slhaf.partner.api.register.capability.annotation.Capability" />
|
||||
</list>
|
||||
<writeAnnotations>
|
||||
<writeAnnotation name="work.slhaf.partner.api.agent.factory.capability.annotation.InjectCapability" />
|
||||
<writeAnnotation name="work.slhaf.partner.api.agent.factory.module.annotation.InjectModule" />
|
||||
</writeAnnotations>
|
||||
</component>
|
||||
<component name="ExternalStorageConfigurationManager" enabled="true" />
|
||||
<component name="MavenProjectsManager">
|
||||
<option name="originalFiles">
|
||||
<list>
|
||||
<option value="$PROJECT_DIR$/pom.xml" />
|
||||
<option value="$PROJECT_DIR$/PartnerExecutor/pom.xml" />
|
||||
</list>
|
||||
</option>
|
||||
</component>
|
||||
|
||||
@@ -1,33 +0,0 @@
|
||||
autoDetectedPackages:
|
||||
- factory
|
||||
- module
|
||||
- work.slhaf
|
||||
enableAutoDetect: true
|
||||
entryDisplayConfig:
|
||||
excludedPathPatterns: []
|
||||
skipJsCss: true
|
||||
funcDisplayConfig:
|
||||
skipConstructors: false
|
||||
skipFieldAccess: true
|
||||
skipFieldChange: true
|
||||
skipGetters: false
|
||||
skipNonProjectPackages: false
|
||||
skipPrivateMethods: false
|
||||
skipSetters: false
|
||||
ignoreSameClassCall: null
|
||||
ignoreSamePackageCall: null
|
||||
includedPackagePrefixes: null
|
||||
includedParentClasses: null
|
||||
maxColSize: 32
|
||||
maxNumFirst: 12
|
||||
maxNumFirstImportant: 1024
|
||||
maxNumHash: 3
|
||||
maxNumHashImportant: 256
|
||||
maxObjectDepth: 4
|
||||
maxStrSize: 4096
|
||||
name: xcodemap-filter
|
||||
openMainWindow: true
|
||||
recordMode: manual
|
||||
sourceDisplayConfig:
|
||||
color: blue
|
||||
startOnDebug: false
|
||||
29
Partner-Api/dependency-reduced-pom.xml
Normal file
29
Partner-Api/dependency-reduced-pom.xml
Normal file
@@ -0,0 +1,29 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/maven-v4_0_0.xsd">
|
||||
<parent>
|
||||
<artifactId>Partner</artifactId>
|
||||
<groupId>work.slhaf</groupId>
|
||||
<version>0.5.0</version>
|
||||
</parent>
|
||||
<modelVersion>4.0.0</modelVersion>
|
||||
<artifactId>Partner-Api</artifactId>
|
||||
<dependencies>
|
||||
<dependency>
|
||||
<groupId>junit</groupId>
|
||||
<artifactId>junit</artifactId>
|
||||
<version>4.13.2</version>
|
||||
<scope>test</scope>
|
||||
<exclusions>
|
||||
<exclusion>
|
||||
<artifactId>hamcrest-core</artifactId>
|
||||
<groupId>org.hamcrest</groupId>
|
||||
</exclusion>
|
||||
</exclusions>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
<properties>
|
||||
<maven.compiler.target>21</maven.compiler.target>
|
||||
<maven.compiler.source>21</maven.compiler.source>
|
||||
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
|
||||
</properties>
|
||||
</project>
|
||||
@@ -1,5 +1,6 @@
|
||||
package work.slhaf.partner.api.agent;
|
||||
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import work.slhaf.partner.api.agent.factory.AgentRegisterFactory;
|
||||
import work.slhaf.partner.api.agent.runtime.config.AgentConfigManager;
|
||||
import work.slhaf.partner.api.agent.runtime.exception.AgentExceptionCallback;
|
||||
@@ -9,20 +10,30 @@ import work.slhaf.partner.api.agent.runtime.interaction.AgentGateway;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.concurrent.CountDownLatch;
|
||||
import java.util.concurrent.ExecutorService;
|
||||
import java.util.concurrent.Executors;
|
||||
|
||||
/**
|
||||
* <h2>Agent 启动入口</h2>
|
||||
* 详细启动流程请参阅{@link AgentRegisterFactory}
|
||||
*/
|
||||
@Slf4j
|
||||
public final class Agent {
|
||||
|
||||
public static AgentGatewayStep newAgent(Class<?> clazz) {
|
||||
public static AgentConfigManagerStep newAgent(Class<?> clazz) {
|
||||
if (clazz == null) {
|
||||
throw new AgentLaunchFailedException("Agent class 和 interaction flow context 不能为 null");
|
||||
}
|
||||
return new AgentApp(clazz);
|
||||
}
|
||||
|
||||
public interface AgentConfigManagerStep {
|
||||
AgentGatewayStep setAgentConfigManager(Class<? extends AgentConfigManager> agentConfigManager);
|
||||
}
|
||||
|
||||
public interface AgentGatewayStep {
|
||||
AgentStep setGateway(AgentGateway gateway);
|
||||
AgentStep setGateway(Class<? extends AgentGateway> gateway);
|
||||
}
|
||||
|
||||
public interface AgentStep {
|
||||
@@ -30,9 +41,7 @@ public final class Agent {
|
||||
|
||||
AgentStep addAfterLaunchRunners(Runnable... runners);
|
||||
|
||||
AgentStep setAgentConfigManager(AgentConfigManager agentConfigManager);
|
||||
|
||||
AgentStep setAgentExceptionCallback(AgentExceptionCallback agentExceptionCallback);
|
||||
AgentStep setAgentExceptionCallback(Class<? extends AgentExceptionCallback> agentExceptionCallback);
|
||||
|
||||
AgentStep addScanPackage(String packageName);
|
||||
|
||||
@@ -42,21 +51,26 @@ public final class Agent {
|
||||
}
|
||||
|
||||
|
||||
public static class AgentApp implements AgentStep, AgentGatewayStep {
|
||||
public static class AgentApp implements AgentStep, AgentGatewayStep, AgentConfigManagerStep {
|
||||
|
||||
private final ExecutorService executorService = Executors.newVirtualThreadPerTaskExecutor();
|
||||
private final List<Runnable> beforeLaunchRunners = new ArrayList<>();
|
||||
private final List<Runnable> afterLaunchRunners = new ArrayList<>();
|
||||
private AgentGateway gateway;
|
||||
private final Class<?> applicationClass;
|
||||
private Class<? extends AgentConfigManager> agentConfigManagerClass;
|
||||
private Class<? extends AgentGateway> gatewayClass;
|
||||
private Class<? extends AgentExceptionCallback> agentExceptionCallbackClass;
|
||||
|
||||
private final CountDownLatch latch = new CountDownLatch(1);
|
||||
|
||||
private AgentApp(Class<?> clazz) {
|
||||
this.applicationClass = clazz;
|
||||
}
|
||||
|
||||
@Override
|
||||
public AgentStep setGateway(AgentGateway gateway) {
|
||||
this.gateway = gateway;
|
||||
public AgentStep setGateway(Class<? extends AgentGateway> gateway) {
|
||||
this.gatewayClass = gateway;
|
||||
return this;
|
||||
}
|
||||
|
||||
@@ -73,14 +87,14 @@ public final class Agent {
|
||||
}
|
||||
|
||||
@Override
|
||||
public AgentStep setAgentConfigManager(AgentConfigManager agentConfigManager) {
|
||||
AgentConfigManager.setINSTANCE(agentConfigManager);
|
||||
public AgentGatewayStep setAgentConfigManager(Class<? extends AgentConfigManager> agentConfigManager) {
|
||||
this.agentConfigManagerClass = agentConfigManager;
|
||||
return this;
|
||||
}
|
||||
|
||||
@Override
|
||||
public AgentStep setAgentExceptionCallback(AgentExceptionCallback agentExceptionCallback) {
|
||||
GlobalExceptionHandler.setExceptionCallback(agentExceptionCallback);
|
||||
public AgentStep setAgentExceptionCallback(Class<? extends AgentExceptionCallback> agentExceptionCallback) {
|
||||
agentExceptionCallbackClass = agentExceptionCallback;
|
||||
return this;
|
||||
}
|
||||
|
||||
@@ -98,10 +112,38 @@ public final class Agent {
|
||||
|
||||
@Override
|
||||
public void launch() {
|
||||
launchRunners(beforeLaunchRunners);
|
||||
beforeLaunch();
|
||||
AgentRegisterFactory.launch(applicationClass.getPackageName());
|
||||
executorService.execute(() -> gateway.launch());
|
||||
afterLaunch();
|
||||
}
|
||||
|
||||
private void afterLaunch() {
|
||||
try {
|
||||
this.gateway = gatewayClass.getDeclaredConstructor().newInstance();
|
||||
executorService.execute(() -> {
|
||||
gateway.launch();
|
||||
latch.countDown();
|
||||
log.info("Gateway 启动完毕: {}", gatewayClass.getSimpleName());
|
||||
});
|
||||
latch.await();
|
||||
launchRunners(afterLaunchRunners);
|
||||
log.info("后置任务启动完毕");
|
||||
} catch (Exception e) {
|
||||
throw new AgentLaunchFailedException("Agent 后置任务启动失败", e);
|
||||
}
|
||||
}
|
||||
|
||||
private void beforeLaunch() {
|
||||
try {
|
||||
AgentConfigManager.setINSTANCE(agentConfigManagerClass.getDeclaredConstructor().newInstance());
|
||||
log.info("配置管理器设置完毕: {}",agentConfigManagerClass.getSimpleName());
|
||||
GlobalExceptionHandler.setExceptionCallback(agentExceptionCallbackClass.getDeclaredConstructor().newInstance());
|
||||
log.info("异常处理回调设置完毕: {}",agentExceptionCallbackClass.getSimpleName());
|
||||
launchRunners(beforeLaunchRunners);
|
||||
log.info("前置任务启动完毕");
|
||||
} catch (Exception e) {
|
||||
throw new AgentLaunchFailedException("Agent 前置任务启动失败", e);
|
||||
}
|
||||
}
|
||||
|
||||
private void launchRunners(List<Runnable> runners) {
|
||||
|
||||
@@ -14,14 +14,22 @@ import work.slhaf.partner.api.agent.factory.module.ModuleInitHookExecuteFactory;
|
||||
import work.slhaf.partner.api.agent.factory.module.ModuleProxyFactory;
|
||||
import work.slhaf.partner.api.agent.factory.module.ModuleRegisterFactory;
|
||||
import work.slhaf.partner.api.agent.factory.module.pojo.MetaModule;
|
||||
import work.slhaf.partner.api.agent.runtime.data.AgentContext;
|
||||
import work.slhaf.partner.api.agent.runtime.config.AgentConfigManager;
|
||||
import work.slhaf.partner.api.agent.runtime.data.AgentContext;
|
||||
import work.slhaf.partner.api.agent.runtime.interaction.flow.AgentRunningFlow;
|
||||
|
||||
import java.io.File;
|
||||
import java.net.URL;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* <h2>Agent 注册工厂</h2>
|
||||
*
|
||||
* <p>
|
||||
* 具体流程依次按照 {@link AgentRegisterFactory#launch(String)} 方法顺序执行,最终将执行模块列表对应实例交给 {@link AgentConfigManager} ,传递给 {@link AgentRunningFlow} 针对交互做出调用
|
||||
* <p/>
|
||||
*/
|
||||
public class AgentRegisterFactory {
|
||||
|
||||
private static final List<URL> urls = new ArrayList<>();
|
||||
@@ -35,20 +43,20 @@ public class AgentRegisterFactory {
|
||||
//流程
|
||||
//0. 加载配置
|
||||
new ConfigLoaderFactory().execute(registerContext);
|
||||
//1. 注册并检查Capability
|
||||
new CapabilityRegisterFactory().execute(registerContext);
|
||||
new CapabilityCheckFactory().execute(registerContext);
|
||||
//2. 注册并检查Module
|
||||
//1. 注册并检查Module
|
||||
new ModuleCheckFactory().execute(registerContext);
|
||||
new ModuleRegisterFactory().execute(registerContext);
|
||||
//3. 为module通过动态代理添加PostHook逻辑并进行实例化
|
||||
//2. 为module通过动态代理添加PostHook逻辑并进行实例化
|
||||
new ModuleProxyFactory().execute(registerContext);
|
||||
//3. 加载检查Capability层内容后进行能力层的内容注册
|
||||
new CapabilityCheckFactory().execute(registerContext);
|
||||
new CapabilityRegisterFactory().execute(registerContext);
|
||||
//. 先一步注入Capability,避免因前hook逻辑存在针对能力的引用而报错
|
||||
new CapabilityInjectFactory().execute(registerContext);
|
||||
//. 执行模块PreHook逻辑
|
||||
new ModuleInitHookExecuteFactory().execute(registerContext);
|
||||
|
||||
List<MetaModule> moduleList = registerContext.getModuleFactoryContext().getModuleList();
|
||||
List<MetaModule> moduleList = registerContext.getModuleFactoryContext().getAgentModuleList();
|
||||
AgentConfigManager.INSTANCE.moduleEnabledStatusFilterAndRecord(moduleList);
|
||||
|
||||
BeanUtil.copyProperties(registerContext, AgentContext.INSTANCE);
|
||||
|
||||
@@ -1,24 +1,52 @@
|
||||
package work.slhaf.partner.api.agent.factory.capability;
|
||||
|
||||
import cn.hutool.core.util.ClassUtil;
|
||||
import org.reflections.Reflections;
|
||||
import work.slhaf.partner.api.agent.factory.AgentBaseFactory;
|
||||
import work.slhaf.partner.api.agent.factory.capability.annotation.*;
|
||||
import work.slhaf.partner.api.agent.factory.capability.exception.DuplicateCapabilityException;
|
||||
import work.slhaf.partner.api.agent.factory.capability.exception.UnMatchedCapabilityException;
|
||||
import work.slhaf.partner.api.agent.factory.capability.exception.UnMatchedCapabilityMethodException;
|
||||
import work.slhaf.partner.api.agent.factory.capability.exception.UnMatchedCoordinatedMethodException;
|
||||
import work.slhaf.partner.api.agent.factory.capability.exception.*;
|
||||
import work.slhaf.partner.api.agent.factory.context.AgentRegisterContext;
|
||||
import work.slhaf.partner.api.agent.factory.context.CapabilityFactoryContext;
|
||||
import work.slhaf.partner.api.agent.factory.module.annotation.AgentModule;
|
||||
import work.slhaf.partner.api.agent.factory.module.annotation.AgentSubModule;
|
||||
import work.slhaf.partner.api.agent.util.AgentUtil;
|
||||
|
||||
import java.lang.reflect.Method;
|
||||
import java.util.*;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
import static work.slhaf.partner.api.agent.util.AgentUtil.isAssignableFromAnnotation;
|
||||
import static work.slhaf.partner.api.agent.util.AgentUtil.methodSignature;
|
||||
|
||||
/**
|
||||
* 执行<code>Capability</code>相关检查
|
||||
* <h2>Agent启动流程 4</h2>
|
||||
*
|
||||
* <p>负责通过反射收集 {@link Capability} 和 {@link CapabilityCore} 注解所在类,并判断是否存在被错误忽略的方法</p>
|
||||
*
|
||||
* <ol>
|
||||
* <li>
|
||||
* <p>{@link CapabilityCheckFactory#loadCoresAndCapabilities()}</p>
|
||||
* 通过反射收集 {@link Capability} 和 {@link CapabilityCore} 注解所在类为对应集合
|
||||
* </li>
|
||||
* <li>
|
||||
* <p>{@link CapabilityCheckFactory#checkCountAndCapabilities()}</p>
|
||||
* 检测 {@link Capability} 与 {@link CapabilityCore} 的数量、对应的能力是否相等。每一个core都将对应一个capability,并通过value属性进行匹配
|
||||
* </li>
|
||||
* <li>
|
||||
* <p>{@link CapabilityCheckFactory#checkCapabilityMethods()}</p>
|
||||
* 检测在 {@link Capability} 与 {@link CapabilityCore} 中是否存在对方尚未实现/注册的方法
|
||||
* </li>
|
||||
* <li>
|
||||
* <p>{@link CapabilityCheckFactory#checkCoordinatedMethods()}</p>
|
||||
* 检查是否包含协调方法({@link ToCoordinated}),如果存在,则进一步检查在 {@link CoordinateManager} 所注类中是否有提供对应的实现
|
||||
* </li>
|
||||
* <li>
|
||||
* <p>{@link CapabilityCheckFactory#checkInjectCapability()}</p>
|
||||
* 检查 {@link InjectCapability} 注解是否只用在 {@link CapabilityHolder} 所标识类的字段上。{@link AgentModule} 与 {@link AgentSubModule} 已经被 {@link CapabilityHolder} 标注
|
||||
* </li>
|
||||
* </ol>
|
||||
*
|
||||
* <p>下一步流程请参阅{@link CapabilityRegisterFactory}</p>
|
||||
*/
|
||||
public class CapabilityCheckFactory extends AgentBaseFactory {
|
||||
|
||||
@@ -37,19 +65,42 @@ public class CapabilityCheckFactory extends AgentBaseFactory {
|
||||
|
||||
@Override
|
||||
protected void run() {
|
||||
loadCoresAndCapabilities();
|
||||
checkCountAndCapabilities();
|
||||
checkCapabilityMethods();
|
||||
checkCoordinatedMethods();
|
||||
checkCoordinatedManager();
|
||||
checkInjectCapability();
|
||||
}
|
||||
|
||||
private void checkCoordinatedManager() {
|
||||
reflections.getTypesAnnotatedWith(CoordinateManager.class)
|
||||
.stream()
|
||||
.filter(ClassUtil::isNormalClass)
|
||||
.forEach(managerClass -> {
|
||||
try {
|
||||
if (!managerClass.getDeclaredConstructor().canAccess(null)) {
|
||||
throw new CapabilityCheckFailedException("CoordinateManager 所注类的无参构造方法未公开!");
|
||||
}
|
||||
} catch (NoSuchMethodException e) {
|
||||
throw new CapabilityCheckFailedException("CoordinateManager 所注类缺少无参构造方法!");
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
private void loadCoresAndCapabilities() {
|
||||
cores.addAll(reflections.getTypesAnnotatedWith(CapabilityCore.class));
|
||||
capabilities.addAll(reflections.getTypesAnnotatedWith(Capability.class));
|
||||
}
|
||||
|
||||
/**
|
||||
* 检查<code>@InjectCapability</code>注解是否只用在<code>@CapabilityHolder</code>所标识类的字段上
|
||||
*/
|
||||
private void checkInjectCapability() {
|
||||
reflections.getFieldsAnnotatedWith(InjectCapability.class).forEach(field -> {
|
||||
if (!field.getDeclaringClass().isAssignableFrom(CapabilityHolder.class)) {
|
||||
throw new UnMatchedCapabilityException("InjectCapability 注解只能用于 CapabilityHolder 注解所在类");
|
||||
Class<?> declaringClass = field.getDeclaringClass();
|
||||
if (!isAssignableFromAnnotation(declaringClass, CapabilityHolder.class)) {
|
||||
throw new UnMatchedCapabilityException("InjectCapability 注解只能用于 CapabilityHolder 注解所在类,检查该类是否使用了@CapabilityHolder注解或者受其标注的注解或父类: " + declaringClass);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
@@ -8,6 +8,9 @@ import work.slhaf.partner.api.agent.factory.capability.annotation.ToCoordinated;
|
||||
import work.slhaf.partner.api.agent.factory.capability.exception.ProxySetFailedExceptionCapability;
|
||||
import work.slhaf.partner.api.agent.factory.context.AgentRegisterContext;
|
||||
import work.slhaf.partner.api.agent.factory.context.CapabilityFactoryContext;
|
||||
import work.slhaf.partner.api.agent.factory.module.ModuleInitHookExecuteFactory;
|
||||
import work.slhaf.partner.api.agent.factory.module.annotation.AgentModule;
|
||||
import work.slhaf.partner.api.agent.factory.module.annotation.AgentSubModule;
|
||||
|
||||
import java.lang.reflect.Field;
|
||||
import java.lang.reflect.Proxy;
|
||||
@@ -18,9 +21,23 @@ import java.util.function.Function;
|
||||
import static work.slhaf.partner.api.agent.util.AgentUtil.methodSignature;
|
||||
|
||||
/**
|
||||
* 负责执行<code>Capability</code>的注入逻辑
|
||||
*/
|
||||
public class CapabilityInjectFactory extends AgentBaseFactory {
|
||||
* <h2>Agent启动流程 6</h2>
|
||||
*
|
||||
* <p>负责执行 {@link Capability} 的注入逻辑。</p>
|
||||
*
|
||||
* <p>实现方式:</p>
|
||||
* <ol>
|
||||
* <li>通过动态代理,为 {@link AgentModule} 与 {@link AgentSubModule} 中待注入的
|
||||
* <b>能力接口</b> 类型(即 {@link Capability} 标注的接口类)生成代理对象。
|
||||
* </li>
|
||||
* <li>在代理对象内部,根据调用方法的签名确定路由,将调用转发至对应的具体函数。
|
||||
* </li>
|
||||
* <li>通过此机制,实现了 {@link Capability} 单一语义层面上普通方法与协调方法的统一入口。
|
||||
* </li>
|
||||
* </ol>
|
||||
*
|
||||
* <p>下一步流程请参阅 {@link ModuleInitHookExecuteFactory}</p>
|
||||
*/public class CapabilityInjectFactory extends AgentBaseFactory {
|
||||
|
||||
private Reflections reflections;
|
||||
private HashMap<String, Function<Object[], Object>> coordinatedMethodsRouterTable;
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
package work.slhaf.partner.api.agent.factory.capability;
|
||||
|
||||
import cn.hutool.core.util.ClassUtil;
|
||||
import org.reflections.Reflections;
|
||||
import work.slhaf.partner.api.agent.factory.AgentBaseFactory;
|
||||
import work.slhaf.partner.api.agent.factory.capability.annotation.*;
|
||||
@@ -8,28 +9,61 @@ import work.slhaf.partner.api.agent.factory.capability.exception.CoreInstancesCr
|
||||
import work.slhaf.partner.api.agent.factory.capability.exception.DuplicateMethodException;
|
||||
import work.slhaf.partner.api.agent.factory.context.AgentRegisterContext;
|
||||
import work.slhaf.partner.api.agent.factory.context.CapabilityFactoryContext;
|
||||
import work.slhaf.partner.api.agent.factory.module.annotation.AgentModule;
|
||||
import work.slhaf.partner.api.agent.factory.module.annotation.AgentSubModule;
|
||||
|
||||
import java.lang.reflect.Constructor;
|
||||
import java.lang.reflect.Field;
|
||||
import java.lang.reflect.InvocationTargetException;
|
||||
import java.lang.reflect.Method;
|
||||
import java.util.Arrays;
|
||||
import java.util.HashMap;
|
||||
import java.util.Set;
|
||||
import java.util.function.Function;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
import static cn.hutool.core.util.ClassUtil.isNormalClass;
|
||||
import static work.slhaf.partner.api.agent.util.AgentUtil.methodSignature;
|
||||
|
||||
|
||||
/**
|
||||
* 负责获取<code>@Capability</code>和<code>@CapabilityCore</code>标识的类,并生成函数路由表、设置<code>Core</code>实例用于后续注入
|
||||
* <h2>Agent启动流程 5</h2>
|
||||
*
|
||||
* <p>
|
||||
* 负责收集注解 {@link Capability} 和 {@link CapabilityCore} 标识的类,并生成函数路由表、创建core、capability实例,以及放入instanceMap供后续进行注入操作
|
||||
* </p>
|
||||
*
|
||||
* <ol>
|
||||
* <li>
|
||||
* <p>{@link CapabilityRegisterFactory#setCoreInstances()}</p>
|
||||
* 通过反射调用无参构造函数创建core实例,并将实例放入instanceMap供后续使用
|
||||
* </li>
|
||||
* <li>
|
||||
* <p>{@link CapabilityRegisterFactory#generateRouterTable()}</p>
|
||||
* 生成函数路由表:
|
||||
* <ul>
|
||||
* <li>
|
||||
* <p>{@link CapabilityRegisterFactory#generateMethodsRouterTable()}</p>
|
||||
* 生成普通方法对应的函数路由表
|
||||
* </li>
|
||||
* <li>
|
||||
* <p>{@link CapabilityRegisterFactory#generateCoordinatedMethodsRouterTable()}</p>
|
||||
* 生成协调方法对应的函数路由表
|
||||
* </li>
|
||||
* </ul>
|
||||
* </li>
|
||||
* <li>
|
||||
* 函数路由表生成完毕、core实例创建完毕之后,将交由下一工厂完成能力(Capability)注入操作,注入到 {@link AgentModule} 与 {@link AgentSubModule} 对应的实例中
|
||||
* </li>
|
||||
* </ol>
|
||||
*
|
||||
* <p>下一步流程请参阅{@link CapabilityInjectFactory}</p>
|
||||
*/
|
||||
public final class CapabilityRegisterFactory extends AgentBaseFactory {
|
||||
public class CapabilityRegisterFactory extends AgentBaseFactory {
|
||||
|
||||
private Reflections reflections;
|
||||
private HashMap<String, Function<Object[], Object>> methodsRouterTable;
|
||||
private HashMap<String, Function<Object[], Object>> coordinatedMethodsRouterTable;
|
||||
private HashMap<Class<?>, Object> capabilityCoreInstances;
|
||||
private HashMap<Class<?>, Object> coreInstances;
|
||||
private HashMap<Class<?>, Object> capabilityHolderInstances;
|
||||
private Set<Class<?>> cores;
|
||||
private Set<Class<?>> capabilities;
|
||||
@@ -40,35 +74,35 @@ public final class CapabilityRegisterFactory extends AgentBaseFactory {
|
||||
reflections = context.getReflections();
|
||||
methodsRouterTable = factoryContext.getMethodsRouterTable();
|
||||
coordinatedMethodsRouterTable = factoryContext.getCoordinatedMethodsRouterTable();
|
||||
capabilityCoreInstances = factoryContext.getCapabilityCoreInstances();
|
||||
coreInstances = factoryContext.getCapabilityCoreInstances();
|
||||
cores = factoryContext.getCores();
|
||||
capabilities = factoryContext.getCapabilities();
|
||||
capabilityHolderInstances = factoryContext.getCapabilityHolderInstances();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void run() throws InvocationTargetException, NoSuchMethodException, InstantiationException, IllegalAccessException {
|
||||
setCapabilityCoreInstances();
|
||||
setAnnotatedClasses();
|
||||
protected void run() {
|
||||
setCapabilityHolderInstances();
|
||||
setCoreInstances();
|
||||
generateRouterTable();
|
||||
}
|
||||
|
||||
/**
|
||||
* 设置<code>CapabilityCore</code>、<code>Capability</code>注解标识类
|
||||
*/
|
||||
private void setAnnotatedClasses() throws InvocationTargetException, NoSuchMethodException, InstantiationException, IllegalAccessException {
|
||||
cores.addAll(reflections.getTypesAnnotatedWith(CapabilityCore.class));
|
||||
capabilities.addAll(reflections.getTypesAnnotatedWith(Capability.class));
|
||||
setCapabilityHolderInstances();
|
||||
private void setCapabilityHolderInstances() {
|
||||
Set<Class<?>> collect = reflections.getTypesAnnotatedWith(CapabilityHolder.class).stream()
|
||||
.filter(ClassUtil::isNormalClass)
|
||||
.filter(clazz -> !capabilityHolderInstances.containsKey(clazz))
|
||||
.collect(Collectors.toSet());
|
||||
for (Class<?> clazz : collect) {
|
||||
try {
|
||||
Constructor<?> constructor = clazz.getDeclaredConstructor();
|
||||
if (constructor.canAccess(null)) {
|
||||
throw new CapabilityFactoryExecuteFailedException("缺少无参构造方法的类: " + clazz);
|
||||
}
|
||||
|
||||
private void setCapabilityHolderInstances() throws NoSuchMethodException, InvocationTargetException, InstantiationException, IllegalAccessException {
|
||||
for (Class<?> clazz : reflections.getTypesAnnotatedWith(CapabilityHolder.class)) {
|
||||
if (!isNormalClass(clazz)){
|
||||
continue;
|
||||
}
|
||||
Object o = clazz.getDeclaredConstructor().newInstance();
|
||||
Object o = constructor.newInstance();
|
||||
capabilityHolderInstances.put(clazz, o);
|
||||
} catch (Exception e) {
|
||||
throw new CapabilityFactoryExecuteFailedException("创建代理对象失败: " + clazz, e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -116,7 +150,7 @@ public final class CapabilityRegisterFactory extends AgentBaseFactory {
|
||||
for (Class<?> c : reflections.getTypesAnnotatedWith(CoordinateManager.class)) {
|
||||
Constructor<?> constructor = c.getDeclaredConstructor();
|
||||
Object instance = constructor.newInstance();
|
||||
|
||||
setCores(instance, c);
|
||||
Arrays.stream(c.getMethods())
|
||||
.filter(method -> method.isAnnotationPresent(Coordinated.class))
|
||||
.forEach(method -> {
|
||||
@@ -127,18 +161,26 @@ public final class CapabilityRegisterFactory extends AgentBaseFactory {
|
||||
return map;
|
||||
}
|
||||
|
||||
private void setCores(Object cmInstance, Class<?> cmClazz) throws IllegalAccessException {
|
||||
for (Field field : cmClazz.getFields()) {
|
||||
if (field.getType().isAnnotationPresent(CapabilityCore.class)) {
|
||||
field.setAccessible(true);
|
||||
field.set(cmInstance, coreInstances.get(field.getType()));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 生成普通方法对应的函数路由表
|
||||
* 扫描`@Capability`与`@CapabilityMethod`注解的类与方法
|
||||
* 将`capabilityValue.methodSignature`作为key,函数对象为通过反射拿到的core实例对应的方法
|
||||
*/
|
||||
private void generateMethodsRouterTable() {
|
||||
//扫描`@Capability`与`@CapabilityMethod`注解的类与方法
|
||||
//将`capabilityValue.methodSignature`作为key,函数对象为通过反射拿到的core实例对应的方法
|
||||
cores.forEach(core -> Arrays.stream(core.getMethods())
|
||||
.filter(method -> method.isAnnotationPresent(CapabilityMethod.class))
|
||||
.forEach(method -> {
|
||||
Function<Object[], Object> function = args -> {
|
||||
try {
|
||||
return method.invoke(capabilityCoreInstances.get(core), args);
|
||||
return method.invoke(coreInstances.get(core), args);
|
||||
} catch (IllegalAccessException | InvocationTargetException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
@@ -154,12 +196,12 @@ public final class CapabilityRegisterFactory extends AgentBaseFactory {
|
||||
/**
|
||||
* 反射获取<code>CapabilityCore</code>实例
|
||||
*/
|
||||
private void setCapabilityCoreInstances() {
|
||||
private void setCoreInstances() {
|
||||
try {
|
||||
for (Class<?> core : cores) {
|
||||
Constructor<?> constructor = core.getDeclaredConstructor();
|
||||
constructor.setAccessible(true);
|
||||
capabilityCoreInstances.put(core, constructor.newInstance());
|
||||
coreInstances.put(core, constructor.newInstance());
|
||||
}
|
||||
} catch (InvocationTargetException | NoSuchMethodException | InstantiationException |
|
||||
IllegalAccessException e) {
|
||||
|
||||
@@ -5,6 +5,9 @@ import java.lang.annotation.Retention;
|
||||
import java.lang.annotation.RetentionPolicy;
|
||||
import java.lang.annotation.Target;
|
||||
|
||||
/**
|
||||
* Core的协调类,该注解的实现类中如果存在任何{@link CapabilityCore}实例的引用,都将被自动注入
|
||||
*/
|
||||
@Target(ElementType.TYPE)
|
||||
@Retention(RetentionPolicy.RUNTIME)
|
||||
public @interface CoordinateManager {
|
||||
|
||||
@@ -1,16 +1,31 @@
|
||||
package work.slhaf.partner.api.agent.factory.config;
|
||||
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import work.slhaf.partner.api.agent.factory.AgentBaseFactory;
|
||||
import work.slhaf.partner.api.agent.factory.config.exception.ConfigNotExistException;
|
||||
import work.slhaf.partner.api.agent.factory.config.exception.PromptNotExistException;
|
||||
import work.slhaf.partner.api.agent.factory.config.pojo.ModelConfig;
|
||||
import work.slhaf.partner.api.agent.factory.context.AgentRegisterContext;
|
||||
import work.slhaf.partner.api.agent.factory.context.ConfigFactoryContext;
|
||||
import work.slhaf.partner.api.agent.factory.module.ModuleCheckFactory;
|
||||
import work.slhaf.partner.api.agent.runtime.config.AgentConfigManager;
|
||||
import work.slhaf.partner.api.agent.runtime.config.DefaultAgentConfigManager;
|
||||
import work.slhaf.partner.api.agent.runtime.config.FileAgentConfigManager;
|
||||
import work.slhaf.partner.api.chat.pojo.Message;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Set;
|
||||
|
||||
/**
|
||||
* <h2>Agent启动流程 0</h2>
|
||||
* <p>
|
||||
* 通过指定的 {@link AgentConfigManager} 或者默认的 {@link FileAgentConfigManager} 加载配置文件
|
||||
* <p/>
|
||||
*
|
||||
* <p>下一步流程请参阅{@link ModuleCheckFactory}</p>
|
||||
*/
|
||||
@Slf4j
|
||||
public class ConfigLoaderFactory extends AgentBaseFactory {
|
||||
|
||||
private AgentConfigManager agentConfigManager;
|
||||
@@ -24,7 +39,7 @@ public class ConfigLoaderFactory extends AgentBaseFactory {
|
||||
modelPromptMap = factoryContext.getModelPromptMap();
|
||||
|
||||
if (AgentConfigManager.INSTANCE == null) {
|
||||
AgentConfigManager.setINSTANCE(new DefaultAgentConfigManager());
|
||||
AgentConfigManager.setINSTANCE(new FileAgentConfigManager());
|
||||
}
|
||||
|
||||
agentConfigManager = AgentConfigManager.INSTANCE;
|
||||
@@ -33,9 +48,30 @@ public class ConfigLoaderFactory extends AgentBaseFactory {
|
||||
@Override
|
||||
protected void run() {
|
||||
agentConfigManager.load();
|
||||
agentConfigManager.check();
|
||||
modelConfigMap.putAll(agentConfigManager.getModelConfigMap());
|
||||
modelPromptMap.putAll(agentConfigManager.getModelPromptMap());
|
||||
check();
|
||||
}
|
||||
|
||||
/**
|
||||
* 对模型Config与Prompt分别进行检验,除了都必须包含default外,还需要确保数量、key一致,毕竟是模型配置与提示词
|
||||
*/
|
||||
private void check() {
|
||||
log.info("执行config与prompt检测...");
|
||||
if (!modelConfigMap.containsKey("default")) {
|
||||
throw new ConfigNotExistException("缺少默认配置! 需确保存在一个模型配置的key为`default`");
|
||||
}
|
||||
if (!modelPromptMap.containsKey("basic")) {
|
||||
throw new PromptNotExistException("缺少基础Prompt! 需要确保存在key为basic的Prompt文件,它将与其他Prompt共同作用于模块节点。");
|
||||
}
|
||||
Set<String> configKeySet = new HashSet<>(modelConfigMap.keySet());
|
||||
configKeySet.remove("default");
|
||||
Set<String> promptKeySet = new HashSet<>(modelPromptMap.keySet());
|
||||
promptKeySet.remove("basic");
|
||||
if (!promptKeySet.containsAll(configKeySet)) {
|
||||
log.warn("存在未被提示词包含的模型配置,该配置将无法生效!");
|
||||
}
|
||||
//检查提示词数量与`ActivateModel`的实现数量是否一致
|
||||
log.info("检测完毕.");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ package work.slhaf.partner.api.agent.factory.context;
|
||||
import lombok.Data;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.HashSet;
|
||||
import java.util.Set;
|
||||
import java.util.function.Function;
|
||||
|
||||
@@ -12,6 +13,6 @@ public class CapabilityFactoryContext {
|
||||
private final HashMap<String, Function<Object[], Object>> coordinatedMethodsRouterTable = new HashMap<>();
|
||||
private final HashMap<Class<?>, Object> capabilityCoreInstances = new HashMap<>();
|
||||
private final HashMap<Class<?>, Object> capabilityHolderInstances = new HashMap<>();
|
||||
private Set<Class<?>> cores;
|
||||
private Set<Class<?>> capabilities;
|
||||
private Set<Class<?>> cores = new HashSet<>();
|
||||
private Set<Class<?>> capabilities = new HashSet<>();
|
||||
}
|
||||
|
||||
@@ -2,11 +2,13 @@ package work.slhaf.partner.api.agent.factory.context;
|
||||
|
||||
import lombok.Data;
|
||||
import work.slhaf.partner.api.agent.factory.module.pojo.MetaModule;
|
||||
import work.slhaf.partner.api.agent.factory.module.pojo.MetaSubModule;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
@Data
|
||||
public class ModuleFactoryContext {
|
||||
private List<MetaModule> moduleList = new ArrayList<>();
|
||||
private List<MetaModule> agentModuleList = new ArrayList<>();
|
||||
private List<MetaSubModule> agentSubModuleList = new ArrayList<>();
|
||||
}
|
||||
|
||||
@@ -6,8 +6,8 @@ import work.slhaf.partner.api.agent.factory.AgentBaseFactory;
|
||||
import work.slhaf.partner.api.agent.factory.context.AgentRegisterContext;
|
||||
import work.slhaf.partner.api.agent.factory.module.annotation.AfterExecute;
|
||||
import work.slhaf.partner.api.agent.factory.module.annotation.AgentModule;
|
||||
import work.slhaf.partner.api.agent.factory.module.annotation.AgentSubModule;
|
||||
import work.slhaf.partner.api.agent.factory.module.annotation.BeforeExecute;
|
||||
import work.slhaf.partner.api.agent.factory.module.annotation.Init;
|
||||
import work.slhaf.partner.api.agent.factory.module.exception.ModuleCheckException;
|
||||
import work.slhaf.partner.api.agent.runtime.config.AgentConfigManager;
|
||||
import work.slhaf.partner.api.agent.runtime.interaction.flow.abstracts.ActivateModel;
|
||||
@@ -19,6 +19,32 @@ import java.util.HashSet;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
import static work.slhaf.partner.api.agent.util.AgentUtil.getMethodAnnotationTypeSet;
|
||||
|
||||
/**
|
||||
* <h2>Agent启动流程 1</h2>
|
||||
*
|
||||
* <p>
|
||||
* 检查模块部分抽象类与注解、接口的使用方式
|
||||
* </p>
|
||||
*
|
||||
* <ol>
|
||||
* <li>
|
||||
* <p>{@link ModuleCheckFactory#annotationAbstractCheck(Set, Class)}</p>
|
||||
* 所有添加了 {@link AgentModule} 注解的类都将作为Agent的执行模块,为规范模块入口,都必须实现抽象类: {@link AgentRunningModule}; {@link AgentSubModule} 注解所在类则必须实现 {@link AgentRunningSubModule}
|
||||
* </li>
|
||||
* <li>
|
||||
* <p>{@link ModuleCheckFactory#moduleConstructorsCheck(Set)}</p>
|
||||
* 所有 {@link AgentModule} 与 {@link AgentSubModule} 注解所在类都必须具备空参构造方法,初始化逻辑可放在 @Init 注解所处方法中,将在 Capability 与 subModules 注入后才会执行
|
||||
* </li>
|
||||
* <li>
|
||||
* <p>{@link ModuleCheckFactory#activateModelImplCheck()}</p>
|
||||
* 检查实现了 {@link ActivateModel} 的模块数量、名称与prompt是否一致
|
||||
* </li>
|
||||
* </ol>
|
||||
*
|
||||
* <p>下一步流程请参阅{@link ModuleRegisterFactory}</p>
|
||||
*/
|
||||
public class ModuleCheckFactory extends AgentBaseFactory {
|
||||
|
||||
private Reflections reflections;
|
||||
@@ -30,15 +56,43 @@ public class ModuleCheckFactory extends AgentBaseFactory {
|
||||
|
||||
@Override
|
||||
protected void run() {
|
||||
Set<Class<?>> types = reflections.getTypesAnnotatedWith(AgentModule.class);
|
||||
//检查注解AgentModule所在类是否继承了AgentInteractionModule
|
||||
agentModuleAnnotationCheck(types);
|
||||
AnnotatedModules annotatedModules = getAnnotatedModules();
|
||||
ExtendedModules extendedModules = getExtendedModules();
|
||||
checkIfClassCorresponds(annotatedModules, extendedModules);
|
||||
//检查注解AgentModule或AgentSubModule所在类是否继承了对应的抽象类
|
||||
annotationAbstractCheck(annotatedModules.moduleTypes(), AgentRunningModule.class);
|
||||
annotationAbstractCheck(annotatedModules.subModuleTypes(), AgentRunningSubModule.class);
|
||||
//检查AgentModule是否具备无参构造方法
|
||||
moduleConstructorsCheck(types);
|
||||
//检查hook注解所在方法是否位于AgentInteractionModule子类/AgentInteractionSubModule子类/ActivateModel子类
|
||||
hookLocationCheck();
|
||||
moduleConstructorsCheck(annotatedModules.moduleTypes());
|
||||
moduleConstructorsCheck(annotatedModules.subModuleTypes());
|
||||
//检查实现了ActivateModel的模块数量、名称与prompt是否一致
|
||||
activateModelImplCheck();
|
||||
//检查hook注解所在位置是否正确
|
||||
hookLocationCheck();
|
||||
}
|
||||
|
||||
private ExtendedModules getExtendedModules() {
|
||||
Set<Class<?>> moduleTypes = reflections.getSubTypesOf(AgentRunningModule.class)
|
||||
.stream()
|
||||
.filter(ClassUtil::isNormalClass)
|
||||
.collect(Collectors.toSet());
|
||||
Set<Class<?>> subModuleTypes = reflections.getSubTypesOf(AgentRunningSubModule.class)
|
||||
.stream()
|
||||
.filter(ClassUtil::isNormalClass)
|
||||
.collect(Collectors.toSet());
|
||||
return new ExtendedModules(moduleTypes, subModuleTypes);
|
||||
}
|
||||
|
||||
private AnnotatedModules getAnnotatedModules() {
|
||||
Set<Class<?>> moduleTypes = reflections.getTypesAnnotatedWith(AgentModule.class)
|
||||
.stream()
|
||||
.filter(ClassUtil::isNormalClass)
|
||||
.collect(Collectors.toSet());
|
||||
Set<Class<?>> subModuleTypes = reflections.getTypesAnnotatedWith(AgentSubModule.class)
|
||||
.stream()
|
||||
.filter(ClassUtil::isNormalClass)
|
||||
.collect(Collectors.toSet());
|
||||
return new AnnotatedModules(moduleTypes, subModuleTypes);
|
||||
}
|
||||
|
||||
private void moduleConstructorsCheck(Set<Class<?>> types) {
|
||||
@@ -76,24 +130,10 @@ public class ModuleCheckFactory extends AgentBaseFactory {
|
||||
preHookLocationCheck();
|
||||
//检查@Init注解
|
||||
initHookLocationCheck();
|
||||
//检查@AgentModule注解是否只位于普通类上
|
||||
agentModuleLocationCheck();
|
||||
}
|
||||
|
||||
private void agentModuleLocationCheck() {
|
||||
Set<Class<?>> types = reflections.getTypesAnnotatedWith(AgentModule.class);
|
||||
for (Class<?> type : types) {
|
||||
if (!ClassUtil.isNormalClass(type)) {
|
||||
throw new ModuleCheckException("AgentModule 注解仅能位于普通类上! 异常类信息: " + type.getSimpleName());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private void initHookLocationCheck() {
|
||||
Set<Method> methods = reflections.getMethodsAnnotatedWith(Init.class);
|
||||
Set<Class<?>> types = methods.stream()
|
||||
.map(Method::getDeclaringClass)
|
||||
.collect(Collectors.toSet());
|
||||
Set<Class<?>> types = getMethodAnnotationTypeSet(AgentModule.class, reflections);
|
||||
checkLocation(types);
|
||||
}
|
||||
|
||||
@@ -129,15 +169,59 @@ public class ModuleCheckFactory extends AgentBaseFactory {
|
||||
}
|
||||
}
|
||||
|
||||
private void agentModuleAnnotationCheck(Set<Class<?>> types) {
|
||||
private void annotationAbstractCheck(Set<Class<?>> types, Class<?> clazz) {
|
||||
for (Class<?> type : types) {
|
||||
if (type.isAnnotation()) {
|
||||
continue;
|
||||
}
|
||||
if (AgentRunningModule.class.isAssignableFrom(type) && ClassUtil.isNormalClass(type)) {
|
||||
if (clazz.isAssignableFrom(type) && ClassUtil.isNormalClass(type)) {
|
||||
continue;
|
||||
}
|
||||
throw new ModuleCheckException("存在未继承AgentInteractionModule.class的AgentModule实现: " + type.getSimpleName());
|
||||
}
|
||||
}
|
||||
|
||||
private void checkIfClassCorresponds(AnnotatedModules annotatedModules, ExtendedModules extendedModules) {
|
||||
// 检查是否有被@AgentModule注解但没有继承AgentRunningModule的类
|
||||
checkSets(annotatedModules.moduleTypes(), extendedModules.moduleTypes(),
|
||||
"存在被@AgentModule注解但未继承AgentRunningModule的类");
|
||||
|
||||
// 检查是否有继承AgentRunningModule但没有被@AgentModule注解的类
|
||||
checkSets(extendedModules.moduleTypes(), annotatedModules.moduleTypes(),
|
||||
"存在继承AgentRunningModule但未被@AgentModule注解的类");
|
||||
|
||||
// 检查是否有被@AgentSubModule注解但没有继承AgentRunningSubModule的类
|
||||
checkSets(annotatedModules.subModuleTypes(), extendedModules.subModuleTypes(),
|
||||
"存在被@AgentSubModule注解但未继承AgentRunningSubModule的类");
|
||||
|
||||
// 检查是否有继承AgentRunningSubModule但没有被@AgentSubModule注解的类
|
||||
checkSets(extendedModules.subModuleTypes(), annotatedModules.subModuleTypes(),
|
||||
"存在继承AgentRunningSubModule但未被@AgentSubModule注解的类");
|
||||
}
|
||||
|
||||
/**
|
||||
* 检查源集合中是否有不在目标集合中的元素
|
||||
* @param source 源集合
|
||||
* @param target 目标集合
|
||||
* @param errorMessage 错误信息前缀
|
||||
*/
|
||||
private void checkSets(Set<Class<?>> source, Set<Class<?>> target, String errorMessage) {
|
||||
// 只有在需要时才创建HashSet以节省内存
|
||||
if (!target.containsAll(source)) {
|
||||
// 使用流式处理找出差异部分,避免创建完整的中间集合
|
||||
String classNames = source.stream()
|
||||
.filter(clazz -> !target.contains(clazz))
|
||||
.map(Class::getSimpleName)
|
||||
.limit(10) // 限制显示数量,避免信息泄露
|
||||
.collect(Collectors.joining(", ", "[", "]"));
|
||||
|
||||
throw new ModuleCheckException(errorMessage + ": " + classNames);
|
||||
}
|
||||
}
|
||||
|
||||
private record AnnotatedModules(Set<Class<?>> moduleTypes, Set<Class<?>> subModuleTypes) {
|
||||
}
|
||||
|
||||
private record ExtendedModules(Set<Class<?>> moduleTypes, Set<Class<?>> subModuleTypes) {
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,13 +1,19 @@
|
||||
package work.slhaf.partner.api.agent.factory.module;
|
||||
|
||||
import work.slhaf.partner.api.agent.factory.AgentBaseFactory;
|
||||
import work.slhaf.partner.api.agent.factory.AgentRegisterFactory;
|
||||
import work.slhaf.partner.api.agent.factory.context.AgentRegisterContext;
|
||||
import work.slhaf.partner.api.agent.factory.context.ModuleFactoryContext;
|
||||
import work.slhaf.partner.api.agent.factory.module.annotation.Init;
|
||||
import work.slhaf.partner.api.agent.factory.module.exception.ModuleInitHookExecuteFailedException;
|
||||
import work.slhaf.partner.api.agent.factory.module.pojo.BaseMetaModule;
|
||||
import work.slhaf.partner.api.agent.factory.module.pojo.MetaMethod;
|
||||
import work.slhaf.partner.api.agent.factory.module.pojo.MetaModule;
|
||||
import work.slhaf.partner.api.agent.factory.module.pojo.MetaSubModule;
|
||||
import work.slhaf.partner.api.agent.runtime.interaction.flow.abstracts.AgentRunningModule;
|
||||
import work.slhaf.partner.api.agent.runtime.interaction.flow.abstracts.AgentRunningSubModule;
|
||||
import work.slhaf.partner.api.agent.runtime.interaction.flow.abstracts.Module;
|
||||
import work.slhaf.partner.api.agent.util.AgentUtil;
|
||||
|
||||
import java.lang.reflect.InvocationTargetException;
|
||||
import java.util.Arrays;
|
||||
@@ -20,39 +26,61 @@ import static work.slhaf.partner.api.agent.util.AgentUtil.collectExtendedClasses
|
||||
import static work.slhaf.partner.api.agent.util.AgentUtil.methodSignature;
|
||||
|
||||
/**
|
||||
* 负责执行前hook逻辑
|
||||
* <h2>Agent启动流程 7</h2>
|
||||
*
|
||||
* <p>负责执行初始化hook逻辑,即 {@link Init} 注解所在方法</p>
|
||||
*
|
||||
* <ol>
|
||||
* <li>
|
||||
* <p>{@link ModuleInitHookExecuteFactory#collectInitHookMethods(Class, Class)}</p>
|
||||
* 分别遍历前置模块拿到的模块列表({@link ModuleInitHookExecuteFactory#moduleList}, {@link ModuleInitHookExecuteFactory#subModuleList}),通过 {@link AgentUtil#collectExtendedClasses(Class, Class)} 收集到当前模块类的继承链上的所有类后,收集其所有带有 {@link Init} 注解的方法
|
||||
* </li>
|
||||
* <li>
|
||||
* <p>{@link ModuleInitHookExecuteFactory#proceedInitMethods(BaseMetaModule, List)}</p>
|
||||
* 收集好初始化方法后,将通过反射执行该方法,所用实例即为前置模块中收集到的执行模块与子模块的 {@link MetaModule} 与 {@link MetaSubModule} 内容
|
||||
* </li>
|
||||
* </ol>
|
||||
*
|
||||
* <p>Agent启动流程到此进行完毕。整个工厂执行链中均为针对 {@link AgentRegisterContext} 进行的操作,在 {@link AgentRegisterFactory} 中,将进行最终处理以及将必要内容进行传递。</p>
|
||||
*/
|
||||
public class ModuleInitHookExecuteFactory extends AgentBaseFactory {
|
||||
|
||||
private List<MetaModule> moduleList;
|
||||
private List<MetaSubModule> subModuleList;
|
||||
|
||||
@Override
|
||||
protected void setVariables(AgentRegisterContext context) {
|
||||
ModuleFactoryContext factoryContext = context.getModuleFactoryContext();
|
||||
moduleList = factoryContext.getModuleList();
|
||||
moduleList = factoryContext.getAgentModuleList();
|
||||
subModuleList = factoryContext.getAgentSubModuleList();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void run() throws NoSuchMethodException, InvocationTargetException, InstantiationException, IllegalAccessException {
|
||||
protected void run() {
|
||||
//遍历模块列表,并向上查找@Init注解
|
||||
for (MetaSubModule metaSubModule : subModuleList) {
|
||||
List<MetaMethod> initHookMethods = collectInitHookMethods(metaSubModule.getClazz(),AgentRunningModule.class);
|
||||
proceedInitMethods(metaSubModule, initHookMethods);
|
||||
}
|
||||
|
||||
for (MetaModule metaModule : moduleList) {
|
||||
List<MetaMethod> initHookMethods = collectInitHookMethods(metaModule.getClazz());
|
||||
List<MetaMethod> initHookMethods = collectInitHookMethods(metaModule.getClazz(), AgentRunningSubModule.class);
|
||||
proceedInitMethods(metaModule, initHookMethods);
|
||||
}
|
||||
}
|
||||
|
||||
private void proceedInitMethods(MetaModule metaModule, List<MetaMethod> initHookMethods) {
|
||||
private void proceedInitMethods(BaseMetaModule metaModule, List<MetaMethod> initHookMethods) {
|
||||
for (MetaMethod metaMethod : initHookMethods) {
|
||||
try {
|
||||
metaMethod.getMethod().invoke(metaModule.getInstance());
|
||||
} catch (IllegalAccessException | InvocationTargetException e) {
|
||||
throw new ModuleInitHookExecuteFailedException("模块的init hook方法执行失败! 模块: " + metaModule.getName() + " 方法签名: " + methodSignature(metaMethod.getMethod()), e);
|
||||
throw new ModuleInitHookExecuteFailedException("模块的init hook方法执行失败! 模块: " + metaModule.getClazz().getSimpleName() + " 方法签名: " + methodSignature(metaMethod.getMethod()), e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private List<MetaMethod> collectInitHookMethods(Class<?> clazz) {
|
||||
Set<Class<?>> classes = collectExtendedClasses(clazz, AgentRunningModule.class);
|
||||
private List<MetaMethod> collectInitHookMethods(Class<?> clazz, Class<? extends Module> target) {
|
||||
Set<Class<?>> classes = collectExtendedClasses(clazz, target);
|
||||
return classes.stream()
|
||||
.map(Class::getDeclaredMethods)
|
||||
.flatMap(Arrays::stream)
|
||||
|
||||
@@ -1,21 +1,29 @@
|
||||
package work.slhaf.partner.api.agent.factory.module;
|
||||
|
||||
import lombok.Getter;
|
||||
import net.bytebuddy.ByteBuddy;
|
||||
import net.bytebuddy.implementation.MethodDelegation;
|
||||
import net.bytebuddy.implementation.bind.annotation.*;
|
||||
import net.bytebuddy.matcher.ElementMatchers;
|
||||
import work.slhaf.partner.api.agent.factory.AgentBaseFactory;
|
||||
import work.slhaf.partner.api.agent.factory.capability.CapabilityCheckFactory;
|
||||
import work.slhaf.partner.api.agent.factory.context.AgentRegisterContext;
|
||||
import work.slhaf.partner.api.agent.factory.context.CapabilityFactoryContext;
|
||||
import work.slhaf.partner.api.agent.factory.context.ModuleFactoryContext;
|
||||
import work.slhaf.partner.api.agent.factory.module.annotation.AfterExecute;
|
||||
import work.slhaf.partner.api.agent.factory.module.annotation.BeforeExecute;
|
||||
import work.slhaf.partner.api.agent.factory.module.annotation.InjectModule;
|
||||
import work.slhaf.partner.api.agent.factory.module.exception.ModuleInstanceGenerateFailedException;
|
||||
import work.slhaf.partner.api.agent.factory.module.exception.ModuleProxyGenerateFailedException;
|
||||
import work.slhaf.partner.api.agent.factory.module.exception.ProxiedModuleRunningException;
|
||||
import work.slhaf.partner.api.agent.factory.module.pojo.BaseMetaModule;
|
||||
import work.slhaf.partner.api.agent.factory.module.pojo.MetaMethod;
|
||||
import work.slhaf.partner.api.agent.factory.module.pojo.MetaModule;
|
||||
import work.slhaf.partner.api.agent.factory.module.pojo.MetaSubModule;
|
||||
import work.slhaf.partner.api.agent.runtime.interaction.flow.abstracts.AgentRunningModule;
|
||||
import work.slhaf.partner.api.agent.runtime.interaction.flow.abstracts.AgentRunningSubModule;
|
||||
import work.slhaf.partner.api.agent.runtime.interaction.flow.abstracts.Module;
|
||||
|
||||
import java.lang.reflect.InvocationTargetException;
|
||||
import java.lang.reflect.Method;
|
||||
import java.util.*;
|
||||
import java.util.concurrent.Callable;
|
||||
@@ -24,50 +32,124 @@ import java.util.stream.Collectors;
|
||||
import static work.slhaf.partner.api.agent.util.AgentUtil.collectExtendedClasses;
|
||||
|
||||
/**
|
||||
* 通过扫描注解<code>@BeforeExecute</code>,获取到各个模块的后hook逻辑并通过动态代理添加到执行逻辑之后
|
||||
* <h2>Agent启动流程 3</h2>
|
||||
*
|
||||
* <p>
|
||||
* 扫描前置模块各个hook注解生成代理对象,放入对应的list中并按照类型为键放入 {@link ModuleProxyFactory#capabilityHolderInstances} 中供后续完成能力(capability)注入
|
||||
* <p/>
|
||||
*
|
||||
* <ol>
|
||||
*
|
||||
* <li>
|
||||
* <p>{@link ModuleProxyFactory#createProxiedInstances()}</p>
|
||||
* 根据moduleList中的类型信息,向上查找继承链获取所有hook方法收集为{@link MethodsListRecord},然后通过ByteBuddy根据收集到的preHook与postHook生成代理对象,放入对应的 {@link MetaModule} 对象以及 instanceMap 中
|
||||
* </li>
|
||||
* <li>
|
||||
* <p>{@link ModuleProxyFactory#injectSubModule()}</p>
|
||||
* 通过反射将子模块实例注入到执行模块中带有注解 {@link InjectModule} 的字段
|
||||
* </li>
|
||||
* </ol>
|
||||
*
|
||||
* <p>下一步流程请参阅{@link CapabilityCheckFactory}</p>
|
||||
*/
|
||||
public class ModuleProxyFactory extends AgentBaseFactory {
|
||||
|
||||
private List<MetaModule> moduleList;
|
||||
private List<MetaSubModule> subModuleList;
|
||||
private HashMap<Class<?>, Object> capabilityHolderInstances;
|
||||
private final HashMap<Class<?>, Object> subModuleInstances = new HashMap<>();
|
||||
private final HashMap<Class<?>, Object> moduleInstances = new HashMap<>();
|
||||
|
||||
@Override
|
||||
protected void setVariables(AgentRegisterContext context) {
|
||||
ModuleFactoryContext factoryContext = context.getModuleFactoryContext();
|
||||
moduleList = factoryContext.getModuleList();
|
||||
CapabilityFactoryContext capabilityFactoryContext = context.getCapabilityFactoryContext();
|
||||
moduleList = factoryContext.getAgentModuleList();
|
||||
subModuleList = factoryContext.getAgentSubModuleList();
|
||||
capabilityHolderInstances = capabilityFactoryContext.getCapabilityHolderInstances();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void run() throws NoSuchMethodException, InvocationTargetException, InstantiationException, IllegalAccessException {
|
||||
generateInstances();
|
||||
setHookProxy();
|
||||
protected void run() {
|
||||
createProxiedInstances();
|
||||
injectSubModule();
|
||||
}
|
||||
|
||||
private void setHookProxy() {
|
||||
private void injectSubModule() {
|
||||
for (MetaModule module : moduleList) {
|
||||
//因为实际上ByteBuddy生成的是module.getClazz()的子类,所以应当使用getDeclaredFields()获取字段
|
||||
Arrays.stream(module.getClazz().getDeclaredFields())
|
||||
.filter(field -> field.isAnnotationPresent(InjectModule.class))
|
||||
.forEach(field -> {
|
||||
try {
|
||||
field.setAccessible(true);
|
||||
field.set(
|
||||
moduleInstances.get(module.getClazz()),
|
||||
subModuleInstances.get(field.getType())
|
||||
);
|
||||
} catch (IllegalAccessException e) {
|
||||
throw new ModuleInstanceGenerateFailedException("模块实例注入失败", e);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
private void createProxiedInstances() {
|
||||
generateModuleProxy(moduleList, AgentRunningModule.class);
|
||||
generateModuleProxy(subModuleList, AgentRunningSubModule.class);
|
||||
updateInstanceMap(moduleInstances, moduleList);
|
||||
updateInstanceMap(subModuleInstances, subModuleList);
|
||||
updateCapabilityHolderInstances();
|
||||
}
|
||||
|
||||
private void updateCapabilityHolderInstances() {
|
||||
capabilityHolderInstances.putAll(moduleInstances);
|
||||
capabilityHolderInstances.putAll(subModuleInstances);
|
||||
}
|
||||
|
||||
private void updateInstanceMap(HashMap<Class<?>, Object> instanceMap, List<? extends BaseMetaModule> list) {
|
||||
for (BaseMetaModule baseMetaModule : list) {
|
||||
instanceMap.put(baseMetaModule.getClazz(), baseMetaModule.getInstance());
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
private void generateModuleProxy(List<? extends BaseMetaModule> list, Class<? extends Module> overrideSource) {
|
||||
for (BaseMetaModule module : list) {
|
||||
Class<?> clazz = module.getClazz();
|
||||
try {
|
||||
MethodsListRecord record = collectHookMethods(clazz);
|
||||
//生成实例
|
||||
generateProxiedInstances(record, module);
|
||||
generateProxiedInstances(record, module, overrideSource);
|
||||
} catch (Exception e) {
|
||||
throw new ModuleProxyGenerateFailedException("创建代理对象失败: " + clazz.getSimpleName(), e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private void generateProxiedInstances(MethodsListRecord record, MetaModule metaModule) {
|
||||
private void generateProxiedInstances(MethodsListRecord record, BaseMetaModule module, Class<? extends Module> overrideSource) {
|
||||
try {
|
||||
Class<? extends AgentRunningModule> clazz = metaModule.getClazz();
|
||||
Class<? extends AgentRunningModule> proxyClass = new ByteBuddy()
|
||||
Class<? extends Module> clazz = module.getClazz();
|
||||
Class<? extends Module> proxyClass = new ByteBuddy()
|
||||
.subclass(clazz)
|
||||
.method(ElementMatchers.isOverriddenFrom(AgentRunningModule.class))
|
||||
.method(ElementMatchers.isOverriddenFrom(overrideSource))
|
||||
.intercept(MethodDelegation.to(new ModuleProxyInterceptor(record.post, record.pre)))
|
||||
.make()
|
||||
.load(ModuleProxyFactory.class.getClassLoader())
|
||||
.getLoaded();
|
||||
metaModule.setInstance(proxyClass.getConstructor().newInstance());
|
||||
|
||||
// new ByteBuddy()
|
||||
// .subclass(clazz)
|
||||
// .method(ElementMatchers.isOverriddenFrom(overrideSource))
|
||||
// .intercept(MethodDelegation.to(new ModuleProxyInterceptor(record.post, record.pre)))
|
||||
//
|
||||
// .make()
|
||||
// .saveIn(new File("./generated-classes"));
|
||||
|
||||
module.setInstance(proxyClass.getConstructor().newInstance());
|
||||
} catch (Exception e) {
|
||||
throw new ModuleProxyGenerateFailedException("模块Hook代理生成失败! 代理失败的模块名: " + metaModule.getClazz().getSimpleName(), e);
|
||||
throw new ModuleProxyGenerateFailedException("模块Hook代理生成失败! 代理失败的模块名: " + module.getClazz().getSimpleName(), e);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -76,8 +158,8 @@ public class ModuleProxyFactory extends AgentBaseFactory {
|
||||
List<MetaMethod> pre = new ArrayList<>();
|
||||
//获取该类本身的hook逻辑
|
||||
collectHookMethods(post, pre, clazz);
|
||||
//获取它所继承、实现的抽象类或接口, 以AgentInteractionModule、ActiveModel为终点
|
||||
Set<Class<?>> classes = collectExtendedClasses(clazz, AgentRunningModule.class);
|
||||
//获取它所继承、实现的抽象类或接口, 以Module为终点,收集继承链上所有父类和接口
|
||||
Set<Class<?>> classes = collectExtendedClasses(clazz, Module.class);
|
||||
//获取这些类中的hook逻辑
|
||||
collectHookMethods(post, pre, classes);
|
||||
return new MethodsListRecord(post, pre);
|
||||
@@ -118,7 +200,7 @@ public class ModuleProxyFactory extends AgentBaseFactory {
|
||||
|
||||
|
||||
private void collectHookMethods(List<MetaMethod> post, List<MetaMethod> pre, Class<?> clazz) {
|
||||
Method[] methods = clazz.getMethods();
|
||||
Method[] methods = clazz.getDeclaredMethods();
|
||||
for (Method method : methods) {
|
||||
if (method.isAnnotationPresent(BeforeExecute.class)) {
|
||||
MetaMethod metaMethod = new MetaMethod();
|
||||
@@ -134,30 +216,38 @@ public class ModuleProxyFactory extends AgentBaseFactory {
|
||||
}
|
||||
}
|
||||
|
||||
private void generateInstances() {
|
||||
for (MetaModule metaModule : moduleList) {
|
||||
@Getter
|
||||
@SuppressWarnings("ClassCanBeRecord")
|
||||
public static class ModuleProxyInterceptor {
|
||||
|
||||
private final List<MetaMethod> postHookMethods;
|
||||
private final List<MetaMethod> preHookMethods;
|
||||
|
||||
public ModuleProxyInterceptor(List<MetaMethod> postHookMethods, List<MetaMethod> preHookMethods) {
|
||||
this.postHookMethods = postHookMethods;
|
||||
this.preHookMethods = preHookMethods;
|
||||
}
|
||||
|
||||
@RuntimeType
|
||||
public Object intercept(@Origin Method method, @AllArguments Object[] allArguments, @SuperCall Callable<?> zuper, @This Object proxy) throws Exception {
|
||||
executeHookMethods(preHookMethods, proxy);
|
||||
Object res = zuper.call();
|
||||
executeHookMethods(postHookMethods, proxy);
|
||||
return res;
|
||||
}
|
||||
|
||||
private void executeHookMethods(List<MetaMethod> hookMethods, Object proxy) {
|
||||
for (MetaMethod metaMethod : hookMethods) {
|
||||
Method m = metaMethod.getMethod();
|
||||
try {
|
||||
Class<? extends AgentRunningModule> clazz = metaModule.getClazz();
|
||||
AgentRunningModule instance = clazz.getConstructor().newInstance();
|
||||
metaModule.setInstance(instance);
|
||||
m.setAccessible(true);
|
||||
m.invoke(proxy);
|
||||
} catch (Exception e) {
|
||||
throw new ModuleInstanceGenerateFailedException("模块实例构造失败:" + e.getMessage());
|
||||
throw new ProxiedModuleRunningException("hook方法执行异常: " + m.getDeclaringClass() + "#" + m.getName(), e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private record ModuleProxyInterceptor(List<MetaMethod> postHookMethods, List<MetaMethod> preHookMethods) {
|
||||
@RuntimeType
|
||||
public Object intercept(@Origin Method method, @AllArguments Object[] allArguments, @SuperCall Callable<?> zuper, @This Object proxy) throws Exception {
|
||||
for (MetaMethod metaMethod : preHookMethods) {
|
||||
metaMethod.getMethod().invoke(proxy);
|
||||
}
|
||||
Object res = zuper.call();
|
||||
for (MetaMethod metaMethod : postHookMethods) {
|
||||
metaMethod.getMethod().invoke(proxy);
|
||||
}
|
||||
return res;
|
||||
}
|
||||
}
|
||||
|
||||
record MethodsListRecord(List<MetaMethod> post, List<MetaMethod> pre) {
|
||||
|
||||
@@ -6,31 +6,71 @@ import work.slhaf.partner.api.agent.factory.AgentBaseFactory;
|
||||
import work.slhaf.partner.api.agent.factory.context.AgentRegisterContext;
|
||||
import work.slhaf.partner.api.agent.factory.context.ModuleFactoryContext;
|
||||
import work.slhaf.partner.api.agent.factory.module.annotation.AgentModule;
|
||||
import work.slhaf.partner.api.agent.factory.module.annotation.AgentSubModule;
|
||||
import work.slhaf.partner.api.agent.factory.module.annotation.CoreModule;
|
||||
import work.slhaf.partner.api.agent.factory.module.pojo.MetaModule;
|
||||
import work.slhaf.partner.api.agent.factory.module.pojo.MetaSubModule;
|
||||
import work.slhaf.partner.api.agent.runtime.interaction.flow.abstracts.AgentRunningModule;
|
||||
import work.slhaf.partner.api.agent.runtime.interaction.flow.abstracts.AgentRunningSubModule;
|
||||
|
||||
import java.util.Comparator;
|
||||
import java.util.List;
|
||||
import java.util.Set;
|
||||
|
||||
/**
|
||||
* 负责扫描<code>@Module</code>注解获取模块实例
|
||||
* <h2>Agent启动流程 2</h2>
|
||||
*
|
||||
* <p>
|
||||
* 负责收集 {@link AgentModule} 与 {@link AgentSubModule} 注解所在类的信息,供后续工厂完成动态代理、模块与能力注入
|
||||
* <p/>
|
||||
*
|
||||
* <ol>
|
||||
* <li>
|
||||
* <p>{@link ModuleRegisterFactory#setModuleList()}</p>
|
||||
* 扫描 {@link AgentModule} 注解,获取执行模块信息: 类型、模块名称({@link AgentModule#name()}),执行顺序。并按照注解的 {@link AgentModule#order()} 字段进行排序
|
||||
* </li>
|
||||
* <li>
|
||||
* <p>{@link ModuleRegisterFactory#setSubModuleList()}</p>
|
||||
* 扫描 {@link AgentSubModule} 注册,获取子模块类型信息
|
||||
* </li>
|
||||
* <li>
|
||||
* 两种模块都将存入各自的list中,供后续模块完成注册与注入
|
||||
* </li>
|
||||
* </ol>
|
||||
*
|
||||
* <p>下一步流程请参阅{@link ModuleProxyFactory}</p>
|
||||
*/
|
||||
public class ModuleRegisterFactory extends AgentBaseFactory {
|
||||
|
||||
private Reflections reflections;
|
||||
private List<MetaModule> moduleList;
|
||||
private List<MetaSubModule> subModuleList;
|
||||
|
||||
@Override
|
||||
protected void setVariables(AgentRegisterContext context) {
|
||||
ModuleFactoryContext factoryContext = context.getModuleFactoryContext();
|
||||
reflections = context.getReflections();
|
||||
moduleList = factoryContext.getModuleList();
|
||||
moduleList = factoryContext.getAgentModuleList();
|
||||
subModuleList = factoryContext.getAgentSubModuleList();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void run() {
|
||||
setModuleList();
|
||||
setSubModuleList();
|
||||
}
|
||||
|
||||
private void setSubModuleList() {
|
||||
Set<Class<?>> subModules = reflections.getTypesAnnotatedWith(AgentSubModule.class);
|
||||
for (Class<?> subModule : subModules) {
|
||||
if (!ClassUtil.isNormalClass(subModule)) {
|
||||
continue;
|
||||
}
|
||||
Class<? extends AgentRunningSubModule> clazz = subModule.asSubclass(AgentRunningSubModule.class);
|
||||
MetaSubModule metaSubModule = new MetaSubModule();
|
||||
metaSubModule.setClazz(clazz);
|
||||
subModuleList.add(metaSubModule);
|
||||
}
|
||||
}
|
||||
|
||||
private void setModuleList() {
|
||||
@@ -41,13 +81,24 @@ public class ModuleRegisterFactory extends AgentBaseFactory {
|
||||
continue;
|
||||
}
|
||||
Class<? extends AgentRunningModule> clazz = module.asSubclass(AgentRunningModule.class);
|
||||
AgentModule agentModule = clazz.getAnnotation(AgentModule.class);
|
||||
MetaModule metaModule = new MetaModule();
|
||||
metaModule.setName(agentModule.name());
|
||||
metaModule.setOrder(agentModule.order());
|
||||
metaModule.setClazz(clazz);
|
||||
MetaModule metaModule = getMetaModule(clazz);
|
||||
moduleList.add(metaModule);
|
||||
}
|
||||
moduleList.sort(Comparator.comparing(MetaModule::getOrder));
|
||||
}
|
||||
|
||||
private static MetaModule getMetaModule(Class<? extends AgentRunningModule> clazz) {
|
||||
MetaModule metaModule = new MetaModule();
|
||||
AgentModule agentModule;
|
||||
if (clazz.isAnnotationPresent(CoreModule.class)){
|
||||
agentModule = CoreModule.class.getAnnotation(AgentModule.class);
|
||||
}else{
|
||||
agentModule = clazz.getAnnotation(AgentModule.class);
|
||||
}
|
||||
metaModule.setName(agentModule.name());
|
||||
metaModule.setOrder(agentModule.order());
|
||||
metaModule.setClazz(clazz);
|
||||
return metaModule;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -1,14 +1,17 @@
|
||||
package work.slhaf.partner.api.agent.factory.module.annotation;
|
||||
|
||||
|
||||
import work.slhaf.partner.api.agent.factory.capability.annotation.CapabilityHolder;
|
||||
|
||||
import java.lang.annotation.*;
|
||||
|
||||
/**
|
||||
* 用于注解执行模块
|
||||
*/
|
||||
@Inherited
|
||||
@Target(ElementType.TYPE)
|
||||
@Retention(RetentionPolicy.RUNTIME)
|
||||
@CapabilityHolder
|
||||
@Inherited
|
||||
public @interface AgentModule {
|
||||
|
||||
/**
|
||||
|
||||
@@ -0,0 +1,14 @@
|
||||
package work.slhaf.partner.api.agent.factory.module.annotation;
|
||||
|
||||
import work.slhaf.partner.api.agent.factory.capability.annotation.CapabilityHolder;
|
||||
|
||||
import java.lang.annotation.ElementType;
|
||||
import java.lang.annotation.Retention;
|
||||
import java.lang.annotation.RetentionPolicy;
|
||||
import java.lang.annotation.Target;
|
||||
|
||||
@Target(ElementType.TYPE)
|
||||
@Retention(RetentionPolicy.RUNTIME)
|
||||
@CapabilityHolder
|
||||
public @interface AgentSubModule {
|
||||
}
|
||||
@@ -1,5 +1,9 @@
|
||||
package work.slhaf.partner.api.agent.factory.module.annotation;
|
||||
|
||||
import java.lang.annotation.Retention;
|
||||
import java.lang.annotation.RetentionPolicy;
|
||||
|
||||
@Retention(RetentionPolicy.RUNTIME)
|
||||
@AgentModule(name = "core",order = 5)
|
||||
public @interface CoreModule {
|
||||
}
|
||||
|
||||
@@ -0,0 +1,11 @@
|
||||
package work.slhaf.partner.api.agent.factory.module.annotation;
|
||||
|
||||
import java.lang.annotation.ElementType;
|
||||
import java.lang.annotation.Retention;
|
||||
import java.lang.annotation.RetentionPolicy;
|
||||
import java.lang.annotation.Target;
|
||||
|
||||
@Target(ElementType.FIELD)
|
||||
@Retention(RetentionPolicy.RUNTIME)
|
||||
public @interface InjectModule {
|
||||
}
|
||||
@@ -0,0 +1,13 @@
|
||||
package work.slhaf.partner.api.agent.factory.module.exception;
|
||||
|
||||
import work.slhaf.partner.api.agent.runtime.exception.AgentRuntimeException;
|
||||
|
||||
public class ProxiedModuleRunningException extends AgentRuntimeException {
|
||||
public ProxiedModuleRunningException(String message) {
|
||||
super(message);
|
||||
}
|
||||
|
||||
public ProxiedModuleRunningException(String message, Throwable cause) {
|
||||
super(message, cause);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,10 @@
|
||||
package work.slhaf.partner.api.agent.factory.module.pojo;
|
||||
|
||||
import lombok.Data;
|
||||
import work.slhaf.partner.api.agent.runtime.interaction.flow.abstracts.Module;
|
||||
|
||||
@Data
|
||||
public abstract class BaseMetaModule <C extends Module> {
|
||||
private Class<? extends C> clazz;
|
||||
private C instance;
|
||||
}
|
||||
@@ -1,13 +1,13 @@
|
||||
package work.slhaf.partner.api.agent.factory.module.pojo;
|
||||
|
||||
import lombok.Data;
|
||||
import lombok.EqualsAndHashCode;
|
||||
import work.slhaf.partner.api.agent.runtime.interaction.flow.abstracts.AgentRunningModule;
|
||||
|
||||
@EqualsAndHashCode(callSuper = true)
|
||||
@Data
|
||||
public class MetaModule {
|
||||
public class MetaModule extends BaseMetaModule<AgentRunningModule>{
|
||||
private String name;
|
||||
private int order;
|
||||
private Class<? extends AgentRunningModule> clazz;
|
||||
private AgentRunningModule instance;
|
||||
private boolean enabled = true;
|
||||
}
|
||||
|
||||
@@ -0,0 +1,10 @@
|
||||
package work.slhaf.partner.api.agent.factory.module.pojo;
|
||||
|
||||
import lombok.Data;
|
||||
import lombok.EqualsAndHashCode;
|
||||
import work.slhaf.partner.api.agent.runtime.interaction.flow.abstracts.AgentRunningSubModule;
|
||||
|
||||
@EqualsAndHashCode(callSuper = true)
|
||||
@Data
|
||||
public class MetaSubModule extends BaseMetaModule<AgentRunningSubModule>{
|
||||
}
|
||||
@@ -3,30 +3,27 @@ package work.slhaf.partner.api.agent.runtime.config;
|
||||
import lombok.Data;
|
||||
import lombok.Setter;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import work.slhaf.partner.api.agent.factory.config.exception.ConfigNotExistException;
|
||||
import work.slhaf.partner.api.agent.factory.config.exception.ConfigUpdateFailedException;
|
||||
import work.slhaf.partner.api.agent.factory.config.exception.PromptNotExistException;
|
||||
import work.slhaf.partner.api.agent.factory.config.pojo.ModelConfig;
|
||||
import work.slhaf.partner.api.agent.factory.module.pojo.MetaModule;
|
||||
import work.slhaf.partner.api.chat.pojo.Message;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Set;
|
||||
import java.util.*;
|
||||
|
||||
@Slf4j
|
||||
@Data
|
||||
public abstract class AgentConfigManager {
|
||||
|
||||
@Setter
|
||||
public static AgentConfigManager INSTANCE;
|
||||
public static AgentConfigManager INSTANCE = new FileAgentConfigManager();
|
||||
private static final String DEFAULT_KEY = "default";
|
||||
|
||||
protected HashMap<String, ModelConfig> modelConfigMap;
|
||||
protected HashMap<String, List<Message>> modelPromptMap;
|
||||
protected HashMap<String, Boolean> moduleEnabledStatus;
|
||||
protected List<MetaModule> moduleList;
|
||||
protected Map<Integer, List<MetaModule>> moduleOrderedMap = new LinkedHashMap<>();
|
||||
protected Map<String, MetaModule> moduleMap = new HashMap<>();
|
||||
|
||||
public void load() {
|
||||
modelConfigMap = loadModelConfig();
|
||||
@@ -41,11 +38,24 @@ public abstract class AgentConfigManager {
|
||||
|
||||
protected abstract void dumpModuleEnabledStatus();
|
||||
|
||||
protected abstract HashMap<String, Boolean> loadModuleEnabledStatusMap();
|
||||
protected abstract HashMap<String, Boolean> loadModuleEnabledStatusMap(List<MetaModule> moduleList);
|
||||
|
||||
public void moduleEnabledStatusFilterAndRecord(List<MetaModule> moduleList) {
|
||||
this.moduleList = moduleList;
|
||||
this.moduleEnabledStatus = loadModuleEnabledStatusMap();
|
||||
updateModuleMap(moduleList);
|
||||
updateModuleEnabledStatus(moduleList);
|
||||
}
|
||||
|
||||
private void updateModuleMap(List<MetaModule> moduleList) {
|
||||
//在ModuleRegisterFactory已进行过排序操作
|
||||
for (MetaModule module : moduleList) {
|
||||
int k = module.getOrder();
|
||||
moduleOrderedMap.computeIfAbsent(k, order -> new ArrayList<>()).add(module);
|
||||
moduleMap.put(module.getName(), module);
|
||||
}
|
||||
}
|
||||
|
||||
private void updateModuleEnabledStatus(List<MetaModule> moduleList) {
|
||||
this.moduleEnabledStatus = loadModuleEnabledStatusMap(moduleList);
|
||||
|
||||
boolean unmatch = false;
|
||||
for (MetaModule metaModule : moduleList) {
|
||||
@@ -62,27 +72,6 @@ public abstract class AgentConfigManager {
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 对模型Config与Prompt分别进行检验,除了都必须包含default外,还需要确保数量、key一致,毕竟是模型配置与提示词
|
||||
*/
|
||||
public void check() {
|
||||
log.info("[AgentConfigManager]: 执行config与prompt检测...");
|
||||
if (!modelConfigMap.containsKey("default")) {
|
||||
throw new ConfigNotExistException("缺少默认配置! 需确保存在一个模型配置的key为`default`");
|
||||
}
|
||||
if (!modelPromptMap.containsKey("basic")) {
|
||||
throw new PromptNotExistException("缺少基础Prompt! 需要确保存在key为basic的Prompt文件,它将与其他Prompt共同作用于模块节点。");
|
||||
}
|
||||
Set<String> configKeySet = new HashSet<>(modelConfigMap.keySet());
|
||||
configKeySet.remove("default");
|
||||
Set<String> promptKeySet = new HashSet<>(modelPromptMap.keySet());
|
||||
promptKeySet.remove("basic");
|
||||
if (!promptKeySet.containsAll(configKeySet)) {
|
||||
log.warn("存在未被提示词包含的模型配置,该配置将无法生效!");
|
||||
}
|
||||
log.info("[AgentConfigManager]: 检测完毕.");
|
||||
}
|
||||
|
||||
public List<Message> loadModelPrompt(String modelKey) {
|
||||
if (!modelPromptMap.containsKey(modelKey)) {
|
||||
throw new PromptNotExistException("不存在的modelPrompt: " + modelKey);
|
||||
@@ -108,12 +97,7 @@ public abstract class AgentConfigManager {
|
||||
}
|
||||
moduleEnabledStatus.put(key, status);
|
||||
dumpModuleEnabledStatus();
|
||||
for (MetaModule metaModule : moduleList) {
|
||||
if (metaModule.getName().equals(key)) {
|
||||
metaModule.setEnabled(status);
|
||||
break;
|
||||
}
|
||||
}
|
||||
moduleMap.get(key).setEnabled(status);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -22,12 +22,12 @@ import java.util.List;
|
||||
* 将从当前运行目录的config文件夹下创建并读取配置
|
||||
*/
|
||||
@Slf4j
|
||||
public class DefaultAgentConfigManager extends AgentConfigManager {
|
||||
public class FileAgentConfigManager extends AgentConfigManager {
|
||||
|
||||
private static final String CONFIG_DIR = "./config/";
|
||||
private static final String MODEL_CONFIG_DIR = "./config/model/";
|
||||
private static final String PROMPT_CONFIG_DIR = "./config/prompt/";
|
||||
private static final String MODULE_ENABLED_STATUS_CONFIG_FILE = CONFIG_DIR + "module_enabled_status.json";
|
||||
protected static final String CONFIG_DIR = "./config/";
|
||||
protected static final String MODEL_CONFIG_DIR = "./config/model/";
|
||||
protected static final String PROMPT_CONFIG_DIR = "./config/prompt/";
|
||||
protected static final String MODULE_ENABLED_STATUS_CONFIG_FILE = CONFIG_DIR + "module_enabled_status.json";
|
||||
|
||||
|
||||
@Override
|
||||
@@ -74,10 +74,10 @@ public class DefaultAgentConfigManager extends AgentConfigManager {
|
||||
}
|
||||
|
||||
@Override
|
||||
protected HashMap<String, Boolean> loadModuleEnabledStatusMap() {
|
||||
protected HashMap<String, Boolean> loadModuleEnabledStatusMap(List<MetaModule> moduleList) {
|
||||
File file = new File(MODULE_ENABLED_STATUS_CONFIG_FILE);
|
||||
try {
|
||||
HashMap<String, Boolean> moduleEnabledStatus = new HashMap<>();
|
||||
moduleEnabledStatus = new HashMap<>();
|
||||
if (!file.exists()) {
|
||||
file.createNewFile();
|
||||
for (MetaModule module : moduleList) {
|
||||
@@ -0,0 +1,11 @@
|
||||
package work.slhaf.partner.api.agent.runtime.exception;
|
||||
|
||||
public class AgentRunningFailedException extends AgentRuntimeException{
|
||||
public AgentRunningFailedException(String message) {
|
||||
super(message);
|
||||
}
|
||||
|
||||
public AgentRunningFailedException(String message, Throwable cause) {
|
||||
super(message, cause);
|
||||
}
|
||||
}
|
||||
@@ -1,23 +1,36 @@
|
||||
package work.slhaf.partner.api.agent.runtime.exception;
|
||||
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
|
||||
@Slf4j
|
||||
public class GlobalExceptionHandler {
|
||||
|
||||
public static GlobalExceptionHandler INSTANCE = new GlobalExceptionHandler();
|
||||
|
||||
private AgentExceptionCallback exceptionCallback = new DefaultAgentExceptionCallback();
|
||||
private AgentExceptionCallback exceptionCallback = new LogAgentExceptionCallback();
|
||||
|
||||
public void handle(Throwable e) {
|
||||
|
||||
switch (e.getClass().getSimpleName()) {
|
||||
case "AgentRuntimeException":
|
||||
exceptionCallback.onRuntimeException((AgentRuntimeException) e);
|
||||
break;
|
||||
case "AgentLaunchFailedException":
|
||||
exceptionCallback.onFailedException((AgentLaunchFailedException) e);
|
||||
break;
|
||||
default:
|
||||
throw new RuntimeException("未经处理的异常!", e);
|
||||
public boolean handle(Throwable e) {
|
||||
boolean exit;
|
||||
Throwable cause = e.getCause();
|
||||
switch (cause) {
|
||||
case AgentRunningFailedException arfe -> {
|
||||
exit = true;
|
||||
exceptionCallback.onRuntimeException((AgentRuntimeException) cause);
|
||||
}
|
||||
case AgentRuntimeException are -> {
|
||||
exit = false;
|
||||
exceptionCallback.onRuntimeException((AgentRuntimeException) cause);
|
||||
}
|
||||
case AgentLaunchFailedException alfe -> {
|
||||
exit = true;
|
||||
exceptionCallback.onFailedException((AgentLaunchFailedException) cause);
|
||||
}
|
||||
default -> {
|
||||
exit = true;
|
||||
log.error("意外异常: ", cause);
|
||||
}
|
||||
}
|
||||
return exit;
|
||||
}
|
||||
|
||||
public static void setExceptionCallback(AgentExceptionCallback callback) {
|
||||
|
||||
@@ -3,7 +3,7 @@ package work.slhaf.partner.api.agent.runtime.exception;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
|
||||
@Slf4j
|
||||
public class DefaultAgentExceptionCallback implements AgentExceptionCallback {
|
||||
public class LogAgentExceptionCallback implements AgentExceptionCallback {
|
||||
|
||||
@Override
|
||||
public void onRuntimeException(AgentRuntimeException e) {
|
||||
@@ -4,9 +4,23 @@ import work.slhaf.partner.api.agent.runtime.interaction.data.AgentInputData;
|
||||
import work.slhaf.partner.api.agent.runtime.interaction.data.AgentOutputData;
|
||||
import work.slhaf.partner.api.agent.runtime.interaction.flow.entity.RunningFlowContext;
|
||||
|
||||
public interface AgentGateway {
|
||||
public interface AgentGateway <I extends AgentInputData, O extends AgentOutputData, C extends RunningFlowContext>{
|
||||
|
||||
void launch();
|
||||
|
||||
<I extends AgentInputData, O extends AgentOutputData, C extends RunningFlowContext> AgentInteractionAdapter<I, O, C> adapter();
|
||||
default void receive(I inputData){
|
||||
C finalInputData = adapter().parseInputData(inputData);
|
||||
C outputContext = adapter().call(finalInputData);
|
||||
O outputData = adapter().parseOutputData(outputContext);
|
||||
send(outputData);
|
||||
}
|
||||
|
||||
void send(O outputData);
|
||||
|
||||
/**
|
||||
* 通过adapter提供的receive、send方法进行与客户端的交互行为
|
||||
*
|
||||
* @return adapter实例
|
||||
*/
|
||||
AgentInteractionAdapter<I, O, C> adapter();
|
||||
}
|
||||
|
||||
@@ -8,34 +8,19 @@ import work.slhaf.partner.api.agent.runtime.interaction.flow.AgentRunningFlow;
|
||||
import work.slhaf.partner.api.agent.runtime.interaction.flow.entity.RunningFlowContext;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
public abstract class AgentInteractionAdapter<I extends AgentInputData, O extends AgentOutputData, C extends RunningFlowContext> {
|
||||
|
||||
private static AgentInteractionAdapter<?,?,?> INSTANCE;
|
||||
|
||||
protected AgentRunningFlow<C> agentRunningFlow = new AgentRunningFlow<>();
|
||||
protected List<MetaModule> moduleList = AgentConfigManager.INSTANCE.getModuleList();
|
||||
protected Map<Integer, List<MetaModule>> moduleOrderedMap = AgentConfigManager.INSTANCE.getModuleOrderedMap();
|
||||
|
||||
public void receive(I inputData) {
|
||||
C finalInputData = parseInputData(inputData);
|
||||
C outputContext = agentRunningFlow.launch(moduleList, finalInputData);
|
||||
O outputData = parseOutputData(outputContext);
|
||||
send(outputData);
|
||||
public C call(C finalInputData){
|
||||
return agentRunningFlow.launch(moduleOrderedMap, finalInputData);
|
||||
}
|
||||
|
||||
protected abstract O parseOutputData(C outputContext);
|
||||
|
||||
protected abstract C parseInputData(I inputData);
|
||||
|
||||
public abstract void send(O outputData);
|
||||
|
||||
public static <I extends AgentInputData, O extends AgentOutputData, C extends RunningFlowContext> AgentInteractionAdapter<I, O, C> getInstance() {
|
||||
@SuppressWarnings("unchecked")
|
||||
AgentInteractionAdapter<I, O, C> instance = (AgentInteractionAdapter<I, O, C>) INSTANCE;
|
||||
return instance;
|
||||
}
|
||||
|
||||
public static <I extends AgentInputData, O extends AgentOutputData, C extends RunningFlowContext> void setInstance(AgentInteractionAdapter<I, O, C> instance) {
|
||||
INSTANCE = instance;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -7,7 +7,7 @@ import lombok.EqualsAndHashCode;
|
||||
@Data
|
||||
public abstract class AgentOutputData extends InteractionData{
|
||||
|
||||
private int code;
|
||||
protected int code;
|
||||
|
||||
public static class StatusCode {
|
||||
public static final int SUCCESS = 1;
|
||||
|
||||
@@ -1,24 +1,48 @@
|
||||
package work.slhaf.partner.api.agent.runtime.interaction.flow;
|
||||
|
||||
import work.slhaf.partner.api.agent.factory.module.pojo.MetaModule;
|
||||
import work.slhaf.partner.api.agent.runtime.exception.AgentRuntimeException;
|
||||
import work.slhaf.partner.api.agent.runtime.exception.GlobalExceptionHandler;
|
||||
import work.slhaf.partner.api.agent.runtime.interaction.flow.entity.RunningFlowContext;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.concurrent.ExecutorService;
|
||||
import java.util.concurrent.Executors;
|
||||
import java.util.concurrent.Future;
|
||||
|
||||
/**
|
||||
* Agent执行流程
|
||||
*/
|
||||
public class AgentRunningFlow<C extends RunningFlowContext> {
|
||||
|
||||
public C launch(List<MetaModule> moduleList, C interactionContext){
|
||||
try {
|
||||
public C launch(Map<Integer, List<MetaModule>> modules, C interactionContext) {
|
||||
try (ExecutorService executor = Executors.newVirtualThreadPerTaskExecutor()) {
|
||||
//流程执行启动
|
||||
for (MetaModule metaModule : moduleList) {
|
||||
metaModule.getInstance().execute(interactionContext);
|
||||
for (Map.Entry<Integer, List<MetaModule>> entry : modules.entrySet()) {
|
||||
List<Future<?>> futures = new ArrayList<>();
|
||||
List<MetaModule> moduleList = entry.getValue();
|
||||
for (MetaModule module : moduleList) {
|
||||
Future<?> future = executor.submit(() -> {
|
||||
module.getInstance().execute(interactionContext);
|
||||
});
|
||||
futures.add(future);
|
||||
}
|
||||
for (Future<?> future : futures) {
|
||||
try {
|
||||
future.get();
|
||||
} catch (Exception e) {
|
||||
GlobalExceptionHandler.INSTANCE.handle(e);
|
||||
boolean exit = GlobalExceptionHandler.INSTANCE.handle(e);
|
||||
if (exit) throw new AgentRuntimeException("Agent执行出错!", e);
|
||||
interactionContext.getErrMsg().add(e.getLocalizedMessage());
|
||||
}
|
||||
}
|
||||
}
|
||||
interactionContext.setOk(1);
|
||||
} catch (Exception e) {
|
||||
interactionContext.setOk(0);
|
||||
interactionContext.getErrMsg().add(e.getLocalizedMessage());
|
||||
}
|
||||
return interactionContext;
|
||||
}
|
||||
|
||||
@@ -17,13 +17,13 @@ public interface ActivateModel {
|
||||
|
||||
AgentConfigManager AGENT_CONFIG_MANAGER = AgentConfigManager.INSTANCE;
|
||||
|
||||
@Init
|
||||
@Init(order = -1)
|
||||
default void modelSettings() {
|
||||
Model model = new Model();
|
||||
ModelConfig modelConfig = AgentConfigManager.INSTANCE.loadModelConfig(modelKey());
|
||||
model.setBaseMessages(withBasicPrompt() ? loadSpecificPromptAndBasicPrompt(modelKey()) : loadSpecificPrompt(modelKey()));
|
||||
model.setChatClient(new ChatClient(modelConfig.getBaseUrl(), modelConfig.getApikey(), modelConfig.getModel()));
|
||||
((Module) this).setModel(model);
|
||||
setModel(model);
|
||||
}
|
||||
|
||||
default void updateModelSettings(ChatClient newChatClient) {
|
||||
@@ -87,6 +87,9 @@ public interface ActivateModel {
|
||||
((Module) this).setModel(model);
|
||||
}
|
||||
|
||||
/**
|
||||
* 对应调用的模型配置名称
|
||||
*/
|
||||
String modelKey();
|
||||
|
||||
boolean withBasicPrompt();
|
||||
|
||||
@@ -1,10 +1,36 @@
|
||||
package work.slhaf.partner.api.agent.runtime.interaction.flow.abstracts;
|
||||
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import work.slhaf.partner.api.agent.factory.module.annotation.AfterExecute;
|
||||
import work.slhaf.partner.api.agent.factory.module.annotation.AgentModule;
|
||||
import work.slhaf.partner.api.agent.factory.module.annotation.BeforeExecute;
|
||||
import work.slhaf.partner.api.agent.factory.module.annotation.CoreModule;
|
||||
import work.slhaf.partner.api.agent.runtime.interaction.flow.entity.RunningFlowContext;
|
||||
|
||||
/**
|
||||
* 流程执行模块基类
|
||||
*/
|
||||
public abstract class AgentRunningModule extends Module {
|
||||
public abstract void execute(RunningFlowContext context);
|
||||
@Slf4j
|
||||
public abstract class AgentRunningModule<C extends RunningFlowContext> extends Module {
|
||||
public abstract void execute(C context);
|
||||
|
||||
@BeforeExecute
|
||||
private void beforeLog() {
|
||||
log.debug("[{}] 模块执行开始...", getModuleName());
|
||||
}
|
||||
|
||||
@AfterExecute
|
||||
private void afterLog() {
|
||||
log.debug("[{}] 模块执行结束...", getModuleName());
|
||||
}
|
||||
|
||||
private String getModuleName(){
|
||||
if (this.getClass().isAnnotationPresent(AgentModule.class)) {
|
||||
return this.getClass().getAnnotation(AgentModule.class).name();
|
||||
} else if (this.getClass().isAnnotationPresent(CoreModule.class)) {
|
||||
return CoreModule.class.getAnnotation(AgentModule.class).name();
|
||||
}else {
|
||||
return "Unknown Module";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,13 +1,35 @@
|
||||
package work.slhaf.partner.api.agent.runtime.interaction.flow.abstracts;
|
||||
|
||||
|
||||
/**
|
||||
* 流程子模块基类
|
||||
* @param <I> 输入类型
|
||||
* @param <O> 输出类型
|
||||
*/
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import work.slhaf.partner.api.agent.factory.module.annotation.AfterExecute;
|
||||
import work.slhaf.partner.api.agent.factory.module.annotation.AgentModule;
|
||||
import work.slhaf.partner.api.agent.factory.module.annotation.BeforeExecute;
|
||||
import work.slhaf.partner.api.agent.factory.module.annotation.CoreModule;
|
||||
|
||||
@Slf4j
|
||||
public abstract class AgentRunningSubModule<I, O> extends Module {
|
||||
|
||||
public abstract O execute(I data);
|
||||
|
||||
|
||||
@BeforeExecute
|
||||
private void beforeLog() {
|
||||
log.debug("[{}] 模块执行开始...", getModuleName());
|
||||
}
|
||||
|
||||
@AfterExecute
|
||||
private void afterLog() {
|
||||
log.debug("[{}] 模块执行结束...", getModuleName());
|
||||
}
|
||||
|
||||
private String getModuleName(){
|
||||
if (this.getClass().isAnnotationPresent(AgentModule.class)) {
|
||||
return this.getClass().getAnnotation(AgentModule.class).name();
|
||||
} else if (this.getClass().isAnnotationPresent(CoreModule.class)) {
|
||||
return CoreModule.class.getAnnotation(AgentModule.class).name();
|
||||
}else {
|
||||
return "Unknown Module";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,13 +2,11 @@ package work.slhaf.partner.api.agent.runtime.interaction.flow.abstracts;
|
||||
|
||||
import lombok.Getter;
|
||||
import lombok.Setter;
|
||||
import work.slhaf.partner.api.agent.factory.capability.annotation.CapabilityHolder;
|
||||
import work.slhaf.partner.api.agent.runtime.interaction.flow.entity.Model;
|
||||
|
||||
/**
|
||||
* 模块基类
|
||||
*/
|
||||
@CapabilityHolder
|
||||
public abstract class Module {
|
||||
|
||||
@Getter
|
||||
|
||||
@@ -1,11 +1,18 @@
|
||||
package work.slhaf.partner.api.agent.runtime.interaction.flow.entity;
|
||||
|
||||
import lombok.Data;
|
||||
import lombok.EqualsAndHashCode;
|
||||
import work.slhaf.partner.api.common.entity.PersistableObject;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* 流程上下文
|
||||
*/
|
||||
@EqualsAndHashCode(callSuper = true)
|
||||
@Data
|
||||
public abstract class RunningFlowContext {
|
||||
|
||||
public abstract class RunningFlowContext extends PersistableObject {
|
||||
protected int ok;
|
||||
protected List<String> errMsg = new ArrayList<>();
|
||||
}
|
||||
|
||||
@@ -1,11 +1,37 @@
|
||||
package work.slhaf.partner.api.agent.util;
|
||||
|
||||
import org.reflections.Reflections;
|
||||
|
||||
import java.lang.annotation.Annotation;
|
||||
import java.lang.reflect.Method;
|
||||
import java.util.HashSet;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
public final class AgentUtil {
|
||||
|
||||
public static boolean isAssignableFromAnnotation(Class<?> clazz,Class<? extends Annotation> targetAnnotation){
|
||||
Set<Class<?>> visited = new HashSet<>();
|
||||
return isAssignableFromAnnotation(clazz,targetAnnotation,visited);
|
||||
}
|
||||
|
||||
private static boolean isAssignableFromAnnotation(Class<?> clazz,Class<? extends Annotation> targetAnnotation,Set<Class<?>> visited){
|
||||
if (!visited.add(clazz)){
|
||||
return false;
|
||||
}
|
||||
if (clazz.isAnnotationPresent(targetAnnotation)){
|
||||
return true;
|
||||
}
|
||||
Annotation[] annotations = clazz.getAnnotations();
|
||||
for (Annotation annotation : annotations) {
|
||||
boolean ok = isAssignableFromAnnotation(annotation.annotationType(),targetAnnotation,visited);
|
||||
if (ok){
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
public static String methodSignature(Method method) {
|
||||
StringBuilder sb = new StringBuilder();
|
||||
sb.append("(");
|
||||
@@ -23,6 +49,7 @@ public final class AgentUtil {
|
||||
public static Set<Class<?>> collectExtendedClasses(Class<?> clazz, Class<?> targetClass) {
|
||||
Set<Class<?>> classes = new HashSet<>();
|
||||
collectExtendedClasses(classes, clazz, targetClass);
|
||||
classes.add(clazz);
|
||||
return classes;
|
||||
}
|
||||
|
||||
@@ -36,10 +63,18 @@ public final class AgentUtil {
|
||||
collectInterfaces(clazz, classes);
|
||||
}
|
||||
|
||||
public static Set<Class<?>> getMethodAnnotationTypeSet(Class<? extends Annotation> clazz, Reflections reflections){
|
||||
Set<Method> methods = reflections.getMethodsAnnotatedWith(clazz);
|
||||
return methods.stream()
|
||||
.map(Method::getDeclaringClass)
|
||||
.collect(Collectors.toSet());
|
||||
}
|
||||
|
||||
private static void collectInterfaces(Class<?> clazz, Set<Class<?>> classes) {
|
||||
for (Class<?> type : clazz.getInterfaces()) {
|
||||
if (classes.add(type)) {
|
||||
collectInterfaces(type, classes);
|
||||
}
|
||||
}
|
||||
}}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
package work.slhaf.partner.api.chat;
|
||||
|
||||
import cn.hutool.core.io.IORuntimeException;
|
||||
import cn.hutool.http.HttpRequest;
|
||||
import cn.hutool.http.HttpResponse;
|
||||
import cn.hutool.json.JSONUtil;
|
||||
import lombok.Data;
|
||||
import lombok.NoArgsConstructor;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import work.slhaf.partner.api.chat.constant.ChatConstant;
|
||||
import work.slhaf.partner.api.chat.pojo.ChatBody;
|
||||
import work.slhaf.partner.api.chat.pojo.ChatResponse;
|
||||
@@ -13,6 +15,7 @@ import work.slhaf.partner.api.chat.pojo.PrimaryChatResponse;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
@Slf4j
|
||||
@Data
|
||||
@NoArgsConstructor
|
||||
public class ChatClient {
|
||||
@@ -34,6 +37,8 @@ public class ChatClient {
|
||||
|
||||
public ChatResponse runChat(List<Message> messages) {
|
||||
HttpRequest request = HttpRequest.post(url);
|
||||
request.setConnectionTimeout(2000);
|
||||
request.setReadTimeout(15000);
|
||||
request.header("Content-Type", "application/json");
|
||||
request.header("Authorization", "Bearer " + apikey);
|
||||
|
||||
@@ -53,17 +58,26 @@ public class ChatClient {
|
||||
.build();
|
||||
}
|
||||
|
||||
HttpResponse response = request.body(JSONUtil.toJsonStr(body)).execute();
|
||||
ChatResponse finalResponse;
|
||||
|
||||
try {
|
||||
HttpResponse response = request.body(JSONUtil.toJsonStr(body)).execute();
|
||||
PrimaryChatResponse primaryChatResponse = JSONUtil.toBean(response.body(), PrimaryChatResponse.class);
|
||||
finalResponse = ChatResponse.builder()
|
||||
.type(ChatConstant.Response.SUCCESS)
|
||||
.status(ChatConstant.ResponseStatus.SUCCESS)
|
||||
.message(primaryChatResponse.getChoices().get(0).getMessage().getContent())
|
||||
.usageBean(primaryChatResponse.getUsage())
|
||||
.build();
|
||||
|
||||
response.close();
|
||||
} catch (IORuntimeException e) {
|
||||
log.error("请求超时", e);
|
||||
finalResponse = ChatResponse.builder()
|
||||
.message("连接超时")
|
||||
.status(ChatConstant.ResponseStatus.FAILED)
|
||||
.usageBean(null)
|
||||
.build();
|
||||
}
|
||||
return finalResponse;
|
||||
}
|
||||
|
||||
|
||||
@@ -8,8 +8,7 @@ public class ChatConstant {
|
||||
public static final String ASSISTANT = "assistant";
|
||||
}
|
||||
|
||||
public static class Response {
|
||||
public static final String SUCCESS = "success";
|
||||
public static final String ERROR = "error";
|
||||
public enum ResponseStatus {
|
||||
SUCCESS, FAILED
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,13 +4,14 @@ import lombok.AllArgsConstructor;
|
||||
import lombok.Builder;
|
||||
import lombok.Data;
|
||||
import lombok.NoArgsConstructor;
|
||||
import work.slhaf.partner.api.chat.constant.ChatConstant;
|
||||
|
||||
@Data
|
||||
@Builder
|
||||
@AllArgsConstructor
|
||||
@NoArgsConstructor
|
||||
public class ChatResponse {
|
||||
private String type;
|
||||
private ChatConstant.ResponseStatus status;
|
||||
private String message;
|
||||
private PrimaryChatResponse.UsageBean usageBean;
|
||||
}
|
||||
|
||||
@@ -6,11 +6,12 @@ import net.bytebuddy.matcher.ElementMatchers;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import work.slhaf.partner.api.agent.runtime.interaction.flow.abstracts.AgentRunningModule;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.lang.reflect.InvocationTargetException;
|
||||
|
||||
public class ModuleProxyTest {
|
||||
@Test
|
||||
public void test() throws NoSuchMethodException, InvocationTargetException, InstantiationException, IllegalAccessException {
|
||||
public void test() throws NoSuchMethodException, InvocationTargetException, InstantiationException, IllegalAccessException, IOException, ClassNotFoundException {
|
||||
Class<? extends AgentRunningModule> clazz = new ByteBuddy().subclass(MyAgentRunningModule.class)
|
||||
.method(ElementMatchers.isOverriddenFrom(AgentRunningModule.class))
|
||||
.intercept(MethodDelegation.to(
|
||||
|
||||
28
Partner-Common/pom.xml
Normal file
28
Partner-Common/pom.xml
Normal file
@@ -0,0 +1,28 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project xmlns="http://maven.apache.org/POM/4.0.0"
|
||||
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
|
||||
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
|
||||
<modelVersion>4.0.0</modelVersion>
|
||||
<parent>
|
||||
<groupId>work.slhaf</groupId>
|
||||
<artifactId>Partner</artifactId>
|
||||
<version>0.5.0</version>
|
||||
</parent>
|
||||
|
||||
<artifactId>Partner-Common</artifactId>
|
||||
|
||||
<dependencies>
|
||||
<!-- https://mvnrepository.com/artifact/io.modelcontextprotocol.sdk/mcp -->
|
||||
<dependency>
|
||||
<groupId>io.modelcontextprotocol.sdk</groupId>
|
||||
<artifactId>mcp</artifactId>
|
||||
<version>0.17.0</version>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
<properties>
|
||||
<maven.compiler.source>21</maven.compiler.source>
|
||||
<maven.compiler.target>21</maven.compiler.target>
|
||||
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
|
||||
</properties>
|
||||
|
||||
</project>
|
||||
@@ -0,0 +1,155 @@
|
||||
package work.slhaf.partner.common.mcp;
|
||||
|
||||
import io.modelcontextprotocol.common.McpTransportContext;
|
||||
import io.modelcontextprotocol.json.McpJsonMapper;
|
||||
import io.modelcontextprotocol.json.TypeRef;
|
||||
import io.modelcontextprotocol.server.McpStatelessServerHandler;
|
||||
import io.modelcontextprotocol.spec.McpClientTransport;
|
||||
import io.modelcontextprotocol.spec.McpSchema;
|
||||
import io.modelcontextprotocol.spec.McpStatelessServerTransport;
|
||||
import reactor.core.publisher.Mono;
|
||||
import reactor.core.publisher.Sinks;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.concurrent.atomic.AtomicBoolean;
|
||||
import java.util.function.Consumer;
|
||||
import java.util.function.Function;
|
||||
|
||||
public final class InProcessMcpTransport implements McpClientTransport, McpStatelessServerTransport {
|
||||
|
||||
// 每个 transport 只处理一条 inbound 流
|
||||
private final Sinks.Many<McpSchema.JSONRPCMessage> inbound =
|
||||
Sinks.many().unicast().onBackpressureBuffer();
|
||||
|
||||
private final AtomicBoolean clientConnected = new AtomicBoolean(false);
|
||||
private final AtomicBoolean serverConnected = new AtomicBoolean(false);
|
||||
|
||||
/**
|
||||
* 对端
|
||||
*/
|
||||
private volatile InProcessMcpTransport peer;
|
||||
|
||||
private volatile McpStatelessServerHandler serverHandler;
|
||||
|
||||
public record Pair(InProcessMcpTransport clientSide, InProcessMcpTransport serverSide) {
|
||||
}
|
||||
|
||||
public static Pair pair() {
|
||||
InProcessMcpTransport client = new InProcessMcpTransport();
|
||||
InProcessMcpTransport server = new InProcessMcpTransport();
|
||||
|
||||
client.peer = server;
|
||||
server.peer = client;
|
||||
|
||||
return new Pair(client, server);
|
||||
}
|
||||
|
||||
/* ======================================================
|
||||
* Internal receive: peer.sendMessage -> this.receive
|
||||
* ====================================================== */
|
||||
private void receive(McpSchema.JSONRPCMessage message) {
|
||||
if (inbound.tryEmitNext(message).isFailure()) {
|
||||
throw new RuntimeException("Failed to receive message: " + message);
|
||||
}
|
||||
}
|
||||
|
||||
/* ======================================================
|
||||
* Client → Server sendMessage
|
||||
* ====================================================== */
|
||||
@Override
|
||||
public Mono<Void> sendMessage(McpSchema.JSONRPCMessage message) {
|
||||
InProcessMcpTransport p = this.peer;
|
||||
if (p == null) {
|
||||
return Mono.error(new IllegalStateException("Transport is not linked"));
|
||||
}
|
||||
return Mono.fromRunnable(() -> p.receive(message));
|
||||
}
|
||||
|
||||
/* ======================================================
|
||||
* Client connect(handler) 处理 server → client 消息
|
||||
* ====================================================== */
|
||||
@Override
|
||||
public Mono<Void> connect(Function<Mono<McpSchema.JSONRPCMessage>, Mono<McpSchema.JSONRPCMessage>> handler) {
|
||||
if (!clientConnected.compareAndSet(false, true)) {
|
||||
return Mono.error(new IllegalStateException("Client already connected"));
|
||||
}
|
||||
|
||||
return inbound.asFlux()
|
||||
.concatMap(msg ->
|
||||
handler.apply(Mono.just(msg))
|
||||
// handler may emit response message → send back to server
|
||||
.flatMap(resp -> resp != null ? sendMessage(resp) : Mono.empty())
|
||||
)
|
||||
.doFinally(sig -> clientConnected.set(false))
|
||||
.then();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setExceptionHandler(Consumer<Throwable> handler) {
|
||||
McpClientTransport.super.setExceptionHandler(handler);
|
||||
}
|
||||
|
||||
/* ======================================================
|
||||
* Server: bind stateless handler = process client → server inbound
|
||||
* ====================================================== */
|
||||
@Override
|
||||
public void setMcpHandler(McpStatelessServerHandler handler) {
|
||||
this.serverHandler = handler;
|
||||
|
||||
if (!serverConnected.compareAndSet(false, true)) {
|
||||
throw new IllegalStateException("Server already connected");
|
||||
}
|
||||
|
||||
// 订阅 client → server 消息
|
||||
inbound.asFlux()
|
||||
.concatMap(this::handleServerMessage)
|
||||
.doFinally(sig -> serverConnected.set(false))
|
||||
.subscribe();
|
||||
}
|
||||
|
||||
/**
|
||||
* Server 端处理 JSONRPCMessage
|
||||
*/
|
||||
private Mono<Void> handleServerMessage(McpSchema.JSONRPCMessage msg) {
|
||||
// 创建 transport context(简单实现即可)
|
||||
McpTransportContext ctx = key -> null;
|
||||
|
||||
if (msg instanceof McpSchema.JSONRPCRequest req) {
|
||||
return serverHandler.handleRequest(ctx, req)
|
||||
.flatMap(this::sendMessage);
|
||||
}
|
||||
|
||||
if (msg instanceof McpSchema.JSONRPCNotification noti) {
|
||||
return serverHandler.handleNotification(ctx, noti);
|
||||
}
|
||||
|
||||
return Mono.empty();
|
||||
}
|
||||
|
||||
/* ======================================================
|
||||
* other boilerplate
|
||||
* ====================================================== */
|
||||
|
||||
@Override
|
||||
public void close() {
|
||||
McpClientTransport.super.close();
|
||||
}
|
||||
|
||||
@Override
|
||||
public Mono<Void> closeGracefully() {
|
||||
inbound.tryEmitComplete();
|
||||
clientConnected.set(false);
|
||||
serverConnected.set(false);
|
||||
return Mono.empty();
|
||||
}
|
||||
|
||||
@Override
|
||||
public <T> T unmarshalFrom(Object data, TypeRef<T> typeRef) {
|
||||
return McpJsonMapper.getDefault().convertValue(data, typeRef);
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<String> protocolVersions() {
|
||||
return McpClientTransport.super.protocolVersions();
|
||||
}
|
||||
}
|
||||
43
Partner-Main/dependency-reduced-pom.xml
Normal file
43
Partner-Main/dependency-reduced-pom.xml
Normal file
@@ -0,0 +1,43 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/maven-v4_0_0.xsd">
|
||||
<parent>
|
||||
<artifactId>Partner</artifactId>
|
||||
<groupId>work.slhaf</groupId>
|
||||
<version>0.5.0</version>
|
||||
</parent>
|
||||
<modelVersion>4.0.0</modelVersion>
|
||||
<artifactId>Partner-Main</artifactId>
|
||||
<build>
|
||||
<plugins>
|
||||
<plugin>
|
||||
<artifactId>maven-shade-plugin</artifactId>
|
||||
<executions>
|
||||
<execution>
|
||||
<phase>package</phase>
|
||||
<goals>
|
||||
<goal>shade</goal>
|
||||
</goals>
|
||||
<configuration>
|
||||
<transformers>
|
||||
<transformer>
|
||||
<mainClass>work.slhaf.partner.Main</mainClass>
|
||||
</transformer>
|
||||
</transformers>
|
||||
</configuration>
|
||||
</execution>
|
||||
</executions>
|
||||
</plugin>
|
||||
<plugin>
|
||||
<artifactId>maven-surefire-plugin</artifactId>
|
||||
<configuration>
|
||||
<skipTests>true</skipTests>
|
||||
</configuration>
|
||||
</plugin>
|
||||
</plugins>
|
||||
</build>
|
||||
<properties>
|
||||
<maven.compiler.target>21</maven.compiler.target>
|
||||
<maven.compiler.source>21</maven.compiler.source>
|
||||
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
|
||||
</properties>
|
||||
</project>
|
||||
@@ -22,6 +22,64 @@
|
||||
<artifactId>Partner-Api</artifactId>
|
||||
<version>0.5.0</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.jetbrains.kotlinx</groupId>
|
||||
<artifactId>kotlinx-coroutines-core</artifactId>
|
||||
<version>1.10.2</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.jetbrains.kotlinx</groupId>
|
||||
<artifactId>kotlinx-coroutines-test</artifactId>
|
||||
<version>1.10.2</version>
|
||||
</dependency>
|
||||
<!-- https://mvnrepository.com/artifact/org.nd4j/nd4j-api -->
|
||||
<dependency>
|
||||
<groupId>org.nd4j</groupId>
|
||||
<artifactId>nd4j-api</artifactId>
|
||||
<version>1.0.0-M2.1</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>com.microsoft.onnxruntime</groupId>
|
||||
<artifactId>onnxruntime</artifactId>
|
||||
<version>1.23.1</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>ai.djl.huggingface</groupId>
|
||||
<artifactId>tokenizers</artifactId>
|
||||
<version>0.34.0</version>
|
||||
</dependency>
|
||||
<!-- https://mvnrepository.com/artifact/io.modelcontextprotocol.sdk/mcp -->
|
||||
<dependency>
|
||||
<groupId>io.modelcontextprotocol.sdk</groupId>
|
||||
<artifactId>mcp</artifactId>
|
||||
<version>0.17.0</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>work.slhaf</groupId>
|
||||
<artifactId>Partner-Common</artifactId>
|
||||
<version>0.5.0</version>
|
||||
<scope>compile</scope>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.mockito</groupId>
|
||||
<artifactId>mockito-core</artifactId>
|
||||
<version>5.20.0</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.mockito</groupId>
|
||||
<artifactId>mockito-junit-jupiter</artifactId>
|
||||
<version>5.20.0</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.mockito</groupId>
|
||||
<artifactId>mockito-inline</artifactId>
|
||||
<version>5.2.0</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>com.cronutils</groupId>
|
||||
<artifactId>cron-utils</artifactId>
|
||||
<version>9.2.1</version>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
|
||||
<properties>
|
||||
@@ -30,4 +88,36 @@
|
||||
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
|
||||
</properties>
|
||||
|
||||
<build>
|
||||
<plugins>
|
||||
<plugin>
|
||||
<groupId>org.apache.maven.plugins</groupId>
|
||||
<artifactId>maven-shade-plugin</artifactId>
|
||||
<executions>
|
||||
<execution>
|
||||
<phase>package</phase>
|
||||
<goals>
|
||||
<goal>shade</goal>
|
||||
</goals>
|
||||
<configuration>
|
||||
<transformers>
|
||||
<transformer
|
||||
implementation="org.apache.maven.plugins.shade.resource.ManifestResourceTransformer">
|
||||
<mainClass>work.slhaf.partner.Main</mainClass>
|
||||
</transformer>
|
||||
</transformers>
|
||||
</configuration>
|
||||
</execution>
|
||||
</executions>
|
||||
</plugin>
|
||||
<plugin>
|
||||
<groupId>org.apache.maven.plugins</groupId>
|
||||
<artifactId>maven-surefire-plugin</artifactId>
|
||||
<configuration>
|
||||
<skipTests>true</skipTests>
|
||||
</configuration>
|
||||
</plugin>
|
||||
</plugins>
|
||||
</build>
|
||||
|
||||
</project>
|
||||
@@ -1,14 +0,0 @@
|
||||
package work.slhaf;
|
||||
|
||||
import work.slhaf.partner.Agent;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Scanner;
|
||||
|
||||
public class Main {
|
||||
public static void main(String[] args) throws IOException {
|
||||
Agent.initialize();
|
||||
Scanner scanner = new Scanner(System.in);
|
||||
while (!scanner.nextLine().equals("exit")) ;
|
||||
}
|
||||
}
|
||||
@@ -1,75 +0,0 @@
|
||||
package work.slhaf.partner;
|
||||
|
||||
import lombok.Data;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import work.slhaf.partner.common.config.Config;
|
||||
import work.slhaf.partner.common.monitor.DebugMonitor;
|
||||
import work.slhaf.partner.core.InteractionHub;
|
||||
import work.slhaf.partner.core.interaction.agent_interface.InputReceiver;
|
||||
import work.slhaf.partner.core.interaction.agent_interface.TaskCallback;
|
||||
import work.slhaf.partner.core.interaction.data.InteractionInputData;
|
||||
import work.slhaf.partner.core.interaction.data.InteractionOutputData;
|
||||
import work.slhaf.partner.gateway.AgentWebSocketServer;
|
||||
import work.slhaf.partner.gateway.MessageSender;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.time.LocalDateTime;
|
||||
|
||||
@Data
|
||||
@Slf4j
|
||||
public class Agent implements TaskCallback, InputReceiver {
|
||||
|
||||
private static volatile Agent agent;
|
||||
private InteractionHub interactionHub;
|
||||
private MessageSender messageSender;
|
||||
|
||||
public static void initialize() throws IOException {
|
||||
if (agent == null) {
|
||||
synchronized (Agent.class) {
|
||||
if (agent == null) {
|
||||
//加载配置
|
||||
Config config = Config.getConfig();
|
||||
agent = new Agent();
|
||||
agent.setInteractionHub(InteractionHub.initialize());
|
||||
agent.registerTaskCallback();
|
||||
AgentWebSocketServer server = new AgentWebSocketServer(config.getWebSocketConfig().getPort(), agent);
|
||||
server.launch();
|
||||
agent.setMessageSender(server);
|
||||
log.info("Agent 加载完毕..");
|
||||
//启动监测线程
|
||||
DebugMonitor.initialize();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public static Agent getInstance() throws IOException {
|
||||
initialize();
|
||||
return agent;
|
||||
}
|
||||
|
||||
/**
|
||||
* 接收用户输入,包装为标准输入数据类
|
||||
*/
|
||||
public void receiveInput(InteractionInputData inputData) throws IOException, ClassNotFoundException {
|
||||
inputData.setLocalDateTime(LocalDateTime.now());
|
||||
interactionHub.call(inputData);
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* 向用户返回输出内容
|
||||
*/
|
||||
public void sendToUser(String userInfo, String output) {
|
||||
messageSender.sendMessage(new InteractionOutputData(output, userInfo));
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onTaskFinished(String userInfo, String output) {
|
||||
sendToUser(userInfo, output);
|
||||
}
|
||||
|
||||
private void registerTaskCallback() {
|
||||
interactionHub.setCallback(this);
|
||||
}
|
||||
}
|
||||
18
Partner-Main/src/main/java/work/slhaf/partner/Main.java
Normal file
18
Partner-Main/src/main/java/work/slhaf/partner/Main.java
Normal file
@@ -0,0 +1,18 @@
|
||||
package work.slhaf.partner;
|
||||
|
||||
import work.slhaf.partner.api.agent.Agent;
|
||||
import work.slhaf.partner.common.config.PartnerAgentConfigManager;
|
||||
import work.slhaf.partner.common.vector.VectorClient;
|
||||
import work.slhaf.partner.runtime.exception.PartnerExceptionCallback;
|
||||
import work.slhaf.partner.runtime.interaction.WebSocketGateway;
|
||||
|
||||
public class Main {
|
||||
public static void main(String[] args) {
|
||||
Agent.newAgent(Main.class)
|
||||
.setAgentConfigManager(PartnerAgentConfigManager.class)
|
||||
.setGateway(WebSocketGateway.class)
|
||||
.setAgentExceptionCallback(PartnerExceptionCallback.class)
|
||||
.addAfterLaunchRunners(VectorClient::load)
|
||||
.launch();
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,10 @@
|
||||
package work.slhaf.partner.common;
|
||||
|
||||
public final class Constant {
|
||||
|
||||
public static final class Path {
|
||||
public static final String DATA = "data";
|
||||
public static final String MEMORY_DATA = DATA + "/memory";
|
||||
}
|
||||
|
||||
}
|
||||
@@ -1,138 +1,24 @@
|
||||
package work.slhaf.partner.common.config;
|
||||
|
||||
import cn.hutool.json.JSONUtil;
|
||||
import com.alibaba.fastjson2.JSONArray;
|
||||
import lombok.Data;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.io.FileUtils;
|
||||
import work.slhaf.partner.module.modules.core.CoreModel;
|
||||
import work.slhaf.partner.module.modules.memory.selector.MemorySelector;
|
||||
import work.slhaf.partner.module.modules.memory.updater.MemoryUpdater;
|
||||
import work.slhaf.partner.module.modules.process.PostprocessExecutor;
|
||||
|
||||
import java.io.File;
|
||||
import java.io.IOException;
|
||||
import java.io.InputStream;
|
||||
import java.nio.charset.StandardCharsets;
|
||||
import java.nio.file.Files;
|
||||
import java.nio.file.Paths;
|
||||
import java.util.List;
|
||||
import java.util.Scanner;
|
||||
|
||||
@Data
|
||||
@Slf4j
|
||||
public class Config {
|
||||
|
||||
private static final String CONFIG_FILE_PATH = "./config/config.json";
|
||||
private static final String LOG_FILE_PATH = "./data/log";
|
||||
private static Config config;
|
||||
|
||||
private String agentId;
|
||||
// private String basicCharacter;
|
||||
|
||||
private WebSocketConfig webSocketConfig;
|
||||
private VectorConfig vectorConfig;
|
||||
|
||||
private List<ModuleConfig> moduleConfigList;
|
||||
|
||||
private Config() {
|
||||
@Data
|
||||
public static class VectorConfig {
|
||||
private int type;
|
||||
private String ollamaEmbeddingUrl;
|
||||
private String ollamaEmbeddingModel;
|
||||
private String tokenizerPath;
|
||||
private String embeddingModelPath;
|
||||
}
|
||||
|
||||
public static Config getConfig() throws IOException {
|
||||
if (config == null) {
|
||||
File file = new File(CONFIG_FILE_PATH);
|
||||
if (file.exists()) {
|
||||
config = JSONUtil.readJSONObject(file, StandardCharsets.UTF_8).toBean(Config.class);
|
||||
} else {
|
||||
config = new Config();
|
||||
Scanner scanner = new Scanner(System.in);
|
||||
|
||||
System.out.print("输入智能体名称: ");
|
||||
config.setAgentId(scanner.nextLine());
|
||||
|
||||
System.out.println("(注意! 设定角色之后修改主配置文件将不会影响现有记忆,除非同时更换agentId)");
|
||||
|
||||
System.out.println("\r\n--------模型配置--------\r\n");
|
||||
generateModelConfig(scanner);
|
||||
|
||||
System.out.println("\r\n--------服务配置--------\r\n");
|
||||
generateWsSocketConfig(scanner);
|
||||
|
||||
System.out.println("\r\n--------模块链配置--------\r\n");
|
||||
generatePipelineConfig();
|
||||
|
||||
boolean launchOrNot = getLaunchOrNot(scanner);
|
||||
|
||||
//保存配置文件
|
||||
String str = JSONUtil.toJsonPrettyStr(config);
|
||||
FileUtils.writeStringToFile(file, str, StandardCharsets.UTF_8);
|
||||
log.info("配置已保存");
|
||||
|
||||
if (!launchOrNot) {
|
||||
System.exit(0);
|
||||
@Data
|
||||
public static class WebSocketConfig {
|
||||
private int port;
|
||||
}
|
||||
}
|
||||
config.generateCommonDirs();
|
||||
}
|
||||
return config;
|
||||
}
|
||||
|
||||
private void generateCommonDirs() throws IOException {
|
||||
Files.createDirectories(Paths.get(LOG_FILE_PATH));
|
||||
}
|
||||
|
||||
private static boolean getLaunchOrNot(Scanner scanner) {
|
||||
System.out.print("是否直接启动Partner?(y/n): ");
|
||||
String input;
|
||||
while (true) {
|
||||
input = scanner.nextLine();
|
||||
if (input.equals("y")) {
|
||||
return true;
|
||||
} else if (input.equals("n")) {
|
||||
return false;
|
||||
} else {
|
||||
System.out.println("请输入y或n");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private static void generatePipelineConfig() {
|
||||
List<ModuleConfig> moduleConfigList = List.of(
|
||||
new ModuleConfig(MemorySelector.class.getName(), ModuleConfig.Constant.INTERNAL, null),
|
||||
new ModuleConfig(CoreModel.class.getName(), ModuleConfig.Constant.INTERNAL, null),
|
||||
new ModuleConfig(PostprocessExecutor.class.getName(),ModuleConfig.Constant.INTERNAL,null),
|
||||
new ModuleConfig(MemoryUpdater.class.getName(), ModuleConfig.Constant.INTERNAL, null)
|
||||
);
|
||||
config.setModuleConfigList(moduleConfigList);
|
||||
}
|
||||
|
||||
private static void generateWsSocketConfig(Scanner scanner) {
|
||||
System.out.print("WebSocket port: ");
|
||||
WebSocketConfig wsConfig = new WebSocketConfig();
|
||||
wsConfig.setPort(scanner.nextInt());
|
||||
config.setWebSocketConfig(wsConfig);
|
||||
}
|
||||
|
||||
private static void generateModelConfig(Scanner scanner) throws IOException {
|
||||
System.out.println("配置LLM APi:");
|
||||
System.out.println("经测试, 目前只建议选择Qwen3: qwen-plus-latest或qwen-max-latest");
|
||||
System.out.print("base_url: ");
|
||||
String baseUrl = scanner.nextLine();
|
||||
System.out.print("apikey: ");
|
||||
String apikey = scanner.nextLine();
|
||||
System.out.print("model: ");
|
||||
String model = scanner.nextLine();
|
||||
|
||||
ModelConfig modelConfig = new ModelConfig();
|
||||
modelConfig.setBaseUrl(baseUrl);
|
||||
modelConfig.setApikey(apikey);
|
||||
modelConfig.setModel(model);
|
||||
|
||||
InputStream stream = Config.class.getClassLoader().getResourceAsStream("modules/default_activated_model.json");
|
||||
String content = new String(stream.readAllBytes(), StandardCharsets.UTF_8);
|
||||
stream.close();
|
||||
for (String s : JSONArray.parseArray(content, String.class)) {
|
||||
modelConfig.generateConfig(s);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -1,40 +0,0 @@
|
||||
package work.slhaf.partner.common.config;
|
||||
|
||||
import cn.hutool.json.JSONUtil;
|
||||
import lombok.Data;
|
||||
import org.apache.commons.io.FileUtils;
|
||||
|
||||
import java.io.File;
|
||||
import java.io.IOException;
|
||||
import java.nio.charset.StandardCharsets;
|
||||
import java.util.HashMap;
|
||||
|
||||
@Data
|
||||
public class ModelConfig {
|
||||
|
||||
private static final String MODEL_CONFIG_DIR_PATH = "./config/model/";
|
||||
private static final HashMap<String, ModelConfig> modelConfigMap = new HashMap<>();
|
||||
|
||||
private String apikey;
|
||||
private String baseUrl;
|
||||
private String model;
|
||||
|
||||
public void generateConfig(String filename) throws IOException {
|
||||
String str = JSONUtil.toJsonPrettyStr(this);
|
||||
File file = new File(MODEL_CONFIG_DIR_PATH + filename + ".json");
|
||||
FileUtils.writeStringToFile(file, str, StandardCharsets.UTF_8);
|
||||
}
|
||||
|
||||
public static ModelConfig load(String modelKey) {
|
||||
if (!modelConfigMap.containsKey(modelKey)) {
|
||||
modelConfigMap.put(modelKey,loadConfig(modelKey));
|
||||
}
|
||||
|
||||
return modelConfigMap.get(modelKey);
|
||||
}
|
||||
|
||||
private static ModelConfig loadConfig(String modelKey) {
|
||||
File file = new File(MODEL_CONFIG_DIR_PATH+modelKey+".json");
|
||||
return JSONUtil.readJSONObject(file,StandardCharsets.UTF_8).toBean(ModelConfig.class);
|
||||
}
|
||||
}
|
||||
@@ -1,17 +0,0 @@
|
||||
package work.slhaf.partner.common.config;
|
||||
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Data;
|
||||
|
||||
@Data
|
||||
@AllArgsConstructor
|
||||
public class ModuleConfig {
|
||||
private String className;
|
||||
private String type;
|
||||
private String path;
|
||||
|
||||
public static class Constant {
|
||||
public static final String INTERNAL = "internal";
|
||||
public static final String EXTERNAL = "external";
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,41 @@
|
||||
package work.slhaf.partner.common.config;
|
||||
|
||||
import cn.hutool.json.JSONUtil;
|
||||
import lombok.Data;
|
||||
import lombok.EqualsAndHashCode;
|
||||
import work.slhaf.partner.api.agent.factory.config.exception.ConfigNotExistException;
|
||||
import work.slhaf.partner.api.agent.runtime.config.FileAgentConfigManager;
|
||||
import work.slhaf.partner.common.exception.ConfigLoadFailedException;
|
||||
|
||||
import java.io.File;
|
||||
import java.nio.charset.StandardCharsets;
|
||||
|
||||
@EqualsAndHashCode(callSuper = true)
|
||||
@Data
|
||||
public final class PartnerAgentConfigManager extends FileAgentConfigManager {
|
||||
|
||||
private static final String COMMON_CONFIG_FILE = CONFIG_DIR + "common_config.json";
|
||||
|
||||
private Config config;
|
||||
|
||||
@Override
|
||||
public void load() {
|
||||
loadWebSocketConfig();
|
||||
super.load();
|
||||
}
|
||||
|
||||
private void loadWebSocketConfig() {
|
||||
File file = new File(COMMON_CONFIG_FILE);
|
||||
if (!file.exists()) {
|
||||
throw new ConfigNotExistException("Partner Config Not Exist: " + COMMON_CONFIG_FILE);
|
||||
}
|
||||
config = JSONUtil.readJSONObject(file, StandardCharsets.UTF_8).toBean(Config.class);
|
||||
if (config == null || config.getAgentId() == null) {
|
||||
throw new ConfigLoadFailedException("Partner Config Load Failed: " + COMMON_CONFIG_FILE);
|
||||
}
|
||||
int port = config.getWebSocketConfig().getPort();
|
||||
if (port <= 0 || port > 65535) {
|
||||
throw new ConfigLoadFailedException("Invalid Websocket port: " + port);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,8 +0,0 @@
|
||||
package work.slhaf.partner.common.config;
|
||||
|
||||
import lombok.Data;
|
||||
|
||||
@Data
|
||||
public class WebSocketConfig {
|
||||
private Integer port;
|
||||
}
|
||||
@@ -0,0 +1,13 @@
|
||||
package work.slhaf.partner.common.exception;
|
||||
|
||||
import work.slhaf.partner.api.agent.factory.config.exception.ConfigFactoryInitFailedException;
|
||||
|
||||
public class ConfigLoadFailedException extends ConfigFactoryInitFailedException {
|
||||
public ConfigLoadFailedException(String message, Throwable cause) {
|
||||
super(message, cause);
|
||||
}
|
||||
|
||||
public ConfigLoadFailedException(String message) {
|
||||
super(message);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,13 @@
|
||||
package work.slhaf.partner.common.exception;
|
||||
|
||||
import work.slhaf.partner.api.agent.runtime.exception.AgentLaunchFailedException;
|
||||
|
||||
public class ServiceLoadFailedException extends AgentLaunchFailedException {
|
||||
public ServiceLoadFailedException(String message, Throwable cause) {
|
||||
super(message, cause);
|
||||
}
|
||||
|
||||
public ServiceLoadFailedException(String message) {
|
||||
super(message);
|
||||
}
|
||||
}
|
||||
@@ -1,43 +0,0 @@
|
||||
package work.slhaf.partner.common.exception_handler;
|
||||
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import work.slhaf.partner.common.exception_handler.pojo.GlobalException;
|
||||
import work.slhaf.partner.common.exception_handler.pojo.GlobalExceptionData;
|
||||
|
||||
import java.io.*;
|
||||
import java.nio.file.Files;
|
||||
import java.nio.file.Path;
|
||||
import java.nio.file.Paths;
|
||||
|
||||
@Slf4j
|
||||
public class GlobalExceptionHandler {
|
||||
|
||||
private static final String EXCEPTION_STATIC_PATH = "./data/exception_snapshot/";
|
||||
|
||||
public static void writeExceptionState(GlobalException exception) {
|
||||
GlobalExceptionData exceptionData = exception.getData();
|
||||
Path filePath = Paths.get(EXCEPTION_STATIC_PATH, exceptionData.getExceptionTime() + ".dat");
|
||||
try {
|
||||
Files.createDirectories(Path.of(EXCEPTION_STATIC_PATH));
|
||||
ObjectOutputStream oos = new ObjectOutputStream(new FileOutputStream(filePath.toFile()));
|
||||
oos.writeObject(exceptionData);
|
||||
oos.close();
|
||||
log.warn("[GlobalExceptionHandler] 捕获异常, 已保存到: {}", filePath);
|
||||
} catch (IOException e) {
|
||||
log.error("[GlobalExceptionHandler] 捕获异常, 保存失败: ", e);
|
||||
}
|
||||
}
|
||||
|
||||
public static GlobalExceptionData readExceptionState(String filePath) {
|
||||
try {
|
||||
ObjectInputStream ois = new ObjectInputStream(new FileInputStream(filePath));
|
||||
GlobalExceptionData exceptionData = (GlobalExceptionData) ois.readObject();
|
||||
ois.close();
|
||||
log.info("[GlobalExceptionHandler] 已从: {} 读取异常快照", filePath);
|
||||
return exceptionData;
|
||||
} catch (IOException | ClassNotFoundException e) {
|
||||
log.error("[GlobalExceptionHandler] 读取异常, 读取失败: ", e);
|
||||
return null;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,30 +0,0 @@
|
||||
package work.slhaf.partner.common.exception_handler.pojo;
|
||||
|
||||
import lombok.Data;
|
||||
import lombok.EqualsAndHashCode;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import work.slhaf.partner.core.cognation.cognation.CognationCore;
|
||||
import work.slhaf.partner.core.interaction.data.context.InteractionContext;
|
||||
import work.slhaf.partner.core.session.SessionManager;
|
||||
|
||||
@EqualsAndHashCode(callSuper = true)
|
||||
@Slf4j
|
||||
@Data
|
||||
public class GlobalException extends RuntimeException {
|
||||
|
||||
private GlobalExceptionData data;
|
||||
|
||||
public GlobalException(String message) {
|
||||
super(message);
|
||||
try {
|
||||
this.data = new GlobalExceptionData();
|
||||
this.data.setExceptionTime(System.currentTimeMillis());
|
||||
this.data.setSessionManager(SessionManager.getInstance());
|
||||
this.data.setCognationCore(CognationCore.getInstance());
|
||||
this.data.setContext(InteractionContext.getInstance());
|
||||
} catch (Exception e) {
|
||||
log.error("[GlobalException] 捕获异常, 获取数据失败");
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
@@ -1,26 +0,0 @@
|
||||
package work.slhaf.partner.common.exception_handler.pojo;
|
||||
|
||||
import lombok.Data;
|
||||
import lombok.EqualsAndHashCode;
|
||||
import work.slhaf.partner.api.common.entity.PersistableObject;
|
||||
import work.slhaf.partner.core.cognation.cognation.CognationCore;
|
||||
import work.slhaf.partner.core.interaction.data.context.InteractionContext;
|
||||
import work.slhaf.partner.core.session.SessionManager;
|
||||
|
||||
import java.io.Serial;
|
||||
import java.util.HashMap;
|
||||
|
||||
@EqualsAndHashCode(callSuper = true)
|
||||
@Data
|
||||
public class GlobalExceptionData extends PersistableObject {
|
||||
|
||||
@Serial
|
||||
private static final long serialVersionUID = 1L;
|
||||
|
||||
private String exceptionMessage;
|
||||
|
||||
protected HashMap<String, InteractionContext> context;
|
||||
protected SessionManager sessionManager;
|
||||
protected CognationCore cognationCore;
|
||||
protected Long exceptionTime;
|
||||
}
|
||||
@@ -1,14 +1,9 @@
|
||||
package work.slhaf.partner.common.thread;
|
||||
|
||||
import lombok.Getter;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.concurrent.Callable;
|
||||
import java.util.concurrent.ExecutorService;
|
||||
import java.util.concurrent.Executors;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
import java.util.concurrent.*;
|
||||
|
||||
@Getter
|
||||
public class InteractionThreadPoolExecutor {
|
||||
|
||||
private static InteractionThreadPoolExecutor interactionThreadPoolExecutor;
|
||||
@@ -33,9 +28,29 @@ public class InteractionThreadPoolExecutor {
|
||||
|
||||
public <T> void invokeAll(List<Callable<T>> tasks) {
|
||||
try {
|
||||
executorService.invokeAll(tasks);
|
||||
List<Future<T>> futures = executorService.invokeAll(tasks);
|
||||
for (Future<T> future : futures) {
|
||||
future.get();
|
||||
}
|
||||
} catch (InterruptedException e) {
|
||||
throw new RuntimeException(e);
|
||||
} catch (ExecutionException e) {
|
||||
throw new RuntimeException(e.getCause());
|
||||
}
|
||||
}
|
||||
|
||||
public <T> List<T> invokeAllAndReturn(List<Callable<T>> tasks) {
|
||||
try {
|
||||
List<Future<T>> futures = executorService.invokeAll(tasks);
|
||||
List<T> results = new ArrayList<>();
|
||||
for (Future<T> future : futures) {
|
||||
results.add(future.get());
|
||||
}
|
||||
return results;
|
||||
} catch (InterruptedException e) {
|
||||
throw new RuntimeException(e);
|
||||
} catch (ExecutionException e) {
|
||||
throw new RuntimeException(e.getCause());
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,14 @@
|
||||
package work.slhaf.partner.common.util;
|
||||
|
||||
public class PathUtil {
|
||||
public static String buildPathStr(String... path) {
|
||||
StringBuilder str = new StringBuilder();
|
||||
for (int i = 0; i < path.length; i++) {
|
||||
str.append(path[i]);
|
||||
if (i < path.length - 1) {
|
||||
str.append("/");
|
||||
}
|
||||
}
|
||||
return str.toString();
|
||||
}
|
||||
}
|
||||
@@ -1,7 +1,7 @@
|
||||
package work.slhaf.partner.common.util;
|
||||
|
||||
import com.alibaba.fastjson2.JSONArray;
|
||||
import work.slhaf.partner.Agent;
|
||||
import work.slhaf.partner.api.agent.Agent;
|
||||
import work.slhaf.partner.api.chat.pojo.Message;
|
||||
|
||||
import java.io.InputStream;
|
||||
|
||||
@@ -0,0 +1,48 @@
|
||||
package work.slhaf.partner.common.vector;
|
||||
|
||||
import cn.hutool.http.HttpRequest;
|
||||
import cn.hutool.http.HttpResponse;
|
||||
import com.alibaba.fastjson2.JSONObject;
|
||||
import lombok.Data;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import work.slhaf.partner.common.vector.exception.VectorClientExecuteException;
|
||||
|
||||
import java.util.Map;
|
||||
|
||||
@Slf4j
|
||||
public class OllamaVectorClient extends VectorClient {
|
||||
|
||||
private String ollamaEmbeddingUrl;
|
||||
private String ollamaEmbeddingModel;
|
||||
|
||||
protected OllamaVectorClient(String url, String model) {
|
||||
this.ollamaEmbeddingUrl = url;
|
||||
this.ollamaEmbeddingModel = model;
|
||||
|
||||
compute("test");
|
||||
}
|
||||
|
||||
@Override
|
||||
protected float[] doCompute(String input) {
|
||||
Map<String, String> param = Map.of("model", ollamaEmbeddingModel, "input", input);
|
||||
HttpRequest request = HttpRequest.get(ollamaEmbeddingUrl).body(JSONObject.toJSONString(param));
|
||||
try (HttpResponse response = request.execute()) {
|
||||
if (!response.isOk())
|
||||
throw new VectorClientExecuteException("嵌入模型执行出错");
|
||||
String resStr = response.body();
|
||||
EmbeddingModelResponse embeddingResponse = JSONObject.parseObject(resStr, EmbeddingModelResponse.class);
|
||||
return embeddingResponse.getEmbeddings()[0];
|
||||
} catch (Exception e) {
|
||||
throw new VectorClientExecuteException("嵌入模型执行出错", e);
|
||||
}
|
||||
}
|
||||
|
||||
@Data
|
||||
private static class EmbeddingModelResponse {
|
||||
private String model;
|
||||
private float[][] embeddings;
|
||||
private long total_duration;
|
||||
private long load_duration;
|
||||
private int prompt_eval_count;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,82 @@
|
||||
package work.slhaf.partner.common.vector;
|
||||
|
||||
import ai.djl.huggingface.tokenizers.Encoding;
|
||||
import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer;
|
||||
import ai.onnxruntime.OnnxTensor;
|
||||
import ai.onnxruntime.OrtEnvironment;
|
||||
import ai.onnxruntime.OrtSession;
|
||||
import work.slhaf.partner.common.vector.exception.VectorClientExecuteException;
|
||||
import work.slhaf.partner.common.vector.exception.VectorClientLoadFailedException;
|
||||
|
||||
import java.nio.file.Path;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
|
||||
@SuppressWarnings("FieldMayBeFinal")
|
||||
public class OnnxVectorClient extends VectorClient {
|
||||
|
||||
private String tokenizerPath;
|
||||
private String modelPath;
|
||||
|
||||
private HuggingFaceTokenizer tokenizer;
|
||||
private OrtSession session;
|
||||
private OrtEnvironment env;
|
||||
|
||||
protected OnnxVectorClient(String tokenizer, String model) {
|
||||
this.tokenizerPath = tokenizer;
|
||||
this.modelPath = model;
|
||||
|
||||
loadTokenizer();
|
||||
loadModel();
|
||||
compute("test");
|
||||
}
|
||||
|
||||
private void loadModel() {
|
||||
try {
|
||||
env = OrtEnvironment.getEnvironment();
|
||||
OrtSession.SessionOptions ops = new OrtSession.SessionOptions();
|
||||
session = env.createSession(modelPath, ops);
|
||||
} catch (Exception e) {
|
||||
throw new VectorClientLoadFailedException("加载ONNX模型失败", e);
|
||||
}
|
||||
}
|
||||
|
||||
private void loadTokenizer() {
|
||||
try {
|
||||
tokenizer = HuggingFaceTokenizer.newInstance(Path.of(tokenizerPath));
|
||||
} catch (Exception e) {
|
||||
throw new VectorClientLoadFailedException("加载Tokenizer失败", e);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
protected float[] doCompute(String input) {
|
||||
try {
|
||||
Encoding encode = tokenizer.encode(input);
|
||||
long[] ids = encode.getIds();
|
||||
long[] attentionMask = encode.getAttentionMask();
|
||||
|
||||
long[][] inputIdsBatch = {ids};
|
||||
long[][] attentionMaskBatch = {attentionMask};
|
||||
long[][] tokenTypeIdsBatch = {new long[ids.length]}; // 初始化全 0
|
||||
for (int i = 0; i < ids.length; i++)
|
||||
tokenTypeIdsBatch[0][i] = 0;
|
||||
|
||||
OnnxTensor inputTensor = OnnxTensor.createTensor(env, inputIdsBatch);
|
||||
OnnxTensor maskTensor = OnnxTensor.createTensor(env, attentionMaskBatch);
|
||||
OnnxTensor tokenTypeTensor = OnnxTensor.createTensor(env, tokenTypeIdsBatch);
|
||||
|
||||
Map<String, OnnxTensor> inputs = new HashMap<>();
|
||||
inputs.put("input_ids", inputTensor);
|
||||
inputs.put("attention_mask", maskTensor);
|
||||
inputs.put("token_type_ids", tokenTypeTensor);
|
||||
|
||||
OrtSession.Result result = session.run(inputs);
|
||||
OnnxTensor embeddingTensor = (OnnxTensor) result.get(0);
|
||||
return embeddingTensor.getFloatBuffer().array();
|
||||
} catch (Exception e) {
|
||||
throw new VectorClientExecuteException("嵌入模型执行出错", e);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,88 @@
|
||||
package work.slhaf.partner.common.vector;
|
||||
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.linalg.ops.transforms.Transforms;
|
||||
import work.slhaf.partner.api.agent.runtime.config.AgentConfigManager;
|
||||
import work.slhaf.partner.common.config.Config.VectorConfig;
|
||||
import work.slhaf.partner.common.config.PartnerAgentConfigManager;
|
||||
import work.slhaf.partner.common.exception.ServiceLoadFailedException;
|
||||
import work.slhaf.partner.common.vector.exception.VectorClientExecuteException;
|
||||
import work.slhaf.partner.common.vector.exception.VectorClientLoadFailedException;
|
||||
|
||||
@Slf4j
|
||||
public abstract class VectorClient {
|
||||
|
||||
public static boolean status;
|
||||
public static VectorClient INSTANCE;
|
||||
|
||||
public static void load() {
|
||||
PartnerAgentConfigManager configManager = (PartnerAgentConfigManager) AgentConfigManager.INSTANCE;
|
||||
VectorConfig vectorConfig = configManager.getConfig().getVectorConfig();
|
||||
int type = vectorConfig.getType();
|
||||
try {
|
||||
switch (type) {
|
||||
case 0:
|
||||
status = false;
|
||||
break;
|
||||
case 1:
|
||||
status = true;
|
||||
INSTANCE = new OllamaVectorClient(vectorConfig.getOllamaEmbeddingUrl(),
|
||||
vectorConfig.getOllamaEmbeddingModel());
|
||||
break;
|
||||
case 2:
|
||||
status = true;
|
||||
INSTANCE = new OnnxVectorClient(vectorConfig.getTokenizerPath(),
|
||||
vectorConfig.getEmbeddingModelPath());
|
||||
break;
|
||||
default:
|
||||
throw new ServiceLoadFailedException(
|
||||
"加载向量客户端失败! type: 0 -> 不启用语义缓存; type: 1 -> ollama; type: 2 -> ONNX RUNTIME");
|
||||
}
|
||||
log.info("向量客户端加载完毕");
|
||||
} catch (VectorClientLoadFailedException | VectorClientExecuteException exception) {
|
||||
status = false;
|
||||
log.error("向量客户端加载失败", exception);
|
||||
}
|
||||
}
|
||||
|
||||
public float[] compute(String input) {
|
||||
if (!status) {
|
||||
return null;
|
||||
}
|
||||
return doCompute(input);
|
||||
}
|
||||
|
||||
protected abstract float[] doCompute(String input);
|
||||
|
||||
public double compare(float[] v1, float[] v2) {
|
||||
if (!status) {
|
||||
return 0;
|
||||
}
|
||||
try (INDArray a1 = Nd4j.create(v1); INDArray a2 = Nd4j.create(v2)) {
|
||||
return Transforms.cosineSim(a1, a2);
|
||||
}
|
||||
}
|
||||
|
||||
public float[] weightedAverage(float[] newVector, float[] primaryVector) {
|
||||
try (INDArray primary = Nd4j.create(primaryVector);
|
||||
INDArray latest = Nd4j.create(newVector)) {
|
||||
|
||||
// 1️⃣ 计算余弦相似度
|
||||
double similarity = Transforms.cosineSim(primary, latest);
|
||||
|
||||
// 2️⃣ 根据相似度决定更新比例 α(差异越大,新输入影响越强)
|
||||
double alpha = (1.0 - similarity) * 0.5;
|
||||
alpha = Math.max(0.05, Math.min(alpha, 0.5));
|
||||
|
||||
// 3️⃣ 按比例混合旧向量与新向量
|
||||
INDArray updated = primary.mul(1 - alpha).add(latest.mul(alpha));
|
||||
|
||||
// 4️⃣ 归一化结果(保持方向空间一致)
|
||||
updated = updated.div(updated.norm2Number());
|
||||
|
||||
return updated.toFloatVector();
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,15 @@
|
||||
package work.slhaf.partner.common.vector.exception;
|
||||
|
||||
import work.slhaf.partner.api.agent.runtime.exception.AgentRuntimeException;
|
||||
|
||||
public class VectorClientExecuteException extends AgentRuntimeException {
|
||||
|
||||
public VectorClientExecuteException(String message) {
|
||||
super(message);
|
||||
}
|
||||
|
||||
public VectorClientExecuteException(String message, Throwable cause) {
|
||||
super(message, cause);
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,15 @@
|
||||
package work.slhaf.partner.common.vector.exception;
|
||||
|
||||
import work.slhaf.partner.api.agent.runtime.exception.AgentRuntimeException;
|
||||
|
||||
public class VectorClientLoadFailedException extends AgentRuntimeException {
|
||||
|
||||
public VectorClientLoadFailedException(String message) {
|
||||
super(message);
|
||||
}
|
||||
|
||||
public VectorClientLoadFailedException(String message, Throwable cause) {
|
||||
super(message, cause);
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,56 @@
|
||||
package work.slhaf.partner.core;
|
||||
|
||||
import lombok.Data;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import work.slhaf.partner.api.agent.factory.capability.annotation.CoordinateManager;
|
||||
import work.slhaf.partner.api.agent.factory.capability.annotation.Coordinated;
|
||||
import work.slhaf.partner.api.chat.constant.ChatConstant;
|
||||
import work.slhaf.partner.core.cognation.CognationCore;
|
||||
import work.slhaf.partner.core.memory.MemoryCore;
|
||||
|
||||
import java.io.Serial;
|
||||
import java.io.Serializable;
|
||||
import java.util.HashSet;
|
||||
import java.util.Set;
|
||||
|
||||
import static work.slhaf.partner.common.util.ExtractUtil.extractUserId;
|
||||
|
||||
@Data
|
||||
@Slf4j
|
||||
@CoordinateManager
|
||||
public class CoordinatedManager implements Serializable {
|
||||
|
||||
@Serial
|
||||
private static final long serialVersionUID = 1L;
|
||||
|
||||
//在框架将自动注入core,详见CapabilityRegistryFactory
|
||||
private CognationCore cognationCore;
|
||||
private MemoryCore memoryCore;
|
||||
|
||||
|
||||
private boolean isCacheSingleUser() {
|
||||
return memoryCore.getUserDialogMap().size() <= 1;
|
||||
}
|
||||
|
||||
@Coordinated(capability = "cognation")
|
||||
public boolean isSingleUser() {
|
||||
return isCacheSingleUser() && isChatMessagesSingleUser();
|
||||
}
|
||||
|
||||
private boolean isChatMessagesSingleUser() {
|
||||
Set<String> userIdSet = new HashSet<>();
|
||||
cognationCore.getChatMessages().forEach(m -> {
|
||||
if (m.getRole().equals(ChatConstant.Character.ASSISTANT)) {
|
||||
return;
|
||||
}
|
||||
String userId = extractUserId(m.getContent());
|
||||
if (userId == null || userId.isEmpty()) {
|
||||
return;
|
||||
}
|
||||
userIdSet.add(userId);
|
||||
});
|
||||
return userIdSet.size() <= 1;
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
@@ -1,56 +0,0 @@
|
||||
package work.slhaf.partner.core;
|
||||
|
||||
import lombok.Data;
|
||||
import lombok.ToString;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import work.slhaf.partner.common.exception_handler.GlobalExceptionHandler;
|
||||
import work.slhaf.partner.common.exception_handler.pojo.GlobalException;
|
||||
import work.slhaf.partner.core.interaction.agent_interface.TaskCallback;
|
||||
import work.slhaf.partner.core.interaction.data.InteractionInputData;
|
||||
import work.slhaf.partner.core.interaction.data.context.InteractionContext;
|
||||
import work.slhaf.partner.core.interaction.module.InteractionFlow;
|
||||
import work.slhaf.partner.core.interaction.module.InteractionModulesLoader;
|
||||
import work.slhaf.partner.module.modules.process.PreprocessExecutor;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.List;
|
||||
|
||||
@Data
|
||||
@Slf4j
|
||||
public class InteractionHub {
|
||||
|
||||
private static volatile InteractionHub interactionHub;
|
||||
|
||||
@ToString.Exclude
|
||||
private TaskCallback callback;
|
||||
private List<InteractionFlow> interactionModules;
|
||||
|
||||
public static InteractionHub initialize() throws IOException {
|
||||
if (interactionHub == null) {
|
||||
synchronized (InteractionHub.class) {
|
||||
if (interactionHub == null) {
|
||||
interactionHub = new InteractionHub();
|
||||
//加载模块
|
||||
interactionHub.setInteractionModules(InteractionModulesLoader.getInstance().registerInteractionModules());
|
||||
log.info("InteractionHub注册完毕...");
|
||||
}
|
||||
}
|
||||
}
|
||||
return interactionHub;
|
||||
}
|
||||
|
||||
public void call(InteractionInputData inputData) throws IOException, ClassNotFoundException {
|
||||
InteractionContext interactionContext = PreprocessExecutor.getInstance().execute(inputData);
|
||||
try {
|
||||
for (InteractionFlow interactionModule : interactionModules) {
|
||||
interactionModule.execute(interactionContext);
|
||||
}
|
||||
} catch (GlobalException e) {
|
||||
GlobalExceptionHandler.writeExceptionState(e);
|
||||
interactionContext.getCoreResponse().put("text", "[ERROR] " + e.getMessage());
|
||||
} finally {
|
||||
callback.onTaskFinished(interactionContext.getUserInfo(), interactionContext.getCoreResponse().getString("text"));
|
||||
interactionContext.clearUp();
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,94 @@
|
||||
package work.slhaf.partner.core;
|
||||
|
||||
import cn.hutool.core.bean.BeanUtil;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.io.FileUtils;
|
||||
import work.slhaf.partner.api.agent.runtime.config.AgentConfigManager;
|
||||
import work.slhaf.partner.api.common.entity.PersistableObject;
|
||||
import work.slhaf.partner.common.config.PartnerAgentConfigManager;
|
||||
|
||||
import java.io.*;
|
||||
import java.nio.file.Files;
|
||||
import java.nio.file.Path;
|
||||
import java.nio.file.Paths;
|
||||
import java.nio.file.StandardCopyOption;
|
||||
|
||||
import static work.slhaf.partner.common.Constant.Path.MEMORY_DATA;
|
||||
|
||||
@Slf4j
|
||||
public abstract class PartnerCore<T extends PartnerCore<T>> extends PersistableObject {
|
||||
|
||||
private final String id = ((PartnerAgentConfigManager) AgentConfigManager.INSTANCE).getConfig().getAgentId();
|
||||
|
||||
public PartnerCore() throws IOException, ClassNotFoundException {
|
||||
createStorageDirectory();
|
||||
Path filePath = getFilePath(id);
|
||||
if (Files.exists(filePath)) {
|
||||
T deserialize = deserialize();
|
||||
setupData(deserialize, (T) this);
|
||||
} else {
|
||||
FileUtils.createParentDirectories(filePath.toFile().getParentFile());
|
||||
this.serialize();
|
||||
}
|
||||
setupHook(this);
|
||||
log.info("[{}] 注册完毕", getCoreKey());
|
||||
|
||||
}
|
||||
|
||||
private void setupHook(PartnerCore<T> temp) {
|
||||
Runtime.getRuntime().addShutdownHook(new Thread(() -> {
|
||||
try {
|
||||
temp.serialize();
|
||||
log.info("[{}] 已保存", getCoreKey());
|
||||
} catch (IOException e) {
|
||||
log.error("[{}] 保存失败: ", getCoreKey(), e);
|
||||
}
|
||||
}));
|
||||
|
||||
}
|
||||
|
||||
private void setupData(T source, T current) {
|
||||
BeanUtil.copyProperties(source, current);
|
||||
}
|
||||
|
||||
public void serialize() throws IOException {
|
||||
//先写入到临时文件,如果正常写入则覆盖原文件
|
||||
Path filePath = getFilePath(id + "-temp");
|
||||
Files.createDirectories(Path.of(MEMORY_DATA));
|
||||
try {
|
||||
ObjectOutputStream oos = new ObjectOutputStream(new FileOutputStream(filePath.toFile()));
|
||||
oos.writeObject(this);
|
||||
oos.close();
|
||||
Path path = getFilePath(id);
|
||||
Files.move(filePath, path, StandardCopyOption.REPLACE_EXISTING);
|
||||
log.info("[{}] 已保存到: {}", getCoreKey(), path);
|
||||
} catch (IOException e) {
|
||||
Files.delete(filePath);
|
||||
log.error("[{}] 序列化保存失败: {}", getCoreKey(), e.getMessage());
|
||||
}
|
||||
}
|
||||
|
||||
private T deserialize() throws IOException, ClassNotFoundException {
|
||||
Path filePath = getFilePath(id);
|
||||
try (ObjectInputStream ois = new ObjectInputStream(
|
||||
new FileInputStream(filePath.toFile()))) {
|
||||
T graph = (T) ois.readObject();
|
||||
log.info("[{}] 已从文件加载: {}", getCoreKey(), filePath);
|
||||
return graph;
|
||||
}
|
||||
}
|
||||
|
||||
private Path getFilePath(String s) {
|
||||
return Paths.get(MEMORY_DATA, s + "-" + getCoreKey() + ".memory");
|
||||
}
|
||||
|
||||
private void createStorageDirectory() {
|
||||
try {
|
||||
Files.createDirectories(Paths.get(MEMORY_DATA));
|
||||
} catch (IOException e) {
|
||||
log.error("[{}]创建存储目录失败: {}", getCoreKey(), e.getMessage());
|
||||
}
|
||||
}
|
||||
|
||||
protected abstract String getCoreKey();
|
||||
}
|
||||
@@ -0,0 +1,59 @@
|
||||
package work.slhaf.partner.core.action;
|
||||
|
||||
import lombok.NonNull;
|
||||
import org.jetbrains.annotations.Nullable;
|
||||
import work.slhaf.partner.api.agent.factory.capability.annotation.Capability;
|
||||
import work.slhaf.partner.core.action.entity.ActionData;
|
||||
import work.slhaf.partner.core.action.entity.MetaAction;
|
||||
import work.slhaf.partner.core.action.entity.MetaActionInfo;
|
||||
import work.slhaf.partner.core.action.entity.PhaserRecord;
|
||||
import work.slhaf.partner.core.action.entity.cache.CacheAdjustData;
|
||||
import work.slhaf.partner.core.action.runner.RunnerClient;
|
||||
import work.slhaf.partner.module.modules.action.interventor.entity.MetaIntervention;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
import java.util.concurrent.ExecutorService;
|
||||
import java.util.concurrent.Phaser;
|
||||
|
||||
@Capability(value = "action")
|
||||
public interface ActionCapability {
|
||||
|
||||
void putAction(@NonNull ActionData actionData);
|
||||
|
||||
Set<ActionData> listActions(@Nullable ActionData.ActionStatus actionStatus, @Nullable String source);
|
||||
|
||||
List<ActionData> popPendingAction(String userId);
|
||||
|
||||
List<ActionData> listPendingAction(String userId);
|
||||
|
||||
void putPendingActions(String userId, ActionData actionData);
|
||||
|
||||
List<String> selectTendencyCache(String input);
|
||||
|
||||
void updateTendencyCache(CacheAdjustData data);
|
||||
|
||||
ExecutorService getExecutor(ActionCore.ExecutorType type);
|
||||
|
||||
PhaserRecord putPhaserRecord(Phaser phaser, ActionData actionData);
|
||||
|
||||
void removePhaserRecord(Phaser phaser);
|
||||
|
||||
List<PhaserRecord> listPhaserRecords();
|
||||
|
||||
PhaserRecord getPhaserRecord(String tendency, String source);
|
||||
|
||||
MetaAction loadMetaAction(@NonNull String actionKey);
|
||||
|
||||
MetaActionInfo loadMetaActionInfo(@NonNull String actionKey);
|
||||
|
||||
Map<String, MetaActionInfo> listAvailableMetaActions();
|
||||
|
||||
boolean checkExists(String... actionKeys);
|
||||
|
||||
RunnerClient runnerClient();
|
||||
|
||||
void handleInterventions(List<MetaIntervention> interventions, ActionData data);
|
||||
|
||||
}
|
||||
@@ -0,0 +1,446 @@
|
||||
package work.slhaf.partner.core.action;
|
||||
|
||||
import lombok.NonNull;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import lombok.val;
|
||||
import org.jetbrains.annotations.Nullable;
|
||||
import work.slhaf.partner.api.agent.factory.capability.annotation.CapabilityCore;
|
||||
import work.slhaf.partner.api.agent.factory.capability.annotation.CapabilityMethod;
|
||||
import work.slhaf.partner.common.vector.VectorClient;
|
||||
import work.slhaf.partner.core.PartnerCore;
|
||||
import work.slhaf.partner.core.action.entity.ActionData;
|
||||
import work.slhaf.partner.core.action.entity.MetaAction;
|
||||
import work.slhaf.partner.core.action.entity.MetaActionInfo;
|
||||
import work.slhaf.partner.core.action.entity.PhaserRecord;
|
||||
import work.slhaf.partner.core.action.entity.cache.ActionCacheData;
|
||||
import work.slhaf.partner.core.action.entity.cache.CacheAdjustData;
|
||||
import work.slhaf.partner.core.action.entity.cache.CacheAdjustMetaData;
|
||||
import work.slhaf.partner.core.action.exception.ActionDataNotFoundException;
|
||||
import work.slhaf.partner.core.action.exception.MetaActionNotFoundException;
|
||||
import work.slhaf.partner.core.action.runner.RunnerClient;
|
||||
import work.slhaf.partner.core.action.runner.SandboxRunnerClient;
|
||||
import work.slhaf.partner.module.modules.action.interventor.entity.InterventionType;
|
||||
import work.slhaf.partner.module.modules.action.interventor.entity.MetaIntervention;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.*;
|
||||
import java.util.concurrent.*;
|
||||
import java.util.concurrent.locks.Lock;
|
||||
import java.util.concurrent.locks.ReentrantLock;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
@SuppressWarnings("FieldMayBeFinal")
|
||||
@CapabilityCore(value = "action")
|
||||
@Slf4j
|
||||
public class ActionCore extends PartnerCore<ActionCore> {
|
||||
|
||||
/**
|
||||
* 持久行动池
|
||||
*/
|
||||
private CopyOnWriteArraySet<ActionData> actionPool = new CopyOnWriteArraySet<>();
|
||||
|
||||
/**
|
||||
* 待确认任务,以userId区分不同用户,因为需要跨请求确认
|
||||
*/
|
||||
private HashMap<String, List<ActionData>> pendingActions = new HashMap<>();
|
||||
|
||||
/**
|
||||
* 语义缓存与行为倾向映射
|
||||
*/
|
||||
private List<ActionCacheData> actionCache = new ArrayList<>();
|
||||
|
||||
private final Lock cacheLock = new ReentrantLock();
|
||||
|
||||
// 由于当前的执行器逻辑实现,平台线程池大小不得小于 2,这里规定为最小为 4
|
||||
private final ExecutorService platformExecutor = Executors
|
||||
.newFixedThreadPool(Math.max(Runtime.getRuntime().availableProcessors(), 4));
|
||||
private final ExecutorService virtualExecutor = Executors.newVirtualThreadPerTaskExecutor();
|
||||
|
||||
/**
|
||||
* 已存在的行动程序,键格式为‘<MCP-ServerName>::<Tool-Name>’,值为 MCP Server 通过 Resources 相关渠道传递的行动程序元信息
|
||||
*/
|
||||
private final ConcurrentHashMap<String, MetaActionInfo> existedMetaActions = new ConcurrentHashMap<>();
|
||||
private final List<PhaserRecord> phaserRecords = new ArrayList<>();
|
||||
private RunnerClient runnerClient;
|
||||
|
||||
public ActionCore() throws IOException, ClassNotFoundException {
|
||||
// TODO 通过 AgentConfigManager指定采用何种 runnerClient
|
||||
runnerClient = new SandboxRunnerClient(existedMetaActions, virtualExecutor);
|
||||
setupShutdownHook();
|
||||
}
|
||||
|
||||
private void setupShutdownHook() {
|
||||
// 将执行中的行动状态置为失败
|
||||
val executingActionSet = listActions(ActionData.ActionStatus.EXECUTING, null);
|
||||
for (ActionData actionData : executingActionSet) {
|
||||
actionData.setStatus(ActionData.ActionStatus.FAILED);
|
||||
actionData.setResult("由于系统中断而失败");
|
||||
}
|
||||
}
|
||||
|
||||
@CapabilityMethod
|
||||
public void putAction(@NonNull ActionData actionData) {
|
||||
actionPool.removeIf(data -> data.getUuid().equals(actionData.getUuid())); // 用来应对 ScheduledActionData 的重新排列
|
||||
actionPool.add(actionData);
|
||||
}
|
||||
|
||||
@CapabilityMethod
|
||||
public Set<ActionData> listActions(@Nullable ActionData.ActionStatus actionStatus, @Nullable String source) {
|
||||
return actionPool.stream()
|
||||
.filter(actionData -> actionStatus == null || actionData.getStatus().equals(actionStatus))
|
||||
.filter(actionData -> source == null || actionData.getSource().equals(source))
|
||||
.collect(Collectors.toSet());
|
||||
}
|
||||
|
||||
@CapabilityMethod
|
||||
public synchronized void putPendingActions(String userId, ActionData actionData) {
|
||||
pendingActions.computeIfAbsent(userId, k -> {
|
||||
List<ActionData> temp = new ArrayList<>();
|
||||
temp.add(actionData);
|
||||
return temp;
|
||||
});
|
||||
}
|
||||
|
||||
@CapabilityMethod
|
||||
public synchronized List<ActionData> popPendingAction(String userId) {
|
||||
List<ActionData> infos = pendingActions.get(userId);
|
||||
pendingActions.remove(userId);
|
||||
return infos;
|
||||
}
|
||||
|
||||
@CapabilityMethod
|
||||
public List<ActionData> listPendingAction(String userId) {
|
||||
return pendingActions.get(userId);
|
||||
}
|
||||
|
||||
/**
|
||||
* 计算输入内容的语义向量,根据与{@link ActionCacheData#getInputVector()}的相似度挑取缓存,后续将根据评估结果来更新计数
|
||||
*
|
||||
* @param input 本次输入内容
|
||||
* @return 命中的行为倾向集合
|
||||
*/
|
||||
@CapabilityMethod
|
||||
public List<String> selectTendencyCache(String input) {
|
||||
if (!VectorClient.status) {
|
||||
return null;
|
||||
}
|
||||
VectorClient vectorClient = VectorClient.INSTANCE;
|
||||
// 计算本次输入的向量
|
||||
float[] vector = vectorClient.compute(input);
|
||||
if (vector == null)
|
||||
return null;
|
||||
// 与现有缓存比对,将匹配到的收集并返回
|
||||
return actionCache.parallelStream()
|
||||
.filter(ActionCacheData::isActivated)
|
||||
.filter(data -> {
|
||||
double compared = vectorClient.compare(vector, data.getInputVector());
|
||||
return compared > data.getThreshold();
|
||||
})
|
||||
.map(ActionCacheData::getTendency)
|
||||
.collect(Collectors.toList());
|
||||
}
|
||||
|
||||
@CapabilityMethod
|
||||
public void updateTendencyCache(CacheAdjustData data) {
|
||||
VectorClient vectorClient = VectorClient.INSTANCE;
|
||||
List<CacheAdjustMetaData> list = data.getMetaDataList();
|
||||
String input = data.getInput();
|
||||
float[] inputVector = vectorClient.compute(input);
|
||||
|
||||
List<CacheAdjustMetaData> matchAndPassed = new ArrayList<>();
|
||||
List<CacheAdjustMetaData> matchNotPassed = new ArrayList<>();
|
||||
List<CacheAdjustMetaData> notMatchPassed = new ArrayList<>();
|
||||
|
||||
for (CacheAdjustMetaData metaData : list) {
|
||||
if (metaData.isHit() && metaData.isPassed()) {
|
||||
matchAndPassed.add(metaData);
|
||||
} else if (metaData.isHit()) {
|
||||
matchNotPassed.add(metaData);
|
||||
} else if (!metaData.isPassed()) {
|
||||
notMatchPassed.add(metaData);
|
||||
}
|
||||
}
|
||||
|
||||
platformExecutor.execute(() -> adjustMatchAndPassed(matchAndPassed, inputVector, input, vectorClient));
|
||||
platformExecutor.execute(() -> adjustMatchNotPassed(matchNotPassed, vectorClient));
|
||||
platformExecutor.execute(() -> adjustNotMatchPassed(notMatchPassed, inputVector, input, vectorClient));
|
||||
}
|
||||
|
||||
@CapabilityMethod
|
||||
public ExecutorService getExecutor(ExecutorType type) {
|
||||
return switch (type) {
|
||||
case VIRTUAL -> virtualExecutor;
|
||||
case PLATFORM -> platformExecutor;
|
||||
};
|
||||
}
|
||||
|
||||
@CapabilityMethod
|
||||
public Map<String, MetaActionInfo> listAvailableActions() {
|
||||
return existedMetaActions;
|
||||
}
|
||||
|
||||
@CapabilityMethod
|
||||
public synchronized PhaserRecord putPhaserRecord(Phaser phaser, ActionData actionData) {
|
||||
PhaserRecord record = new PhaserRecord(phaser, actionData);
|
||||
phaserRecords.add(record);
|
||||
return record;
|
||||
}
|
||||
|
||||
@CapabilityMethod
|
||||
public synchronized void removePhaserRecord(Phaser phaser) {
|
||||
PhaserRecord remove = null;
|
||||
for (PhaserRecord record : phaserRecords) {
|
||||
if (record.phaser().equals(phaser)) {
|
||||
remove = record;
|
||||
}
|
||||
}
|
||||
|
||||
if (remove != null) {
|
||||
phaserRecords.remove(remove);
|
||||
}
|
||||
}
|
||||
|
||||
@CapabilityMethod
|
||||
public PhaserRecord getPhaserRecord(String tendency, String source) {
|
||||
for (PhaserRecord record : phaserRecords) {
|
||||
ActionData data = record.actionData();
|
||||
if (data.getTendency().equals(tendency) && data.getSource().equals(source)) {
|
||||
return record;
|
||||
}
|
||||
}
|
||||
throw new ActionDataNotFoundException("未找到对应的 Phaser 记录: tendency=" + tendency + ", source=" + source);
|
||||
}
|
||||
|
||||
@CapabilityMethod
|
||||
public MetaAction loadMetaAction(@NonNull String actionKey) {
|
||||
MetaActionInfo metaActionInfo = existedMetaActions.get(actionKey);
|
||||
if (metaActionInfo == null) {
|
||||
throw new MetaActionNotFoundException("未找到对应的行动程序信息" + actionKey);
|
||||
}
|
||||
|
||||
String[] split = actionKey.split("::");
|
||||
if (split.length < 2) {
|
||||
throw new MetaActionNotFoundException("未找到对应的行动程序,原因: 传入的 actionKey(" + actionKey + ") 存在异常");
|
||||
}
|
||||
return new MetaAction(
|
||||
split[1],
|
||||
metaActionInfo.isIo(),
|
||||
MetaAction.Type.MCP,
|
||||
split[0]
|
||||
);
|
||||
}
|
||||
|
||||
@CapabilityMethod
|
||||
public List<PhaserRecord> listPhaserRecords() {
|
||||
return phaserRecords;
|
||||
}
|
||||
|
||||
@CapabilityMethod
|
||||
public MetaActionInfo loadMetaActionInfo(@NonNull String actionKey) {
|
||||
MetaActionInfo info = existedMetaActions.get(actionKey);
|
||||
if (info == null) {
|
||||
throw new MetaActionNotFoundException("未找到对应的行动程序描述信息: " + actionKey);
|
||||
}
|
||||
return info;
|
||||
}
|
||||
|
||||
@CapabilityMethod
|
||||
public boolean checkExists(String... actionKeys) {
|
||||
return existedMetaActions.keySet().containsAll(Arrays.asList(actionKeys));
|
||||
}
|
||||
|
||||
@CapabilityMethod
|
||||
public RunnerClient runnerClient() {
|
||||
return runnerClient;
|
||||
}
|
||||
|
||||
@CapabilityMethod
|
||||
public void handleInterventions(List<MetaIntervention> interventions, ActionData actionData) {
|
||||
// 加载数据
|
||||
if (actionData == null) {
|
||||
return;
|
||||
}
|
||||
|
||||
// 加锁确保同步
|
||||
synchronized (actionData.getStatus()) {
|
||||
applyInterventions(interventions, actionData);
|
||||
}
|
||||
}
|
||||
|
||||
private void applyInterventions(List<MetaIntervention> interventions, ActionData actionData) {
|
||||
boolean rebuildCleanTag = false;
|
||||
|
||||
interventions.sort(Comparator.comparingInt(MetaIntervention::getOrder));
|
||||
|
||||
for (MetaIntervention intervention : interventions) {
|
||||
List<MetaAction> actions = intervention.getActions()
|
||||
.stream()
|
||||
.map(this::loadMetaAction)
|
||||
.toList();
|
||||
|
||||
switch (intervention.getType()) {
|
||||
case InterventionType.APPEND -> handleAppend(actionData, intervention.getOrder(), actions);
|
||||
case InterventionType.INSERT -> handleInsert(actionData, intervention.getOrder(), actions);
|
||||
case InterventionType.DELETE -> handleDelete(actionData, intervention.getOrder(), actions);
|
||||
case InterventionType.CANCEL -> handleCancel(actionData);
|
||||
case InterventionType.REBUILD -> {
|
||||
if (!rebuildCleanTag) {
|
||||
cleanActionData(actionData);
|
||||
rebuildCleanTag = true;
|
||||
}
|
||||
handleRebuild(actionData, intervention.getOrder(), actions);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
/**
|
||||
* 在未进入执行阶段的行动单元组新增新的行动
|
||||
*/
|
||||
private void handleAppend(ActionData actionData, int order, List<MetaAction> actions) {
|
||||
if (order <= actionData.getExecutingStage())
|
||||
return;
|
||||
|
||||
actionData.getActionChain().put(order, actions);
|
||||
}
|
||||
|
||||
/**
|
||||
* 在未进入执行阶段和正处于行动阶段的行动单元组插入新的行动
|
||||
*/
|
||||
private void handleInsert(ActionData actionData, int order, List<MetaAction> actions) {
|
||||
if (order < actionData.getExecutingStage())
|
||||
return;
|
||||
|
||||
actionData.getActionChain().computeIfAbsent(order, k -> new ArrayList<>()).addAll(actions);
|
||||
}
|
||||
|
||||
private void handleDelete(ActionData actionData, int order, List<MetaAction> actions) {
|
||||
if (order <= actionData.getExecutingStage())
|
||||
return;
|
||||
|
||||
Map<Integer, List<MetaAction>> actionChain = actionData.getActionChain();
|
||||
if (actionChain.containsKey(order)) {
|
||||
actionChain.get(order).removeAll(actions);
|
||||
if (actionChain.get(order).isEmpty()) {
|
||||
actionChain.remove(order);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private void handleCancel(ActionData actionData) {
|
||||
actionData.setStatus(ActionData.ActionStatus.FAILED);
|
||||
actionData.setResult("行动取消");
|
||||
}
|
||||
|
||||
private void handleRebuild(ActionData actionData, int order, List<MetaAction> actions) {
|
||||
Map<Integer, List<MetaAction>> actionChain = actionData.getActionChain();
|
||||
actionChain.put(order, actions);
|
||||
}
|
||||
|
||||
private void cleanActionData(ActionData actionData) {
|
||||
actionData.getActionChain().clear();
|
||||
actionData.setExecutingStage(0);
|
||||
actionData.setStatus(ActionData.ActionStatus.PREPARE);
|
||||
actionData.getHistory().clear();
|
||||
}
|
||||
|
||||
/**
|
||||
* 命中缓存且评估通过时
|
||||
*
|
||||
* @param matchAndPassed 该类型的带调整缓存信息列表
|
||||
* @param inputVector 本次输入内容的语义向量
|
||||
* @param vectorClient 向量客户端
|
||||
*/
|
||||
private void adjustMatchAndPassed(List<CacheAdjustMetaData> matchAndPassed, float[] inputVector, String input,
|
||||
VectorClient vectorClient) {
|
||||
matchAndPassed.forEach(adjustData -> {
|
||||
// 获取原始缓存条目
|
||||
String tendency = adjustData.getTendency();
|
||||
ActionCacheData primaryCacheData = selectCacheData(tendency);
|
||||
if (primaryCacheData == null) {
|
||||
return;
|
||||
}
|
||||
primaryCacheData.updateAfterMatchAndPassed(inputVector, vectorClient, input);
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* 针对命中缓存、但评估未通过的条目与输入进行处理
|
||||
*
|
||||
* @param matchNotPassed 该类型的带调整缓存信息列表
|
||||
* @param vectorClient 向量客户端
|
||||
*/
|
||||
private void adjustMatchNotPassed(List<CacheAdjustMetaData> matchNotPassed, VectorClient vectorClient) {
|
||||
List<ActionCacheData> toRemove = new ArrayList<>();
|
||||
matchNotPassed.forEach(adjustData -> {
|
||||
// 获取原始缓存条目
|
||||
String tendency = adjustData.getTendency();
|
||||
ActionCacheData primaryCacheData = selectCacheData(tendency);
|
||||
if (primaryCacheData == null) {
|
||||
return;
|
||||
}
|
||||
boolean remove = primaryCacheData.updateAfterMatchNotPassed(vectorClient);
|
||||
if (remove) {
|
||||
toRemove.add(primaryCacheData);
|
||||
}
|
||||
|
||||
});
|
||||
cacheLock.lock();
|
||||
actionCache.removeAll(toRemove);
|
||||
cacheLock.unlock();
|
||||
}
|
||||
|
||||
/**
|
||||
* 针对未命中但评估通过的缓存做出调整:
|
||||
* <ol>
|
||||
* <h3>如果存在缓存条目</h3>
|
||||
* <li>
|
||||
* 若已生效,但此时未匹配到则说明尚未生效或者阈值、向量{@link ActionCacheData#getInputVector()}存在问题,调低阈值,同时带权移动平均
|
||||
* </li>
|
||||
* <li>
|
||||
* 若未生效,则只增加计数并带权移动平均
|
||||
* </li>
|
||||
* </ol>
|
||||
* 如果不存在缓存条目,则新增并填充字段
|
||||
*
|
||||
* @param notMatchPassed 该类型的带调整缓存信息列表
|
||||
* @param inputVector 本次输入内容的语义向量
|
||||
* @param input 本次输入内容
|
||||
* @param vectorClient 向量客户端
|
||||
*/
|
||||
private void adjustNotMatchPassed(List<CacheAdjustMetaData> notMatchPassed, float[] inputVector, String input,
|
||||
VectorClient vectorClient) {
|
||||
notMatchPassed.forEach(adjustData -> {
|
||||
// 获取原始缓存条目
|
||||
String tendency = adjustData.getTendency();
|
||||
ActionCacheData primaryCacheData = selectCacheData(tendency);
|
||||
float[] tendencyVector = vectorClient.compute(tendency);
|
||||
if (primaryCacheData == null) {
|
||||
actionCache.add(new ActionCacheData(tendency, tendencyVector, inputVector, input));
|
||||
return;
|
||||
}
|
||||
primaryCacheData.updateAfterNotMatchPassed(input, inputVector, tendencyVector, vectorClient);
|
||||
});
|
||||
}
|
||||
|
||||
private ActionCacheData selectCacheData(String tendency) {
|
||||
for (ActionCacheData actionCacheData : actionCache) {
|
||||
if (actionCacheData.getTendency().equals(tendency)) {
|
||||
return actionCacheData;
|
||||
}
|
||||
}
|
||||
log.warn("[{}] 未找到行为倾向[{}]对应的缓存条目,可能是代码逻辑存在错误", getCoreKey(), tendency);
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected String getCoreKey() {
|
||||
return "action-core";
|
||||
}
|
||||
|
||||
public enum ExecutorType {
|
||||
VIRTUAL, PLATFORM
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,143 @@
|
||||
package work.slhaf.partner.core.action.entity
|
||||
|
||||
import work.slhaf.partner.module.modules.action.dispatcher.executor.entity.HistoryAction
|
||||
import java.time.ZonedDateTime
|
||||
import java.util.*
|
||||
|
||||
/**
|
||||
* 行动模块传递的行动数据,包含行动uuid、倾向、状态、行动链、结果、发起原因、行动描述等信息。
|
||||
*/
|
||||
sealed class ActionData {
|
||||
/**
|
||||
* 行动ID
|
||||
*/
|
||||
val uuid: String = UUID.randomUUID().toString()
|
||||
|
||||
/**
|
||||
* 行动倾向
|
||||
*/
|
||||
abstract val tendency: String
|
||||
|
||||
/**
|
||||
* 行动状态
|
||||
*/
|
||||
var status: ActionStatus = ActionStatus.PREPARE
|
||||
|
||||
/**
|
||||
* 行动链
|
||||
*/
|
||||
abstract val actionChain: MutableMap<Int, MutableList<MetaAction>>
|
||||
|
||||
/**
|
||||
* 行动阶段(当前阶段)
|
||||
*/
|
||||
var executingStage: Int = 0
|
||||
|
||||
/**
|
||||
* 行动结果
|
||||
*/
|
||||
lateinit var result: String
|
||||
|
||||
val history: MutableMap<Int, MutableList<HistoryAction>> = mutableMapOf()
|
||||
|
||||
/**
|
||||
* 修复上下文
|
||||
*/
|
||||
val additionalContext: MutableMap<Int, MutableList<String>> = mutableMapOf()
|
||||
|
||||
/**
|
||||
* 行动原因
|
||||
*/
|
||||
abstract val reason: String
|
||||
|
||||
/**
|
||||
* 行动描述
|
||||
*/
|
||||
abstract val description: String
|
||||
|
||||
/**
|
||||
* 行动来源
|
||||
*/
|
||||
abstract val source: String
|
||||
|
||||
enum class ActionStatus {
|
||||
/**
|
||||
* 执行成功
|
||||
*/
|
||||
SUCCESS,
|
||||
|
||||
/**
|
||||
* 执行失败
|
||||
*/
|
||||
FAILED,
|
||||
|
||||
/**
|
||||
* 执行中
|
||||
*/
|
||||
EXECUTING,
|
||||
|
||||
/**
|
||||
* 暂时中断
|
||||
*/
|
||||
INTERRUPTED,
|
||||
|
||||
/**
|
||||
* 预备执行
|
||||
*/
|
||||
PREPARE
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 计划行动数据类,继承自{@link ActionData},扩展了属性{@link ScheduledActionData#type}和{@link ScheduledActionData#scheduleContent},用于标识计划类型(单次还是周期性任务)和计划内容
|
||||
*/
|
||||
data class ScheduledActionData(
|
||||
override val tendency: String,
|
||||
override val actionChain: MutableMap<Int, MutableList<MetaAction>>,
|
||||
override val reason: String,
|
||||
override val description: String,
|
||||
override val source: String,
|
||||
val scheduleType: ScheduleType,
|
||||
val scheduleContent: String,
|
||||
) : ActionData() {
|
||||
|
||||
val scheduleHistories = ArrayList<ScheduleHistory>()
|
||||
|
||||
fun recordAndReset() {
|
||||
val newHistory = ScheduleHistory(ZonedDateTime.now(), result, history.toMap())
|
||||
scheduleHistories.add(newHistory)
|
||||
|
||||
additionalContext.clear()
|
||||
executingStage = 0
|
||||
for (entry in actionChain) {
|
||||
for (action in entry.value) {
|
||||
action.params.clear()
|
||||
action.result.reset()
|
||||
}
|
||||
}
|
||||
|
||||
status = ActionStatus.PREPARE
|
||||
}
|
||||
|
||||
enum class ScheduleType {
|
||||
CYCLE,
|
||||
ONCE
|
||||
}
|
||||
|
||||
data class ScheduleHistory(
|
||||
val endTime: ZonedDateTime,
|
||||
val result: String,
|
||||
val history: Map<Int, List<HistoryAction>>
|
||||
)
|
||||
}
|
||||
|
||||
/**
|
||||
* 即时行动数据类
|
||||
*/
|
||||
data class ImmediateActionData(
|
||||
override val tendency: String,
|
||||
override val actionChain: MutableMap<Int, MutableList<MetaAction>>,
|
||||
override val reason: String,
|
||||
override val description: String,
|
||||
override val source: String,
|
||||
) : ActionData()
|
||||
@@ -0,0 +1,10 @@
|
||||
package work.slhaf.partner.core.action.entity;
|
||||
|
||||
import lombok.Data;
|
||||
|
||||
@Data
|
||||
public class ActionFileMetaData {
|
||||
private String content;
|
||||
private String name;
|
||||
private String ext;
|
||||
}
|
||||
@@ -0,0 +1,15 @@
|
||||
package work.slhaf.partner.core.action.entity;
|
||||
|
||||
import com.alibaba.fastjson2.JSONObject;
|
||||
import lombok.Data;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
@Data
|
||||
public class GeneratedData {
|
||||
private List<String> dependencies;
|
||||
private String code;
|
||||
private String codeType;
|
||||
private boolean serialize;
|
||||
private JSONObject responseSchema;
|
||||
}
|
||||
@@ -0,0 +1,5 @@
|
||||
package work.slhaf.partner.core.action.entity;
|
||||
|
||||
public class McpData {
|
||||
|
||||
}
|
||||
@@ -0,0 +1,73 @@
|
||||
package work.slhaf.partner.core.action.entity
|
||||
|
||||
|
||||
/**
|
||||
* 行动链中的单一元素,封装了调用外部行动程序的必要信息与结果容器,可被[work.slhaf.partner.core.action.ActionCapability]执行
|
||||
*/
|
||||
data class MetaAction(
|
||||
/**
|
||||
* 行动name,用于标识行动程序
|
||||
*/
|
||||
val name: String,
|
||||
/**
|
||||
* 是否IO密集,用于决定使用何种线程池
|
||||
*/
|
||||
val io: Boolean = false,
|
||||
/**
|
||||
* 行动程序类型,可分为 MCP、ORIGIN 两种,前者对应读取到的 MCP Tool、后者对应生成的临时行动程序
|
||||
*/
|
||||
val type: Type,
|
||||
/**
|
||||
* 当类型为 MCP 时,该字段对应相应 MCP Client 注册时生成的 id;
|
||||
* 当类型为 ORIGIN 时,该字段对应相应的磁盘路径字符串
|
||||
*/
|
||||
val location: String,
|
||||
) {
|
||||
|
||||
/**
|
||||
* 行动程序可接受的参数,由调用处设置
|
||||
*/
|
||||
val params: MutableMap<String, Any> = mutableMapOf()
|
||||
|
||||
/**
|
||||
* 行动结果,包括执行状态和相应内容(执行结果或者错误信息)
|
||||
*/
|
||||
val result = Result()
|
||||
|
||||
val key: String
|
||||
/**
|
||||
* actionKey 将由 location+name 共同定位
|
||||
*
|
||||
* @return actionKey
|
||||
*/
|
||||
get() = "$location::$name"
|
||||
|
||||
class Result {
|
||||
var status = Status.WAITING
|
||||
var data: String? = null
|
||||
|
||||
fun reset() {
|
||||
status = Status.WAITING
|
||||
data = null
|
||||
}
|
||||
|
||||
enum class Status {
|
||||
SUCCESS,
|
||||
FAILED,
|
||||
WAITING
|
||||
}
|
||||
}
|
||||
|
||||
enum class Type {
|
||||
/**
|
||||
* 将调用的 MCP 工具,可包括远程、本地任意服务
|
||||
*/
|
||||
MCP,
|
||||
|
||||
/**
|
||||
* 适用于‘临时生成’的行动程序,在生成后根据序列化选项及执行情况,进行持久化
|
||||
*/
|
||||
ORIGIN
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,25 @@
|
||||
package work.slhaf.partner.core.action.entity;
|
||||
|
||||
import com.alibaba.fastjson2.JSONObject;
|
||||
import lombok.Data;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
@Data
|
||||
public class MetaActionInfo {
|
||||
private boolean io;
|
||||
|
||||
private Map<String, Object> params;
|
||||
private String description;
|
||||
private List<String> tags;
|
||||
|
||||
private List<String> preActions;
|
||||
private List<String> postActions;
|
||||
/**
|
||||
* 是否严格依赖前置行动的成功执行,若为true且前置行动失败则不执行该行动,后置任务多为触发式。默认即执行。
|
||||
*/
|
||||
private boolean strictDependencies;
|
||||
|
||||
private JSONObject responseSchema;
|
||||
}
|
||||
@@ -0,0 +1,33 @@
|
||||
package work.slhaf.partner.core.action.entity;
|
||||
|
||||
import work.slhaf.partner.core.action.entity.ActionData.ActionStatus;
|
||||
|
||||
import java.util.concurrent.Phaser;
|
||||
|
||||
public record PhaserRecord(Phaser phaser, ActionData actionData) {
|
||||
|
||||
public void fail() {
|
||||
actionData.setStatus(ActionStatus.FAILED);
|
||||
}
|
||||
|
||||
/**
|
||||
* 负责将 ActionData 的状态设置为 INTERRUPTED
|
||||
* 同时循环检查进行阻塞
|
||||
*/
|
||||
public void interrupt() {
|
||||
actionData.setStatus(ActionStatus.INTERRUPTED);
|
||||
while (actionData().getStatus() == ActionStatus.INTERRUPTED) {
|
||||
try {
|
||||
Thread.sleep(500);
|
||||
} catch (InterruptedException ignored) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 将状态重新设置为 EXECUTING ,恢复 interrupt 阻塞状态
|
||||
*/
|
||||
public void complete() {
|
||||
actionData().setStatus(ActionStatus.EXECUTING);
|
||||
}
|
||||
}
|
||||
181
Partner-Main/src/main/java/work/slhaf/partner/core/action/entity/cache/ActionCacheData.java
vendored
Normal file
181
Partner-Main/src/main/java/work/slhaf/partner/core/action/entity/cache/ActionCacheData.java
vendored
Normal file
@@ -0,0 +1,181 @@
|
||||
package work.slhaf.partner.core.action.entity.cache;
|
||||
|
||||
import lombok.Data;
|
||||
import work.slhaf.partner.common.vector.VectorClient;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
@Data
|
||||
public class ActionCacheData {
|
||||
private boolean activated = false;
|
||||
private int inputMatchCount = 1;
|
||||
|
||||
private float[] inputVector;
|
||||
private float[] tendencyVector;
|
||||
private String tendency;
|
||||
private double threshold = 0.75;
|
||||
|
||||
private List<String> validSamples = new ArrayList<>();
|
||||
private int failedCount = 0;
|
||||
private Type type = Type.PRIMARY;
|
||||
|
||||
public ActionCacheData(String tendency, float[] tendencyVector, float[] inputVector, String input) {
|
||||
this.tendency = tendency;
|
||||
this.inputVector = inputVector;
|
||||
this.tendencyVector = tendencyVector;
|
||||
this.validSamples.add(input);
|
||||
}
|
||||
|
||||
/**
|
||||
* 命中缓存且评估通过时,根据输入内容的语义向量与现有的输入语义向量进行带权移动平均,以相似度为权重,同时降低失败计数,为零时置为上一级缓存类型{@link ActionCacheData.Type}
|
||||
*
|
||||
* @param inputVector 本次输入内容对应的语义向量
|
||||
* @param vectorClient 向量客户端
|
||||
* @param input 本次输入内容
|
||||
*/
|
||||
public synchronized void updateAfterMatchAndPassed(float[] inputVector, VectorClient vectorClient, String input) {
|
||||
updateInputVector(inputVector, vectorClient);
|
||||
addValidSample(input);
|
||||
reduceFailedCount();
|
||||
updateType();
|
||||
addInputMatchCount();
|
||||
}
|
||||
|
||||
private void updateType() {
|
||||
if (this.failedCount == 0) {
|
||||
this.type = switch (type) {
|
||||
case PRIMARY, REBUILD_V1 -> ActionCacheData.Type.PRIMARY;
|
||||
case REBUILD_V2 -> ActionCacheData.Type.REBUILD_V1;
|
||||
case REBUILD_V3 -> ActionCacheData.Type.REBUILD_V2;
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
private void reduceFailedCount() {
|
||||
this.failedCount = Math.max(this.failedCount - 1, 0);
|
||||
}
|
||||
|
||||
private void addValidSample(String input) {
|
||||
if (this.validSamples.size() == 12) {
|
||||
this.validSamples.removeFirst();
|
||||
}
|
||||
this.validSamples.add(input);
|
||||
}
|
||||
|
||||
private void updateInputVector(float[] inputVector, VectorClient vectorClient) {
|
||||
this.inputVector = vectorClient.weightedAverage(inputVector, this.inputVector);
|
||||
}
|
||||
|
||||
/**
|
||||
* 针对命中缓存、但评估未通过的条目与输入进行处理: 增加失败计数(必要时重建并更新类型等级)、调高阈值(0.02),由于缓存匹配但评估未通过,所以不进行带权移动平均
|
||||
*
|
||||
* @param vectorClient 向量客户端
|
||||
* @return 是否需要删除(已在REBUILD_V3状态且达到最大误判次数的)
|
||||
*/
|
||||
public synchronized boolean updateAfterMatchNotPassed(VectorClient vectorClient) {
|
||||
adjustThreshold();
|
||||
addFailedCount();
|
||||
if (this.failedCount < 3) {
|
||||
return false;
|
||||
}
|
||||
if (this.type == Type.REBUILD_V3) {
|
||||
return true;
|
||||
}
|
||||
rebuildAndSwitchType(vectorClient);
|
||||
return false;
|
||||
}
|
||||
|
||||
private void rebuildAndSwitchType(VectorClient vectorClient) {
|
||||
this.type = switch (this.type) {
|
||||
case PRIMARY -> {
|
||||
//样本顺序反转后,以全部样本重建
|
||||
this.validSamples = this.validSamples.reversed();
|
||||
rebuildWithSamples(vectorClient);
|
||||
yield Type.REBUILD_V1;
|
||||
}
|
||||
case REBUILD_V1 -> {
|
||||
//截取后一半样本,反转后以此重建
|
||||
List<String> temp = this.validSamples.subList(this.validSamples.size() / 2, this.validSamples.size());
|
||||
this.validSamples = temp.reversed();
|
||||
rebuildWithSamples(vectorClient);
|
||||
yield Type.REBUILD_V2;
|
||||
}
|
||||
case REBUILD_V2 -> {
|
||||
//截取后四分之一样本,反转后以此重建
|
||||
List<String> temp = this.validSamples.subList(this.validSamples.size() / 4, this.validSamples.size());
|
||||
this.validSamples = temp.reversed();
|
||||
rebuildWithSamples(vectorClient);
|
||||
yield Type.REBUILD_V3;
|
||||
}
|
||||
case REBUILD_V3 -> null;
|
||||
};
|
||||
//阈值减0.05,防止重建后一直升高
|
||||
this.threshold = Math.max(this.threshold - 0.05, 0.75);
|
||||
this.failedCount = 0;
|
||||
}
|
||||
|
||||
private void rebuildWithSamples(VectorClient vectorClient) {
|
||||
for (int i = 0; i < this.validSamples.size(); i++) {
|
||||
String sample = this.validSamples.get(i);
|
||||
if (i == 0) {
|
||||
this.inputVector = vectorClient.compute(sample);
|
||||
} else {
|
||||
float[] newSampleVector = vectorClient.compute(sample);
|
||||
this.inputVector = vectorClient.weightedAverage(this.inputVector, newSampleVector);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private void addFailedCount() {
|
||||
this.failedCount = Math.min(this.failedCount + 1, 3);
|
||||
}
|
||||
|
||||
private void adjustThreshold() {
|
||||
double newThreshold = this.threshold + 0.03;
|
||||
this.threshold = Math.min(newThreshold, 0.95);
|
||||
}
|
||||
|
||||
/**
|
||||
* 针对未命中但评估通过的已存在缓存做出调整:
|
||||
* <ol>
|
||||
* <li>
|
||||
* 若已生效,但此时未匹配到则说明阈值或者向量{@link ActionCacheData#getInputVector()}存在问题,调低阈值,同时带权移动平均
|
||||
* </li>
|
||||
* <li>
|
||||
* 若未生效,则只增加计数并带权移动平均
|
||||
* </li>
|
||||
* </ol>
|
||||
*
|
||||
* @param input 本次输入内容
|
||||
* @param inputVector 本次输入内容对应的语义向量
|
||||
* @param tendencyVector 本次倾向对应的语义向量
|
||||
* @param vectorClient 向量客户端
|
||||
*/
|
||||
public synchronized void updateAfterNotMatchPassed(String input, float[] inputVector, float[] tendencyVector, VectorClient vectorClient) {
|
||||
if (this.activated) {
|
||||
reduceThreshold();
|
||||
this.inputVector = vectorClient.weightedAverage(inputVector, this.inputVector);
|
||||
} else {
|
||||
addValidSample(input);
|
||||
this.tendencyVector = vectorClient.weightedAverage(tendencyVector, this.tendencyVector);
|
||||
addInputMatchCount();
|
||||
}
|
||||
}
|
||||
|
||||
private void reduceThreshold() {
|
||||
double newThreshold = this.threshold - 0.02;
|
||||
this.threshold = Math.max(newThreshold, 0.75);
|
||||
}
|
||||
|
||||
private void addInputMatchCount() {
|
||||
this.inputMatchCount += 1;
|
||||
if (inputMatchCount >= 6) {
|
||||
this.activated = true;
|
||||
}
|
||||
}
|
||||
|
||||
public enum Type {
|
||||
PRIMARY, REBUILD_V1, REBUILD_V2, REBUILD_V3
|
||||
}
|
||||
}
|
||||
11
Partner-Main/src/main/java/work/slhaf/partner/core/action/entity/cache/CacheAdjustData.java
vendored
Normal file
11
Partner-Main/src/main/java/work/slhaf/partner/core/action/entity/cache/CacheAdjustData.java
vendored
Normal file
@@ -0,0 +1,11 @@
|
||||
package work.slhaf.partner.core.action.entity.cache;
|
||||
|
||||
import lombok.Data;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
@Data
|
||||
public class CacheAdjustData {
|
||||
private String input;
|
||||
private List<CacheAdjustMetaData> metaDataList;
|
||||
}
|
||||
10
Partner-Main/src/main/java/work/slhaf/partner/core/action/entity/cache/CacheAdjustMetaData.java
vendored
Normal file
10
Partner-Main/src/main/java/work/slhaf/partner/core/action/entity/cache/CacheAdjustMetaData.java
vendored
Normal file
@@ -0,0 +1,10 @@
|
||||
package work.slhaf.partner.core.action.entity.cache;
|
||||
|
||||
import lombok.Data;
|
||||
|
||||
@Data
|
||||
public class CacheAdjustMetaData {
|
||||
private String tendency;
|
||||
private boolean passed;
|
||||
private boolean hit;
|
||||
}
|
||||
@@ -0,0 +1,13 @@
|
||||
package work.slhaf.partner.core.action.exception;
|
||||
|
||||
import work.slhaf.partner.api.agent.runtime.exception.AgentRuntimeException;
|
||||
|
||||
public class ActionDataNotFoundException extends AgentRuntimeException {
|
||||
public ActionDataNotFoundException(String message) {
|
||||
super(message);
|
||||
}
|
||||
|
||||
public ActionDataNotFoundException(String message, Throwable cause) {
|
||||
super(message, cause);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,13 @@
|
||||
package work.slhaf.partner.core.action.exception;
|
||||
|
||||
import work.slhaf.partner.api.agent.runtime.exception.AgentLaunchFailedException;
|
||||
|
||||
public class ActionInitFailedException extends AgentLaunchFailedException {
|
||||
public ActionInitFailedException(String message, Throwable cause) {
|
||||
super(message, cause);
|
||||
}
|
||||
|
||||
public ActionInitFailedException(String message) {
|
||||
super(message);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,13 @@
|
||||
package work.slhaf.partner.core.action.exception;
|
||||
|
||||
import work.slhaf.partner.api.agent.runtime.exception.AgentRuntimeException;
|
||||
|
||||
public class ActionLoadFailedException extends AgentRuntimeException {
|
||||
public ActionLoadFailedException(String message) {
|
||||
super(message);
|
||||
}
|
||||
|
||||
public ActionLoadFailedException(String message, Throwable cause) {
|
||||
super(message, cause);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,13 @@
|
||||
package work.slhaf.partner.core.action.exception;
|
||||
|
||||
import work.slhaf.partner.api.agent.runtime.exception.AgentRuntimeException;
|
||||
|
||||
public class ActionSerializeFailedException extends AgentRuntimeException {
|
||||
public ActionSerializeFailedException(String message) {
|
||||
super(message);
|
||||
}
|
||||
|
||||
public ActionSerializeFailedException(String message, Throwable cause) {
|
||||
super(message, cause);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,13 @@
|
||||
package work.slhaf.partner.core.action.exception;
|
||||
|
||||
import work.slhaf.partner.api.agent.runtime.exception.AgentRuntimeException;
|
||||
|
||||
public class MetaActionNotFoundException extends AgentRuntimeException {
|
||||
public MetaActionNotFoundException(String message) {
|
||||
super(message);
|
||||
}
|
||||
|
||||
public MetaActionNotFoundException(String message, Throwable cause) {
|
||||
super(message, cause);
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,107 @@
|
||||
package work.slhaf.partner.core.action.runner;
|
||||
|
||||
import com.alibaba.fastjson2.JSONObject;
|
||||
import io.modelcontextprotocol.server.McpStatelessAsyncServer;
|
||||
import lombok.Data;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import lombok.val;
|
||||
import org.jetbrains.annotations.Nullable;
|
||||
import work.slhaf.partner.core.action.entity.ActionFileMetaData;
|
||||
import work.slhaf.partner.core.action.entity.MetaAction;
|
||||
import work.slhaf.partner.core.action.entity.MetaAction.Result;
|
||||
import work.slhaf.partner.core.action.entity.MetaActionInfo;
|
||||
import work.slhaf.partner.core.action.exception.ActionInitFailedException;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.nio.file.Files;
|
||||
import java.nio.file.Path;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
import java.util.concurrent.ExecutorService;
|
||||
|
||||
import static work.slhaf.partner.common.Constant.Path.DATA;
|
||||
import static work.slhaf.partner.common.util.PathUtil.buildPathStr;
|
||||
|
||||
/**
|
||||
* 执行客户端抽象类
|
||||
* <br/>
|
||||
* 只负责暴露序列化、执行等相应接口,具体逻辑交给下游实现
|
||||
* <br/>
|
||||
* 默认存在两类实现,{@link LocalRunnerClient} 和 {@link SandboxRunnerClient}
|
||||
* <ol>
|
||||
* LocalRunnerClient:
|
||||
* <li>
|
||||
* 对应本地运行环境,可在本地启动 MCP 客户端将 RunnerClient 暴露的能力接口转发至本地 MCP Client 并执行
|
||||
* </li>
|
||||
* SandboxRunnerClient:
|
||||
* <li>
|
||||
* 对应沙盒运行环境,该 Client 仅作为沙盒环境的客户端,不持有额外能力,仅保持远端连接已存在行动的内容更新
|
||||
* </li>
|
||||
* </ol>
|
||||
*/
|
||||
@Slf4j
|
||||
public abstract class RunnerClient {
|
||||
|
||||
protected final String ACTION_PATH;
|
||||
|
||||
protected final ConcurrentHashMap<String, MetaActionInfo> existedMetaActions;
|
||||
protected final ExecutorService executor;
|
||||
//TODO 仍可提供内部 MCP,但调用方式需要结合 AgentContext来获取,否则生命周期不合
|
||||
protected McpStatelessAsyncServer innerMcpServer;
|
||||
|
||||
/**
|
||||
* ActionCore 将注入虚拟线程池
|
||||
*/
|
||||
public RunnerClient(ConcurrentHashMap<String, MetaActionInfo> existedMetaActions, ExecutorService executor, @Nullable String baseActionPath) {
|
||||
this.existedMetaActions = existedMetaActions;
|
||||
this.executor = executor;
|
||||
baseActionPath = baseActionPath == null ? DATA : baseActionPath;
|
||||
this.ACTION_PATH = buildPathStr(baseActionPath, "action");
|
||||
|
||||
createPath(ACTION_PATH);
|
||||
}
|
||||
|
||||
/**
|
||||
* 执行行动程序
|
||||
*/
|
||||
public void submit(MetaAction metaAction) {
|
||||
// 获取已存在行动列表
|
||||
Result result = metaAction.getResult();
|
||||
if (!result.getStatus().equals(Result.Status.WAITING)) {
|
||||
return;
|
||||
}
|
||||
RunnerResponse response = doRun(metaAction);
|
||||
result.setData(response.getData());
|
||||
result.setStatus(response.isOk() ? Result.Status.SUCCESS : Result.Status.FAILED);
|
||||
}
|
||||
|
||||
protected abstract RunnerResponse doRun(MetaAction metaAction);
|
||||
|
||||
public abstract String buildTmpPath(String actionKey, String codeType);
|
||||
|
||||
public abstract void tmpSerialize(MetaAction tempAction, String code, String codeType) throws IOException;
|
||||
|
||||
public abstract void persistSerialize(MetaActionInfo metaActionInfo, ActionFileMetaData fileMetaData);
|
||||
|
||||
protected void createPath(String pathStr) {
|
||||
val path = Path.of(pathStr);
|
||||
try {
|
||||
Files.createDirectory(path);
|
||||
} catch (IOException e) {
|
||||
if (!Files.exists(path)) {
|
||||
throw new ActionInitFailedException("目录创建失败: " + pathStr, e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 列出执行环境下的系统依赖情况
|
||||
*/
|
||||
public abstract JSONObject listSysDependencies();
|
||||
|
||||
@Data
|
||||
public static class RunnerResponse {
|
||||
private boolean ok;
|
||||
private String data;
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,57 @@
|
||||
package work.slhaf.partner.core.action.runner;
|
||||
|
||||
import com.alibaba.fastjson2.JSONObject;
|
||||
import work.slhaf.partner.core.action.entity.ActionFileMetaData;
|
||||
import work.slhaf.partner.core.action.entity.MetaAction;
|
||||
import work.slhaf.partner.core.action.entity.MetaActionInfo;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
import java.util.concurrent.ExecutorService;
|
||||
|
||||
/**
|
||||
* 基于 Http 与 WebSocket 的沙盒执行器客户端,负责:
|
||||
* <ul>
|
||||
* <li>
|
||||
* 发送行动单元数据
|
||||
* </li>
|
||||
* <li>
|
||||
* 实时更新获取已存在行动列表
|
||||
* </li>
|
||||
* <li>
|
||||
* 向传入的 MetaAction 回写执行结果
|
||||
* </li>
|
||||
* </ul>
|
||||
*/
|
||||
public class SandboxRunnerClient extends RunnerClient {
|
||||
|
||||
public SandboxRunnerClient(ConcurrentHashMap<String, MetaActionInfo> existedMetaActions, ExecutorService executor) { // 连接沙盒执行器(websocket)
|
||||
super(existedMetaActions, executor, null);
|
||||
}
|
||||
|
||||
protected RunnerResponse doRun(MetaAction metaAction) {
|
||||
// 调用沙盒执行器
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public JSONObject listSysDependencies() {
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String buildTmpPath(String actionKey, String codeType) {
|
||||
throw new UnsupportedOperationException("Unimplemented method 'buildTmpPath'");
|
||||
}
|
||||
|
||||
@Override
|
||||
public void tmpSerialize(MetaAction tempAction, String code, String codeType) throws IOException {
|
||||
throw new UnsupportedOperationException("Unimplemented method 'tmpSerialize'");
|
||||
}
|
||||
|
||||
@Override
|
||||
public void persistSerialize(MetaActionInfo metaActionInfo, ActionFileMetaData fileMetaData) {
|
||||
throw new UnsupportedOperationException("Unimplemented method 'persistSerialize'");
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,28 @@
|
||||
package work.slhaf.partner.core.cognation;
|
||||
|
||||
import work.slhaf.partner.api.agent.factory.capability.annotation.Capability;
|
||||
import work.slhaf.partner.api.agent.factory.capability.annotation.ToCoordinated;
|
||||
import work.slhaf.partner.api.chat.pojo.Message;
|
||||
import work.slhaf.partner.api.chat.pojo.MetaMessage;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.concurrent.locks.Lock;
|
||||
|
||||
@Capability("cognation")
|
||||
public interface CognationCapability {
|
||||
|
||||
List<Message> getChatMessages();
|
||||
void cleanMessage(List<Message> messages);
|
||||
Lock getMessageLock();
|
||||
void addMetaMessage(String userId, MetaMessage metaMessage);
|
||||
List<Message> unpackAndClear(String userId);
|
||||
void refreshMemoryId();
|
||||
void resetLastUpdatedTime();
|
||||
long getLastUpdatedTime();
|
||||
HashMap<String,List<MetaMessage>> getSingleMetaMessageMap();
|
||||
String getCurrentMemoryId();
|
||||
|
||||
@ToCoordinated
|
||||
boolean isSingleUser();
|
||||
}
|
||||
@@ -0,0 +1,116 @@
|
||||
package work.slhaf.partner.core.cognation;
|
||||
|
||||
import com.alibaba.fastjson2.JSONObject;
|
||||
import lombok.EqualsAndHashCode;
|
||||
import lombok.Getter;
|
||||
import lombok.Setter;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import work.slhaf.partner.api.agent.factory.capability.annotation.CapabilityCore;
|
||||
import work.slhaf.partner.api.agent.factory.capability.annotation.CapabilityMethod;
|
||||
import work.slhaf.partner.api.chat.pojo.Message;
|
||||
import work.slhaf.partner.api.chat.pojo.MetaMessage;
|
||||
import work.slhaf.partner.core.PartnerCore;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.io.Serial;
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.UUID;
|
||||
import java.util.concurrent.locks.Lock;
|
||||
import java.util.concurrent.locks.ReentrantLock;
|
||||
|
||||
@EqualsAndHashCode(callSuper = true)
|
||||
@Slf4j
|
||||
@CapabilityCore(value = "cognation")
|
||||
@Getter
|
||||
@Setter
|
||||
public class CognationCore extends PartnerCore<CognationCore> {
|
||||
|
||||
@Serial
|
||||
private static final long serialVersionUID = 1L;
|
||||
|
||||
private final ReentrantLock messageLock = new ReentrantLock();
|
||||
|
||||
/**
|
||||
* 主模型的聊天记录
|
||||
*/
|
||||
private List<Message> chatMessages = new ArrayList<>();
|
||||
private HashMap<String /*startUserId*/, List<MetaMessage>> singleMetaMessageMap = new HashMap<>();
|
||||
private String currentMemoryId;
|
||||
private long lastUpdatedTime;
|
||||
|
||||
public CognationCore() throws IOException, ClassNotFoundException {
|
||||
}
|
||||
|
||||
@CapabilityMethod
|
||||
public List<Message> getChatMessages() {
|
||||
return chatMessages;
|
||||
}
|
||||
|
||||
@CapabilityMethod
|
||||
public long getLastUpdatedTime(){
|
||||
return lastUpdatedTime;
|
||||
}
|
||||
|
||||
@CapabilityMethod
|
||||
public HashMap<String,List<MetaMessage>> getSingleMetaMessageMap(){
|
||||
return singleMetaMessageMap;
|
||||
}
|
||||
|
||||
@CapabilityMethod
|
||||
public String getCurrentMemoryId(){
|
||||
return currentMemoryId;
|
||||
}
|
||||
|
||||
@CapabilityMethod
|
||||
public void cleanMessage(List<Message> messages) {
|
||||
messageLock.lock();
|
||||
this.getChatMessages().removeAll(messages);
|
||||
messageLock.unlock();
|
||||
|
||||
}
|
||||
|
||||
@CapabilityMethod
|
||||
public Lock getMessageLock() {
|
||||
return messageLock;
|
||||
}
|
||||
|
||||
@CapabilityMethod
|
||||
public void addMetaMessage(String userId, MetaMessage metaMessage) {
|
||||
log.debug("[{}] 当前会话历史: {}", getCoreKey(), JSONObject.toJSONString(singleMetaMessageMap));
|
||||
if (singleMetaMessageMap.containsKey(userId)) {
|
||||
singleMetaMessageMap.get(userId).add(metaMessage);
|
||||
} else {
|
||||
singleMetaMessageMap.put(userId, new java.util.ArrayList<>());
|
||||
singleMetaMessageMap.get(userId).add(metaMessage);
|
||||
}
|
||||
log.debug("[{}] 会话历史更新: {}", getCoreKey(), JSONObject.toJSONString(singleMetaMessageMap));
|
||||
}
|
||||
|
||||
@CapabilityMethod
|
||||
public List<Message> unpackAndClear(String userId) {
|
||||
List<Message> messages = new ArrayList<>();
|
||||
for (MetaMessage metaMessage : singleMetaMessageMap.get(userId)) {
|
||||
messages.add(metaMessage.getUserMessage());
|
||||
messages.add(metaMessage.getAssistantMessage());
|
||||
}
|
||||
singleMetaMessageMap.remove(userId);
|
||||
return messages;
|
||||
}
|
||||
|
||||
@CapabilityMethod
|
||||
public void refreshMemoryId() {
|
||||
currentMemoryId = UUID.randomUUID().toString();
|
||||
}
|
||||
|
||||
@CapabilityMethod
|
||||
public void resetLastUpdatedTime() {
|
||||
lastUpdatedTime = System.currentTimeMillis();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected String getCoreKey() {
|
||||
return "cognation-core";
|
||||
}
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user