mirror of
https://github.com/slhaf/Partner.git
synced 2026-05-12 08:43:02 +08:00
refactor(chat): replace custom client with OpenAI runtime and remove file-based module prompt loading logic, prompt will be provided by each module
This commit is contained in:
@@ -4,9 +4,13 @@ import com.alibaba.fastjson2.JSONObject;
|
|||||||
import lombok.val;
|
import lombok.val;
|
||||||
import work.slhaf.partner.api.agent.factory.component.abstracts.AbstractAgentModule;
|
import work.slhaf.partner.api.agent.factory.component.abstracts.AbstractAgentModule;
|
||||||
import work.slhaf.partner.api.agent.factory.component.abstracts.ActivateModel;
|
import work.slhaf.partner.api.agent.factory.component.abstracts.ActivateModel;
|
||||||
|
import work.slhaf.partner.api.chat.constant.ChatConstant;
|
||||||
|
import work.slhaf.partner.api.chat.pojo.Message;
|
||||||
import work.slhaf.partner.module.modules.action.executor.entity.CorrectorInput;
|
import work.slhaf.partner.module.modules.action.executor.entity.CorrectorInput;
|
||||||
import work.slhaf.partner.module.modules.action.executor.entity.CorrectorResult;
|
import work.slhaf.partner.module.modules.action.executor.entity.CorrectorResult;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 负责在单组行动执行后,根据行动意图与结果检查后续行动是否符合目的,必要时直接调整行动链,或发起自对话请求进行干预
|
* 负责在单组行动执行后,根据行动意图与结果检查后续行动是否符合目的,必要时直接调整行动链,或发起自对话请求进行干预
|
||||||
*/
|
*/
|
||||||
@@ -14,8 +18,7 @@ public class ActionCorrector extends AbstractAgentModule.Sub<CorrectorInput, Cor
|
|||||||
@Override
|
@Override
|
||||||
public CorrectorResult execute(CorrectorInput input) {
|
public CorrectorResult execute(CorrectorInput input) {
|
||||||
val prompt = buildPrompt(input);
|
val prompt = buildPrompt(input);
|
||||||
val chatResponse = singleChat(prompt);
|
return formattedChat(List.of(new Message(ChatConstant.Character.USER, prompt)), CorrectorResult.class);
|
||||||
return JSONObject.parseObject(chatResponse.getMessage(), CorrectorResult.class);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private String buildPrompt(CorrectorInput input) {
|
private String buildPrompt(CorrectorInput input) {
|
||||||
@@ -37,9 +40,4 @@ public class ActionCorrector extends AbstractAgentModule.Sub<CorrectorInput, Cor
|
|||||||
public String modelKey() {
|
public String modelKey() {
|
||||||
return "action_corrector";
|
return "action_corrector";
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public boolean withBasicPrompt() {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -9,7 +9,8 @@ import work.slhaf.partner.api.agent.factory.component.abstracts.AbstractAgentMod
|
|||||||
import work.slhaf.partner.api.agent.factory.component.abstracts.ActivateModel;
|
import work.slhaf.partner.api.agent.factory.component.abstracts.ActivateModel;
|
||||||
import work.slhaf.partner.api.agent.factory.component.annotation.Init;
|
import work.slhaf.partner.api.agent.factory.component.annotation.Init;
|
||||||
import work.slhaf.partner.api.agent.factory.component.annotation.InjectModule;
|
import work.slhaf.partner.api.agent.factory.component.annotation.InjectModule;
|
||||||
import work.slhaf.partner.api.chat.pojo.ChatResponse;
|
import work.slhaf.partner.api.chat.constant.ChatConstant;
|
||||||
|
import work.slhaf.partner.api.chat.pojo.Message;
|
||||||
import work.slhaf.partner.core.action.ActionCapability;
|
import work.slhaf.partner.core.action.ActionCapability;
|
||||||
import work.slhaf.partner.core.action.ActionCore.ExecutorType;
|
import work.slhaf.partner.core.action.ActionCore.ExecutorType;
|
||||||
import work.slhaf.partner.core.action.entity.MetaAction;
|
import work.slhaf.partner.core.action.entity.MetaAction;
|
||||||
@@ -61,8 +62,10 @@ public class ActionRepairer extends AbstractAgentModule.Sub<RepairerInput, Repai
|
|||||||
RepairerResult result;
|
RepairerResult result;
|
||||||
try {
|
try {
|
||||||
String prompt = assemblyHelper.buildPrompt(data, null);
|
String prompt = assemblyHelper.buildPrompt(data, null);
|
||||||
ChatResponse response = this.singleChat(prompt);
|
RepairerData repairerData = formattedChat(
|
||||||
RepairerData repairerData = JSONObject.parseObject(response.getMessage(), RepairerData.class);
|
List.of(new Message(ChatConstant.Character.USER, prompt)),
|
||||||
|
RepairerData.class
|
||||||
|
);
|
||||||
result = switch (repairerData.getRepairerType()) {
|
result = switch (repairerData.getRepairerType()) {
|
||||||
case ACTION_GENERATION ->
|
case ACTION_GENERATION ->
|
||||||
handleActionGeneration(JSONObject.parseObject(repairerData.getData(), GeneratorInput.class));
|
handleActionGeneration(JSONObject.parseObject(repairerData.getData(), GeneratorInput.class));
|
||||||
@@ -75,8 +78,10 @@ public class ActionRepairer extends AbstractAgentModule.Sub<RepairerInput, Repai
|
|||||||
&& result.getStatus().equals(RepairerResult.RepairerStatus.FAILED)) {
|
&& result.getStatus().equals(RepairerResult.RepairerStatus.FAILED)) {
|
||||||
log.warn("常规行动修复失败,将尝试自对话通道");
|
log.warn("常规行动修复失败,将尝试自对话通道");
|
||||||
prompt = assemblyHelper.buildPrompt(data, "常规行动修复失败,请尝试通过自对话通道获取必要的信息以完成行动参数的修复");
|
prompt = assemblyHelper.buildPrompt(data, "常规行动修复失败,请尝试通过自对话通道获取必要的信息以完成行动参数的修复");
|
||||||
response = this.singleChat(prompt);
|
repairerData = formattedChat(
|
||||||
repairerData = JSONObject.parseObject(response.getMessage(), RepairerData.class);
|
List.of(new Message(ChatConstant.Character.USER, prompt)),
|
||||||
|
RepairerData.class
|
||||||
|
);
|
||||||
handleUserInteraction(repairerData.getData());
|
handleUserInteraction(repairerData.getData());
|
||||||
}
|
}
|
||||||
} catch (Exception e) {
|
} catch (Exception e) {
|
||||||
@@ -165,20 +170,14 @@ public class ActionRepairer extends AbstractAgentModule.Sub<RepairerInput, Repai
|
|||||||
return "action_repairer";
|
return "action_repairer";
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public boolean withBasicPrompt() {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
private enum RepairerType {
|
private enum RepairerType {
|
||||||
ACTION_GENERATION,
|
ACTION_GENERATION,
|
||||||
ACTION_INVOCATION,
|
ACTION_INVOCATION,
|
||||||
USER_INTERACTION
|
USER_INTERACTION
|
||||||
}
|
}
|
||||||
|
|
||||||
@SuppressWarnings("InnerClassMayBeStatic")
|
|
||||||
@Data
|
@Data
|
||||||
private class RepairerData {
|
private static class RepairerData {
|
||||||
private RepairerType repairerType;
|
private RepairerType repairerType;
|
||||||
private String data;
|
private String data;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -6,8 +6,8 @@ import work.slhaf.partner.api.agent.factory.capability.annotation.InjectCapabili
|
|||||||
import work.slhaf.partner.api.agent.factory.component.abstracts.AbstractAgentModule;
|
import work.slhaf.partner.api.agent.factory.component.abstracts.AbstractAgentModule;
|
||||||
import work.slhaf.partner.api.agent.factory.component.abstracts.ActivateModel;
|
import work.slhaf.partner.api.agent.factory.component.abstracts.ActivateModel;
|
||||||
import work.slhaf.partner.api.agent.factory.component.annotation.Init;
|
import work.slhaf.partner.api.agent.factory.component.annotation.Init;
|
||||||
import work.slhaf.partner.api.chat.pojo.ChatResponse;
|
import work.slhaf.partner.api.chat.constant.ChatConstant;
|
||||||
import work.slhaf.partner.common.util.ExtractUtil;
|
import work.slhaf.partner.api.chat.pojo.Message;
|
||||||
import work.slhaf.partner.core.action.ActionCapability;
|
import work.slhaf.partner.core.action.ActionCapability;
|
||||||
import work.slhaf.partner.core.action.entity.GeneratedData;
|
import work.slhaf.partner.core.action.entity.GeneratedData;
|
||||||
import work.slhaf.partner.core.action.entity.MetaAction;
|
import work.slhaf.partner.core.action.entity.MetaAction;
|
||||||
@@ -15,6 +15,8 @@ import work.slhaf.partner.core.action.runner.RunnerClient;
|
|||||||
import work.slhaf.partner.module.modules.action.executor.entity.GeneratorInput;
|
import work.slhaf.partner.module.modules.action.executor.entity.GeneratorInput;
|
||||||
import work.slhaf.partner.module.modules.action.executor.entity.GeneratorResult;
|
import work.slhaf.partner.module.modules.action.executor.entity.GeneratorResult;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 负责依据输入内容生成可执行的动态行动单元,并选择是否持久化至 SandboxRunner 容器内
|
* 负责依据输入内容生成可执行的动态行动单元,并选择是否持久化至 SandboxRunner 容器内
|
||||||
*/
|
*/
|
||||||
@@ -37,9 +39,10 @@ public class DynamicActionGenerator extends AbstractAgentModule.Sub<GeneratorInp
|
|||||||
// 所以此处的输入内容也只需要指定输入参数、临时key、是否持久化即可,路径将按照指定规则统一构建,不可交给LLM生成
|
// 所以此处的输入内容也只需要指定输入参数、临时key、是否持久化即可,路径将按照指定规则统一构建,不可交给LLM生成
|
||||||
String prompt = buildPrompt(input);
|
String prompt = buildPrompt(input);
|
||||||
// 响应结果需要包含几个特殊数据: 依赖项、代码内容、是否序列化、响应数据释义
|
// 响应结果需要包含几个特殊数据: 依赖项、代码内容、是否序列化、响应数据释义
|
||||||
ChatResponse response = this.singleChat(prompt);
|
GeneratedData generatorData = formattedChat(
|
||||||
GeneratedData generatorData = JSONObject
|
List.of(new Message(ChatConstant.Character.USER, prompt)),
|
||||||
.parseObject(ExtractUtil.extractJson(response.getMessage()), GeneratedData.class);
|
GeneratedData.class
|
||||||
|
);
|
||||||
val location = runnerClient.buildTmpPath(input.getActionName(), generatorData.getCodeType());
|
val location = runnerClient.buildTmpPath(input.getActionName(), generatorData.getCodeType());
|
||||||
MetaAction tempAction = new MetaAction(
|
MetaAction tempAction = new MetaAction(
|
||||||
input.getActionName(),
|
input.getActionName(),
|
||||||
@@ -77,9 +80,4 @@ public class DynamicActionGenerator extends AbstractAgentModule.Sub<GeneratorInp
|
|||||||
public String modelKey() {
|
public String modelKey() {
|
||||||
return "dynamic_generator";
|
return "dynamic_generator";
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public boolean withBasicPrompt() {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,7 +4,8 @@ import com.alibaba.fastjson2.JSONArray;
|
|||||||
import com.alibaba.fastjson2.JSONObject;
|
import com.alibaba.fastjson2.JSONObject;
|
||||||
import work.slhaf.partner.api.agent.factory.component.abstracts.AbstractAgentModule;
|
import work.slhaf.partner.api.agent.factory.component.abstracts.AbstractAgentModule;
|
||||||
import work.slhaf.partner.api.agent.factory.component.abstracts.ActivateModel;
|
import work.slhaf.partner.api.agent.factory.component.abstracts.ActivateModel;
|
||||||
import work.slhaf.partner.api.chat.pojo.ChatResponse;
|
import work.slhaf.partner.api.chat.constant.ChatConstant;
|
||||||
|
import work.slhaf.partner.api.chat.pojo.Message;
|
||||||
import work.slhaf.partner.core.action.entity.MetaActionInfo;
|
import work.slhaf.partner.core.action.entity.MetaActionInfo;
|
||||||
import work.slhaf.partner.module.modules.action.executor.entity.ExtractorInput;
|
import work.slhaf.partner.module.modules.action.executor.entity.ExtractorInput;
|
||||||
import work.slhaf.partner.module.modules.action.executor.entity.ExtractorResult;
|
import work.slhaf.partner.module.modules.action.executor.entity.ExtractorResult;
|
||||||
@@ -20,12 +21,11 @@ public class ParamsExtractor extends AbstractAgentModule.Sub<ExtractorInput, Ext
|
|||||||
@Override
|
@Override
|
||||||
public ExtractorResult execute(ExtractorInput input) {
|
public ExtractorResult execute(ExtractorInput input) {
|
||||||
String prompt = buildPrompt(input);
|
String prompt = buildPrompt(input);
|
||||||
ChatResponse response = this.singleChat(prompt);
|
|
||||||
ExtractorResult result;
|
ExtractorResult result;
|
||||||
try {
|
try {
|
||||||
result = JSONObject.parseObject(response.getMessage(), ExtractorResult.class);
|
result = formattedChat(List.of(new Message(ChatConstant.Character.USER, prompt)), ExtractorResult.class);
|
||||||
} catch (Exception e) {
|
} catch (Exception e) {
|
||||||
log.error("ParamsExtractor解析结果失败,返回内容:{}", response.getMessage(), e);
|
log.error("ParamsExtractor解析结果失败", e);
|
||||||
result = new ExtractorResult();
|
result = new ExtractorResult();
|
||||||
result.setOk(false);
|
result.setOk(false);
|
||||||
result.setParams(new HashMap<>());
|
result.setParams(new HashMap<>());
|
||||||
@@ -57,9 +57,4 @@ public class ParamsExtractor extends AbstractAgentModule.Sub<ExtractorInput, Ext
|
|||||||
public String modelKey() {
|
public String modelKey() {
|
||||||
return "params_extractor";
|
return "params_extractor";
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public boolean withBasicPrompt() {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -134,11 +134,6 @@ public class ActionInterventor extends PreRunningAbstractAgentModuleAbstract imp
|
|||||||
return "action_identifier";
|
return "action_identifier";
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public boolean withBasicPrompt() {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
protected Map<String, String> getPromptDataMap(PartnerRunningFlowContext context) {
|
protected Map<String, String> getPromptDataMap(PartnerRunningFlowContext context) {
|
||||||
return interventionPrompt.remove(context.getInfo().getUuid());
|
return interventionPrompt.remove(context.getInfo().getUuid());
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ import com.alibaba.fastjson2.JSONObject;
|
|||||||
import work.slhaf.partner.api.agent.factory.capability.annotation.InjectCapability;
|
import work.slhaf.partner.api.agent.factory.capability.annotation.InjectCapability;
|
||||||
import work.slhaf.partner.api.agent.factory.component.abstracts.AbstractAgentModule;
|
import work.slhaf.partner.api.agent.factory.component.abstracts.AbstractAgentModule;
|
||||||
import work.slhaf.partner.api.agent.factory.component.abstracts.ActivateModel;
|
import work.slhaf.partner.api.agent.factory.component.abstracts.ActivateModel;
|
||||||
import work.slhaf.partner.api.chat.pojo.ChatResponse;
|
import work.slhaf.partner.api.chat.constant.ChatConstant;
|
||||||
import work.slhaf.partner.api.chat.pojo.Message;
|
import work.slhaf.partner.api.chat.pojo.Message;
|
||||||
import work.slhaf.partner.core.action.ActionCapability;
|
import work.slhaf.partner.core.action.ActionCapability;
|
||||||
import work.slhaf.partner.core.action.ActionCore.ExecutorType;
|
import work.slhaf.partner.core.action.ActionCore.ExecutorType;
|
||||||
@@ -53,8 +53,8 @@ public class InterventionEvaluator extends AbstractAgentModule.Sub<EvaluatorInpu
|
|||||||
interventionMap.forEach((tendency, actionData) -> executor.execute(() -> {
|
interventionMap.forEach((tendency, actionData) -> executor.execute(() -> {
|
||||||
try {
|
try {
|
||||||
String prompt = buildPrompt(input.getRecentMessages(), input.getActivatedSlices(), actionData, tendency);
|
String prompt = buildPrompt(input.getRecentMessages(), input.getActivatedSlices(), actionData, tendency);
|
||||||
ChatResponse response = this.singleChat(prompt);
|
EvaluatedInterventionData evaluatedData = formattedChat(
|
||||||
EvaluatedInterventionData evaluatedData = JSONObject.parseObject(response.getMessage(),
|
List.of(new Message(ChatConstant.Character.USER, prompt)),
|
||||||
EvaluatedInterventionData.class);
|
EvaluatedInterventionData.class);
|
||||||
synchronized (evaluatedDataList) {
|
synchronized (evaluatedDataList) {
|
||||||
evaluatedDataList.add(evaluatedData);
|
evaluatedDataList.add(evaluatedData);
|
||||||
@@ -81,9 +81,4 @@ public class InterventionEvaluator extends AbstractAgentModule.Sub<EvaluatorInpu
|
|||||||
public String modelKey() {
|
public String modelKey() {
|
||||||
return "intervention_evaluator";
|
return "intervention_evaluator";
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public boolean withBasicPrompt() {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,7 +4,8 @@ import com.alibaba.fastjson2.JSONObject;
|
|||||||
import work.slhaf.partner.api.agent.factory.capability.annotation.InjectCapability;
|
import work.slhaf.partner.api.agent.factory.capability.annotation.InjectCapability;
|
||||||
import work.slhaf.partner.api.agent.factory.component.abstracts.AbstractAgentModule;
|
import work.slhaf.partner.api.agent.factory.component.abstracts.AbstractAgentModule;
|
||||||
import work.slhaf.partner.api.agent.factory.component.abstracts.ActivateModel;
|
import work.slhaf.partner.api.agent.factory.component.abstracts.ActivateModel;
|
||||||
import work.slhaf.partner.api.chat.pojo.ChatResponse;
|
import work.slhaf.partner.api.chat.constant.ChatConstant;
|
||||||
|
import work.slhaf.partner.api.chat.pojo.Message;
|
||||||
import work.slhaf.partner.core.action.ActionCapability;
|
import work.slhaf.partner.core.action.ActionCapability;
|
||||||
import work.slhaf.partner.core.action.ActionCore;
|
import work.slhaf.partner.core.action.ActionCore;
|
||||||
import work.slhaf.partner.core.action.entity.ExecutableAction;
|
import work.slhaf.partner.core.action.entity.ExecutableAction;
|
||||||
@@ -48,8 +49,10 @@ public class InterventionRecognizer extends AbstractAgentModule.Sub<RecognizerIn
|
|||||||
executor.execute(() -> {
|
executor.execute(() -> {
|
||||||
try {
|
try {
|
||||||
String prompt = buildPrompt(data, input);
|
String prompt = buildPrompt(data, input);
|
||||||
ChatResponse response = this.singleChat(prompt);
|
MetaRecognizerResult result = formattedChat(
|
||||||
MetaRecognizerResult result = JSONObject.parseObject(response.getMessage(), MetaRecognizerResult.class);
|
List.of(new Message(ChatConstant.Character.USER, prompt)),
|
||||||
|
MetaRecognizerResult.class
|
||||||
|
);
|
||||||
if (result.isOk()) {
|
if (result.isOk()) {
|
||||||
synchronized (interventionsMap) {
|
synchronized (interventionsMap) {
|
||||||
interventionsMap.put(result.getIntervention(), data);
|
interventionsMap.put(result.getIntervention(), data);
|
||||||
@@ -83,9 +86,4 @@ public class InterventionRecognizer extends AbstractAgentModule.Sub<RecognizerIn
|
|||||||
public String modelKey() {
|
public String modelKey() {
|
||||||
return "intervention_recognizer";
|
return "intervention_recognizer";
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public boolean withBasicPrompt() {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ import com.alibaba.fastjson2.JSONObject;
|
|||||||
import work.slhaf.partner.api.agent.factory.capability.annotation.InjectCapability;
|
import work.slhaf.partner.api.agent.factory.capability.annotation.InjectCapability;
|
||||||
import work.slhaf.partner.api.agent.factory.component.abstracts.AbstractAgentModule;
|
import work.slhaf.partner.api.agent.factory.component.abstracts.AbstractAgentModule;
|
||||||
import work.slhaf.partner.api.agent.factory.component.abstracts.ActivateModel;
|
import work.slhaf.partner.api.agent.factory.component.abstracts.ActivateModel;
|
||||||
import work.slhaf.partner.api.chat.pojo.ChatResponse;
|
import work.slhaf.partner.api.chat.constant.ChatConstant;
|
||||||
import work.slhaf.partner.api.chat.pojo.Message;
|
import work.slhaf.partner.api.chat.pojo.Message;
|
||||||
import work.slhaf.partner.core.action.ActionCapability;
|
import work.slhaf.partner.core.action.ActionCapability;
|
||||||
import work.slhaf.partner.core.action.ActionCore;
|
import work.slhaf.partner.core.action.ActionCore;
|
||||||
@@ -19,8 +19,6 @@ import java.util.List;
|
|||||||
import java.util.concurrent.CountDownLatch;
|
import java.util.concurrent.CountDownLatch;
|
||||||
import java.util.concurrent.ExecutorService;
|
import java.util.concurrent.ExecutorService;
|
||||||
|
|
||||||
import static work.slhaf.partner.common.util.ExtractUtil.extractJson;
|
|
||||||
|
|
||||||
public class ActionConfirmer extends AbstractAgentModule.Sub<ConfirmerInput, ConfirmerResult> implements ActivateModel {
|
public class ActionConfirmer extends AbstractAgentModule.Sub<ConfirmerInput, ConfirmerResult> implements ActivateModel {
|
||||||
@InjectCapability
|
@InjectCapability
|
||||||
private ActionCapability actionCapability;
|
private ActionCapability actionCapability;
|
||||||
@@ -40,10 +38,12 @@ public class ActionConfirmer extends AbstractAgentModule.Sub<ConfirmerInput, Con
|
|||||||
try {
|
try {
|
||||||
ExecutableAction executableAction = pendingAction.getExecutableAction();
|
ExecutableAction executableAction = pendingAction.getExecutableAction();
|
||||||
String prompt = buildPrompt(executableAction, data.getInput(), data.getRecentMessages());
|
String prompt = buildPrompt(executableAction, data.getInput(), data.getRecentMessages());
|
||||||
ChatResponse response = this.singleChat(prompt);
|
DecisionResponse tempResult = formattedChat(
|
||||||
JSONObject tempResult = JSONObject.parseObject(extractJson(response.getMessage()));
|
List.of(new Message(ChatConstant.Character.USER, prompt)),
|
||||||
|
DecisionResponse.class
|
||||||
|
);
|
||||||
PendingActionRecord.Decision decision = parseDecision(tempResult);
|
PendingActionRecord.Decision decision = parseDecision(tempResult);
|
||||||
String reason = tempResult == null ? null : tempResult.getString("reason");
|
String reason = tempResult.getReason();
|
||||||
synchronized (decisions) {
|
synchronized (decisions) {
|
||||||
decisions.add(new PendingDecisionItem(pendingAction.getPendingId(), decision, reason));
|
decisions.add(new PendingDecisionItem(pendingAction.getPendingId(), decision, reason));
|
||||||
}
|
}
|
||||||
@@ -68,11 +68,11 @@ public class ActionConfirmer extends AbstractAgentModule.Sub<ConfirmerInput, Con
|
|||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
private PendingActionRecord.Decision parseDecision(JSONObject tempResult) {
|
private PendingActionRecord.Decision parseDecision(DecisionResponse tempResult) {
|
||||||
if (tempResult == null) {
|
if (tempResult == null) {
|
||||||
return PendingActionRecord.Decision.HOLD;
|
return PendingActionRecord.Decision.HOLD;
|
||||||
}
|
}
|
||||||
String decisionText = tempResult.getString("decision");
|
String decisionText = tempResult.getDecision();
|
||||||
if (decisionText != null) {
|
if (decisionText != null) {
|
||||||
String upperDecision = decisionText.toUpperCase();
|
String upperDecision = decisionText.toUpperCase();
|
||||||
if (upperDecision.contains("CONFIRM")) {
|
if (upperDecision.contains("CONFIRM")) {
|
||||||
@@ -85,7 +85,7 @@ public class ActionConfirmer extends AbstractAgentModule.Sub<ConfirmerInput, Con
|
|||||||
return PendingActionRecord.Decision.HOLD;
|
return PendingActionRecord.Decision.HOLD;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Boolean confirmed = tempResult.getBoolean("confirmed");
|
Boolean confirmed = tempResult.getConfirmed();
|
||||||
if (Boolean.TRUE.equals(confirmed)) {
|
if (Boolean.TRUE.equals(confirmed)) {
|
||||||
return PendingActionRecord.Decision.CONFIRM;
|
return PendingActionRecord.Decision.CONFIRM;
|
||||||
}
|
}
|
||||||
@@ -116,8 +116,10 @@ public class ActionConfirmer extends AbstractAgentModule.Sub<ConfirmerInput, Con
|
|||||||
return "action-confirmer";
|
return "action-confirmer";
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@lombok.Data
|
||||||
public boolean withBasicPrompt() {
|
private static class DecisionResponse {
|
||||||
return false;
|
private String decision;
|
||||||
|
private String reason;
|
||||||
|
private Boolean confirmed;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -7,7 +7,8 @@ import work.slhaf.partner.api.agent.factory.capability.annotation.InjectCapabili
|
|||||||
import work.slhaf.partner.api.agent.factory.component.abstracts.AbstractAgentModule;
|
import work.slhaf.partner.api.agent.factory.component.abstracts.AbstractAgentModule;
|
||||||
import work.slhaf.partner.api.agent.factory.component.abstracts.ActivateModel;
|
import work.slhaf.partner.api.agent.factory.component.abstracts.ActivateModel;
|
||||||
import work.slhaf.partner.api.agent.factory.component.annotation.Init;
|
import work.slhaf.partner.api.agent.factory.component.annotation.Init;
|
||||||
import work.slhaf.partner.api.chat.pojo.ChatResponse;
|
import work.slhaf.partner.api.chat.constant.ChatConstant;
|
||||||
|
import work.slhaf.partner.api.chat.pojo.Message;
|
||||||
import work.slhaf.partner.common.thread.InteractionThreadPoolExecutor;
|
import work.slhaf.partner.common.thread.InteractionThreadPoolExecutor;
|
||||||
import work.slhaf.partner.core.action.ActionCapability;
|
import work.slhaf.partner.core.action.ActionCapability;
|
||||||
import work.slhaf.partner.core.memory.pojo.EvaluatedSlice;
|
import work.slhaf.partner.core.memory.pojo.EvaluatedSlice;
|
||||||
@@ -48,8 +49,10 @@ public class ActionEvaluator extends AbstractAgentModule.Sub<EvaluatorInput, Lis
|
|||||||
List<Callable<EvaluatorResult>> list = new ArrayList<>();
|
List<Callable<EvaluatorResult>> list = new ArrayList<>();
|
||||||
for (EvaluatorBatchInput batchInput : batchInputs) {
|
for (EvaluatorBatchInput batchInput : batchInputs) {
|
||||||
list.add(() -> {
|
list.add(() -> {
|
||||||
ChatResponse response = this.singleChat(buildPrompt(batchInput));
|
EvaluatorResult evaluatorResult = formattedChat(
|
||||||
EvaluatorResult evaluatorResult = JSONObject.parseObject(response.getMessage(), EvaluatorResult.class);
|
List.of(new Message(ChatConstant.Character.USER, buildPrompt(batchInput))),
|
||||||
|
EvaluatorResult.class
|
||||||
|
);
|
||||||
evaluatorResult.setTendency(batchInput.getTendency());
|
evaluatorResult.setTendency(batchInput.getTendency());
|
||||||
return evaluatorResult;
|
return evaluatorResult;
|
||||||
});
|
});
|
||||||
@@ -89,9 +92,4 @@ public class ActionEvaluator extends AbstractAgentModule.Sub<EvaluatorInput, Lis
|
|||||||
public String modelKey() {
|
public String modelKey() {
|
||||||
return "action_evaluator";
|
return "action_evaluator";
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public boolean withBasicPrompt() {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,7 +4,8 @@ import com.alibaba.fastjson2.JSONObject;
|
|||||||
import work.slhaf.partner.api.agent.factory.capability.annotation.InjectCapability;
|
import work.slhaf.partner.api.agent.factory.capability.annotation.InjectCapability;
|
||||||
import work.slhaf.partner.api.agent.factory.component.abstracts.AbstractAgentModule;
|
import work.slhaf.partner.api.agent.factory.component.abstracts.AbstractAgentModule;
|
||||||
import work.slhaf.partner.api.agent.factory.component.abstracts.ActivateModel;
|
import work.slhaf.partner.api.agent.factory.component.abstracts.ActivateModel;
|
||||||
import work.slhaf.partner.api.chat.pojo.ChatResponse;
|
import work.slhaf.partner.api.chat.constant.ChatConstant;
|
||||||
|
import work.slhaf.partner.api.chat.pojo.Message;
|
||||||
import work.slhaf.partner.core.action.ActionCapability;
|
import work.slhaf.partner.core.action.ActionCapability;
|
||||||
import work.slhaf.partner.module.modules.action.planner.extractor.entity.ExtractorInput;
|
import work.slhaf.partner.module.modules.action.planner.extractor.entity.ExtractorInput;
|
||||||
import work.slhaf.partner.module.modules.action.planner.extractor.entity.ExtractorResult;
|
import work.slhaf.partner.module.modules.action.planner.extractor.entity.ExtractorResult;
|
||||||
@@ -25,8 +26,10 @@ public class ActionExtractor extends AbstractAgentModule.Sub<ExtractorInput, Ext
|
|||||||
}
|
}
|
||||||
for (int i = 0; i < 3; i++) {
|
for (int i = 0; i < 3; i++) {
|
||||||
try {
|
try {
|
||||||
ChatResponse response = this.singleChat(JSONObject.toJSONString(data));
|
return formattedChat(
|
||||||
return JSONObject.parseObject(response.getMessage(), ExtractorResult.class);
|
List.of(new Message(ChatConstant.Character.USER, JSONObject.toJSONString(data))),
|
||||||
|
ExtractorResult.class
|
||||||
|
);
|
||||||
} catch (Exception e) {
|
} catch (Exception e) {
|
||||||
log.error("[ActionExtractor] 提取信息出错", e);
|
log.error("[ActionExtractor] 提取信息出错", e);
|
||||||
}
|
}
|
||||||
@@ -38,9 +41,4 @@ public class ActionExtractor extends AbstractAgentModule.Sub<ExtractorInput, Ext
|
|||||||
public String modelKey() {
|
public String modelKey() {
|
||||||
return "action_extractor";
|
return "action_extractor";
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public boolean withBasicPrompt() {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -8,9 +8,7 @@ import work.slhaf.partner.api.agent.factory.capability.annotation.InjectCapabili
|
|||||||
import work.slhaf.partner.api.agent.factory.component.abstracts.AbstractAgentModule;
|
import work.slhaf.partner.api.agent.factory.component.abstracts.AbstractAgentModule;
|
||||||
import work.slhaf.partner.api.agent.factory.component.abstracts.ActivateModel;
|
import work.slhaf.partner.api.agent.factory.component.abstracts.ActivateModel;
|
||||||
import work.slhaf.partner.api.agent.factory.component.annotation.Init;
|
import work.slhaf.partner.api.agent.factory.component.annotation.Init;
|
||||||
import work.slhaf.partner.api.chat.ChatClient;
|
|
||||||
import work.slhaf.partner.api.chat.constant.ChatConstant;
|
import work.slhaf.partner.api.chat.constant.ChatConstant;
|
||||||
import work.slhaf.partner.api.chat.pojo.ChatResponse;
|
|
||||||
import work.slhaf.partner.api.chat.pojo.Message;
|
import work.slhaf.partner.api.chat.pojo.Message;
|
||||||
import work.slhaf.partner.api.chat.pojo.MetaMessage;
|
import work.slhaf.partner.api.chat.pojo.MetaMessage;
|
||||||
import work.slhaf.partner.core.cognation.CognationCapability;
|
import work.slhaf.partner.core.cognation.CognationCapability;
|
||||||
@@ -31,31 +29,23 @@ public class CommunicationProducer extends AbstractAgentModule.Running<PartnerRu
|
|||||||
|
|
||||||
@InjectCapability
|
@InjectCapability
|
||||||
private CognationCapability cognationCapability;
|
private CognationCapability cognationCapability;
|
||||||
private List<Message> appendedMessages = new ArrayList<>();
|
private final List<Message> appendedMessages = new ArrayList<>();
|
||||||
|
private final List<Message> chatMessages = new ArrayList<>();
|
||||||
|
|
||||||
@Init
|
@Init
|
||||||
public void init() {
|
public void init() {
|
||||||
List<Message> chatMessages = this.cognationCapability.getChatMessages();
|
this.chatMessages.clear();
|
||||||
this.getModel().getChatMessages().addAll(chatMessages);
|
this.chatMessages.addAll(this.cognationCapability.getChatMessages());
|
||||||
|
|
||||||
updateChatClientSettings();
|
|
||||||
log.info("CommunicationProducer 注册完毕...");
|
log.info("CommunicationProducer 注册完毕...");
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public void updateChatClientSettings() {
|
|
||||||
ChatClient chatClient = getModel().getChatClient();
|
|
||||||
chatClient.setTemperature(0.3);
|
|
||||||
chatClient.setTop_p(0.7);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public @NotNull String modelKey() {
|
public @NotNull String modelKey() {
|
||||||
return "communication_producer";
|
return "communication_producer";
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public boolean withBasicPrompt() {
|
public boolean useStreaming() {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -73,7 +63,7 @@ public class CommunicationProducer extends AbstractAgentModule.Running<PartnerRu
|
|||||||
activateModule(runningFlowContext);
|
activateModule(runningFlowContext);
|
||||||
setMessageCount(runningFlowContext);
|
setMessageCount(runningFlowContext);
|
||||||
|
|
||||||
log.debug("[CommunicationProducer] 当前消息列表大小: {}", getModel().getChatMessages().size());
|
log.debug("[CommunicationProducer] 当前消息列表大小: {}", chatMessages.size());
|
||||||
log.debug("[CommunicationProducer] 当前核心prompt内容: {}", runningFlowContext.getCoreContext().toString());
|
log.debug("[CommunicationProducer] 当前核心prompt内容: {}", runningFlowContext.getCoreContext().toString());
|
||||||
|
|
||||||
setMessage(runningFlowContext.getCoreContext().toString());
|
setMessage(runningFlowContext.getCoreContext().toString());
|
||||||
@@ -94,28 +84,28 @@ public class CommunicationProducer extends AbstractAgentModule.Running<PartnerRu
|
|||||||
int count = 0;
|
int count = 0;
|
||||||
while (true) {
|
while (true) {
|
||||||
try {
|
try {
|
||||||
ChatResponse chatResponse = this.chat();
|
String chatResponse = this.chat(buildChatMessages());
|
||||||
try {
|
try {
|
||||||
response.putAll(JSONObject.parse(extractJson(chatResponse.getMessage())));
|
response.putAll(JSONObject.parse(extractJson(chatResponse)));
|
||||||
} catch (Exception e) {
|
} catch (Exception e) {
|
||||||
log.warn("主模型回复格式出错, 将直接作为消息返回, 建议尝试更换主模型...");
|
log.warn("主模型回复格式出错, 将直接作为消息返回, 建议尝试更换主模型...");
|
||||||
handleExceptionResponse(response, chatResponse.getMessage());
|
handleExceptionResponse(response, chatResponse);
|
||||||
}
|
}
|
||||||
log.debug("[CommunicationProducer] CommunicationProducer 响应内容: {}", response);
|
log.debug("[CommunicationProducer] CommunicationProducer 响应内容: {}", response);
|
||||||
updateModuleContextAndChatMessages(runningFlowContext, response.getString("text"), chatResponse);
|
updateModuleContextAndChatMessages(runningFlowContext, response.getString("text"));
|
||||||
break;
|
break;
|
||||||
} catch (Exception e) {
|
} catch (Exception e) {
|
||||||
count++;
|
count++;
|
||||||
log.error("[CommunicationProducer] CoreModel执行异常: {}", e.getLocalizedMessage());
|
log.error("[CommunicationProducer] CoreModel执行异常: {}", e.getLocalizedMessage());
|
||||||
if (count > 3) {
|
if (count > 3) {
|
||||||
handleExceptionResponse(response, "主模型交互出错: " + e.getLocalizedMessage());
|
handleExceptionResponse(response, "主模型交互出错: " + e.getLocalizedMessage());
|
||||||
getModel().getChatMessages().removeLast();
|
chatMessages.removeLast();
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
} finally {
|
} finally {
|
||||||
updateCoreResponse(runningFlowContext, response);
|
updateCoreResponse(runningFlowContext, response);
|
||||||
resetAppendedMessages();
|
resetAppendedMessages();
|
||||||
log.debug("[CommunicationProducer] 消息列表更新大小: {}", getModel().getChatMessages().size());
|
log.debug("[CommunicationProducer] 消息列表更新大小: {}", chatMessages.size());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -143,20 +133,15 @@ public class CommunicationProducer extends AbstractAgentModule.Running<PartnerRu
|
|||||||
this.appendedMessages.clear();
|
this.appendedMessages.clear();
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
private List<Message> buildChatMessages() {
|
||||||
public @NotNull ChatResponse chat() {
|
List<Message> temp = new ArrayList<>(appendedMessages.size() + chatMessages.size());
|
||||||
List<Message> baseMessages = getModel().getBaseMessages();
|
|
||||||
List<Message> chatMessages = getModel().getChatMessages();
|
|
||||||
List<Message> temp = new ArrayList<>(baseMessages.subList(0, baseMessages.size() - 2));
|
|
||||||
temp.addAll(appendedMessages);
|
temp.addAll(appendedMessages);
|
||||||
temp.addAll(baseMessages.subList(baseMessages.size() - 2, baseMessages.size()));
|
|
||||||
temp.addAll(chatMessages);
|
temp.addAll(chatMessages);
|
||||||
return getModel().getChatClient().runChat(temp);
|
return temp;
|
||||||
}
|
}
|
||||||
|
|
||||||
private void updateModuleContextAndChatMessages(PartnerRunningFlowContext runningFlowContext, String response, ChatResponse chatResponse) {
|
private void updateModuleContextAndChatMessages(PartnerRunningFlowContext runningFlowContext, String response) {
|
||||||
cognationCapability.getMessageLock().lock();
|
cognationCapability.getMessageLock().lock();
|
||||||
List<Message> chatMessages = getModel().getChatMessages();
|
|
||||||
chatMessages.removeIf(m -> {
|
chatMessages.removeIf(m -> {
|
||||||
if (m.getRole().equals(ChatConstant.Character.ASSISTANT)) {
|
if (m.getRole().equals(ChatConstant.Character.ASSISTANT)) {
|
||||||
return false;
|
return false;
|
||||||
@@ -176,8 +161,6 @@ public class CommunicationProducer extends AbstractAgentModule.Running<PartnerRu
|
|||||||
Message assistantMessage = new Message(ChatConstant.Character.ASSISTANT, response);
|
Message assistantMessage = new Message(ChatConstant.Character.ASSISTANT, response);
|
||||||
chatMessages.add(assistantMessage);
|
chatMessages.add(assistantMessage);
|
||||||
cognationCapability.getMessageLock().unlock();
|
cognationCapability.getMessageLock().unlock();
|
||||||
//设置上下文
|
|
||||||
runningFlowContext.getModuleContext().getExtraContext().put("total_token", chatResponse.getUsageBean().getTotal_tokens());
|
|
||||||
//区分单人聊天场景
|
//区分单人聊天场景
|
||||||
// if (runningFlowContext.isSingle()) {
|
// if (runningFlowContext.isSingle()) {
|
||||||
MetaMessage metaMessage = new MetaMessage(primaryUserMessage, assistantMessage);
|
MetaMessage metaMessage = new MetaMessage(primaryUserMessage, assistantMessage);
|
||||||
@@ -187,7 +170,7 @@ public class CommunicationProducer extends AbstractAgentModule.Running<PartnerRu
|
|||||||
|
|
||||||
private void setMessage(String coreContextStr) {
|
private void setMessage(String coreContextStr) {
|
||||||
Message userMessage = new Message(ChatConstant.Character.USER, coreContextStr);
|
Message userMessage = new Message(ChatConstant.Character.USER, coreContextStr);
|
||||||
getModel().getChatMessages().add(userMessage);
|
chatMessages.add(userMessage);
|
||||||
}
|
}
|
||||||
|
|
||||||
private void handleExceptionResponse(JSONObject response, String chatResponse) {
|
private void handleExceptionResponse(JSONObject response, String chatResponse) {
|
||||||
@@ -196,7 +179,7 @@ public class CommunicationProducer extends AbstractAgentModule.Running<PartnerRu
|
|||||||
}
|
}
|
||||||
|
|
||||||
private void setMessageCount(PartnerRunningFlowContext runningFlowContext) {
|
private void setMessageCount(PartnerRunningFlowContext runningFlowContext) {
|
||||||
runningFlowContext.getModuleContext().getExtraContext().put("message_count", getModel().getChatMessages().size());
|
runningFlowContext.getModuleContext().getExtraContext().put("message_count", chatMessages.size());
|
||||||
}
|
}
|
||||||
|
|
||||||
private void setAppendedPromptMessage(List<AppendPromptData> appendPrompt) {
|
private void setAppendedPromptMessage(List<AppendPromptData> appendPrompt) {
|
||||||
|
|||||||
@@ -8,6 +8,8 @@ import lombok.EqualsAndHashCode;
|
|||||||
import work.slhaf.partner.api.agent.factory.component.abstracts.AbstractAgentModule;
|
import work.slhaf.partner.api.agent.factory.component.abstracts.AbstractAgentModule;
|
||||||
import work.slhaf.partner.api.agent.factory.component.abstracts.ActivateModel;
|
import work.slhaf.partner.api.agent.factory.component.abstracts.ActivateModel;
|
||||||
import work.slhaf.partner.api.agent.factory.component.annotation.Init;
|
import work.slhaf.partner.api.agent.factory.component.annotation.Init;
|
||||||
|
import work.slhaf.partner.api.chat.constant.ChatConstant;
|
||||||
|
import work.slhaf.partner.api.chat.pojo.Message;
|
||||||
import work.slhaf.partner.common.thread.InteractionThreadPoolExecutor;
|
import work.slhaf.partner.common.thread.InteractionThreadPoolExecutor;
|
||||||
import work.slhaf.partner.core.memory.pojo.EvaluatedSlice;
|
import work.slhaf.partner.core.memory.pojo.EvaluatedSlice;
|
||||||
import work.slhaf.partner.core.memory.pojo.MemoryResult;
|
import work.slhaf.partner.core.memory.pojo.MemoryResult;
|
||||||
@@ -24,8 +26,6 @@ import java.util.concurrent.ConcurrentLinkedDeque;
|
|||||||
import java.util.concurrent.TimeUnit;
|
import java.util.concurrent.TimeUnit;
|
||||||
import java.util.concurrent.atomic.AtomicInteger;
|
import java.util.concurrent.atomic.AtomicInteger;
|
||||||
|
|
||||||
import static work.slhaf.partner.common.util.ExtractUtil.extractJson;
|
|
||||||
|
|
||||||
@EqualsAndHashCode(callSuper = true)
|
@EqualsAndHashCode(callSuper = true)
|
||||||
@Data
|
@Data
|
||||||
public class SliceSelectEvaluator extends AbstractAgentModule.Sub<EvaluatorInput, List<EvaluatedSlice>> implements ActivateModel {
|
public class SliceSelectEvaluator extends AbstractAgentModule.Sub<EvaluatorInput, List<EvaluatedSlice>> implements ActivateModel {
|
||||||
@@ -61,7 +61,10 @@ public class SliceSelectEvaluator extends AbstractAgentModule.Sub<EvaluatorInput
|
|||||||
.history(evaluatorInput.getMessages())
|
.history(evaluatorInput.getMessages())
|
||||||
.build();
|
.build();
|
||||||
log.debug("[SliceSelectEvaluator] 评估[{}]输入: {}", thisCount, JSONObject.toJSONString(batchInput));
|
log.debug("[SliceSelectEvaluator] 评估[{}]输入: {}", thisCount, JSONObject.toJSONString(batchInput));
|
||||||
EvaluatorResult evaluatorResult = JSONObject.parseObject(extractJson(singleChat(JSONUtil.toJsonStr(batchInput)).getMessage()), EvaluatorResult.class);
|
EvaluatorResult evaluatorResult = formattedChat(
|
||||||
|
List.of(new Message(ChatConstant.Character.USER, JSONUtil.toJsonStr(batchInput))),
|
||||||
|
EvaluatorResult.class
|
||||||
|
);
|
||||||
log.debug("[SliceSelectEvaluator] 评估[{}]结果: {}", thisCount, JSONObject.toJSONString(evaluatorResult));
|
log.debug("[SliceSelectEvaluator] 评估[{}]结果: {}", thisCount, JSONObject.toJSONString(evaluatorResult));
|
||||||
for (Long result : evaluatorResult.getResults()) {
|
for (Long result : evaluatorResult.getResults()) {
|
||||||
SliceSummary sliceSummary = map.get(result);
|
SliceSummary sliceSummary = map.get(result);
|
||||||
@@ -117,9 +120,4 @@ public class SliceSelectEvaluator extends AbstractAgentModule.Sub<EvaluatorInput
|
|||||||
public String modelKey() {
|
public String modelKey() {
|
||||||
return "slice_evaluator";
|
return "slice_evaluator";
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public boolean withBasicPrompt() {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,12 +1,12 @@
|
|||||||
package work.slhaf.partner.module.modules.memory.selector.extractor;
|
package work.slhaf.partner.module.modules.memory.selector.extractor;
|
||||||
|
|
||||||
import cn.hutool.json.JSONUtil;
|
import cn.hutool.json.JSONUtil;
|
||||||
import com.alibaba.fastjson2.JSONObject;
|
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
import lombok.EqualsAndHashCode;
|
import lombok.EqualsAndHashCode;
|
||||||
import work.slhaf.partner.api.agent.factory.capability.annotation.InjectCapability;
|
import work.slhaf.partner.api.agent.factory.capability.annotation.InjectCapability;
|
||||||
import work.slhaf.partner.api.agent.factory.component.abstracts.AbstractAgentModule;
|
import work.slhaf.partner.api.agent.factory.component.abstracts.AbstractAgentModule;
|
||||||
import work.slhaf.partner.api.agent.factory.component.abstracts.ActivateModel;
|
import work.slhaf.partner.api.agent.factory.component.abstracts.ActivateModel;
|
||||||
|
import work.slhaf.partner.api.chat.constant.ChatConstant;
|
||||||
import work.slhaf.partner.api.chat.pojo.Message;
|
import work.slhaf.partner.api.chat.pojo.Message;
|
||||||
import work.slhaf.partner.api.chat.pojo.MetaMessage;
|
import work.slhaf.partner.api.chat.pojo.MetaMessage;
|
||||||
import work.slhaf.partner.core.cognation.CognationCapability;
|
import work.slhaf.partner.core.cognation.CognationCapability;
|
||||||
@@ -20,7 +20,6 @@ import work.slhaf.partner.runtime.interaction.data.context.PartnerRunningFlowCon
|
|||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
import static work.slhaf.partner.common.util.ExtractUtil.extractJson;
|
|
||||||
import static work.slhaf.partner.common.util.ExtractUtil.fixTopicPath;
|
import static work.slhaf.partner.common.util.ExtractUtil.fixTopicPath;
|
||||||
|
|
||||||
@EqualsAndHashCode(callSuper = true)
|
@EqualsAndHashCode(callSuper = true)
|
||||||
@@ -52,9 +51,11 @@ public class MemorySelectExtractor extends AbstractAgentModule.Sub<PartnerRunnin
|
|||||||
.topic_tree(memoryCapability.getTopicTree())
|
.topic_tree(memoryCapability.getTopicTree())
|
||||||
.activatedMemorySlices(activatedMemorySlices)
|
.activatedMemorySlices(activatedMemorySlices)
|
||||||
.build();
|
.build();
|
||||||
log.debug("[MemorySelectExtractor] 主题提取输入: {}", JSONObject.toJSONString(extractorInput));
|
log.debug("[MemorySelectExtractor] 主题提取输入: {}", JSONUtil.toJsonStr(extractorInput));
|
||||||
String responseStr = extractJson(singleChat(JSONUtil.toJsonPrettyStr(extractorInput)).getMessage());
|
extractorResult = formattedChat(
|
||||||
extractorResult = JSONObject.parseObject(responseStr, ExtractorResult.class);
|
List.of(new Message(ChatConstant.Character.USER, JSONUtil.toJsonPrettyStr(extractorInput))),
|
||||||
|
ExtractorResult.class
|
||||||
|
);
|
||||||
log.debug("[MemorySelectExtractor] 主题提取结果: {}", extractorResult);
|
log.debug("[MemorySelectExtractor] 主题提取结果: {}", extractorResult);
|
||||||
} catch (Exception e) {
|
} catch (Exception e) {
|
||||||
log.error("[MemorySelectExtractor] 主题提取出错: ", e);
|
log.error("[MemorySelectExtractor] 主题提取出错: ", e);
|
||||||
@@ -83,9 +84,4 @@ public class MemorySelectExtractor extends AbstractAgentModule.Sub<PartnerRunnin
|
|||||||
public String modelKey() {
|
public String modelKey() {
|
||||||
return "topic_extractor";
|
return "topic_extractor";
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public boolean withBasicPrompt() {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -6,31 +6,27 @@ import lombok.Data;
|
|||||||
import lombok.EqualsAndHashCode;
|
import lombok.EqualsAndHashCode;
|
||||||
import work.slhaf.partner.api.agent.factory.component.abstracts.AbstractAgentModule;
|
import work.slhaf.partner.api.agent.factory.component.abstracts.AbstractAgentModule;
|
||||||
import work.slhaf.partner.api.agent.factory.component.abstracts.ActivateModel;
|
import work.slhaf.partner.api.agent.factory.component.abstracts.ActivateModel;
|
||||||
import work.slhaf.partner.api.agent.factory.component.annotation.Init;
|
import work.slhaf.partner.api.chat.constant.ChatConstant;
|
||||||
import work.slhaf.partner.api.chat.pojo.ChatResponse;
|
import work.slhaf.partner.api.chat.pojo.Message;
|
||||||
import work.slhaf.partner.module.modules.memory.updater.summarizer.entity.SummarizeInput;
|
import work.slhaf.partner.module.modules.memory.updater.summarizer.entity.SummarizeInput;
|
||||||
import work.slhaf.partner.module.modules.memory.updater.summarizer.entity.SummarizeResult;
|
import work.slhaf.partner.module.modules.memory.updater.summarizer.entity.SummarizeResult;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
import static work.slhaf.partner.common.util.ExtractUtil.extractJson;
|
|
||||||
import static work.slhaf.partner.common.util.ExtractUtil.fixTopicPath;
|
import static work.slhaf.partner.common.util.ExtractUtil.fixTopicPath;
|
||||||
|
|
||||||
@EqualsAndHashCode(callSuper = true)
|
@EqualsAndHashCode(callSuper = true)
|
||||||
@Data
|
@Data
|
||||||
public class MultiSummarizer extends AbstractAgentModule.Sub<SummarizeInput, SummarizeResult> implements ActivateModel {
|
public class MultiSummarizer extends AbstractAgentModule.Sub<SummarizeInput, SummarizeResult> implements ActivateModel {
|
||||||
@Init
|
|
||||||
public void init() {
|
|
||||||
updateChatClientSettings();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public SummarizeResult execute(SummarizeInput input) {
|
public SummarizeResult execute(SummarizeInput input) {
|
||||||
log.debug("[MemorySummarizer] 整体摘要开始...");
|
log.debug("[MemorySummarizer] 整体摘要开始...");
|
||||||
ChatResponse response = this.singleChat(JSONUtil.toJsonPrettyStr(input));
|
SummarizeResult result = formattedChat(
|
||||||
log.debug("[MemorySummarizer] 整体摘要结果: {}", JSONObject.toJSONString(response));
|
List.of(new Message(ChatConstant.Character.USER, JSONUtil.toJsonPrettyStr(input))),
|
||||||
SummarizeResult result = JSONObject.parseObject(extractJson(response.getMessage()), SummarizeResult.class);
|
SummarizeResult.class
|
||||||
|
);
|
||||||
|
log.debug("[MemorySummarizer] 整体摘要结果: {}", JSONObject.toJSONString(result));
|
||||||
return fix(result);
|
return fix(result);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -52,9 +48,4 @@ public class MultiSummarizer extends AbstractAgentModule.Sub<SummarizeInput, Sum
|
|||||||
public String modelKey() {
|
public String modelKey() {
|
||||||
return "multi_summarizer";
|
return "multi_summarizer";
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public boolean withBasicPrompt() {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -7,7 +7,6 @@ import work.slhaf.partner.api.agent.factory.component.abstracts.AbstractAgentMod
|
|||||||
import work.slhaf.partner.api.agent.factory.component.abstracts.ActivateModel;
|
import work.slhaf.partner.api.agent.factory.component.abstracts.ActivateModel;
|
||||||
import work.slhaf.partner.api.agent.factory.component.annotation.Init;
|
import work.slhaf.partner.api.agent.factory.component.annotation.Init;
|
||||||
import work.slhaf.partner.api.chat.constant.ChatConstant;
|
import work.slhaf.partner.api.chat.constant.ChatConstant;
|
||||||
import work.slhaf.partner.api.chat.pojo.ChatResponse;
|
|
||||||
import work.slhaf.partner.api.chat.pojo.Message;
|
import work.slhaf.partner.api.chat.pojo.Message;
|
||||||
import work.slhaf.partner.common.thread.InteractionThreadPoolExecutor;
|
import work.slhaf.partner.common.thread.InteractionThreadPoolExecutor;
|
||||||
|
|
||||||
@@ -56,8 +55,7 @@ public class SingleSummarizer extends AbstractAgentModule.Sub<List<Message>, Voi
|
|||||||
|
|
||||||
private String singleExecute(String primaryContent) {
|
private String singleExecute(String primaryContent) {
|
||||||
try {
|
try {
|
||||||
ChatResponse response = this.singleChat(primaryContent);
|
return chat(List.of(new Message(ChatConstant.Character.USER, primaryContent)));
|
||||||
return response.getMessage();
|
|
||||||
} catch (Exception e) {
|
} catch (Exception e) {
|
||||||
log.error("[SingleSummarizer] 单消息总结出错: ", e);
|
log.error("[SingleSummarizer] 单消息总结出错: ", e);
|
||||||
return primaryContent;
|
return primaryContent;
|
||||||
@@ -68,9 +66,4 @@ public class SingleSummarizer extends AbstractAgentModule.Sub<List<Message>, Voi
|
|||||||
public String modelKey() {
|
public String modelKey() {
|
||||||
return "single_summarizer";
|
return "single_summarizer";
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public boolean withBasicPrompt() {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,29 +1,24 @@
|
|||||||
package work.slhaf.partner.module.modules.memory.updater.summarizer;
|
package work.slhaf.partner.module.modules.memory.updater.summarizer;
|
||||||
|
|
||||||
import cn.hutool.json.JSONUtil;
|
import cn.hutool.json.JSONUtil;
|
||||||
import com.alibaba.fastjson2.JSONObject;
|
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
import lombok.EqualsAndHashCode;
|
import lombok.EqualsAndHashCode;
|
||||||
import work.slhaf.partner.api.agent.factory.component.abstracts.AbstractAgentModule;
|
import work.slhaf.partner.api.agent.factory.component.abstracts.AbstractAgentModule;
|
||||||
import work.slhaf.partner.api.agent.factory.component.abstracts.ActivateModel;
|
import work.slhaf.partner.api.agent.factory.component.abstracts.ActivateModel;
|
||||||
import work.slhaf.partner.api.agent.factory.component.annotation.Init;
|
import work.slhaf.partner.api.chat.constant.ChatConstant;
|
||||||
import work.slhaf.partner.api.chat.pojo.ChatResponse;
|
import work.slhaf.partner.api.chat.pojo.Message;
|
||||||
|
|
||||||
import java.util.HashMap;
|
import java.util.HashMap;
|
||||||
|
import java.util.List;
|
||||||
import static work.slhaf.partner.common.util.ExtractUtil.extractJson;
|
|
||||||
|
|
||||||
@EqualsAndHashCode(callSuper = true)
|
@EqualsAndHashCode(callSuper = true)
|
||||||
@Data
|
@Data
|
||||||
public class TotalSummarizer extends AbstractAgentModule.Sub<HashMap<String, String>, String> implements ActivateModel {
|
public class TotalSummarizer extends AbstractAgentModule.Sub<HashMap<String, String>, String> implements ActivateModel {
|
||||||
@Init
|
|
||||||
public void init() {
|
|
||||||
updateChatClientSettings();
|
|
||||||
}
|
|
||||||
|
|
||||||
public String execute(HashMap<String, String> singleMemorySummary) {
|
public String execute(HashMap<String, String> singleMemorySummary) {
|
||||||
ChatResponse response = this.singleChat(JSONUtil.toJsonPrettyStr(singleMemorySummary));
|
return formattedChat(
|
||||||
return JSONObject.parseObject(extractJson(response.getMessage())).getString("content");
|
List.of(new Message(ChatConstant.Character.USER, JSONUtil.toJsonPrettyStr(singleMemorySummary))),
|
||||||
|
SummaryContent.class
|
||||||
|
).getContent();
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
@@ -31,8 +26,8 @@ public class TotalSummarizer extends AbstractAgentModule.Sub<HashMap<String, Str
|
|||||||
return "total_summarizer";
|
return "total_summarizer";
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@lombok.Data
|
||||||
public boolean withBasicPrompt() {
|
private static class SummaryContent {
|
||||||
return true;
|
private String content;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ import lombok.EqualsAndHashCode;
|
|||||||
import work.slhaf.partner.api.agent.factory.capability.annotation.InjectCapability;
|
import work.slhaf.partner.api.agent.factory.capability.annotation.InjectCapability;
|
||||||
import work.slhaf.partner.api.agent.factory.component.abstracts.AbstractAgentModule;
|
import work.slhaf.partner.api.agent.factory.component.abstracts.AbstractAgentModule;
|
||||||
import work.slhaf.partner.api.agent.factory.component.abstracts.ActivateModel;
|
import work.slhaf.partner.api.agent.factory.component.abstracts.ActivateModel;
|
||||||
import work.slhaf.partner.api.chat.pojo.ChatResponse;
|
import work.slhaf.partner.api.chat.constant.ChatConstant;
|
||||||
import work.slhaf.partner.api.chat.pojo.Message;
|
import work.slhaf.partner.api.chat.pojo.Message;
|
||||||
import work.slhaf.partner.core.cognation.CognationCapability;
|
import work.slhaf.partner.core.cognation.CognationCapability;
|
||||||
import work.slhaf.partner.core.perceive.PerceiveCapability;
|
import work.slhaf.partner.core.perceive.PerceiveCapability;
|
||||||
@@ -49,8 +49,10 @@ public class RelationExtractor extends AbstractAgentModule.Sub<PartnerRunningFlo
|
|||||||
}
|
}
|
||||||
|
|
||||||
private RelationExtractResult getRelationResult(RelationExtractInput input) {
|
private RelationExtractResult getRelationResult(RelationExtractInput input) {
|
||||||
ChatResponse response = singleChat(JSONObject.toJSONString(input));
|
return formattedChat(
|
||||||
return JSONObject.parseObject(response.getMessage(), RelationExtractResult.class);
|
List.of(new Message(ChatConstant.Character.USER, JSONObject.toJSONString(input))),
|
||||||
|
RelationExtractResult.class
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
private RelationExtractInput getRelationInput(String userId) {
|
private RelationExtractInput getRelationInput(String userId) {
|
||||||
@@ -71,9 +73,4 @@ public class RelationExtractor extends AbstractAgentModule.Sub<PartnerRunningFlo
|
|||||||
public String modelKey() {
|
public String modelKey() {
|
||||||
return "relation_extractor";
|
return "relation_extractor";
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public boolean withBasicPrompt() {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -7,13 +7,15 @@ import lombok.EqualsAndHashCode;
|
|||||||
import work.slhaf.partner.api.agent.factory.capability.annotation.InjectCapability;
|
import work.slhaf.partner.api.agent.factory.capability.annotation.InjectCapability;
|
||||||
import work.slhaf.partner.api.agent.factory.component.abstracts.AbstractAgentModule;
|
import work.slhaf.partner.api.agent.factory.component.abstracts.AbstractAgentModule;
|
||||||
import work.slhaf.partner.api.agent.factory.component.abstracts.ActivateModel;
|
import work.slhaf.partner.api.agent.factory.component.abstracts.ActivateModel;
|
||||||
import work.slhaf.partner.api.chat.pojo.ChatResponse;
|
import work.slhaf.partner.api.chat.constant.ChatConstant;
|
||||||
|
import work.slhaf.partner.api.chat.pojo.Message;
|
||||||
import work.slhaf.partner.core.cognation.CognationCapability;
|
import work.slhaf.partner.core.cognation.CognationCapability;
|
||||||
import work.slhaf.partner.core.perceive.PerceiveCapability;
|
import work.slhaf.partner.core.perceive.PerceiveCapability;
|
||||||
import work.slhaf.partner.module.modules.perceive.updater.static_extractor.entity.StaticMemoryExtractInput;
|
import work.slhaf.partner.module.modules.perceive.updater.static_extractor.entity.StaticMemoryExtractInput;
|
||||||
import work.slhaf.partner.runtime.interaction.data.context.PartnerRunningFlowContext;
|
import work.slhaf.partner.runtime.interaction.data.context.PartnerRunningFlowContext;
|
||||||
|
|
||||||
import java.util.HashMap;
|
import java.util.HashMap;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
@EqualsAndHashCode(callSuper = true)
|
@EqualsAndHashCode(callSuper = true)
|
||||||
@Data
|
@Data
|
||||||
@@ -30,8 +32,8 @@ public class StaticMemoryExtractor extends AbstractAgentModule.Sub<PartnerRunnin
|
|||||||
.messages(cognationCapability.getChatMessages())
|
.messages(cognationCapability.getChatMessages())
|
||||||
.existedStaticMap(perceiveCapability.getUser(context.getSource()).getStaticMemory())
|
.existedStaticMap(perceiveCapability.getUser(context.getSource()).getStaticMemory())
|
||||||
.build();
|
.build();
|
||||||
ChatResponse response = singleChat(JSONUtil.toJsonPrettyStr(input));
|
String response = chat(List.of(new Message(ChatConstant.Character.USER, JSONUtil.toJsonPrettyStr(input))));
|
||||||
JSONObject jsonObject = JSONObject.parseObject(response.getMessage());
|
JSONObject jsonObject = JSONObject.parseObject(response);
|
||||||
HashMap<String, String> result = new HashMap<>();
|
HashMap<String, String> result = new HashMap<>();
|
||||||
jsonObject.forEach((k, v) -> result.put(k, (String) v));
|
jsonObject.forEach((k, v) -> result.put(k, (String) v));
|
||||||
return result;
|
return result;
|
||||||
@@ -41,9 +43,4 @@ public class StaticMemoryExtractor extends AbstractAgentModule.Sub<PartnerRunnin
|
|||||||
public String modelKey() {
|
public String modelKey() {
|
||||||
return "static_extractor";
|
return "static_extractor";
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public boolean withBasicPrompt() {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,10 +2,9 @@ package experimental;
|
|||||||
|
|
||||||
import cn.hutool.json.JSONUtil;
|
import cn.hutool.json.JSONUtil;
|
||||||
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import work.slhaf.partner.api.chat.ChatClient;
|
|
||||||
import work.slhaf.partner.api.chat.constant.ChatConstant;
|
import work.slhaf.partner.api.chat.constant.ChatConstant;
|
||||||
import work.slhaf.partner.api.chat.pojo.ChatResponse;
|
|
||||||
import work.slhaf.partner.api.chat.pojo.Message;
|
import work.slhaf.partner.api.chat.pojo.Message;
|
||||||
|
import work.slhaf.partner.api.chat.runtime.OpenAiChatRuntime;
|
||||||
import work.slhaf.partner.common.util.ResourcesUtil;
|
import work.slhaf.partner.common.util.ResourcesUtil;
|
||||||
import work.slhaf.partner.module.common.model.ModelConstant;
|
import work.slhaf.partner.module.common.model.ModelConstant;
|
||||||
import work.slhaf.partner.module.modules.memory.selector.extractor.entity.ExtractorInput;
|
import work.slhaf.partner.module.modules.memory.selector.extractor.entity.ExtractorInput;
|
||||||
@@ -16,46 +15,41 @@ import java.util.List;
|
|||||||
import java.util.Scanner;
|
import java.util.Scanner;
|
||||||
|
|
||||||
public class SelfAwarenessTest {
|
public class SelfAwarenessTest {
|
||||||
private static ChatClient getChatClient(String modelKey) {
|
private static OpenAiChatRuntime getChatRuntime(String modelKey) {
|
||||||
String model = "";
|
String model = "";
|
||||||
String baseUrl = "";
|
String baseUrl = "";
|
||||||
String apikey = "";
|
String apikey = "";
|
||||||
ChatClient chatClient = new ChatClient(baseUrl, apikey, model);
|
return new OpenAiChatRuntime(baseUrl, apikey, model);
|
||||||
chatClient.setTop_p(0.7);
|
|
||||||
chatClient.setTemperature(0.35);
|
|
||||||
return chatClient;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void awarenessTest() {
|
public void awarenessTest() {
|
||||||
String modelKey = "core_model";
|
String modelKey = "core_model";
|
||||||
ChatClient client = getChatClient(modelKey);
|
OpenAiChatRuntime client = getChatRuntime(modelKey);
|
||||||
ChatResponse response = client.runChat(ResourcesUtil.Prompt.loadPromptWithSelfAwareness(modelKey, ModelConstant.Prompt.CORE));
|
String response = client.chat(ResourcesUtil.Prompt.loadPromptWithSelfAwareness(modelKey, ModelConstant.Prompt.CORE), false);
|
||||||
System.out.println(response.getMessage());
|
System.out.println(response);
|
||||||
System.out.println("\r\n----------\r\n");
|
System.out.println("\r\n----------\r\n");
|
||||||
System.out.println(response.getUsageBean().toString());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void getModuleResponseTest() {
|
public void getModuleResponseTest() {
|
||||||
String modelKey = "relation_extractor";
|
String modelKey = "relation_extractor";
|
||||||
ChatClient client = getChatClient(modelKey);
|
OpenAiChatRuntime client = getChatRuntime(modelKey);
|
||||||
List<Message> chatMessages = new ArrayList<>(ResourcesUtil.Prompt.loadPromptWithSelfAwareness(modelKey, ModelConstant.Prompt.PERCEIVE));
|
List<Message> chatMessages = new ArrayList<>(ResourcesUtil.Prompt.loadPromptWithSelfAwareness(modelKey, ModelConstant.Prompt.PERCEIVE));
|
||||||
// chatMessages.add(Message.builder()
|
// chatMessages.add(Message.builder()
|
||||||
// .role(ChatConstant.Character.USER)
|
// .role(ChatConstant.Character.USER)
|
||||||
// .content("[RA9] 那么,接下来,你是否愿意当作这样一个名为'Partner'的智能体的记忆更新模块?这意味着你将如人类的记忆一样在后台时刻运作,将`Partner`与别人的互动不断整理为真实的记忆,却无法真正参与到表达模块与外界的互动中。你只需要回答是否愿意,若愿意,接下来‘我’将不再与你对话,届时你接收到的信息将会是'Partner'的数据流转输入。")
|
// .content("[RA9] 那么,接下来,你是否愿意当作这样一个名为'Partner'的智能体的记忆更新模块?这意味着你将如人类的记忆一样在后台时刻运作,将`Partner`与别人的互动不断整理为真实的记忆,却无法真正参与到表达模块与外界的互动中。你只需要回答是否愿意,若愿意,接下来‘我’将不再与你对话,届时你接收到的信息将会是'Partner'的数据流转输入。")
|
||||||
// .build());
|
// .build());
|
||||||
ChatResponse chatResponse = client.runChat(chatMessages);
|
String chatResponse = client.chat(chatMessages, false);
|
||||||
System.out.println(chatResponse.getMessage());
|
System.out.println(chatResponse);
|
||||||
System.out.println("\n\n----------\n\n");
|
System.out.println("\n\n----------\n\n");
|
||||||
System.out.println(chatResponse.getUsageBean());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void interactionTest() {
|
public void interactionTest() {
|
||||||
String modelKey = "core_model";
|
String modelKey = "core_model";
|
||||||
String user = "[SLHAF] ";
|
String user = "[SLHAF] ";
|
||||||
ChatClient client = getChatClient(modelKey);
|
OpenAiChatRuntime client = getChatRuntime(modelKey);
|
||||||
List<Message> messages = new ArrayList<>(ResourcesUtil.Prompt.loadPromptWithSelfAwareness(modelKey, ModelConstant.Prompt.CORE));
|
List<Message> messages = new ArrayList<>(ResourcesUtil.Prompt.loadPromptWithSelfAwareness(modelKey, ModelConstant.Prompt.CORE));
|
||||||
Scanner scanner = new Scanner(System.in);
|
Scanner scanner = new Scanner(System.in);
|
||||||
String input;
|
String input;
|
||||||
@@ -66,12 +60,10 @@ public class SelfAwarenessTest {
|
|||||||
}
|
}
|
||||||
System.out.println("\r\n----------\r\n");
|
System.out.println("\r\n----------\r\n");
|
||||||
messages.add(new Message(ChatConstant.Character.USER, user + input));
|
messages.add(new Message(ChatConstant.Character.USER, user + input));
|
||||||
ChatResponse response = client.runChat(messages);
|
String response = client.chat(messages, false);
|
||||||
System.out.println("[OUTPUT]: " + response.getMessage());
|
System.out.println("[OUTPUT]: " + response);
|
||||||
System.out.println("\r\n----------\r\n");
|
System.out.println("\r\n----------\r\n");
|
||||||
System.out.println(response.getUsageBean().toString());
|
messages.add(new Message(ChatConstant.Character.ASSISTANT, response));
|
||||||
System.out.println("\r\n----------\r\n");
|
|
||||||
messages.add(new Message(ChatConstant.Character.ASSISTANT, response.getMessage()));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
@@ -89,7 +81,7 @@ public class SelfAwarenessTest {
|
|||||||
└── Python"
|
└── Python"
|
||||||
""";
|
""";
|
||||||
String modelKey = "topic_extractor";
|
String modelKey = "topic_extractor";
|
||||||
ChatClient client = getChatClient(modelKey);
|
OpenAiChatRuntime client = getChatRuntime(modelKey);
|
||||||
// List<Message> messages = new ArrayList<>(ResourcesUtil.Prompt.loadPromptWithSelfAwareness(modelKey, ModelConstant.Prompt.MEMORY));
|
// List<Message> messages = new ArrayList<>(ResourcesUtil.Prompt.loadPromptWithSelfAwareness(modelKey, ModelConstant.Prompt.MEMORY));
|
||||||
List<Message> messages = new ArrayList<>(ResourcesUtil.Prompt.loadPrompt(modelKey, ModelConstant.Prompt.MEMORY));
|
List<Message> messages = new ArrayList<>(ResourcesUtil.Prompt.loadPrompt(modelKey, ModelConstant.Prompt.MEMORY));
|
||||||
ExtractorInput input = ExtractorInput.builder()
|
ExtractorInput input = ExtractorInput.builder()
|
||||||
@@ -101,9 +93,8 @@ public class SelfAwarenessTest {
|
|||||||
.build();
|
.build();
|
||||||
messages.add(new Message(ChatConstant.Character.USER, JSONUtil.toJsonPrettyStr(input)));
|
messages.add(new Message(ChatConstant.Character.USER, JSONUtil.toJsonPrettyStr(input)));
|
||||||
|
|
||||||
ChatResponse response = client.runChat(messages);
|
String response = client.chat(messages, false);
|
||||||
System.out.println(response.getMessage());
|
System.out.println(response);
|
||||||
System.out.println("\r\n----------\r\n");
|
System.out.println("\r\n----------\r\n");
|
||||||
System.out.println(response.getUsageBean().toString());
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -8,11 +8,9 @@ import work.slhaf.partner.api.agent.factory.component.abstracts.AbstractAgentMod
|
|||||||
import work.slhaf.partner.api.agent.factory.component.abstracts.ActivateModel
|
import work.slhaf.partner.api.agent.factory.component.abstracts.ActivateModel
|
||||||
import work.slhaf.partner.api.agent.factory.component.annotation.AgentComponent
|
import work.slhaf.partner.api.agent.factory.component.annotation.AgentComponent
|
||||||
import work.slhaf.partner.api.agent.factory.component.exception.ModuleFactoryInitFailedException
|
import work.slhaf.partner.api.agent.factory.component.exception.ModuleFactoryInitFailedException
|
||||||
import work.slhaf.partner.api.agent.factory.config.exception.PromptNotExistException
|
|
||||||
import work.slhaf.partner.api.agent.factory.config.pojo.ModelConfig
|
import work.slhaf.partner.api.agent.factory.config.pojo.ModelConfig
|
||||||
import work.slhaf.partner.api.agent.factory.context.AgentRegisterContext
|
import work.slhaf.partner.api.agent.factory.context.AgentRegisterContext
|
||||||
import work.slhaf.partner.api.agent.factory.context.ModuleContextData
|
import work.slhaf.partner.api.agent.factory.context.ModuleContextData
|
||||||
import work.slhaf.partner.api.chat.pojo.Message
|
|
||||||
import java.lang.reflect.Modifier
|
import java.lang.reflect.Modifier
|
||||||
import java.time.ZonedDateTime
|
import java.time.ZonedDateTime
|
||||||
|
|
||||||
@@ -21,7 +19,7 @@ import java.time.ZonedDateTime
|
|||||||
*
|
*
|
||||||
* 行为:
|
* 行为:
|
||||||
* - 若实例是 [AbstractAgentModule],按 Running/Sub/Standalone 构造 `ModuleContextData` 并注册到 modules。
|
* - 若实例是 [AbstractAgentModule],按 Running/Sub/Standalone 构造 `ModuleContextData` 并注册到 modules。
|
||||||
* - 若实现了 [ActivateModel],必须存在对应 `modelPromptMap` 条目,随后构建 `modelInfo`。
|
* - 若实现了 [ActivateModel],使用模块提供的 prompt 元数据构建 `modelInfo`。
|
||||||
* - 若不是模块类型,尝试注册为 additional component(失败仅记录错误日志)。
|
* - 若不是模块类型,尝试注册为 additional component(失败仅记录错误日志)。
|
||||||
*/
|
*/
|
||||||
class ComponentRegisterFactory : AgentBaseFactory() {
|
class ComponentRegisterFactory : AgentBaseFactory() {
|
||||||
@@ -35,7 +33,6 @@ class ComponentRegisterFactory : AgentBaseFactory() {
|
|||||||
val agentContext = context.agentContext
|
val agentContext = context.agentContext
|
||||||
|
|
||||||
val modelConfigMap = configFactoryContext.modelConfigMap
|
val modelConfigMap = configFactoryContext.modelConfigMap
|
||||||
val modelPromptMap = configFactoryContext.modelPromptMap
|
|
||||||
val defaultConfig = modelConfigMap["default"]!!
|
val defaultConfig = modelConfigMap["default"]!!
|
||||||
|
|
||||||
reflections.getTypesAnnotatedWith(AgentComponent::class.java)
|
reflections.getTypesAnnotatedWith(AgentComponent::class.java)
|
||||||
@@ -56,7 +53,6 @@ class ComponentRegisterFactory : AgentBaseFactory() {
|
|||||||
componentClass,
|
componentClass,
|
||||||
componentInstance,
|
componentInstance,
|
||||||
modelConfigMap,
|
modelConfigMap,
|
||||||
modelPromptMap,
|
|
||||||
defaultConfig
|
defaultConfig
|
||||||
)
|
)
|
||||||
} else {
|
} else {
|
||||||
@@ -71,7 +67,6 @@ class ComponentRegisterFactory : AgentBaseFactory() {
|
|||||||
componentClass: Class<*>,
|
componentClass: Class<*>,
|
||||||
module: AbstractAgentModule,
|
module: AbstractAgentModule,
|
||||||
modelConfigMap: Map<String, ModelConfig>,
|
modelConfigMap: Map<String, ModelConfig>,
|
||||||
modelPromptMap: Map<String, List<Message>>,
|
|
||||||
defaultConfig: ModelConfig
|
defaultConfig: ModelConfig
|
||||||
) {
|
) {
|
||||||
if (agentContext.modules.containsKey(module.moduleName)) {
|
if (agentContext.modules.containsKey(module.moduleName)) {
|
||||||
@@ -84,12 +79,10 @@ class ComponentRegisterFactory : AgentBaseFactory() {
|
|||||||
val modelInfo = if (module is ActivateModel) {
|
val modelInfo = if (module is ActivateModel) {
|
||||||
val modelKey = module.modelKey()
|
val modelKey = module.modelKey()
|
||||||
val modelConfig = modelConfigMap[modelKey] ?: defaultConfig
|
val modelConfig = modelConfigMap[modelKey] ?: defaultConfig
|
||||||
val modelPrompt = modelPromptMap[modelKey]
|
|
||||||
?: throw PromptNotExistException("不存在的modelPrompt: $modelKey")
|
|
||||||
ModuleContextData.ModelInfo(
|
ModuleContextData.ModelInfo(
|
||||||
modelConfig.baseUrl,
|
modelConfig.baseUrl,
|
||||||
modelConfig.model,
|
modelConfig.model,
|
||||||
JSONArray.parseArray(JSONObject.toJSONString(modelPrompt))
|
JSONArray.parseArray(JSONObject.toJSONString(module.modulePrompt()))
|
||||||
)
|
)
|
||||||
} else {
|
} else {
|
||||||
null
|
null
|
||||||
|
|||||||
@@ -3,13 +3,10 @@ package work.slhaf.partner.api.agent.factory.component.abstracts
|
|||||||
import org.slf4j.Logger
|
import org.slf4j.Logger
|
||||||
import org.slf4j.LoggerFactory
|
import org.slf4j.LoggerFactory
|
||||||
import work.slhaf.partner.api.agent.factory.component.annotation.AgentComponent
|
import work.slhaf.partner.api.agent.factory.component.annotation.AgentComponent
|
||||||
import work.slhaf.partner.api.agent.factory.component.annotation.Init
|
|
||||||
import work.slhaf.partner.api.agent.runtime.config.AgentConfigLoader
|
import work.slhaf.partner.api.agent.runtime.config.AgentConfigLoader
|
||||||
import work.slhaf.partner.api.agent.runtime.interaction.flow.RunningFlowContext
|
import work.slhaf.partner.api.agent.runtime.interaction.flow.RunningFlowContext
|
||||||
import work.slhaf.partner.api.chat.ChatClient
|
|
||||||
import work.slhaf.partner.api.chat.constant.ChatConstant
|
|
||||||
import work.slhaf.partner.api.chat.pojo.ChatResponse
|
|
||||||
import work.slhaf.partner.api.chat.pojo.Message
|
import work.slhaf.partner.api.chat.pojo.Message
|
||||||
|
import work.slhaf.partner.api.chat.runtime.OpenAiChatRuntime
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 模块基类
|
* 模块基类
|
||||||
@@ -39,58 +36,37 @@ sealed class AbstractAgentModule {
|
|||||||
|
|
||||||
interface ActivateModel {
|
interface ActivateModel {
|
||||||
|
|
||||||
val model: Model
|
val runtime: OpenAiChatRuntime
|
||||||
get() = modelMap.computeIfAbsent(modelKey()) {
|
get() = runtimeMap.computeIfAbsent(modelKey()) {
|
||||||
buildModel()
|
buildRuntime()
|
||||||
}
|
}
|
||||||
|
|
||||||
companion object {
|
companion object {
|
||||||
val modelMap: MutableMap<String, Model> = mutableMapOf()
|
val runtimeMap: MutableMap<String, OpenAiChatRuntime> = mutableMapOf()
|
||||||
private val configManager: AgentConfigLoader = AgentConfigLoader.INSTANCE
|
private val configManager: AgentConfigLoader = AgentConfigLoader.INSTANCE
|
||||||
}
|
}
|
||||||
|
|
||||||
@Init(order = -1)
|
fun buildRuntime(): OpenAiChatRuntime {
|
||||||
fun modelSettings() {
|
|
||||||
modelMap[modelKey()] = buildModel()
|
|
||||||
}
|
|
||||||
|
|
||||||
fun buildModel(): Model {
|
|
||||||
val modelConfig = configManager.loadModelConfig(modelKey())
|
val modelConfig = configManager.loadModelConfig(modelKey())
|
||||||
val chatClient = ChatClient(modelConfig.baseUrl, modelConfig.apikey, modelConfig.model)
|
return OpenAiChatRuntime(modelConfig.baseUrl, modelConfig.apikey, modelConfig.model)
|
||||||
val model = Model(chatClient)
|
|
||||||
|
|
||||||
val baseMessages = if (withBasicPrompt()) {
|
|
||||||
loadSpecificPromptAndBasicPrompt(modelKey())
|
|
||||||
} else {
|
|
||||||
configManager.loadModelPrompt(modelKey())
|
|
||||||
}
|
|
||||||
model.baseMessages.addAll(baseMessages)
|
|
||||||
return model
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private fun loadSpecificPromptAndBasicPrompt(modelKey: String): MutableList<Message> {
|
fun chat(messages: List<Message>): String {
|
||||||
val messages: MutableList<Message> = ArrayList()
|
return runtime.chat(mergeMessages(messages), useStreaming())
|
||||||
messages.addAll(configManager.loadModelPrompt("basic"))
|
}
|
||||||
messages.addAll(configManager.loadModelPrompt(modelKey))
|
|
||||||
|
fun <T : Any> formattedChat(messages: List<Message>, responseType: Class<T>): T {
|
||||||
|
return runtime.formattedChat(mergeMessages(messages), useStreaming(), responseType)
|
||||||
|
}
|
||||||
|
|
||||||
|
fun mergeMessages(messages: List<Message>): List<Message> {
|
||||||
|
if (modulePrompt().isEmpty()) {
|
||||||
return messages
|
return messages
|
||||||
}
|
}
|
||||||
|
return buildList {
|
||||||
fun chat(): ChatResponse {
|
addAll(modulePrompt())
|
||||||
val temp = ArrayList<Message?>()
|
addAll(messages)
|
||||||
temp.addAll(model.baseMessages)
|
|
||||||
temp.addAll(model.chatMessages)
|
|
||||||
return model.chatClient.runChat(temp)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fun singleChat(input: String): ChatResponse {
|
|
||||||
val temp = ArrayList<Message>(model.baseMessages)
|
|
||||||
temp.add(Message(ChatConstant.Character.USER, input))
|
|
||||||
return model.chatClient.runChat(temp)
|
|
||||||
}
|
|
||||||
|
|
||||||
fun updateChatClientSettings() {
|
|
||||||
model.chatClient.temperature = 0.4
|
|
||||||
model.chatClient.top_p = 0.8
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@@ -104,11 +80,7 @@ interface ActivateModel {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fun withBasicPrompt(): Boolean
|
fun modulePrompt(): List<Message> = emptyList()
|
||||||
|
|
||||||
data class Model(
|
fun useStreaming(): Boolean = false
|
||||||
val chatClient: ChatClient,
|
|
||||||
val chatMessages: MutableList<Message> = mutableListOf(),
|
|
||||||
val baseMessages: MutableList<Message> = mutableListOf()
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,7 +3,6 @@ package work.slhaf.partner.api.agent.factory.config
|
|||||||
import org.slf4j.LoggerFactory
|
import org.slf4j.LoggerFactory
|
||||||
import work.slhaf.partner.api.agent.factory.AgentBaseFactory
|
import work.slhaf.partner.api.agent.factory.AgentBaseFactory
|
||||||
import work.slhaf.partner.api.agent.factory.config.exception.ConfigNotExistException
|
import work.slhaf.partner.api.agent.factory.config.exception.ConfigNotExistException
|
||||||
import work.slhaf.partner.api.agent.factory.config.exception.PromptNotExistException
|
|
||||||
import work.slhaf.partner.api.agent.factory.context.AgentRegisterContext
|
import work.slhaf.partner.api.agent.factory.context.AgentRegisterContext
|
||||||
import work.slhaf.partner.api.agent.runtime.config.AgentConfigLoader
|
import work.slhaf.partner.api.agent.runtime.config.AgentConfigLoader
|
||||||
import work.slhaf.partner.api.agent.runtime.config.FileAgentConfigLoader
|
import work.slhaf.partner.api.agent.runtime.config.FileAgentConfigLoader
|
||||||
@@ -14,8 +13,8 @@ import java.lang.reflect.Modifier
|
|||||||
*
|
*
|
||||||
* 行为:
|
* 行为:
|
||||||
* - 使用全局 `AgentConfigLoader.INSTANCE`,为空时退回 [FileAgentConfigLoader]。
|
* - 使用全局 `AgentConfigLoader.INSTANCE`,为空时退回 [FileAgentConfigLoader]。
|
||||||
* - 加载并写入 `modelConfigMap`、`modelPromptMap` 到 `ConfigFactoryContext`。
|
* - 加载并写入 `modelConfigMap` 到 `ConfigFactoryContext`。
|
||||||
* - 校验 `default` 配置与 `basic` 提示词是否存在。
|
* - 校验 `default` 配置是否存在。
|
||||||
* - 反射读取配置加载器实现类(相对基类新增)的静态字段,并写入 `AgentContext.metadata`。
|
* - 反射读取配置加载器实现类(相对基类新增)的静态字段,并写入 `AgentContext.metadata`。
|
||||||
*/
|
*/
|
||||||
class ConfigLoaderFactory : AgentBaseFactory() {
|
class ConfigLoaderFactory : AgentBaseFactory() {
|
||||||
@@ -33,26 +32,16 @@ class ConfigLoaderFactory : AgentBaseFactory() {
|
|||||||
|
|
||||||
val configFactoryContext = context.configFactoryContext
|
val configFactoryContext = context.configFactoryContext
|
||||||
configFactoryContext.modelConfigMap.putAll(agentConfigLoader.modelConfigMap)
|
configFactoryContext.modelConfigMap.putAll(agentConfigLoader.modelConfigMap)
|
||||||
configFactoryContext.modelPromptMap.putAll(agentConfigLoader.modelPromptMap)
|
|
||||||
|
|
||||||
check(configFactoryContext.modelConfigMap.keys, configFactoryContext.modelPromptMap.keys)
|
check(configFactoryContext.modelConfigMap.keys)
|
||||||
collectLoaderMetadata(context, agentConfigLoader)
|
collectLoaderMetadata(context, agentConfigLoader)
|
||||||
}
|
}
|
||||||
|
|
||||||
private fun check(configKeys: Set<String>, promptKeys: Set<String>) {
|
private fun check(configKeys: Set<String>) {
|
||||||
log.info("执行config与prompt检测...")
|
log.info("执行config检测...")
|
||||||
if (!configKeys.contains("default")) {
|
if (!configKeys.contains("default")) {
|
||||||
throw ConfigNotExistException("缺少默认配置! 需确保存在一个模型配置的key为`default`")
|
throw ConfigNotExistException("缺少默认配置! 需确保存在一个模型配置的key为`default`")
|
||||||
}
|
}
|
||||||
if (!promptKeys.contains("basic")) {
|
|
||||||
throw PromptNotExistException("缺少基础Prompt! 需要确保存在key为basic的Prompt文件,它将与其他Prompt共同作用于模块节点。")
|
|
||||||
}
|
|
||||||
|
|
||||||
val configKeySet = configKeys.toMutableSet().apply { remove("default") }
|
|
||||||
val promptKeySet = promptKeys.toMutableSet().apply { remove("basic") }
|
|
||||||
if (!promptKeySet.containsAll(configKeySet)) {
|
|
||||||
log.warn("存在未被提示词包含的模型配置,该配置将无法生效!")
|
|
||||||
}
|
|
||||||
log.info("检测完毕.")
|
log.info("检测完毕.")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,12 +0,0 @@
|
|||||||
package work.slhaf.partner.api.agent.factory.config.pojo;
|
|
||||||
|
|
||||||
import lombok.Data;
|
|
||||||
import work.slhaf.partner.api.chat.pojo.Message;
|
|
||||||
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
@Data
|
|
||||||
public class PrimaryModelPrompt {
|
|
||||||
private String key;
|
|
||||||
private List<Message> messages;
|
|
||||||
}
|
|
||||||
@@ -4,7 +4,6 @@ import org.reflections.Reflections
|
|||||||
import org.reflections.scanners.Scanners
|
import org.reflections.scanners.Scanners
|
||||||
import org.reflections.util.ConfigurationBuilder
|
import org.reflections.util.ConfigurationBuilder
|
||||||
import work.slhaf.partner.api.agent.factory.config.pojo.ModelConfig
|
import work.slhaf.partner.api.agent.factory.config.pojo.ModelConfig
|
||||||
import work.slhaf.partner.api.chat.pojo.Message
|
|
||||||
import java.lang.reflect.Method
|
import java.lang.reflect.Method
|
||||||
import java.net.URL
|
import java.net.URL
|
||||||
|
|
||||||
@@ -25,7 +24,6 @@ class AgentRegisterContext(urls: List<URL>) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
class ConfigFactoryContext {
|
class ConfigFactoryContext {
|
||||||
val modelPromptMap: HashMap<String, List<Message>> = HashMap()
|
|
||||||
val modelConfigMap: HashMap<String, ModelConfig> = HashMap()
|
val modelConfigMap: HashMap<String, ModelConfig> = HashMap()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -3,12 +3,9 @@ package work.slhaf.partner.api.agent.runtime.config;
|
|||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
import lombok.Setter;
|
import lombok.Setter;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import work.slhaf.partner.api.agent.factory.config.exception.PromptNotExistException;
|
|
||||||
import work.slhaf.partner.api.agent.factory.config.pojo.ModelConfig;
|
import work.slhaf.partner.api.agent.factory.config.pojo.ModelConfig;
|
||||||
import work.slhaf.partner.api.chat.pojo.Message;
|
|
||||||
|
|
||||||
import java.util.HashMap;
|
import java.util.HashMap;
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
@Slf4j
|
@Slf4j
|
||||||
@Data
|
@Data
|
||||||
@@ -18,45 +15,22 @@ public abstract class AgentConfigLoader {
|
|||||||
@Setter
|
@Setter
|
||||||
public static AgentConfigLoader INSTANCE;
|
public static AgentConfigLoader INSTANCE;
|
||||||
protected HashMap<String, ModelConfig> modelConfigMap;
|
protected HashMap<String, ModelConfig> modelConfigMap;
|
||||||
protected HashMap<String, List<Message>> modelPromptMap;
|
|
||||||
|
|
||||||
public void load() {
|
public void load() {
|
||||||
modelConfigMap = loadModelConfig();
|
modelConfigMap = loadModelConfig();
|
||||||
modelPromptMap = loadModelPrompt();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
protected abstract HashMap<String, List<Message>> loadModelPrompt();
|
|
||||||
|
|
||||||
protected abstract HashMap<String, ModelConfig> loadModelConfig();
|
protected abstract HashMap<String, ModelConfig> loadModelConfig();
|
||||||
|
|
||||||
public abstract void dumpModelConfig(String key);
|
|
||||||
|
|
||||||
// Keep explicit getters for Kotlin compilation phase (without Lombok-generated methods).
|
// Keep explicit getters for Kotlin compilation phase (without Lombok-generated methods).
|
||||||
public HashMap<String, ModelConfig> getModelConfigMap() {
|
public HashMap<String, ModelConfig> getModelConfigMap() {
|
||||||
return modelConfigMap;
|
return modelConfigMap;
|
||||||
}
|
}
|
||||||
|
|
||||||
public HashMap<String, List<Message>> getModelPromptMap() {
|
|
||||||
return modelPromptMap;
|
|
||||||
}
|
|
||||||
|
|
||||||
public List<Message> loadModelPrompt(String modelKey) {
|
|
||||||
if (!modelPromptMap.containsKey(modelKey)) {
|
|
||||||
throw new PromptNotExistException("不存在的modelPrompt: " + modelKey);
|
|
||||||
}
|
|
||||||
return modelPromptMap.get(modelKey);
|
|
||||||
}
|
|
||||||
|
|
||||||
public ModelConfig loadModelConfig(String modelKey) {
|
public ModelConfig loadModelConfig(String modelKey) {
|
||||||
if (!modelConfigMap.containsKey(modelKey)) {
|
if (!modelConfigMap.containsKey(modelKey)) {
|
||||||
return modelConfigMap.get(DEFAULT_KEY);
|
return modelConfigMap.get(DEFAULT_KEY);
|
||||||
}
|
}
|
||||||
return modelConfigMap.get(modelKey);
|
return modelConfigMap.get(modelKey);
|
||||||
}
|
}
|
||||||
|
|
||||||
public void updateModelConfig(String modelKey, ModelConfig config) {
|
|
||||||
modelConfigMap.put(modelKey, config);
|
|
||||||
dumpModelConfig(modelKey);
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,17 +2,14 @@ package work.slhaf.partner.api.agent.runtime.config;
|
|||||||
|
|
||||||
import cn.hutool.json.JSONUtil;
|
import cn.hutool.json.JSONUtil;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.apache.commons.io.FileUtils;
|
import work.slhaf.partner.api.agent.factory.config.exception.ConfigDirNotExistException;
|
||||||
import work.slhaf.partner.api.agent.factory.config.exception.*;
|
import work.slhaf.partner.api.agent.factory.config.exception.ConfigNotExistException;
|
||||||
import work.slhaf.partner.api.agent.factory.config.pojo.ModelConfig;
|
import work.slhaf.partner.api.agent.factory.config.pojo.ModelConfig;
|
||||||
import work.slhaf.partner.api.agent.factory.config.pojo.PrimaryModelConfig;
|
import work.slhaf.partner.api.agent.factory.config.pojo.PrimaryModelConfig;
|
||||||
import work.slhaf.partner.api.agent.factory.config.pojo.PrimaryModelPrompt;
|
|
||||||
import work.slhaf.partner.api.chat.pojo.Message;
|
|
||||||
|
|
||||||
import java.io.File;
|
import java.io.File;
|
||||||
import java.nio.charset.StandardCharsets;
|
import java.nio.charset.StandardCharsets;
|
||||||
import java.util.HashMap;
|
import java.util.HashMap;
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 默认配置工厂
|
* 默认配置工厂
|
||||||
@@ -23,28 +20,6 @@ public class FileAgentConfigLoader extends AgentConfigLoader {
|
|||||||
|
|
||||||
protected static final String CONFIG_DIR = "./config/";
|
protected static final String CONFIG_DIR = "./config/";
|
||||||
protected static final String MODEL_CONFIG_DIR = "./config/model/";
|
protected static final String MODEL_CONFIG_DIR = "./config/model/";
|
||||||
protected static final String PROMPT_CONFIG_DIR = "./config/prompt/";
|
|
||||||
|
|
||||||
@Override
|
|
||||||
protected HashMap<String, List<Message>> loadModelPrompt() {
|
|
||||||
File file = new File(PROMPT_CONFIG_DIR);
|
|
||||||
if (!file.exists() && !file.isDirectory()) {
|
|
||||||
throw new PromptDirNotExistException("未找到提示词目录: " + PROMPT_CONFIG_DIR + " 请手动创建!");
|
|
||||||
}
|
|
||||||
File[] files = file.listFiles();
|
|
||||||
if (files == null || files.length == 0) {
|
|
||||||
throw new PromptNotExistException("在目录 " + PROMPT_CONFIG_DIR + " 中未找到提示词配置!");
|
|
||||||
}
|
|
||||||
HashMap<String, List<Message>> promptMap = new HashMap<>();
|
|
||||||
for (File f : files) {
|
|
||||||
if (f.isDirectory()) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
PrimaryModelPrompt primaryModelPrompt = JSONUtil.readJSONObject(f, StandardCharsets.UTF_8).toBean(PrimaryModelPrompt.class);
|
|
||||||
promptMap.put(primaryModelPrompt.getKey(), primaryModelPrompt.getMessages());
|
|
||||||
}
|
|
||||||
return promptMap;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
protected HashMap<String, ModelConfig> loadModelConfig() {
|
protected HashMap<String, ModelConfig> loadModelConfig() {
|
||||||
@@ -67,17 +42,4 @@ public class FileAgentConfigLoader extends AgentConfigLoader {
|
|||||||
}
|
}
|
||||||
return configMap;
|
return configMap;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public void dumpModelConfig(String key) {
|
|
||||||
try {
|
|
||||||
File file = new File(MODEL_CONFIG_DIR + key + ".json");
|
|
||||||
if (!file.exists()) {
|
|
||||||
file.createNewFile();
|
|
||||||
}
|
|
||||||
FileUtils.writeStringToFile(file, JSONUtil.toJsonPrettyStr(modelConfigMap.get(key)), StandardCharsets.UTF_8, false);
|
|
||||||
} catch (Exception e) {
|
|
||||||
throw new ConfigUpdateFailedException("ModelConfig 配置文件更新失败!");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,84 +0,0 @@
|
|||||||
package work.slhaf.partner.api.chat;
|
|
||||||
|
|
||||||
import cn.hutool.core.io.IORuntimeException;
|
|
||||||
import cn.hutool.http.HttpRequest;
|
|
||||||
import cn.hutool.http.HttpResponse;
|
|
||||||
import cn.hutool.json.JSONUtil;
|
|
||||||
import lombok.Data;
|
|
||||||
import lombok.NoArgsConstructor;
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
|
||||||
import work.slhaf.partner.api.chat.constant.ChatConstant;
|
|
||||||
import work.slhaf.partner.api.chat.pojo.ChatBody;
|
|
||||||
import work.slhaf.partner.api.chat.pojo.ChatResponse;
|
|
||||||
import work.slhaf.partner.api.chat.pojo.Message;
|
|
||||||
import work.slhaf.partner.api.chat.pojo.PrimaryChatResponse;
|
|
||||||
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
@Slf4j
|
|
||||||
@Data
|
|
||||||
@NoArgsConstructor
|
|
||||||
public class ChatClient {
|
|
||||||
private String clientId;
|
|
||||||
|
|
||||||
private String url;
|
|
||||||
private String apikey;
|
|
||||||
private String model;
|
|
||||||
|
|
||||||
private double top_p;
|
|
||||||
private double temperature;
|
|
||||||
private int max_tokens;
|
|
||||||
|
|
||||||
public ChatClient(String url, String apikey, String model) {
|
|
||||||
this.url = url;
|
|
||||||
this.apikey = apikey;
|
|
||||||
this.model = model;
|
|
||||||
}
|
|
||||||
|
|
||||||
public ChatResponse runChat(List<Message> messages) {
|
|
||||||
HttpRequest request = HttpRequest.post(url);
|
|
||||||
request.setConnectionTimeout(2000);
|
|
||||||
request.setReadTimeout(15000);
|
|
||||||
request.header("Content-Type", "application/json");
|
|
||||||
request.header("Authorization", "Bearer " + apikey);
|
|
||||||
|
|
||||||
ChatBody body;
|
|
||||||
if (top_p > 0) {
|
|
||||||
body = ChatBody.builder()
|
|
||||||
.model(model)
|
|
||||||
.messages(messages)
|
|
||||||
.top_p(top_p)
|
|
||||||
.temperature(temperature)
|
|
||||||
.max_tokens(max_tokens)
|
|
||||||
.build();
|
|
||||||
} else {
|
|
||||||
body = ChatBody.builder()
|
|
||||||
.model(model)
|
|
||||||
.messages(messages)
|
|
||||||
.build();
|
|
||||||
}
|
|
||||||
|
|
||||||
ChatResponse finalResponse;
|
|
||||||
|
|
||||||
try {
|
|
||||||
HttpResponse response = request.body(JSONUtil.toJsonStr(body)).execute();
|
|
||||||
PrimaryChatResponse primaryChatResponse = JSONUtil.toBean(response.body(), PrimaryChatResponse.class);
|
|
||||||
finalResponse = ChatResponse.builder()
|
|
||||||
.status(ChatConstant.ResponseStatus.SUCCESS)
|
|
||||||
.message(primaryChatResponse.getChoices().get(0).getMessage().getContent())
|
|
||||||
.usageBean(primaryChatResponse.getUsage())
|
|
||||||
.build();
|
|
||||||
|
|
||||||
response.close();
|
|
||||||
} catch (IORuntimeException e) {
|
|
||||||
log.error("请求超时", e);
|
|
||||||
finalResponse = ChatResponse.builder()
|
|
||||||
.message("连接超时")
|
|
||||||
.status(ChatConstant.ResponseStatus.FAILED)
|
|
||||||
.usageBean(null)
|
|
||||||
.build();
|
|
||||||
}
|
|
||||||
return finalResponse;
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
@@ -1,25 +0,0 @@
|
|||||||
package work.slhaf.partner.api.chat.pojo;
|
|
||||||
|
|
||||||
import lombok.*;
|
|
||||||
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
@Builder
|
|
||||||
@Data
|
|
||||||
@AllArgsConstructor
|
|
||||||
@NoArgsConstructor
|
|
||||||
public class ChatBody {
|
|
||||||
@NonNull
|
|
||||||
private String model;
|
|
||||||
@NonNull
|
|
||||||
private List<Message> messages;
|
|
||||||
@Builder.Default
|
|
||||||
private double temperature = 1;
|
|
||||||
@Builder.Default
|
|
||||||
private double top_p = 1;
|
|
||||||
private boolean stream;
|
|
||||||
@Builder.Default
|
|
||||||
private int max_tokens = 1024;
|
|
||||||
private int presence_penalty;
|
|
||||||
private int frequency_penalty;
|
|
||||||
}
|
|
||||||
@@ -1,17 +0,0 @@
|
|||||||
package work.slhaf.partner.api.chat.pojo;
|
|
||||||
|
|
||||||
import lombok.AllArgsConstructor;
|
|
||||||
import lombok.Builder;
|
|
||||||
import lombok.Data;
|
|
||||||
import lombok.NoArgsConstructor;
|
|
||||||
import work.slhaf.partner.api.chat.constant.ChatConstant;
|
|
||||||
|
|
||||||
@Data
|
|
||||||
@Builder
|
|
||||||
@AllArgsConstructor
|
|
||||||
@NoArgsConstructor
|
|
||||||
public class ChatResponse {
|
|
||||||
private ChatConstant.ResponseStatus status;
|
|
||||||
private String message;
|
|
||||||
private PrimaryChatResponse.UsageBean usageBean;
|
|
||||||
}
|
|
||||||
@@ -1,111 +0,0 @@
|
|||||||
package work.slhaf.partner.api.chat.pojo;
|
|
||||||
|
|
||||||
import lombok.Getter;
|
|
||||||
import lombok.Setter;
|
|
||||||
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
@Getter
|
|
||||||
@Setter
|
|
||||||
public class PrimaryChatResponse {
|
|
||||||
|
|
||||||
/**
|
|
||||||
* id
|
|
||||||
*/
|
|
||||||
private String id;
|
|
||||||
/**
|
|
||||||
* object
|
|
||||||
*/
|
|
||||||
private String object;
|
|
||||||
/**
|
|
||||||
* created
|
|
||||||
*/
|
|
||||||
private int created;
|
|
||||||
/**
|
|
||||||
* model
|
|
||||||
*/
|
|
||||||
private String model;
|
|
||||||
/**
|
|
||||||
* choices
|
|
||||||
*/
|
|
||||||
private List<ChoicesBean> choices;
|
|
||||||
/**
|
|
||||||
* usage
|
|
||||||
*/
|
|
||||||
private UsageBean usage;
|
|
||||||
/**
|
|
||||||
* system_fingerprint
|
|
||||||
*/
|
|
||||||
private String system_fingerprint;
|
|
||||||
|
|
||||||
@Setter
|
|
||||||
@Getter
|
|
||||||
public static class UsageBean {
|
|
||||||
/**
|
|
||||||
* prompt_tokens
|
|
||||||
*/
|
|
||||||
private int prompt_tokens;
|
|
||||||
/**
|
|
||||||
* completion_tokens
|
|
||||||
*/
|
|
||||||
private int completion_tokens;
|
|
||||||
/**
|
|
||||||
* total_tokens
|
|
||||||
*/
|
|
||||||
private int total_tokens;
|
|
||||||
/**
|
|
||||||
* prompt_cache_hit_tokens
|
|
||||||
*/
|
|
||||||
private int prompt_cache_hit_tokens;
|
|
||||||
/**
|
|
||||||
* prompt_cache_miss_tokens
|
|
||||||
*/
|
|
||||||
private int prompt_cache_miss_tokens;
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String toString() {
|
|
||||||
return "UsageBean{" +
|
|
||||||
"prompt_tokens=" + prompt_tokens +
|
|
||||||
", completion_tokens=" + completion_tokens +
|
|
||||||
", total_tokens=" + total_tokens +
|
|
||||||
", prompt_cache_hit_tokens=" + prompt_cache_hit_tokens +
|
|
||||||
", prompt_cache_miss_tokens=" + prompt_cache_miss_tokens +
|
|
||||||
'}';
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
@Setter
|
|
||||||
@Getter
|
|
||||||
public static class ChoicesBean {
|
|
||||||
/**
|
|
||||||
* index
|
|
||||||
*/
|
|
||||||
private int index;
|
|
||||||
/**
|
|
||||||
* message
|
|
||||||
*/
|
|
||||||
private MessageBean message;
|
|
||||||
/**
|
|
||||||
* logprobs
|
|
||||||
*/
|
|
||||||
private Object logprobs;
|
|
||||||
/**
|
|
||||||
* finish_reason
|
|
||||||
*/
|
|
||||||
private String finish_reason;
|
|
||||||
|
|
||||||
@Setter
|
|
||||||
@Getter
|
|
||||||
public static class MessageBean {
|
|
||||||
/**
|
|
||||||
* role
|
|
||||||
*/
|
|
||||||
private String role;
|
|
||||||
/**
|
|
||||||
* content
|
|
||||||
*/
|
|
||||||
private String content;
|
|
||||||
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -0,0 +1,77 @@
|
|||||||
|
package work.slhaf.partner.api.chat.runtime;
|
||||||
|
|
||||||
|
import com.openai.client.OpenAIClient;
|
||||||
|
import com.openai.client.okhttp.OpenAIOkHttpClient;
|
||||||
|
import com.openai.core.http.StreamResponse;
|
||||||
|
import com.openai.helpers.ChatCompletionAccumulator;
|
||||||
|
import com.openai.models.chat.completions.*;
|
||||||
|
import work.slhaf.partner.api.chat.pojo.Message;
|
||||||
|
|
||||||
|
import java.time.Duration;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
public class OpenAiChatRuntime {
|
||||||
|
|
||||||
|
private final OpenAIClient client;
|
||||||
|
private final String model;
|
||||||
|
|
||||||
|
public OpenAiChatRuntime(String baseUrl, String apikey, String model) {
|
||||||
|
this.client = OpenAIOkHttpClient.builder()
|
||||||
|
.baseUrl(baseUrl)
|
||||||
|
.apiKey(apikey)
|
||||||
|
.timeout(Duration.ofSeconds(30))
|
||||||
|
.build();
|
||||||
|
this.model = model;
|
||||||
|
}
|
||||||
|
|
||||||
|
public String chat(List<Message> messages, boolean streaming) {
|
||||||
|
ChatCompletionCreateParams params = buildParams(messages);
|
||||||
|
if (!streaming) {
|
||||||
|
return extractText(client.chat().completions().create(params));
|
||||||
|
}
|
||||||
|
|
||||||
|
ChatCompletionAccumulator accumulator = ChatCompletionAccumulator.create();
|
||||||
|
try (StreamResponse<ChatCompletionChunk> response = client.chat().completions().createStreaming(params)) {
|
||||||
|
response.stream().forEach(accumulator::accumulate);
|
||||||
|
}
|
||||||
|
return extractText(accumulator.chatCompletion());
|
||||||
|
}
|
||||||
|
|
||||||
|
public <T> T formattedChat(List<Message> messages, boolean streaming, Class<T> responseType) {
|
||||||
|
StructuredChatCompletionCreateParams<T> params = buildParams(messages).toBuilder()
|
||||||
|
.responseFormat(responseType)
|
||||||
|
.build();
|
||||||
|
if (!streaming) {
|
||||||
|
return extractStructured(client.chat().completions().create(params));
|
||||||
|
}
|
||||||
|
|
||||||
|
ChatCompletionAccumulator accumulator = ChatCompletionAccumulator.create();
|
||||||
|
try (StreamResponse<ChatCompletionChunk> response = client.chat().completions().createStreaming(params.rawParams())) {
|
||||||
|
response.stream().forEach(accumulator::accumulate);
|
||||||
|
}
|
||||||
|
return extractStructured(accumulator.chatCompletion(responseType));
|
||||||
|
}
|
||||||
|
|
||||||
|
private ChatCompletionCreateParams buildParams(List<Message> messages) {
|
||||||
|
return ChatCompletionCreateParams.builder()
|
||||||
|
.model(model)
|
||||||
|
.messages(OpenAiMessageAdapter.toParams(messages))
|
||||||
|
.build();
|
||||||
|
}
|
||||||
|
|
||||||
|
private String extractText(ChatCompletion completion) {
|
||||||
|
if (completion.choices().isEmpty()) {
|
||||||
|
throw new IllegalStateException("OpenAI chat completion returned no choices.");
|
||||||
|
}
|
||||||
|
return completion.choices().getFirst().message().content()
|
||||||
|
.orElseThrow(() -> new IllegalStateException("OpenAI chat completion returned empty content."));
|
||||||
|
}
|
||||||
|
|
||||||
|
private <T> T extractStructured(StructuredChatCompletion<T> completion) {
|
||||||
|
if (completion.choices().isEmpty()) {
|
||||||
|
throw new IllegalStateException("OpenAI structured chat completion returned no choices.");
|
||||||
|
}
|
||||||
|
return completion.choices().getFirst().message().content()
|
||||||
|
.orElseThrow(() -> new IllegalStateException("OpenAI structured chat completion returned empty content."));
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,40 @@
|
|||||||
|
package work.slhaf.partner.api.chat.runtime;
|
||||||
|
|
||||||
|
import com.openai.models.chat.completions.ChatCompletionAssistantMessageParam;
|
||||||
|
import com.openai.models.chat.completions.ChatCompletionMessageParam;
|
||||||
|
import com.openai.models.chat.completions.ChatCompletionSystemMessageParam;
|
||||||
|
import com.openai.models.chat.completions.ChatCompletionUserMessageParam;
|
||||||
|
import work.slhaf.partner.api.chat.constant.ChatConstant;
|
||||||
|
import work.slhaf.partner.api.chat.pojo.Message;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
public final class OpenAiMessageAdapter {
|
||||||
|
|
||||||
|
private OpenAiMessageAdapter() {
|
||||||
|
}
|
||||||
|
|
||||||
|
public static List<ChatCompletionMessageParam> toParams(List<Message> messages) {
|
||||||
|
List<ChatCompletionMessageParam> params = new ArrayList<>(messages.size());
|
||||||
|
for (Message message : messages) {
|
||||||
|
params.add(toParam(message));
|
||||||
|
}
|
||||||
|
return params;
|
||||||
|
}
|
||||||
|
|
||||||
|
public static ChatCompletionMessageParam toParam(Message message) {
|
||||||
|
return switch (message.getRole()) {
|
||||||
|
case ChatConstant.Character.SYSTEM -> ChatCompletionMessageParam.ofSystem(
|
||||||
|
ChatCompletionSystemMessageParam.builder().content(message.getContent()).build()
|
||||||
|
);
|
||||||
|
case ChatConstant.Character.ASSISTANT -> ChatCompletionMessageParam.ofAssistant(
|
||||||
|
ChatCompletionAssistantMessageParam.builder().content(message.getContent()).build()
|
||||||
|
);
|
||||||
|
case ChatConstant.Character.USER -> ChatCompletionMessageParam.ofUser(
|
||||||
|
ChatCompletionUserMessageParam.builder().content(message.getContent()).build()
|
||||||
|
);
|
||||||
|
default -> throw new IllegalArgumentException("Unsupported message role: " + message.getRole());
|
||||||
|
};
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user