refactor(chat): replace custom client with OpenAI runtime and remove file-based module prompt loading logic, prompt will be provided by each module

This commit is contained in:
2026-03-09 21:51:07 +08:00
parent 8dc7ed080b
commit 1b2ccaee9c
32 changed files with 288 additions and 615 deletions

View File

@@ -4,9 +4,13 @@ import com.alibaba.fastjson2.JSONObject;
import lombok.val;
import work.slhaf.partner.api.agent.factory.component.abstracts.AbstractAgentModule;
import work.slhaf.partner.api.agent.factory.component.abstracts.ActivateModel;
import work.slhaf.partner.api.chat.constant.ChatConstant;
import work.slhaf.partner.api.chat.pojo.Message;
import work.slhaf.partner.module.modules.action.executor.entity.CorrectorInput;
import work.slhaf.partner.module.modules.action.executor.entity.CorrectorResult;
import java.util.List;
/**
* 负责在单组行动执行后,根据行动意图与结果检查后续行动是否符合目的,必要时直接调整行动链,或发起自对话请求进行干预
*/
@@ -14,8 +18,7 @@ public class ActionCorrector extends AbstractAgentModule.Sub<CorrectorInput, Cor
@Override
public CorrectorResult execute(CorrectorInput input) {
val prompt = buildPrompt(input);
val chatResponse = singleChat(prompt);
return JSONObject.parseObject(chatResponse.getMessage(), CorrectorResult.class);
return formattedChat(List.of(new Message(ChatConstant.Character.USER, prompt)), CorrectorResult.class);
}
private String buildPrompt(CorrectorInput input) {
@@ -37,9 +40,4 @@ public class ActionCorrector extends AbstractAgentModule.Sub<CorrectorInput, Cor
public String modelKey() {
return "action_corrector";
}
@Override
public boolean withBasicPrompt() {
return false;
}
}

View File

@@ -9,7 +9,8 @@ import work.slhaf.partner.api.agent.factory.component.abstracts.AbstractAgentMod
import work.slhaf.partner.api.agent.factory.component.abstracts.ActivateModel;
import work.slhaf.partner.api.agent.factory.component.annotation.Init;
import work.slhaf.partner.api.agent.factory.component.annotation.InjectModule;
import work.slhaf.partner.api.chat.pojo.ChatResponse;
import work.slhaf.partner.api.chat.constant.ChatConstant;
import work.slhaf.partner.api.chat.pojo.Message;
import work.slhaf.partner.core.action.ActionCapability;
import work.slhaf.partner.core.action.ActionCore.ExecutorType;
import work.slhaf.partner.core.action.entity.MetaAction;
@@ -61,8 +62,10 @@ public class ActionRepairer extends AbstractAgentModule.Sub<RepairerInput, Repai
RepairerResult result;
try {
String prompt = assemblyHelper.buildPrompt(data, null);
ChatResponse response = this.singleChat(prompt);
RepairerData repairerData = JSONObject.parseObject(response.getMessage(), RepairerData.class);
RepairerData repairerData = formattedChat(
List.of(new Message(ChatConstant.Character.USER, prompt)),
RepairerData.class
);
result = switch (repairerData.getRepairerType()) {
case ACTION_GENERATION ->
handleActionGeneration(JSONObject.parseObject(repairerData.getData(), GeneratorInput.class));
@@ -75,8 +78,10 @@ public class ActionRepairer extends AbstractAgentModule.Sub<RepairerInput, Repai
&& result.getStatus().equals(RepairerResult.RepairerStatus.FAILED)) {
log.warn("常规行动修复失败,将尝试自对话通道");
prompt = assemblyHelper.buildPrompt(data, "常规行动修复失败,请尝试通过自对话通道获取必要的信息以完成行动参数的修复");
response = this.singleChat(prompt);
repairerData = JSONObject.parseObject(response.getMessage(), RepairerData.class);
repairerData = formattedChat(
List.of(new Message(ChatConstant.Character.USER, prompt)),
RepairerData.class
);
handleUserInteraction(repairerData.getData());
}
} catch (Exception e) {
@@ -165,20 +170,14 @@ public class ActionRepairer extends AbstractAgentModule.Sub<RepairerInput, Repai
return "action_repairer";
}
@Override
public boolean withBasicPrompt() {
return false;
}
private enum RepairerType {
ACTION_GENERATION,
ACTION_INVOCATION,
USER_INTERACTION
}
@SuppressWarnings("InnerClassMayBeStatic")
@Data
private class RepairerData {
private static class RepairerData {
private RepairerType repairerType;
private String data;
}

View File

@@ -6,8 +6,8 @@ import work.slhaf.partner.api.agent.factory.capability.annotation.InjectCapabili
import work.slhaf.partner.api.agent.factory.component.abstracts.AbstractAgentModule;
import work.slhaf.partner.api.agent.factory.component.abstracts.ActivateModel;
import work.slhaf.partner.api.agent.factory.component.annotation.Init;
import work.slhaf.partner.api.chat.pojo.ChatResponse;
import work.slhaf.partner.common.util.ExtractUtil;
import work.slhaf.partner.api.chat.constant.ChatConstant;
import work.slhaf.partner.api.chat.pojo.Message;
import work.slhaf.partner.core.action.ActionCapability;
import work.slhaf.partner.core.action.entity.GeneratedData;
import work.slhaf.partner.core.action.entity.MetaAction;
@@ -15,6 +15,8 @@ import work.slhaf.partner.core.action.runner.RunnerClient;
import work.slhaf.partner.module.modules.action.executor.entity.GeneratorInput;
import work.slhaf.partner.module.modules.action.executor.entity.GeneratorResult;
import java.util.List;
/**
* 负责依据输入内容生成可执行的动态行动单元,并选择是否持久化至 SandboxRunner 容器内
*/
@@ -37,9 +39,10 @@ public class DynamicActionGenerator extends AbstractAgentModule.Sub<GeneratorInp
// 所以此处的输入内容也只需要指定输入参数、临时key、是否持久化即可路径将按照指定规则统一构建不可交给LLM生成
String prompt = buildPrompt(input);
// 响应结果需要包含几个特殊数据: 依赖项、代码内容、是否序列化、响应数据释义
ChatResponse response = this.singleChat(prompt);
GeneratedData generatorData = JSONObject
.parseObject(ExtractUtil.extractJson(response.getMessage()), GeneratedData.class);
GeneratedData generatorData = formattedChat(
List.of(new Message(ChatConstant.Character.USER, prompt)),
GeneratedData.class
);
val location = runnerClient.buildTmpPath(input.getActionName(), generatorData.getCodeType());
MetaAction tempAction = new MetaAction(
input.getActionName(),
@@ -77,9 +80,4 @@ public class DynamicActionGenerator extends AbstractAgentModule.Sub<GeneratorInp
public String modelKey() {
return "dynamic_generator";
}
@Override
public boolean withBasicPrompt() {
return false;
}
}

View File

@@ -4,7 +4,8 @@ import com.alibaba.fastjson2.JSONArray;
import com.alibaba.fastjson2.JSONObject;
import work.slhaf.partner.api.agent.factory.component.abstracts.AbstractAgentModule;
import work.slhaf.partner.api.agent.factory.component.abstracts.ActivateModel;
import work.slhaf.partner.api.chat.pojo.ChatResponse;
import work.slhaf.partner.api.chat.constant.ChatConstant;
import work.slhaf.partner.api.chat.pojo.Message;
import work.slhaf.partner.core.action.entity.MetaActionInfo;
import work.slhaf.partner.module.modules.action.executor.entity.ExtractorInput;
import work.slhaf.partner.module.modules.action.executor.entity.ExtractorResult;
@@ -20,12 +21,11 @@ public class ParamsExtractor extends AbstractAgentModule.Sub<ExtractorInput, Ext
@Override
public ExtractorResult execute(ExtractorInput input) {
String prompt = buildPrompt(input);
ChatResponse response = this.singleChat(prompt);
ExtractorResult result;
try {
result = JSONObject.parseObject(response.getMessage(), ExtractorResult.class);
result = formattedChat(List.of(new Message(ChatConstant.Character.USER, prompt)), ExtractorResult.class);
} catch (Exception e) {
log.error("ParamsExtractor解析结果失败,返回内容:{}", response.getMessage(), e);
log.error("ParamsExtractor解析结果失败", e);
result = new ExtractorResult();
result.setOk(false);
result.setParams(new HashMap<>());
@@ -57,9 +57,4 @@ public class ParamsExtractor extends AbstractAgentModule.Sub<ExtractorInput, Ext
public String modelKey() {
return "params_extractor";
}
@Override
public boolean withBasicPrompt() {
return false;
}
}

View File

@@ -134,11 +134,6 @@ public class ActionInterventor extends PreRunningAbstractAgentModuleAbstract imp
return "action_identifier";
}
@Override
public boolean withBasicPrompt() {
return false;
}
@Override
protected Map<String, String> getPromptDataMap(PartnerRunningFlowContext context) {
return interventionPrompt.remove(context.getInfo().getUuid());

View File

@@ -4,7 +4,7 @@ import com.alibaba.fastjson2.JSONObject;
import work.slhaf.partner.api.agent.factory.capability.annotation.InjectCapability;
import work.slhaf.partner.api.agent.factory.component.abstracts.AbstractAgentModule;
import work.slhaf.partner.api.agent.factory.component.abstracts.ActivateModel;
import work.slhaf.partner.api.chat.pojo.ChatResponse;
import work.slhaf.partner.api.chat.constant.ChatConstant;
import work.slhaf.partner.api.chat.pojo.Message;
import work.slhaf.partner.core.action.ActionCapability;
import work.slhaf.partner.core.action.ActionCore.ExecutorType;
@@ -53,8 +53,8 @@ public class InterventionEvaluator extends AbstractAgentModule.Sub<EvaluatorInpu
interventionMap.forEach((tendency, actionData) -> executor.execute(() -> {
try {
String prompt = buildPrompt(input.getRecentMessages(), input.getActivatedSlices(), actionData, tendency);
ChatResponse response = this.singleChat(prompt);
EvaluatedInterventionData evaluatedData = JSONObject.parseObject(response.getMessage(),
EvaluatedInterventionData evaluatedData = formattedChat(
List.of(new Message(ChatConstant.Character.USER, prompt)),
EvaluatedInterventionData.class);
synchronized (evaluatedDataList) {
evaluatedDataList.add(evaluatedData);
@@ -81,9 +81,4 @@ public class InterventionEvaluator extends AbstractAgentModule.Sub<EvaluatorInpu
public String modelKey() {
return "intervention_evaluator";
}
@Override
public boolean withBasicPrompt() {
return false;
}
}

View File

@@ -4,7 +4,8 @@ import com.alibaba.fastjson2.JSONObject;
import work.slhaf.partner.api.agent.factory.capability.annotation.InjectCapability;
import work.slhaf.partner.api.agent.factory.component.abstracts.AbstractAgentModule;
import work.slhaf.partner.api.agent.factory.component.abstracts.ActivateModel;
import work.slhaf.partner.api.chat.pojo.ChatResponse;
import work.slhaf.partner.api.chat.constant.ChatConstant;
import work.slhaf.partner.api.chat.pojo.Message;
import work.slhaf.partner.core.action.ActionCapability;
import work.slhaf.partner.core.action.ActionCore;
import work.slhaf.partner.core.action.entity.ExecutableAction;
@@ -48,8 +49,10 @@ public class InterventionRecognizer extends AbstractAgentModule.Sub<RecognizerIn
executor.execute(() -> {
try {
String prompt = buildPrompt(data, input);
ChatResponse response = this.singleChat(prompt);
MetaRecognizerResult result = JSONObject.parseObject(response.getMessage(), MetaRecognizerResult.class);
MetaRecognizerResult result = formattedChat(
List.of(new Message(ChatConstant.Character.USER, prompt)),
MetaRecognizerResult.class
);
if (result.isOk()) {
synchronized (interventionsMap) {
interventionsMap.put(result.getIntervention(), data);
@@ -83,9 +86,4 @@ public class InterventionRecognizer extends AbstractAgentModule.Sub<RecognizerIn
public String modelKey() {
return "intervention_recognizer";
}
@Override
public boolean withBasicPrompt() {
return false;
}
}

View File

@@ -5,7 +5,7 @@ import com.alibaba.fastjson2.JSONObject;
import work.slhaf.partner.api.agent.factory.capability.annotation.InjectCapability;
import work.slhaf.partner.api.agent.factory.component.abstracts.AbstractAgentModule;
import work.slhaf.partner.api.agent.factory.component.abstracts.ActivateModel;
import work.slhaf.partner.api.chat.pojo.ChatResponse;
import work.slhaf.partner.api.chat.constant.ChatConstant;
import work.slhaf.partner.api.chat.pojo.Message;
import work.slhaf.partner.core.action.ActionCapability;
import work.slhaf.partner.core.action.ActionCore;
@@ -19,8 +19,6 @@ import java.util.List;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import static work.slhaf.partner.common.util.ExtractUtil.extractJson;
public class ActionConfirmer extends AbstractAgentModule.Sub<ConfirmerInput, ConfirmerResult> implements ActivateModel {
@InjectCapability
private ActionCapability actionCapability;
@@ -40,10 +38,12 @@ public class ActionConfirmer extends AbstractAgentModule.Sub<ConfirmerInput, Con
try {
ExecutableAction executableAction = pendingAction.getExecutableAction();
String prompt = buildPrompt(executableAction, data.getInput(), data.getRecentMessages());
ChatResponse response = this.singleChat(prompt);
JSONObject tempResult = JSONObject.parseObject(extractJson(response.getMessage()));
DecisionResponse tempResult = formattedChat(
List.of(new Message(ChatConstant.Character.USER, prompt)),
DecisionResponse.class
);
PendingActionRecord.Decision decision = parseDecision(tempResult);
String reason = tempResult == null ? null : tempResult.getString("reason");
String reason = tempResult.getReason();
synchronized (decisions) {
decisions.add(new PendingDecisionItem(pendingAction.getPendingId(), decision, reason));
}
@@ -68,11 +68,11 @@ public class ActionConfirmer extends AbstractAgentModule.Sub<ConfirmerInput, Con
return result;
}
private PendingActionRecord.Decision parseDecision(JSONObject tempResult) {
private PendingActionRecord.Decision parseDecision(DecisionResponse tempResult) {
if (tempResult == null) {
return PendingActionRecord.Decision.HOLD;
}
String decisionText = tempResult.getString("decision");
String decisionText = tempResult.getDecision();
if (decisionText != null) {
String upperDecision = decisionText.toUpperCase();
if (upperDecision.contains("CONFIRM")) {
@@ -85,7 +85,7 @@ public class ActionConfirmer extends AbstractAgentModule.Sub<ConfirmerInput, Con
return PendingActionRecord.Decision.HOLD;
}
}
Boolean confirmed = tempResult.getBoolean("confirmed");
Boolean confirmed = tempResult.getConfirmed();
if (Boolean.TRUE.equals(confirmed)) {
return PendingActionRecord.Decision.CONFIRM;
}
@@ -116,8 +116,10 @@ public class ActionConfirmer extends AbstractAgentModule.Sub<ConfirmerInput, Con
return "action-confirmer";
}
@Override
public boolean withBasicPrompt() {
return false;
@lombok.Data
private static class DecisionResponse {
private String decision;
private String reason;
private Boolean confirmed;
}
}

View File

@@ -7,7 +7,8 @@ import work.slhaf.partner.api.agent.factory.capability.annotation.InjectCapabili
import work.slhaf.partner.api.agent.factory.component.abstracts.AbstractAgentModule;
import work.slhaf.partner.api.agent.factory.component.abstracts.ActivateModel;
import work.slhaf.partner.api.agent.factory.component.annotation.Init;
import work.slhaf.partner.api.chat.pojo.ChatResponse;
import work.slhaf.partner.api.chat.constant.ChatConstant;
import work.slhaf.partner.api.chat.pojo.Message;
import work.slhaf.partner.common.thread.InteractionThreadPoolExecutor;
import work.slhaf.partner.core.action.ActionCapability;
import work.slhaf.partner.core.memory.pojo.EvaluatedSlice;
@@ -48,8 +49,10 @@ public class ActionEvaluator extends AbstractAgentModule.Sub<EvaluatorInput, Lis
List<Callable<EvaluatorResult>> list = new ArrayList<>();
for (EvaluatorBatchInput batchInput : batchInputs) {
list.add(() -> {
ChatResponse response = this.singleChat(buildPrompt(batchInput));
EvaluatorResult evaluatorResult = JSONObject.parseObject(response.getMessage(), EvaluatorResult.class);
EvaluatorResult evaluatorResult = formattedChat(
List.of(new Message(ChatConstant.Character.USER, buildPrompt(batchInput))),
EvaluatorResult.class
);
evaluatorResult.setTendency(batchInput.getTendency());
return evaluatorResult;
});
@@ -89,9 +92,4 @@ public class ActionEvaluator extends AbstractAgentModule.Sub<EvaluatorInput, Lis
public String modelKey() {
return "action_evaluator";
}
@Override
public boolean withBasicPrompt() {
return true;
}
}

View File

@@ -4,7 +4,8 @@ import com.alibaba.fastjson2.JSONObject;
import work.slhaf.partner.api.agent.factory.capability.annotation.InjectCapability;
import work.slhaf.partner.api.agent.factory.component.abstracts.AbstractAgentModule;
import work.slhaf.partner.api.agent.factory.component.abstracts.ActivateModel;
import work.slhaf.partner.api.chat.pojo.ChatResponse;
import work.slhaf.partner.api.chat.constant.ChatConstant;
import work.slhaf.partner.api.chat.pojo.Message;
import work.slhaf.partner.core.action.ActionCapability;
import work.slhaf.partner.module.modules.action.planner.extractor.entity.ExtractorInput;
import work.slhaf.partner.module.modules.action.planner.extractor.entity.ExtractorResult;
@@ -25,8 +26,10 @@ public class ActionExtractor extends AbstractAgentModule.Sub<ExtractorInput, Ext
}
for (int i = 0; i < 3; i++) {
try {
ChatResponse response = this.singleChat(JSONObject.toJSONString(data));
return JSONObject.parseObject(response.getMessage(), ExtractorResult.class);
return formattedChat(
List.of(new Message(ChatConstant.Character.USER, JSONObject.toJSONString(data))),
ExtractorResult.class
);
} catch (Exception e) {
log.error("[ActionExtractor] 提取信息出错", e);
}
@@ -38,9 +41,4 @@ public class ActionExtractor extends AbstractAgentModule.Sub<ExtractorInput, Ext
public String modelKey() {
return "action_extractor";
}
@Override
public boolean withBasicPrompt() {
return false;
}
}

View File

@@ -8,9 +8,7 @@ import work.slhaf.partner.api.agent.factory.capability.annotation.InjectCapabili
import work.slhaf.partner.api.agent.factory.component.abstracts.AbstractAgentModule;
import work.slhaf.partner.api.agent.factory.component.abstracts.ActivateModel;
import work.slhaf.partner.api.agent.factory.component.annotation.Init;
import work.slhaf.partner.api.chat.ChatClient;
import work.slhaf.partner.api.chat.constant.ChatConstant;
import work.slhaf.partner.api.chat.pojo.ChatResponse;
import work.slhaf.partner.api.chat.pojo.Message;
import work.slhaf.partner.api.chat.pojo.MetaMessage;
import work.slhaf.partner.core.cognation.CognationCapability;
@@ -31,31 +29,23 @@ public class CommunicationProducer extends AbstractAgentModule.Running<PartnerRu
@InjectCapability
private CognationCapability cognationCapability;
private List<Message> appendedMessages = new ArrayList<>();
private final List<Message> appendedMessages = new ArrayList<>();
private final List<Message> chatMessages = new ArrayList<>();
@Init
public void init() {
List<Message> chatMessages = this.cognationCapability.getChatMessages();
this.getModel().getChatMessages().addAll(chatMessages);
updateChatClientSettings();
this.chatMessages.clear();
this.chatMessages.addAll(this.cognationCapability.getChatMessages());
log.info("CommunicationProducer 注册完毕...");
}
@Override
public void updateChatClientSettings() {
ChatClient chatClient = getModel().getChatClient();
chatClient.setTemperature(0.3);
chatClient.setTop_p(0.7);
}
@Override
public @NotNull String modelKey() {
return "communication_producer";
}
@Override
public boolean withBasicPrompt() {
public boolean useStreaming() {
return true;
}
@@ -73,7 +63,7 @@ public class CommunicationProducer extends AbstractAgentModule.Running<PartnerRu
activateModule(runningFlowContext);
setMessageCount(runningFlowContext);
log.debug("[CommunicationProducer] 当前消息列表大小: {}", getModel().getChatMessages().size());
log.debug("[CommunicationProducer] 当前消息列表大小: {}", chatMessages.size());
log.debug("[CommunicationProducer] 当前核心prompt内容: {}", runningFlowContext.getCoreContext().toString());
setMessage(runningFlowContext.getCoreContext().toString());
@@ -94,28 +84,28 @@ public class CommunicationProducer extends AbstractAgentModule.Running<PartnerRu
int count = 0;
while (true) {
try {
ChatResponse chatResponse = this.chat();
String chatResponse = this.chat(buildChatMessages());
try {
response.putAll(JSONObject.parse(extractJson(chatResponse.getMessage())));
response.putAll(JSONObject.parse(extractJson(chatResponse)));
} catch (Exception e) {
log.warn("主模型回复格式出错, 将直接作为消息返回, 建议尝试更换主模型...");
handleExceptionResponse(response, chatResponse.getMessage());
handleExceptionResponse(response, chatResponse);
}
log.debug("[CommunicationProducer] CommunicationProducer 响应内容: {}", response);
updateModuleContextAndChatMessages(runningFlowContext, response.getString("text"), chatResponse);
updateModuleContextAndChatMessages(runningFlowContext, response.getString("text"));
break;
} catch (Exception e) {
count++;
log.error("[CommunicationProducer] CoreModel执行异常: {}", e.getLocalizedMessage());
if (count > 3) {
handleExceptionResponse(response, "主模型交互出错: " + e.getLocalizedMessage());
getModel().getChatMessages().removeLast();
chatMessages.removeLast();
break;
}
} finally {
updateCoreResponse(runningFlowContext, response);
resetAppendedMessages();
log.debug("[CommunicationProducer] 消息列表更新大小: {}", getModel().getChatMessages().size());
log.debug("[CommunicationProducer] 消息列表更新大小: {}", chatMessages.size());
}
}
}
@@ -143,20 +133,15 @@ public class CommunicationProducer extends AbstractAgentModule.Running<PartnerRu
this.appendedMessages.clear();
}
@Override
public @NotNull ChatResponse chat() {
List<Message> baseMessages = getModel().getBaseMessages();
List<Message> chatMessages = getModel().getChatMessages();
List<Message> temp = new ArrayList<>(baseMessages.subList(0, baseMessages.size() - 2));
private List<Message> buildChatMessages() {
List<Message> temp = new ArrayList<>(appendedMessages.size() + chatMessages.size());
temp.addAll(appendedMessages);
temp.addAll(baseMessages.subList(baseMessages.size() - 2, baseMessages.size()));
temp.addAll(chatMessages);
return getModel().getChatClient().runChat(temp);
return temp;
}
private void updateModuleContextAndChatMessages(PartnerRunningFlowContext runningFlowContext, String response, ChatResponse chatResponse) {
private void updateModuleContextAndChatMessages(PartnerRunningFlowContext runningFlowContext, String response) {
cognationCapability.getMessageLock().lock();
List<Message> chatMessages = getModel().getChatMessages();
chatMessages.removeIf(m -> {
if (m.getRole().equals(ChatConstant.Character.ASSISTANT)) {
return false;
@@ -176,8 +161,6 @@ public class CommunicationProducer extends AbstractAgentModule.Running<PartnerRu
Message assistantMessage = new Message(ChatConstant.Character.ASSISTANT, response);
chatMessages.add(assistantMessage);
cognationCapability.getMessageLock().unlock();
//设置上下文
runningFlowContext.getModuleContext().getExtraContext().put("total_token", chatResponse.getUsageBean().getTotal_tokens());
//区分单人聊天场景
// if (runningFlowContext.isSingle()) {
MetaMessage metaMessage = new MetaMessage(primaryUserMessage, assistantMessage);
@@ -187,7 +170,7 @@ public class CommunicationProducer extends AbstractAgentModule.Running<PartnerRu
private void setMessage(String coreContextStr) {
Message userMessage = new Message(ChatConstant.Character.USER, coreContextStr);
getModel().getChatMessages().add(userMessage);
chatMessages.add(userMessage);
}
private void handleExceptionResponse(JSONObject response, String chatResponse) {
@@ -196,7 +179,7 @@ public class CommunicationProducer extends AbstractAgentModule.Running<PartnerRu
}
private void setMessageCount(PartnerRunningFlowContext runningFlowContext) {
runningFlowContext.getModuleContext().getExtraContext().put("message_count", getModel().getChatMessages().size());
runningFlowContext.getModuleContext().getExtraContext().put("message_count", chatMessages.size());
}
private void setAppendedPromptMessage(List<AppendPromptData> appendPrompt) {

View File

@@ -8,6 +8,8 @@ import lombok.EqualsAndHashCode;
import work.slhaf.partner.api.agent.factory.component.abstracts.AbstractAgentModule;
import work.slhaf.partner.api.agent.factory.component.abstracts.ActivateModel;
import work.slhaf.partner.api.agent.factory.component.annotation.Init;
import work.slhaf.partner.api.chat.constant.ChatConstant;
import work.slhaf.partner.api.chat.pojo.Message;
import work.slhaf.partner.common.thread.InteractionThreadPoolExecutor;
import work.slhaf.partner.core.memory.pojo.EvaluatedSlice;
import work.slhaf.partner.core.memory.pojo.MemoryResult;
@@ -24,8 +26,6 @@ import java.util.concurrent.ConcurrentLinkedDeque;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import static work.slhaf.partner.common.util.ExtractUtil.extractJson;
@EqualsAndHashCode(callSuper = true)
@Data
public class SliceSelectEvaluator extends AbstractAgentModule.Sub<EvaluatorInput, List<EvaluatedSlice>> implements ActivateModel {
@@ -61,7 +61,10 @@ public class SliceSelectEvaluator extends AbstractAgentModule.Sub<EvaluatorInput
.history(evaluatorInput.getMessages())
.build();
log.debug("[SliceSelectEvaluator] 评估[{}]输入: {}", thisCount, JSONObject.toJSONString(batchInput));
EvaluatorResult evaluatorResult = JSONObject.parseObject(extractJson(singleChat(JSONUtil.toJsonStr(batchInput)).getMessage()), EvaluatorResult.class);
EvaluatorResult evaluatorResult = formattedChat(
List.of(new Message(ChatConstant.Character.USER, JSONUtil.toJsonStr(batchInput))),
EvaluatorResult.class
);
log.debug("[SliceSelectEvaluator] 评估[{}]结果: {}", thisCount, JSONObject.toJSONString(evaluatorResult));
for (Long result : evaluatorResult.getResults()) {
SliceSummary sliceSummary = map.get(result);
@@ -117,9 +120,4 @@ public class SliceSelectEvaluator extends AbstractAgentModule.Sub<EvaluatorInput
public String modelKey() {
return "slice_evaluator";
}
@Override
public boolean withBasicPrompt() {
return false;
}
}

View File

@@ -1,12 +1,12 @@
package work.slhaf.partner.module.modules.memory.selector.extractor;
import cn.hutool.json.JSONUtil;
import com.alibaba.fastjson2.JSONObject;
import lombok.Data;
import lombok.EqualsAndHashCode;
import work.slhaf.partner.api.agent.factory.capability.annotation.InjectCapability;
import work.slhaf.partner.api.agent.factory.component.abstracts.AbstractAgentModule;
import work.slhaf.partner.api.agent.factory.component.abstracts.ActivateModel;
import work.slhaf.partner.api.chat.constant.ChatConstant;
import work.slhaf.partner.api.chat.pojo.Message;
import work.slhaf.partner.api.chat.pojo.MetaMessage;
import work.slhaf.partner.core.cognation.CognationCapability;
@@ -20,7 +20,6 @@ import work.slhaf.partner.runtime.interaction.data.context.PartnerRunningFlowCon
import java.util.ArrayList;
import java.util.List;
import static work.slhaf.partner.common.util.ExtractUtil.extractJson;
import static work.slhaf.partner.common.util.ExtractUtil.fixTopicPath;
@EqualsAndHashCode(callSuper = true)
@@ -52,9 +51,11 @@ public class MemorySelectExtractor extends AbstractAgentModule.Sub<PartnerRunnin
.topic_tree(memoryCapability.getTopicTree())
.activatedMemorySlices(activatedMemorySlices)
.build();
log.debug("[MemorySelectExtractor] 主题提取输入: {}", JSONObject.toJSONString(extractorInput));
String responseStr = extractJson(singleChat(JSONUtil.toJsonPrettyStr(extractorInput)).getMessage());
extractorResult = JSONObject.parseObject(responseStr, ExtractorResult.class);
log.debug("[MemorySelectExtractor] 主题提取输入: {}", JSONUtil.toJsonStr(extractorInput));
extractorResult = formattedChat(
List.of(new Message(ChatConstant.Character.USER, JSONUtil.toJsonPrettyStr(extractorInput))),
ExtractorResult.class
);
log.debug("[MemorySelectExtractor] 主题提取结果: {}", extractorResult);
} catch (Exception e) {
log.error("[MemorySelectExtractor] 主题提取出错: ", e);
@@ -83,9 +84,4 @@ public class MemorySelectExtractor extends AbstractAgentModule.Sub<PartnerRunnin
public String modelKey() {
return "topic_extractor";
}
@Override
public boolean withBasicPrompt() {
return false;
}
}

View File

@@ -6,31 +6,27 @@ import lombok.Data;
import lombok.EqualsAndHashCode;
import work.slhaf.partner.api.agent.factory.component.abstracts.AbstractAgentModule;
import work.slhaf.partner.api.agent.factory.component.abstracts.ActivateModel;
import work.slhaf.partner.api.agent.factory.component.annotation.Init;
import work.slhaf.partner.api.chat.pojo.ChatResponse;
import work.slhaf.partner.api.chat.constant.ChatConstant;
import work.slhaf.partner.api.chat.pojo.Message;
import work.slhaf.partner.module.modules.memory.updater.summarizer.entity.SummarizeInput;
import work.slhaf.partner.module.modules.memory.updater.summarizer.entity.SummarizeResult;
import java.util.ArrayList;
import java.util.List;
import static work.slhaf.partner.common.util.ExtractUtil.extractJson;
import static work.slhaf.partner.common.util.ExtractUtil.fixTopicPath;
@EqualsAndHashCode(callSuper = true)
@Data
public class MultiSummarizer extends AbstractAgentModule.Sub<SummarizeInput, SummarizeResult> implements ActivateModel {
@Init
public void init() {
updateChatClientSettings();
}
@Override
public SummarizeResult execute(SummarizeInput input) {
log.debug("[MemorySummarizer] 整体摘要开始...");
ChatResponse response = this.singleChat(JSONUtil.toJsonPrettyStr(input));
log.debug("[MemorySummarizer] 整体摘要结果: {}", JSONObject.toJSONString(response));
SummarizeResult result = JSONObject.parseObject(extractJson(response.getMessage()), SummarizeResult.class);
SummarizeResult result = formattedChat(
List.of(new Message(ChatConstant.Character.USER, JSONUtil.toJsonPrettyStr(input))),
SummarizeResult.class
);
log.debug("[MemorySummarizer] 整体摘要结果: {}", JSONObject.toJSONString(result));
return fix(result);
}
@@ -52,9 +48,4 @@ public class MultiSummarizer extends AbstractAgentModule.Sub<SummarizeInput, Sum
public String modelKey() {
return "multi_summarizer";
}
@Override
public boolean withBasicPrompt() {
return true;
}
}

View File

@@ -7,7 +7,6 @@ import work.slhaf.partner.api.agent.factory.component.abstracts.AbstractAgentMod
import work.slhaf.partner.api.agent.factory.component.abstracts.ActivateModel;
import work.slhaf.partner.api.agent.factory.component.annotation.Init;
import work.slhaf.partner.api.chat.constant.ChatConstant;
import work.slhaf.partner.api.chat.pojo.ChatResponse;
import work.slhaf.partner.api.chat.pojo.Message;
import work.slhaf.partner.common.thread.InteractionThreadPoolExecutor;
@@ -56,8 +55,7 @@ public class SingleSummarizer extends AbstractAgentModule.Sub<List<Message>, Voi
private String singleExecute(String primaryContent) {
try {
ChatResponse response = this.singleChat(primaryContent);
return response.getMessage();
return chat(List.of(new Message(ChatConstant.Character.USER, primaryContent)));
} catch (Exception e) {
log.error("[SingleSummarizer] 单消息总结出错: ", e);
return primaryContent;
@@ -68,9 +66,4 @@ public class SingleSummarizer extends AbstractAgentModule.Sub<List<Message>, Voi
public String modelKey() {
return "single_summarizer";
}
@Override
public boolean withBasicPrompt() {
return false;
}
}

View File

@@ -1,29 +1,24 @@
package work.slhaf.partner.module.modules.memory.updater.summarizer;
import cn.hutool.json.JSONUtil;
import com.alibaba.fastjson2.JSONObject;
import lombok.Data;
import lombok.EqualsAndHashCode;
import work.slhaf.partner.api.agent.factory.component.abstracts.AbstractAgentModule;
import work.slhaf.partner.api.agent.factory.component.abstracts.ActivateModel;
import work.slhaf.partner.api.agent.factory.component.annotation.Init;
import work.slhaf.partner.api.chat.pojo.ChatResponse;
import work.slhaf.partner.api.chat.constant.ChatConstant;
import work.slhaf.partner.api.chat.pojo.Message;
import java.util.HashMap;
import static work.slhaf.partner.common.util.ExtractUtil.extractJson;
import java.util.List;
@EqualsAndHashCode(callSuper = true)
@Data
public class TotalSummarizer extends AbstractAgentModule.Sub<HashMap<String, String>, String> implements ActivateModel {
@Init
public void init() {
updateChatClientSettings();
}
public String execute(HashMap<String, String> singleMemorySummary) {
ChatResponse response = this.singleChat(JSONUtil.toJsonPrettyStr(singleMemorySummary));
return JSONObject.parseObject(extractJson(response.getMessage())).getString("content");
return formattedChat(
List.of(new Message(ChatConstant.Character.USER, JSONUtil.toJsonPrettyStr(singleMemorySummary))),
SummaryContent.class
).getContent();
}
@Override
@@ -31,8 +26,8 @@ public class TotalSummarizer extends AbstractAgentModule.Sub<HashMap<String, Str
return "total_summarizer";
}
@Override
public boolean withBasicPrompt() {
return true;
@lombok.Data
private static class SummaryContent {
private String content;
}
}

View File

@@ -6,7 +6,7 @@ import lombok.EqualsAndHashCode;
import work.slhaf.partner.api.agent.factory.capability.annotation.InjectCapability;
import work.slhaf.partner.api.agent.factory.component.abstracts.AbstractAgentModule;
import work.slhaf.partner.api.agent.factory.component.abstracts.ActivateModel;
import work.slhaf.partner.api.chat.pojo.ChatResponse;
import work.slhaf.partner.api.chat.constant.ChatConstant;
import work.slhaf.partner.api.chat.pojo.Message;
import work.slhaf.partner.core.cognation.CognationCapability;
import work.slhaf.partner.core.perceive.PerceiveCapability;
@@ -49,8 +49,10 @@ public class RelationExtractor extends AbstractAgentModule.Sub<PartnerRunningFlo
}
private RelationExtractResult getRelationResult(RelationExtractInput input) {
ChatResponse response = singleChat(JSONObject.toJSONString(input));
return JSONObject.parseObject(response.getMessage(), RelationExtractResult.class);
return formattedChat(
List.of(new Message(ChatConstant.Character.USER, JSONObject.toJSONString(input))),
RelationExtractResult.class
);
}
private RelationExtractInput getRelationInput(String userId) {
@@ -71,9 +73,4 @@ public class RelationExtractor extends AbstractAgentModule.Sub<PartnerRunningFlo
public String modelKey() {
return "relation_extractor";
}
@Override
public boolean withBasicPrompt() {
return true;
}
}

View File

@@ -7,13 +7,15 @@ import lombok.EqualsAndHashCode;
import work.slhaf.partner.api.agent.factory.capability.annotation.InjectCapability;
import work.slhaf.partner.api.agent.factory.component.abstracts.AbstractAgentModule;
import work.slhaf.partner.api.agent.factory.component.abstracts.ActivateModel;
import work.slhaf.partner.api.chat.pojo.ChatResponse;
import work.slhaf.partner.api.chat.constant.ChatConstant;
import work.slhaf.partner.api.chat.pojo.Message;
import work.slhaf.partner.core.cognation.CognationCapability;
import work.slhaf.partner.core.perceive.PerceiveCapability;
import work.slhaf.partner.module.modules.perceive.updater.static_extractor.entity.StaticMemoryExtractInput;
import work.slhaf.partner.runtime.interaction.data.context.PartnerRunningFlowContext;
import java.util.HashMap;
import java.util.List;
@EqualsAndHashCode(callSuper = true)
@Data
@@ -30,8 +32,8 @@ public class StaticMemoryExtractor extends AbstractAgentModule.Sub<PartnerRunnin
.messages(cognationCapability.getChatMessages())
.existedStaticMap(perceiveCapability.getUser(context.getSource()).getStaticMemory())
.build();
ChatResponse response = singleChat(JSONUtil.toJsonPrettyStr(input));
JSONObject jsonObject = JSONObject.parseObject(response.getMessage());
String response = chat(List.of(new Message(ChatConstant.Character.USER, JSONUtil.toJsonPrettyStr(input))));
JSONObject jsonObject = JSONObject.parseObject(response);
HashMap<String, String> result = new HashMap<>();
jsonObject.forEach((k, v) -> result.put(k, (String) v));
return result;
@@ -41,9 +43,4 @@ public class StaticMemoryExtractor extends AbstractAgentModule.Sub<PartnerRunnin
public String modelKey() {
return "static_extractor";
}
@Override
public boolean withBasicPrompt() {
return true;
}
}

View File

@@ -2,10 +2,9 @@ package experimental;
import cn.hutool.json.JSONUtil;
import org.junit.jupiter.api.Test;
import work.slhaf.partner.api.chat.ChatClient;
import work.slhaf.partner.api.chat.constant.ChatConstant;
import work.slhaf.partner.api.chat.pojo.ChatResponse;
import work.slhaf.partner.api.chat.pojo.Message;
import work.slhaf.partner.api.chat.runtime.OpenAiChatRuntime;
import work.slhaf.partner.common.util.ResourcesUtil;
import work.slhaf.partner.module.common.model.ModelConstant;
import work.slhaf.partner.module.modules.memory.selector.extractor.entity.ExtractorInput;
@@ -16,46 +15,41 @@ import java.util.List;
import java.util.Scanner;
public class SelfAwarenessTest {
private static ChatClient getChatClient(String modelKey) {
private static OpenAiChatRuntime getChatRuntime(String modelKey) {
String model = "";
String baseUrl = "";
String apikey = "";
ChatClient chatClient = new ChatClient(baseUrl, apikey, model);
chatClient.setTop_p(0.7);
chatClient.setTemperature(0.35);
return chatClient;
return new OpenAiChatRuntime(baseUrl, apikey, model);
}
@Test
public void awarenessTest() {
String modelKey = "core_model";
ChatClient client = getChatClient(modelKey);
ChatResponse response = client.runChat(ResourcesUtil.Prompt.loadPromptWithSelfAwareness(modelKey, ModelConstant.Prompt.CORE));
System.out.println(response.getMessage());
OpenAiChatRuntime client = getChatRuntime(modelKey);
String response = client.chat(ResourcesUtil.Prompt.loadPromptWithSelfAwareness(modelKey, ModelConstant.Prompt.CORE), false);
System.out.println(response);
System.out.println("\r\n----------\r\n");
System.out.println(response.getUsageBean().toString());
}
@Test
public void getModuleResponseTest() {
String modelKey = "relation_extractor";
ChatClient client = getChatClient(modelKey);
OpenAiChatRuntime client = getChatRuntime(modelKey);
List<Message> chatMessages = new ArrayList<>(ResourcesUtil.Prompt.loadPromptWithSelfAwareness(modelKey, ModelConstant.Prompt.PERCEIVE));
// chatMessages.add(Message.builder()
// .role(ChatConstant.Character.USER)
// .content("[RA9] 那么,接下来,你是否愿意当作这样一个名为'Partner'的智能体的记忆更新模块?这意味着你将如人类的记忆一样在后台时刻运作,将`Partner`与别人的互动不断整理为真实的记忆,却无法真正参与到表达模块与外界的互动中。你只需要回答是否愿意,若愿意,接下来‘我’将不再与你对话,届时你接收到的信息将会是'Partner'的数据流转输入。")
// .build());
ChatResponse chatResponse = client.runChat(chatMessages);
System.out.println(chatResponse.getMessage());
String chatResponse = client.chat(chatMessages, false);
System.out.println(chatResponse);
System.out.println("\n\n----------\n\n");
System.out.println(chatResponse.getUsageBean());
}
@Test
public void interactionTest() {
String modelKey = "core_model";
String user = "[SLHAF] ";
ChatClient client = getChatClient(modelKey);
OpenAiChatRuntime client = getChatRuntime(modelKey);
List<Message> messages = new ArrayList<>(ResourcesUtil.Prompt.loadPromptWithSelfAwareness(modelKey, ModelConstant.Prompt.CORE));
Scanner scanner = new Scanner(System.in);
String input;
@@ -66,12 +60,10 @@ public class SelfAwarenessTest {
}
System.out.println("\r\n----------\r\n");
messages.add(new Message(ChatConstant.Character.USER, user + input));
ChatResponse response = client.runChat(messages);
System.out.println("[OUTPUT]: " + response.getMessage());
String response = client.chat(messages, false);
System.out.println("[OUTPUT]: " + response);
System.out.println("\r\n----------\r\n");
System.out.println(response.getUsageBean().toString());
System.out.println("\r\n----------\r\n");
messages.add(new Message(ChatConstant.Character.ASSISTANT, response.getMessage()));
messages.add(new Message(ChatConstant.Character.ASSISTANT, response));
}
}
@@ -89,7 +81,7 @@ public class SelfAwarenessTest {
└── Python"
""";
String modelKey = "topic_extractor";
ChatClient client = getChatClient(modelKey);
OpenAiChatRuntime client = getChatRuntime(modelKey);
// List<Message> messages = new ArrayList<>(ResourcesUtil.Prompt.loadPromptWithSelfAwareness(modelKey, ModelConstant.Prompt.MEMORY));
List<Message> messages = new ArrayList<>(ResourcesUtil.Prompt.loadPrompt(modelKey, ModelConstant.Prompt.MEMORY));
ExtractorInput input = ExtractorInput.builder()
@@ -101,9 +93,8 @@ public class SelfAwarenessTest {
.build();
messages.add(new Message(ChatConstant.Character.USER, JSONUtil.toJsonPrettyStr(input)));
ChatResponse response = client.runChat(messages);
System.out.println(response.getMessage());
String response = client.chat(messages, false);
System.out.println(response);
System.out.println("\r\n----------\r\n");
System.out.println(response.getUsageBean().toString());
}
}

View File

@@ -8,11 +8,9 @@ import work.slhaf.partner.api.agent.factory.component.abstracts.AbstractAgentMod
import work.slhaf.partner.api.agent.factory.component.abstracts.ActivateModel
import work.slhaf.partner.api.agent.factory.component.annotation.AgentComponent
import work.slhaf.partner.api.agent.factory.component.exception.ModuleFactoryInitFailedException
import work.slhaf.partner.api.agent.factory.config.exception.PromptNotExistException
import work.slhaf.partner.api.agent.factory.config.pojo.ModelConfig
import work.slhaf.partner.api.agent.factory.context.AgentRegisterContext
import work.slhaf.partner.api.agent.factory.context.ModuleContextData
import work.slhaf.partner.api.chat.pojo.Message
import java.lang.reflect.Modifier
import java.time.ZonedDateTime
@@ -21,7 +19,7 @@ import java.time.ZonedDateTime
*
* 行为:
* - 若实例是 [AbstractAgentModule],按 Running/Sub/Standalone 构造 `ModuleContextData` 并注册到 modules。
* - 若实现了 [ActivateModel]必须存在对应 `modelPromptMap` 条目,随后构建 `modelInfo`。
* - 若实现了 [ActivateModel]使用模块提供的 prompt 元数据构建 `modelInfo`。
* - 若不是模块类型,尝试注册为 additional component失败仅记录错误日志
*/
class ComponentRegisterFactory : AgentBaseFactory() {
@@ -35,7 +33,6 @@ class ComponentRegisterFactory : AgentBaseFactory() {
val agentContext = context.agentContext
val modelConfigMap = configFactoryContext.modelConfigMap
val modelPromptMap = configFactoryContext.modelPromptMap
val defaultConfig = modelConfigMap["default"]!!
reflections.getTypesAnnotatedWith(AgentComponent::class.java)
@@ -56,7 +53,6 @@ class ComponentRegisterFactory : AgentBaseFactory() {
componentClass,
componentInstance,
modelConfigMap,
modelPromptMap,
defaultConfig
)
} else {
@@ -71,7 +67,6 @@ class ComponentRegisterFactory : AgentBaseFactory() {
componentClass: Class<*>,
module: AbstractAgentModule,
modelConfigMap: Map<String, ModelConfig>,
modelPromptMap: Map<String, List<Message>>,
defaultConfig: ModelConfig
) {
if (agentContext.modules.containsKey(module.moduleName)) {
@@ -84,12 +79,10 @@ class ComponentRegisterFactory : AgentBaseFactory() {
val modelInfo = if (module is ActivateModel) {
val modelKey = module.modelKey()
val modelConfig = modelConfigMap[modelKey] ?: defaultConfig
val modelPrompt = modelPromptMap[modelKey]
?: throw PromptNotExistException("不存在的modelPrompt: $modelKey")
ModuleContextData.ModelInfo(
modelConfig.baseUrl,
modelConfig.model,
JSONArray.parseArray(JSONObject.toJSONString(modelPrompt))
JSONArray.parseArray(JSONObject.toJSONString(module.modulePrompt()))
)
} else {
null

View File

@@ -3,13 +3,10 @@ package work.slhaf.partner.api.agent.factory.component.abstracts
import org.slf4j.Logger
import org.slf4j.LoggerFactory
import work.slhaf.partner.api.agent.factory.component.annotation.AgentComponent
import work.slhaf.partner.api.agent.factory.component.annotation.Init
import work.slhaf.partner.api.agent.runtime.config.AgentConfigLoader
import work.slhaf.partner.api.agent.runtime.interaction.flow.RunningFlowContext
import work.slhaf.partner.api.chat.ChatClient
import work.slhaf.partner.api.chat.constant.ChatConstant
import work.slhaf.partner.api.chat.pojo.ChatResponse
import work.slhaf.partner.api.chat.pojo.Message
import work.slhaf.partner.api.chat.runtime.OpenAiChatRuntime
/**
* 模块基类
@@ -39,58 +36,37 @@ sealed class AbstractAgentModule {
interface ActivateModel {
val model: Model
get() = modelMap.computeIfAbsent(modelKey()) {
buildModel()
val runtime: OpenAiChatRuntime
get() = runtimeMap.computeIfAbsent(modelKey()) {
buildRuntime()
}
companion object {
val modelMap: MutableMap<String, Model> = mutableMapOf()
val runtimeMap: MutableMap<String, OpenAiChatRuntime> = mutableMapOf()
private val configManager: AgentConfigLoader = AgentConfigLoader.INSTANCE
}
@Init(order = -1)
fun modelSettings() {
modelMap[modelKey()] = buildModel()
}
fun buildModel(): Model {
fun buildRuntime(): OpenAiChatRuntime {
val modelConfig = configManager.loadModelConfig(modelKey())
val chatClient = ChatClient(modelConfig.baseUrl, modelConfig.apikey, modelConfig.model)
val model = Model(chatClient)
return OpenAiChatRuntime(modelConfig.baseUrl, modelConfig.apikey, modelConfig.model)
}
val baseMessages = if (withBasicPrompt()) {
loadSpecificPromptAndBasicPrompt(modelKey())
} else {
configManager.loadModelPrompt(modelKey())
fun chat(messages: List<Message>): String {
return runtime.chat(mergeMessages(messages), useStreaming())
}
fun <T : Any> formattedChat(messages: List<Message>, responseType: Class<T>): T {
return runtime.formattedChat(mergeMessages(messages), useStreaming(), responseType)
}
fun mergeMessages(messages: List<Message>): List<Message> {
if (modulePrompt().isEmpty()) {
return messages
}
return buildList {
addAll(modulePrompt())
addAll(messages)
}
model.baseMessages.addAll(baseMessages)
return model
}
private fun loadSpecificPromptAndBasicPrompt(modelKey: String): MutableList<Message> {
val messages: MutableList<Message> = ArrayList()
messages.addAll(configManager.loadModelPrompt("basic"))
messages.addAll(configManager.loadModelPrompt(modelKey))
return messages
}
fun chat(): ChatResponse {
val temp = ArrayList<Message?>()
temp.addAll(model.baseMessages)
temp.addAll(model.chatMessages)
return model.chatClient.runChat(temp)
}
fun singleChat(input: String): ChatResponse {
val temp = ArrayList<Message>(model.baseMessages)
temp.add(Message(ChatConstant.Character.USER, input))
return model.chatClient.runChat(temp)
}
fun updateChatClientSettings() {
model.chatClient.temperature = 0.4
model.chatClient.top_p = 0.8
}
/**
@@ -104,11 +80,7 @@ interface ActivateModel {
}
}
fun withBasicPrompt(): Boolean
fun modulePrompt(): List<Message> = emptyList()
data class Model(
val chatClient: ChatClient,
val chatMessages: MutableList<Message> = mutableListOf(),
val baseMessages: MutableList<Message> = mutableListOf()
)
fun useStreaming(): Boolean = false
}

View File

@@ -3,7 +3,6 @@ package work.slhaf.partner.api.agent.factory.config
import org.slf4j.LoggerFactory
import work.slhaf.partner.api.agent.factory.AgentBaseFactory
import work.slhaf.partner.api.agent.factory.config.exception.ConfigNotExistException
import work.slhaf.partner.api.agent.factory.config.exception.PromptNotExistException
import work.slhaf.partner.api.agent.factory.context.AgentRegisterContext
import work.slhaf.partner.api.agent.runtime.config.AgentConfigLoader
import work.slhaf.partner.api.agent.runtime.config.FileAgentConfigLoader
@@ -14,8 +13,8 @@ import java.lang.reflect.Modifier
*
* 行为:
* - 使用全局 `AgentConfigLoader.INSTANCE`,为空时退回 [FileAgentConfigLoader]。
* - 加载并写入 `modelConfigMap`、`modelPromptMap` 到 `ConfigFactoryContext`。
* - 校验 `default` 配置与 `basic` 提示词是否存在。
* - 加载并写入 `modelConfigMap` 到 `ConfigFactoryContext`。
* - 校验 `default` 配置是否存在。
* - 反射读取配置加载器实现类(相对基类新增)的静态字段,并写入 `AgentContext.metadata`。
*/
class ConfigLoaderFactory : AgentBaseFactory() {
@@ -33,26 +32,16 @@ class ConfigLoaderFactory : AgentBaseFactory() {
val configFactoryContext = context.configFactoryContext
configFactoryContext.modelConfigMap.putAll(agentConfigLoader.modelConfigMap)
configFactoryContext.modelPromptMap.putAll(agentConfigLoader.modelPromptMap)
check(configFactoryContext.modelConfigMap.keys, configFactoryContext.modelPromptMap.keys)
check(configFactoryContext.modelConfigMap.keys)
collectLoaderMetadata(context, agentConfigLoader)
}
private fun check(configKeys: Set<String>, promptKeys: Set<String>) {
log.info("执行config与prompt检测...")
private fun check(configKeys: Set<String>) {
log.info("执行config检测...")
if (!configKeys.contains("default")) {
throw ConfigNotExistException("缺少默认配置! 需确保存在一个模型配置的key为`default`")
}
if (!promptKeys.contains("basic")) {
throw PromptNotExistException("缺少基础Prompt! 需要确保存在key为basic的Prompt文件它将与其他Prompt共同作用于模块节点。")
}
val configKeySet = configKeys.toMutableSet().apply { remove("default") }
val promptKeySet = promptKeys.toMutableSet().apply { remove("basic") }
if (!promptKeySet.containsAll(configKeySet)) {
log.warn("存在未被提示词包含的模型配置,该配置将无法生效!")
}
log.info("检测完毕.")
}

View File

@@ -1,12 +0,0 @@
package work.slhaf.partner.api.agent.factory.config.pojo;
import lombok.Data;
import work.slhaf.partner.api.chat.pojo.Message;
import java.util.List;
@Data
public class PrimaryModelPrompt {
private String key;
private List<Message> messages;
}

View File

@@ -4,7 +4,6 @@ import org.reflections.Reflections
import org.reflections.scanners.Scanners
import org.reflections.util.ConfigurationBuilder
import work.slhaf.partner.api.agent.factory.config.pojo.ModelConfig
import work.slhaf.partner.api.chat.pojo.Message
import java.lang.reflect.Method
import java.net.URL
@@ -25,7 +24,6 @@ class AgentRegisterContext(urls: List<URL>) {
}
class ConfigFactoryContext {
val modelPromptMap: HashMap<String, List<Message>> = HashMap()
val modelConfigMap: HashMap<String, ModelConfig> = HashMap()
}

View File

@@ -3,12 +3,9 @@ package work.slhaf.partner.api.agent.runtime.config;
import lombok.Data;
import lombok.Setter;
import lombok.extern.slf4j.Slf4j;
import work.slhaf.partner.api.agent.factory.config.exception.PromptNotExistException;
import work.slhaf.partner.api.agent.factory.config.pojo.ModelConfig;
import work.slhaf.partner.api.chat.pojo.Message;
import java.util.HashMap;
import java.util.List;
@Slf4j
@Data
@@ -18,45 +15,22 @@ public abstract class AgentConfigLoader {
@Setter
public static AgentConfigLoader INSTANCE;
protected HashMap<String, ModelConfig> modelConfigMap;
protected HashMap<String, List<Message>> modelPromptMap;
public void load() {
modelConfigMap = loadModelConfig();
modelPromptMap = loadModelPrompt();
}
protected abstract HashMap<String, List<Message>> loadModelPrompt();
protected abstract HashMap<String, ModelConfig> loadModelConfig();
public abstract void dumpModelConfig(String key);
// Keep explicit getters for Kotlin compilation phase (without Lombok-generated methods).
public HashMap<String, ModelConfig> getModelConfigMap() {
return modelConfigMap;
}
public HashMap<String, List<Message>> getModelPromptMap() {
return modelPromptMap;
}
public List<Message> loadModelPrompt(String modelKey) {
if (!modelPromptMap.containsKey(modelKey)) {
throw new PromptNotExistException("不存在的modelPrompt: " + modelKey);
}
return modelPromptMap.get(modelKey);
}
public ModelConfig loadModelConfig(String modelKey) {
if (!modelConfigMap.containsKey(modelKey)) {
return modelConfigMap.get(DEFAULT_KEY);
}
return modelConfigMap.get(modelKey);
}
public void updateModelConfig(String modelKey, ModelConfig config) {
modelConfigMap.put(modelKey, config);
dumpModelConfig(modelKey);
}
}

View File

@@ -2,17 +2,14 @@ package work.slhaf.partner.api.agent.runtime.config;
import cn.hutool.json.JSONUtil;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.io.FileUtils;
import work.slhaf.partner.api.agent.factory.config.exception.*;
import work.slhaf.partner.api.agent.factory.config.exception.ConfigDirNotExistException;
import work.slhaf.partner.api.agent.factory.config.exception.ConfigNotExistException;
import work.slhaf.partner.api.agent.factory.config.pojo.ModelConfig;
import work.slhaf.partner.api.agent.factory.config.pojo.PrimaryModelConfig;
import work.slhaf.partner.api.agent.factory.config.pojo.PrimaryModelPrompt;
import work.slhaf.partner.api.chat.pojo.Message;
import java.io.File;
import java.nio.charset.StandardCharsets;
import java.util.HashMap;
import java.util.List;
/**
* 默认配置工厂
@@ -23,28 +20,6 @@ public class FileAgentConfigLoader extends AgentConfigLoader {
protected static final String CONFIG_DIR = "./config/";
protected static final String MODEL_CONFIG_DIR = "./config/model/";
protected static final String PROMPT_CONFIG_DIR = "./config/prompt/";
@Override
protected HashMap<String, List<Message>> loadModelPrompt() {
File file = new File(PROMPT_CONFIG_DIR);
if (!file.exists() && !file.isDirectory()) {
throw new PromptDirNotExistException("未找到提示词目录: " + PROMPT_CONFIG_DIR + " 请手动创建!");
}
File[] files = file.listFiles();
if (files == null || files.length == 0) {
throw new PromptNotExistException("在目录 " + PROMPT_CONFIG_DIR + " 中未找到提示词配置!");
}
HashMap<String, List<Message>> promptMap = new HashMap<>();
for (File f : files) {
if (f.isDirectory()) {
continue;
}
PrimaryModelPrompt primaryModelPrompt = JSONUtil.readJSONObject(f, StandardCharsets.UTF_8).toBean(PrimaryModelPrompt.class);
promptMap.put(primaryModelPrompt.getKey(), primaryModelPrompt.getMessages());
}
return promptMap;
}
@Override
protected HashMap<String, ModelConfig> loadModelConfig() {
@@ -67,17 +42,4 @@ public class FileAgentConfigLoader extends AgentConfigLoader {
}
return configMap;
}
@Override
public void dumpModelConfig(String key) {
try {
File file = new File(MODEL_CONFIG_DIR + key + ".json");
if (!file.exists()) {
file.createNewFile();
}
FileUtils.writeStringToFile(file, JSONUtil.toJsonPrettyStr(modelConfigMap.get(key)), StandardCharsets.UTF_8, false);
} catch (Exception e) {
throw new ConfigUpdateFailedException("ModelConfig 配置文件更新失败!");
}
}
}

View File

@@ -1,84 +0,0 @@
package work.slhaf.partner.api.chat;
import cn.hutool.core.io.IORuntimeException;
import cn.hutool.http.HttpRequest;
import cn.hutool.http.HttpResponse;
import cn.hutool.json.JSONUtil;
import lombok.Data;
import lombok.NoArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import work.slhaf.partner.api.chat.constant.ChatConstant;
import work.slhaf.partner.api.chat.pojo.ChatBody;
import work.slhaf.partner.api.chat.pojo.ChatResponse;
import work.slhaf.partner.api.chat.pojo.Message;
import work.slhaf.partner.api.chat.pojo.PrimaryChatResponse;
import java.util.List;
@Slf4j
@Data
@NoArgsConstructor
public class ChatClient {
private String clientId;
private String url;
private String apikey;
private String model;
private double top_p;
private double temperature;
private int max_tokens;
public ChatClient(String url, String apikey, String model) {
this.url = url;
this.apikey = apikey;
this.model = model;
}
public ChatResponse runChat(List<Message> messages) {
HttpRequest request = HttpRequest.post(url);
request.setConnectionTimeout(2000);
request.setReadTimeout(15000);
request.header("Content-Type", "application/json");
request.header("Authorization", "Bearer " + apikey);
ChatBody body;
if (top_p > 0) {
body = ChatBody.builder()
.model(model)
.messages(messages)
.top_p(top_p)
.temperature(temperature)
.max_tokens(max_tokens)
.build();
} else {
body = ChatBody.builder()
.model(model)
.messages(messages)
.build();
}
ChatResponse finalResponse;
try {
HttpResponse response = request.body(JSONUtil.toJsonStr(body)).execute();
PrimaryChatResponse primaryChatResponse = JSONUtil.toBean(response.body(), PrimaryChatResponse.class);
finalResponse = ChatResponse.builder()
.status(ChatConstant.ResponseStatus.SUCCESS)
.message(primaryChatResponse.getChoices().get(0).getMessage().getContent())
.usageBean(primaryChatResponse.getUsage())
.build();
response.close();
} catch (IORuntimeException e) {
log.error("请求超时", e);
finalResponse = ChatResponse.builder()
.message("连接超时")
.status(ChatConstant.ResponseStatus.FAILED)
.usageBean(null)
.build();
}
return finalResponse;
}
}

View File

@@ -1,25 +0,0 @@
package work.slhaf.partner.api.chat.pojo;
import lombok.*;
import java.util.List;
@Builder
@Data
@AllArgsConstructor
@NoArgsConstructor
public class ChatBody {
@NonNull
private String model;
@NonNull
private List<Message> messages;
@Builder.Default
private double temperature = 1;
@Builder.Default
private double top_p = 1;
private boolean stream;
@Builder.Default
private int max_tokens = 1024;
private int presence_penalty;
private int frequency_penalty;
}

View File

@@ -1,17 +0,0 @@
package work.slhaf.partner.api.chat.pojo;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;
import work.slhaf.partner.api.chat.constant.ChatConstant;
@Data
@Builder
@AllArgsConstructor
@NoArgsConstructor
public class ChatResponse {
private ChatConstant.ResponseStatus status;
private String message;
private PrimaryChatResponse.UsageBean usageBean;
}

View File

@@ -1,111 +0,0 @@
package work.slhaf.partner.api.chat.pojo;
import lombok.Getter;
import lombok.Setter;
import java.util.List;
@Getter
@Setter
public class PrimaryChatResponse {
/**
* id
*/
private String id;
/**
* object
*/
private String object;
/**
* created
*/
private int created;
/**
* model
*/
private String model;
/**
* choices
*/
private List<ChoicesBean> choices;
/**
* usage
*/
private UsageBean usage;
/**
* system_fingerprint
*/
private String system_fingerprint;
@Setter
@Getter
public static class UsageBean {
/**
* prompt_tokens
*/
private int prompt_tokens;
/**
* completion_tokens
*/
private int completion_tokens;
/**
* total_tokens
*/
private int total_tokens;
/**
* prompt_cache_hit_tokens
*/
private int prompt_cache_hit_tokens;
/**
* prompt_cache_miss_tokens
*/
private int prompt_cache_miss_tokens;
@Override
public String toString() {
return "UsageBean{" +
"prompt_tokens=" + prompt_tokens +
", completion_tokens=" + completion_tokens +
", total_tokens=" + total_tokens +
", prompt_cache_hit_tokens=" + prompt_cache_hit_tokens +
", prompt_cache_miss_tokens=" + prompt_cache_miss_tokens +
'}';
}
}
@Setter
@Getter
public static class ChoicesBean {
/**
* index
*/
private int index;
/**
* message
*/
private MessageBean message;
/**
* logprobs
*/
private Object logprobs;
/**
* finish_reason
*/
private String finish_reason;
@Setter
@Getter
public static class MessageBean {
/**
* role
*/
private String role;
/**
* content
*/
private String content;
}
}
}

View File

@@ -0,0 +1,77 @@
package work.slhaf.partner.api.chat.runtime;
import com.openai.client.OpenAIClient;
import com.openai.client.okhttp.OpenAIOkHttpClient;
import com.openai.core.http.StreamResponse;
import com.openai.helpers.ChatCompletionAccumulator;
import com.openai.models.chat.completions.*;
import work.slhaf.partner.api.chat.pojo.Message;
import java.time.Duration;
import java.util.List;
public class OpenAiChatRuntime {
private final OpenAIClient client;
private final String model;
public OpenAiChatRuntime(String baseUrl, String apikey, String model) {
this.client = OpenAIOkHttpClient.builder()
.baseUrl(baseUrl)
.apiKey(apikey)
.timeout(Duration.ofSeconds(30))
.build();
this.model = model;
}
public String chat(List<Message> messages, boolean streaming) {
ChatCompletionCreateParams params = buildParams(messages);
if (!streaming) {
return extractText(client.chat().completions().create(params));
}
ChatCompletionAccumulator accumulator = ChatCompletionAccumulator.create();
try (StreamResponse<ChatCompletionChunk> response = client.chat().completions().createStreaming(params)) {
response.stream().forEach(accumulator::accumulate);
}
return extractText(accumulator.chatCompletion());
}
public <T> T formattedChat(List<Message> messages, boolean streaming, Class<T> responseType) {
StructuredChatCompletionCreateParams<T> params = buildParams(messages).toBuilder()
.responseFormat(responseType)
.build();
if (!streaming) {
return extractStructured(client.chat().completions().create(params));
}
ChatCompletionAccumulator accumulator = ChatCompletionAccumulator.create();
try (StreamResponse<ChatCompletionChunk> response = client.chat().completions().createStreaming(params.rawParams())) {
response.stream().forEach(accumulator::accumulate);
}
return extractStructured(accumulator.chatCompletion(responseType));
}
private ChatCompletionCreateParams buildParams(List<Message> messages) {
return ChatCompletionCreateParams.builder()
.model(model)
.messages(OpenAiMessageAdapter.toParams(messages))
.build();
}
private String extractText(ChatCompletion completion) {
if (completion.choices().isEmpty()) {
throw new IllegalStateException("OpenAI chat completion returned no choices.");
}
return completion.choices().getFirst().message().content()
.orElseThrow(() -> new IllegalStateException("OpenAI chat completion returned empty content."));
}
private <T> T extractStructured(StructuredChatCompletion<T> completion) {
if (completion.choices().isEmpty()) {
throw new IllegalStateException("OpenAI structured chat completion returned no choices.");
}
return completion.choices().getFirst().message().content()
.orElseThrow(() -> new IllegalStateException("OpenAI structured chat completion returned empty content."));
}
}

View File

@@ -0,0 +1,40 @@
package work.slhaf.partner.api.chat.runtime;
import com.openai.models.chat.completions.ChatCompletionAssistantMessageParam;
import com.openai.models.chat.completions.ChatCompletionMessageParam;
import com.openai.models.chat.completions.ChatCompletionSystemMessageParam;
import com.openai.models.chat.completions.ChatCompletionUserMessageParam;
import work.slhaf.partner.api.chat.constant.ChatConstant;
import work.slhaf.partner.api.chat.pojo.Message;
import java.util.ArrayList;
import java.util.List;
public final class OpenAiMessageAdapter {
private OpenAiMessageAdapter() {
}
public static List<ChatCompletionMessageParam> toParams(List<Message> messages) {
List<ChatCompletionMessageParam> params = new ArrayList<>(messages.size());
for (Message message : messages) {
params.add(toParam(message));
}
return params;
}
public static ChatCompletionMessageParam toParam(Message message) {
return switch (message.getRole()) {
case ChatConstant.Character.SYSTEM -> ChatCompletionMessageParam.ofSystem(
ChatCompletionSystemMessageParam.builder().content(message.getContent()).build()
);
case ChatConstant.Character.ASSISTANT -> ChatCompletionMessageParam.ofAssistant(
ChatCompletionAssistantMessageParam.builder().content(message.getContent()).build()
);
case ChatConstant.Character.USER -> ChatCompletionMessageParam.ofUser(
ChatCompletionUserMessageParam.builder().content(message.getContent()).build()
);
default -> throw new IllegalArgumentException("Unsupported message role: " + message.getRole());
};
}
}