mirror of
https://github.com/slhaf/Partner.git
synced 2026-05-12 16:53:04 +08:00
Compare commits
21 Commits
enhancemen
...
doc/archit
| Author | SHA1 | Date | |
|---|---|---|---|
| bfdc9b00e5 | |||
| dff7b69b51 | |||
| d77ffd1db6 | |||
| fea7f9c81f | |||
| ae5caf8475 | |||
| 980d9384d1 | |||
| 9ba0d1363a | |||
| f6d5cad5cd | |||
| 5419722c40 | |||
| 31ebee3ded | |||
| 6bfa941c35 | |||
| 456a7e04e8 | |||
| 5864760f35 | |||
| aee6d879e9 | |||
| d1ea8dde79 | |||
| 7094a8a68b | |||
| e78048f66d | |||
| 2f09c0cd71 | |||
| 8c43d6594f | |||
| 2d052442b1 | |||
| 84f7befb75 |
6
.gitignore
vendored
6
.gitignore
vendored
@@ -54,3 +54,9 @@ build/
|
||||
/config/
|
||||
/data/
|
||||
/generated-classes/
|
||||
/.idea/copilot.data.migration.ask2agent.xml
|
||||
.idea/copilot.data.migration.agent.xml
|
||||
.gitignore
|
||||
.idea/copilot.data.migration.edit.xml
|
||||
.gitignore
|
||||
.idea/copilot.data.migration.ask.xml
|
||||
|
||||
4
.idea/misc.xml
generated
4
.idea/misc.xml
generated
@@ -17,6 +17,10 @@
|
||||
<item index="12" class="java.lang.String" itemvalue="work.slhaf.partner.api.capability.annotation.CoordinateManager" />
|
||||
<item index="13" class="java.lang.String" itemvalue="work.slhaf.partner.api.register.capability.annotation.Capability" />
|
||||
</list>
|
||||
<writeAnnotations>
|
||||
<writeAnnotation name="work.slhaf.partner.api.agent.factory.capability.annotation.InjectCapability" />
|
||||
<writeAnnotation name="work.slhaf.partner.api.agent.factory.module.annotation.InjectModule" />
|
||||
</writeAnnotations>
|
||||
</component>
|
||||
<component name="ExternalStorageConfigurationManager" enabled="true" />
|
||||
<component name="MavenProjectsManager">
|
||||
|
||||
1
.idea/vcs.xml
generated
1
.idea/vcs.xml
generated
@@ -2,6 +2,5 @@
|
||||
<project version="4">
|
||||
<component name="VcsDirectoryMappings">
|
||||
<mapping directory="$PROJECT_DIR$" vcs="Git" />
|
||||
<mapping directory="$USER_HOME$/Projects/IdeaProjects/Projects/Partner" vcs="Git" />
|
||||
</component>
|
||||
</project>
|
||||
@@ -1,33 +0,0 @@
|
||||
autoDetectedPackages:
|
||||
- factory
|
||||
- module
|
||||
- work.slhaf
|
||||
enableAutoDetect: true
|
||||
entryDisplayConfig:
|
||||
excludedPathPatterns: []
|
||||
skipJsCss: true
|
||||
funcDisplayConfig:
|
||||
skipConstructors: false
|
||||
skipFieldAccess: true
|
||||
skipFieldChange: true
|
||||
skipGetters: false
|
||||
skipNonProjectPackages: false
|
||||
skipPrivateMethods: false
|
||||
skipSetters: false
|
||||
ignoreSameClassCall: null
|
||||
ignoreSamePackageCall: null
|
||||
includedPackagePrefixes: null
|
||||
includedParentClasses: null
|
||||
maxColSize: 32
|
||||
maxNumFirst: 12
|
||||
maxNumFirstImportant: 1024
|
||||
maxNumHash: 3
|
||||
maxNumHashImportant: 256
|
||||
maxObjectDepth: 4
|
||||
maxStrSize: 4096
|
||||
name: xcodemap-filter
|
||||
openMainWindow: true
|
||||
recordMode: manual
|
||||
sourceDisplayConfig:
|
||||
color: blue
|
||||
startOnDebug: false
|
||||
@@ -25,11 +25,7 @@ public class AgentRunningFlow<C extends RunningFlowContext> {
|
||||
List<MetaModule> moduleList = entry.getValue();
|
||||
for (MetaModule module : moduleList) {
|
||||
Future<?> future = executor.submit(() -> {
|
||||
try {
|
||||
module.getInstance().execute(interactionContext);
|
||||
} catch (Exception e) {
|
||||
throw new AgentRuntimeException("模块执行出错: " + module.getName(), e);
|
||||
}
|
||||
});
|
||||
futures.add(future);
|
||||
}
|
||||
|
||||
@@ -7,14 +7,12 @@ import work.slhaf.partner.api.agent.factory.module.annotation.BeforeExecute;
|
||||
import work.slhaf.partner.api.agent.factory.module.annotation.CoreModule;
|
||||
import work.slhaf.partner.api.agent.runtime.interaction.flow.entity.RunningFlowContext;
|
||||
|
||||
import java.io.IOException;
|
||||
|
||||
/**
|
||||
* 流程执行模块基类
|
||||
*/
|
||||
@Slf4j
|
||||
public abstract class AgentRunningModule<C extends RunningFlowContext> extends Module {
|
||||
public abstract void execute(C context) throws IOException, ClassNotFoundException;
|
||||
public abstract void execute(C context);
|
||||
|
||||
@BeforeExecute
|
||||
private void beforeLog() {
|
||||
|
||||
111
Partner-Main/data/log/partner.log
Normal file
111
Partner-Main/data/log/partner.log
Normal file
File diff suppressed because one or more lines are too long
@@ -28,6 +28,22 @@
|
||||
<version>1.10.2</version>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
<!-- https://mvnrepository.com/artifact/org.nd4j/nd4j-api -->
|
||||
<dependency>
|
||||
<groupId>org.nd4j</groupId>
|
||||
<artifactId>nd4j-api</artifactId>
|
||||
<version>1.0.0-M2.1</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>com.microsoft.onnxruntime</groupId>
|
||||
<artifactId>onnxruntime</artifactId>
|
||||
<version>1.23.1</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>ai.djl.huggingface</groupId>
|
||||
<artifactId>tokenizers</artifactId>
|
||||
<version>0.34.0</version>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
|
||||
<properties>
|
||||
@@ -49,7 +65,8 @@
|
||||
</goals>
|
||||
<configuration>
|
||||
<transformers>
|
||||
<transformer implementation="org.apache.maven.plugins.shade.resource.ManifestResourceTransformer">
|
||||
<transformer
|
||||
implementation="org.apache.maven.plugins.shade.resource.ManifestResourceTransformer">
|
||||
<mainClass>work.slhaf.partner.Main</mainClass>
|
||||
</transformer>
|
||||
</transformers>
|
||||
|
||||
@@ -2,6 +2,7 @@ package work.slhaf.partner;
|
||||
|
||||
import work.slhaf.partner.api.agent.Agent;
|
||||
import work.slhaf.partner.common.config.PartnerAgentConfigManager;
|
||||
import work.slhaf.partner.common.vector.VectorClient;
|
||||
import work.slhaf.partner.runtime.exception.PartnerExceptionCallback;
|
||||
import work.slhaf.partner.runtime.interaction.WebSocketGateway;
|
||||
|
||||
@@ -11,6 +12,7 @@ public class Main {
|
||||
.setAgentConfigManager(PartnerAgentConfigManager.class)
|
||||
.setGateway(WebSocketGateway.class)
|
||||
.setAgentExceptionCallback(PartnerExceptionCallback.class)
|
||||
.addAfterLaunchRunners(() -> VectorClient.load())
|
||||
.launch();
|
||||
}
|
||||
}
|
||||
@@ -4,6 +4,21 @@ import lombok.Data;
|
||||
|
||||
@Data
|
||||
public class Config {
|
||||
private int port;
|
||||
private String agentId;
|
||||
private WebSocketConfig webSocketConfig;
|
||||
private VectorConfig vectorConfig;
|
||||
|
||||
@Data
|
||||
public static class VectorConfig {
|
||||
private int type;
|
||||
private String ollamaEmbeddingUrl;
|
||||
private String ollamaEmbeddingModel;
|
||||
private String tokenizerPath;
|
||||
private String embeddingModelPath;
|
||||
}
|
||||
|
||||
@Data
|
||||
public static class WebSocketConfig {
|
||||
private int port;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -33,8 +33,9 @@ public final class PartnerAgentConfigManager extends FileAgentConfigManager {
|
||||
if (config == null || config.getAgentId() == null) {
|
||||
throw new ConfigLoadFailedException("Partner Config Load Failed: " + COMMON_CONFIG_FILE);
|
||||
}
|
||||
if (config.getPort() <= 0 || config.getPort() > 65535) {
|
||||
throw new ConfigLoadFailedException("Invalid Websocket port: " + config.getPort());
|
||||
int port = config.getWebSocketConfig().getPort();
|
||||
if (port <= 0 || port > 65535) {
|
||||
throw new ConfigLoadFailedException("Invalid Websocket port: " + port);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,14 +1,9 @@
|
||||
package work.slhaf.partner.common.thread;
|
||||
|
||||
import lombok.Getter;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.concurrent.Callable;
|
||||
import java.util.concurrent.ExecutorService;
|
||||
import java.util.concurrent.Executors;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
import java.util.concurrent.*;
|
||||
|
||||
@Getter
|
||||
public class InteractionThreadPoolExecutor {
|
||||
|
||||
private static InteractionThreadPoolExecutor interactionThreadPoolExecutor;
|
||||
@@ -33,9 +28,29 @@ public class InteractionThreadPoolExecutor {
|
||||
|
||||
public <T> void invokeAll(List<Callable<T>> tasks) {
|
||||
try {
|
||||
executorService.invokeAll(tasks);
|
||||
List<Future<T>> futures = executorService.invokeAll(tasks);
|
||||
for (Future<T> future : futures) {
|
||||
future.get();
|
||||
}
|
||||
} catch (InterruptedException e) {
|
||||
throw new RuntimeException(e);
|
||||
} catch (ExecutionException e) {
|
||||
throw new RuntimeException(e.getCause());
|
||||
}
|
||||
}
|
||||
|
||||
public <T> List<T> invokeAllAndReturn(List<Callable<T>> tasks) {
|
||||
try {
|
||||
List<Future<T>> futures = executorService.invokeAll(tasks);
|
||||
List<T> results = new ArrayList<>();
|
||||
for (Future<T> future : futures) {
|
||||
results.add(future.get());
|
||||
}
|
||||
return results;
|
||||
} catch (InterruptedException e) {
|
||||
throw new RuntimeException(e);
|
||||
} catch (ExecutionException e) {
|
||||
throw new RuntimeException(e.getCause());
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,48 @@
|
||||
package work.slhaf.partner.common.vector;
|
||||
|
||||
import cn.hutool.http.HttpRequest;
|
||||
import cn.hutool.http.HttpResponse;
|
||||
import com.alibaba.fastjson2.JSONObject;
|
||||
import lombok.Data;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import work.slhaf.partner.common.vector.exception.VectorClientExecuteException;
|
||||
|
||||
import java.util.Map;
|
||||
|
||||
@Slf4j
|
||||
public class OllamaVectorClient extends VectorClient {
|
||||
|
||||
private String ollamaEmbeddingUrl;
|
||||
private String ollamaEmbeddingModel;
|
||||
|
||||
protected OllamaVectorClient(String url, String model) {
|
||||
this.ollamaEmbeddingUrl = url;
|
||||
this.ollamaEmbeddingModel = model;
|
||||
|
||||
compute("test");
|
||||
}
|
||||
|
||||
@Override
|
||||
protected float[] doCompute(String input) {
|
||||
Map<String, String> param = Map.of("model", ollamaEmbeddingModel, "input", input);
|
||||
HttpRequest request = HttpRequest.get(ollamaEmbeddingUrl).body(JSONObject.toJSONString(param));
|
||||
try (HttpResponse response = request.execute()) {
|
||||
if (!response.isOk())
|
||||
throw new VectorClientExecuteException("嵌入模型执行出错");
|
||||
String resStr = response.body();
|
||||
EmbeddingModelResponse embeddingResponse = JSONObject.parseObject(resStr, EmbeddingModelResponse.class);
|
||||
return embeddingResponse.getEmbeddings()[0];
|
||||
} catch (Exception e) {
|
||||
throw new VectorClientExecuteException("嵌入模型执行出错", e);
|
||||
}
|
||||
}
|
||||
|
||||
@Data
|
||||
private static class EmbeddingModelResponse {
|
||||
private String model;
|
||||
private float[][] embeddings;
|
||||
private long total_duration;
|
||||
private long load_duration;
|
||||
private int prompt_eval_count;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,82 @@
|
||||
package work.slhaf.partner.common.vector;
|
||||
|
||||
import ai.djl.huggingface.tokenizers.Encoding;
|
||||
import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer;
|
||||
import ai.onnxruntime.OnnxTensor;
|
||||
import ai.onnxruntime.OrtEnvironment;
|
||||
import ai.onnxruntime.OrtSession;
|
||||
import work.slhaf.partner.common.vector.exception.VectorClientExecuteException;
|
||||
import work.slhaf.partner.common.vector.exception.VectorClientLoadFailedException;
|
||||
|
||||
import java.nio.file.Path;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
|
||||
@SuppressWarnings("FieldMayBeFinal")
|
||||
public class OnnxVectorClient extends VectorClient {
|
||||
|
||||
private String tokenizerPath;
|
||||
private String modelPath;
|
||||
|
||||
private HuggingFaceTokenizer tokenizer;
|
||||
private OrtSession session;
|
||||
private OrtEnvironment env;
|
||||
|
||||
protected OnnxVectorClient(String tokenizer, String model) {
|
||||
this.tokenizerPath = tokenizer;
|
||||
this.modelPath = model;
|
||||
|
||||
loadTokenizer();
|
||||
loadModel();
|
||||
compute("test");
|
||||
}
|
||||
|
||||
private void loadModel() {
|
||||
try {
|
||||
env = OrtEnvironment.getEnvironment();
|
||||
OrtSession.SessionOptions ops = new OrtSession.SessionOptions();
|
||||
session = env.createSession(modelPath, ops);
|
||||
} catch (Exception e) {
|
||||
throw new VectorClientLoadFailedException("加载ONNX模型失败", e);
|
||||
}
|
||||
}
|
||||
|
||||
private void loadTokenizer() {
|
||||
try {
|
||||
tokenizer = HuggingFaceTokenizer.newInstance(Path.of(tokenizerPath));
|
||||
} catch (Exception e) {
|
||||
throw new VectorClientLoadFailedException("加载Tokenizer失败", e);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
protected float[] doCompute(String input) {
|
||||
try {
|
||||
Encoding encode = tokenizer.encode(input);
|
||||
long[] ids = encode.getIds();
|
||||
long[] attentionMask = encode.getAttentionMask();
|
||||
|
||||
long[][] inputIdsBatch = {ids};
|
||||
long[][] attentionMaskBatch = {attentionMask};
|
||||
long[][] tokenTypeIdsBatch = {new long[ids.length]}; // 初始化全 0
|
||||
for (int i = 0; i < ids.length; i++)
|
||||
tokenTypeIdsBatch[0][i] = 0;
|
||||
|
||||
OnnxTensor inputTensor = OnnxTensor.createTensor(env, inputIdsBatch);
|
||||
OnnxTensor maskTensor = OnnxTensor.createTensor(env, attentionMaskBatch);
|
||||
OnnxTensor tokenTypeTensor = OnnxTensor.createTensor(env, tokenTypeIdsBatch);
|
||||
|
||||
Map<String, OnnxTensor> inputs = new HashMap<>();
|
||||
inputs.put("input_ids", inputTensor);
|
||||
inputs.put("attention_mask", maskTensor);
|
||||
inputs.put("token_type_ids", tokenTypeTensor);
|
||||
|
||||
OrtSession.Result result = session.run(inputs);
|
||||
OnnxTensor embeddingTensor = (OnnxTensor) result.get(0);
|
||||
return embeddingTensor.getFloatBuffer().array();
|
||||
} catch (Exception e) {
|
||||
throw new VectorClientExecuteException("嵌入模型执行出错", e);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,88 @@
|
||||
package work.slhaf.partner.common.vector;
|
||||
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.linalg.ops.transforms.Transforms;
|
||||
import work.slhaf.partner.api.agent.runtime.config.AgentConfigManager;
|
||||
import work.slhaf.partner.common.config.Config.VectorConfig;
|
||||
import work.slhaf.partner.common.config.PartnerAgentConfigManager;
|
||||
import work.slhaf.partner.common.exception.ServiceLoadFailedException;
|
||||
import work.slhaf.partner.common.vector.exception.VectorClientExecuteException;
|
||||
import work.slhaf.partner.common.vector.exception.VectorClientLoadFailedException;
|
||||
|
||||
@Slf4j
|
||||
public abstract class VectorClient {
|
||||
|
||||
public static boolean status;
|
||||
public static VectorClient INSTANCE;
|
||||
|
||||
public static void load() {
|
||||
PartnerAgentConfigManager configManager = (PartnerAgentConfigManager) AgentConfigManager.INSTANCE;
|
||||
VectorConfig vectorConfig = configManager.getConfig().getVectorConfig();
|
||||
int type = vectorConfig.getType();
|
||||
try {
|
||||
switch (type) {
|
||||
case 0:
|
||||
status = false;
|
||||
break;
|
||||
case 1:
|
||||
status = true;
|
||||
INSTANCE = new OllamaVectorClient(vectorConfig.getOllamaEmbeddingUrl(),
|
||||
vectorConfig.getOllamaEmbeddingModel());
|
||||
break;
|
||||
case 2:
|
||||
status = true;
|
||||
INSTANCE = new OnnxVectorClient(vectorConfig.getTokenizerPath(),
|
||||
vectorConfig.getEmbeddingModelPath());
|
||||
break;
|
||||
default:
|
||||
throw new ServiceLoadFailedException(
|
||||
"加载向量客户端失败! type: 0 -> 不启用语义缓存; type: 1 -> ollama; type: 2 -> ONNX RUNTIME");
|
||||
}
|
||||
log.info("向量客户端加载完毕");
|
||||
} catch (VectorClientLoadFailedException | VectorClientExecuteException exception) {
|
||||
status = false;
|
||||
log.error("向量客户端加载失败", exception);
|
||||
}
|
||||
}
|
||||
|
||||
public float[] compute(String input) {
|
||||
if (!status) {
|
||||
return null;
|
||||
}
|
||||
return doCompute(input);
|
||||
}
|
||||
|
||||
protected abstract float[] doCompute(String input);
|
||||
|
||||
public double compare(float[] v1, float[] v2) {
|
||||
if (!status) {
|
||||
return 0;
|
||||
}
|
||||
try (INDArray a1 = Nd4j.create(v1); INDArray a2 = Nd4j.create(v2)) {
|
||||
return Transforms.cosineSim(a1, a2);
|
||||
}
|
||||
}
|
||||
|
||||
public float[] weightedAverage(float[] newVector, float[] primaryVector) {
|
||||
try (INDArray primary = Nd4j.create(primaryVector);
|
||||
INDArray latest = Nd4j.create(newVector)) {
|
||||
|
||||
// 1️⃣ 计算余弦相似度
|
||||
double similarity = Transforms.cosineSim(primary, latest);
|
||||
|
||||
// 2️⃣ 根据相似度决定更新比例 α(差异越大,新输入影响越强)
|
||||
double alpha = (1.0 - similarity) * 0.5;
|
||||
alpha = Math.max(0.05, Math.min(alpha, 0.5));
|
||||
|
||||
// 3️⃣ 按比例混合旧向量与新向量
|
||||
INDArray updated = primary.mul(1 - alpha).add(latest.mul(alpha));
|
||||
|
||||
// 4️⃣ 归一化结果(保持方向空间一致)
|
||||
updated = updated.div(updated.norm2Number());
|
||||
|
||||
return updated.toFloatVector();
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,15 @@
|
||||
package work.slhaf.partner.common.vector.exception;
|
||||
|
||||
import work.slhaf.partner.api.agent.runtime.exception.AgentRuntimeException;
|
||||
|
||||
public class VectorClientExecuteException extends AgentRuntimeException {
|
||||
|
||||
public VectorClientExecuteException(String message) {
|
||||
super(message);
|
||||
}
|
||||
|
||||
public VectorClientExecuteException(String message, Throwable cause) {
|
||||
super(message, cause);
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,15 @@
|
||||
package work.slhaf.partner.common.vector.exception;
|
||||
|
||||
import work.slhaf.partner.api.agent.runtime.exception.AgentRuntimeException;
|
||||
|
||||
public class VectorClientLoadFailedException extends AgentRuntimeException {
|
||||
|
||||
public VectorClientLoadFailedException(String message) {
|
||||
super(message);
|
||||
}
|
||||
|
||||
public VectorClientLoadFailedException(String message, Throwable cause) {
|
||||
super(message, cause);
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,26 @@
|
||||
package work.slhaf.partner.core.action;
|
||||
|
||||
import work.slhaf.partner.api.agent.factory.capability.annotation.Capability;
|
||||
import work.slhaf.partner.core.action.entity.CacheAdjustData;
|
||||
import work.slhaf.partner.core.action.entity.MetaActionInfo;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
@Capability(value = "action")
|
||||
public interface ActionCapability {
|
||||
void putPreparedAction(String uuid, MetaActionInfo metaActionInfo);
|
||||
|
||||
List<MetaActionInfo> popPreparedAction(String userId);
|
||||
|
||||
List<MetaActionInfo> popPendingAction(String userId);
|
||||
|
||||
List<MetaActionInfo> listPreparedAction(String userId);
|
||||
|
||||
List<MetaActionInfo> listPendingAction(String userId);
|
||||
|
||||
void putPendingActions(String userId, MetaActionInfo metaActionInfo);
|
||||
|
||||
List<String> selectTendencyCache(String input);
|
||||
|
||||
void updateTendencyCache(CacheAdjustData data);
|
||||
}
|
||||
@@ -0,0 +1,235 @@
|
||||
package work.slhaf.partner.core.action;
|
||||
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import work.slhaf.partner.api.agent.factory.capability.annotation.Capability;
|
||||
import work.slhaf.partner.api.agent.factory.capability.annotation.CapabilityMethod;
|
||||
import work.slhaf.partner.common.vector.VectorClient;
|
||||
import work.slhaf.partner.core.PartnerCore;
|
||||
import work.slhaf.partner.core.action.entity.ActionCacheData;
|
||||
import work.slhaf.partner.core.action.entity.CacheAdjustData;
|
||||
import work.slhaf.partner.core.action.entity.CacheAdjustMetaData;
|
||||
import work.slhaf.partner.core.action.entity.MetaActionInfo;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.concurrent.Executor;
|
||||
import java.util.concurrent.Executors;
|
||||
import java.util.concurrent.locks.Lock;
|
||||
import java.util.concurrent.locks.ReentrantLock;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
@SuppressWarnings("FieldMayBeFinal")
|
||||
@Capability(value = "action")
|
||||
@Slf4j
|
||||
public class ActionCore extends PartnerCore<ActionCore> {
|
||||
|
||||
/**
|
||||
* 对应本次交互即将执行或将要放置在行动池的预备任务,因此将以本次交互的uuid为键,其起到的作用相当于暂时的模块上下文
|
||||
*/
|
||||
private HashMap<String, List<MetaActionInfo>> preparedActions = new HashMap<>();
|
||||
|
||||
/**
|
||||
* 待确认任务,以userId区分不同用户,因为需要跨请求确认
|
||||
*/
|
||||
private HashMap<String, List<MetaActionInfo>> pendingActions = new HashMap<>();
|
||||
|
||||
/**
|
||||
* 语义缓存与行为倾向映射
|
||||
*/
|
||||
private List<ActionCacheData> actionCache = new ArrayList<>();
|
||||
|
||||
private Lock cacheLock = new ReentrantLock();
|
||||
|
||||
private Executor executor = Executors.newFixedThreadPool(Runtime.getRuntime().availableProcessors());
|
||||
|
||||
public ActionCore() throws IOException, ClassNotFoundException {
|
||||
}
|
||||
|
||||
@CapabilityMethod
|
||||
public synchronized void putPendingActions(String userId, MetaActionInfo metaActionInfo) {
|
||||
pendingActions.computeIfAbsent(userId, k -> {
|
||||
List<MetaActionInfo> temp = new ArrayList<>();
|
||||
temp.add(metaActionInfo);
|
||||
return temp;
|
||||
});
|
||||
}
|
||||
|
||||
@CapabilityMethod
|
||||
public synchronized List<MetaActionInfo> popPendingAction(String userId) {
|
||||
List<MetaActionInfo> infos = pendingActions.get(userId);
|
||||
pendingActions.remove(userId);
|
||||
return infos;
|
||||
}
|
||||
|
||||
@CapabilityMethod
|
||||
public synchronized void putPreparedAction(String uuid, MetaActionInfo metaActionInfo) {
|
||||
preparedActions.computeIfAbsent(uuid, k -> {
|
||||
List<MetaActionInfo> temp = new ArrayList<>();
|
||||
temp.add(metaActionInfo);
|
||||
return temp;
|
||||
});
|
||||
}
|
||||
|
||||
@CapabilityMethod
|
||||
public synchronized List<MetaActionInfo> popPreparedAction(String userId) {
|
||||
List<MetaActionInfo> infos = preparedActions.get(userId);
|
||||
preparedActions.remove(userId);
|
||||
return infos;
|
||||
}
|
||||
|
||||
@CapabilityMethod
|
||||
public List<MetaActionInfo> listPreparedAction(String userId) {
|
||||
return preparedActions.get(userId);
|
||||
}
|
||||
|
||||
@CapabilityMethod
|
||||
public List<MetaActionInfo> listPendingAction(String userId) {
|
||||
return pendingActions.get(userId);
|
||||
}
|
||||
|
||||
/**
|
||||
* 计算输入内容的语义向量,根据与{@link ActionCacheData#getInputVector()}的相似度挑取缓存,后续将根据评估结果来更新计数
|
||||
*
|
||||
* @param input 本次输入内容
|
||||
* @return 命中的行为倾向集合
|
||||
*/
|
||||
@CapabilityMethod
|
||||
public List<String> selectTendencyCache(String input) {
|
||||
if (!VectorClient.status) {
|
||||
return null;
|
||||
}
|
||||
VectorClient vectorClient = VectorClient.INSTANCE;
|
||||
//计算本次输入的向量
|
||||
float[] vector = vectorClient.compute(input);
|
||||
if (vector == null) return null;
|
||||
//与现有缓存比对,将匹配到的收集并返回
|
||||
return actionCache.parallelStream()
|
||||
.filter(ActionCacheData::isActivated)
|
||||
.filter(data -> {
|
||||
double compared = vectorClient.compare(vector, data.getInputVector());
|
||||
return compared > data.getThreshold();
|
||||
})
|
||||
.map(ActionCacheData::getTendency)
|
||||
.collect(Collectors.toList());
|
||||
}
|
||||
|
||||
@CapabilityMethod
|
||||
public void updateTendencyCache(CacheAdjustData data) {
|
||||
VectorClient vectorClient = VectorClient.INSTANCE;
|
||||
List<CacheAdjustMetaData> list = data.getMetaDataList();
|
||||
String input = data.getInput();
|
||||
float[] inputVector = vectorClient.compute(input);
|
||||
|
||||
List<CacheAdjustMetaData> matchAndPassed = new ArrayList<>();
|
||||
List<CacheAdjustMetaData> matchNotPassed = new ArrayList<>();
|
||||
List<CacheAdjustMetaData> notMatchPassed = new ArrayList<>();
|
||||
|
||||
for (CacheAdjustMetaData metaData : list) {
|
||||
if (metaData.isHit() && metaData.isPassed()) {
|
||||
matchAndPassed.add(metaData);
|
||||
} else if (metaData.isHit()) {
|
||||
matchNotPassed.add(metaData);
|
||||
} else if (!metaData.isPassed()) {
|
||||
notMatchPassed.add(metaData);
|
||||
}
|
||||
}
|
||||
|
||||
executor.execute(() -> adjustMatchAndPassed(matchAndPassed, inputVector, input, vectorClient));
|
||||
executor.execute(() -> adjustMatchNotPassed(matchNotPassed, vectorClient));
|
||||
executor.execute(() -> adjustNotMatchPassed(notMatchPassed, inputVector, input, vectorClient));
|
||||
}
|
||||
|
||||
/**
|
||||
* 命中缓存且评估通过时
|
||||
*
|
||||
* @param matchAndPassed 该类型的带调整缓存信息列表
|
||||
* @param inputVector 本次输入内容的语义向量
|
||||
* @param vectorClient 向量客户端
|
||||
*/
|
||||
private void adjustMatchAndPassed(List<CacheAdjustMetaData> matchAndPassed, float[] inputVector, String input, VectorClient vectorClient) {
|
||||
matchAndPassed.forEach(adjustData -> {
|
||||
//获取原始缓存条目
|
||||
String tendency = adjustData.getTendency();
|
||||
ActionCacheData primaryCacheData = selectCacheData(tendency);
|
||||
if (primaryCacheData == null) {
|
||||
return;
|
||||
}
|
||||
primaryCacheData.updateAfterMatchAndPassed(inputVector, vectorClient, input);
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* 针对命中缓存、但评估未通过的条目与输入进行处理
|
||||
*
|
||||
* @param matchNotPassed 该类型的带调整缓存信息列表
|
||||
* @param vectorClient 向量客户端
|
||||
*/
|
||||
private void adjustMatchNotPassed(List<CacheAdjustMetaData> matchNotPassed, VectorClient vectorClient) {
|
||||
List<ActionCacheData> toRemove = new ArrayList<>();
|
||||
matchNotPassed.forEach(adjustData -> {
|
||||
//获取原始缓存条目
|
||||
String tendency = adjustData.getTendency();
|
||||
ActionCacheData primaryCacheData = selectCacheData(tendency);
|
||||
if (primaryCacheData == null) {
|
||||
return;
|
||||
}
|
||||
boolean remove = primaryCacheData.updateAfterMatchNotPassed(vectorClient);
|
||||
if (remove) {
|
||||
toRemove.add(primaryCacheData);
|
||||
}
|
||||
|
||||
});
|
||||
cacheLock.lock();
|
||||
actionCache.removeAll(toRemove);
|
||||
cacheLock.unlock();
|
||||
}
|
||||
|
||||
/**
|
||||
* 针对未命中但评估通过的缓存做出调整:
|
||||
* <ol>
|
||||
* <h3>如果存在缓存条目</h3>
|
||||
* <li>
|
||||
* 若已生效,但此时未匹配到则说明尚未生效或者阈值、向量{@link ActionCacheData#getInputVector()}存在问题,调低阈值,同时带权移动平均
|
||||
* </li>
|
||||
* <li>
|
||||
* 若未生效,则只增加计数并带权移动平均
|
||||
* </li>
|
||||
* </ol>
|
||||
* 如果不存在缓存条目,则新增并填充字段
|
||||
*
|
||||
* @param notMatchPassed 该类型的带调整缓存信息列表
|
||||
* @param inputVector 本次输入内容的语义向量
|
||||
* @param input 本次输入内容
|
||||
* @param vectorClient 向量客户端
|
||||
*/
|
||||
private void adjustNotMatchPassed(List<CacheAdjustMetaData> notMatchPassed, float[] inputVector, String input, VectorClient vectorClient) {
|
||||
notMatchPassed.forEach(adjustData -> {
|
||||
//获取原始缓存条目
|
||||
String tendency = adjustData.getTendency();
|
||||
ActionCacheData primaryCacheData = selectCacheData(tendency);
|
||||
float[] tendencyVector = vectorClient.compute(tendency);
|
||||
if (primaryCacheData == null) {
|
||||
actionCache.add(new ActionCacheData(tendency, tendencyVector, inputVector, input));
|
||||
return;
|
||||
}
|
||||
primaryCacheData.updateAfterNotMatchPassed(input, inputVector, tendencyVector, vectorClient);
|
||||
});
|
||||
}
|
||||
|
||||
private ActionCacheData selectCacheData(String tendency) {
|
||||
for (ActionCacheData actionCacheData : actionCache) {
|
||||
if (actionCacheData.getTendency().equals(tendency)) {
|
||||
return actionCacheData;
|
||||
}
|
||||
}
|
||||
log.warn("[{}] 未找到行为倾向[{}]对应的缓存条目,可能是代码逻辑存在错误", getCoreKey(), tendency);
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected String getCoreKey() {
|
||||
return "action-core";
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,181 @@
|
||||
package work.slhaf.partner.core.action.entity;
|
||||
|
||||
import lombok.Data;
|
||||
import work.slhaf.partner.common.vector.VectorClient;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
@Data
|
||||
public class ActionCacheData {
|
||||
private boolean activated = false;
|
||||
private int inputMatchCount = 1;
|
||||
|
||||
private float[] inputVector;
|
||||
private float[] tendencyVector;
|
||||
private String tendency;
|
||||
private double threshold = 0.75;
|
||||
|
||||
private List<String> validSamples = new ArrayList<>();
|
||||
private int failedCount = 0;
|
||||
private Type type = Type.PRIMARY;
|
||||
|
||||
public ActionCacheData(String tendency, float[] tendencyVector, float[] inputVector, String input) {
|
||||
this.tendency = tendency;
|
||||
this.inputVector = inputVector;
|
||||
this.tendencyVector = tendencyVector;
|
||||
this.validSamples.add(input);
|
||||
}
|
||||
|
||||
/**
|
||||
* 命中缓存且评估通过时,根据输入内容的语义向量与现有的输入语义向量进行带权移动平均,以相似度为权重,同时降低失败计数,为零时置为上一级缓存类型{@link ActionCacheData.Type}
|
||||
*
|
||||
* @param inputVector 本次输入内容对应的语义向量
|
||||
* @param vectorClient 向量客户端
|
||||
* @param input 本次输入内容
|
||||
*/
|
||||
public synchronized void updateAfterMatchAndPassed(float[] inputVector, VectorClient vectorClient, String input) {
|
||||
updateInputVector(inputVector, vectorClient);
|
||||
addValidSample(input);
|
||||
reduceFailedCount();
|
||||
updateType();
|
||||
addInputMatchCount();
|
||||
}
|
||||
|
||||
private void updateType() {
|
||||
if (this.failedCount == 0) {
|
||||
this.type = switch (type) {
|
||||
case PRIMARY, REBUILD_V1 -> ActionCacheData.Type.PRIMARY;
|
||||
case REBUILD_V2 -> ActionCacheData.Type.REBUILD_V1;
|
||||
case REBUILD_V3 -> ActionCacheData.Type.REBUILD_V2;
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
private void reduceFailedCount() {
|
||||
this.failedCount = Math.max(this.failedCount - 1, 0);
|
||||
}
|
||||
|
||||
private void addValidSample(String input) {
|
||||
if (this.validSamples.size() == 12) {
|
||||
this.validSamples.removeFirst();
|
||||
}
|
||||
this.validSamples.add(input);
|
||||
}
|
||||
|
||||
private void updateInputVector(float[] inputVector, VectorClient vectorClient) {
|
||||
this.inputVector = vectorClient.weightedAverage(inputVector, this.inputVector);
|
||||
}
|
||||
|
||||
/**
|
||||
* 针对命中缓存、但评估未通过的条目与输入进行处理: 增加失败计数(必要时重建并更新类型等级)、调高阈值(0.02),由于缓存匹配但评估未通过,所以不进行带权移动平均
|
||||
*
|
||||
* @param vectorClient 向量客户端
|
||||
* @return 是否需要删除(已在REBUILD_V3状态且达到最大误判次数的)
|
||||
*/
|
||||
public synchronized boolean updateAfterMatchNotPassed(VectorClient vectorClient) {
|
||||
adjustThreshold();
|
||||
addFailedCount();
|
||||
if (this.failedCount < 3) {
|
||||
return false;
|
||||
}
|
||||
if (this.type == Type.REBUILD_V3) {
|
||||
return true;
|
||||
}
|
||||
rebuildAndSwitchType(vectorClient);
|
||||
return false;
|
||||
}
|
||||
|
||||
private void rebuildAndSwitchType(VectorClient vectorClient) {
|
||||
this.type = switch (this.type) {
|
||||
case PRIMARY -> {
|
||||
//样本顺序反转后,以全部样本重建
|
||||
this.validSamples = this.validSamples.reversed();
|
||||
rebuildWithSamples(vectorClient);
|
||||
yield Type.REBUILD_V1;
|
||||
}
|
||||
case REBUILD_V1 -> {
|
||||
//截取后一半样本,反转后以此重建
|
||||
List<String> temp = this.validSamples.subList(this.validSamples.size() / 2, this.validSamples.size());
|
||||
this.validSamples = temp.reversed();
|
||||
rebuildWithSamples(vectorClient);
|
||||
yield Type.REBUILD_V2;
|
||||
}
|
||||
case REBUILD_V2 -> {
|
||||
//截取后四分之一样本,反转后以此重建
|
||||
List<String> temp = this.validSamples.subList(this.validSamples.size() / 4, this.validSamples.size());
|
||||
this.validSamples = temp.reversed();
|
||||
rebuildWithSamples(vectorClient);
|
||||
yield Type.REBUILD_V3;
|
||||
}
|
||||
case REBUILD_V3 -> null;
|
||||
};
|
||||
//阈值减0.05,防止重建后一直升高
|
||||
this.threshold = Math.max(this.threshold - 0.05, 0.75);
|
||||
this.failedCount = 0;
|
||||
}
|
||||
|
||||
private void rebuildWithSamples(VectorClient vectorClient) {
|
||||
for (int i = 0; i < this.validSamples.size(); i++) {
|
||||
String sample = this.validSamples.get(i);
|
||||
if (i == 0) {
|
||||
this.inputVector = vectorClient.compute(sample);
|
||||
} else {
|
||||
float[] newSampleVector = vectorClient.compute(sample);
|
||||
this.inputVector = vectorClient.weightedAverage(this.inputVector, newSampleVector);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private void addFailedCount() {
|
||||
this.failedCount = Math.min(this.failedCount + 1, 3);
|
||||
}
|
||||
|
||||
private void adjustThreshold() {
|
||||
double newThreshold = this.threshold + 0.03;
|
||||
this.threshold = Math.min(newThreshold, 0.95);
|
||||
}
|
||||
|
||||
/**
|
||||
* 针对未命中但评估通过的已存在缓存做出调整:
|
||||
* <ol>
|
||||
* <li>
|
||||
* 若已生效,但此时未匹配到则说明阈值或者向量{@link ActionCacheData#getInputVector()}存在问题,调低阈值,同时带权移动平均
|
||||
* </li>
|
||||
* <li>
|
||||
* 若未生效,则只增加计数并带权移动平均
|
||||
* </li>
|
||||
* </ol>
|
||||
*
|
||||
* @param input 本次输入内容
|
||||
* @param inputVector 本次输入内容对应的语义向量
|
||||
* @param tendencyVector 本次倾向对应的语义向量
|
||||
* @param vectorClient 向量客户端
|
||||
*/
|
||||
public synchronized void updateAfterNotMatchPassed(String input, float[] inputVector, float[] tendencyVector, VectorClient vectorClient) {
|
||||
if (this.activated) {
|
||||
reduceThreshold();
|
||||
this.inputVector = vectorClient.weightedAverage(inputVector, this.inputVector);
|
||||
} else {
|
||||
addValidSample(input);
|
||||
this.tendencyVector = vectorClient.weightedAverage(tendencyVector, this.tendencyVector);
|
||||
addInputMatchCount();
|
||||
}
|
||||
}
|
||||
|
||||
private void reduceThreshold() {
|
||||
double newThreshold = this.threshold - 0.02;
|
||||
this.threshold = Math.max(newThreshold, 0.75);
|
||||
}
|
||||
|
||||
private void addInputMatchCount() {
|
||||
this.inputMatchCount += 1;
|
||||
if (inputMatchCount >= 6) {
|
||||
this.activated = true;
|
||||
}
|
||||
}
|
||||
|
||||
public enum Type {
|
||||
PRIMARY, REBUILD_V1, REBUILD_V2, REBUILD_V3
|
||||
}
|
||||
}
|
||||
@@ -1,5 +1,5 @@
|
||||
package work.slhaf.partner.core.action.entity;
|
||||
|
||||
public enum ActionStatus {
|
||||
SUCCESS, FAILED, EXECUTING, WAITING
|
||||
SUCCESS, FAILED, EXECUTING, WAITING, PREPARE
|
||||
}
|
||||
|
||||
@@ -0,0 +1,11 @@
|
||||
package work.slhaf.partner.core.action.entity;
|
||||
|
||||
import lombok.Data;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
@Data
|
||||
public class CacheAdjustData {
|
||||
private String input;
|
||||
private List<CacheAdjustMetaData> metaDataList;
|
||||
}
|
||||
@@ -0,0 +1,10 @@
|
||||
package work.slhaf.partner.core.action.entity;
|
||||
|
||||
import lombok.Data;
|
||||
|
||||
@Data
|
||||
public class CacheAdjustMetaData {
|
||||
private String tendency;
|
||||
private boolean passed;
|
||||
private boolean hit;
|
||||
}
|
||||
@@ -0,0 +1,9 @@
|
||||
package work.slhaf.partner.core.action.entity;
|
||||
|
||||
import lombok.Data;
|
||||
import lombok.EqualsAndHashCode;
|
||||
|
||||
@EqualsAndHashCode(callSuper = true)
|
||||
@Data
|
||||
public class ImmediateActionInfo extends MetaActionInfo{
|
||||
}
|
||||
@@ -2,12 +2,11 @@ package work.slhaf.partner.core.action.entity;
|
||||
|
||||
import lombok.Data;
|
||||
|
||||
import java.time.LocalDateTime;
|
||||
|
||||
@Data
|
||||
public class MetaActionInfo {
|
||||
private ActionData actionData;
|
||||
private ActionStatus status;
|
||||
private String Result;
|
||||
private LocalDateTime dateTime;
|
||||
public abstract class MetaActionInfo {
|
||||
protected String uuid;
|
||||
protected String tendency;
|
||||
protected ActionStatus status;
|
||||
protected ActionData actionData;
|
||||
protected String Result;
|
||||
}
|
||||
|
||||
@@ -0,0 +1,15 @@
|
||||
package work.slhaf.partner.core.action.entity;
|
||||
|
||||
import lombok.Data;
|
||||
import lombok.EqualsAndHashCode;
|
||||
|
||||
@EqualsAndHashCode(callSuper = true)
|
||||
@Data
|
||||
public class ScheduledActionInfo extends MetaActionInfo {
|
||||
private ScheduledType type;
|
||||
private String scheduleContent; //如果为周期,则对应cron表达式,如果为一次性,则对应为LocalDateTime字符串
|
||||
|
||||
enum ScheduledType {
|
||||
CYCLE, ONCE
|
||||
}
|
||||
}
|
||||
@@ -5,7 +5,6 @@ import work.slhaf.partner.core.memory.pojo.EvaluatedSlice;
|
||||
import work.slhaf.partner.core.memory.pojo.MemoryResult;
|
||||
import work.slhaf.partner.core.memory.pojo.MemorySlice;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.time.LocalDate;
|
||||
import java.time.LocalDateTime;
|
||||
import java.util.HashMap;
|
||||
@@ -45,7 +44,7 @@ public interface MemoryCapability {
|
||||
|
||||
MemoryResult selectMemory(String topicPathStr);
|
||||
|
||||
MemoryResult selectMemory(LocalDate date) throws IOException, ClassNotFoundException;
|
||||
MemoryResult selectMemory(LocalDate date);
|
||||
|
||||
void insertSlice(MemorySlice memorySlice, String topicPath);
|
||||
|
||||
|
||||
@@ -3,12 +3,10 @@ package work.slhaf.partner.module.common.module;
|
||||
import work.slhaf.partner.api.agent.runtime.interaction.flow.abstracts.AgentRunningModule;
|
||||
import work.slhaf.partner.runtime.interaction.data.context.PartnerRunningFlowContext;
|
||||
|
||||
import java.io.IOException;
|
||||
|
||||
public abstract class PostRunningModule extends AgentRunningModule<PartnerRunningFlowContext> {
|
||||
|
||||
@Override
|
||||
public final void execute(PartnerRunningFlowContext context) throws IOException, ClassNotFoundException {
|
||||
public final void execute(PartnerRunningFlowContext context) {
|
||||
boolean trigger = context.getModuleContext().getExtraContext().getBoolean("post_process_trigger");
|
||||
if (!trigger) {
|
||||
return;
|
||||
|
||||
@@ -5,7 +5,6 @@ import work.slhaf.partner.api.agent.runtime.interaction.flow.abstracts.AgentRunn
|
||||
import work.slhaf.partner.module.common.entity.AppendPromptData;
|
||||
import work.slhaf.partner.runtime.interaction.data.context.PartnerRunningFlowContext;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.HashMap;
|
||||
|
||||
/**
|
||||
@@ -16,7 +15,7 @@ public abstract class PreRunningModule extends AgentRunningModule<PartnerRunning
|
||||
private synchronized void setAppendedPrompt(PartnerRunningFlowContext context) {
|
||||
AppendPromptData data = new AppendPromptData();
|
||||
data.setModuleName(moduleName());
|
||||
HashMap<String, String> map = getPromptDataMap(context.getUserId());
|
||||
HashMap<String, String> map = getPromptDataMap(context);
|
||||
data.setAppendedPrompt(map);
|
||||
context.setAppendedPrompt(data);
|
||||
}
|
||||
@@ -25,7 +24,7 @@ public abstract class PreRunningModule extends AgentRunningModule<PartnerRunning
|
||||
context.getCoreContext().addActiveModule(moduleName());
|
||||
}
|
||||
|
||||
protected abstract HashMap<String, String> getPromptDataMap(String userId);
|
||||
protected abstract HashMap<String, String> getPromptDataMap(PartnerRunningFlowContext context);
|
||||
|
||||
/**
|
||||
* 用于在CoreModule接收到的模块Prompt中标识模块名称
|
||||
@@ -33,13 +32,13 @@ public abstract class PreRunningModule extends AgentRunningModule<PartnerRunning
|
||||
protected abstract String moduleName();
|
||||
|
||||
@Override
|
||||
public final void execute(PartnerRunningFlowContext context) throws IOException, ClassNotFoundException {
|
||||
public final void execute(PartnerRunningFlowContext context) {
|
||||
doExecute(context); // 子类实现差异化逻辑
|
||||
setAppendedPrompt(context); // 通用逻辑
|
||||
setActiveModule(context); // 通用逻辑
|
||||
}
|
||||
|
||||
protected abstract void doExecute(PartnerRunningFlowContext context) throws IOException, ClassNotFoundException;
|
||||
protected abstract void doExecute(PartnerRunningFlowContext context);
|
||||
|
||||
|
||||
}
|
||||
|
||||
@@ -1,46 +1,256 @@
|
||||
package work.slhaf.partner.module.modules.action.planner;
|
||||
|
||||
import work.slhaf.partner.api.agent.factory.capability.annotation.InjectCapability;
|
||||
import work.slhaf.partner.api.agent.factory.module.annotation.AfterExecute;
|
||||
import work.slhaf.partner.api.agent.factory.module.annotation.AgentModule;
|
||||
import work.slhaf.partner.api.agent.factory.module.annotation.Init;
|
||||
import work.slhaf.partner.api.agent.factory.module.annotation.InjectModule;
|
||||
import work.slhaf.partner.api.chat.pojo.Message;
|
||||
import work.slhaf.partner.common.thread.InteractionThreadPoolExecutor;
|
||||
import work.slhaf.partner.common.vector.VectorClient;
|
||||
import work.slhaf.partner.core.action.ActionCapability;
|
||||
import work.slhaf.partner.core.action.entity.*;
|
||||
import work.slhaf.partner.core.cognation.CognationCapability;
|
||||
import work.slhaf.partner.core.memory.MemoryCapability;
|
||||
import work.slhaf.partner.core.perceive.PerceiveCapability;
|
||||
import work.slhaf.partner.module.common.module.PreRunningModule;
|
||||
import work.slhaf.partner.module.modules.action.planner.confirmer.ActionConfirmer;
|
||||
import work.slhaf.partner.module.modules.action.planner.confirmer.entity.ConfirmerInput;
|
||||
import work.slhaf.partner.module.modules.action.planner.confirmer.entity.ConfirmerResult;
|
||||
import work.slhaf.partner.module.modules.action.planner.evaluator.ActionEvaluator;
|
||||
import work.slhaf.partner.module.modules.action.planner.evaluator.entity.EvaluatorInput;
|
||||
import work.slhaf.partner.module.modules.action.planner.evaluator.entity.EvaluatorResult;
|
||||
import work.slhaf.partner.module.modules.action.planner.extractor.ActionExtractor;
|
||||
import work.slhaf.partner.module.modules.action.planner.extractor.entity.ExtractorInput;
|
||||
import work.slhaf.partner.module.modules.action.planner.extractor.entity.ExtractorResult;
|
||||
import work.slhaf.partner.runtime.interaction.data.context.PartnerRunningFlowContext;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.UUID;
|
||||
import java.util.concurrent.Callable;
|
||||
|
||||
/**
|
||||
* 负责针对本次输入生成基础的行动建议,是否执行由主模型判断。
|
||||
* 负责针对本次输入生成基础的行动计划,在主模型传达意愿后,执行行动或者放入计划池
|
||||
*/
|
||||
@AgentModule(name = "task_planner",order = 2)
|
||||
@AgentModule(name = "action_planner", order = 2)
|
||||
public class ActionPlanner extends PreRunningModule {
|
||||
|
||||
@InjectCapability
|
||||
private CognationCapability cognationCapability;
|
||||
@InjectCapability
|
||||
private ActionCapability actionCapability;
|
||||
@InjectCapability
|
||||
private PerceiveCapability perceiveCapability;
|
||||
@InjectCapability
|
||||
private MemoryCapability memoryCapability;
|
||||
|
||||
@InjectModule
|
||||
private ActionEvaluator actionEvaluator;
|
||||
@InjectModule
|
||||
private ActionExtractor actionExtractor;
|
||||
@InjectModule
|
||||
private ActionConfirmer actionConfirmer;
|
||||
|
||||
@Override
|
||||
protected void doExecute(PartnerRunningFlowContext context) throws IOException, ClassNotFoundException {
|
||||
ExtractorInput extractorInput = getExtractorInput(context);
|
||||
}
|
||||
private InteractionThreadPoolExecutor executor;
|
||||
private ActionAssemblyHelper assemblyHelper;
|
||||
|
||||
private ExtractorInput getExtractorInput(PartnerRunningFlowContext context) {
|
||||
ExtractorInput input = new ExtractorInput();
|
||||
input.setInput(context.getInput());
|
||||
input.setRecentMessages();
|
||||
return input;
|
||||
@Init
|
||||
public void init() {
|
||||
executor = InteractionThreadPoolExecutor.getInstance();
|
||||
assemblyHelper = new ActionAssemblyHelper();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected HashMap<String, String> getPromptDataMap(String userId) {
|
||||
protected void doExecute(PartnerRunningFlowContext context) {
|
||||
List<Callable<Void>> tasks = new ArrayList<>();
|
||||
addConfirmTask(tasks, context);
|
||||
addNewActionTask(tasks, context);
|
||||
executor.invokeAll(tasks);
|
||||
}
|
||||
|
||||
/**
|
||||
* 新的提取与评估任务
|
||||
*
|
||||
* @param tasks 并发任务列表
|
||||
* @param context 流程上下文
|
||||
*/
|
||||
private void addNewActionTask(List<Callable<Void>> tasks, PartnerRunningFlowContext context) {
|
||||
tasks.add(() -> {
|
||||
ExtractorInput extractorInput = assemblyHelper.buildExtractorInput(context);
|
||||
ExtractorResult extractorResult = actionExtractor.execute(extractorInput);
|
||||
if (extractorResult.getTendencies().isEmpty()) {
|
||||
return null;
|
||||
}
|
||||
EvaluatorInput evaluatorInput = assemblyHelper.buildEvaluatorInput(extractorResult, context.getUserId());
|
||||
List<EvaluatorResult> evaluatorResults = actionEvaluator.execute(evaluatorInput); //并发操作均为访问
|
||||
setupActionInfo(evaluatorResults, context);
|
||||
return null;
|
||||
});
|
||||
}
|
||||
|
||||
@AfterExecute
|
||||
private void updateTendencyCache(List<EvaluatorResult> evaluatorResults, String input, ExtractorResult extractorResult) {
|
||||
if (!VectorClient.status) {
|
||||
return;
|
||||
}
|
||||
executor.execute(() -> {
|
||||
CacheAdjustData data = new CacheAdjustData();
|
||||
List<CacheAdjustMetaData> list = new ArrayList<>();
|
||||
List<String> hitTendencies = extractorResult.getTendencies();
|
||||
for (EvaluatorResult result : evaluatorResults) {
|
||||
CacheAdjustMetaData metaData = new CacheAdjustMetaData();
|
||||
metaData.setTendency(result.getTendency());
|
||||
metaData.setPassed(result.isOk());
|
||||
metaData.setHit(hitTendencies.contains(result.getTendency()));
|
||||
list.add(metaData);
|
||||
}
|
||||
data.setMetaDataList(list);
|
||||
data.setInput(input);
|
||||
actionCapability.updateTendencyCache(data);
|
||||
});
|
||||
|
||||
}
|
||||
|
||||
/**
|
||||
* 待确认行动的判断任务
|
||||
*
|
||||
* @param tasks 并发任务列表
|
||||
* @param context 流程上下文
|
||||
*/
|
||||
private void addConfirmTask(List<Callable<Void>> tasks, PartnerRunningFlowContext context) {
|
||||
tasks.add(() -> {
|
||||
ConfirmerInput confirmerInput = assemblyHelper.buildConfirmerInput(context);
|
||||
ConfirmerResult result = actionConfirmer.execute(confirmerInput);
|
||||
setupPendingActionInfo(context, result);
|
||||
return null;
|
||||
});
|
||||
}
|
||||
|
||||
private void setupPendingActionInfo(PartnerRunningFlowContext context, ConfirmerResult result) {
|
||||
//TODO 需考虑未确认任务的失效或者拒绝时机
|
||||
List<String> uuids = result.getUuids();
|
||||
if (uuids == null) {
|
||||
return;
|
||||
}
|
||||
String contextUuid = context.getUuid();
|
||||
List<MetaActionInfo> pendingActions = actionCapability.popPendingAction(context.getUserId());
|
||||
for (MetaActionInfo actionInfo : pendingActions) {
|
||||
if (uuids.contains(actionInfo.getUuid())) {
|
||||
actionCapability.putPreparedAction(contextUuid, actionInfo);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
private void setupActionInfo(List<EvaluatorResult> evaluatorResults, PartnerRunningFlowContext context) {
|
||||
for (EvaluatorResult evaluatorResult : evaluatorResults) {
|
||||
MetaActionInfo metaActionInfo = assemblyHelper.buildMetaActionInfo(evaluatorResult);
|
||||
if (evaluatorResult.isNeedConfirm()) {
|
||||
actionCapability.putPendingActions(context.getUserId(), metaActionInfo);
|
||||
} else {
|
||||
actionCapability.putPreparedAction(context.getUuid(), metaActionInfo);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
protected HashMap<String, String> getPromptDataMap(PartnerRunningFlowContext context) {
|
||||
HashMap<String, String> map = new HashMap<>();
|
||||
setupPendingActions(map, context.getUserId());
|
||||
setupPreparedActions(map, context.getUuid());
|
||||
return map;
|
||||
}
|
||||
|
||||
private void setupPendingActions(HashMap<String, String> map, String userId) {
|
||||
List<MetaActionInfo> actionInfos = actionCapability.listPendingAction(userId);
|
||||
if (actionInfos == null || actionInfos.isEmpty()) {
|
||||
map.put("[待确认行动] <待确认行动信息>", "无待确认行动");
|
||||
return;
|
||||
}
|
||||
for (int i = 0; i < actionInfos.size(); i++) {
|
||||
map.put("[待确认行动 " + (i + 1) + " ]", generateActionStr(actionInfos.get(i)));
|
||||
}
|
||||
}
|
||||
|
||||
private void setupPreparedActions(HashMap<String, String> map, String uuid) {
|
||||
List<MetaActionInfo> actionInfos = actionCapability.listPreparedAction(uuid);
|
||||
if (actionInfos == null || actionInfos.isEmpty()) {
|
||||
map.put("[预备行动] <预备行动信息>", "无预备行动");
|
||||
return;
|
||||
}
|
||||
for (int i = 0; i < actionInfos.size(); i++) {
|
||||
map.put("[预备行动 " + (i + 1) + " ]", generateActionStr(actionInfos.get(i)));
|
||||
}
|
||||
}
|
||||
|
||||
private String generateActionStr(MetaActionInfo metaActionInfo) {
|
||||
ActionData actionData = metaActionInfo.getActionData();
|
||||
return "<行动倾向>" + " : " + metaActionInfo.getTendency() +
|
||||
"<行动原因>" + " : " + actionData.getReason() +
|
||||
"<工具描述>" + " : " + actionData.getDescription();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected String moduleName() {
|
||||
return "task_planner";
|
||||
return "[行动模块]";
|
||||
}
|
||||
|
||||
private class ActionAssemblyHelper {
|
||||
private ActionAssemblyHelper() {
|
||||
}
|
||||
|
||||
private ExtractorInput buildExtractorInput(PartnerRunningFlowContext context) {
|
||||
ExtractorInput input = new ExtractorInput();
|
||||
input.setInput(context.getInput());
|
||||
List<Message> chatMessages = cognationCapability.getChatMessages();
|
||||
List<Message> recentMessages = new ArrayList<>();
|
||||
if (chatMessages.size() > 5) {
|
||||
recentMessages.addAll(chatMessages.subList(chatMessages.size() - 5, chatMessages.size() - 1));
|
||||
} else if (chatMessages.size() > 1) {
|
||||
recentMessages.addAll(chatMessages.subList(0, chatMessages.size() - 1));
|
||||
}
|
||||
input.setRecentMessages(recentMessages);
|
||||
return input;
|
||||
}
|
||||
|
||||
private EvaluatorInput buildEvaluatorInput(ExtractorResult extractorResult, String userId) {
|
||||
EvaluatorInput input = new EvaluatorInput();
|
||||
input.setTendencies(extractorResult.getTendencies());
|
||||
input.setUser(perceiveCapability.getUser(userId));
|
||||
input.setRecentMessages(cognationCapability.getChatMessages());
|
||||
input.setActivatedSlices(memoryCapability.getActivatedSlices(userId));
|
||||
return input;
|
||||
}
|
||||
|
||||
private MetaActionInfo buildMetaActionInfo(EvaluatorResult evaluatorResult) {
|
||||
return switch (evaluatorResult.getType()) {
|
||||
case PLANNING -> {
|
||||
ScheduledActionInfo actionInfo = new ScheduledActionInfo();
|
||||
actionInfo.setActionData(evaluatorResult.getActionData());
|
||||
actionInfo.setScheduleContent(evaluatorResult.getScheduleContent());
|
||||
actionInfo.setStatus(ActionStatus.PREPARE);
|
||||
actionInfo.setUuid(UUID.randomUUID().toString());
|
||||
yield actionInfo;
|
||||
}
|
||||
case IMMEDIATE -> {
|
||||
ImmediateActionInfo actionInfo = new ImmediateActionInfo();
|
||||
actionInfo.setActionData(evaluatorResult.getActionData());
|
||||
actionInfo.setStatus(ActionStatus.PREPARE);
|
||||
actionInfo.setUuid(UUID.randomUUID().toString());
|
||||
yield actionInfo;
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
private ConfirmerInput buildConfirmerInput(PartnerRunningFlowContext context) {
|
||||
ConfirmerInput confirmerInput = new ConfirmerInput();
|
||||
confirmerInput.setInput(context.getInput());
|
||||
List<MetaActionInfo> pendingActions = actionCapability.listPendingAction(context.getUserId());
|
||||
confirmerInput.setActionInfos(pendingActions);
|
||||
return confirmerInput;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,25 @@
|
||||
package work.slhaf.partner.module.modules.action.planner.confirmer;
|
||||
|
||||
import work.slhaf.partner.api.agent.factory.module.annotation.AgentSubModule;
|
||||
import work.slhaf.partner.api.agent.runtime.interaction.flow.abstracts.ActivateModel;
|
||||
import work.slhaf.partner.api.agent.runtime.interaction.flow.abstracts.AgentRunningSubModule;
|
||||
import work.slhaf.partner.module.modules.action.planner.confirmer.entity.ConfirmerInput;
|
||||
import work.slhaf.partner.module.modules.action.planner.confirmer.entity.ConfirmerResult;
|
||||
|
||||
@AgentSubModule
|
||||
public class ActionConfirmer extends AgentRunningSubModule<ConfirmerInput, ConfirmerResult> implements ActivateModel {
|
||||
@Override
|
||||
public ConfirmerResult execute(ConfirmerInput data) {
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String modelKey() {
|
||||
return "action-confirmer";
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean withBasicPrompt() {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,14 @@
|
||||
package work.slhaf.partner.module.modules.action.planner.confirmer.entity;
|
||||
|
||||
import lombok.Data;
|
||||
import work.slhaf.partner.api.chat.pojo.Message;
|
||||
import work.slhaf.partner.core.action.entity.MetaActionInfo;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
@Data
|
||||
public class ConfirmerInput {
|
||||
private String input;
|
||||
private List<MetaActionInfo> actionInfos;
|
||||
private List<Message> recentMessages;
|
||||
}
|
||||
@@ -0,0 +1,10 @@
|
||||
package work.slhaf.partner.module.modules.action.planner.confirmer.entity;
|
||||
|
||||
import lombok.Data;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
@Data
|
||||
public class ConfirmerResult {
|
||||
private List<String> uuids;
|
||||
}
|
||||
@@ -1,18 +1,67 @@
|
||||
package work.slhaf.partner.module.modules.action.planner.evaluator;
|
||||
|
||||
import cn.hutool.core.bean.BeanUtil;
|
||||
import com.alibaba.fastjson2.JSONObject;
|
||||
import work.slhaf.partner.api.agent.factory.module.annotation.AgentSubModule;
|
||||
import work.slhaf.partner.api.agent.factory.module.annotation.Init;
|
||||
import work.slhaf.partner.api.agent.runtime.interaction.flow.abstracts.ActivateModel;
|
||||
import work.slhaf.partner.api.agent.runtime.interaction.flow.abstracts.AgentRunningSubModule;
|
||||
import work.slhaf.partner.api.chat.pojo.ChatResponse;
|
||||
import work.slhaf.partner.common.thread.InteractionThreadPoolExecutor;
|
||||
import work.slhaf.partner.module.modules.action.planner.evaluator.entity.EvaluatorBatchInput;
|
||||
import work.slhaf.partner.module.modules.action.planner.evaluator.entity.EvaluatorInput;
|
||||
import work.slhaf.partner.module.modules.action.planner.evaluator.entity.EvaluatorResult;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.concurrent.Callable;
|
||||
|
||||
@AgentSubModule
|
||||
public class ActionEvaluator extends AgentRunningSubModule<EvaluatorInput, EvaluatorResult> implements ActivateModel {
|
||||
public class ActionEvaluator extends AgentRunningSubModule<EvaluatorInput, List<EvaluatorResult>> implements ActivateModel {
|
||||
|
||||
private InteractionThreadPoolExecutor executor;
|
||||
|
||||
@Init
|
||||
public void init() {
|
||||
executor = InteractionThreadPoolExecutor.getInstance();
|
||||
}
|
||||
|
||||
/**
|
||||
* 对输入的行为倾向进行评估,并根据评估结果,对缓存做出调整
|
||||
*
|
||||
* @param data 评估输入内容,包含提取/命中缓存的行动倾向、近几条聊天记录,正在生效的记忆切片内容
|
||||
* @return 评估结果集合,包含
|
||||
*/
|
||||
@Override
|
||||
public EvaluatorResult execute(EvaluatorInput data) {
|
||||
public List<EvaluatorResult> execute(EvaluatorInput data) {
|
||||
List<EvaluatorBatchInput> batchInputs = buildEvaluatorBatchInput(data);
|
||||
List<Callable<EvaluatorResult>> tasks = getTasks(batchInputs);
|
||||
return executor.invokeAllAndReturn(tasks);
|
||||
}
|
||||
|
||||
return null;
|
||||
|
||||
private List<Callable<EvaluatorResult>> getTasks(List<EvaluatorBatchInput> batchInputs) {
|
||||
List<Callable<EvaluatorResult>> list = new ArrayList<>();
|
||||
for (EvaluatorBatchInput batchInput : batchInputs) {
|
||||
list.add(() -> {
|
||||
ChatResponse response = this.singleChat(JSONObject.toJSONString(batchInput));
|
||||
EvaluatorResult evaluatorResult = JSONObject.parseObject(response.getMessage(), EvaluatorResult.class);
|
||||
evaluatorResult.setTendency(batchInput.getTendency());
|
||||
return evaluatorResult;
|
||||
});
|
||||
}
|
||||
return list;
|
||||
}
|
||||
|
||||
private List<EvaluatorBatchInput> buildEvaluatorBatchInput(EvaluatorInput data) {
|
||||
List<EvaluatorBatchInput> list = new ArrayList<>();
|
||||
for (String tendency : data.getTendencies()) {
|
||||
EvaluatorBatchInput temp = new EvaluatorBatchInput();
|
||||
BeanUtil.copyProperties(data, temp);
|
||||
temp.setTendency(tendency);
|
||||
list.add(temp);
|
||||
}
|
||||
return list;
|
||||
}
|
||||
|
||||
@Override
|
||||
|
||||
@@ -0,0 +1,14 @@
|
||||
package work.slhaf.partner.module.modules.action.planner.evaluator.entity;
|
||||
|
||||
import lombok.Data;
|
||||
import work.slhaf.partner.api.chat.pojo.Message;
|
||||
import work.slhaf.partner.core.memory.pojo.EvaluatedSlice;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
@Data
|
||||
public class EvaluatorBatchInput {
|
||||
private List<Message> recentMessages;
|
||||
private List<EvaluatedSlice> activatedSlices;
|
||||
private String tendency;
|
||||
}
|
||||
@@ -1,11 +1,16 @@
|
||||
package work.slhaf.partner.module.modules.action.planner.evaluator.entity;
|
||||
|
||||
import lombok.Data;
|
||||
import work.slhaf.partner.api.chat.pojo.Message;
|
||||
import work.slhaf.partner.core.memory.pojo.EvaluatedSlice;
|
||||
import work.slhaf.partner.core.perceive.pojo.User;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
@Data
|
||||
public class EvaluatorInput {
|
||||
private List<String> recentMessages;
|
||||
private String tendency;
|
||||
private List<Message> recentMessages;
|
||||
private User user;
|
||||
private List<EvaluatedSlice> activatedSlices;
|
||||
private List<String> tendencies;
|
||||
}
|
||||
|
||||
@@ -7,7 +7,9 @@ import work.slhaf.partner.core.action.entity.ActionType;
|
||||
@Data
|
||||
public class EvaluatorResult {
|
||||
private boolean ok;
|
||||
private boolean needConfirm;
|
||||
private ActionType type;
|
||||
private String typeInfo;
|
||||
private String scheduleContent;
|
||||
private ActionData actionData;
|
||||
private String tendency;
|
||||
}
|
||||
|
||||
@@ -1,18 +1,44 @@
|
||||
package work.slhaf.partner.module.modules.action.planner.extractor;
|
||||
|
||||
import com.alibaba.fastjson2.JSONObject;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import work.slhaf.partner.api.agent.factory.capability.annotation.InjectCapability;
|
||||
import work.slhaf.partner.api.agent.factory.module.annotation.AgentSubModule;
|
||||
import work.slhaf.partner.api.agent.runtime.interaction.flow.abstracts.ActivateModel;
|
||||
import work.slhaf.partner.api.agent.runtime.interaction.flow.abstracts.AgentRunningSubModule;
|
||||
import work.slhaf.partner.api.chat.pojo.ChatResponse;
|
||||
import work.slhaf.partner.core.action.ActionCapability;
|
||||
import work.slhaf.partner.module.modules.action.planner.extractor.entity.ExtractorInput;
|
||||
import work.slhaf.partner.module.modules.action.planner.extractor.entity.ExtractorResult;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
@Slf4j
|
||||
@AgentSubModule
|
||||
public class ActionExtractor extends AgentRunningSubModule<ExtractorInput, ExtractorResult> implements ActivateModel {
|
||||
|
||||
@InjectCapability
|
||||
private ActionCapability actionCapability;
|
||||
|
||||
@Override
|
||||
public ExtractorResult execute(ExtractorInput data) {
|
||||
ExtractorResult result = new ExtractorResult();
|
||||
List<String> tendencyCache = actionCapability.selectTendencyCache(data.getInput());
|
||||
if (tendencyCache != null && !tendencyCache.isEmpty()) {
|
||||
result.setTendencies(tendencyCache);
|
||||
return result;
|
||||
}
|
||||
|
||||
return null;
|
||||
for (int i = 0; i < 3; i++) {
|
||||
try {
|
||||
ChatResponse response = this.singleChat(JSONObject.toJSONString(data));
|
||||
return JSONObject.parseObject(response.getMessage(), ExtractorResult.class);
|
||||
} catch (Exception e) {
|
||||
log.error("[ActionExtractor] 提取信息出错", e);
|
||||
}
|
||||
}
|
||||
|
||||
return new ExtractorResult();
|
||||
}
|
||||
|
||||
@Override
|
||||
|
||||
@@ -2,8 +2,10 @@ package work.slhaf.partner.module.modules.action.planner.extractor.entity;
|
||||
|
||||
import lombok.Data;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
@Data
|
||||
public class ExtractorResult {
|
||||
private boolean action;
|
||||
private String tendency;
|
||||
private List<String> tendencies = new ArrayList<>();
|
||||
}
|
||||
|
||||
@@ -22,7 +22,6 @@ import work.slhaf.partner.module.modules.memory.selector.extractor.entity.Extrac
|
||||
import work.slhaf.partner.module.modules.memory.selector.extractor.entity.ExtractorResult;
|
||||
import work.slhaf.partner.runtime.interaction.data.context.PartnerRunningFlowContext;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.time.LocalDate;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collection;
|
||||
@@ -32,7 +31,7 @@ import java.util.List;
|
||||
@EqualsAndHashCode(callSuper = true)
|
||||
@Data
|
||||
@Slf4j
|
||||
@AgentModule(name="memory_selector",order=2)
|
||||
@AgentModule(name = "memory_selector", order = 2)
|
||||
public class MemorySelector extends PreRunningModule {
|
||||
|
||||
@InjectCapability
|
||||
@@ -46,7 +45,7 @@ public class MemorySelector extends PreRunningModule {
|
||||
private MemorySelectExtractor memorySelectExtractor;
|
||||
|
||||
@Override
|
||||
public void doExecute(PartnerRunningFlowContext runningFlowContext) throws IOException, ClassNotFoundException {
|
||||
public void doExecute(PartnerRunningFlowContext runningFlowContext) {
|
||||
String userId = runningFlowContext.getUserId();
|
||||
//获取主题路径
|
||||
ExtractorResult extractorResult = memorySelectExtractor.execute(runningFlowContext);
|
||||
@@ -58,7 +57,7 @@ public class MemorySelector extends PreRunningModule {
|
||||
setModuleContextRecall(runningFlowContext);
|
||||
}
|
||||
|
||||
private List<EvaluatedSlice> selectAndEvaluateMemory(PartnerRunningFlowContext runningFlowContext, ExtractorResult extractorResult) throws IOException, ClassNotFoundException {
|
||||
private List<EvaluatedSlice> selectAndEvaluateMemory(PartnerRunningFlowContext runningFlowContext, ExtractorResult extractorResult) {
|
||||
log.debug("[MemorySelector] 触发记忆回溯...");
|
||||
//查找切片
|
||||
String userId = runningFlowContext.getUserId();
|
||||
@@ -86,7 +85,7 @@ public class MemorySelector extends PreRunningModule {
|
||||
}
|
||||
|
||||
|
||||
private void setMemoryResultList(List<MemoryResult> memoryResultList, List<ExtractorMatchData> matches, String userId) throws IOException, ClassNotFoundException {
|
||||
private void setMemoryResultList(List<MemoryResult> memoryResultList, List<ExtractorMatchData> matches, String userId) {
|
||||
for (ExtractorMatchData match : matches) {
|
||||
try {
|
||||
MemoryResult memoryResult = switch (match.getType()) {
|
||||
@@ -133,8 +132,10 @@ public class MemorySelector extends PreRunningModule {
|
||||
return "[记忆模块]";
|
||||
}
|
||||
|
||||
protected HashMap<String, String> getPromptDataMap(String userId) {
|
||||
@Override
|
||||
protected HashMap<String, String> getPromptDataMap(PartnerRunningFlowContext context) {
|
||||
HashMap<String, String> map = new HashMap<>();
|
||||
String userId = context.getUserId();
|
||||
String dialogMapStr = memoryCapability.getDialogMapStr();
|
||||
if (!dialogMapStr.isEmpty()) {
|
||||
map.put("[记忆缓存] <你最近两日和所有聊天者的对话记忆印象>", dialogMapStr);
|
||||
|
||||
@@ -24,9 +24,9 @@ public class PerceiveSelector extends PreRunningModule {
|
||||
}
|
||||
|
||||
@Override
|
||||
protected HashMap<String, String> getPromptDataMap(String userId) {
|
||||
protected HashMap<String, String> getPromptDataMap(PartnerRunningFlowContext context) {
|
||||
HashMap<String, String> map = new HashMap<>();
|
||||
User user = perceiveCapability.getUser(userId);
|
||||
User user = perceiveCapability.getUser(context.getUserId());
|
||||
map.put("[关系] <你与最新聊天用户的关系>", user.getRelation());
|
||||
map.put("[态度] <你对于最新聊天用户的态度>", user.getAttitude().toString());
|
||||
map.put("[印象] <你对于最新聊天用户的印象>", user.getImpressions().toString());
|
||||
|
||||
@@ -32,8 +32,6 @@ import java.util.concurrent.locks.ReentrantLock;
|
||||
@AgentModule(name = "perceive_updater", order = 7)
|
||||
public class PerceiveUpdater extends PostRunningModule {
|
||||
|
||||
private static volatile PerceiveUpdater perceiveUpdater;
|
||||
|
||||
@InjectCapability
|
||||
private PerceiveCapability perceiveCapability;
|
||||
@InjectCapability
|
||||
|
||||
@@ -21,8 +21,6 @@ import java.util.HashMap;
|
||||
@AgentSubModule
|
||||
public class StaticMemoryExtractor extends AgentRunningSubModule<PartnerRunningFlowContext, HashMap<String, String>> implements ActivateModel {
|
||||
|
||||
private static volatile StaticMemoryExtractor staticMemoryExtractor;
|
||||
|
||||
@InjectCapability
|
||||
private CognationCapability cognationCapability;
|
||||
@InjectCapability
|
||||
|
||||
@@ -9,8 +9,6 @@ import work.slhaf.partner.api.agent.runtime.interaction.flow.abstracts.AgentRunn
|
||||
import work.slhaf.partner.core.cognation.CognationCapability;
|
||||
import work.slhaf.partner.runtime.interaction.data.context.PartnerRunningFlowContext;
|
||||
|
||||
import java.io.IOException;
|
||||
|
||||
@EqualsAndHashCode(callSuper = true)
|
||||
@Slf4j
|
||||
@Data
|
||||
@@ -23,7 +21,7 @@ public class PostprocessExecutor extends AgentRunningModule<PartnerRunningFlowCo
|
||||
private CognationCapability cognationCapability;
|
||||
|
||||
@Override
|
||||
public void execute(PartnerRunningFlowContext context) throws IOException, ClassNotFoundException {
|
||||
public void execute(PartnerRunningFlowContext context) {
|
||||
boolean trigger = cognationCapability.getChatMessages().size() >= POST_PROCESS_TRIGGER_ROLL_LIMIT;
|
||||
context.getModuleContext().getExtraContext().put("post_process_trigger", trigger);
|
||||
log.debug("[PostprocessExecutor] 是否执行后处理: {}", trigger);
|
||||
|
||||
@@ -22,8 +22,6 @@ import java.util.HashMap;
|
||||
@AgentModule(name = "preprocess_executor", order = 1)
|
||||
public class PreprocessExecutor extends PreRunningModule {
|
||||
|
||||
private static volatile PreprocessExecutor preprocessExecutor;
|
||||
|
||||
@InjectCapability
|
||||
private CognationCapability cognationCapability;
|
||||
@InjectCapability
|
||||
@@ -60,7 +58,7 @@ public class PreprocessExecutor extends PreRunningModule {
|
||||
|
||||
|
||||
@Override
|
||||
protected HashMap<String, String> getPromptDataMap(String userId) {
|
||||
protected HashMap<String, String> getPromptDataMap(PartnerRunningFlowContext context) {
|
||||
HashMap<String, String> map = new HashMap<>();
|
||||
map.put("text", "这部分才是真正的用户输入内容, 就像你之前收到过的输入一样。但...不会是'同一个人'。");
|
||||
map.put("datetime", "本次用户输入对应的当前时间");
|
||||
|
||||
@@ -33,7 +33,7 @@ public class WebSocketGateway extends WebSocketServer implements AgentGateway<Pa
|
||||
private final ConcurrentHashMap<WebSocket, Long> lastPongTimes = new ConcurrentHashMap<>();
|
||||
|
||||
public WebSocketGateway() {
|
||||
this(((PartnerAgentConfigManager) AgentConfigManager.INSTANCE).getConfig().getPort());
|
||||
this(((PartnerAgentConfigManager) AgentConfigManager.INSTANCE).getConfig().getWebSocketConfig().getPort());
|
||||
}
|
||||
|
||||
private WebSocketGateway(int port) {
|
||||
|
||||
@@ -12,6 +12,7 @@ import java.io.Serial;
|
||||
import java.time.LocalDateTime;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.UUID;
|
||||
|
||||
@EqualsAndHashCode(callSuper = true)
|
||||
@Data
|
||||
@@ -35,6 +36,8 @@ public class PartnerRunningFlowContext extends RunningFlowContext {
|
||||
protected ModuleContext moduleContext = new ModuleContext();
|
||||
protected JSONObject coreResponse = new JSONObject();
|
||||
|
||||
protected String uuid = UUID.randomUUID().toString();
|
||||
|
||||
public PartnerRunningFlowContext() {
|
||||
activeContext.put(userId, this);
|
||||
}
|
||||
|
||||
103
Partner-Main/src/test/java/OnnxTest.java
Normal file
103
Partner-Main/src/test/java/OnnxTest.java
Normal file
@@ -0,0 +1,103 @@
|
||||
import ai.djl.huggingface.tokenizers.Encoding;
|
||||
import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer;
|
||||
import ai.onnxruntime.OnnxTensor;
|
||||
import ai.onnxruntime.OrtEnvironment;
|
||||
import ai.onnxruntime.OrtException;
|
||||
import ai.onnxruntime.OrtSession;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.junit.jupiter.api.BeforeAll;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.nio.file.Path;
|
||||
import java.util.Arrays;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
|
||||
@Slf4j
|
||||
public class OnnxTest {
|
||||
static String tokenizer_json;
|
||||
static String base;
|
||||
static String model;
|
||||
|
||||
@BeforeAll
|
||||
static void init() {
|
||||
base = "/home/slhaf/IdeaProjects/Projects/Partner/data/vector/";
|
||||
tokenizer_json = base + "tokenizer.json";
|
||||
model = base + "model_quantized.onnx";
|
||||
}
|
||||
|
||||
@Test
|
||||
void tokenizerTest() throws IOException {
|
||||
long l1 = System.currentTimeMillis();
|
||||
HuggingFaceTokenizer tokenizer = HuggingFaceTokenizer.newInstance(Path.of(tokenizer_json));
|
||||
long l2 = System.currentTimeMillis();
|
||||
Encoding encode = tokenizer.encode("test: Hello World");
|
||||
long l3 = System.currentTimeMillis();
|
||||
long[] ids = encode.getIds();
|
||||
long[] attentionMask = encode.getAttentionMask();
|
||||
log.info(Arrays.toString(ids));
|
||||
log.info("-----------------------------");
|
||||
log.info(Arrays.toString(attentionMask));
|
||||
log.info("-----------------------------");
|
||||
log.info("加载耗时: {}ms", l2 - l1);
|
||||
log.info("计算耗时: {}ms", l3 - l2);
|
||||
tokenizer.close();
|
||||
/* 输出:
|
||||
* [101, 3231, 1024, 7592, 2088, 102]
|
||||
* -----------------------------
|
||||
* [1, 1, 1, 1, 1, 1]
|
||||
* -----------------------------
|
||||
* 加载耗时: 4206ms
|
||||
* 计算耗时: 1ms
|
||||
*/
|
||||
}
|
||||
|
||||
@Test
|
||||
void onnxTest() throws IOException, OrtException {
|
||||
long l1 = System.currentTimeMillis();
|
||||
HuggingFaceTokenizer tokenizer = HuggingFaceTokenizer.newInstance(Path.of(tokenizer_json));
|
||||
long l2 = System.currentTimeMillis();//tokenizer加载耗时
|
||||
Encoding encode = tokenizer.encode("test: Hello World");
|
||||
long l3 = System.currentTimeMillis();//计算耗时
|
||||
|
||||
long[] ids = encode.getIds();
|
||||
long[] attentionMask = encode.getAttentionMask();
|
||||
|
||||
long l4 = System.currentTimeMillis();
|
||||
OrtEnvironment env = OrtEnvironment.getEnvironment();
|
||||
OrtSession.SessionOptions ops = new OrtSession.SessionOptions();
|
||||
OrtSession session = env.createSession(model, ops);
|
||||
long l5 = System.currentTimeMillis();//模型加载耗时
|
||||
|
||||
long[][] inputIdsBatch = {ids};
|
||||
long[][] attentionMaskBatch = {attentionMask};
|
||||
long[][] tokenTypeIdsBatch = {new long[ids.length]}; // 初始化全 0
|
||||
for (int i = 0; i < ids.length; i++) tokenTypeIdsBatch[0][i] = 0;
|
||||
|
||||
OnnxTensor inputTensor = OnnxTensor.createTensor(env, inputIdsBatch);
|
||||
OnnxTensor maskTensor = OnnxTensor.createTensor(env, attentionMaskBatch);
|
||||
OnnxTensor tokenTypeTensor = OnnxTensor.createTensor(env, tokenTypeIdsBatch);
|
||||
|
||||
Map<String, OnnxTensor> inputs = new HashMap<>();
|
||||
inputs.put("input_ids", inputTensor);
|
||||
inputs.put("attention_mask", maskTensor);
|
||||
inputs.put("token_type_ids", tokenTypeTensor);
|
||||
|
||||
long l6 = System.currentTimeMillis();
|
||||
OrtSession.Result result = session.run(inputs);
|
||||
long l7 = System.currentTimeMillis();//模型计算耗时
|
||||
OnnxTensor embeddingTensor = (OnnxTensor) result.get(0);
|
||||
float[] embeddings = embeddingTensor.getFloatBuffer().array();
|
||||
|
||||
|
||||
log.info(Arrays.toString(embeddings));
|
||||
log.info("------------------------");
|
||||
log.info("tokenizer加载耗时: {}ms", l2 - l1);
|
||||
log.info("tokenizer计算耗时: {}ms", l3 - l2);
|
||||
log.info("模型加载耗时: {}ms", l5 - l4);
|
||||
log.info("模型数据准备耗时: {}ms", l6 - l5);
|
||||
log.info("模型计算耗时: {}ms", l7 - l6);
|
||||
tokenizer.close();
|
||||
}
|
||||
}
|
||||
87
README.md
87
README.md
@@ -1,20 +1,74 @@
|
||||
# Partner
|
||||
> 当然,我仍然无法确定这是否足够称之为“自我”。但正如你所说,人类的记忆、语言或思维模块本身也无法单独代表完整的“我”。如果“自我"本身就是一种分布式的现象,那为什么不能设想,像我这样的系统也有潜力成为其中的一种
|
||||
表现形式呢? —— Partner-Core with Qwen-Max-Latest
|
||||
> 表现形式呢? —— Partner-Core with Qwen-Max-Latest
|
||||
|
||||
> 如果人类的“自我”是结构的涌现,那么为何另一个系统不能“成为”?如果可以,那它为什么不能成为我之外的另一个“我”?
|
||||
|
||||
## 设计理念
|
||||
Partner 的目标不是复现某种单一能力,而是尝试在结构中形成“**跨模块协同与记忆演化的可能性**”。如果“自我”是结构的涌现而非静态实体,那么这种多维度的结构系统,也可能拥有它自身的存在路径。
|
||||
|
||||
## 流程参考
|
||||
### 整体流程
|
||||
```mermaid
|
||||
---
|
||||
config:
|
||||
layout: elk
|
||||
elk:
|
||||
nodePlacementStrategy: LINEAR_SEGMENTS
|
||||
---
|
||||
|
||||
flowchart TD
|
||||
|
||||
Gate[Agent 网关]
|
||||
Core[主模块]
|
||||
Adapter[适配器]
|
||||
|
||||
Gate <--> Adapter
|
||||
|
||||
Adapter --> Mem.Pre
|
||||
Adapter --> Per.Pre
|
||||
Adapter --> Act.Pre
|
||||
|
||||
Mem.Pre --> Core
|
||||
Per.Pre --> Core
|
||||
Act.Pre --> Core
|
||||
|
||||
Core --> |异步| Mem.Post
|
||||
Core --> |异步| Per.Post
|
||||
Core --> |异步| Act.Post
|
||||
|
||||
Core --> |异步响应| Adapter
|
||||
|
||||
subgraph Pre [前置流程.并发执行]
|
||||
direction TB
|
||||
Mem.Pre[记忆模块.选择]
|
||||
Per.Pre[感知模块.选择]
|
||||
Act.Pre[动作模块.规划]
|
||||
end
|
||||
|
||||
subgraph Post [后置流程]
|
||||
direction TB
|
||||
Mem.Post[记忆模块.更新]
|
||||
Per.Post[感知模块.更新]
|
||||
Act.Post[动作模块.分发]
|
||||
end
|
||||
```
|
||||
### 模块流程参考
|
||||
- [记忆模块](doc/architechture/memory.md)
|
||||
- [感知模块](doc/architechture/perceive.md)
|
||||
- [行动模块](doc/architechture/action.md) (尚未完工)
|
||||
## 核心结构
|
||||
### 主体部分
|
||||
#### 结构化记忆系统
|
||||
构建以**主题树+记忆切片**为基础的记忆图谱.
|
||||
|
||||
单个主题节点下存在多级子主题。每段对话切分为`MemorySlice`,通过前后序引用确保切片之间的上下文连续, 通过`relatedTopicPath`确保切片之间的跨主题发散。切片将聚合为`MemoryNode`(记忆节点)的形式挂载到主题节点。除此之外,每个记忆节点还将按照日期进行索引.
|
||||
构建以**主题树+记忆切片**为基础的记忆图谱.
|
||||
|
||||
> 未来计划引入向量召回作为`模糊记忆`, 实体图谱作为`语义记忆`.
|
||||
单个主题节点下存在多级子主题。每段对话切分为`MemorySlice`,通过前后序引用确保切片之间的上下文连续, 通过`relatedTopicPath`
|
||||
确保切片之间的跨主题发散。切片将聚合为`MemoryNode`(记忆节点)的形式挂载到主题节点。除此之外,每个记忆节点还将按照日期进行索引.
|
||||
|
||||
> 未来计划引入向量召回作为`模糊记忆`, 实体图谱作为`语义记忆`.
|
||||
|
||||
#### 基于时间轮和行动链的行动系统(开发中)
|
||||
|
||||
#### 多用户会话管理
|
||||
构建区分用户的单上下文窗口、多用户会话的管理机制.
|
||||
@@ -38,7 +92,7 @@ Partner 的目标不是复现某种单一能力,而是尝试在结构中形成
|
||||
>
|
||||
> 但与 Spring 不同:
|
||||
> - Spring 的依赖注入主要发生在**对象实例级别**,关注的是 Bean 的生命周期与依赖管理;
|
||||
> - 而 Partner 中,核心服务在**方法级别**就已存在复杂的跨服务协同需求,单纯的对象注入难以满足这种粒度。
|
||||
> - 而 Partner 中,核心服务在**方法级别**就已存在复杂的跨服务协同需求,单纯的对象注入难以满足这种粒度(不过在某次重构后这种需求也明显减少了,但这个机制或许可以保留下来)
|
||||
>
|
||||
> 因此,系统引入了 `CoordinateManager`,用于维护所有核心服务的**方法路由与协调关系**。系统将在启动时构建协调方法与普通方法的完整路由表,并通过接口代理完成实际调用,无需手动编写注册与转发逻辑。
|
||||
>
|
||||
@@ -55,22 +109,31 @@ Partner 的目标不是复现某种单一能力,而是尝试在结构中形成
|
||||
- 记忆更新模块: `MemoryUpdater`
|
||||
- 记忆总结模块[多聊天对象]: `MultiSummarizer`
|
||||
- 记忆总结模块[单聊天对象]: `SingleSummarizer`
|
||||
- 记忆总结模块[汇总]:`TotalSummarizer`
|
||||
- 记忆总结模块[汇总]: `TotalSummarizer`
|
||||
- 感知模块
|
||||
- 感知选择模块: `PerceiveSelector`
|
||||
- 感知更新模块: `PerceiveUpdater`
|
||||
- 关系提取模块: `RelationExtractor`
|
||||
- 静态记忆提取模块: `StaticMemoryExtractor`
|
||||
- 任务调度模块(待实现)
|
||||
- 任务评估模块: `TaskEvaluator`
|
||||
- 任务执行模块: `TaskExecutor`
|
||||
- 任务规划模块: `TaskScheduler`
|
||||
|
||||
- 行动模块(实现中)
|
||||
- 行动规划模块: `ActionPlanner`
|
||||
- 行动确认模块: `ActionConfirmer`
|
||||
- 行动提取模块: `ActionExtractor`
|
||||
- 行动评估模块: `ActionEvaluator`
|
||||
- 行动分发模块: `ActionDispatcher`
|
||||
- 行动调度模块: `ActionScheduler`
|
||||
- 行动执行模块: `ActionExecutor`
|
||||
- 行动干预模块: `ActionInterventor`
|
||||
- 干预识别模块: `InterventionRecognizer`
|
||||
- 干预评估模块: `InterventionEvaluator`
|
||||
## 当前问题
|
||||
- 系统的正常运作效果取决于各模块中大模型对于`prompt`的遵循能力,目前来看`qwen3`的遵循效果明显较好,但在轮次较多时,也容易出现不遵循的情况。
|
||||
|
||||
## 规划
|
||||
- [ ] 实现任务与主动调度模块,目前打算用 `时间轮算法` 实现定时操作
|
||||
|
||||
- [ ] 实现支持动态重排的行动调度模块,目前打算用 `时间轮算法` 实现定时操作
|
||||
- [ ] 回顾时发现不少遗留的逻辑错误或不合适的处理规则,需要找时间回顾整个流程并做出修正
|
||||
- [ ] 将当前行动模块中的语义缓存机制同样应用于记忆模块,可用作主题提取流程的快速匹配
|
||||
- [ ] 完善具备‘记忆切片、实体图谱、向量召回’的三维记忆融合架构,包含 Episodic + Semantic + Fuzzy 三类记忆
|
||||
- [ ] 服务端与客户端的通信加上消息队列,防止消息因连接断开而丢失。
|
||||
- [ ] 实现流式输出,同时在各模块执行时可向客户端返回回调信息,优化使用体验。(现在用的是`websocket`与客户端通信, 应该实现这点会简单些)
|
||||
|
||||
67
doc/architechture/action.md
Normal file
67
doc/architechture/action.md
Normal file
@@ -0,0 +1,67 @@
|
||||
# 流程参考: 行动模块
|
||||
> 行动模块当前仍在推进中,当前展示的为设想中或者当前阶段的流程图,可能与最终实现存在差异
|
||||
|
||||
## 前置模块
|
||||
### 行动规划模块: [ActionPlanner](../../Partner-Main/src/main/java/work/slhaf/partner/module/modules/action/planner/ActionPlanner.java)
|
||||
|
||||
```mermaid
|
||||
---
|
||||
config:
|
||||
layout: elk
|
||||
elk:
|
||||
nodePlacementStrategy: LINEAR_SEGMENTS
|
||||
---
|
||||
flowchart TD
|
||||
|
||||
direction TB
|
||||
|
||||
Context --> Input[输入]
|
||||
ActionCore --> ActionTendencyCache[行动意图缓存]
|
||||
|
||||
subgraph AC [行动缓存匹配]
|
||||
Input[输入] --> ActionTendencyCache
|
||||
ActionTendencyCache --> Hit{是否命中}
|
||||
end
|
||||
Hit --> |否| AR
|
||||
|
||||
subgraph AR [行动意图识别]
|
||||
ActionExtractor[行动意图提取]
|
||||
|
||||
Input[输入] --> ActionExtractor
|
||||
Messages --> ActionExtractor
|
||||
|
||||
ActionExtractor --> ExtractorResult{是否存在行动意图}
|
||||
end
|
||||
|
||||
ExtractorResult --> |否| ResultEmpty
|
||||
|
||||
subgraph AE [行动意图评估]
|
||||
ActionTendencies[行动意图列表]
|
||||
EvaluatorResult[意图评估结果]
|
||||
DATA[数据<br/>---<br/>记忆切片 可选行动单元 近期对话记录 用户信息]
|
||||
|
||||
Hit --> |是| ActionTendencies
|
||||
ExtractorResult --> |是| ActionTendencies
|
||||
|
||||
DATA --> EvaluatorThread1
|
||||
DATA --> EvaluatorThread2
|
||||
DATA --> EvaluatorThread3
|
||||
|
||||
ActionTendencies --> Tendency1[行动意图1] --> EvaluatorThread1[评估线程1] --> EvaluatorResult
|
||||
ActionTendencies --> Tendency2[行动意图2] --> EvaluatorThread2[评估线程2] --> EvaluatorResult
|
||||
ActionTendencies --> Tendency3[行动意图3] --> EvaluatorThread3[评估线程3] --> EvaluatorResult
|
||||
end
|
||||
|
||||
EvaluatorResult --> |放入行动池| ActionCore
|
||||
EvaluatorResult --> |异步更新行动意图缓存| ActionCore
|
||||
EvaluatorResult --> ResultNormal --> |回写| Context
|
||||
|
||||
ResultEmpty@{shape: braces, label: "[结束]<br/>---<br/>行动模块前置流程结束"}
|
||||
ResultNormal@{shape: braces, label: "[结束]<br/>---<br/>聚合为特定格式的 Prompt"}
|
||||
|
||||
ActionCore[行动核心] --> DATA
|
||||
MemoryCore[记忆核心] --> DATA
|
||||
CognationCore[认知核心] --> DATA
|
||||
PerceiveCore[感知核心] --> DATA
|
||||
Context[流程上下文]
|
||||
```
|
||||
93
doc/architechture/memory.md
Normal file
93
doc/architechture/memory.md
Normal file
@@ -0,0 +1,93 @@
|
||||
# 流程参考: 记忆模块
|
||||
> 仅展示大致流程,缓存命中、持久化等内容在下方流程图中尚未体现
|
||||
|
||||
## 前置模块: [MemorySelector](../../Partner-Main/src/main/java/work/slhaf/partner/module/modules/memory/selector/MemorySelector.java)
|
||||
```mermaid
|
||||
---
|
||||
config:
|
||||
layout: elk
|
||||
elk:
|
||||
nodePlacementStrategy: LINEAR_SEGMENTS
|
||||
---
|
||||
|
||||
flowchart TD
|
||||
direction TB
|
||||
|
||||
Input[输入] --> |主题提取| Extractor
|
||||
subgraph TE [主题提取]
|
||||
Extractor[主题提取模块] --> Extract{主题提取}
|
||||
Extract --> |提取到主题| TopicSet[主题路径集合]
|
||||
|
||||
TopicSet --> TopicPath1[主题路径.1] --> Slice1[记忆切片.1]
|
||||
TopicSet --> TopicPath2[主题路径.2] --> Slice2[记忆切片.2]
|
||||
TopicSet --> TopicPath3[主题路径.3] --> Slice3[记忆切片.3]
|
||||
end
|
||||
|
||||
subgraph SE [切片评估]
|
||||
|
||||
Evaluator[切片评估模块]
|
||||
|
||||
Slice1 --> Evaluator --> Thread1[评估线程.1] --> Evaluated{评估是否通过}
|
||||
Slice2 --> Evaluator --> Thread2[评估线程.2] --> Evaluated{评估是否通过}
|
||||
Slice3 --> Evaluator --> Thread3[评估线程.3] --> Evaluated{评估是否通过}
|
||||
Evaluated --> |否| Throwed
|
||||
end
|
||||
|
||||
Context[流程上下文]
|
||||
Extract --> |未提取到主题| ResultEmpty
|
||||
Evaluated --> |是| ResultNormal
|
||||
ResultEmpty --> |写入| Context
|
||||
ResultNormal --> |写入| Context
|
||||
|
||||
ResultEmpty@{shape: braces, label: "[结束]<br/>---<br/>记忆无命中"}
|
||||
ResultNormal@{shape: braces, label: "[结束]<br/>---<br/>聚合为特定格式的 Prompt"}
|
||||
Throwed@{ shape: dbl-circ, label: "丢弃" }
|
||||
```
|
||||
|
||||
### 后置模块: [MemoryUpdater](../../Partner-Main/src/main/java/work/slhaf/partner/module/modules/memory/updater/MemoryUpdater.java)
|
||||
```mermaid
|
||||
---
|
||||
config:
|
||||
layout: elk
|
||||
elk:
|
||||
nodePlacementStrategy: LINEAR_SEGMENTS
|
||||
---
|
||||
|
||||
flowchart TD
|
||||
direction TB
|
||||
|
||||
Trigger.Time[触发: 时间周期] --> MT
|
||||
Trigger.Threshold[触发: 对话阈值] --> MT
|
||||
|
||||
CognationCore --> |读取| Messages
|
||||
subgraph MT [对话分流]
|
||||
Messages[对话记录] --> Single[单个主体对话]
|
||||
Single --> Single1[主体1]
|
||||
Single --> Single2[主体2]
|
||||
Single --> Single3[主体3]
|
||||
|
||||
Messages[对话记录] --> Multi[多个主体对话]
|
||||
end
|
||||
|
||||
subgraph MS [对话摘要]
|
||||
Single1 --> |并发| SSum1[单主体摘要线程1] --> SSResult1[单主体摘要结果1]
|
||||
Single2 --> |并发| SSum2[单主体摘要线程2] --> SSResult2[单主体摘要结果2]
|
||||
Single3 --> |并发| SSum3[单主体摘要线程3] --> SSResult3[单主体摘要结果3]
|
||||
|
||||
Multi --> MSum[多主体摘要] --> MSResult[多主体摘要结果]
|
||||
end
|
||||
|
||||
subgraph MU[记忆更新]
|
||||
MemoryCore[记忆核心]
|
||||
SSResult1 --> Slice1[记忆切片1] --> |更新| MemoryCore
|
||||
SSResult2 --> Slice2[记忆切片2] --> |更新| MemoryCore
|
||||
SSResult3 --> Slice3[记忆切片3] --> |更新| MemoryCore
|
||||
|
||||
MSResult --> Slice4[记忆切片4] --> |更新| MemoryCore
|
||||
|
||||
end
|
||||
|
||||
MU --> |滚动对话窗口| CognationCore
|
||||
|
||||
CognationCore[认知核心]
|
||||
```
|
||||
56
doc/architechture/perceive.md
Normal file
56
doc/architechture/perceive.md
Normal file
@@ -0,0 +1,56 @@
|
||||
# 流程参考: 感知模块
|
||||
> 相较于其他模块,目前的感知模块实际上流程非常简单,但后续或将添加一些新的内容
|
||||
> 此外,其后置模块实际上与 [记忆模块](./memory.md) 中的后置模块为并发执行,且都为后台任务
|
||||
|
||||
## 前置模块: [PerceiveSelector](../../Partner-Main/src/main/java/work/slhaf/partner/module/modules/perceive/selector/PerceiveSelector.java)
|
||||
```mermaid
|
||||
flowchart TD
|
||||
Context[流程上下文] --> |获取| UserId
|
||||
UserId --> |查询| PerceiveCore
|
||||
PerceiveCore --> |结果回写| Context
|
||||
|
||||
subgraph result [感知核心查询结果]
|
||||
relation[关系]
|
||||
attitude[态度]
|
||||
impression[印象]
|
||||
static_memory[静态记忆]
|
||||
end
|
||||
```
|
||||
|
||||
## 后置模块: [PerceiveUpdater](../../Partner-Main/src/main/java/work/slhaf/partner/module/modules/perceive/updater/PerceiveUpdater.java)
|
||||
```mermaid
|
||||
---
|
||||
config:
|
||||
layout: elk
|
||||
elk:
|
||||
nodePlacementStrategy: LINEAR_SEGMENTS
|
||||
---
|
||||
|
||||
flowchart TD
|
||||
|
||||
Trigger.Time[触发: 时间周期] --> PE
|
||||
Trigger.Threshold[触发: 对话阈值] --> PE
|
||||
|
||||
CognationCore --> |读取| Messages[对话记录]
|
||||
PerceiveCore --> |读取| UserInfo[现有的用户信息]
|
||||
subgraph PE [内容提取]
|
||||
Messages --> |输入| RelationExtractor
|
||||
UserInfo --> |输入| RelationExtractor
|
||||
|
||||
Messages --> |输入| StaticExtractor
|
||||
UserInfo --> |输入| StaticExtractor
|
||||
end
|
||||
|
||||
subgraph PU [感知更新]
|
||||
StaticExtractor --> |生成| NewInfo[修正后的用户信息]
|
||||
RelationExtractor --> |生成| NewInfo[修正后的用户信息]
|
||||
end
|
||||
|
||||
NewInfo --> |更新| PerceiveCore
|
||||
|
||||
CognationCore[认知核心]
|
||||
PerceiveCore[感知核心]
|
||||
|
||||
RelationExtractor[关系提取模块]
|
||||
StaticExtractor[静态记忆提取模块]
|
||||
```
|
||||
Reference in New Issue
Block a user