refactor(Project): rename Partner-Api/Partner-Main modules to Partner-Framework/Partner-Core and update Maven dependencies

This commit is contained in:
2026-02-19 10:39:21 +08:00
parent 1244d59fa4
commit 73ab40416d
258 changed files with 12 additions and 12 deletions

View File

@@ -0,0 +1,18 @@
package work.slhaf.partner;
import work.slhaf.partner.api.agent.Agent;
import work.slhaf.partner.common.config.PartnerAgentConfigManager;
import work.slhaf.partner.common.vector.VectorClient;
import work.slhaf.partner.runtime.exception.PartnerExceptionCallback;
import work.slhaf.partner.runtime.interaction.WebSocketGateway;
public class Main {
public static void main(String[] args) {
Agent.newAgent(Main.class)
.setAgentConfigManager(PartnerAgentConfigManager.class)
.setGateway(WebSocketGateway.class)
.setAgentExceptionCallback(PartnerExceptionCallback.class)
.addAfterLaunchRunners(VectorClient::load)
.launch();
}
}

View File

@@ -0,0 +1,10 @@
package work.slhaf.partner.common;
public final class Constant {
public static final class Path {
public static final String DATA = "data";
public static final String MEMORY_DATA = DATA + "/memory";
}
}

View File

@@ -0,0 +1,24 @@
package work.slhaf.partner.common.config;
import lombok.Data;
@Data
public class Config {
private String agentId;
private WebSocketConfig webSocketConfig;
private VectorConfig vectorConfig;
@Data
public static class VectorConfig {
private int type;
private String ollamaEmbeddingUrl;
private String ollamaEmbeddingModel;
private String tokenizerPath;
private String embeddingModelPath;
}
@Data
public static class WebSocketConfig {
private int port;
}
}

View File

@@ -0,0 +1,41 @@
package work.slhaf.partner.common.config;
import cn.hutool.json.JSONUtil;
import lombok.Data;
import lombok.EqualsAndHashCode;
import work.slhaf.partner.api.agent.factory.config.exception.ConfigNotExistException;
import work.slhaf.partner.api.agent.runtime.config.FileAgentConfigManager;
import work.slhaf.partner.common.exception.ConfigLoadFailedException;
import java.io.File;
import java.nio.charset.StandardCharsets;
@EqualsAndHashCode(callSuper = true)
@Data
public final class PartnerAgentConfigManager extends FileAgentConfigManager {
private static final String COMMON_CONFIG_FILE = CONFIG_DIR + "common_config.json";
private Config config;
@Override
public void load() {
loadWebSocketConfig();
super.load();
}
private void loadWebSocketConfig() {
File file = new File(COMMON_CONFIG_FILE);
if (!file.exists()) {
throw new ConfigNotExistException("Partner Config Not Exist: " + COMMON_CONFIG_FILE);
}
config = JSONUtil.readJSONObject(file, StandardCharsets.UTF_8).toBean(Config.class);
if (config == null || config.getAgentId() == null) {
throw new ConfigLoadFailedException("Partner Config Load Failed: " + COMMON_CONFIG_FILE);
}
int port = config.getWebSocketConfig().getPort();
if (port <= 0 || port > 65535) {
throw new ConfigLoadFailedException("Invalid Websocket port: " + port);
}
}
}

View File

@@ -0,0 +1,13 @@
package work.slhaf.partner.common.exception;
import work.slhaf.partner.api.agent.factory.config.exception.ConfigFactoryInitFailedException;
public class ConfigLoadFailedException extends ConfigFactoryInitFailedException {
public ConfigLoadFailedException(String message, Throwable cause) {
super(message, cause);
}
public ConfigLoadFailedException(String message) {
super(message);
}
}

View File

@@ -0,0 +1,13 @@
package work.slhaf.partner.common.exception;
import work.slhaf.partner.api.agent.runtime.exception.AgentLaunchFailedException;
public class ServiceLoadFailedException extends AgentLaunchFailedException {
public ServiceLoadFailedException(String message, Throwable cause) {
super(message, cause);
}
public ServiceLoadFailedException(String message) {
super(message);
}
}

View File

@@ -0,0 +1,36 @@
package work.slhaf.partner.common.monitor;
import lombok.extern.slf4j.Slf4j;
import work.slhaf.partner.common.thread.InteractionThreadPoolExecutor;
@Slf4j
public class DebugMonitor {
private InteractionThreadPoolExecutor executor;
private static DebugMonitor debugMonitor;
public static void initialize() {
debugMonitor = new DebugMonitor();
debugMonitor.executor = InteractionThreadPoolExecutor.getInstance();
debugMonitor.runMonitor();
}
private void runMonitor() {
executor.execute(() -> {
while (true) {
try {
Thread.sleep(3000);
} catch (InterruptedException e) {
log.error("监测线程报错?");
}
}
});
}
public static DebugMonitor getInstance(){
if (debugMonitor == null) {
initialize();
}
return debugMonitor;
}
}

View File

@@ -0,0 +1,60 @@
package work.slhaf.partner.common.thread;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.*;
public class InteractionThreadPoolExecutor {
private static InteractionThreadPoolExecutor interactionThreadPoolExecutor;
private final ExecutorService executorService = Executors.newVirtualThreadPerTaskExecutor();
public static InteractionThreadPoolExecutor getInstance() {
if (interactionThreadPoolExecutor == null) {
interactionThreadPoolExecutor = new InteractionThreadPoolExecutor();
}
return interactionThreadPoolExecutor;
}
public <T> void invokeAll(List<Callable<T>> tasks, int time, TimeUnit timeUnit) {
try {
executorService.invokeAll(tasks, time, timeUnit);
} catch (InterruptedException e) {
throw new RuntimeException(e);
}
}
public <T> void invokeAll(List<Callable<T>> tasks) {
try {
List<Future<T>> futures = executorService.invokeAll(tasks);
for (Future<T> future : futures) {
future.get();
}
} catch (InterruptedException e) {
throw new RuntimeException(e);
} catch (ExecutionException e) {
throw new RuntimeException(e.getCause());
}
}
public <T> List<T> invokeAllAndReturn(List<Callable<T>> tasks) {
try {
List<Future<T>> futures = executorService.invokeAll(tasks);
List<T> results = new ArrayList<>();
for (Future<T> future : futures) {
results.add(future.get());
}
return results;
} catch (InterruptedException e) {
throw new RuntimeException(e);
} catch (ExecutionException e) {
throw new RuntimeException(e.getCause());
}
}
public void execute(Runnable runnable) {
executorService.execute(runnable);
}
}

View File

@@ -0,0 +1,42 @@
package work.slhaf.partner.common.util;
import java.util.ArrayList;
import java.util.List;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
public class ExtractUtil {
public static String extractJson(String jsonStr) {
jsonStr = jsonStr.replace("", "\"").replace("", "\"");
int start = jsonStr.indexOf("{");
int end = jsonStr.lastIndexOf("}");
if (start != -1 && end != -1 && start < end) {
return jsonStr.substring(start, end + 1);
}
return jsonStr;
}
public static String extractUserId(String messageContent) {
Pattern pattern = Pattern.compile("\\[.*\\(([^)]+)\\)\\]");
Matcher matcher = pattern.matcher(messageContent);
if (!matcher.find()) {
return null;
}
return matcher.group(1);
}
public static String fixTopicPath(String topicPath) {
String[] parts = topicPath.split("->");
List<String> cleanedParts = new ArrayList<>();
for (String part : parts) {
// 修正正则表达式,正确移除 [xxx] 部分
String cleaned = part.replaceAll("\\[[^\\]]*\\]", "").trim();
if (!cleaned.isEmpty()) { // 忽略空字符串
cleanedParts.add(cleaned);
}
}
return String.join("->", cleanedParts);
}
}

View File

@@ -0,0 +1,14 @@
package work.slhaf.partner.common.util;
public class PathUtil {
public static String buildPathStr(String... path) {
StringBuilder str = new StringBuilder();
for (int i = 0; i < path.length; i++) {
str.append(path[i]);
if (i < path.length - 1) {
str.append("/");
}
}
return str.toString();
}
}

View File

@@ -0,0 +1,50 @@
package work.slhaf.partner.common.util;
import com.alibaba.fastjson2.JSONArray;
import work.slhaf.partner.api.agent.Agent;
import work.slhaf.partner.api.chat.pojo.Message;
import java.io.InputStream;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.List;
public class ResourcesUtil {
private static final ClassLoader classloader = Agent.class.getClassLoader();
public static class Prompt {
private static final String SELF_AWARENESS_PATH = "prompt/basic_prompt.json";
private static final String MODULE_PROMPT_PREFIX_PATH = "prompt/module/";
public static List<Message> loadPromptWithSelfAwareness(String modelKey, String promptType) {
//加载人格引导
List<Message> messages = new ArrayList<>(loadSelfAwareness());
//加载常规提示
String path = MODULE_PROMPT_PREFIX_PATH + promptType + "/" + modelKey + ".json";
messages.addAll(readPromptFromResources(path));
return messages;
}
public static List<Message> loadSelfAwareness() {
return readPromptFromResources(SELF_AWARENESS_PATH);
}
public static List<Message> loadPrompt(String modelKey,String promptType){
return new ArrayList<>(readPromptFromResources(MODULE_PROMPT_PREFIX_PATH+promptType+"/"+modelKey+".json"));
}
private static List<Message> readPromptFromResources(String filePath) {
try {
InputStream inputStream = classloader.getResourceAsStream(filePath);
String content = new String(inputStream.readAllBytes(), StandardCharsets.UTF_8);
JSONArray array = JSONArray.parse(content);
inputStream.close();
return array.toJavaList(Message.class);
} catch (Exception e) {
throw new RuntimeException("读取Resource失败: " + filePath, e);
}
}
}
}

View File

@@ -0,0 +1,48 @@
package work.slhaf.partner.common.vector;
import cn.hutool.http.HttpRequest;
import cn.hutool.http.HttpResponse;
import com.alibaba.fastjson2.JSONObject;
import lombok.Data;
import lombok.extern.slf4j.Slf4j;
import work.slhaf.partner.common.vector.exception.VectorClientExecuteException;
import java.util.Map;
@Slf4j
public class OllamaVectorClient extends VectorClient {
private final String ollamaEmbeddingUrl;
private final String ollamaEmbeddingModel;
protected OllamaVectorClient(String url, String model) {
this.ollamaEmbeddingUrl = url;
this.ollamaEmbeddingModel = model;
compute("test");
}
@Override
protected float[] doCompute(String input) {
Map<String, String> param = Map.of("model", ollamaEmbeddingModel, "input", input);
HttpRequest request = HttpRequest.get(ollamaEmbeddingUrl).body(JSONObject.toJSONString(param));
try (HttpResponse response = request.execute()) {
if (!response.isOk())
throw new VectorClientExecuteException("嵌入模型执行出错");
String resStr = response.body();
EmbeddingModelResponse embeddingResponse = JSONObject.parseObject(resStr, EmbeddingModelResponse.class);
return embeddingResponse.getEmbeddings()[0];
} catch (Exception e) {
throw new VectorClientExecuteException("嵌入模型执行出错", e);
}
}
@Data
private static class EmbeddingModelResponse {
private String model;
private float[][] embeddings;
private long total_duration;
private long load_duration;
private int prompt_eval_count;
}
}

View File

@@ -0,0 +1,82 @@
package work.slhaf.partner.common.vector;
import ai.djl.huggingface.tokenizers.Encoding;
import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer;
import ai.onnxruntime.OnnxTensor;
import ai.onnxruntime.OrtEnvironment;
import ai.onnxruntime.OrtSession;
import work.slhaf.partner.common.vector.exception.VectorClientExecuteException;
import work.slhaf.partner.common.vector.exception.VectorClientLoadFailedException;
import java.nio.file.Path;
import java.util.HashMap;
import java.util.Map;
@SuppressWarnings("FieldMayBeFinal")
public class OnnxVectorClient extends VectorClient {
private String tokenizerPath;
private String modelPath;
private HuggingFaceTokenizer tokenizer;
private OrtSession session;
private OrtEnvironment env;
protected OnnxVectorClient(String tokenizer, String model) {
this.tokenizerPath = tokenizer;
this.modelPath = model;
loadTokenizer();
loadModel();
compute("test");
}
private void loadModel() {
try {
env = OrtEnvironment.getEnvironment();
OrtSession.SessionOptions ops = new OrtSession.SessionOptions();
session = env.createSession(modelPath, ops);
} catch (Exception e) {
throw new VectorClientLoadFailedException("加载ONNX模型失败", e);
}
}
private void loadTokenizer() {
try {
tokenizer = HuggingFaceTokenizer.newInstance(Path.of(tokenizerPath));
} catch (Exception e) {
throw new VectorClientLoadFailedException("加载Tokenizer失败", e);
}
}
@Override
protected float[] doCompute(String input) {
try {
Encoding encode = tokenizer.encode(input);
long[] ids = encode.getIds();
long[] attentionMask = encode.getAttentionMask();
long[][] inputIdsBatch = {ids};
long[][] attentionMaskBatch = {attentionMask};
long[][] tokenTypeIdsBatch = {new long[ids.length]}; // 初始化全 0
for (int i = 0; i < ids.length; i++)
tokenTypeIdsBatch[0][i] = 0;
OnnxTensor inputTensor = OnnxTensor.createTensor(env, inputIdsBatch);
OnnxTensor maskTensor = OnnxTensor.createTensor(env, attentionMaskBatch);
OnnxTensor tokenTypeTensor = OnnxTensor.createTensor(env, tokenTypeIdsBatch);
Map<String, OnnxTensor> inputs = new HashMap<>();
inputs.put("input_ids", inputTensor);
inputs.put("attention_mask", maskTensor);
inputs.put("token_type_ids", tokenTypeTensor);
OrtSession.Result result = session.run(inputs);
OnnxTensor embeddingTensor = (OnnxTensor) result.get(0);
return embeddingTensor.getFloatBuffer().array();
} catch (Exception e) {
throw new VectorClientExecuteException("嵌入模型执行出错", e);
}
}
}

View File

@@ -0,0 +1,88 @@
package work.slhaf.partner.common.vector;
import lombok.extern.slf4j.Slf4j;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.ops.transforms.Transforms;
import work.slhaf.partner.api.agent.runtime.config.AgentConfigManager;
import work.slhaf.partner.common.config.Config.VectorConfig;
import work.slhaf.partner.common.config.PartnerAgentConfigManager;
import work.slhaf.partner.common.exception.ServiceLoadFailedException;
import work.slhaf.partner.common.vector.exception.VectorClientExecuteException;
import work.slhaf.partner.common.vector.exception.VectorClientLoadFailedException;
@Slf4j
public abstract class VectorClient {
public static boolean status;
public static VectorClient INSTANCE;
public static void load() {
PartnerAgentConfigManager configManager = (PartnerAgentConfigManager) AgentConfigManager.INSTANCE;
VectorConfig vectorConfig = configManager.getConfig().getVectorConfig();
int type = vectorConfig.getType();
try {
switch (type) {
case 0:
status = false;
break;
case 1:
status = true;
INSTANCE = new OllamaVectorClient(vectorConfig.getOllamaEmbeddingUrl(),
vectorConfig.getOllamaEmbeddingModel());
break;
case 2:
status = true;
INSTANCE = new OnnxVectorClient(vectorConfig.getTokenizerPath(),
vectorConfig.getEmbeddingModelPath());
break;
default:
throw new ServiceLoadFailedException(
"加载向量客户端失败! type: 0 -> 不启用语义缓存; type: 1 -> ollama; type: 2 -> ONNX RUNTIME");
}
log.info("向量客户端加载完毕");
} catch (VectorClientLoadFailedException | VectorClientExecuteException exception) {
status = false;
log.error("向量客户端加载失败", exception);
}
}
public float[] compute(String input) {
if (!status) {
return null;
}
return doCompute(input);
}
protected abstract float[] doCompute(String input);
public double compare(float[] v1, float[] v2) {
if (!status) {
return 0;
}
try (INDArray a1 = Nd4j.create(v1); INDArray a2 = Nd4j.create(v2)) {
return Transforms.cosineSim(a1, a2);
}
}
public float[] weightedAverage(float[] newVector, float[] primaryVector) {
try (INDArray primary = Nd4j.create(primaryVector);
INDArray latest = Nd4j.create(newVector)) {
// 1⃣ 计算余弦相似度
double similarity = Transforms.cosineSim(primary, latest);
// 2⃣ 根据相似度决定更新比例 α(差异越大,新输入影响越强)
double alpha = (1.0 - similarity) * 0.5;
alpha = Math.max(0.05, Math.min(alpha, 0.5));
// 3⃣ 按比例混合旧向量与新向量
INDArray updated = primary.mul(1 - alpha).add(latest.mul(alpha));
// 4⃣ 归一化结果(保持方向空间一致)
updated = updated.div(updated.norm2Number());
return updated.toFloatVector();
}
}
}

View File

@@ -0,0 +1,15 @@
package work.slhaf.partner.common.vector.exception;
import work.slhaf.partner.api.agent.runtime.exception.AgentRuntimeException;
public class VectorClientExecuteException extends AgentRuntimeException {
public VectorClientExecuteException(String message) {
super(message);
}
public VectorClientExecuteException(String message, Throwable cause) {
super(message, cause);
}
}

View File

@@ -0,0 +1,15 @@
package work.slhaf.partner.common.vector.exception;
import work.slhaf.partner.api.agent.runtime.exception.AgentRuntimeException;
public class VectorClientLoadFailedException extends AgentRuntimeException {
public VectorClientLoadFailedException(String message) {
super(message);
}
public VectorClientLoadFailedException(String message, Throwable cause) {
super(message, cause);
}
}

View File

@@ -0,0 +1,56 @@
package work.slhaf.partner.core;
import lombok.Data;
import lombok.extern.slf4j.Slf4j;
import work.slhaf.partner.api.agent.factory.capability.annotation.CoordinateManager;
import work.slhaf.partner.api.agent.factory.capability.annotation.Coordinated;
import work.slhaf.partner.api.chat.constant.ChatConstant;
import work.slhaf.partner.core.cognation.CognationCore;
import work.slhaf.partner.core.memory.MemoryCore;
import java.io.Serial;
import java.io.Serializable;
import java.util.HashSet;
import java.util.Set;
import static work.slhaf.partner.common.util.ExtractUtil.extractUserId;
@Data
@Slf4j
@CoordinateManager
public class CoordinatedManager implements Serializable {
@Serial
private static final long serialVersionUID = 1L;
//在框架将自动注入core,详见CapabilityRegistryFactory
private CognationCore cognationCore;
private MemoryCore memoryCore;
private boolean isCacheSingleUser() {
return memoryCore.getUserDialogMap().size() <= 1;
}
@Coordinated(capability = "cognation")
public boolean isSingleUser() {
return isCacheSingleUser() && isChatMessagesSingleUser();
}
private boolean isChatMessagesSingleUser() {
Set<String> userIdSet = new HashSet<>();
cognationCore.getChatMessages().forEach(m -> {
if (m.getRole().equals(ChatConstant.Character.ASSISTANT)) {
return;
}
String userId = extractUserId(m.getContent());
if (userId == null || userId.isEmpty()) {
return;
}
userIdSet.add(userId);
});
return userIdSet.size() <= 1;
}
}

View File

@@ -0,0 +1,94 @@
package work.slhaf.partner.core;
import cn.hutool.core.bean.BeanUtil;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.io.FileUtils;
import work.slhaf.partner.api.agent.runtime.config.AgentConfigManager;
import work.slhaf.partner.api.common.entity.PersistableObject;
import work.slhaf.partner.common.config.PartnerAgentConfigManager;
import java.io.*;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.nio.file.StandardCopyOption;
import static work.slhaf.partner.common.Constant.Path.MEMORY_DATA;
@Slf4j
public abstract class PartnerCore<T extends PartnerCore<T>> extends PersistableObject {
private final String id = ((PartnerAgentConfigManager) AgentConfigManager.INSTANCE).getConfig().getAgentId();
public PartnerCore() throws IOException, ClassNotFoundException {
createStorageDirectory();
Path filePath = getFilePath(id);
if (Files.exists(filePath)) {
T deserialize = deserialize();
setupData(deserialize, (T) this);
} else {
FileUtils.createParentDirectories(filePath.toFile().getParentFile());
this.serialize();
}
setupHook(this);
log.info("[{}] 注册完毕", getCoreKey());
}
private void setupHook(PartnerCore<T> temp) {
Runtime.getRuntime().addShutdownHook(new Thread(() -> {
try {
temp.serialize();
log.info("[{}] 已保存", getCoreKey());
} catch (IOException e) {
log.error("[{}] 保存失败: ", getCoreKey(), e);
}
}));
}
private void setupData(T source, T current) {
BeanUtil.copyProperties(source, current);
}
public void serialize() throws IOException {
//先写入到临时文件,如果正常写入则覆盖原文件
Path filePath = getFilePath(id + "-temp");
Files.createDirectories(Path.of(MEMORY_DATA));
try {
ObjectOutputStream oos = new ObjectOutputStream(new FileOutputStream(filePath.toFile()));
oos.writeObject(this);
oos.close();
Path path = getFilePath(id);
Files.move(filePath, path, StandardCopyOption.REPLACE_EXISTING);
log.info("[{}] 已保存到: {}", getCoreKey(), path);
} catch (IOException e) {
Files.delete(filePath);
log.error("[{}] 序列化保存失败: {}", getCoreKey(), e.getMessage());
}
}
private T deserialize() throws IOException, ClassNotFoundException {
Path filePath = getFilePath(id);
try (ObjectInputStream ois = new ObjectInputStream(
new FileInputStream(filePath.toFile()))) {
T graph = (T) ois.readObject();
log.info("[{}] 已从文件加载: {}", getCoreKey(), filePath);
return graph;
}
}
private Path getFilePath(String s) {
return Paths.get(MEMORY_DATA, s + "-" + getCoreKey() + ".memory");
}
private void createStorageDirectory() {
try {
Files.createDirectories(Paths.get(MEMORY_DATA));
} catch (IOException e) {
log.error("[{}]创建存储目录失败: {}", getCoreKey(), e.getMessage());
}
}
protected abstract String getCoreKey();
}

View File

@@ -0,0 +1,59 @@
package work.slhaf.partner.core.action;
import lombok.NonNull;
import org.jetbrains.annotations.Nullable;
import work.slhaf.partner.api.agent.factory.capability.annotation.Capability;
import work.slhaf.partner.core.action.entity.ExecutableAction;
import work.slhaf.partner.core.action.entity.MetaAction;
import work.slhaf.partner.core.action.entity.MetaActionInfo;
import work.slhaf.partner.core.action.entity.PhaserRecord;
import work.slhaf.partner.core.action.entity.cache.CacheAdjustData;
import work.slhaf.partner.core.action.runner.RunnerClient;
import work.slhaf.partner.module.modules.action.interventor.entity.MetaIntervention;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Phaser;
@Capability(value = "action")
public interface ActionCapability {
void putAction(@NonNull ExecutableAction executableAction);
Set<ExecutableAction> listActions(@Nullable ExecutableAction.Status status, @Nullable String source);
List<ExecutableAction> popPendingAction(String userId);
List<ExecutableAction> listPendingAction(String userId);
void putPendingActions(String userId, ExecutableAction executableAction);
List<String> selectTendencyCache(String input);
void updateTendencyCache(CacheAdjustData data);
ExecutorService getExecutor(ActionCore.ExecutorType type);
PhaserRecord putPhaserRecord(Phaser phaser, ExecutableAction executableAction);
void removePhaserRecord(Phaser phaser);
List<PhaserRecord> listPhaserRecords();
PhaserRecord getPhaserRecord(String tendency, String source);
MetaAction loadMetaAction(@NonNull String actionKey);
MetaActionInfo loadMetaActionInfo(@NonNull String actionKey);
Map<String, MetaActionInfo> listAvailableMetaActions();
boolean checkExists(String... actionKeys);
RunnerClient runnerClient();
void handleInterventions(List<MetaIntervention> interventions, ExecutableAction data);
}

View File

@@ -0,0 +1,446 @@
package work.slhaf.partner.core.action;
import lombok.NonNull;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import org.jetbrains.annotations.Nullable;
import work.slhaf.partner.api.agent.factory.capability.annotation.CapabilityCore;
import work.slhaf.partner.api.agent.factory.capability.annotation.CapabilityMethod;
import work.slhaf.partner.common.vector.VectorClient;
import work.slhaf.partner.core.PartnerCore;
import work.slhaf.partner.core.action.entity.ExecutableAction;
import work.slhaf.partner.core.action.entity.MetaAction;
import work.slhaf.partner.core.action.entity.MetaActionInfo;
import work.slhaf.partner.core.action.entity.PhaserRecord;
import work.slhaf.partner.core.action.entity.cache.ActionCacheData;
import work.slhaf.partner.core.action.entity.cache.CacheAdjustData;
import work.slhaf.partner.core.action.entity.cache.CacheAdjustMetaData;
import work.slhaf.partner.core.action.exception.ActionDataNotFoundException;
import work.slhaf.partner.core.action.exception.MetaActionNotFoundException;
import work.slhaf.partner.core.action.runner.RunnerClient;
import work.slhaf.partner.core.action.runner.SandboxRunnerClient;
import work.slhaf.partner.module.modules.action.interventor.entity.InterventionType;
import work.slhaf.partner.module.modules.action.interventor.entity.MetaIntervention;
import java.io.IOException;
import java.util.*;
import java.util.concurrent.*;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReentrantLock;
import java.util.stream.Collectors;
@SuppressWarnings("FieldMayBeFinal")
@CapabilityCore(value = "action")
@Slf4j
public class ActionCore extends PartnerCore<ActionCore> {
/**
* 持久行动池
*/
private CopyOnWriteArraySet<ExecutableAction> actionPool = new CopyOnWriteArraySet<>();
/**
* 待确认任务以userId区分不同用户因为需要跨请求确认
*/
private HashMap<String, List<ExecutableAction>> pendingActions = new HashMap<>();
/**
* 语义缓存与行为倾向映射
*/
private List<ActionCacheData> actionCache = new ArrayList<>();
private final Lock cacheLock = new ReentrantLock();
// 由于当前的执行器逻辑实现,平台线程池大小不得小于 2这里规定为最小为 4
private final ExecutorService platformExecutor = Executors
.newFixedThreadPool(Math.max(Runtime.getRuntime().availableProcessors(), 4));
private final ExecutorService virtualExecutor = Executors.newVirtualThreadPerTaskExecutor();
/**
* 已存在的行动程序,键格式为‘<MCP-ServerName>::<Tool-Name>’,值为 MCP Server 通过 Resources 相关渠道传递的行动程序元信息
*/
private final ConcurrentHashMap<String, MetaActionInfo> existedMetaActions = new ConcurrentHashMap<>();
private final List<PhaserRecord> phaserRecords = new ArrayList<>();
private RunnerClient runnerClient;
public ActionCore() throws IOException, ClassNotFoundException {
// TODO 通过 AgentConfigManager指定采用何种 runnerClient
runnerClient = new SandboxRunnerClient(existedMetaActions, virtualExecutor);
setupShutdownHook();
}
private void setupShutdownHook() {
// 将执行中的行动状态置为失败
val executingActionSet = listActions(ExecutableAction.Status.EXECUTING, null);
for (ExecutableAction executableAction : executingActionSet) {
executableAction.setStatus(ExecutableAction.Status.FAILED);
executableAction.setResult("由于系统中断而失败");
}
}
@CapabilityMethod
public void putAction(@NonNull ExecutableAction executableAction) {
actionPool.removeIf(data -> data.getUuid().equals(executableAction.getUuid())); // 用来应对 ScheduledActionData 的重新排列
actionPool.add(executableAction);
}
@CapabilityMethod
public Set<ExecutableAction> listActions(@Nullable ExecutableAction.Status status, @Nullable String source) {
return actionPool.stream()
.filter(actionData -> status == null || actionData.getStatus().equals(status))
.filter(actionData -> source == null || actionData.getSource().equals(source))
.collect(Collectors.toSet());
}
@CapabilityMethod
public synchronized void putPendingActions(String userId, ExecutableAction executableAction) {
pendingActions.computeIfAbsent(userId, k -> {
List<ExecutableAction> temp = new ArrayList<>();
temp.add(executableAction);
return temp;
});
}
@CapabilityMethod
public synchronized List<ExecutableAction> popPendingAction(String userId) {
List<ExecutableAction> infos = pendingActions.get(userId);
pendingActions.remove(userId);
return infos;
}
@CapabilityMethod
public List<ExecutableAction> listPendingAction(String userId) {
return pendingActions.get(userId);
}
/**
* 计算输入内容的语义向量,根据与{@link ActionCacheData#getInputVector()}的相似度挑取缓存,后续将根据评估结果来更新计数
*
* @param input 本次输入内容
* @return 命中的行为倾向集合
*/
@CapabilityMethod
public List<String> selectTendencyCache(String input) {
if (!VectorClient.status) {
return null;
}
VectorClient vectorClient = VectorClient.INSTANCE;
// 计算本次输入的向量
float[] vector = vectorClient.compute(input);
if (vector == null)
return null;
// 与现有缓存比对,将匹配到的收集并返回
return actionCache.parallelStream()
.filter(ActionCacheData::isActivated)
.filter(data -> {
double compared = vectorClient.compare(vector, data.getInputVector());
return compared > data.getThreshold();
})
.map(ActionCacheData::getTendency)
.collect(Collectors.toList());
}
@CapabilityMethod
public void updateTendencyCache(CacheAdjustData data) {
VectorClient vectorClient = VectorClient.INSTANCE;
List<CacheAdjustMetaData> list = data.getMetaDataList();
String input = data.getInput();
float[] inputVector = vectorClient.compute(input);
List<CacheAdjustMetaData> matchAndPassed = new ArrayList<>();
List<CacheAdjustMetaData> matchNotPassed = new ArrayList<>();
List<CacheAdjustMetaData> notMatchPassed = new ArrayList<>();
for (CacheAdjustMetaData metaData : list) {
if (metaData.isHit() && metaData.isPassed()) {
matchAndPassed.add(metaData);
} else if (metaData.isHit()) {
matchNotPassed.add(metaData);
} else if (!metaData.isPassed()) {
notMatchPassed.add(metaData);
}
}
platformExecutor.execute(() -> adjustMatchAndPassed(matchAndPassed, inputVector, input, vectorClient));
platformExecutor.execute(() -> adjustMatchNotPassed(matchNotPassed, vectorClient));
platformExecutor.execute(() -> adjustNotMatchPassed(notMatchPassed, inputVector, input, vectorClient));
}
@CapabilityMethod
public ExecutorService getExecutor(ExecutorType type) {
return switch (type) {
case VIRTUAL -> virtualExecutor;
case PLATFORM -> platformExecutor;
};
}
@CapabilityMethod
public Map<String, MetaActionInfo> listAvailableActions() {
return existedMetaActions;
}
@CapabilityMethod
public synchronized PhaserRecord putPhaserRecord(Phaser phaser, ExecutableAction executableAction) {
PhaserRecord record = new PhaserRecord(phaser, executableAction);
phaserRecords.add(record);
return record;
}
@CapabilityMethod
public synchronized void removePhaserRecord(Phaser phaser) {
PhaserRecord remove = null;
for (PhaserRecord record : phaserRecords) {
if (record.phaser().equals(phaser)) {
remove = record;
}
}
if (remove != null) {
phaserRecords.remove(remove);
}
}
@CapabilityMethod
public PhaserRecord getPhaserRecord(String tendency, String source) {
for (PhaserRecord record : phaserRecords) {
ExecutableAction data = record.executableAction();
if (data.getTendency().equals(tendency) && data.getSource().equals(source)) {
return record;
}
}
throw new ActionDataNotFoundException("未找到对应的 Phaser 记录: tendency=" + tendency + ", source=" + source);
}
@CapabilityMethod
public MetaAction loadMetaAction(@NonNull String actionKey) {
MetaActionInfo metaActionInfo = existedMetaActions.get(actionKey);
if (metaActionInfo == null) {
throw new MetaActionNotFoundException("未找到对应的行动程序信息" + actionKey);
}
String[] split = actionKey.split("::");
if (split.length < 2) {
throw new MetaActionNotFoundException("未找到对应的行动程序,原因: 传入的 actionKey(" + actionKey + ") 存在异常");
}
return new MetaAction(
split[1],
metaActionInfo.isIo(),
MetaAction.Type.MCP,
split[0]
);
}
@CapabilityMethod
public List<PhaserRecord> listPhaserRecords() {
return phaserRecords;
}
@CapabilityMethod
public MetaActionInfo loadMetaActionInfo(@NonNull String actionKey) {
MetaActionInfo info = existedMetaActions.get(actionKey);
if (info == null) {
throw new MetaActionNotFoundException("未找到对应的行动程序描述信息: " + actionKey);
}
return info;
}
@CapabilityMethod
public boolean checkExists(String... actionKeys) {
return existedMetaActions.keySet().containsAll(Arrays.asList(actionKeys));
}
@CapabilityMethod
public RunnerClient runnerClient() {
return runnerClient;
}
@CapabilityMethod
public void handleInterventions(List<MetaIntervention> interventions, ExecutableAction executableAction) {
// 加载数据
if (executableAction == null) {
return;
}
// 加锁确保同步
synchronized (executableAction.getStatus()) {
applyInterventions(interventions, executableAction);
}
}
private void applyInterventions(List<MetaIntervention> interventions, ExecutableAction executableAction) {
boolean rebuildCleanTag = false;
interventions.sort(Comparator.comparingInt(MetaIntervention::getOrder));
for (MetaIntervention intervention : interventions) {
List<MetaAction> actions = intervention.getActions()
.stream()
.map(this::loadMetaAction)
.toList();
switch (intervention.getType()) {
case InterventionType.APPEND -> handleAppend(executableAction, intervention.getOrder(), actions);
case InterventionType.INSERT -> handleInsert(executableAction, intervention.getOrder(), actions);
case InterventionType.DELETE -> handleDelete(executableAction, intervention.getOrder(), actions);
case InterventionType.CANCEL -> handleCancel(executableAction);
case InterventionType.REBUILD -> {
if (!rebuildCleanTag) {
cleanActionData(executableAction);
rebuildCleanTag = true;
}
handleRebuild(executableAction, intervention.getOrder(), actions);
}
}
}
}
/**
* 在未进入执行阶段的行动单元组新增新的行动
*/
private void handleAppend(ExecutableAction executableAction, int order, List<MetaAction> actions) {
if (order <= executableAction.getExecutingStage())
return;
executableAction.getActionChain().put(order, actions);
}
/**
* 在未进入执行阶段和正处于行动阶段的行动单元组插入新的行动
*/
private void handleInsert(ExecutableAction executableAction, int order, List<MetaAction> actions) {
if (order < executableAction.getExecutingStage())
return;
executableAction.getActionChain().computeIfAbsent(order, k -> new ArrayList<>()).addAll(actions);
}
private void handleDelete(ExecutableAction executableAction, int order, List<MetaAction> actions) {
if (order <= executableAction.getExecutingStage())
return;
Map<Integer, List<MetaAction>> actionChain = executableAction.getActionChain();
if (actionChain.containsKey(order)) {
actionChain.get(order).removeAll(actions);
if (actionChain.get(order).isEmpty()) {
actionChain.remove(order);
}
}
}
private void handleCancel(ExecutableAction executableAction) {
executableAction.setStatus(ExecutableAction.Status.FAILED);
executableAction.setResult("行动取消");
}
private void handleRebuild(ExecutableAction executableAction, int order, List<MetaAction> actions) {
Map<Integer, List<MetaAction>> actionChain = executableAction.getActionChain();
actionChain.put(order, actions);
}
private void cleanActionData(ExecutableAction executableAction) {
executableAction.getActionChain().clear();
executableAction.setExecutingStage(0);
executableAction.setStatus(ExecutableAction.Status.PREPARE);
executableAction.getHistory().clear();
}
/**
* 命中缓存且评估通过时
*
* @param matchAndPassed 该类型的带调整缓存信息列表
* @param inputVector 本次输入内容的语义向量
* @param vectorClient 向量客户端
*/
private void adjustMatchAndPassed(List<CacheAdjustMetaData> matchAndPassed, float[] inputVector, String input,
VectorClient vectorClient) {
matchAndPassed.forEach(adjustData -> {
// 获取原始缓存条目
String tendency = adjustData.getTendency();
ActionCacheData primaryCacheData = selectCacheData(tendency);
if (primaryCacheData == null) {
return;
}
primaryCacheData.updateAfterMatchAndPassed(inputVector, vectorClient, input);
});
}
/**
* 针对命中缓存、但评估未通过的条目与输入进行处理
*
* @param matchNotPassed 该类型的带调整缓存信息列表
* @param vectorClient 向量客户端
*/
private void adjustMatchNotPassed(List<CacheAdjustMetaData> matchNotPassed, VectorClient vectorClient) {
List<ActionCacheData> toRemove = new ArrayList<>();
matchNotPassed.forEach(adjustData -> {
// 获取原始缓存条目
String tendency = adjustData.getTendency();
ActionCacheData primaryCacheData = selectCacheData(tendency);
if (primaryCacheData == null) {
return;
}
boolean remove = primaryCacheData.updateAfterMatchNotPassed(vectorClient);
if (remove) {
toRemove.add(primaryCacheData);
}
});
cacheLock.lock();
actionCache.removeAll(toRemove);
cacheLock.unlock();
}
/**
* 针对未命中但评估通过的缓存做出调整:
* <ol>
* <h3>如果存在缓存条目</h3>
* <li>
* 若已生效,但此时未匹配到则说明尚未生效或者阈值、向量{@link ActionCacheData#getInputVector()}存在问题,调低阈值,同时带权移动平均
* </li>
* <li>
* 若未生效,则只增加计数并带权移动平均
* </li>
* </ol>
* 如果不存在缓存条目,则新增并填充字段
*
* @param notMatchPassed 该类型的带调整缓存信息列表
* @param inputVector 本次输入内容的语义向量
* @param input 本次输入内容
* @param vectorClient 向量客户端
*/
private void adjustNotMatchPassed(List<CacheAdjustMetaData> notMatchPassed, float[] inputVector, String input,
VectorClient vectorClient) {
notMatchPassed.forEach(adjustData -> {
// 获取原始缓存条目
String tendency = adjustData.getTendency();
ActionCacheData primaryCacheData = selectCacheData(tendency);
float[] tendencyVector = vectorClient.compute(tendency);
if (primaryCacheData == null) {
actionCache.add(new ActionCacheData(tendency, tendencyVector, inputVector, input));
return;
}
primaryCacheData.updateAfterNotMatchPassed(input, inputVector, tendencyVector, vectorClient);
});
}
private ActionCacheData selectCacheData(String tendency) {
for (ActionCacheData actionCacheData : actionCache) {
if (actionCacheData.getTendency().equals(tendency)) {
return actionCacheData;
}
}
log.warn("[{}] 未找到行为倾向[{}]对应的缓存条目,可能是代码逻辑存在错误", getCoreKey(), tendency);
return null;
}
@Override
protected String getCoreKey() {
return "action-core";
}
public enum ExecutorType {
VIRTUAL, PLATFORM
}
}

View File

@@ -0,0 +1,192 @@
package work.slhaf.partner.core.action.entity
import work.slhaf.partner.module.modules.action.dispatcher.executor.entity.HistoryAction
import java.time.ZonedDateTime
import java.util.*
sealed class Action {
/**
* 行动ID
*/
val uuid: String = UUID.randomUUID().toString()
/**
* 行动来源
*/
abstract val source: String
/**
* 行动原因
*/
abstract val reason: String
/**
* 行动描述
*/
abstract val description: String
}
sealed interface Schedulable {
val scheduleType: ScheduleType
val scheduleContent: String
val uuid: String
enum class ScheduleType {
CYCLE,
ONCE
}
}
/**
* 行动模块传递的行动数据包含行动uuid、倾向、状态、行动链、结果、发起原因、行动描述等信息。
*/
sealed class ExecutableAction : Action() {
/**
* 行动倾向
*/
abstract val tendency: String
/**
* 行动状态
*/
var status: Status = Status.PREPARE
/**
* 行动链
*/
abstract val actionChain: MutableMap<Int, MutableList<MetaAction>>
/**
* 行动阶段(当前阶段)
*/
var executingStage: Int = 0
/**
* 行动结果
*/
lateinit var result: String
val history: MutableMap<Int, MutableList<HistoryAction>> = mutableMapOf()
/**
* 修复上下文
*/
val additionalContext: MutableMap<Int, MutableList<String>> = mutableMapOf()
enum class Status {
/**
* 执行成功
*/
SUCCESS,
/**
* 执行失败
*/
FAILED,
/**
* 执行中
*/
EXECUTING,
/**
* 暂时中断
*/
INTERRUPTED,
/**
* 预备执行
*/
PREPARE
}
}
/**
* 计划行动数据类,继承自[Action],扩展了[Schedulable]相关调度属性,用于标识计划类型(单次还是周期性任务)和计划内容
*/
data class SchedulableExecutableAction(
override val tendency: String,
override val actionChain: MutableMap<Int, MutableList<MetaAction>>,
override val reason: String,
override val description: String,
override val source: String,
override val scheduleType: Schedulable.ScheduleType,
override val scheduleContent: String
) : ExecutableAction(), Schedulable {
val scheduleHistories = ArrayList<ScheduleHistory>()
fun recordAndReset() {
val newHistory = ScheduleHistory(ZonedDateTime.now(), result, history.toMap())
scheduleHistories.add(newHistory)
additionalContext.clear()
executingStage = 0
for (entry in actionChain) {
for (action in entry.value) {
action.params.clear()
action.result.reset()
}
}
status = Status.PREPARE
}
data class ScheduleHistory(
val endTime: ZonedDateTime,
val result: String,
val history: Map<Int, List<HistoryAction>>
)
}
/**
* 即时行动数据类
*/
data class ImmediateExecutableAction(
override val tendency: String,
override val actionChain: MutableMap<Int, MutableList<MetaAction>>,
override val reason: String,
override val description: String,
override val source: String,
) : ExecutableAction()
/**
* 用于计时的一次性触发或者针对某一数据源进行内容更新的行动
*/
data class StateAction(
override val source: String,
override val reason: String,
override val description: String,
override val scheduleType: Schedulable.ScheduleType,
override val scheduleContent: String,
val trigger: Trigger
) : Action(), Schedulable {
sealed interface Trigger {
fun onTrigger()
/**
* State 更新触发
*/
class Update<T>(val stateSource: T, val update: (stateSource: T) -> Unit) : Trigger {
override fun onTrigger() {
update(stateSource)
}
}
/**
* 常规逻辑触发
*/
class Call(val call: () -> Unit) : Trigger {
override fun onTrigger() {
call()
}
}
}
}

View File

@@ -0,0 +1,10 @@
package work.slhaf.partner.core.action.entity;
import lombok.Data;
@Data
public class ActionFileMetaData {
private String content;
private String name;
private String ext;
}

View File

@@ -0,0 +1,15 @@
package work.slhaf.partner.core.action.entity;
import com.alibaba.fastjson2.JSONObject;
import lombok.Data;
import java.util.List;
@Data
public class GeneratedData {
private List<String> dependencies;
private String code;
private String codeType;
private boolean serialize;
private JSONObject responseSchema;
}

View File

@@ -0,0 +1,5 @@
package work.slhaf.partner.core.action.entity;
public class McpData {
}

View File

@@ -0,0 +1,73 @@
package work.slhaf.partner.core.action.entity
/**
* 行动链中的单一元素,封装了调用外部行动程序的必要信息与结果容器,可被[work.slhaf.partner.core.action.ActionCapability]执行
*/
data class MetaAction(
/**
* 行动name用于标识行动程序
*/
val name: String,
/**
* 是否IO密集用于决定使用何种线程池
*/
val io: Boolean = false,
/**
* 行动程序类型,可分为 MCP、ORIGIN 两种,前者对应读取到的 MCP Tool、后者对应生成的临时行动程序
*/
val type: Type,
/**
* 当类型为 MCP 时,该字段对应相应 MCP Client 注册时生成的 id;
* 当类型为 ORIGIN 时,该字段对应相应的磁盘路径字符串
*/
val location: String,
) {
/**
* 行动程序可接受的参数,由调用处设置
*/
val params: MutableMap<String, Any> = mutableMapOf()
/**
* 行动结果,包括执行状态和相应内容(执行结果或者错误信息)
*/
val result = Result()
val key: String
/**
* actionKey 将由 location+name 共同定位
*
* @return actionKey
*/
get() = "$location::$name"
class Result {
var status = Status.WAITING
var data: String? = null
fun reset() {
status = Status.WAITING
data = null
}
enum class Status {
SUCCESS,
FAILED,
WAITING
}
}
enum class Type {
/**
* 将调用的 MCP 工具,可包括远程、本地任意服务
*/
MCP,
/**
* 适用于‘临时生成’的行动程序,在生成后根据序列化选项及执行情况,进行持久化
*/
ORIGIN
}
}

View File

@@ -0,0 +1,25 @@
package work.slhaf.partner.core.action.entity;
import com.alibaba.fastjson2.JSONObject;
import lombok.Data;
import java.util.List;
import java.util.Map;
@Data
public class MetaActionInfo {
private boolean io;
private Map<String, Object> params;
private String description;
private List<String> tags;
private List<String> preActions;
private List<String> postActions;
/**
* 是否严格依赖前置行动的成功执行若为true且前置行动失败则不执行该行动后置任务多为触发式。默认即执行。
*/
private boolean strictDependencies;
private JSONObject responseSchema;
}

View File

@@ -0,0 +1,33 @@
package work.slhaf.partner.core.action.entity;
import work.slhaf.partner.core.action.entity.ExecutableAction.Status;
import java.util.concurrent.Phaser;
public record PhaserRecord(Phaser phaser, ExecutableAction executableAction) {
public void fail() {
executableAction.setStatus(Status.FAILED);
}
/**
* 负责将 ActionData 的状态设置为 INTERRUPTED
* 同时循环检查进行阻塞
*/
public void interrupt() {
executableAction.setStatus(Status.INTERRUPTED);
while (executableAction().getStatus() == Status.INTERRUPTED) {
try {
Thread.sleep(500);
} catch (InterruptedException ignored) {
}
}
}
/**
* 将状态重新设置为 EXECUTING ,恢复 interrupt 阻塞状态
*/
public void complete() {
executableAction().setStatus(Status.EXECUTING);
}
}

View File

@@ -0,0 +1,181 @@
package work.slhaf.partner.core.action.entity.cache;
import lombok.Data;
import work.slhaf.partner.common.vector.VectorClient;
import java.util.ArrayList;
import java.util.List;
@Data
public class ActionCacheData {
private boolean activated = false;
private int inputMatchCount = 1;
private float[] inputVector;
private float[] tendencyVector;
private String tendency;
private double threshold = 0.75;
private List<String> validSamples = new ArrayList<>();
private int failedCount = 0;
private Type type = Type.PRIMARY;
public ActionCacheData(String tendency, float[] tendencyVector, float[] inputVector, String input) {
this.tendency = tendency;
this.inputVector = inputVector;
this.tendencyVector = tendencyVector;
this.validSamples.add(input);
}
/**
* 命中缓存且评估通过时,根据输入内容的语义向量与现有的输入语义向量进行带权移动平均,以相似度为权重,同时降低失败计数,为零时置为上一级缓存类型{@link ActionCacheData.Type}
*
* @param inputVector 本次输入内容对应的语义向量
* @param vectorClient 向量客户端
* @param input 本次输入内容
*/
public synchronized void updateAfterMatchAndPassed(float[] inputVector, VectorClient vectorClient, String input) {
updateInputVector(inputVector, vectorClient);
addValidSample(input);
reduceFailedCount();
updateType();
addInputMatchCount();
}
private void updateType() {
if (this.failedCount == 0) {
this.type = switch (type) {
case PRIMARY, REBUILD_V1 -> ActionCacheData.Type.PRIMARY;
case REBUILD_V2 -> ActionCacheData.Type.REBUILD_V1;
case REBUILD_V3 -> ActionCacheData.Type.REBUILD_V2;
};
}
}
private void reduceFailedCount() {
this.failedCount = Math.max(this.failedCount - 1, 0);
}
private void addValidSample(String input) {
if (this.validSamples.size() == 12) {
this.validSamples.removeFirst();
}
this.validSamples.add(input);
}
private void updateInputVector(float[] inputVector, VectorClient vectorClient) {
this.inputVector = vectorClient.weightedAverage(inputVector, this.inputVector);
}
/**
* 针对命中缓存、但评估未通过的条目与输入进行处理: 增加失败计数(必要时重建并更新类型等级)、调高阈值(0.02),由于缓存匹配但评估未通过,所以不进行带权移动平均
*
* @param vectorClient 向量客户端
* @return 是否需要删除(已在REBUILD_V3状态且达到最大误判次数的)
*/
public synchronized boolean updateAfterMatchNotPassed(VectorClient vectorClient) {
adjustThreshold();
addFailedCount();
if (this.failedCount < 3) {
return false;
}
if (this.type == Type.REBUILD_V3) {
return true;
}
rebuildAndSwitchType(vectorClient);
return false;
}
private void rebuildAndSwitchType(VectorClient vectorClient) {
this.type = switch (this.type) {
case PRIMARY -> {
//样本顺序反转后,以全部样本重建
this.validSamples = this.validSamples.reversed();
rebuildWithSamples(vectorClient);
yield Type.REBUILD_V1;
}
case REBUILD_V1 -> {
//截取后一半样本,反转后以此重建
List<String> temp = this.validSamples.subList(this.validSamples.size() / 2, this.validSamples.size());
this.validSamples = temp.reversed();
rebuildWithSamples(vectorClient);
yield Type.REBUILD_V2;
}
case REBUILD_V2 -> {
//截取后四分之一样本,反转后以此重建
List<String> temp = this.validSamples.subList(this.validSamples.size() / 4, this.validSamples.size());
this.validSamples = temp.reversed();
rebuildWithSamples(vectorClient);
yield Type.REBUILD_V3;
}
case REBUILD_V3 -> null;
};
//阈值减0.05,防止重建后一直升高
this.threshold = Math.max(this.threshold - 0.05, 0.75);
this.failedCount = 0;
}
private void rebuildWithSamples(VectorClient vectorClient) {
for (int i = 0; i < this.validSamples.size(); i++) {
String sample = this.validSamples.get(i);
if (i == 0) {
this.inputVector = vectorClient.compute(sample);
} else {
float[] newSampleVector = vectorClient.compute(sample);
this.inputVector = vectorClient.weightedAverage(this.inputVector, newSampleVector);
}
}
}
private void addFailedCount() {
this.failedCount = Math.min(this.failedCount + 1, 3);
}
private void adjustThreshold() {
double newThreshold = this.threshold + 0.03;
this.threshold = Math.min(newThreshold, 0.95);
}
/**
* 针对未命中但评估通过的已存在缓存做出调整:
* <ol>
* <li>
* 若已生效,但此时未匹配到则说明阈值或者向量{@link ActionCacheData#getInputVector()}存在问题,调低阈值,同时带权移动平均
* </li>
* <li>
* 若未生效,则只增加计数并带权移动平均
* </li>
* </ol>
*
* @param input 本次输入内容
* @param inputVector 本次输入内容对应的语义向量
* @param tendencyVector 本次倾向对应的语义向量
* @param vectorClient 向量客户端
*/
public synchronized void updateAfterNotMatchPassed(String input, float[] inputVector, float[] tendencyVector, VectorClient vectorClient) {
if (this.activated) {
reduceThreshold();
this.inputVector = vectorClient.weightedAverage(inputVector, this.inputVector);
} else {
addValidSample(input);
this.tendencyVector = vectorClient.weightedAverage(tendencyVector, this.tendencyVector);
addInputMatchCount();
}
}
private void reduceThreshold() {
double newThreshold = this.threshold - 0.02;
this.threshold = Math.max(newThreshold, 0.75);
}
private void addInputMatchCount() {
this.inputMatchCount += 1;
if (inputMatchCount >= 6) {
this.activated = true;
}
}
public enum Type {
PRIMARY, REBUILD_V1, REBUILD_V2, REBUILD_V3
}
}

View File

@@ -0,0 +1,11 @@
package work.slhaf.partner.core.action.entity.cache;
import lombok.Data;
import java.util.List;
@Data
public class CacheAdjustData {
private String input;
private List<CacheAdjustMetaData> metaDataList;
}

View File

@@ -0,0 +1,10 @@
package work.slhaf.partner.core.action.entity.cache;
import lombok.Data;
@Data
public class CacheAdjustMetaData {
private String tendency;
private boolean passed;
private boolean hit;
}

View File

@@ -0,0 +1,13 @@
package work.slhaf.partner.core.action.exception;
import work.slhaf.partner.api.agent.runtime.exception.AgentRuntimeException;
public class ActionDataNotFoundException extends AgentRuntimeException {
public ActionDataNotFoundException(String message) {
super(message);
}
public ActionDataNotFoundException(String message, Throwable cause) {
super(message, cause);
}
}

View File

@@ -0,0 +1,13 @@
package work.slhaf.partner.core.action.exception;
import work.slhaf.partner.api.agent.runtime.exception.AgentLaunchFailedException;
public class ActionInitFailedException extends AgentLaunchFailedException {
public ActionInitFailedException(String message, Throwable cause) {
super(message, cause);
}
public ActionInitFailedException(String message) {
super(message);
}
}

View File

@@ -0,0 +1,13 @@
package work.slhaf.partner.core.action.exception;
import work.slhaf.partner.api.agent.runtime.exception.AgentRuntimeException;
public class ActionLoadFailedException extends AgentRuntimeException {
public ActionLoadFailedException(String message) {
super(message);
}
public ActionLoadFailedException(String message, Throwable cause) {
super(message, cause);
}
}

View File

@@ -0,0 +1,13 @@
package work.slhaf.partner.core.action.exception;
import work.slhaf.partner.api.agent.runtime.exception.AgentRuntimeException;
public class ActionSerializeFailedException extends AgentRuntimeException {
public ActionSerializeFailedException(String message) {
super(message);
}
public ActionSerializeFailedException(String message, Throwable cause) {
super(message, cause);
}
}

View File

@@ -0,0 +1,13 @@
package work.slhaf.partner.core.action.exception;
import work.slhaf.partner.api.agent.runtime.exception.AgentRuntimeException;
public class MetaActionNotFoundException extends AgentRuntimeException {
public MetaActionNotFoundException(String message) {
super(message);
}
public MetaActionNotFoundException(String message, Throwable cause) {
super(message, cause);
}
}

View File

@@ -0,0 +1,107 @@
package work.slhaf.partner.core.action.runner;
import com.alibaba.fastjson2.JSONObject;
import io.modelcontextprotocol.server.McpStatelessAsyncServer;
import lombok.Data;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import org.jetbrains.annotations.Nullable;
import work.slhaf.partner.core.action.entity.ActionFileMetaData;
import work.slhaf.partner.core.action.entity.MetaAction;
import work.slhaf.partner.core.action.entity.MetaAction.Result;
import work.slhaf.partner.core.action.entity.MetaActionInfo;
import work.slhaf.partner.core.action.exception.ActionInitFailedException;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ExecutorService;
import static work.slhaf.partner.common.Constant.Path.DATA;
import static work.slhaf.partner.common.util.PathUtil.buildPathStr;
/**
* 执行客户端抽象类
* <br/>
* 只负责暴露序列化、执行等相应接口,具体逻辑交给下游实现
* <br/>
* 默认存在两类实现,{@link LocalRunnerClient} 和 {@link SandboxRunnerClient}
* <ol>
* LocalRunnerClient:
* <li>
* 对应本地运行环境,可在本地启动 MCP 客户端将 RunnerClient 暴露的能力接口转发至本地 MCP Client 并执行
* </li>
* SandboxRunnerClient:
* <li>
* 对应沙盒运行环境,该 Client 仅作为沙盒环境的客户端,不持有额外能力,仅保持远端连接已存在行动的内容更新
* </li>
* </ol>
*/
@Slf4j
public abstract class RunnerClient {
protected final String ACTION_PATH;
protected final ConcurrentHashMap<String, MetaActionInfo> existedMetaActions;
protected final ExecutorService executor;
//TODO 仍可提供内部 MCP但调用方式需要结合 AgentContext来获取否则生命周期不合
protected McpStatelessAsyncServer innerMcpServer;
/**
* ActionCore 将注入虚拟线程池
*/
public RunnerClient(ConcurrentHashMap<String, MetaActionInfo> existedMetaActions, ExecutorService executor, @Nullable String baseActionPath) {
this.existedMetaActions = existedMetaActions;
this.executor = executor;
baseActionPath = baseActionPath == null ? DATA : baseActionPath;
this.ACTION_PATH = buildPathStr(baseActionPath, "action");
createPath(ACTION_PATH);
}
/**
* 执行行动程序
*/
public void submit(MetaAction metaAction) {
// 获取已存在行动列表
Result result = metaAction.getResult();
if (!result.getStatus().equals(Result.Status.WAITING)) {
return;
}
RunnerResponse response = doRun(metaAction);
result.setData(response.getData());
result.setStatus(response.isOk() ? Result.Status.SUCCESS : Result.Status.FAILED);
}
protected abstract RunnerResponse doRun(MetaAction metaAction);
public abstract String buildTmpPath(String actionKey, String codeType);
public abstract void tmpSerialize(MetaAction tempAction, String code, String codeType) throws IOException;
public abstract void persistSerialize(MetaActionInfo metaActionInfo, ActionFileMetaData fileMetaData);
protected void createPath(String pathStr) {
val path = Path.of(pathStr);
try {
Files.createDirectory(path);
} catch (IOException e) {
if (!Files.exists(path)) {
throw new ActionInitFailedException("目录创建失败: " + pathStr, e);
}
}
}
/**
* 列出执行环境下的系统依赖情况
*/
public abstract JSONObject listSysDependencies();
@Data
public static class RunnerResponse {
private boolean ok;
private String data;
}
}

View File

@@ -0,0 +1,57 @@
package work.slhaf.partner.core.action.runner;
import com.alibaba.fastjson2.JSONObject;
import work.slhaf.partner.core.action.entity.ActionFileMetaData;
import work.slhaf.partner.core.action.entity.MetaAction;
import work.slhaf.partner.core.action.entity.MetaActionInfo;
import java.io.IOException;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ExecutorService;
/**
* 基于 Http 与 WebSocket 的沙盒执行器客户端,负责:
* <ul>
* <li>
* 发送行动单元数据
* </li>
* <li>
* 实时更新获取已存在行动列表
* </li>
* <li>
* 向传入的 MetaAction 回写执行结果
* </li>
* </ul>
*/
public class SandboxRunnerClient extends RunnerClient {
public SandboxRunnerClient(ConcurrentHashMap<String, MetaActionInfo> existedMetaActions, ExecutorService executor) { // 连接沙盒执行器(websocket)
super(existedMetaActions, executor, null);
}
protected RunnerResponse doRun(MetaAction metaAction) {
// 调用沙盒执行器
return null;
}
@Override
public JSONObject listSysDependencies() {
return null;
}
@Override
public String buildTmpPath(String actionKey, String codeType) {
throw new UnsupportedOperationException("Unimplemented method 'buildTmpPath'");
}
@Override
public void tmpSerialize(MetaAction tempAction, String code, String codeType) throws IOException {
throw new UnsupportedOperationException("Unimplemented method 'tmpSerialize'");
}
@Override
public void persistSerialize(MetaActionInfo metaActionInfo, ActionFileMetaData fileMetaData) {
throw new UnsupportedOperationException("Unimplemented method 'persistSerialize'");
}
}

View File

@@ -0,0 +1,28 @@
package work.slhaf.partner.core.cognation;
import work.slhaf.partner.api.agent.factory.capability.annotation.Capability;
import work.slhaf.partner.api.agent.factory.capability.annotation.ToCoordinated;
import work.slhaf.partner.api.chat.pojo.Message;
import work.slhaf.partner.api.chat.pojo.MetaMessage;
import java.util.HashMap;
import java.util.List;
import java.util.concurrent.locks.Lock;
@Capability("cognation")
public interface CognationCapability {
List<Message> getChatMessages();
void cleanMessage(List<Message> messages);
Lock getMessageLock();
void addMetaMessage(String userId, MetaMessage metaMessage);
List<Message> unpackAndClear(String userId);
void refreshMemoryId();
void resetLastUpdatedTime();
long getLastUpdatedTime();
HashMap<String,List<MetaMessage>> getSingleMetaMessageMap();
String getCurrentMemoryId();
@ToCoordinated
boolean isSingleUser();
}

View File

@@ -0,0 +1,116 @@
package work.slhaf.partner.core.cognation;
import com.alibaba.fastjson2.JSONObject;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.Setter;
import lombok.extern.slf4j.Slf4j;
import work.slhaf.partner.api.agent.factory.capability.annotation.CapabilityCore;
import work.slhaf.partner.api.agent.factory.capability.annotation.CapabilityMethod;
import work.slhaf.partner.api.chat.pojo.Message;
import work.slhaf.partner.api.chat.pojo.MetaMessage;
import work.slhaf.partner.core.PartnerCore;
import java.io.IOException;
import java.io.Serial;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.UUID;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReentrantLock;
@EqualsAndHashCode(callSuper = true)
@Slf4j
@CapabilityCore(value = "cognation")
@Getter
@Setter
public class CognationCore extends PartnerCore<CognationCore> {
@Serial
private static final long serialVersionUID = 1L;
private final ReentrantLock messageLock = new ReentrantLock();
/**
* 主模型的聊天记录
*/
private List<Message> chatMessages = new ArrayList<>();
private HashMap<String /*startUserId*/, List<MetaMessage>> singleMetaMessageMap = new HashMap<>();
private String currentMemoryId;
private long lastUpdatedTime;
public CognationCore() throws IOException, ClassNotFoundException {
}
@CapabilityMethod
public List<Message> getChatMessages() {
return chatMessages;
}
@CapabilityMethod
public long getLastUpdatedTime(){
return lastUpdatedTime;
}
@CapabilityMethod
public HashMap<String,List<MetaMessage>> getSingleMetaMessageMap(){
return singleMetaMessageMap;
}
@CapabilityMethod
public String getCurrentMemoryId(){
return currentMemoryId;
}
@CapabilityMethod
public void cleanMessage(List<Message> messages) {
messageLock.lock();
this.getChatMessages().removeAll(messages);
messageLock.unlock();
}
@CapabilityMethod
public Lock getMessageLock() {
return messageLock;
}
@CapabilityMethod
public void addMetaMessage(String userId, MetaMessage metaMessage) {
log.debug("[{}] 当前会话历史: {}", getCoreKey(), JSONObject.toJSONString(singleMetaMessageMap));
if (singleMetaMessageMap.containsKey(userId)) {
singleMetaMessageMap.get(userId).add(metaMessage);
} else {
singleMetaMessageMap.put(userId, new java.util.ArrayList<>());
singleMetaMessageMap.get(userId).add(metaMessage);
}
log.debug("[{}] 会话历史更新: {}", getCoreKey(), JSONObject.toJSONString(singleMetaMessageMap));
}
@CapabilityMethod
public List<Message> unpackAndClear(String userId) {
List<Message> messages = new ArrayList<>();
for (MetaMessage metaMessage : singleMetaMessageMap.get(userId)) {
messages.add(metaMessage.getUserMessage());
messages.add(metaMessage.getAssistantMessage());
}
singleMetaMessageMap.remove(userId);
return messages;
}
@CapabilityMethod
public void refreshMemoryId() {
currentMemoryId = UUID.randomUUID().toString();
}
@CapabilityMethod
public void resetLastUpdatedTime() {
lastUpdatedTime = System.currentTimeMillis();
}
@Override
protected String getCoreKey() {
return "cognation-core";
}
}

View File

@@ -0,0 +1,7 @@
package work.slhaf.partner.core.cognation.exception;
public class UserNotExistsException extends RuntimeException {
public UserNotExistsException(String message) {
super(message);
}
}

View File

@@ -0,0 +1,51 @@
package work.slhaf.partner.core.memory;
import work.slhaf.partner.api.agent.factory.capability.annotation.Capability;
import work.slhaf.partner.core.memory.pojo.EvaluatedSlice;
import work.slhaf.partner.core.memory.pojo.MemoryResult;
import work.slhaf.partner.core.memory.pojo.MemorySlice;
import java.time.LocalDate;
import java.time.LocalDateTime;
import java.util.HashMap;
import java.util.List;
import java.util.concurrent.ConcurrentHashMap;
@Capability(value = "memory")
public interface MemoryCapability {
void cleanSelectedSliceFilter();
String getTopicTree();
HashMap<LocalDateTime, String> getDialogMap();
ConcurrentHashMap<LocalDateTime, String> getUserDialogMap(String userId);
void updateDialogMap(LocalDateTime dateTime, String newDialogCache);
String getDialogMapStr();
String getUserDialogMapStr(String userId);
void updateActivatedSlices(String userId, List<EvaluatedSlice> memorySlices);
String getActivatedSlicesStr(String userId);
HashMap<String, List<EvaluatedSlice>> getActivatedSlices();
void clearActivatedSlices(String userId);
boolean hasActivatedSlices(String userId);
int getActivatedSlicesSize(String userId);
List<EvaluatedSlice> getActivatedSlices(String userId);
MemoryResult selectMemory(String topicPathStr);
MemoryResult selectMemory(LocalDate date);
void insertSlice(MemorySlice memorySlice, String topicPath);
}

View File

@@ -0,0 +1,613 @@
package work.slhaf.partner.core.memory;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.Setter;
import lombok.extern.slf4j.Slf4j;
import work.slhaf.partner.api.agent.factory.capability.annotation.CapabilityCore;
import work.slhaf.partner.api.agent.factory.capability.annotation.CapabilityMethod;
import work.slhaf.partner.core.PartnerCore;
import work.slhaf.partner.core.memory.exception.UnExistedDateIndexException;
import work.slhaf.partner.core.memory.exception.UnExistedTopicException;
import work.slhaf.partner.core.memory.pojo.EvaluatedSlice;
import work.slhaf.partner.core.memory.pojo.MemoryResult;
import work.slhaf.partner.core.memory.pojo.MemorySlice;
import work.slhaf.partner.core.memory.pojo.MemorySliceResult;
import work.slhaf.partner.core.memory.pojo.node.MemoryNode;
import work.slhaf.partner.core.memory.pojo.node.TopicNode;
import java.io.IOException;
import java.io.Serial;
import java.time.LocalDate;
import java.time.LocalDateTime;
import java.util.*;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReentrantLock;
@EqualsAndHashCode(callSuper = true)
@CapabilityCore(value = "memory")
@Getter
@Setter
@Slf4j
public class MemoryCore extends PartnerCore<MemoryCore> {
@Serial
private static final long serialVersionUID = 1L;
/**
* key: 根主题名称 value: 根主题节点
*/
private HashMap<String, TopicNode> topicNodes = new HashMap<>();
/**
* 用于存储已存在的主题列表,便于记忆查找, 使用根主题名称作为键, 子主题名称集合为值
* 该部分在'主题提取LLM'的system prompt中常驻
*/
private HashMap<String /*根主题名*/, LinkedHashSet<String> /*子主题列表*/> existedTopics = new HashMap<>();
/**
* 临时的同一对话切片容器, 用于为同一对话内的不同切片提供更新上下文的场所
*/
private HashMap<String /*对话id, 即slice中的字段'memoryId'*/, List<MemorySlice>> currentDateDialogSlices = new HashMap<>();
/**
* 记忆节点的日期索引, 同一日期内按照对话id区分
*/
private HashMap<LocalDate, Set<String>> dateIndex = new HashMap<>();
/**
* 已被选中的切片时间戳集合,需要及时清理
*/
private Set<Long> selectedSlices = new HashSet<>();
private HashMap<String, List<String>> userIndex = new HashMap<>();
private MemoryCache cache = new MemoryCache();
private final Lock sliceInsertLock = new ReentrantLock();
public MemoryCore() throws IOException, ClassNotFoundException {
}
@CapabilityMethod
public MemoryResult selectMemory(LocalDate date) throws IOException, ClassNotFoundException {
MemoryResult memoryResult = new MemoryResult();
CopyOnWriteArrayList<MemorySliceResult> targetSliceList = new CopyOnWriteArrayList<>();
//加载节点并获取记忆切片列表
List<List<MemorySlice>> currentDateDialogSlices = loadSlicesByDate(date);
for (List<MemorySlice> value : currentDateDialogSlices) {
for (MemorySlice memorySlice : value) {
if (selectedSlices.contains(memorySlice.getTimestamp())) {
continue;
}
MemorySliceResult memorySliceResult = new MemorySliceResult();
memorySliceResult.setMemorySlice(memorySlice);
targetSliceList.add(memorySliceResult);
selectedSlices.add(memorySlice.getTimestamp());
}
}
memoryResult.setMemorySliceResult(targetSliceList);
return cacheFilter(memoryResult);
}
@CapabilityMethod
public void insertSlice(MemorySlice memorySlice, String topicPath) {
sliceInsertLock.lock();
List<String> topicPathList = Arrays.stream(topicPath.split("->")).toList();
try {
//检查是否存在当天对应的memorySlice并确定是否插入
//每日刷新缓存
checkCacheDate();
//如果topicPath在memorySliceCache中存在对应缓存由于进行的插入操作则需要移除该缓存但不清除相关计数
clearCacheByTopicPath(topicPathList);
insertMemory(topicPathList, memorySlice);
if (!memorySlice.isPrivate()) {
updateUserDialogMap(memorySlice);
}
} catch (Exception e) {
log.error("[CoordinatedManager] 插入记忆时出错: ", e);
}
log.debug("[CoordinatedManager] 插入切片: {}, 路径: {}", memorySlice, topicPath);
sliceInsertLock.unlock();
}
@CapabilityMethod
public String getTopicTree() {
StringBuilder stringBuilder = new StringBuilder();
for (Map.Entry<String, TopicNode> entry : topicNodes.entrySet()) {
String rootName = entry.getKey();
TopicNode rootNode = entry.getValue();
stringBuilder.append(rootName).append("[root]").append("\r\n");
printSubTopicsTreeFormat(rootNode, "", stringBuilder);
}
return stringBuilder.toString();
}
@CapabilityMethod
public void updateDialogMap(LocalDateTime dateTime, String newDialogCache) {
List<LocalDateTime> keysToRemove = new ArrayList<>();
HashMap<LocalDateTime, String> dialogMap = cache.dialogMap;
dialogMap.forEach((k, v) -> {
if (dateTime.minusDays(2).isAfter(k)) {
keysToRemove.add(k);
}
});
for (LocalDateTime temp : keysToRemove) {
dialogMap.remove(temp);
}
keysToRemove.clear();
//放入新缓存
dialogMap.put(dateTime, newDialogCache);
}
@CapabilityMethod
public HashMap<LocalDateTime, String> getDialogMap() {
return cache.dialogMap;
}
@CapabilityMethod
public ConcurrentHashMap<LocalDateTime, String> getUserDialogMap(String userId) {
return cache.userDialogMap.get(userId);
}
@CapabilityMethod
public String getDialogMapStr() {
StringBuilder str = new StringBuilder();
this.getDialogMap().forEach((dateTime, dialog) -> str.append("\n\n").append("[").append(dateTime).append("]\n")
.append(dialog));
return str.toString();
}
@CapabilityMethod
public String getUserDialogMapStr(String userId) {
ConcurrentHashMap<String, ConcurrentHashMap<LocalDateTime, String>> userDialogMap = cache.userDialogMap;
if (userDialogMap.containsKey(userId)) {
StringBuilder str = new StringBuilder();
Collection<String> dialogMapValues = this.getDialogMap().values();
userDialogMap.get(userId).forEach((dateTime, dialog) -> {
if (dialogMapValues.contains(dialog)) {
return;
}
str.append("\n\n").append("[").append(dateTime).append("]\n")
.append(dialog);
});
return str.toString();
} else {
return null;
}
}
@CapabilityMethod
public MemoryResult selectMemory(String topicPathStr) {
MemoryResult memoryResult;
List<String> topicPath = List.of(topicPathStr.split("->"));
try {
List<String> path = new ArrayList<>(topicPath);
//每日刷新缓存
checkCacheDate();
//检测缓存并更新计数, 查看是否需要放入缓存
updateCacheCounter(path);
//查看是否存在缓存,如果存在,则直接返回
if ((memoryResult = selectCache(path)) != null) {
return memoryResult;
}
memoryResult = selectMemory(path);
//尝试更新缓存
updateCache(topicPath, memoryResult);
} catch (Exception e) {
log.error("[{}] selectMemory error: ", getCoreKey(), e);
log.error("[{}] 路径: {}", getCoreKey(), topicPathStr);
log.error("[{}] 主题树: {}", getCoreKey(), getTopicTree());
memoryResult = new MemoryResult();
memoryResult.setRelatedMemorySliceResult(new ArrayList<>());
memoryResult.setMemorySliceResult(new CopyOnWriteArrayList<>());
}
return cacheFilter(memoryResult);
}
@CapabilityMethod
public void updateActivatedSlices(String userId, List<EvaluatedSlice> memorySlices) {
cache.activatedSlices.put(userId, memorySlices);
log.debug("[{}] 已更新激活切片, userId: {}", getCoreKey(), userId);
}
@CapabilityMethod
public String getActivatedSlicesStr(String userId) {
HashMap<String, List<EvaluatedSlice>> activatedSlices = cache.activatedSlices;
if (activatedSlices.containsKey(userId)) {
StringBuilder str = new StringBuilder();
activatedSlices.get(userId).forEach(slice -> str.append("\n\n").append("[").append(slice.getDate()).append("]\n")
.append(slice.getSummary()));
return str.toString();
} else {
return null;
}
}
@CapabilityMethod
public HashMap<String, List<EvaluatedSlice>> getActivatedSlices() {
return cache.activatedSlices;
}
@CapabilityMethod
public void clearActivatedSlices(String userId) {
cache.activatedSlices.remove(userId);
}
@CapabilityMethod
public boolean hasActivatedSlices(String userId) {
HashMap<String, List<EvaluatedSlice>> activatedSlices = cache.activatedSlices;
if (!activatedSlices.containsKey(userId)) {
return false;
}
return !activatedSlices.get(userId).isEmpty();
}
@CapabilityMethod
public int getActivatedSlicesSize(String userId) {
return cache.activatedSlices.get(userId).size();
}
@CapabilityMethod
public List<EvaluatedSlice> getActivatedSlices(String userId) {
return cache.activatedSlices.get(userId);
}
@CapabilityMethod
public void cleanSelectedSliceFilter() {
this.selectedSlices.clear();
}
private List<List<MemorySlice>> loadSlicesByDate(LocalDate date) throws IOException, ClassNotFoundException {
if (!dateIndex.containsKey(date)) {
throw new UnExistedDateIndexException("不存在的日期索引: " + date);
}
List<List<MemorySlice>> list = new ArrayList<>();
for (String memoryNodeId : dateIndex.get(date)) {
MemoryNode memoryNode = new MemoryNode();
memoryNode.setMemoryNodeId(memoryNodeId);
list.add(memoryNode.loadMemorySliceList());
}
return list;
}
private void printSubTopicsTreeFormat(TopicNode node, String prefix, StringBuilder stringBuilder) {
if (node.getTopicNodes() == null || node.getTopicNodes().isEmpty()) return;
List<Map.Entry<String, TopicNode>> entries = new ArrayList<>(node.getTopicNodes().entrySet());
for (int i = 0; i < entries.size(); i++) {
boolean last = (i == entries.size() - 1);
Map.Entry<String, TopicNode> entry = entries.get(i);
stringBuilder.append(prefix).append(last ? "└── " : "├── ").append(entry.getKey()).append("[").append(entry.getValue().getMemoryNodes().size()).append("]").append("\r\n");
printSubTopicsTreeFormat(entry.getValue(), prefix + (last ? " " : ""), stringBuilder);
}
}
private void insertMemory(List<String> topicPath, MemorySlice slice) throws IOException, ClassNotFoundException {
LocalDate now = LocalDate.now();
boolean hasSlice = false;
MemoryNode node = null;
TopicNode lastTopicNode = generateTopicPath(topicPath);
for (MemoryNode memoryNode : lastTopicNode.getMemoryNodes()) {
if (now.equals(memoryNode.getLocalDate())) {
hasSlice = true;
node = memoryNode;
break;
}
}
if (!hasSlice) {
node = new MemoryNode();
node.setLocalDate(now);
node.setMemoryNodeId(UUID.randomUUID().toString());
node.setMemorySliceList(new CopyOnWriteArrayList<>());
lastTopicNode.getMemoryNodes().add(node);
lastTopicNode.getMemoryNodes().sort(null);
}
node.loadMemorySliceList().add(slice);
//生成relatedTopicPath
for (List<String> relatedTopic : slice.getRelatedTopics()) {
generateTopicPath(relatedTopic);
}
updateSlicePrecedent(slice);
updateDateIndex(slice);
updateUserIndex(slice);
node.saveMemorySliceList();
}
private void updateUserIndex(MemorySlice slice) {
String memoryId = slice.getMemoryId();
String userId = slice.getStartUserId();
if (!userIndex.containsKey(userId)) {
List<String> memoryIdSet = new ArrayList<>();
memoryIdSet.add(memoryId);
userIndex.put(userId, memoryIdSet);
} else {
userIndex.get(userId).add(memoryId);
}
}
private TopicNode generateTopicPath(List<String> topicPath) {
topicPath = new ArrayList<>(topicPath);
//查看是否存在根主题节点
String rootTopic = topicPath.getFirst();
topicPath.removeFirst();
if (!topicNodes.containsKey(rootTopic)) {
synchronized (this) {
if (!topicNodes.containsKey(rootTopic)) {
TopicNode rootNode = new TopicNode();
topicNodes.put(rootTopic, rootNode);
existedTopics.put(rootTopic, new LinkedHashSet<>());
}
}
}
TopicNode current = topicNodes.get(rootTopic);
Set<String> existedTopicNodes = existedTopics.get(rootTopic);
for (String topic : topicPath) {
if (existedTopicNodes.contains(topic) && current.getTopicNodes().containsKey(topic)) {
current = current.getTopicNodes().get(topic);
} else {
TopicNode newNode = new TopicNode();
current.getTopicNodes().put(topic, newNode);
current = newNode;
current.setMemoryNodes(new CopyOnWriteArrayList<>());
current.setTopicNodes(new ConcurrentHashMap<>());
existedTopicNodes.add(topic);
}
}
return current;
}
private void updateSlicePrecedent(MemorySlice slice) {
String memoryId = slice.getMemoryId();
//查看是否切换了memoryId
if (!currentDateDialogSlices.containsKey(memoryId)) {
List<MemorySlice> memorySliceList = new ArrayList<>();
currentDateDialogSlices.clear();
currentDateDialogSlices.put(memoryId, memorySliceList);
}
//处理上下文关系
List<MemorySlice> memorySliceList = currentDateDialogSlices.get(memoryId);
if (memorySliceList.isEmpty()) {
memorySliceList.add(slice);
} else {
//排序
memorySliceList.sort(null);
MemorySlice tempSlice = memorySliceList.getLast();
//设置私密状态一致
tempSlice.setPrivate(slice.isPrivate());
//末尾切片添加当前切片的引用
tempSlice.setSliceAfter(slice);
//当前切片添加前序切片的引用
slice.setSliceBefore(tempSlice);
}
}
private void updateDateIndex(MemorySlice slice) {
String memoryId = slice.getMemoryId();
LocalDate date = LocalDate.now();
if (!dateIndex.containsKey(date)) {
HashSet<String> memoryIdSet = new HashSet<>();
memoryIdSet.add(memoryId);
dateIndex.put(date, memoryIdSet);
} else {
dateIndex.get(date).add(memoryId);
}
}
public MemoryResult selectMemory(List<String> path) throws IOException, ClassNotFoundException {
MemoryResult memoryResult = new MemoryResult();
CopyOnWriteArrayList<MemorySliceResult> targetSliceList = new CopyOnWriteArrayList<>();
String targetTopic = path.getLast();
TopicNode targetParentNode = getTargetParentNode(path, targetTopic);
List<List<String>> relatedTopics = new ArrayList<>();
//终点记忆节点
MemorySliceResult sliceResult = new MemorySliceResult();
for (MemoryNode memoryNode : targetParentNode.getTopicNodes().get(targetTopic).getMemoryNodes()) {
List<MemorySlice> endpointMemorySliceList = memoryNode.loadMemorySliceList();
for (MemorySlice memorySlice : endpointMemorySliceList) {
if (selectedSlices.contains(memorySlice.getTimestamp())) {
continue;
}
sliceResult.setSliceBefore(memorySlice.getSliceBefore());
sliceResult.setMemorySlice(memorySlice);
sliceResult.setSliceAfter(memorySlice.getSliceAfter());
targetSliceList.add(sliceResult);
selectedSlices.add(memorySlice.getTimestamp());
}
for (MemorySlice memorySlice : endpointMemorySliceList) {
if (memorySlice.getRelatedTopics() != null) {
relatedTopics.addAll(memorySlice.getRelatedTopics());
}
}
}
memoryResult.setMemorySliceResult(targetSliceList);
//邻近节点
List<MemorySlice> relatedMemorySlice = new ArrayList<>();
//邻近记忆节点 联系
for (List<String> relatedTopic : relatedTopics) {
List<String> tempTopicPath = new ArrayList<>(relatedTopic);
String tempTargetTopic = tempTopicPath.getLast();
TopicNode tempTargetParentNode = getTargetParentNode(tempTopicPath, tempTargetTopic);
//获取终点节点及其最新记忆节点
TopicNode tempTargetNode = tempTargetParentNode.getTopicNodes().get(tempTopicPath.getLast());
setRelatedMemorySlices(tempTargetNode, relatedMemorySlice);
}
//邻近记忆节点 父级
setRelatedMemorySlices(targetParentNode, relatedMemorySlice);
//将上述结果包装为MemoryResult
memoryResult.setRelatedMemorySliceResult(relatedMemorySlice);
return memoryResult;
}
private void setRelatedMemorySlices(TopicNode targetParentNode, List<MemorySlice> relatedMemorySlice) throws IOException, ClassNotFoundException {
List<MemoryNode> targetParentMemoryNodes = targetParentNode.getMemoryNodes();
if (!targetParentMemoryNodes.isEmpty()) {
for (MemorySlice memorySlice : targetParentMemoryNodes.getFirst().loadMemorySliceList()) {
if (selectedSlices.contains(memorySlice.getTimestamp())) {
continue;
}
relatedMemorySlice.add(memorySlice);
selectedSlices.add(memorySlice.getTimestamp());
}
}
}
private TopicNode getTargetParentNode(List<String> topicPath, String targetTopic) {
String topTopic = topicPath.getFirst();
if (!existedTopics.containsKey(topTopic)) {
throw new UnExistedTopicException("不存在的主题: " + topTopic);
}
TopicNode targetParentNode = topicNodes.get(topTopic);
topicPath.removeFirst();
for (String topic : topicPath) {
if (!existedTopics.get(topTopic).contains(topic)) {
throw new UnExistedTopicException("不存在的主题: " + topTopic);
}
}
//逐层查找目标主题
while (!targetParentNode.getTopicNodes().containsKey(targetTopic)) {
targetParentNode = targetParentNode.getTopicNodes().get(topicPath.getFirst());
topicPath.removeFirst();
}
return targetParentNode;
}
private void updateCacheCounter(List<String> topicPath) {
ConcurrentHashMap<List<String>, Integer> memoryNodeCacheCounter = cache.memoryNodeCacheCounter;
if (memoryNodeCacheCounter.containsKey(topicPath)) {
Integer tempCount = memoryNodeCacheCounter.get(topicPath);
memoryNodeCacheCounter.put(topicPath, ++tempCount);
} else {
memoryNodeCacheCounter.put(topicPath, 1);
}
}
private void checkCacheDate() {
if (cache.cacheDate == null || cache.cacheDate.isBefore(LocalDate.now())) {
cache.memorySliceCache.clear();
cache.memoryNodeCacheCounter.clear();
cache.cacheDate = LocalDate.now();
}
}
private void updateCache(List<String> topicPath, MemoryResult memoryResult) {
ConcurrentHashMap<List<String>, Integer> memoryNodeCacheCounter = cache.memoryNodeCacheCounter;
Integer tempCount = memoryNodeCacheCounter.get(topicPath);
if (tempCount == null) {
log.warn("[CacheCore] tempCount为null? memoryNodeCacheCounter: {}; topicPath: {}", memoryNodeCacheCounter, topicPath);
return;
}
if (tempCount >= 5) {
cache.memorySliceCache.put(topicPath, memoryResult);
}
}
private void updateUserDialogMap(MemorySlice slice) {
String summary = slice.getSummary();
LocalDateTime now = LocalDateTime.now();
ConcurrentHashMap<String, ConcurrentHashMap<LocalDateTime, String>> userDialogMap = cache.userDialogMap;
//更新userDialogMap
//移除两天前上下文缓存(切片总结)
List<LocalDateTime> keysToRemove = new ArrayList<>();
userDialogMap.forEach((k, v) -> v.forEach((i, j) -> {
if (now.minusDays(2).isAfter(i)) {
keysToRemove.add(i);
}
}));
for (LocalDateTime dateTime : keysToRemove) {
userDialogMap.forEach((k, v) -> v.remove(dateTime));
}
//放入新缓存
userDialogMap
.computeIfAbsent(slice.getStartUserId(), k -> new ConcurrentHashMap<>())
.merge(now, summary, (oldVal, newVal) -> oldVal + " " + newVal);
}
private void clearCacheByTopicPath(List<String> topicPath) {
cache.memorySliceCache.remove(topicPath);
}
private MemoryResult selectCache(List<String> path) {
ConcurrentHashMap<List<String>, MemoryResult> memorySliceCache = cache.memorySliceCache;
if (memorySliceCache.containsKey(path)) {
return memorySliceCache.get(path);
}
return null;
}
@Override
protected String getCoreKey() {
return "memory-core";
}
public ConcurrentHashMap<String, ConcurrentHashMap<LocalDateTime, String>> getUserDialogMap() {
return cache.userDialogMap;
}
private MemoryResult cacheFilter(MemoryResult memoryResult) {
//过滤掉与缓存重复的切片
CopyOnWriteArrayList<MemorySliceResult> memorySliceResult = memoryResult.getMemorySliceResult();
List<MemorySlice> relatedMemorySliceResult = memoryResult.getRelatedMemorySliceResult();
cache.dialogMap.forEach((k, v) -> {
memorySliceResult.removeIf(m -> m.getMemorySlice().getSummary().equals(v));
relatedMemorySliceResult.removeIf(m -> m.getSummary().equals(v));
});
return memoryResult;
}
@SuppressWarnings("FieldMayBeFinal")
private static class MemoryCache {
/**
* 近两日的对话总结缓存, 用于为大模型提供必要的记忆补充, hashmap以切片的存储时间为键总结为值
* 该部分作为'主LLM'system prompt常驻
* 该部分作为近两日的整体对话缓存, 不区分用户
*/
private HashMap<LocalDateTime, String> dialogMap = new HashMap<>();
/**
* 近两日的区分用户的对话总结缓存在prompt结构上比dialogMap层级深一层, dialogMap更具近两日整体对话的摘要性质
*/
private ConcurrentHashMap<String/*userId*/, ConcurrentHashMap<LocalDateTime, String>> userDialogMap = new ConcurrentHashMap<>();
/**
* memorySliceCache计数器每日清空
*/
private ConcurrentHashMap<List<String> /*触发查询的主题列表*/, Integer> memoryNodeCacheCounter = new ConcurrentHashMap<>();
/**
* 记忆切片缓存,每日清空
* 用于记录作为终点节点调用次数最多的记忆节点的切片数据
*/
private ConcurrentHashMap<List<String> /*主题路径*/, MemoryResult /*切片列表*/> memorySliceCache = new ConcurrentHashMap<>();
/**
* 缓存日期
*/
private LocalDate cacheDate;
private HashMap<String, List<EvaluatedSlice>> activatedSlices = new HashMap<>();
private MemoryCache() {
}
}
}

View File

@@ -0,0 +1,7 @@
package work.slhaf.partner.core.memory.exception;
public class NullSliceListException extends RuntimeException {
public NullSliceListException(String message) {
super(message);
}
}

View File

@@ -0,0 +1,7 @@
package work.slhaf.partner.core.memory.exception;
public class UnExistedDateIndexException extends RuntimeException {
public UnExistedDateIndexException(String message) {
super(message);
}
}

View File

@@ -0,0 +1,7 @@
package work.slhaf.partner.core.memory.exception;
public class UnExistedTopicException extends RuntimeException {
public UnExistedTopicException(String message) {
super(message);
}
}

View File

@@ -0,0 +1,22 @@
package work.slhaf.partner.core.memory.pojo;
import lombok.Builder;
import lombok.Data;
import lombok.EqualsAndHashCode;
import work.slhaf.partner.api.common.entity.PersistableObject;
import java.io.Serial;
import java.time.LocalDate;
@EqualsAndHashCode(callSuper = true)
@Data
@Builder
public class EvaluatedSlice extends PersistableObject {
@Serial
private static final long serialVersionUID = 1L;
// private List<Message> chatMessages;
private LocalDate date;
private String summary;
}

View File

@@ -0,0 +1,26 @@
package work.slhaf.partner.core.memory.pojo;
import lombok.Data;
import lombok.EqualsAndHashCode;
import work.slhaf.partner.api.common.entity.PersistableObject;
import java.io.Serial;
import java.util.List;
import java.util.concurrent.CopyOnWriteArrayList;
@EqualsAndHashCode(callSuper = true)
@Data
public class MemoryResult extends PersistableObject {
@Serial
private static final long serialVersionUID = 1L;
private CopyOnWriteArrayList<MemorySliceResult> memorySliceResult;
private List<MemorySlice> relatedMemorySliceResult;
public boolean isEmpty(){
boolean a = memorySliceResult == null || memorySliceResult.isEmpty();
boolean b = relatedMemorySliceResult == null || relatedMemorySliceResult.isEmpty();
return a && b;
}
}

View File

@@ -0,0 +1,83 @@
package work.slhaf.partner.core.memory.pojo;
import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.ToString;
import work.slhaf.partner.api.chat.pojo.Message;
import work.slhaf.partner.api.common.entity.PersistableObject;
import java.io.Serial;
import java.util.List;
@EqualsAndHashCode(callSuper = true)
@Data
public class MemorySlice extends PersistableObject implements Comparable<MemorySlice> {
@Serial
private static final long serialVersionUID = 1L;
/**
* 关联的完整对话的id
*/
private String memoryId;
/**
* 该切片在关联的完整对话中的顺序, 由时间戳确定
*/
private Long timestamp;
/**
* 格式为"<日期>.slice", 如2025-04-11.slice
*/
private String summary;
private List<Message> chatMessages;
/**
* 关联的其他主题, 即"邻近节点(联系)"
*/
private List<List<String>> relatedTopics;
/**
* 关联完整对话中的前序切片, 排序为键,完整路径为值
*/
@ToString.Exclude
private MemorySlice sliceBefore, sliceAfter;
/**
* 多用户设定
* 发起该切片对话的用户
*/
private String startUserId;
/**
* 该切片涉及到的用户uuid
*/
private List<String> involvedUserIds;
/**
* 是否仅供发起用户作为记忆参考
*/
private boolean isPrivate;
/**
* 摘要向量化结果
*/
private float[] summaryEmbedding;
/**
* 是否向量化
*/
private boolean embedded;
@Override
public int compareTo(MemorySlice memorySlice) {
if (memorySlice.getTimestamp() > this.getTimestamp()) {
return -1;
} else if (memorySlice.getTimestamp() < this.timestamp) {
return 1;
}
return 0;
}
}

View File

@@ -0,0 +1,24 @@
package work.slhaf.partner.core.memory.pojo;
import com.alibaba.fastjson2.annotation.JSONField;
import lombok.Data;
import lombok.EqualsAndHashCode;
import work.slhaf.partner.api.common.entity.PersistableObject;
import java.io.Serial;
@EqualsAndHashCode(callSuper = true)
@Data
public class MemorySliceResult extends PersistableObject {
@Serial
private static final long serialVersionUID = 1L;
@JSONField(serialize = false)
private MemorySlice sliceBefore;
private MemorySlice memorySlice;
@JSONField(serialize = false)
private MemorySlice sliceAfter;
}

View File

@@ -0,0 +1,82 @@
package work.slhaf.partner.core.memory.pojo.node;
import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.extern.slf4j.Slf4j;
import work.slhaf.partner.api.common.entity.PersistableObject;
import work.slhaf.partner.core.memory.exception.NullSliceListException;
import work.slhaf.partner.core.memory.pojo.MemorySlice;
import java.io.*;
import java.nio.file.Files;
import java.nio.file.Path;
import java.time.LocalDate;
import java.util.List;
import java.util.concurrent.CopyOnWriteArrayList;
@EqualsAndHashCode(callSuper = true)
@Data
@Slf4j
public class MemoryNode extends PersistableObject implements Comparable<MemoryNode> {
@Serial
private static final long serialVersionUID = 1L;
private static String SLICE_DATA_DIR = "./data/memory/slice/";
/**
* 记忆节点唯一标识, 用于作为实际文件名, 如(xxxx-xxxxx-xxxxx.slice)
*/
private String memoryNodeId;
/**
* 记忆节点所属日期
*/
private LocalDate localDate;
/**
* 该日期对应的全部记忆切片
*/
private CopyOnWriteArrayList<MemorySlice> memorySliceList;
@Override
public int compareTo(MemoryNode memoryNode) {
if (memoryNode.getLocalDate().isAfter(this.localDate)) {
return -1;
} else if (memoryNode.getLocalDate().isBefore(this.localDate)) {
return 1;
}
return 0;
}
public List<MemorySlice> loadMemorySliceList() throws IOException, ClassNotFoundException {
//检查是否存在对应文件
File file = new File(SLICE_DATA_DIR+this.getMemoryNodeId()+".slice");
if (file.exists()){
this.memorySliceList = deserialize(file);
}else {
//逻辑正常的话这部分应该不会出现除非在insertMemory中进行save操作之前出现异常中断了方法但程序却没有结束
this.memorySliceList = new CopyOnWriteArrayList<>();
}
return this.memorySliceList;
}
public void saveMemorySliceList() throws IOException {
if (memorySliceList == null){
throw new NullSliceListException("memorySliceList为NULL! 检查实现逻辑!");
}
File file = new File(SLICE_DATA_DIR+this.getMemoryNodeId()+".slice");
Files.createDirectories(Path.of(SLICE_DATA_DIR));
try(ObjectOutputStream oos = new ObjectOutputStream(new FileOutputStream(file))){
oos.writeObject(this.memorySliceList);
}
//取消切片挂载, 释放内存
this.memorySliceList = null;
}
private CopyOnWriteArrayList<MemorySlice> deserialize(File file) throws IOException, ClassNotFoundException {
try(ObjectInputStream ois = new ObjectInputStream(new FileInputStream(file))) {
return (CopyOnWriteArrayList<MemorySlice>) ois.readObject();
}
}
}

View File

@@ -0,0 +1,20 @@
package work.slhaf.partner.core.memory.pojo.node;
import lombok.Data;
import lombok.EqualsAndHashCode;
import work.slhaf.partner.api.common.entity.PersistableObject;
import java.io.Serial;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.CopyOnWriteArrayList;
@EqualsAndHashCode(callSuper = true)
@Data
public class TopicNode extends PersistableObject {
@Serial
private static final long serialVersionUID = 1L;
private ConcurrentHashMap<String,TopicNode> topicNodes = new ConcurrentHashMap<>();
private CopyOnWriteArrayList<MemoryNode> memoryNodes = new CopyOnWriteArrayList<>();
}

View File

@@ -0,0 +1,12 @@
package work.slhaf.partner.core.perceive;
import work.slhaf.partner.api.agent.factory.capability.annotation.Capability;
import work.slhaf.partner.core.perceive.pojo.User;
@Capability(value = "perceive")
public interface PerceiveCapability {
User getUser(String userInfo, String client);
User getUser(String id);
User addUser(String userInfo, String platform, String userNickName);
void updateUser(User user);
}

View File

@@ -0,0 +1,99 @@
package work.slhaf.partner.core.perceive;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.Setter;
import work.slhaf.partner.api.agent.factory.capability.annotation.CapabilityCore;
import work.slhaf.partner.api.agent.factory.capability.annotation.CapabilityMethod;
import work.slhaf.partner.core.PartnerCore;
import work.slhaf.partner.core.cognation.exception.UserNotExistsException;
import work.slhaf.partner.core.perceive.pojo.User;
import java.io.IOException;
import java.io.Serial;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.UUID;
import java.util.concurrent.locks.ReentrantLock;
@EqualsAndHashCode(callSuper = true)
@CapabilityCore(value = "perceive")
@Getter
@Setter
public class PerceiveCore extends PartnerCore<PerceiveCore> {
@Serial
private static final long serialVersionUID = 1L;
private static final ReentrantLock usersLock = new ReentrantLock();
/**
* 用户列表
*/
private List<User> users = new ArrayList<>();
public PerceiveCore() throws IOException, ClassNotFoundException {
}
@CapabilityMethod
public User getUser(String userInfo, String platform) {
User resultUser = null;
usersLock.lock();
for (User user : users) {
HashMap<String, String> info = user.getInfo();
if (info.containsKey(platform)) {
if (info.get(platform).equals(userInfo)) {
resultUser = user;
}
}
}
usersLock.unlock();
return resultUser;
}
@CapabilityMethod
public User addUser(String userInfo, String platform, String userNickName) {
User user = new User();
user.addInfo(platform, userInfo);
user.setNickName(userNickName);
user.setUuid(UUID.randomUUID().toString());
usersLock.lock();
users.add(user);
usersLock.unlock();
return user;
}
@CapabilityMethod
public User getUser(String id) {
usersLock.lock();
User resultUser = null;
for (User user : users) {
if (user.getUuid().equals(id)) {
resultUser = user;
}
}
usersLock.unlock();
if (resultUser == null) {
throw new UserNotExistsException("[PerceiveCore] 用户不存在: " + id);
}
return resultUser;
}
@CapabilityMethod
public void updateUser(User temp) {
usersLock.lock();
User user = getUser(temp.getUuid());
user.setRelation(temp.getRelation());
user.setImpressions(temp.getImpressions());
user.setAttitude(temp.getAttitude());
user.setStaticMemory(temp.getStaticMemory());
user.updateRelationChange(user.getRelationChange());
usersLock.unlock();
}
@Override
protected String getCoreKey() {
return "perceive-core";
}
}

View File

@@ -0,0 +1,51 @@
package work.slhaf.partner.core.perceive.pojo;
import lombok.Data;
import lombok.EqualsAndHashCode;
import work.slhaf.partner.api.common.entity.PersistableObject;
import java.io.Serial;
import java.time.LocalDate;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
@EqualsAndHashCode(callSuper = true)
@Data
public class User extends PersistableObject {
@Serial
private static final long serialVersionUID = 1L;
private String uuid;
private String nickName;
private HashMap<String/*platform*/, String> info = new HashMap<>();
private String relation = Constant.Relation.STRANGER;
// private HashMap<LocalDate, String> events = new HashMap<>();
private List<String> impressions = new ArrayList<>();
private List<String> attitude = new ArrayList<>();
private LinkedHashMap<LocalDate,String> relationChange = new LinkedHashMap<>();
private HashMap<String,String> staticMemory = new HashMap<>();
public void addInfo(String platform, String userInfo) {
this.info.put(platform, userInfo);
}
public void updateRelationChange(String changeReason){
relationChange.put(LocalDate.now(),changeReason);
}
public void updateRelationChange(LocalDate date, String changeReason){
relationChange.put(date,changeReason);
}
public void updateRelationChange(LinkedHashMap<LocalDate,String> tempRelationChange){
relationChange.putAll(tempRelationChange);
}
public static class Constant {
public static class Relation {
public static final String STRANGER = "陌生";
}
}
}

View File

@@ -0,0 +1,19 @@
package work.slhaf.partner.module.common.entity;
import lombok.Data;
import lombok.EqualsAndHashCode;
import work.slhaf.partner.api.common.entity.PersistableObject;
import java.io.Serial;
import java.util.Map;
@EqualsAndHashCode(callSuper = true)
@Data
public class AppendPromptData extends PersistableObject {
@Serial
private static final long serialVersionUID = 1L;
private String moduleName;
private Map<String, String> appendedPrompt;
}

View File

@@ -0,0 +1,16 @@
package work.slhaf.partner.module.common.model;
public class ModelConstant {
public static class Prompt {
public static final String MEMORY = "memory";
public static final String SCHEDULE = "schedule";
public static final String CORE = "core";
public static final String PERCEIVE = "perceive";
}
public static class CharacterPrefix {
public static final String SYSTEM = "[SYSTEM] ";
}
}

View File

@@ -0,0 +1,20 @@
package work.slhaf.partner.module.common.module;
import work.slhaf.partner.api.agent.runtime.interaction.flow.abstracts.AgentRunningModule;
import work.slhaf.partner.runtime.interaction.data.context.PartnerRunningFlowContext;
public abstract class PostRunningModule extends AgentRunningModule<PartnerRunningFlowContext> {
@Override
public final void execute(PartnerRunningFlowContext context) {
boolean trigger = context.getModuleContext().getExtraContext().getBoolean("post_process_trigger");
if (!trigger && relyOnMessage()) {
return;
}
doExecute(context);
}
public abstract void doExecute(PartnerRunningFlowContext context);
protected abstract boolean relyOnMessage();
}

View File

@@ -0,0 +1,43 @@
package work.slhaf.partner.module.common.module;
import lombok.extern.slf4j.Slf4j;
import work.slhaf.partner.api.agent.runtime.interaction.flow.abstracts.AgentRunningModule;
import work.slhaf.partner.module.common.entity.AppendPromptData;
import work.slhaf.partner.runtime.interaction.data.context.PartnerRunningFlowContext;
import java.util.Map;
/**
* 前置模块抽象类
*/
@Slf4j
public abstract class PreRunningModule extends AgentRunningModule<PartnerRunningFlowContext> {
private synchronized void setAppendedPrompt(PartnerRunningFlowContext context) {
AppendPromptData data = new AppendPromptData();
data.setModuleName(moduleName());
Map<String, String> map = getPromptDataMap(context);
data.setAppendedPrompt(map);
context.setAppendedPrompt(data);
}
private synchronized void setActiveModule(PartnerRunningFlowContext context) {
context.getCoreContext().addActiveModule(moduleName());
}
protected abstract Map<String, String> getPromptDataMap(PartnerRunningFlowContext context);
/**
* 用于在CoreModule接收到的模块Prompt中标识模块名称
*/
protected abstract String moduleName();
@Override
public final void execute(PartnerRunningFlowContext context) {
doExecute(context); // 子类实现差异化逻辑
setAppendedPrompt(context); // 通用逻辑
setActiveModule(context); // 通用逻辑
}
protected abstract void doExecute(PartnerRunningFlowContext context);
}

View File

@@ -0,0 +1,70 @@
package work.slhaf.partner.module.modules.action.dispatcher;
import lombok.val;
import work.slhaf.partner.api.agent.factory.capability.annotation.InjectCapability;
import work.slhaf.partner.api.agent.factory.module.annotation.AgentModule;
import work.slhaf.partner.api.agent.factory.module.annotation.Init;
import work.slhaf.partner.api.agent.factory.module.annotation.InjectModule;
import work.slhaf.partner.core.action.ActionCapability;
import work.slhaf.partner.core.action.ActionCore;
import work.slhaf.partner.core.action.entity.ExecutableAction;
import work.slhaf.partner.core.action.entity.ImmediateExecutableAction;
import work.slhaf.partner.core.action.entity.SchedulableExecutableAction;
import work.slhaf.partner.module.common.module.PostRunningModule;
import work.slhaf.partner.module.modules.action.dispatcher.executor.ActionExecutor;
import work.slhaf.partner.module.modules.action.dispatcher.executor.entity.ActionExecutorInput;
import work.slhaf.partner.module.modules.action.dispatcher.scheduler.ActionScheduler;
import work.slhaf.partner.runtime.interaction.data.context.PartnerRunningFlowContext;
import java.util.HashSet;
import java.util.Set;
import java.util.concurrent.ExecutorService;
@AgentModule(name = "action_dispatcher", order = 7)
public class ActionDispatcher extends PostRunningModule {
@InjectCapability
private ActionCapability actionCapability;
@InjectModule
private ActionExecutor actionExecutor;
@InjectModule
private ActionScheduler actionScheduler;
private ExecutorService executor;
@Init
public void init() {
executor = actionCapability.getExecutor(ActionCore.ExecutorType.VIRTUAL);
}
@Override
public void doExecute(PartnerRunningFlowContext context) {
// 只需要处理prepared action因为pending action在用户确认后也将变为prepared action
// 将PLANNING action放入时间轮中IMMEDIATE action直接进入并发执行流
// 对于将触发的PLANNING
// action理想做法是将执行工具做成执行链的形式模型的自对话流程、是否通知用户都做成与普通工具同等的通用可选能力避免绑定固定流程
executor.execute(() -> {
String userId = context.getUserId();
val preparedActions = actionCapability.listActions(ExecutableAction.Status.PREPARE, userId);
// 分类成PLANNING和IMMEDIATE两类
Set<SchedulableExecutableAction> scheduledActions = new HashSet<>();
Set<ImmediateExecutableAction> immediateActions = new HashSet<>();
for (ExecutableAction preparedAction : preparedActions) {
if (preparedAction instanceof SchedulableExecutableAction actionInfo) {
scheduledActions.add(actionInfo);
} else if (preparedAction instanceof ImmediateExecutableAction actionInfo) {
immediateActions.add(actionInfo);
}
}
actionExecutor.execute(new ActionExecutorInput(immediateActions));
actionScheduler.execute(scheduledActions);
});
}
@Override
protected boolean relyOnMessage() {
return false;
}
}

View File

@@ -0,0 +1,52 @@
package work.slhaf.partner.module.modules.action.dispatcher.executor;
import com.alibaba.fastjson2.JSONObject;
import lombok.val;
import work.slhaf.partner.api.agent.factory.module.annotation.AgentSubModule;
import work.slhaf.partner.api.agent.runtime.interaction.flow.abstracts.ActivateModel;
import work.slhaf.partner.api.agent.runtime.interaction.flow.abstracts.AgentRunningSubModule;
import work.slhaf.partner.module.modules.action.dispatcher.executor.entity.CorrectorInput;
import work.slhaf.partner.module.modules.action.dispatcher.executor.entity.CorrectorResult;
/**
* 负责在单组行动执行后,根据行动意图与结果检查后续行动是否符合目的,必要时直接调整行动链,或发起自对话请求进行干预
*/
@AgentSubModule
public class ActionCorrector extends AgentRunningSubModule<CorrectorInput, CorrectorResult> implements ActivateModel {
@Override
public CorrectorResult execute(CorrectorInput input) {
val prompt = buildPrompt(input);
val chatResponse = singleChat(prompt);
return JSONObject.parseObject(chatResponse.getMessage(), CorrectorResult.class);
}
private String buildPrompt(CorrectorInput input) {
val prompt = new JSONObject();
prompt.put("[行动来源]", input.getSource());
prompt.put("[行动倾向]", input.getTendency());
prompt.put("[行动描述]", input.getDescription());
prompt.put("[行动原因]", input.getReason());
val messages = prompt.putArray("[近期对话]");
messages.addAll(input.getRecentMessages());
val memory = prompt.putArray("[已激活记忆]");
memory.addAll(input.getActivatedSlices());
val history = prompt.putArray("[已执行情况]");
history.addAll(input.getHistory());
return prompt.toJSONString();
}
@Override
public String modelKey() {
return "action_corrector";
}
@Override
public boolean withBasicPrompt() {
return false;
}
}

View File

@@ -0,0 +1,324 @@
package work.slhaf.partner.module.modules.action.dispatcher.executor;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import work.slhaf.partner.api.agent.factory.capability.annotation.InjectCapability;
import work.slhaf.partner.api.agent.factory.module.annotation.AgentSubModule;
import work.slhaf.partner.api.agent.factory.module.annotation.Init;
import work.slhaf.partner.api.agent.factory.module.annotation.InjectModule;
import work.slhaf.partner.api.agent.runtime.interaction.flow.abstracts.AgentRunningSubModule;
import work.slhaf.partner.core.action.ActionCapability;
import work.slhaf.partner.core.action.ActionCore;
import work.slhaf.partner.core.action.entity.*;
import work.slhaf.partner.core.action.entity.ExecutableAction.Status;
import work.slhaf.partner.core.action.runner.RunnerClient;
import work.slhaf.partner.core.cognation.CognationCapability;
import work.slhaf.partner.core.memory.MemoryCapability;
import work.slhaf.partner.module.modules.action.dispatcher.executor.entity.*;
import work.slhaf.partner.module.modules.action.dispatcher.scheduler.ActionScheduler;
import java.util.ArrayList;
import java.util.List;
import java.util.Set;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Phaser;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
@Slf4j
@AgentSubModule
public class ActionExecutor extends AgentRunningSubModule<ActionExecutorInput, Void> {
@InjectCapability
private ActionCapability actionCapability;
@InjectCapability
private MemoryCapability memoryCapability;
@InjectCapability
private CognationCapability cognationCapability;
@InjectModule
private ParamsExtractor paramsExtractor;
@InjectModule
private ActionRepairer actionRepairer;
@InjectModule
private ActionCorrector actionCorrector;
@InjectModule
private ActionScheduler actionScheduler;
private ExecutorService virtualExecutor;
private ExecutorService platformExecutor;
private RunnerClient runnerClient;
private final AssemblyHelper assemblyHelper = new AssemblyHelper();
@Init
public void init() {
virtualExecutor = actionCapability.getExecutor(ActionCore.ExecutorType.VIRTUAL);
platformExecutor = actionCapability.getExecutor(ActionCore.ExecutorType.PLATFORM);
runnerClient = actionCapability.runnerClient();
}
/**
* 执行行动
*
* @param input ActionExecutor 输入内容
* @return 无返回,执行结果回写至 input 内部携带的 actionData 中
*/
@Override
public Void execute(ActionExecutorInput input) {
val actions = input.getActions();
// 异步执行所有行动
for (ExecutableAction executableAction : actions) {
platformExecutor.execute(() -> {
val source = executableAction.getSource();
if (executableAction.getStatus() != Status.PREPARE) {
return;
}
val actionChain = executableAction.getActionChain();
if (actionChain.isEmpty()) {
executableAction.setStatus(Status.FAILED);
executableAction.setResult("行动链为空");
return;
}
// 注册执行中行动
val phaser = new Phaser();
val phaserRecord = actionCapability.putPhaserRecord(phaser, executableAction);
executableAction.setStatus(Status.EXECUTING);
// 开始执行
val stageCursor = new Object() {
int stageCount;
boolean executingStageUpdated;
boolean stageCountUpdated;
void init() {
stageCount = 0;
executingStageUpdated = false;
stageCountUpdated = false;
update();
}
void requestAdvance() {
if (!stageCountUpdated) {
stageCount++;
stageCountUpdated = true;
}
if (stageCount < actionChain.size() && !executingStageUpdated) {
update();
executingStageUpdated = true;
}
}
boolean next() {
executingStageUpdated = false;
stageCountUpdated = false;
return stageCount < actionChain.size();
}
void update() {
val orderList = new ArrayList<>(actionChain.keySet());
orderList.sort(Integer::compareTo);
executableAction.setExecutingStage(orderList.get(stageCount));
}
};
stageCursor.init();
do {
val metaActions = actionChain.get(executableAction.getExecutingStage());
val listeningRecord = executeAndListening(metaActions, phaserRecord, source);
phaser.awaitAdvance(listeningRecord.phase());
// synchronized 同步防止 accepting 循环间、phase guard 判定后发生 stage 推进
// 导致新行动的 phaser 投放阶段错乱无法阻塞的场景
// 该 synchronized 将阶段推进与 accepting 监听 loop 捆绑为互斥的原子事件,避免了细粒度的 phaser 阶段竞态问题
synchronized (listeningRecord.accepting()) {
listeningRecord.accepting().set(false);
// 立即尝试推进,本次推进中,如果前方仍有未执行 stage将执行一次阶段推进
stageCursor.requestAdvance();
}
try {
// 针对行动链进行修正,修正需要传入执行历史、行动目标等内容
// 如果后续运行 corrector 触发频率较高,可考虑增加重试机制
val correctorInput = assemblyHelper.buildCorrectorInput(executableAction, source);
val correctorResult = actionCorrector.execute(correctorInput);
actionCapability.handleInterventions(correctorResult.getMetaInterventionList(), executableAction);
} catch (Exception ignored) {
}
// 第二次尝试进行阶段推进,本次负责补充上一次在不存在 stage时但 corrector 执行期间发生了 actionChain 的插入事件
// 如果第一次已经推进完毕,本次将会跳过
stageCursor.requestAdvance();
} while (stageCursor.next());
// 结束
actionCapability.removePhaserRecord(phaser);
if (executableAction.getStatus() != Status.FAILED) {
// 如果是 ScheduledActionData, 则重置 ActionData 内容,记录执行历史与最终结果
if (executableAction instanceof SchedulableExecutableAction scheduledActionData) {
scheduledActionData.recordAndReset();
actionScheduler.execute(Set.of(scheduledActionData));
} else {
executableAction.setStatus(Status.SUCCESS);
}
// TODO 执行过后需要回写至任务上下文recentCompletedTask同时触发自对话信号进行确认并记录以及是否通知用户触发与否需要机制进行匹配在模块链路可增加 interaction gate 门控,判断此次对话作用于谁、由谁发出、何种性质、是否需要回应等)
}
});
}
return null;
}
private MetaActionsListeningRecord executeAndListening(List<MetaAction> metaActions, PhaserRecord phaserRecord, String source) {
AtomicBoolean accepting = new AtomicBoolean(true);
AtomicInteger cursor = new AtomicInteger();
CountDownLatch latch = new CountDownLatch(1);
val phaser = phaserRecord.phaser();
val phase = phaser.register();
platformExecutor.execute(() -> {
boolean first = true;
while (accepting.get()) {
synchronized (accepting) {
MetaAction next = null;
synchronized (metaActions) {
if (cursor.get() < metaActions.size()) {
next = metaActions.get(cursor.getAndIncrement());
}
}
if (next == null) {
Thread.onSpinWait();
continue;
}
if (phaser.getPhase() != phase) {
metaActions.remove(next);
log.warn("行动阶段已推进,丢弃该行动: {}", next);
continue;
}
ExecutorService executor = next.getIo() ? virtualExecutor : platformExecutor;
executor.execute(buildMataActionTask(next, phaserRecord, source));
if (first) {
phaser.arriveAndDeregister();
latch.countDown();
first = false;
}
}
}
});
try {
// 确保执行一次,防止没来得及注册任务就已经结束
latch.await();
} catch (InterruptedException ignored) {
}
return new MetaActionsListeningRecord(accepting, phase);
}
private Runnable buildMataActionTask(MetaAction metaAction, PhaserRecord phaserRecord, String source) {
val phaser = phaserRecord.phaser();
phaser.register();
return () -> {
val actionKey = metaAction.getKey();
try {
val result = metaAction.getResult();
do {
val actionData = phaserRecord.executableAction();
val executingStage = actionData.getExecutingStage();
val historyActionResults = actionData.getHistory().get(executingStage);
val additionalContext = actionData.getAdditionalContext().get(executingStage);
val extractorInput = assemblyHelper.buildExtractorInput(metaAction, source, historyActionResults, additionalContext);
val extractorResult = paramsExtractor.execute(extractorInput);
if (extractorResult.isOk()) {
metaAction.getParams().putAll(extractorResult.getParams());
runnerClient.submit(metaAction);
val historyAction = new HistoryAction(actionKey, actionCapability.loadMetaActionInfo(actionKey).getDescription(), metaAction.getResult().getData());
actionData.getHistory()
.computeIfAbsent(executingStage, integer -> new ArrayList<>())
.add(historyAction);
} else {
val repairerInput = assemblyHelper.buildRepairerInput(historyActionResults, metaAction, source);
val repairerResult = actionRepairer.execute(repairerInput);
switch (repairerResult.getStatus()) {
// 如果本次修复被认为成功,则将补充的信息添加至 additionalContext
case RepairerResult.RepairerStatus.OK -> {
additionalContext.addAll(repairerResult.getFixedData());
result.setStatus(MetaAction.Result.Status.WAITING);
}
// 此处的修复失败来自系统内部的执行失败:其余方式均不可行时将回退至当前分支
case RepairerResult.RepairerStatus.FAILED -> {
result.setStatus(MetaAction.Result.Status.FAILED);
result.setData("行动执行失败");
}
// 此处对应已在 repairer 内发起外部请求,故在此处进行阻塞
case RepairerResult.RepairerStatus.ACQUIRE -> {
phaserRecord.interrupt();
result.setStatus(MetaAction.Result.Status.WAITING);
}
}
}
} while (result.getStatus().equals(MetaAction.Result.Status.WAITING));
} catch (Exception e) {
log.error("Action executing failed: {}", actionKey, e);
} finally {
phaser.arriveAndDeregister();
}
};
}
private record MetaActionsListeningRecord(AtomicBoolean accepting, int phase) {
}
@SuppressWarnings("InnerClassMayBeStatic")
private class AssemblyHelper {
private AssemblyHelper() {
}
private RepairerInput buildRepairerInput(List<HistoryAction> historyActionsResults, MetaAction action, String userId) {
RepairerInput input = new RepairerInput();
MetaActionInfo metaActionInfo = actionCapability.loadMetaActionInfo(action.getKey());
input.setHistoryActionResults(historyActionsResults);
input.setParams(metaActionInfo.getParams());
input.setRecentMessages(cognationCapability.getChatMessages());
input.setActionDescription(metaActionInfo.getDescription());
input.setUserId(userId);
return input;
}
private ExtractorInput buildExtractorInput(MetaAction action, String source, List<HistoryAction> historyActionResults,
List<String> additionalContext) {
ExtractorInput input = new ExtractorInput();
input.setEvaluatedSlices(memoryCapability.getActivatedSlices(source));
input.setRecentMessages(cognationCapability.getChatMessages());
input.setMetaActionInfo(actionCapability.loadMetaActionInfo(action.getKey()));
input.setHistoryActionResults(historyActionResults);
input.setAdditionalContext(additionalContext);
return input;
}
private CorrectorInput buildCorrectorInput(ExecutableAction executableAction, String source) {
return CorrectorInput.builder()
.tendency(executableAction.getTendency())
.source(executableAction.getSource())
.reason(executableAction.getReason())
.description(executableAction.getDescription())
.history(executableAction.getHistory().get(executableAction.getExecutingStage()))
.status(executableAction.getStatus())
.recentMessages(cognationCapability.getChatMessages())
.activatedSlices(memoryCapability.getActivatedSlices(source))
.build();
}
}
}

View File

@@ -0,0 +1,228 @@
package work.slhaf.partner.module.modules.action.dispatcher.executor;
import com.alibaba.fastjson2.JSONArray;
import com.alibaba.fastjson2.JSONObject;
import com.alibaba.fastjson2.TypeReference;
import lombok.Data;
import lombok.extern.slf4j.Slf4j;
import work.slhaf.partner.api.agent.factory.capability.annotation.InjectCapability;
import work.slhaf.partner.api.agent.factory.module.annotation.AgentSubModule;
import work.slhaf.partner.api.agent.factory.module.annotation.Init;
import work.slhaf.partner.api.agent.factory.module.annotation.InjectModule;
import work.slhaf.partner.api.agent.runtime.interaction.flow.abstracts.ActivateModel;
import work.slhaf.partner.api.agent.runtime.interaction.flow.abstracts.AgentRunningSubModule;
import work.slhaf.partner.api.chat.pojo.ChatResponse;
import work.slhaf.partner.core.action.ActionCapability;
import work.slhaf.partner.core.action.ActionCore.ExecutorType;
import work.slhaf.partner.core.action.entity.MetaAction;
import work.slhaf.partner.core.action.entity.MetaAction.Result;
import work.slhaf.partner.core.action.runner.RunnerClient;
import work.slhaf.partner.core.cognation.CognationCapability;
import work.slhaf.partner.module.modules.action.dispatcher.executor.entity.GeneratorInput;
import work.slhaf.partner.module.modules.action.dispatcher.executor.entity.GeneratorResult;
import work.slhaf.partner.module.modules.action.dispatcher.executor.entity.RepairerInput;
import work.slhaf.partner.module.modules.action.dispatcher.executor.entity.RepairerResult;
import work.slhaf.partner.module.modules.action.dispatcher.executor.entity.RepairerResult.RepairerStatus;
import java.util.List;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.atomic.AtomicInteger;
/**
* 负责识别行动链的修复
* <ol>
* <li>
* 可通过协调 {@link DynamicActionGenerator} 生成新的行动单元并调用,获取所需的参数信息(必要时持久化);
* </li>
* <li>
* 也可以直接调用已存在的行动程序获取信息;
* </li>
* <li>
* 如果上述都无法满足,将发起自对话借助干预模块进行操作或者借助自对话通道向用户发起沟通请求,该请求的目的一般为行动程序生成/调用指导或者用户侧的信息补充,后续还需要再走一遍参数修复流程
* </li>
* </ol>
*/
@Slf4j
@AgentSubModule
public class ActionRepairer extends AgentRunningSubModule<RepairerInput, RepairerResult> implements ActivateModel {
@InjectCapability
private ActionCapability actionCapability;
@InjectCapability
private CognationCapability cognationCapability;
@InjectModule
private DynamicActionGenerator dynamicActionGenerator;
private final AssemblyHelper assemblyHelper = new AssemblyHelper();
private RunnerClient runnerClient;
@Init
void init() {
runnerClient = actionCapability.runnerClient();
}
@Override
public RepairerResult execute(RepairerInput data) {
RepairerResult result;
try {
String prompt = assemblyHelper.buildPrompt(data, null);
ChatResponse response = this.singleChat(prompt);
RepairerData repairerData = JSONObject.parseObject(response.getMessage(), RepairerData.class);
result = switch (repairerData.getRepairerType()) {
case ACTION_GENERATION ->
handleActionGeneration(JSONObject.parseObject(repairerData.getData(), GeneratorInput.class));
case ACTION_INVOCATION -> handleActionInvocation(
JSONObject.parseObject(repairerData.getData(), new TypeReference<List<String>>() {
}));
case USER_INTERACTION -> handleUserInteraction(repairerData.getData());
};
if (!repairerData.getRepairerType().equals(RepairerType.USER_INTERACTION)
&& result.getStatus().equals(RepairerResult.RepairerStatus.FAILED)) {
log.warn("常规行动修复失败,将尝试自对话通道");
prompt = assemblyHelper.buildPrompt(data, "常规行动修复失败,请尝试通过自对话通道获取必要的信息以完成行动参数的修复");
response = this.singleChat(prompt);
repairerData = JSONObject.parseObject(response.getMessage(), RepairerData.class);
handleUserInteraction(repairerData.getData());
}
} catch (Exception e) {
result = new RepairerResult();
result.setStatus(RepairerStatus.FAILED);
}
return result;
}
/**
* 负责根据输入内容进行行动单元的参数信息修复
*
* @param generatorInput 生成的行动单元参考内容,最好包含行动单元的执行逻辑
* @return 修复后的行动单元结果
*/
private RepairerResult handleActionGeneration(GeneratorInput generatorInput) {
RepairerResult result = new RepairerResult();
GeneratorResult generatorResult = dynamicActionGenerator.execute(generatorInput);
MetaAction tempAction = generatorResult.getTempAction();
if (tempAction == null) {
result.setStatus(RepairerStatus.FAILED);
return result;
}
runnerClient.submit(tempAction);
// 根据 tempAction 的执行状态设置修复结果
Result actionResult = tempAction.getResult();
if (actionResult.getStatus() != MetaAction.Result.Status.SUCCESS) {
result.setStatus(RepairerStatus.FAILED);
return result;
}
result.setStatus(RepairerStatus.OK);
result.getFixedData().add(actionResult.getData());
return result;
}
/**
* 负责根据输入内容进行行动单元的参数信息修复
*
* @param actionKeys 需要调用的行动单元Key列表
* @return 修复后的行动单元结果
*/
private RepairerResult handleActionInvocation(List<String> actionKeys) {
RepairerResult result = new RepairerResult();
CountDownLatch latch = new CountDownLatch(actionKeys.size());
ExecutorService virtual = actionCapability.getExecutor(ExecutorType.VIRTUAL);
ExecutorService platform = actionCapability.getExecutor(ExecutorType.PLATFORM);
ExecutorService executor;
AtomicInteger failedCount = new AtomicInteger(0);
for (String key : actionKeys) {
MetaAction action = actionCapability.loadMetaAction(key);
executor = action.getIo() ? virtual : platform;
executor.execute(() -> {
try {
runnerClient.submit(action);
result.getFixedData().add(action.getResult().getData());
} catch (Exception e) {
log.error("行动单元执行失败: {}", key, e);
failedCount.incrementAndGet();
} finally {
latch.countDown();
}
});
}
try {
latch.await();
} catch (Exception e) {
log.warn("CountDownLatch 已中断");
}
if (actionKeys.size() - failedCount.get() > 0) {
result.setStatus(RepairerStatus.OK);
} else {
result.setStatus(RepairerStatus.FAILED);
}
return result;
}
private RepairerResult handleUserInteraction(String acquireContent) {
RepairerResult result = new RepairerResult();
result.setStatus(RepairerStatus.ACQUIRE);
// 发送自对话请求
return result;
}
@Override
public String modelKey() {
return "action_repairer";
}
@Override
public boolean withBasicPrompt() {
return false;
}
@SuppressWarnings("InnerClassMayBeStatic")
@Data
private class RepairerData {
private RepairerType repairerType;
private String data;
}
private enum RepairerType {
ACTION_GENERATION,
ACTION_INVOCATION,
USER_INTERACTION
}
@SuppressWarnings("InnerClassMayBeStatic")
private class AssemblyHelper {
private AssemblyHelper() {
}
private String buildPrompt(RepairerInput data, String specialInstruction) {
JSONObject prompt = new JSONObject();
JSONObject actionData = prompt.putObject("[本次行动信息]");
actionData.put("[行动描述]", data.getActionDescription());
JSONObject actionParamsData = actionData.putObject("[行动参数说明]");
actionParamsData.putAll(data.getParams());
JSONArray historyData = prompt.putArray("[历史行动执行结果]");
data.getHistoryActionResults().forEach(historyAction -> {
JSONObject historyItem = new JSONObject();
historyItem.put("[行动Key]", historyAction.actionKey());
historyItem.put("[行动描述]", historyAction.description());
historyItem.put("[行动结果]", historyAction.result());
historyData.add(historyItem);
});
JSONArray messageData = prompt.putArray("[最近消息列表]");
messageData.addAll(data.getRecentMessages());
if (specialInstruction != null) {
prompt.put("[特殊指令]", specialInstruction);
}
return prompt.toString();
}
}
}

View File

@@ -0,0 +1,91 @@
package work.slhaf.partner.module.modules.action.dispatcher.executor;
import com.alibaba.fastjson2.JSONObject;
import lombok.val;
import work.slhaf.partner.api.agent.factory.capability.annotation.InjectCapability;
import work.slhaf.partner.api.agent.factory.module.annotation.AgentSubModule;
import work.slhaf.partner.api.agent.factory.module.annotation.Init;
import work.slhaf.partner.api.agent.runtime.interaction.flow.abstracts.ActivateModel;
import work.slhaf.partner.api.agent.runtime.interaction.flow.abstracts.AgentRunningSubModule;
import work.slhaf.partner.api.chat.pojo.ChatResponse;
import work.slhaf.partner.common.util.ExtractUtil;
import work.slhaf.partner.core.action.ActionCapability;
import work.slhaf.partner.core.action.entity.GeneratedData;
import work.slhaf.partner.core.action.entity.MetaAction;
import work.slhaf.partner.core.action.runner.RunnerClient;
import work.slhaf.partner.module.modules.action.dispatcher.executor.entity.GeneratorInput;
import work.slhaf.partner.module.modules.action.dispatcher.executor.entity.GeneratorResult;
/**
* 负责依据输入内容生成可执行的动态行动单元,并选择是否持久化至 SandboxRunner 容器内
*/
@AgentSubModule
public class DynamicActionGenerator extends AgentRunningSubModule<GeneratorInput, GeneratorResult>
implements ActivateModel {
@InjectCapability
private ActionCapability actionCapability;
private RunnerClient runnerClient;
@Init
void init() {
runnerClient = actionCapability.runnerClient();
}
@Override
public GeneratorResult execute(GeneratorInput input) {
GeneratorResult result = new GeneratorResult();
try {
// 由于 SCRIPT 类型程序都是在 SandboxRunner 内部的磁盘上加载然后执行的,
// 所以此处的输入内容也只需要指定输入参数、临时key、是否持久化即可路径将按照指定规则统一构建不可交给LLM生成
String prompt = buildPrompt(input);
// 响应结果需要包含几个特殊数据: 依赖项、代码内容、是否序列化、响应数据释义
ChatResponse response = this.singleChat(prompt);
GeneratedData generatorData = JSONObject
.parseObject(ExtractUtil.extractJson(response.getMessage()), GeneratedData.class);
val location = runnerClient.buildTmpPath(input.getActionName(), generatorData.getCodeType());
MetaAction tempAction = new MetaAction(
input.getActionName(),
true,
MetaAction.Type.ORIGIN,
location
);
// 将临时行动单元序列化至临时文件夹,并设置程序路径、放置在队列中,等待执行状态变化,并根据序列化选项选择是否补充 MetaActionInfo 并持久序列化
// 通过 ActionCapability 暴露的接口序列化至临时文件夹同时返回Path对象并设置。队列建议交给 SandboxRunner
// 持有,包括监听与序列化线程
runnerClient.tmpSerialize(tempAction, generatorData.getCode(), generatorData.getCodeType());
if (generatorData.isSerialize()) {
waitingSerialize();
}
result.setTempAction(tempAction);
} catch (Exception e) {
result.setTempAction(null);
}
return result;
}
private void waitingSerialize() {
throw new UnsupportedOperationException("Unimplemented method 'waitingSerialize'");
}
private String buildPrompt(GeneratorInput data) {
JSONObject prompt = new JSONObject();
prompt.put("[行动描述]", data.getDescription());
// prompt.putObject("[行动参数]").putAll(data.getParams());
prompt.putObject("[行动参数描述]").putAll(data.getParamsDescription());
return prompt.toString();
}
@Override
public String modelKey() {
return "dynamic_generator";
}
@Override
public boolean withBasicPrompt() {
return false;
}
}

View File

@@ -0,0 +1,75 @@
package work.slhaf.partner.module.modules.action.dispatcher.executor;
import com.alibaba.fastjson2.JSONArray;
import com.alibaba.fastjson2.JSONObject;
import lombok.extern.slf4j.Slf4j;
import work.slhaf.partner.api.agent.factory.module.annotation.AgentSubModule;
import work.slhaf.partner.api.agent.runtime.interaction.flow.abstracts.ActivateModel;
import work.slhaf.partner.api.agent.runtime.interaction.flow.abstracts.AgentRunningSubModule;
import work.slhaf.partner.api.chat.pojo.ChatResponse;
import work.slhaf.partner.core.action.entity.MetaActionInfo;
import work.slhaf.partner.module.modules.action.dispatcher.executor.entity.ExtractorInput;
import work.slhaf.partner.module.modules.action.dispatcher.executor.entity.ExtractorResult;
import work.slhaf.partner.module.modules.action.dispatcher.executor.entity.HistoryAction;
import java.util.HashMap;
import java.util.List;
/**
* 负责依据输入内容进行行动单元的参数信息提取
*/
@Slf4j
@AgentSubModule
public class ParamsExtractor extends AgentRunningSubModule<ExtractorInput, ExtractorResult> implements ActivateModel {
@Override
public ExtractorResult execute(ExtractorInput input) {
String prompt = buildPrompt(input);
ChatResponse response = this.singleChat(prompt);
ExtractorResult result;
try {
result = JSONObject.parseObject(response.getMessage(), ExtractorResult.class);
} catch (Exception e) {
log.error("ParamsExtractor解析结果失败返回内容{}", response.getMessage(), e);
result = new ExtractorResult();
result.setOk(false);
result.setParams(new HashMap<>());
}
return result;
}
private String buildPrompt(ExtractorInput input) {
JSONObject prompt = new JSONObject();
JSONObject actionData = prompt.putObject("[本次行动信息]");
MetaActionInfo actionInfo = input.getMetaActionInfo();
actionData.put("[行动描述]", actionInfo.getDescription());
actionData.put("[行动参数说明]", actionInfo.getParams());
JSONArray historyData = prompt.putArray("[历史行动执行结果]");
List<HistoryAction> historyActions = input.getHistoryActionResults();
for (HistoryAction historyAction : historyActions) {
JSONObject historyItem = new JSONObject();
historyItem.put("[行动Key]", historyAction.actionKey());
historyItem.put("[行动描述]", historyAction.description());
historyItem.put("[行动结果]", historyAction.result());
historyData.add(historyItem);
}
JSONArray messageData = prompt.putArray("[最近消息列表]");
messageData.addAll(input.getRecentMessages());
return prompt.toString();
}
@Override
public String modelKey() {
return "params_extractor";
}
@Override
public boolean withBasicPrompt() {
return false;
}
}

View File

@@ -0,0 +1,5 @@
package work.slhaf.partner.module.modules.action.dispatcher.executor.entity
import work.slhaf.partner.core.action.entity.ExecutableAction
data class ActionExecutorInput(val actions: Set<ExecutableAction>)

View File

@@ -0,0 +1,24 @@
package work.slhaf.partner.module.modules.action.dispatcher.executor.entity;
import lombok.Builder;
import lombok.Data;
import work.slhaf.partner.api.chat.pojo.Message;
import work.slhaf.partner.core.action.entity.ExecutableAction;
import work.slhaf.partner.core.memory.pojo.EvaluatedSlice;
import java.util.List;
@Data
@Builder
public class CorrectorInput {
private String tendency;
private String source;
private String reason;
private String description;
private List<HistoryAction> history;
private ExecutableAction.Status status;
private List<Message> recentMessages;
private List<EvaluatedSlice> activatedSlices;
}

View File

@@ -0,0 +1,11 @@
package work.slhaf.partner.module.modules.action.dispatcher.executor.entity;
import lombok.Data;
import work.slhaf.partner.module.modules.action.interventor.entity.MetaIntervention;
import java.util.List;
@Data
public class CorrectorResult {
private List<MetaIntervention> metaInterventionList;
}

View File

@@ -0,0 +1,32 @@
package work.slhaf.partner.module.modules.action.dispatcher.executor.entity;
import lombok.Data;
import work.slhaf.partner.api.chat.pojo.Message;
import work.slhaf.partner.core.action.entity.MetaActionInfo;
import work.slhaf.partner.core.memory.pojo.EvaluatedSlice;
import java.util.List;
@Data
public class ExtractorInput {
/**
* 目标 MetaActionInfo
*/
private MetaActionInfo metaActionInfo;
/**
* 可参考的记忆切片
*/
private List<EvaluatedSlice> evaluatedSlices;
/**
* 历史行动执行结果
*/
private List<HistoryAction> historyActionResults;
/**
* 最近的消息列表
*/
private List<Message> recentMessages;
/**
* 额外的上下文信息(可来自修复器等)
*/
private List<String> additionalContext;
}

View File

@@ -0,0 +1,11 @@
package work.slhaf.partner.module.modules.action.dispatcher.executor.entity;
import lombok.Data;
import java.util.Map;
@Data
public class ExtractorResult {
private boolean ok;
private Map<String, Object> params;
}

View File

@@ -0,0 +1,13 @@
package work.slhaf.partner.module.modules.action.dispatcher.executor.entity;
import lombok.Data;
import java.util.Map;
@Data
public class GeneratorInput {
private String actionName;
private Map<String, Object> params;
private String description;
private Map<String, String> paramsDescription;
}

View File

@@ -0,0 +1,9 @@
package work.slhaf.partner.module.modules.action.dispatcher.executor.entity;
import lombok.Data;
import work.slhaf.partner.core.action.entity.MetaAction;
@Data
public class GeneratorResult {
private MetaAction tempAction;
}

View File

@@ -0,0 +1,4 @@
package work.slhaf.partner.module.modules.action.dispatcher.executor.entity;
public record HistoryAction(String actionKey, String description, String result) {
}

View File

@@ -0,0 +1,17 @@
package work.slhaf.partner.module.modules.action.dispatcher.executor.entity;
import lombok.Data;
import work.slhaf.partner.api.chat.pojo.Message;
import java.util.List;
import java.util.Map;
@Data
public class RepairerInput {
private String userId;
private List<Message> recentMessages;
private Map<String, Object> params;
private String actionDescription;
private List<HistoryAction> historyActionResults;
}

View File

@@ -0,0 +1,30 @@
package work.slhaf.partner.module.modules.action.dispatcher.executor.entity;
import lombok.Data;
import java.util.List;
/**
* 行动修复结果,包含行动状态和修复后的参数
*/
@Data
public class RepairerResult {
private RepairerStatus status;
private List<String> fixedData;
public enum RepairerStatus {
/**
* 成功修复: 携带修复后参数; 此种情况对应 Repairer 通过某种方式获取到了完整的参数(调用额外的行动)
*/
OK,
/**
* 发送了自对话请求干预行动,这类一般是补充信息或者提供行动指导,后续必须再步入修复进程,但需要设置层级
*/
ACQUIRE,
/**
* 修复失败(简单修复、自对话通道均出现错误,正常情况不应该出现)
*/
FAILED
}
}

View File

@@ -0,0 +1,14 @@
package work.slhaf.partner.module.modules.action.dispatcher.executor.exception;
import work.slhaf.partner.api.agent.runtime.exception.AgentRuntimeException;
public class ActionExecutingFailedException extends AgentRuntimeException {
public ActionExecutingFailedException(String message) {
super(message);
}
public ActionExecutingFailedException(String message, Throwable cause) {
super(message, cause);
}
}

View File

@@ -0,0 +1,423 @@
package work.slhaf.partner.module.modules.action.dispatcher.scheduler
import com.cronutils.model.CronType
import com.cronutils.model.definition.CronDefinition
import com.cronutils.model.definition.CronDefinitionBuilder
import com.cronutils.model.time.ExecutionTime
import com.cronutils.parser.CronParser
import kotlinx.coroutines.*
import kotlinx.coroutines.flow.MutableStateFlow
import kotlinx.coroutines.flow.first
import kotlinx.coroutines.sync.Mutex
import kotlinx.coroutines.sync.withLock
import org.slf4j.LoggerFactory
import work.slhaf.partner.api.agent.factory.capability.annotation.InjectCapability
import work.slhaf.partner.api.agent.factory.module.annotation.AgentSubModule
import work.slhaf.partner.api.agent.factory.module.annotation.Init
import work.slhaf.partner.api.agent.factory.module.annotation.InjectModule
import work.slhaf.partner.api.agent.runtime.interaction.flow.abstracts.AgentRunningSubModule
import work.slhaf.partner.core.action.ActionCapability
import work.slhaf.partner.core.action.ActionCore
import work.slhaf.partner.core.action.entity.Schedulable
import work.slhaf.partner.core.action.entity.SchedulableExecutableAction
import work.slhaf.partner.core.action.entity.StateAction
import work.slhaf.partner.module.modules.action.dispatcher.executor.ActionExecutor
import work.slhaf.partner.module.modules.action.dispatcher.executor.entity.ActionExecutorInput
import java.io.Closeable
import java.time.Duration
import java.time.ZonedDateTime
import java.time.temporal.ChronoUnit
import java.util.stream.Collectors
import kotlin.jvm.optionals.getOrNull
@AgentSubModule
class ActionScheduler : AgentRunningSubModule<Set<Schedulable>, Void>() {
@InjectCapability
private lateinit var actionCapability: ActionCapability
@InjectModule
private lateinit var actionExecutor: ActionExecutor
private lateinit var timeWheel: TimeWheel
private val schedulerScope =
CoroutineScope(Dispatchers.Default + SupervisorJob() + CoroutineName("ActionScheduler"))
companion object {
private val log = LoggerFactory.getLogger(ActionScheduler::class.java)
}
@Init
fun init() {
fun loadScheduledActions() {
val listScheduledActions: () -> Set<SchedulableExecutableAction> = {
actionCapability.listActions(null, null)
.stream()
.filter { it is SchedulableExecutableAction }
.map { it as SchedulableExecutableAction }
.collect(Collectors.toSet())
}
val onTrigger: (Set<Schedulable>) -> Unit = { schedulableSet ->
val executableActions = mutableSetOf<SchedulableExecutableAction>()
val stateActions = mutableSetOf<StateAction>()
for (schedulable in schedulableSet) {
when (schedulable) {
is SchedulableExecutableAction -> executableActions.add(schedulable)
is StateAction -> stateActions.add(schedulable)
}
}
actionExecutor.execute(ActionExecutorInput(executableActions))
actionCapability.getExecutor(ActionCore.ExecutorType.VIRTUAL)
.execute { stateActions.forEach { it.trigger.onTrigger() } }
}
timeWheel = TimeWheel(listScheduledActions, onTrigger)
}
loadScheduledActions()
setupShutdownHook()
}
private fun setupShutdownHook() {
Runtime.getRuntime().addShutdownHook(Thread {
timeWheel.close()
schedulerScope.cancel()
})
}
override fun execute(schedulableSet: Set<Schedulable>?): Void? {
schedulerScope.launch {
schedulableSet?.run {
for (schedulableData in schedulableSet) {
log.debug("New data to schedule: {}", schedulableData)
timeWheel.schedule(schedulableData)
if (schedulableData is SchedulableExecutableAction) {
actionCapability.putAction(schedulableData)
}
}
}
}
return null
}
private class TimeWheel(
val listSource: () -> Set<Schedulable>,
val onTrigger: (toTrigger: Set<Schedulable>) -> Unit
) : Closeable {
private val schedulableGroupByHour = Array<MutableSet<Schedulable>>(24) { mutableSetOf() }
private val wheel = Array<MutableSet<Schedulable>>(60 * 60) { mutableSetOf() }
private var recordHour: Int = -1
private var recordDay: Int = -1
private val state = MutableStateFlow(WheelState.SLEEPING)
private val wheelActionsLock = Mutex()
private val timeWheelScope = CoroutineScope(SupervisorJob() + Dispatchers.Default + CoroutineName("TimeWheel"))
private val cronDefinition: CronDefinition = CronDefinitionBuilder.instanceDefinitionFor(CronType.QUARTZ)
private val cronParser: CronParser = CronParser(cronDefinition)
init {
// 启动时间轮
wheel()
}
suspend fun schedule(schedulableData: Schedulable) {
checkThenExecute {
val parseToZonedDateTime = parseToZonedDateTime(
schedulableData.scheduleType,
schedulableData.scheduleContent,
it
) ?: run {
logFailedStatus(schedulableData)
return@checkThenExecute
}
log.debug("Action next execution time: {}", parseToZonedDateTime)
val hour = parseToZonedDateTime.hour
schedulableGroupByHour[hour].add(schedulableData)
log.debug("Action scheduled at {}", hour)
if (it.hour == hour) {
val wheelOffset = parseToZonedDateTime.minute * 60 + parseToZonedDateTime.second
wheel[wheelOffset].add(schedulableData)
state.value = WheelState.ACTIVE
log.debug("Action scheduled at wheel offset {}", wheelOffset)
}
}
}
private fun wheel() {
data class WheelStepResult(
val toTrigger: Set<Schedulable>?,
val shouldBreak: Boolean
)
fun collectToTrigger(tick: Int, previousTick: Int, triggerHour: Int): Set<Schedulable>? {
if (tick > previousTick) {
val toTrigger = mutableSetOf<Schedulable>()
for (i in previousTick..tick) {
val bucket = wheel[i]
if (bucket.isNotEmpty()) {
toTrigger.addAll(bucket)
val bucketUuids = bucket.asSequence().map { it.uuid }.toHashSet()
schedulableGroupByHour[triggerHour].removeIf { it.uuid in bucketUuids }
bucket.clear() // 避免重复触发
}
}
return toTrigger
}
return null
}
suspend fun CoroutineScope.wheel(launchingTime: ZonedDateTime, primaryTickAdvanceTime: Long) {
val launchingHour = launchingTime.hour
var tick = launchingTime.minute * 60 + launchingTime.second
// 让节拍器从“启动时刻的下一秒”开始(避免立即 step=0
var nextTickNanos = primaryTickAdvanceTime + 1_000_000_000L
while (isActive) {
// 1) 计算落后多少秒:至少 1正常推进也可能 >1追赶
val now0 = System.nanoTime()
val lagNanos = now0 - nextTickNanos
val step = if (lagNanos < 0) 1 else (lagNanos / 1_000_000_000L).toInt() + 1
val previousTick = tick
tick = (tick + step).coerceAtMost(wheel.lastIndex)
// 2) 推进节拍器:按“理论秒”前进 step 次
nextTickNanos += step.toLong() * 1_000_000_000L
val stepResult = run {
var shouldBreak = false
var toTrigger: Set<Schedulable>? = null
checkThenExecute(false) {
if (it.hour != launchingHour) {
shouldBreak = true
toTrigger = collectToTrigger(wheel.lastIndex, previousTick, launchingHour)
log.debug(
"Hour changed, previousTick: {}, tick: {}, toTriggerSize: {}",
previousTick,
tick,
toTrigger?.size
)
return@checkThenExecute
}
toTrigger = collectToTrigger(tick, previousTick, launchingHour)
if (tick >= wheel.lastIndex || schedulableGroupByHour[launchingHour].isEmpty()) {
state.value = WheelState.SLEEPING
shouldBreak = true
}
}
WheelStepResult(toTrigger, shouldBreak)
}
stepResult.toTrigger?.let { trigger ->
timeWheelScope.launch {
onTrigger(trigger)
}
}
if (stepResult.shouldBreak) {
log.debug("Wheel stopped at tick {}", tick)
break
}
// 3) 精确睡到下一次理论 tick用最新 nanoTime
val now1 = System.nanoTime()
val sleepNanos = nextTickNanos - now1
if (sleepNanos > 0) {
delay(sleepNanos / 1_000_000L) // 毫秒级 delay 足够;剩余 nanos 不必忙等
}
}
}
suspend fun wait(currentTime: ZonedDateTime) {
val nextHour = currentTime.truncatedTo(ChronoUnit.HOURS).plusHours(1)
val seconds = Duration.between(
currentTime, nextHour
).toMillis()
// withTimeoutOrNull 内部已处理 seconds 小于 0 的情况
log.debug("Start waiting {} ms at {}, target time: {}", seconds, currentTime, nextHour)
withTimeoutOrNull(seconds) {
state.first { it == WheelState.ACTIVE }
}
log.debug("Waiting ended at {}", ZonedDateTime.now())
}
timeWheelScope.launch {
while (isActive) {
// 判断是否该步入下一小时
var shouldWait: Boolean? = null
var currentTime: ZonedDateTime? = null
var primaryTickAdvanceTime: Long? = null
checkThenExecute {
currentTime = it
shouldWait = schedulableGroupByHour[it.hour].isEmpty()
// 由于 wheel 的启动时间可能存在延迟,而时内推进由 nanoTime 保证不会漏发,
// 正常的时序结束又由 tick 是否触顶、当前时是否存在额外任务触发,
// 而启动时无触发保障,此时一并初始化 tick 推进时间,足以应对 check 与 wheel 间的这段时间间隔
primaryTickAdvanceTime = System.nanoTime()
}
// 如果该时无任务则等待,插入事件可提前唤醒
if (shouldWait!!) {
// 计算距离下一小时的时间,等待
currentTime?.let { wait(it) }
continue
}
// 唤醒进行时间轮循环
wheel(currentTime!!, primaryTickAdvanceTime!!)
}
}
}
suspend fun checkThenExecute(finallyToExecute: Boolean = true, then: (currentTime: ZonedDateTime) -> Unit) =
wheelActionsLock.withLock {
fun loadActions(
source: Set<Schedulable>,
now: ZonedDateTime,
load: (latestExecutingTime: ZonedDateTime, schedulableData: Schedulable) -> Unit,
repair: () -> Unit
) {
val runLoading = {
for (schedulableData in source) {
val nextExecutingTime =
parseToZonedDateTime(
schedulableData.scheduleType,
schedulableData.scheduleContent,
now
) ?: run {
logFailedStatus(schedulableData)
continue
}
load(nextExecutingTime, schedulableData)
}
}
repair()
runLoading()
}
fun loadHourActions(currentTime: ZonedDateTime) {
val load: (ZonedDateTime, Schedulable) -> Unit =
{ latestExecutionTime, schedulableData ->
val secondsTime = latestExecutionTime.minute * 60 + latestExecutionTime.second
wheel[secondsTime].add(schedulableData)
log.debug("Action loaded to hour: {}", schedulableData)
}
val repair: () -> Unit = {
for (set in wheel) {
set.clear()
}
}
loadActions(schedulableGroupByHour[currentTime.hour], currentTime, load, repair)
}
fun loadDayActions(currentTime: ZonedDateTime) {
val load: (ZonedDateTime, Schedulable) -> Unit =
{ latestExecutingTime, schedulableData ->
schedulableGroupByHour[latestExecutingTime.hour].add(schedulableData)
log.debug("Action loaded to day: {}", schedulableData)
}
val repair: () -> Unit = {
for (set in schedulableGroupByHour) {
set.clear()
}
}
loadActions(listSource(), currentTime, load, repair)
}
fun refreshIfNeeded(now: ZonedDateTime) {
val d = now.dayOfMonth
val h = now.hour
if (d != recordDay) {
recordDay = d
recordHour = h
loadDayActions(now)
loadHourActions(now)
} else if (h != recordHour) {
recordHour = h
loadHourActions(now)
}
}
val now = ZonedDateTime.now()
if (finallyToExecute) {
refreshIfNeeded(now)
then(now)
} else {
then(now)
refreshIfNeeded(now)
}
}
private fun parseToZonedDateTime(
scheduleType: Schedulable.ScheduleType,
scheduleContent: String,
now: ZonedDateTime
): ZonedDateTime? {
return when (scheduleType) {
Schedulable.ScheduleType.CYCLE
-> {
val cron = try {
cronParser.parse(scheduleContent).validate()
} catch (_: Exception) {
return null
}
val executionTime = ExecutionTime.forCron(cron)
executionTime.nextExecution(now).getOrNull()
}
Schedulable.ScheduleType.ONCE -> {
val executionTime = try {
ZonedDateTime.parse(scheduleContent)
} catch (_: Exception) {
return null
}
if (executionTime.plusSeconds(1).isBefore(now) || executionTime.dayOfMonth != now.dayOfMonth)
null
else
executionTime
}
}
}
private fun logFailedStatus(scheduleData: Schedulable) {
log.warn(
"行动未加载scheduleType: {}, scheduleContent: {}",
scheduleData.scheduleType,
scheduleData.scheduleContent,
)
}
override fun close() {
timeWheelScope.cancel()
}
private enum class WheelState {
ACTIVE,
SLEEPING,
}
}
}

View File

@@ -0,0 +1,251 @@
package work.slhaf.partner.module.modules.action.interventor;
import com.alibaba.fastjson2.JSONArray;
import com.alibaba.fastjson2.JSONObject;
import lombok.val;
import work.slhaf.partner.api.agent.factory.capability.annotation.InjectCapability;
import work.slhaf.partner.api.agent.factory.module.annotation.AgentModule;
import work.slhaf.partner.api.agent.factory.module.annotation.InjectModule;
import work.slhaf.partner.api.agent.runtime.interaction.flow.abstracts.ActivateModel;
import work.slhaf.partner.core.action.ActionCapability;
import work.slhaf.partner.core.action.ActionCore;
import work.slhaf.partner.core.action.entity.ExecutableAction;
import work.slhaf.partner.core.action.entity.PhaserRecord;
import work.slhaf.partner.core.cognation.CognationCapability;
import work.slhaf.partner.core.memory.MemoryCapability;
import work.slhaf.partner.module.common.module.PreRunningModule;
import work.slhaf.partner.module.modules.action.interventor.entity.InterventionType;
import work.slhaf.partner.module.modules.action.interventor.entity.MetaIntervention;
import work.slhaf.partner.module.modules.action.interventor.evaluator.InterventionEvaluator;
import work.slhaf.partner.module.modules.action.interventor.evaluator.entity.EvaluatorInput;
import work.slhaf.partner.module.modules.action.interventor.evaluator.entity.EvaluatorResult;
import work.slhaf.partner.module.modules.action.interventor.evaluator.entity.EvaluatorResult.EvaluatedInterventionData;
import work.slhaf.partner.module.modules.action.interventor.recognizer.InterventionRecognizer;
import work.slhaf.partner.module.modules.action.interventor.recognizer.entity.RecognizerInput;
import work.slhaf.partner.module.modules.action.interventor.recognizer.entity.RecognizerResult;
import work.slhaf.partner.runtime.interaction.data.context.PartnerRunningFlowContext;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Stream;
/**
* 负责识别潜在的行动干预信息,作用于正在进行或已存在的行动池中内容
*/
@AgentModule(name = "action_identifier", order = 2)
public class ActionInterventor extends PreRunningModule implements ActivateModel {
@InjectModule
private InterventionRecognizer interventionRecognizer;
@InjectModule
private InterventionEvaluator interventionEvaluator;
@InjectCapability
private ActionCapability actionCapability;
@InjectCapability
private CognationCapability cognationCapability;
@InjectCapability
private MemoryCapability memoryCapability;
private final AssemblyHelper assemblyHelper = new AssemblyHelper();
private final PromptHelper promptHelper = new PromptHelper();
/**
* 键: 本次调用uuid
* 值本次调用对应的prompt
*/
private final Map<String, Map<String, String>> interventionPrompt = new HashMap<>();
@Override
protected void doExecute(PartnerRunningFlowContext context) {
// 综合当前正在进行的行动链信息、用户交互历史、激活的记忆切片,尝试识别出是否存在行动干预意图
// 首先通过recognizer进行快速意图识别识别成功则步入评估阶段评估成功则直接作用于目标行动链
// 进行快速意图识别时必须结合近期对话与进行中行动链情况
// 干预意图识别
String uuid = context.getUuid();
String userId = context.getUserId();
RecognizerResult recognizerResult = interventionRecognizer
.execute(assemblyHelper.buildRecognizerInput(userId, context.getInput())); // 此处的输入内容携带了所有 PhaserRecord
if (!recognizerResult.isOk()) {
promptHelper.setupNoInterventionPrompt(uuid);
return;
}
// 干预意图评估
EvaluatorResult evaluatorResult = interventionEvaluator
.execute(assemblyHelper.buildEvaluatorInput(recognizerResult, userId));
List<EvaluatedInterventionData> executingDataList = evaluatorResult.getExecutingDataList();
List<EvaluatedInterventionData> preparedDataList = evaluatorResult.getPreparedDataList();
// 意图评估结果处理
if (evaluatorResult.isOk()) {
// 对存在异常ActionKey的评估结果列表进行过滤
invalidActionKeysFilter(executingDataList);
invalidActionKeysFilter(preparedDataList);
// 同步写入prompt异步处理干预行为异步在处理流程中体现
promptHelper.setupInterventionPrompt(uuid, executingDataList, preparedDataList);
handleInterventions(executingDataList, recognizerResult.getExecutingInterventions());
handleInterventions(preparedDataList, recognizerResult.getPreparedInterventions());
} else {
promptHelper.setupInterventionIgnoredPrompt(uuid, executingDataList, preparedDataList);
}
}
private void handleInterventions(List<EvaluatedInterventionData> interventionDataList, Map<String, ExecutableAction> interventionDataMap) {
val executor = actionCapability.getExecutor(ActionCore.ExecutorType.PLATFORM);
executor.execute(() -> {
for (EvaluatedInterventionData interventionData : interventionDataList) {
// 此处拿到的为 ActionData 或者 PhaserRecord, 来自 Recognizer 的封装
val data = interventionDataMap.get(interventionData.getTendency());
actionCapability.handleInterventions(interventionData.getMetaInterventionList(), data);
}
});
}
private void invalidActionKeysFilter(List<EvaluatedInterventionData> interventions) {
List<EvaluatedInterventionData> toRemove = new ArrayList<>();
for (EvaluatedInterventionData intervention : interventions) {
List<MetaIntervention> interventionData = intervention.getMetaInterventionList();
List<String> actions = new ArrayList<>();
for (MetaIntervention metaData : interventionData) {
actions.addAll(metaData.getActions());
}
// 如果存在异常行动key则可视为该评估结果存在问题直接忽略该结果
if (!actionCapability.checkExists(actions.toArray(String[]::new))) {
toRemove.add(intervention);
}
// 针对 REBUILD 类型进行特殊校验, REBUILD 类型必须满足所有 MetaIntervention 的类型均为 REBUILD
if (!checkRebuildType(interventionData)) {
toRemove.add(intervention);
}
}
interventions.removeAll(toRemove);
}
private boolean checkRebuildType(List<MetaIntervention> interventionData) {
boolean hasRebuild = false;
for (MetaIntervention meta : interventionData) {
if (meta.getType() == InterventionType.REBUILD) {
hasRebuild = true;
} else if (hasRebuild) {
// 已经存在REBUILD类型但又发现了非REBUILD类型不合法
return false;
}
}
return true;
}
@Override
public String modelKey() {
return "action_identifier";
}
@Override
public boolean withBasicPrompt() {
return false;
}
@Override
protected Map<String, String> getPromptDataMap(PartnerRunningFlowContext context) {
return interventionPrompt.remove(context.getUuid());
}
@Override
protected String moduleName() {
return "[行动干预识别模块]";
}
private final class AssemblyHelper {
private AssemblyHelper() {
}
private RecognizerInput buildRecognizerInput(String userId, String input) {
RecognizerInput recognizerInput = new RecognizerInput();
recognizerInput.setInput(input);
recognizerInput.setUserDialogMapStr(memoryCapability.getUserDialogMapStr(userId));
// 参考的对话列表大小或需调整
recognizerInput.setRecentMessages(cognationCapability.getChatMessages());
recognizerInput.setExecutingActions(actionCapability.listPhaserRecords().stream().map(PhaserRecord::executableAction).toList());
recognizerInput.setPreparedActions(actionCapability.listActions(ExecutableAction.Status.PREPARE, userId).stream().toList());
return recognizerInput;
}
private EvaluatorInput buildEvaluatorInput(RecognizerResult recognizerResult, String userId) {
EvaluatorInput input = new EvaluatorInput();
input.setExecutingInterventions(recognizerResult.getExecutingInterventions());
input.setPreparedInterventions(recognizerResult.getPreparedInterventions());
input.setRecentMessages(cognationCapability.getChatMessages());
input.setActivatedSlices(memoryCapability.getActivatedSlices(userId));
return input;
}
}
private final class PromptHelper {
private PromptHelper() {
}
private void setupInterventionIgnoredPrompt(String uuid, List<EvaluatedInterventionData> executingDataList, List<EvaluatedInterventionData> preparedDataList) {
List<EvaluatedInterventionData> total = Stream.concat(executingDataList.stream(), preparedDataList.stream()).toList();
JSONArray reasons = new JSONArray();
for (EvaluatedInterventionData data : total) {
JSONObject reason = reasons.addObject();
reason.put("[干预倾向]", data.getTendency());
reason.put("[未采用原因]", data.getDescription());
}
synchronized (interventionPrompt) {
interventionPrompt.put(uuid, Map.of(
"[识别状态] <是否识别到干预已存在行动的意图>", "识别到,但都未采用",
"[忽略原因] <各个意图被忽略的原因>", reasons.toString(),
"[干预行动] <将对已存在行动做出的行为>", "无行为"));
}
}
private void setupInterventionPrompt(String uuid, List<EvaluatedInterventionData> executingDataList,
List<EvaluatedInterventionData> preparedDataList) {
JSONArray contents = new JSONArray();
List<EvaluatedInterventionData> temp = Stream.concat(executingDataList.stream(), preparedDataList.stream()).toList();
for (EvaluatedInterventionData data : temp) {
if (!data.isOk()) {
continue;
}
String tendency = data.getTendency();
JSONObject newElement = contents.addObject();
newElement.put("[干预倾向]", tendency);
JSONArray changes = newElement.putArray("[行动链变动情况]");
for (MetaIntervention intervention : data.getMetaInterventionList()) {
JSONObject change = changes.addObject();
change.put("[干预类型]", intervention.getType());
change.put("[干预序号]", intervention.getOrder());
change.putArray("[干预内容]").addAll(intervention.getActions());
}
}
synchronized (interventionPrompt) {
interventionPrompt.put(uuid, Map.of(
"[识别状态] <是否识别到干预已存在行动的意图>", "识别到,将采用",
"[干预内容] <将对已存在行动做出的行为>", contents.toString()));
}
}
private void setupNoInterventionPrompt(String uuid) {
interventionPrompt.put(uuid, Map.of(
"[识别状态] <是否识别到干预已存在行动的意图>", "未识别到干预意图",
"[干预行动] <将对已存在行动做出的行为>", "无行动"));
}
}
}

View File

@@ -0,0 +1,28 @@
package work.slhaf.partner.module.modules.action.interventor.entity;
public enum InterventionType {
/**
* 追加行动: 追加至指定行动链序列之后才执行
*/
APPEND,
/**
* 插入行动: 指定行动链序列执行过程中即时新增并执行
*/
INSERT,
/**
* 重建行动: 重建指定行动链序列之后的所有行动内容
*/
REBUILD,
/**
* 删除行动: 删除指定行动链序列上的指定行动单元
*/
DELETE,
/**
* 取消行动链: 中断并取消指定行动链的执行
*/
CANCEL
}

View File

@@ -0,0 +1,21 @@
package work.slhaf.partner.module.modules.action.interventor.entity;
import lombok.Data;
import java.util.List;
@Data
public class MetaIntervention {
/**
* 干预数据类型
*/
private InterventionType type;
/**
* 干预数据对应的行动链序列
*/
private int order;
/**
* 干预数据所需的行动key列表
*/
private List<String> actions;
}

View File

@@ -0,0 +1,99 @@
package work.slhaf.partner.module.modules.action.interventor.evaluator;
import com.alibaba.fastjson2.JSONObject;
import lombok.extern.slf4j.Slf4j;
import work.slhaf.partner.api.agent.factory.capability.annotation.InjectCapability;
import work.slhaf.partner.api.agent.factory.module.annotation.AgentSubModule;
import work.slhaf.partner.api.agent.runtime.interaction.flow.abstracts.ActivateModel;
import work.slhaf.partner.api.agent.runtime.interaction.flow.abstracts.AgentRunningSubModule;
import work.slhaf.partner.api.chat.pojo.ChatResponse;
import work.slhaf.partner.api.chat.pojo.Message;
import work.slhaf.partner.core.action.ActionCapability;
import work.slhaf.partner.core.action.ActionCore.ExecutorType;
import work.slhaf.partner.core.action.entity.ExecutableAction;
import work.slhaf.partner.core.memory.pojo.EvaluatedSlice;
import work.slhaf.partner.module.modules.action.interventor.evaluator.entity.EvaluatorInput;
import work.slhaf.partner.module.modules.action.interventor.evaluator.entity.EvaluatorResult;
import work.slhaf.partner.module.modules.action.interventor.evaluator.entity.EvaluatorResult.EvaluatedInterventionData;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
@Slf4j
@AgentSubModule
public class InterventionEvaluator extends AgentRunningSubModule<EvaluatorInput, EvaluatorResult>
implements ActivateModel {
@InjectCapability
private ActionCapability actionCapability;
/**
* 基于干预意图、记忆切片、交互上下文、已有行动程序综合评估,尝试评估并选取出合适的行动程序,交付给 ActionInterventor
*/
@Override
public EvaluatorResult execute(EvaluatorInput input) {
// 获取必须数据
ExecutorService executor = actionCapability.getExecutor(ExecutorType.VIRTUAL);
Map<String, ExecutableAction> executingInterventions = input.getExecutingInterventions();
Map<String, ExecutableAction> preparedInterventions = input.getPreparedInterventions();
CountDownLatch latch = new CountDownLatch(executingInterventions.size() + preparedInterventions.size());
// 创建结果容器
EvaluatorResult result = new EvaluatorResult();
List<EvaluatedInterventionData> executingDataList = result.getExecutingDataList();
List<EvaluatedInterventionData> preparedDataList = result.getPreparedDataList();
// 并发评估
evaluateIntervention(executingDataList, executingInterventions, input, executor, latch);
evaluateIntervention(preparedDataList, preparedInterventions, input, executor, latch);
try {
latch.await();
} catch (InterruptedException e) {
log.warn("CountDownLatch阻塞已中断");
}
return result;
}
private void evaluateIntervention(List<EvaluatedInterventionData> evaluatedDataList, Map<String, ExecutableAction> interventionMap, EvaluatorInput input, ExecutorService executor, CountDownLatch latch) {
interventionMap.forEach((tendency, actionData) -> executor.execute(() -> {
try {
String prompt = buildPrompt(input.getRecentMessages(), input.getActivatedSlices(), actionData, tendency);
ChatResponse response = this.singleChat(prompt);
EvaluatedInterventionData evaluatedData = JSONObject.parseObject(response.getMessage(),
EvaluatedInterventionData.class);
synchronized (evaluatedDataList) {
evaluatedDataList.add(evaluatedData);
}
} catch (Exception e) {
log.error("干预意图评估出错: {}", tendency, e);
} finally {
latch.countDown();
}
}));
}
private String buildPrompt(List<Message> recentMessages, List<EvaluatedSlice> activatedSlices,
ExecutableAction executableAction, String tendency) {
JSONObject json = new JSONObject();
json.put("干预倾向", tendency);
json.putArray("近期对话").addAll(recentMessages);
json.putArray("参考记忆").addAll(activatedSlices);
json.put("将干预的行动", JSONObject.toJSONString(executableAction));
return json.toJSONString();
}
@Override
public String modelKey() {
return "intervention_evaluator";
}
@Override
public boolean withBasicPrompt() {
return false;
}
}

View File

@@ -0,0 +1,17 @@
package work.slhaf.partner.module.modules.action.interventor.evaluator.entity;
import lombok.Data;
import work.slhaf.partner.api.chat.pojo.Message;
import work.slhaf.partner.core.action.entity.ExecutableAction;
import work.slhaf.partner.core.memory.pojo.EvaluatedSlice;
import java.util.List;
import java.util.Map;
@Data
public class EvaluatorInput {
private Map<String, ExecutableAction> executingInterventions;
private Map<String, ExecutableAction> preparedInterventions;
private List<EvaluatedSlice> activatedSlices;
private List<Message> recentMessages;
}

View File

@@ -0,0 +1,33 @@
package work.slhaf.partner.module.modules.action.interventor.evaluator.entity;
import lombok.Data;
import work.slhaf.partner.module.modules.action.interventor.entity.MetaIntervention;
import java.util.List;
/**
* 干预倾向评估结果包含评估通过的倾向文本、对行动链的行为、指定操作的行动单元key、未通过的原因
*/
@Data
public class EvaluatorResult {
/**
* 是否存在通过的干预倾向
*/
private boolean ok;
private List<EvaluatedInterventionData> executingDataList;
private List<EvaluatedInterventionData> preparedDataList;
@Data
public static class EvaluatedInterventionData {
/**
* 是否通过
*/
private boolean ok;
private String tendency;
/**
* 描述信息(包括通过、失败原因)
*/
private String description;
private List<MetaIntervention> metaInterventionList;
}
}

View File

@@ -0,0 +1,102 @@
package work.slhaf.partner.module.modules.action.interventor.recognizer;
import com.alibaba.fastjson2.JSONObject;
import lombok.extern.slf4j.Slf4j;
import work.slhaf.partner.api.agent.factory.capability.annotation.InjectCapability;
import work.slhaf.partner.api.agent.factory.module.annotation.AgentSubModule;
import work.slhaf.partner.api.agent.runtime.interaction.flow.abstracts.ActivateModel;
import work.slhaf.partner.api.agent.runtime.interaction.flow.abstracts.AgentRunningSubModule;
import work.slhaf.partner.api.chat.pojo.ChatResponse;
import work.slhaf.partner.core.action.ActionCapability;
import work.slhaf.partner.core.action.ActionCore;
import work.slhaf.partner.core.action.entity.ExecutableAction;
import work.slhaf.partner.module.modules.action.interventor.recognizer.entity.MetaRecognizerResult;
import work.slhaf.partner.module.modules.action.interventor.recognizer.entity.RecognizerInput;
import work.slhaf.partner.module.modules.action.interventor.recognizer.entity.RecognizerResult;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
@Slf4j
@AgentSubModule
public class InterventionRecognizer extends AgentRunningSubModule<RecognizerInput, RecognizerResult> implements ActivateModel {
@InjectCapability
private ActionCapability actionCapability;
@Override
public RecognizerResult execute(RecognizerInput input) {
// 获取必须数据
ExecutorService executor = actionCapability.getExecutor(ActionCore.ExecutorType.VIRTUAL);
List<ExecutableAction> executingActions = input.getExecutingActions();
List<ExecutableAction> preparedActions = input.getPreparedActions();
CountDownLatch countDownLatch = new CountDownLatch(executingActions.size() + preparedActions.size());
// 创建结果容器
RecognizerResult recognizerResult = new RecognizerResult();
Map<String, ExecutableAction> executingInterventions = recognizerResult.getExecutingInterventions();
Map<String, ExecutableAction> preparedInterventions = recognizerResult.getPreparedInterventions();
// 执行识别操作
recognizeIntervention(executingInterventions, executingActions, executor, input, countDownLatch);
recognizeIntervention(preparedInterventions, preparedActions, executor, input, countDownLatch);
try {
countDownLatch.await();
} catch (InterruptedException e) {
log.warn("CountDownLatch阻塞已中断");
}
return recognizerResult;
}
private void recognizeIntervention(Map<String, ExecutableAction> interventionsMap, List<ExecutableAction> actions, ExecutorService executor, RecognizerInput input, CountDownLatch latch) {
for (ExecutableAction data : actions) {
executor.execute(() -> {
try {
String prompt = buildPrompt(data, input);
ChatResponse response = this.singleChat(prompt);
MetaRecognizerResult result = JSONObject.parseObject(response.getMessage(), MetaRecognizerResult.class);
if (result.isOk()) {
synchronized (interventionsMap) {
interventionsMap.put(result.getIntervention(), data);
}
}
} catch (Exception e) {
log.error("LLM干预意图提取出错", e);
} finally {
latch.countDown();
}
});
}
}
private String buildPrompt(ExecutableAction executableAction, RecognizerInput input) {
JSONObject json = new JSONObject();
JSONObject actionInfo = json.putObject("行动信息");
actionInfo.put("行动倾向", executableAction.getTendency());
actionInfo.put("行动原因", executableAction.getReason());
actionInfo.put("行动描述", executableAction.getDescription());
actionInfo.put("行动状态", executableAction.getStatus());
actionInfo.put("行动来源", executableAction.getSource());
JSONObject interactionInfo = json.putObject("交互信息");
interactionInfo.put("用户输入", input.getInput());
interactionInfo.put("当前对话", input.getRecentMessages());
interactionInfo.put("近期对话", input.getUserDialogMapStr());
return json.toString();
}
@Override
public String modelKey() {
return "intervention_recognizer";
}
@Override
public boolean withBasicPrompt() {
return false;
}
}

View File

@@ -0,0 +1,9 @@
package work.slhaf.partner.module.modules.action.interventor.recognizer.entity;
import lombok.Data;
@Data
public class MetaRecognizerResult {
private boolean ok;
private String intervention;
}

View File

@@ -0,0 +1,22 @@
package work.slhaf.partner.module.modules.action.interventor.recognizer.entity;
import lombok.Data;
import work.slhaf.partner.api.chat.pojo.Message;
import work.slhaf.partner.core.action.entity.ExecutableAction;
import java.util.List;
@Data
public class RecognizerInput {
private String input;
private List<Message> recentMessages;
/**
* 当前用户对应的近两日对话缓存
*/
private String userDialogMapStr;
/**
* 正在执行的行动-Phaser记录列表在Recognizer中结合本次输入并发评估(考虑到不同行动链之间对LLM的影响)
*/
private List<ExecutableAction> executingActions;
private List<ExecutableAction> preparedActions;
}

View File

@@ -0,0 +1,29 @@
package work.slhaf.partner.module.modules.action.interventor.recognizer.entity;
import lombok.Data;
import work.slhaf.partner.core.action.entity.ExecutableAction;
import java.util.HashMap;
import java.util.Map;
@Data
public class RecognizerResult {
private boolean ok;
/**
* <h4>将被干预的‘执行中行动’</h4>
* key: 干预倾向
* <br/>
* value: 干预倾向将作用的行动数据
*/
private Map<String, ExecutableAction> executingInterventions = new HashMap<>();
/**
* <h4>将被干预的‘等待中行动’</h4>
* key: 干预倾向
* <br/>
* value: 干预倾向将作用的行动数据
*/
private Map<String, ExecutableAction> preparedInterventions = new HashMap<>();
}

View File

@@ -0,0 +1,340 @@
package work.slhaf.partner.module.modules.action.planner;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import org.jetbrains.annotations.NotNull;
import work.slhaf.partner.api.agent.factory.capability.annotation.InjectCapability;
import work.slhaf.partner.api.agent.factory.module.annotation.AgentModule;
import work.slhaf.partner.api.agent.factory.module.annotation.Init;
import work.slhaf.partner.api.agent.factory.module.annotation.InjectModule;
import work.slhaf.partner.api.chat.pojo.Message;
import work.slhaf.partner.common.vector.VectorClient;
import work.slhaf.partner.core.action.ActionCapability;
import work.slhaf.partner.core.action.ActionCore;
import work.slhaf.partner.core.action.entity.*;
import work.slhaf.partner.core.action.entity.cache.CacheAdjustData;
import work.slhaf.partner.core.action.entity.cache.CacheAdjustMetaData;
import work.slhaf.partner.core.cognation.CognationCapability;
import work.slhaf.partner.core.memory.MemoryCapability;
import work.slhaf.partner.core.perceive.PerceiveCapability;
import work.slhaf.partner.module.common.module.PreRunningModule;
import work.slhaf.partner.module.modules.action.planner.confirmer.ActionConfirmer;
import work.slhaf.partner.module.modules.action.planner.confirmer.entity.ConfirmerInput;
import work.slhaf.partner.module.modules.action.planner.confirmer.entity.ConfirmerResult;
import work.slhaf.partner.module.modules.action.planner.evaluator.ActionEvaluator;
import work.slhaf.partner.module.modules.action.planner.evaluator.entity.EvaluatorInput;
import work.slhaf.partner.module.modules.action.planner.evaluator.entity.EvaluatorResult;
import work.slhaf.partner.module.modules.action.planner.extractor.ActionExtractor;
import work.slhaf.partner.module.modules.action.planner.extractor.entity.ExtractorInput;
import work.slhaf.partner.module.modules.action.planner.extractor.entity.ExtractorResult;
import work.slhaf.partner.runtime.interaction.data.context.PartnerRunningFlowContext;
import java.util.*;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.atomic.AtomicBoolean;
/**
* 负责针对本次输入生成基础的行动计划,在主模型传达意愿后,执行行动或者放入计划池
*/
@Slf4j
@AgentModule(name = "action_planner", order = 2)
public class ActionPlanner extends PreRunningModule {
@InjectCapability
private CognationCapability cognationCapability;
@InjectCapability
private ActionCapability actionCapability;
@InjectCapability
private PerceiveCapability perceiveCapability;
@InjectCapability
private MemoryCapability memoryCapability;
@InjectModule
private ActionEvaluator actionEvaluator;
@InjectModule
private ActionExtractor actionExtractor;
@InjectModule
private ActionConfirmer actionConfirmer;
private ExecutorService executor;
private final ActionAssemblyHelper assemblyHelper = new ActionAssemblyHelper();
@Init
public void init() {
executor = actionCapability.getExecutor(ActionCore.ExecutorType.VIRTUAL);
}
@Override
protected void doExecute(PartnerRunningFlowContext context) {
try {
List<Callable<Void>> tasks = new ArrayList<>();
addConfirmTask(tasks, context);
addNewActionTask(tasks, context);
executor.invokeAll(tasks);
} catch (Exception e) {
log.error("执行异常", e);
}
}
/**
* 新的提取与评估任务
*
* @param tasks 并发任务列表
* @param context 流程上下文
*/
private void addNewActionTask(List<Callable<Void>> tasks, PartnerRunningFlowContext context) {
tasks.add(() -> {
ExtractorInput extractorInput = assemblyHelper.buildExtractorInput(context);
ExtractorResult extractorResult = actionExtractor.execute(extractorInput);
if (extractorResult.getTendencies().isEmpty()) {
return null;
}
EvaluatorInput evaluatorInput = assemblyHelper.buildEvaluatorInput(extractorResult, context.getUserId());
List<EvaluatorResult> evaluatorResults = actionEvaluator.execute(evaluatorInput); // 并发操作均为访问
putActionData(evaluatorResults, context);
updateTendencyCache(evaluatorResults, context.getInput(), extractorResult);
return null;
});
}
private void updateTendencyCache(List<EvaluatorResult> evaluatorResults, String input,
ExtractorResult extractorResult) {
if (!VectorClient.status) {
return;
}
executor.execute(() -> {
CacheAdjustData data = new CacheAdjustData();
List<CacheAdjustMetaData> list = new ArrayList<>();
List<String> hitTendencies = extractorResult.getTendencies();
for (EvaluatorResult result : evaluatorResults) {
CacheAdjustMetaData metaData = new CacheAdjustMetaData();
metaData.setTendency(result.getTendency());
metaData.setPassed(result.isOk());
metaData.setHit(hitTendencies.contains(result.getTendency()));
list.add(metaData);
}
data.setMetaDataList(list);
data.setInput(input);
actionCapability.updateTendencyCache(data);
});
}
/**
* 待确认行动的判断任务
*
* @param tasks 并发任务列表
* @param context 流程上下文
*/
private void addConfirmTask(List<Callable<Void>> tasks, PartnerRunningFlowContext context) {
tasks.add(() -> {
ConfirmerInput confirmerInput = assemblyHelper.buildConfirmerInput(context);
ConfirmerResult result = actionConfirmer.execute(confirmerInput);
setupConfirmedActionInfo(context, result);
return null;
});
}
private void setupConfirmedActionInfo(PartnerRunningFlowContext context, ConfirmerResult result) {
// TODO 需考虑未确认任务的失效或者拒绝时机在action core中实现
List<String> uuids = result.getUuids();
if (uuids == null) {
return;
}
List<ExecutableAction> pendingActions = actionCapability.popPendingAction(context.getUserId());
for (ExecutableAction executableAction : pendingActions) {
if (uuids.contains(executableAction.getUuid())) {
actionCapability.putAction(executableAction);
}
}
}
private void putActionData(List<EvaluatorResult> evaluatorResults, PartnerRunningFlowContext context) {
for (EvaluatorResult evaluatorResult : evaluatorResults) {
ExecutableAction executableAction = assemblyHelper.buildActionData(evaluatorResult, context.getUserId());
if (evaluatorResult.isNeedConfirm()) {
actionCapability.putPendingActions(context.getUserId(), executableAction);
} else {
actionCapability.putAction(executableAction);
}
}
}
@Override
protected Map<String, String> getPromptDataMap(PartnerRunningFlowContext context) {
HashMap<String, String> map = new HashMap<>();
String userId = context.getUserId();
setupPendingActions(map, userId);
setupPreparedActions(map, userId);
return map;
}
private void setupPendingActions(HashMap<String, String> map, String userId) {
List<ExecutableAction> executableActionData = actionCapability.listPendingAction(userId);
if (executableActionData == null || executableActionData.isEmpty()) {
map.put("[待确认行动] <等待用户确认的行动信息>", "无待确认行动");
return;
}
for (int i = 0; i < executableActionData.size(); i++) {
map.put("[待确认行动 " + (i + 1) + " ] <等待用户确认的行动信息>", generateActionStr(executableActionData.get(i)));
}
}
private void setupPreparedActions(HashMap<String, String> map, String userId) {
val preparedActions = actionCapability.listActions(ExecutableAction.Status.PREPARE, userId).stream().toList();
if (preparedActions.isEmpty()) {
map.put("[预备行动] <预备执行或放入计划池的行动信息>", "无预备行动");
return;
}
for (int i = 0; i < preparedActions.size(); i++) {
map.put("[预备行动 " + (i + 1) + " ] <预备执行或放入计划池的行动信息>", generateActionStr(preparedActions.get(i)));
}
}
private String generateActionStr(ExecutableAction executableAction) {
return "<行动倾向>" + " : " + executableAction.getTendency() +
"<行动原因>" + " : " + executableAction.getReason() +
"<工具描述>" + " : " + executableAction.getDescription();
}
@Override
protected String moduleName() {
return "[行动模块]";
}
private final class ActionAssemblyHelper {
private ActionAssemblyHelper() {
}
private ExtractorInput buildExtractorInput(PartnerRunningFlowContext context) {
ExtractorInput input = new ExtractorInput();
input.setInput(context.getInput());
List<Message> chatMessages = cognationCapability.getChatMessages();
List<Message> recentMessages = new ArrayList<>();
if (chatMessages.size() > 5) {
recentMessages.addAll(chatMessages.subList(chatMessages.size() - 5, chatMessages.size() - 1));
} else if (chatMessages.size() > 1) {
recentMessages.addAll(chatMessages.subList(0, chatMessages.size() - 1));
}
input.setRecentMessages(recentMessages);
return input;
}
private EvaluatorInput buildEvaluatorInput(ExtractorResult extractorResult, String userId) {
EvaluatorInput input = new EvaluatorInput();
input.setTendencies(extractorResult.getTendencies());
input.setUser(perceiveCapability.getUser(userId));
input.setRecentMessages(cognationCapability.getChatMessages());
input.setActivatedSlices(memoryCapability.getActivatedSlices(userId));
return input;
}
private ExecutableAction buildActionData(EvaluatorResult evaluatorResult, String userId) {
Map<Integer, List<MetaAction>> actionChain = getActionChain(evaluatorResult);
return switch (evaluatorResult.getType()) {
case PLANNING -> new SchedulableExecutableAction(
evaluatorResult.getTendency(),
actionChain,
evaluatorResult.getReason(),
evaluatorResult.getDescription(),
userId,
evaluatorResult.getScheduleType(),
evaluatorResult.getScheduleContent()
);
case IMMEDIATE -> new ImmediateExecutableAction(
evaluatorResult.getTendency(),
actionChain,
evaluatorResult.getReason(),
evaluatorResult.getDescription(),
userId
);
};
}
private @NotNull Map<Integer, List<MetaAction>> getActionChain(EvaluatorResult evaluatorResult) {
Map<Integer, List<MetaAction>> actionChain = new HashMap<>();
Map<Integer, List<String>> primaryActionChain = evaluatorResult.getPrimaryActionChain();
fixDependencies(primaryActionChain);
primaryActionChain.forEach((order, actionKeys) -> {
List<MetaAction> metaActions = actionKeys.stream()
.map(actionKey -> actionCapability.loadMetaAction(actionKey))
.toList();
actionChain.put(order, metaActions);
});
return actionChain;
}
private void fixDependencies(Map<Integer, List<String>> primaryActionChain) {
// 先将 primaryActionChain 的节点序号修正为从1开始依次增大
fixOrder(primaryActionChain);
List<Integer> fixedOrders = new ArrayList<>(primaryActionChain.keySet().stream().toList());
AtomicBoolean fixed = new AtomicBoolean(false);
do {
Set<Integer> tempOrders = new HashSet<>();
fixedOrders.sort(Integer::compareTo);
for (Integer fixedOrder : fixedOrders) {
int lastOrder = fixedOrder - 1;
List<String> actionKeys = primaryActionChain.get(fixedOrder);
for (String actionKey : actionKeys) {
// 根据 actionKey 加载行动信息,并检查是否存在必需前置依赖
MetaActionInfo metaActionInfo = actionCapability.loadMetaActionInfo(actionKey);
List<String> preActions = metaActionInfo.getPreActions();
boolean preActionsExist = preActions != null && !preActions.isEmpty();
if (!preActionsExist) {
continue;
}
if (!metaActionInfo.isStrictDependencies()) {
continue;
}
if (checkDependenciesExist(lastOrder, preActions, primaryActionChain)) {
continue;
}
// 如果存在前置依赖,则将其放置在当前order之前的位置,
// 放置位置优先选择已存在的上一节点,如果不存在(行动链的头节点时)则需要向行动链新增order
// 不需要检查行动链的当前节点的已存在 Action 是否为新 Action 的依赖项,因为这些 Action 实际来自 LLM
// 的评估结果,并非作为依赖项存在
fixed.set(true);
List<String> actionsInChain = primaryActionChain.computeIfAbsent(lastOrder,
list -> new ArrayList<>());
preActions = new ArrayList<>(preActions);
preActions.removeAll(actionsInChain);
actionsInChain.addAll(preActions);
tempOrders.add(lastOrder);
}
}
fixedOrders.clear();
fixedOrders.addAll(tempOrders);
} while (fixed.getAndSet(false));
}
private void fixOrder(Map<Integer, List<String>> primaryActionChain) {
Map<Integer, List<String>> tempChain = new HashMap<>(primaryActionChain);
primaryActionChain.clear();
int chainSize = tempChain.size();
for (int i = 0; i < chainSize; i++) {
primaryActionChain.put(i, tempChain.get(i));
}
}
private boolean checkDependenciesExist(int lastOrder, List<String> preActions,
Map<Integer, List<String>> primaryActionChain) {
if (!primaryActionChain.containsKey(lastOrder)) {
return false;
}
List<String> existActions = primaryActionChain.get(lastOrder);
//noinspection SlowListContainsAll
return existActions.containsAll(preActions);
}
private ConfirmerInput buildConfirmerInput(PartnerRunningFlowContext context) {
ConfirmerInput confirmerInput = new ConfirmerInput();
confirmerInput.setInput(context.getInput());
List<ExecutableAction> pendingActions = actionCapability.listPendingAction(context.getUserId());
confirmerInput.setExecutableActionData(pendingActions);
return confirmerInput;
}
}
}

View File

@@ -0,0 +1,90 @@
package work.slhaf.partner.module.modules.action.planner.confirmer;
import com.alibaba.fastjson2.JSONArray;
import com.alibaba.fastjson2.JSONObject;
import lombok.extern.slf4j.Slf4j;
import work.slhaf.partner.api.agent.factory.capability.annotation.InjectCapability;
import work.slhaf.partner.api.agent.factory.module.annotation.AgentSubModule;
import work.slhaf.partner.api.agent.runtime.interaction.flow.abstracts.ActivateModel;
import work.slhaf.partner.api.agent.runtime.interaction.flow.abstracts.AgentRunningSubModule;
import work.slhaf.partner.api.chat.pojo.ChatResponse;
import work.slhaf.partner.api.chat.pojo.Message;
import work.slhaf.partner.core.action.ActionCapability;
import work.slhaf.partner.core.action.ActionCore;
import work.slhaf.partner.core.action.entity.ExecutableAction;
import work.slhaf.partner.module.modules.action.planner.confirmer.entity.ConfirmerInput;
import work.slhaf.partner.module.modules.action.planner.confirmer.entity.ConfirmerResult;
import java.util.List;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import static work.slhaf.partner.common.util.ExtractUtil.extractJson;
@Slf4j
@AgentSubModule
public class ActionConfirmer extends AgentRunningSubModule<ConfirmerInput, ConfirmerResult> implements ActivateModel {
@InjectCapability
private ActionCapability actionCapability;
@Override
public ConfirmerResult execute(ConfirmerInput data) {
List<ExecutableAction> executableActionList = data.getExecutableActionData();
ExecutorService executor = actionCapability.getExecutor(ActionCore.ExecutorType.VIRTUAL);
CountDownLatch latch = new CountDownLatch(executableActionList.size());
ConfirmerResult result = new ConfirmerResult();
List<String> uuids = result.getUuids();
for (ExecutableAction executableAction : executableActionList) {
executor.execute(() -> {
try {
String prompt = buildPrompt(executableAction, data.getInput(), data.getRecentMessages());
ChatResponse response = this.singleChat(prompt);
JSONObject tempResult = JSONObject.parseObject(extractJson(response.getMessage()));
if (tempResult.getBoolean("confirmed")) {
executableAction.setStatus(ExecutableAction.Status.PREPARE);
synchronized (uuids) {
uuids.add(executableAction.getUuid());
}
}
} finally {
latch.countDown();
}
});
}
try {
latch.await();
} catch (InterruptedException e) {
log.warn("CountDownLatch阻塞已中断");
}
return result;
}
private String buildPrompt(ExecutableAction data, String input, List<Message> recentMessages) {
JSONObject prompt = new JSONObject();
prompt.put("[用户输入]", input);
JSONObject actionData = prompt.putObject("[行动数据]");
actionData.put("[行动倾向]", data.getTendency());
actionData.put("[行动原因]", data.getReason());
actionData.put("[行动来源]", data.getSource());
actionData.put("[行动描述]", data.getDescription());
JSONArray messageData = prompt.putArray("[近期对话]");
messageData.addAll(recentMessages);
return prompt.toString();
}
@Override
public String modelKey() {
return "action-confirmer";
}
@Override
public boolean withBasicPrompt() {
return false;
}
}

View File

@@ -0,0 +1,14 @@
package work.slhaf.partner.module.modules.action.planner.confirmer.entity;
import lombok.Data;
import work.slhaf.partner.api.chat.pojo.Message;
import work.slhaf.partner.core.action.entity.ExecutableAction;
import java.util.List;
@Data
public class ConfirmerInput {
private String input;
private List<ExecutableAction> executableActionData;
private List<Message> recentMessages;
}

View File

@@ -0,0 +1,11 @@
package work.slhaf.partner.module.modules.action.planner.confirmer.entity;
import lombok.Data;
import java.util.ArrayList;
import java.util.List;
@Data
public class ConfirmerResult {
private List<String> uuids = new ArrayList<>();
}

View File

@@ -0,0 +1,105 @@
package work.slhaf.partner.module.modules.action.planner.evaluator;
import cn.hutool.core.bean.BeanUtil;
import com.alibaba.fastjson2.JSONArray;
import com.alibaba.fastjson2.JSONObject;
import work.slhaf.partner.api.agent.factory.capability.annotation.InjectCapability;
import work.slhaf.partner.api.agent.factory.module.annotation.AgentSubModule;
import work.slhaf.partner.api.agent.factory.module.annotation.Init;
import work.slhaf.partner.api.agent.runtime.interaction.flow.abstracts.ActivateModel;
import work.slhaf.partner.api.agent.runtime.interaction.flow.abstracts.AgentRunningSubModule;
import work.slhaf.partner.api.chat.pojo.ChatResponse;
import work.slhaf.partner.common.thread.InteractionThreadPoolExecutor;
import work.slhaf.partner.core.action.ActionCapability;
import work.slhaf.partner.core.memory.pojo.EvaluatedSlice;
import work.slhaf.partner.module.modules.action.planner.evaluator.entity.EvaluatorBatchInput;
import work.slhaf.partner.module.modules.action.planner.evaluator.entity.EvaluatorInput;
import work.slhaf.partner.module.modules.action.planner.evaluator.entity.EvaluatorResult;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.Callable;
@AgentSubModule
public class ActionEvaluator extends AgentRunningSubModule<EvaluatorInput, List<EvaluatorResult>> implements ActivateModel {
@InjectCapability
private ActionCapability actionCapability;
private InteractionThreadPoolExecutor executor;
@Init
public void init() {
executor = InteractionThreadPoolExecutor.getInstance();
}
/**
* 对输入的行为倾向进行评估,并根据评估结果,对缓存做出调整
*
* @param data 评估输入内容,包含提取/命中缓存的行动倾向、近几条聊天记录,正在生效的记忆切片内容
* @return 评估结果集合
*/
@Override
public List<EvaluatorResult> execute(EvaluatorInput data) {
List<EvaluatorBatchInput> batchInputs = buildEvaluatorBatchInput(data);
List<Callable<EvaluatorResult>> tasks = getTasks(batchInputs);
return executor.invokeAllAndReturn(tasks);
}
private List<Callable<EvaluatorResult>> getTasks(List<EvaluatorBatchInput> batchInputs) {
List<Callable<EvaluatorResult>> list = new ArrayList<>();
for (EvaluatorBatchInput batchInput : batchInputs) {
list.add(() -> {
ChatResponse response = this.singleChat(buildPrompt(batchInput));
EvaluatorResult evaluatorResult = JSONObject.parseObject(response.getMessage(), EvaluatorResult.class);
evaluatorResult.setTendency(batchInput.getTendency());
return evaluatorResult;
});
}
return list;
}
private List<EvaluatorBatchInput> buildEvaluatorBatchInput(EvaluatorInput data) {
List<EvaluatorBatchInput> list = new ArrayList<>();
for (String tendency : data.getTendencies()) {
EvaluatorBatchInput temp = new EvaluatorBatchInput();
BeanUtil.copyProperties(data, temp);
temp.setTendency(tendency);
Map<String, String> availableActions = new HashMap<>();
actionCapability.listAvailableMetaActions().forEach((key, info) -> availableActions.put(key, info.getDescription()));
temp.setAvailableActions(availableActions);
list.add(temp);
}
return list;
}
private String buildPrompt(EvaluatorBatchInput batchInput) {
JSONObject prompt = new JSONObject();
prompt.put("[行动倾向]", batchInput.getTendency());
JSONArray memoryData = prompt.putArray("[相关记忆切片]");
for (EvaluatedSlice evaluatedSlice : batchInput.getActivatedSlices()) {
JSONObject memory = memoryData.addObject();
memory.put("[日期]", evaluatedSlice.getDate());
memory.put("[摘要]", evaluatedSlice.getSummary());
}
JSONObject availableActionData = prompt.putObject("[可用行动单元]");
availableActionData.putAll(batchInput.getAvailableActions());
return prompt.toString();
}
@Override
public String modelKey() {
return "action_evaluator";
}
@Override
public boolean withBasicPrompt() {
return true;
}
}

View File

@@ -0,0 +1,16 @@
package work.slhaf.partner.module.modules.action.planner.evaluator.entity;
import lombok.Data;
import work.slhaf.partner.api.chat.pojo.Message;
import work.slhaf.partner.core.memory.pojo.EvaluatedSlice;
import java.util.List;
import java.util.Map;
@Data
public class EvaluatorBatchInput {
private List<Message> recentMessages;
private List<EvaluatedSlice> activatedSlices;
private Map<String, String> availableActions;
private String tendency;
}

View File

@@ -0,0 +1,16 @@
package work.slhaf.partner.module.modules.action.planner.evaluator.entity;
import lombok.Data;
import work.slhaf.partner.api.chat.pojo.Message;
import work.slhaf.partner.core.memory.pojo.EvaluatedSlice;
import work.slhaf.partner.core.perceive.pojo.User;
import java.util.List;
@Data
public class EvaluatorInput {
private List<Message> recentMessages;
private User user;
private List<EvaluatedSlice> activatedSlices;
private List<String> tendencies;
}

View File

@@ -0,0 +1,24 @@
package work.slhaf.partner.module.modules.action.planner.evaluator.entity;
import lombok.Data;
import work.slhaf.partner.core.action.entity.SchedulableExecutableAction;
import java.util.List;
import java.util.Map;
@Data
public class EvaluatorResult {
private boolean ok;
private boolean needConfirm;
private ActionType type;
private String scheduleContent;
private SchedulableExecutableAction.ScheduleType scheduleType;
private Map<Integer, List<String>> primaryActionChain;
private String tendency;
private String reason;
private String description;
public enum ActionType {
IMMEDIATE, PLANNING
}
}

View File

@@ -0,0 +1,53 @@
package work.slhaf.partner.module.modules.action.planner.extractor;
import com.alibaba.fastjson2.JSONObject;
import lombok.extern.slf4j.Slf4j;
import work.slhaf.partner.api.agent.factory.capability.annotation.InjectCapability;
import work.slhaf.partner.api.agent.factory.module.annotation.AgentSubModule;
import work.slhaf.partner.api.agent.runtime.interaction.flow.abstracts.ActivateModel;
import work.slhaf.partner.api.agent.runtime.interaction.flow.abstracts.AgentRunningSubModule;
import work.slhaf.partner.api.chat.pojo.ChatResponse;
import work.slhaf.partner.core.action.ActionCapability;
import work.slhaf.partner.module.modules.action.planner.extractor.entity.ExtractorInput;
import work.slhaf.partner.module.modules.action.planner.extractor.entity.ExtractorResult;
import java.util.List;
@Slf4j
@AgentSubModule
public class ActionExtractor extends AgentRunningSubModule<ExtractorInput, ExtractorResult> implements ActivateModel {
@InjectCapability
private ActionCapability actionCapability;
@Override
public ExtractorResult execute(ExtractorInput data) {
ExtractorResult result = new ExtractorResult();
List<String> tendencyCache = actionCapability.selectTendencyCache(data.getInput());
if (tendencyCache != null && !tendencyCache.isEmpty()) {
result.setTendencies(tendencyCache);
return result;
}
for (int i = 0; i < 3; i++) {
try {
ChatResponse response = this.singleChat(JSONObject.toJSONString(data));
return JSONObject.parseObject(response.getMessage(), ExtractorResult.class);
} catch (Exception e) {
log.error("[ActionExtractor] 提取信息出错", e);
}
}
return new ExtractorResult();
}
@Override
public String modelKey() {
return "action_extractor";
}
@Override
public boolean withBasicPrompt() {
return false;
}
}

View File

@@ -0,0 +1,12 @@
package work.slhaf.partner.module.modules.action.planner.extractor.entity;
import lombok.Data;
import work.slhaf.partner.api.chat.pojo.Message;
import java.util.List;
@Data
public class ExtractorInput {
private String input;
private List<Message> recentMessages;
}

View File

@@ -0,0 +1,11 @@
package work.slhaf.partner.module.modules.action.planner.extractor.entity;
import lombok.Data;
import java.util.ArrayList;
import java.util.List;
@Data
public class ExtractorResult {
private List<String> tendencies = new ArrayList<>();
}

View File

@@ -0,0 +1,250 @@
package work.slhaf.partner.module.modules.core;
import com.alibaba.fastjson2.JSONObject;
import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.extern.slf4j.Slf4j;
import work.slhaf.partner.api.agent.factory.capability.annotation.InjectCapability;
import work.slhaf.partner.api.agent.factory.module.annotation.CoreModule;
import work.slhaf.partner.api.agent.factory.module.annotation.Init;
import work.slhaf.partner.api.agent.runtime.interaction.flow.abstracts.ActivateModel;
import work.slhaf.partner.api.agent.runtime.interaction.flow.abstracts.AgentRunningModule;
import work.slhaf.partner.api.chat.constant.ChatConstant;
import work.slhaf.partner.api.chat.pojo.ChatResponse;
import work.slhaf.partner.api.chat.pojo.Message;
import work.slhaf.partner.api.chat.pojo.MetaMessage;
import work.slhaf.partner.core.cognation.CognationCapability;
import work.slhaf.partner.module.common.entity.AppendPromptData;
import work.slhaf.partner.module.common.model.ModelConstant;
import work.slhaf.partner.runtime.interaction.data.context.PartnerRunningFlowContext;
import java.time.LocalDateTime;
import java.time.format.DateTimeFormatter;
import java.util.ArrayList;
import java.util.List;
import static work.slhaf.partner.common.util.ExtractUtil.extractJson;
@EqualsAndHashCode(callSuper = true)
@Data
@Slf4j
@CoreModule
public class CoreModel extends AgentRunningModule<PartnerRunningFlowContext> implements ActivateModel {
@InjectCapability
private CognationCapability cognationCapability;
private List<Message> appendedMessages;
@Init
public void init(){
List<Message> chatMessages = this.cognationCapability.getChatMessages();
this.getModel().setChatMessages(chatMessages);
this.appendedMessages = new ArrayList<>();
updateChatClientSettings();
log.info("[CoreModel] CoreModel注册完毕...");
}
@Override
public void updateChatClientSettings() {
chatClient().setTemperature(0.3);
chatClient().setTop_p(0.7);
}
@Override
public String modelKey() {
return "core_model";
}
@Override
public boolean withBasicPrompt() {
return true;
}
@Override
public void execute(PartnerRunningFlowContext runningFlowContext) {
String userId = runningFlowContext.getUserId();
log.debug("[CoreModel] 主对话流程开始: {}", userId);
beforeChat(runningFlowContext);
executeChat(runningFlowContext);
log.debug("[CoreModel] 主对话流程({})结束...", userId);
}
private void beforeChat(PartnerRunningFlowContext runningFlowContext) {
setAppendedPromptMessage(runningFlowContext);
activateModule(runningFlowContext);
setMessageCount(runningFlowContext);
log.debug("[CoreModel] 当前消息列表大小: {}", chatMessages().size());
log.debug("[CoreModel] 当前核心prompt内容: {}", runningFlowContext.getCoreContext().toString());
setMessage(runningFlowContext.getCoreContext().toString());
}
private void setAppendedPromptMessage(PartnerRunningFlowContext runningFlowContext) {
List<AppendPromptData> appendedPrompt = runningFlowContext.getModuleContext().getAppendedPrompt();
int appendedPromptSize = getAppendedPromptSize(appendedPrompt);
if (appendedPromptSize > 0) {
setAppendedPromptMessage(appendedPrompt);
}
}
private void executeChat(PartnerRunningFlowContext runningFlowContext) {
JSONObject response = new JSONObject();
int count = 0;
while (true) {
try {
ChatResponse chatResponse = this.chat();
try {
response.putAll(JSONObject.parse(extractJson(chatResponse.getMessage())));
} catch (Exception e) {
log.warn("主模型回复格式出错, 将直接作为消息返回, 建议尝试更换主模型...");
handleExceptionResponse(response, chatResponse.getMessage());
}
log.debug("[CoreModel] CoreModel 响应内容: {}", response);
updateModuleContextAndChatMessages(runningFlowContext, response.getString("text"), chatResponse);
break;
} catch (Exception e) {
count++;
log.error("[CoreModel] CoreModel执行异常: {}", e.getLocalizedMessage());
if (count > 3) {
handleExceptionResponse(response, "主模型交互出错: " + e.getLocalizedMessage());
chatMessages().removeLast();
break;
}
} finally {
updateCoreResponse(runningFlowContext, response);
resetAppendedMessages();
log.debug("[CoreModel] 消息列表更新大小: {}", chatMessages().size());
}
}
}
private int getAppendedPromptSize(List<AppendPromptData> appendedPrompt) {
int size = 0;
for (AppendPromptData data : appendedPrompt) {
size += data.getAppendedPrompt().size();
}
return size;
}
private void activateModule(PartnerRunningFlowContext context) {
for (AppendPromptData data : context.getModuleContext().getAppendedPrompt()) {
if (data.getAppendedPrompt().isEmpty()) continue;
context.getCoreContext().activateModule(data.getModuleName());
}
}
private void updateCoreResponse(PartnerRunningFlowContext runningFlowContext, JSONObject response) {
runningFlowContext.getCoreResponse().put("text", response.getString("text"));
}
private void resetAppendedMessages() {
this.appendedMessages.clear();
}
@Override
public ChatResponse chat() {
List<Message> temp = new ArrayList<>(baseMessages().subList(0, baseMessages().size() - 2));
temp.addAll(appendedMessages);
temp.addAll(baseMessages().subList(baseMessages().size() - 2, baseMessages().size()));
temp.addAll(chatMessages());
return chatClient().runChat(temp);
}
private void updateModuleContextAndChatMessages(PartnerRunningFlowContext runningFlowContext, String response, ChatResponse chatResponse) {
cognationCapability.getMessageLock().lock();
chatMessages().removeIf(m -> {
if (m.getRole().equals(ChatConstant.Character.ASSISTANT)) {
return false;
}
try {
JSONObject.parseObject(extractJson(m.getContent()));
return true;
} catch (Exception e) {
return false;
}
});
//添加时间标志
String dateTime = LocalDateTime.now().format(DateTimeFormatter.ofPattern("\r\n**[yyyy-MM-dd HH:mm:ss]"));
Message primaryUserMessage = new Message(ChatConstant.Character.USER, runningFlowContext.getCoreContext().getText() + dateTime);
chatMessages().add(primaryUserMessage);
Message assistantMessage = new Message(ChatConstant.Character.ASSISTANT, response);
chatMessages().add(assistantMessage);
cognationCapability.getMessageLock().unlock();
//设置上下文
runningFlowContext.getModuleContext().getExtraContext().put("total_token", chatResponse.getUsageBean().getTotal_tokens());
//区分单人聊天场景
if (runningFlowContext.isSingle()) {
MetaMessage metaMessage = new MetaMessage(primaryUserMessage, assistantMessage);
cognationCapability.addMetaMessage(runningFlowContext.getUserId(), metaMessage);
}
}
private void setMessage(String coreContextStr) {
Message userMessage = new Message(ChatConstant.Character.USER, coreContextStr);
chatMessages().add(userMessage);
}
private void handleExceptionResponse(JSONObject response, String chatResponse) {
response.put("text", chatResponse);
// interactionContext.setFinished(true);
}
private void setMessageCount(PartnerRunningFlowContext runningFlowContext) {
runningFlowContext.getModuleContext().getExtraContext().put("message_count", chatMessages().size());
}
private void setAppendedPromptMessage(List<AppendPromptData> appendPrompt) {
Message appendDeclareMessage = Message.builder()
.role(ChatConstant.Character.USER)
.content(ModelConstant.CharacterPrefix.SYSTEM + "认知补充开始")
.build();
this.appendedMessages.add(appendDeclareMessage);
for (AppendPromptData data : appendPrompt) {
setStartMessage(data);
setContentMessage(data);
setEndMessage(data);
setAssistantMessage();
}
Message appendEndMessage = Message.builder()
.role(ChatConstant.Character.USER)
.content(ModelConstant.CharacterPrefix.SYSTEM + "认知补充结束")
.build();
this.appendedMessages.add(appendEndMessage);
}
private void setAssistantMessage() {
appendedMessages.add(Message.builder()
.role(ChatConstant.Character.ASSISTANT)
.content("嗯,明白了")
.build());
}
private void setEndMessage(AppendPromptData data) {
Message endMessage = Message.builder()
.role(ChatConstant.Character.USER)
.content(ModelConstant.CharacterPrefix.SYSTEM + data.getModuleName() + "认知补充结束.")
.build();
appendedMessages.add(endMessage);
}
private void setContentMessage(AppendPromptData data) {
data.getAppendedPrompt().forEach((k, v) -> {
Message contentMessage = Message.builder()
.role(ChatConstant.Character.USER)
.content(ModelConstant.CharacterPrefix.SYSTEM + k + v + "\r\n")
.build();
appendedMessages.add(contentMessage);
});
}
private void setStartMessage(AppendPromptData data) {
Message startMessage = Message.builder()
.role(ChatConstant.Character.USER)
.content(ModelConstant.CharacterPrefix.SYSTEM + data.getModuleName() + "以下为" + data.getModuleName() + "相关认知.")
.build();
appendedMessages.add(startMessage);
}
}

View File

@@ -0,0 +1,153 @@
package work.slhaf.partner.module.modules.memory.selector;
import com.alibaba.fastjson2.JSONObject;
import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.extern.slf4j.Slf4j;
import work.slhaf.partner.api.agent.factory.capability.annotation.InjectCapability;
import work.slhaf.partner.api.agent.factory.module.annotation.AgentModule;
import work.slhaf.partner.api.agent.factory.module.annotation.InjectModule;
import work.slhaf.partner.core.cognation.CognationCapability;
import work.slhaf.partner.core.memory.MemoryCapability;
import work.slhaf.partner.core.memory.exception.UnExistedDateIndexException;
import work.slhaf.partner.core.memory.exception.UnExistedTopicException;
import work.slhaf.partner.core.memory.pojo.EvaluatedSlice;
import work.slhaf.partner.core.memory.pojo.MemoryResult;
import work.slhaf.partner.core.memory.pojo.MemorySlice;
import work.slhaf.partner.module.common.module.PreRunningModule;
import work.slhaf.partner.module.modules.memory.selector.evaluator.SliceSelectEvaluator;
import work.slhaf.partner.module.modules.memory.selector.evaluator.entity.EvaluatorInput;
import work.slhaf.partner.module.modules.memory.selector.extractor.MemorySelectExtractor;
import work.slhaf.partner.module.modules.memory.selector.extractor.entity.ExtractorMatchData;
import work.slhaf.partner.module.modules.memory.selector.extractor.entity.ExtractorResult;
import work.slhaf.partner.runtime.interaction.data.context.PartnerRunningFlowContext;
import java.time.LocalDate;
import java.util.*;
@EqualsAndHashCode(callSuper = true)
@Data
@Slf4j
@AgentModule(name = "memory_selector", order = 2)
public class MemorySelector extends PreRunningModule {
@InjectCapability
private MemoryCapability memoryCapability;
@InjectCapability
private CognationCapability cognationCapability;
@InjectModule
private SliceSelectEvaluator sliceSelectEvaluator;
@InjectModule
private MemorySelectExtractor memorySelectExtractor;
@Override
public void doExecute(PartnerRunningFlowContext runningFlowContext) {
String userId = runningFlowContext.getUserId();
//获取主题路径
ExtractorResult extractorResult = memorySelectExtractor.execute(runningFlowContext);
if (extractorResult.isRecall() || !extractorResult.getMatches().isEmpty()) {
memoryCapability.clearActivatedSlices(userId);
List<EvaluatedSlice> evaluatedSlices = selectAndEvaluateMemory(runningFlowContext, extractorResult);
memoryCapability.updateActivatedSlices(userId, evaluatedSlices);
}
setModuleContextRecall(runningFlowContext);
}
private List<EvaluatedSlice> selectAndEvaluateMemory(PartnerRunningFlowContext runningFlowContext, ExtractorResult extractorResult) {
log.debug("[MemorySelector] 触发记忆回溯...");
//查找切片
String userId = runningFlowContext.getUserId();
List<MemoryResult> memoryResultList = new ArrayList<>();
setMemoryResultList(memoryResultList, extractorResult.getMatches(), userId);
//评估切片
EvaluatorInput evaluatorInput = EvaluatorInput.builder()
.input(runningFlowContext.getInput())
.memoryResults(memoryResultList)
.messages(cognationCapability.getChatMessages())
.build();
log.debug("[MemorySelector] 切片评估输入: {}", JSONObject.toJSONString(evaluatorInput));
List<EvaluatedSlice> memorySlices = sliceSelectEvaluator.execute(evaluatorInput);
log.debug("[MemorySelector] 切片评估结果: {}", JSONObject.toJSONString(memorySlices));
return memorySlices;
}
private void setModuleContextRecall(PartnerRunningFlowContext runningFlowContext) {
String userId = runningFlowContext.getUserId();
boolean recall = memoryCapability.hasActivatedSlices(userId);
runningFlowContext.getModuleContext().getExtraContext().put("recall", recall);
if (recall) {
runningFlowContext.getModuleContext().getExtraContext().put("recall_count", memoryCapability.getActivatedSlicesSize(userId));
}
}
private void setMemoryResultList(List<MemoryResult> memoryResultList, List<ExtractorMatchData> matches, String userId) {
for (ExtractorMatchData match : matches) {
try {
MemoryResult memoryResult = switch (match.getType()) {
case ExtractorMatchData.Constant.TOPIC -> memoryCapability.selectMemory(match.getText());
case ExtractorMatchData.Constant.DATE ->
memoryCapability.selectMemory(LocalDate.parse(match.getText()));
default -> null;
};
if (memoryResult == null || memoryResult.isEmpty()) continue;
removeDuplicateSlice(memoryResult);
memoryResultList.add(memoryResult);
} catch (UnExistedDateIndexException | UnExistedTopicException e) {
log.error("[MemorySelector] 不存在的记忆索引! 请尝试更换更合适的主题提取LLM!", e);
log.error("[MemorySelector] 错误索引: {}", match.getText());
}
}
//清理切片记录
memoryCapability.cleanSelectedSliceFilter();
//根据userInfo过滤是否为私人记忆
for (MemoryResult memoryResult : memoryResultList) {
//过滤终点记忆
memoryResult.getMemorySliceResult().removeIf(m -> removeOrNot(m.getMemorySlice(), userId));
//过滤邻近记忆
memoryResult.getRelatedMemorySliceResult().removeIf(m -> removeOrNot(m, userId));
}
}
private void removeDuplicateSlice(MemoryResult memoryResult) {
Collection<String> values = memoryCapability.getDialogMap().values();
memoryResult.getRelatedMemorySliceResult().removeIf(m -> values.contains(m.getSummary()));
memoryResult.getMemorySliceResult().removeIf(m -> values.contains(m.getMemorySlice().getSummary()));
}
private boolean removeOrNot(MemorySlice memorySlice, String userId) {
if (memorySlice.isPrivate()) {
return memorySlice.getStartUserId().equals(userId);
}
return false;
}
@Override
public String moduleName() {
return "[记忆模块]";
}
@Override
protected Map<String, String> getPromptDataMap(PartnerRunningFlowContext context) {
HashMap<String, String> map = new HashMap<>();
String userId = context.getUserId();
String dialogMapStr = memoryCapability.getDialogMapStr();
if (!dialogMapStr.isEmpty()) {
map.put("[记忆缓存] <你最近两日和所有聊天者的对话记忆印象>", dialogMapStr);
}
String userDialogMapStr = memoryCapability.getUserDialogMapStr(userId);
if (userDialogMapStr != null && !userDialogMapStr.isEmpty() && !cognationCapability.isSingleUser()) {
map.put("[用户记忆缓存] <与最新一条消息的发送者的近两天对话记忆印象, 可能与[记忆缓存]稍有重复>", userDialogMapStr);
}
String sliceStr = memoryCapability.getActivatedSlicesStr(userId);
if (sliceStr != null && !sliceStr.isEmpty()) {
map.put("[记忆切片] <你与最新一条消息的发送者的相关回忆, 不会与[记忆缓存]重复, 如果有重复你也可以指出来>", sliceStr);
}
return map;
}
}

Some files were not shown because too many files have changed in this diff Show More