mirror of
https://github.com/slhaf/Partner.git
synced 2026-05-12 16:53:04 +08:00
refactor(framework): unify model invocation result and exception handling
This commit is contained in:
@@ -30,7 +30,7 @@ public class ActionCorrectionRecognizer extends AbstractAgentModule.Sub<Correcti
|
|||||||
resolveContextMessage(),
|
resolveContextMessage(),
|
||||||
resolveTaskMessage(input)
|
resolveTaskMessage(input)
|
||||||
);
|
);
|
||||||
return formattedChat(messages, CorrectionRecognizerResult.class);
|
return formattedChat(messages, CorrectionRecognizerResult.class).getOrThrow();
|
||||||
}
|
}
|
||||||
|
|
||||||
private Message resolveTaskMessage(CorrectionRecognizerInput input) {
|
private Message resolveTaskMessage(CorrectionRecognizerInput input) {
|
||||||
|
|||||||
@@ -30,7 +30,7 @@ public class ActionCorrector extends AbstractAgentModule.Sub<CorrectorInput, Cor
|
|||||||
resolveContextMessage(),
|
resolveContextMessage(),
|
||||||
resolveTaskMessage(input)
|
resolveTaskMessage(input)
|
||||||
);
|
);
|
||||||
return formattedChat(messages, CorrectorResult.class);
|
return formattedChat(messages, CorrectorResult.class).getOrThrow();
|
||||||
}
|
}
|
||||||
|
|
||||||
private Message resolveTaskMessage(CorrectorInput input) {
|
private Message resolveTaskMessage(CorrectorInput input) {
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ import work.slhaf.partner.framework.agent.factory.capability.annotation.InjectCa
|
|||||||
import work.slhaf.partner.framework.agent.factory.component.abstracts.AbstractAgentModule;
|
import work.slhaf.partner.framework.agent.factory.component.abstracts.AbstractAgentModule;
|
||||||
import work.slhaf.partner.framework.agent.model.ActivateModel;
|
import work.slhaf.partner.framework.agent.model.ActivateModel;
|
||||||
import work.slhaf.partner.framework.agent.model.pojo.Message;
|
import work.slhaf.partner.framework.agent.model.pojo.Message;
|
||||||
|
import work.slhaf.partner.framework.agent.support.Result;
|
||||||
import work.slhaf.partner.module.TaskBlock;
|
import work.slhaf.partner.module.TaskBlock;
|
||||||
import work.slhaf.partner.module.action.executor.entity.ExtractorInput;
|
import work.slhaf.partner.module.action.executor.entity.ExtractorInput;
|
||||||
import work.slhaf.partner.module.action.executor.entity.ExtractorResult;
|
import work.slhaf.partner.module.action.executor.entity.ExtractorResult;
|
||||||
@@ -28,20 +29,19 @@ public class ParamsExtractor extends AbstractAgentModule.Sub<ExtractorInput, Ext
|
|||||||
|
|
||||||
@Override
|
@Override
|
||||||
public ExtractorResult execute(ExtractorInput input) {
|
public ExtractorResult execute(ExtractorInput input) {
|
||||||
ExtractorResult result;
|
List<Message> messages = List.of(
|
||||||
try {
|
resolveContextMessage(),
|
||||||
List<Message> messages = List.of(
|
resolveTaskMessage(input)
|
||||||
resolveContextMessage(),
|
);
|
||||||
resolveTaskMessage(input)
|
Result<ExtractorResult> result = formattedChat(messages, ExtractorResult.class);
|
||||||
);
|
if (result.isFailure()) {
|
||||||
result = formattedChat(messages, ExtractorResult.class);
|
log.error("ParamsExtractor解析结果失败", result.exceptionOrNull());
|
||||||
} catch (Exception e) {
|
ExtractorResult fallback = new ExtractorResult();
|
||||||
log.error("ParamsExtractor解析结果失败", e);
|
fallback.setOk(false);
|
||||||
result = new ExtractorResult();
|
fallback.setParams(new HashMap<>());
|
||||||
result.setOk(false);
|
return fallback;
|
||||||
result.setParams(new HashMap<>());
|
|
||||||
}
|
}
|
||||||
return result;
|
return result.getOrThrow();
|
||||||
}
|
}
|
||||||
|
|
||||||
private Message resolveTaskMessage(ExtractorInput input) {
|
private Message resolveTaskMessage(ExtractorInput input) {
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ import work.slhaf.partner.framework.agent.factory.component.abstracts.AbstractAg
|
|||||||
import work.slhaf.partner.framework.agent.factory.component.annotation.Init;
|
import work.slhaf.partner.framework.agent.factory.component.annotation.Init;
|
||||||
import work.slhaf.partner.framework.agent.model.ActivateModel;
|
import work.slhaf.partner.framework.agent.model.ActivateModel;
|
||||||
import work.slhaf.partner.framework.agent.model.pojo.Message;
|
import work.slhaf.partner.framework.agent.model.pojo.Message;
|
||||||
|
import work.slhaf.partner.framework.agent.support.Result;
|
||||||
import work.slhaf.partner.module.action.planner.evaluator.entity.EvaluatorInput;
|
import work.slhaf.partner.module.action.planner.evaluator.entity.EvaluatorInput;
|
||||||
import work.slhaf.partner.module.action.planner.evaluator.entity.EvaluatorResult;
|
import work.slhaf.partner.module.action.planner.evaluator.entity.EvaluatorResult;
|
||||||
|
|
||||||
@@ -61,10 +62,15 @@ public class ActionEvaluator extends AbstractAgentModule.Sub<EvaluatorInput, Lis
|
|||||||
availableMetaActionContext(),
|
availableMetaActionContext(),
|
||||||
new Message(Message.Character.USER, tendency)
|
new Message(Message.Character.USER, tendency)
|
||||||
);
|
);
|
||||||
EvaluatorResult evaluatorResult = formattedChat(
|
Result<EvaluatorResult> result = formattedChat(
|
||||||
messages,
|
messages,
|
||||||
EvaluatorResult.class
|
EvaluatorResult.class
|
||||||
);
|
);
|
||||||
|
if (result.isFailure()) {
|
||||||
|
log.error("ActionEvaluator评估失败: {}", tendency, result.exceptionOrNull());
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
EvaluatorResult evaluatorResult = result.getOrThrow();
|
||||||
evaluatorResult.setTendency(tendency);
|
evaluatorResult.setTendency(tendency);
|
||||||
synchronized (evaluatorResults) {
|
synchronized (evaluatorResults) {
|
||||||
evaluatorResults.add(evaluatorResult);
|
evaluatorResults.add(evaluatorResult);
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import work.slhaf.partner.framework.agent.factory.capability.annotation.InjectCa
|
|||||||
import work.slhaf.partner.framework.agent.factory.component.abstracts.AbstractAgentModule;
|
import work.slhaf.partner.framework.agent.factory.component.abstracts.AbstractAgentModule;
|
||||||
import work.slhaf.partner.framework.agent.model.ActivateModel;
|
import work.slhaf.partner.framework.agent.model.ActivateModel;
|
||||||
import work.slhaf.partner.framework.agent.model.pojo.Message;
|
import work.slhaf.partner.framework.agent.model.pojo.Message;
|
||||||
|
import work.slhaf.partner.framework.agent.support.Result;
|
||||||
import work.slhaf.partner.module.action.planner.extractor.entity.ExtractorResult;
|
import work.slhaf.partner.module.action.planner.extractor.entity.ExtractorResult;
|
||||||
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
@@ -18,23 +19,18 @@ public class ActionExtractor extends AbstractAgentModule.Sub<String, ExtractorRe
|
|||||||
|
|
||||||
@Override
|
@Override
|
||||||
public ExtractorResult execute(String input) {
|
public ExtractorResult execute(String input) {
|
||||||
for (int i = 0; i < 3; i++) {
|
List<Message> messages = List.of(
|
||||||
try {
|
cognitionCapability.contextWorkspace().resolve(List.of(
|
||||||
List<Message> messages = List.of(
|
ContextBlock.VisibleDomain.COGNITION,
|
||||||
cognitionCapability.contextWorkspace().resolve(List.of(
|
ContextBlock.VisibleDomain.ACTION
|
||||||
ContextBlock.VisibleDomain.COGNITION,
|
)).encodeToMessage(),
|
||||||
ContextBlock.VisibleDomain.ACTION
|
new Message(Message.Character.USER, input)
|
||||||
)).encodeToMessage(),
|
);
|
||||||
new Message(Message.Character.USER, input)
|
Result<ExtractorResult> result = formattedChat(messages, ExtractorResult.class);
|
||||||
);
|
if (result.isSuccess()) {
|
||||||
return formattedChat(
|
return result.getOrThrow();
|
||||||
messages,
|
|
||||||
ExtractorResult.class
|
|
||||||
);
|
|
||||||
} catch (Exception e) {
|
|
||||||
log.error("提取信息出错", e);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
log.error("提取信息出错", result.exceptionOrNull());
|
||||||
return new ExtractorResult();
|
return new ExtractorResult();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ import work.slhaf.partner.framework.agent.factory.component.annotation.Init;
|
|||||||
import work.slhaf.partner.framework.agent.model.ActivateModel;
|
import work.slhaf.partner.framework.agent.model.ActivateModel;
|
||||||
import work.slhaf.partner.framework.agent.model.StreamChatMessageConsumer;
|
import work.slhaf.partner.framework.agent.model.StreamChatMessageConsumer;
|
||||||
import work.slhaf.partner.framework.agent.model.pojo.Message;
|
import work.slhaf.partner.framework.agent.model.pojo.Message;
|
||||||
|
import work.slhaf.partner.framework.agent.support.Result;
|
||||||
import work.slhaf.partner.runtime.PartnerRunningFlowContext;
|
import work.slhaf.partner.runtime.PartnerRunningFlowContext;
|
||||||
|
|
||||||
import javax.xml.parsers.DocumentBuilderFactory;
|
import javax.xml.parsers.DocumentBuilderFactory;
|
||||||
@@ -28,6 +29,8 @@ import java.util.stream.Collectors;
|
|||||||
@Data
|
@Data
|
||||||
public class CommunicationProducer extends AbstractAgentModule.Running<PartnerRunningFlowContext> implements ActivateModel {
|
public class CommunicationProducer extends AbstractAgentModule.Running<PartnerRunningFlowContext> implements ActivateModel {
|
||||||
|
|
||||||
|
private static final String INTERRUPTED_MARKER = " [response interrupted due to internal exception]";
|
||||||
|
|
||||||
private static final String MODULE_PROMPT = """
|
private static final String MODULE_PROMPT = """
|
||||||
你是 Partner 的表达模块。
|
你是 Partner 的表达模块。
|
||||||
你接下来收到的消息固定分为三个区段:
|
你接下来收到的消息固定分为三个区段:
|
||||||
@@ -64,7 +67,11 @@ public class CommunicationProducer extends AbstractAgentModule.Running<PartnerRu
|
|||||||
|
|
||||||
private void executeChat(PartnerRunningFlowContext runningFlowContext) {
|
private void executeChat(PartnerRunningFlowContext runningFlowContext) {
|
||||||
StreamChatMessageConsumer consumer = ReplyDispatcher.INSTANCE.createConsumer(runningFlowContext.getTarget());
|
StreamChatMessageConsumer consumer = ReplyDispatcher.INSTANCE.createConsumer(runningFlowContext.getTarget());
|
||||||
this.streamChat(buildChatMessages(runningFlowContext), consumer);
|
Result<kotlin.Unit> result = this.streamChat(buildChatMessages(runningFlowContext), consumer);
|
||||||
|
if (result.isFailure()) {
|
||||||
|
log.error("Streaming response failed", result.exceptionOrNull());
|
||||||
|
consumer.onDelta(INTERRUPTED_MARKER);
|
||||||
|
}
|
||||||
updateChatMessages(runningFlowContext, consumer.collectResponse());
|
updateChatMessages(runningFlowContext, consumer.collectResponse());
|
||||||
updateContext();
|
updateContext();
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -45,7 +45,7 @@ public class DialogRollingService extends AbstractAgentModule.Standalone impleme
|
|||||||
List<Message> messages = List.of(
|
List<Message> messages = List.of(
|
||||||
resolveTaskBlock(snapshotMessages)
|
resolveTaskBlock(snapshotMessages)
|
||||||
);
|
);
|
||||||
return chat(messages);
|
return chat(messages).getOrThrow();
|
||||||
}
|
}
|
||||||
|
|
||||||
private @NotNull BlockContent buildDialogAbstractBlock(String summary, @Nullable String unitId, @Nullable String sliceId) {
|
private @NotNull BlockContent buildDialogAbstractBlock(String summary, @Nullable String unitId, @Nullable String sliceId) {
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ import work.slhaf.partner.framework.agent.factory.component.abstracts.AbstractAg
|
|||||||
import work.slhaf.partner.framework.agent.factory.component.annotation.Init;
|
import work.slhaf.partner.framework.agent.factory.component.annotation.Init;
|
||||||
import work.slhaf.partner.framework.agent.model.ActivateModel;
|
import work.slhaf.partner.framework.agent.model.ActivateModel;
|
||||||
import work.slhaf.partner.framework.agent.model.pojo.Message;
|
import work.slhaf.partner.framework.agent.model.pojo.Message;
|
||||||
|
import work.slhaf.partner.framework.agent.support.Result;
|
||||||
import work.slhaf.partner.module.TaskBlock;
|
import work.slhaf.partner.module.TaskBlock;
|
||||||
import work.slhaf.partner.module.memory.selector.ActivatedMemorySlice;
|
import work.slhaf.partner.module.memory.selector.ActivatedMemorySlice;
|
||||||
import work.slhaf.partner.module.memory.selector.evaluator.entity.EvaluatorBatchInput;
|
import work.slhaf.partner.module.memory.selector.evaluator.entity.EvaluatorBatchInput;
|
||||||
@@ -64,13 +65,16 @@ public class SliceSelectEvaluator extends AbstractAgentModule.Sub<EvaluatorInput
|
|||||||
contextMessage,
|
contextMessage,
|
||||||
resolveTaskMessage(batchInput)
|
resolveTaskMessage(batchInput)
|
||||||
);
|
);
|
||||||
EvaluatorBatchResult batchResult = formattedChat(messages, EvaluatorBatchResult.class);
|
Result<EvaluatorBatchResult> batchResult = formattedChat(messages, EvaluatorBatchResult.class);
|
||||||
if (batchResult.isPassed()) {
|
if (batchResult.isFailure()) {
|
||||||
|
log.debug("切片评估失败,已跳过当前切片", batchResult.exceptionOrNull());
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (batchResult.getOrThrow().isPassed()) {
|
||||||
synchronized (result) {
|
synchronized (result) {
|
||||||
result.add(slice);
|
result.add(slice);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} catch (Exception ignore) {
|
|
||||||
} finally {
|
} finally {
|
||||||
latch.countDown();
|
latch.countDown();
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ import work.slhaf.partner.framework.agent.factory.component.abstracts.AbstractAg
|
|||||||
import work.slhaf.partner.framework.agent.factory.component.annotation.InjectModule;
|
import work.slhaf.partner.framework.agent.factory.component.annotation.InjectModule;
|
||||||
import work.slhaf.partner.framework.agent.model.ActivateModel;
|
import work.slhaf.partner.framework.agent.model.ActivateModel;
|
||||||
import work.slhaf.partner.framework.agent.model.pojo.Message;
|
import work.slhaf.partner.framework.agent.model.pojo.Message;
|
||||||
|
import work.slhaf.partner.framework.agent.support.Result;
|
||||||
import work.slhaf.partner.module.TaskBlock;
|
import work.slhaf.partner.module.TaskBlock;
|
||||||
import work.slhaf.partner.module.memory.runtime.MemoryRuntime;
|
import work.slhaf.partner.module.memory.runtime.MemoryRuntime;
|
||||||
import work.slhaf.partner.module.memory.selector.extractor.entity.ExtractorInput;
|
import work.slhaf.partner.module.memory.selector.extractor.entity.ExtractorInput;
|
||||||
@@ -34,18 +35,19 @@ public class MemorySelectExtractor extends AbstractAgentModule.Sub<ExtractorInpu
|
|||||||
public ExtractorResult execute(ExtractorInput input) {
|
public ExtractorResult execute(ExtractorInput input) {
|
||||||
log.debug("[MemorySelectExtractor] 主题提取模块开始...");
|
log.debug("[MemorySelectExtractor] 主题提取模块开始...");
|
||||||
ExtractorResult extractorResult;
|
ExtractorResult extractorResult;
|
||||||
try {
|
List<Message> messages = List.of(
|
||||||
List<Message> messages = List.of(
|
resolveContextMessage(),
|
||||||
resolveContextMessage(),
|
resolveTaskMessage(input)
|
||||||
resolveTaskMessage(input)
|
);
|
||||||
);
|
Result<ExtractorResult> result = formattedChat(
|
||||||
extractorResult = formattedChat(
|
messages,
|
||||||
messages,
|
ExtractorResult.class
|
||||||
ExtractorResult.class
|
);
|
||||||
);
|
if (result.isSuccess()) {
|
||||||
|
extractorResult = result.getOrThrow();
|
||||||
log.debug("[MemorySelectExtractor] 主题提取结果: {}", extractorResult);
|
log.debug("[MemorySelectExtractor] 主题提取结果: {}", extractorResult);
|
||||||
} catch (Exception e) {
|
} else {
|
||||||
log.error("[MemorySelectExtractor] 主题提取出错: ", e);
|
log.error("[MemorySelectExtractor] 主题提取出错: ", result.exceptionOrNull());
|
||||||
extractorResult = new ExtractorResult();
|
extractorResult = new ExtractorResult();
|
||||||
extractorResult.setRecall(false);
|
extractorResult.setRecall(false);
|
||||||
extractorResult.setMatches(List.of());
|
extractorResult.setMatches(List.of());
|
||||||
|
|||||||
@@ -28,7 +28,7 @@ public class MultiSummarizer extends AbstractAgentModule.Sub<SummarizeInput, Sum
|
|||||||
SummarizeResult result = formattedChat(
|
SummarizeResult result = formattedChat(
|
||||||
List.of(new Message(Message.Character.USER, JSONUtil.toJsonPrettyStr(input))),
|
List.of(new Message(Message.Character.USER, JSONUtil.toJsonPrettyStr(input))),
|
||||||
SummarizeResult.class
|
SummarizeResult.class
|
||||||
);
|
).getOrThrow();
|
||||||
log.debug("[MemorySummarizer] 整体摘要结果: {}", JSONObject.toJSONString(result));
|
log.debug("[MemorySummarizer] 整体摘要结果: {}", JSONObject.toJSONString(result));
|
||||||
return fix(result);
|
return fix(result);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -60,7 +60,7 @@ public class SingleSummarizer extends AbstractAgentModule.Sub<List<Message>, Voi
|
|||||||
|
|
||||||
private String singleExecute(String primaryContent) {
|
private String singleExecute(String primaryContent) {
|
||||||
try {
|
try {
|
||||||
return chat(List.of(new Message(Message.Character.USER, primaryContent)));
|
return chat(List.of(new Message(Message.Character.USER, primaryContent))).getOrThrow();
|
||||||
} catch (Exception e) {
|
} catch (Exception e) {
|
||||||
log.error("[SingleSummarizer] 单消息总结出错: ", e);
|
log.error("[SingleSummarizer] 单消息总结出错: ", e);
|
||||||
return primaryContent;
|
return primaryContent;
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ public class TotalSummarizer extends AbstractAgentModule.Sub<HashMap<String, Str
|
|||||||
return formattedChat(
|
return formattedChat(
|
||||||
List.of(new Message(Message.Character.USER, JSONUtil.toJsonPrettyStr(singleMemorySummary))),
|
List.of(new Message(Message.Character.USER, JSONUtil.toJsonPrettyStr(singleMemorySummary))),
|
||||||
SummaryContent.class
|
SummaryContent.class
|
||||||
).getContent();
|
).getOrThrow().getContent();
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|||||||
@@ -2,18 +2,22 @@ package work.slhaf.partner.framework.agent.model
|
|||||||
|
|
||||||
import work.slhaf.partner.framework.agent.factory.component.abstracts.AbstractAgentModule
|
import work.slhaf.partner.framework.agent.factory.component.abstracts.AbstractAgentModule
|
||||||
import work.slhaf.partner.framework.agent.model.pojo.Message
|
import work.slhaf.partner.framework.agent.model.pojo.Message
|
||||||
|
import work.slhaf.partner.framework.agent.support.Result
|
||||||
|
|
||||||
interface ActivateModel {
|
interface ActivateModel {
|
||||||
|
|
||||||
fun chat(messages: List<Message>): String {
|
fun chat(messages: List<Message>): Result<String> {
|
||||||
return ModelRuntimeRegistry.resolveProvider(modelKey()).chat(mergeMessages(messages))
|
return ModelRuntimeRegistry.resolveProvider(modelKey()).chat(mergeMessages(messages))
|
||||||
}
|
}
|
||||||
|
|
||||||
fun streamChat(messages: List<Message>, handler: StreamChatMessageConsumer) {
|
fun streamChat(
|
||||||
ModelRuntimeRegistry.resolveProvider(modelKey()).streamChat(mergeMessages(messages), handler)
|
messages: List<Message>,
|
||||||
|
handler: StreamChatMessageConsumer
|
||||||
|
): work.slhaf.partner.framework.agent.support.Result<Unit> {
|
||||||
|
return ModelRuntimeRegistry.resolveProvider(modelKey()).streamChat(mergeMessages(messages), handler)
|
||||||
}
|
}
|
||||||
|
|
||||||
fun <T : Any> formattedChat(messages: List<Message>, responseType: Class<T>): T {
|
fun <T : Any> formattedChat(messages: List<Message>, responseType: Class<T>): Result<T> {
|
||||||
return ModelRuntimeRegistry.resolveProvider(modelKey()).formattedChat(mergeMessages(messages), responseType)
|
return ModelRuntimeRegistry.resolveProvider(modelKey()).formattedChat(mergeMessages(messages), responseType)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -46,7 +46,13 @@ object ModelRuntimeRegistry : Configurable, ConfigRegistration<ModelRuntimeRegis
|
|||||||
private fun registerProvider(config: ProviderConfig) {
|
private fun registerProvider(config: ProviderConfig) {
|
||||||
when (config) {
|
when (config) {
|
||||||
is OpenAiCompatibleProviderConfig -> baseProvider[config.name] =
|
is OpenAiCompatibleProviderConfig -> baseProvider[config.name] =
|
||||||
OpenAiCompatibleProvider(config.baseUrl, config.apiKey, config.defaultModel)
|
OpenAiCompatibleProvider(
|
||||||
|
config.name,
|
||||||
|
DEFAULT_PROVIDER,
|
||||||
|
config.baseUrl,
|
||||||
|
config.apiKey,
|
||||||
|
config.defaultModel
|
||||||
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -61,11 +67,7 @@ object ModelRuntimeRegistry : Configurable, ConfigRegistration<ModelRuntimeRegis
|
|||||||
val override = config.override
|
val override = config.override
|
||||||
|
|
||||||
try {
|
try {
|
||||||
runtimeProvider[config.modelKey] = if (override != null) {
|
runtimeProvider[config.modelKey] = provider.fork(config.modelKey, override)
|
||||||
provider.fork(override)
|
|
||||||
} else {
|
|
||||||
provider
|
|
||||||
}
|
|
||||||
} catch (e: Exception) {
|
} catch (e: Exception) {
|
||||||
throw runtimeModelException(
|
throw runtimeModelException(
|
||||||
"Failed to build runtime provider for model key ${config.modelKey}",
|
"Failed to build runtime provider for model key ${config.modelKey}",
|
||||||
|
|||||||
@@ -3,18 +3,21 @@ package work.slhaf.partner.framework.agent.model.provider
|
|||||||
import com.alibaba.fastjson2.JSONObject
|
import com.alibaba.fastjson2.JSONObject
|
||||||
import work.slhaf.partner.framework.agent.model.StreamChatMessageConsumer
|
import work.slhaf.partner.framework.agent.model.StreamChatMessageConsumer
|
||||||
import work.slhaf.partner.framework.agent.model.pojo.Message
|
import work.slhaf.partner.framework.agent.model.pojo.Message
|
||||||
|
import work.slhaf.partner.framework.agent.support.Result
|
||||||
|
|
||||||
abstract class ModelProvider @JvmOverloads constructor(
|
abstract class ModelProvider @JvmOverloads constructor(
|
||||||
|
val providerName: String,
|
||||||
|
val modelKey: String,
|
||||||
val override: ProviderOverride? = null
|
val override: ProviderOverride? = null
|
||||||
) {
|
) {
|
||||||
|
|
||||||
abstract fun fork(override: ProviderOverride): ModelProvider
|
abstract fun fork(modelKey: String, override: ProviderOverride? = null): ModelProvider
|
||||||
|
|
||||||
abstract fun streamChat(messages: List<Message>, consumer: StreamChatMessageConsumer)
|
abstract fun streamChat(messages: List<Message>, consumer: StreamChatMessageConsumer): Result<Unit>
|
||||||
|
|
||||||
abstract fun chat(messages: List<Message>): String
|
abstract fun chat(messages: List<Message>): Result<String>
|
||||||
|
|
||||||
abstract fun <T> formattedChat(messages: List<Message>, type: Class<T>): T
|
abstract fun <T> formattedChat(messages: List<Message>, type: Class<T>): Result<T>
|
||||||
}
|
}
|
||||||
|
|
||||||
data class ProviderOverride(
|
data class ProviderOverride(
|
||||||
|
|||||||
@@ -6,24 +6,33 @@ import com.openai.client.okhttp.OpenAIOkHttpClient;
|
|||||||
import com.openai.core.JsonValue;
|
import com.openai.core.JsonValue;
|
||||||
import com.openai.core.http.StreamResponse;
|
import com.openai.core.http.StreamResponse;
|
||||||
import com.openai.models.chat.completions.*;
|
import com.openai.models.chat.completions.*;
|
||||||
|
import kotlin.Unit;
|
||||||
import org.jetbrains.annotations.NotNull;
|
import org.jetbrains.annotations.NotNull;
|
||||||
|
import work.slhaf.partner.framework.agent.exception.ModelInvokeException;
|
||||||
import work.slhaf.partner.framework.agent.model.StreamChatMessageConsumer;
|
import work.slhaf.partner.framework.agent.model.StreamChatMessageConsumer;
|
||||||
import work.slhaf.partner.framework.agent.model.pojo.Message;
|
import work.slhaf.partner.framework.agent.model.pojo.Message;
|
||||||
import work.slhaf.partner.framework.agent.model.provider.ModelProvider;
|
import work.slhaf.partner.framework.agent.model.provider.ModelProvider;
|
||||||
import work.slhaf.partner.framework.agent.model.provider.ProviderOverride;
|
import work.slhaf.partner.framework.agent.model.provider.ProviderOverride;
|
||||||
|
import work.slhaf.partner.framework.agent.support.Result;
|
||||||
|
|
||||||
import java.time.Duration;
|
import java.time.Duration;
|
||||||
|
import java.util.Iterator;
|
||||||
|
import java.util.LinkedHashMap;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
public class OpenAiCompatibleProvider extends ModelProvider {
|
public class OpenAiCompatibleProvider extends ModelProvider {
|
||||||
|
|
||||||
|
private static final int MAX_ATTEMPTS = 3;
|
||||||
|
|
||||||
private final String baseUrl;
|
private final String baseUrl;
|
||||||
private final String apiKey;
|
private final String apiKey;
|
||||||
private final String model;
|
private final String model;
|
||||||
|
|
||||||
private final OpenAIClient client;
|
private final OpenAIClient client;
|
||||||
|
|
||||||
public OpenAiCompatibleProvider(String baseUrl, String apikey, String model) {
|
public OpenAiCompatibleProvider(String providerName, String modelKey, String baseUrl, String apikey, String model) {
|
||||||
|
super(providerName, modelKey, null);
|
||||||
this.client = OpenAIOkHttpClient.builder()
|
this.client = OpenAIOkHttpClient.builder()
|
||||||
.baseUrl(baseUrl)
|
.baseUrl(baseUrl)
|
||||||
.apiKey(apikey)
|
.apiKey(apikey)
|
||||||
@@ -34,8 +43,8 @@ public class OpenAiCompatibleProvider extends ModelProvider {
|
|||||||
this.model = model;
|
this.model = model;
|
||||||
}
|
}
|
||||||
|
|
||||||
public OpenAiCompatibleProvider(String baseUrl, String apikey, String model, ProviderOverride override) {
|
public OpenAiCompatibleProvider(String providerName, String modelKey, String baseUrl, String apikey, String model, ProviderOverride override) {
|
||||||
super(override);
|
super(providerName, modelKey, override);
|
||||||
this.client = OpenAIOkHttpClient.builder()
|
this.client = OpenAIOkHttpClient.builder()
|
||||||
.baseUrl(baseUrl)
|
.baseUrl(baseUrl)
|
||||||
.apiKey(apikey)
|
.apiKey(apikey)
|
||||||
@@ -46,27 +55,59 @@ public class OpenAiCompatibleProvider extends ModelProvider {
|
|||||||
this.model = model;
|
this.model = model;
|
||||||
}
|
}
|
||||||
|
|
||||||
public @NotNull String chat(@NotNull List<Message> messages) {
|
@Override
|
||||||
ChatCompletionCreateParams params = buildParams(messages);
|
public @NotNull Result<String> chat(@NotNull List<Message> messages) {
|
||||||
return extractText(client.chat().completions().create(params));
|
return executeWithRetry(
|
||||||
|
"OpenAI-compatible provider failed to complete the chat request after 3 attempts.",
|
||||||
|
() -> extractText(client.chat().completions().create(buildParams(messages)))
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
public void streamChat(@NotNull List<Message> messages, StreamChatMessageConsumer handler) {
|
@Override
|
||||||
ChatCompletionCreateParams params = buildParams(messages);
|
public @NotNull Result<Unit> streamChat(@NotNull List<Message> messages, @NotNull StreamChatMessageConsumer handler) {
|
||||||
try (StreamResponse<ChatCompletionChunk> streamResponse = client.chat().completions().createStreaming(params)) {
|
Exception lastFailure = null;
|
||||||
streamResponse.stream()
|
int remainingAttempts = MAX_ATTEMPTS;
|
||||||
.flatMap(completion -> completion.choices().stream())
|
while (remainingAttempts > 0) {
|
||||||
.flatMap(choice -> choice.delta().content().stream())
|
boolean emitted = false;
|
||||||
.filter(delta -> !delta.isEmpty())
|
try (StreamResponse<ChatCompletionChunk> streamResponse = client.chat().completions().createStreaming(buildParams(messages))) {
|
||||||
.forEach(handler::onDelta);
|
Iterator<ChatCompletionChunk> iterator = streamResponse.stream().iterator();
|
||||||
|
while (iterator.hasNext()) {
|
||||||
|
ChatCompletionChunk chunk = iterator.next();
|
||||||
|
for (ChatCompletionChunk.Choice choice : chunk.choices()) {
|
||||||
|
String delta = choice.delta().content().orElse("");
|
||||||
|
if (delta.isEmpty()) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
emitted = true;
|
||||||
|
handler.onDelta(delta);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return Result.success(Unit.INSTANCE);
|
||||||
|
} catch (Exception e) {
|
||||||
|
lastFailure = e;
|
||||||
|
remainingAttempts--;
|
||||||
|
if (emitted || remainingAttempts == 0) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
return Result.failure(invokeException(
|
||||||
|
"OpenAI-compatible provider failed to stream the chat response after 3 attempts.",
|
||||||
|
lastFailure
|
||||||
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
public <T> T formattedChat(@NotNull List<Message> messages, @NotNull Class<T> responseType) {
|
@Override
|
||||||
StructuredChatCompletionCreateParams<T> params = buildParams(messages).toBuilder()
|
public <T> @NotNull Result<T> formattedChat(@NotNull List<Message> messages, @NotNull Class<T> responseType) {
|
||||||
.responseFormat(responseType)
|
return executeWithRetry(
|
||||||
.build();
|
"OpenAI-compatible provider failed to complete the structured chat request after 3 attempts.",
|
||||||
return extractStructured(client.chat().completions().create(params));
|
() -> {
|
||||||
|
StructuredChatCompletionCreateParams<T> params = buildParams(messages).toBuilder()
|
||||||
|
.responseFormat(responseType)
|
||||||
|
.build();
|
||||||
|
return extractStructured(client.chat().completions().create(params));
|
||||||
|
}
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
private ChatCompletionCreateParams buildParams(List<Message> messages) {
|
private ChatCompletionCreateParams buildParams(List<Message> messages) {
|
||||||
@@ -87,9 +128,7 @@ public class OpenAiCompatibleProvider extends ModelProvider {
|
|||||||
}
|
}
|
||||||
JSONObject extras = override.getExtras();
|
JSONObject extras = override.getExtras();
|
||||||
if (extras != null) {
|
if (extras != null) {
|
||||||
extras.forEach((key, value) -> {
|
extras.forEach((key, value) -> paramsBuilder.putAdditionalBodyProperty(key, JsonValue.from(value)));
|
||||||
paramsBuilder.putAdditionalBodyProperty(key, JsonValue.from(value));
|
|
||||||
});
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -98,22 +137,81 @@ public class OpenAiCompatibleProvider extends ModelProvider {
|
|||||||
|
|
||||||
private String extractText(ChatCompletion completion) {
|
private String extractText(ChatCompletion completion) {
|
||||||
if (completion.choices().isEmpty()) {
|
if (completion.choices().isEmpty()) {
|
||||||
throw new IllegalStateException("OpenAI chat completion returned no choices.");
|
throw invokeException("OpenAI chat completion returned no choices.", null);
|
||||||
}
|
}
|
||||||
return completion.choices().getFirst().message().content()
|
return completion.choices().getFirst().message().content()
|
||||||
.orElseThrow(() -> new IllegalStateException("OpenAI chat completion returned empty content."));
|
.orElseThrow(() -> invokeException("OpenAI chat completion returned empty content.", null));
|
||||||
}
|
}
|
||||||
|
|
||||||
private <T> T extractStructured(StructuredChatCompletion<T> completion) {
|
private <T> T extractStructured(StructuredChatCompletion<T> completion) {
|
||||||
if (completion.choices().isEmpty()) {
|
if (completion.choices().isEmpty()) {
|
||||||
throw new IllegalStateException("OpenAI structured chat completion returned no choices.");
|
throw invokeException("OpenAI structured chat completion returned no choices.", null);
|
||||||
}
|
}
|
||||||
return completion.choices().getFirst().message().content()
|
return completion.choices().getFirst().message().content()
|
||||||
.orElseThrow(() -> new IllegalStateException("OpenAI structured chat completion returned empty content."));
|
.orElseThrow(() -> invokeException("OpenAI structured chat completion returned empty content.", null));
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public @NotNull ModelProvider fork(@NotNull ProviderOverride override) {
|
public @NotNull ModelProvider fork(@NotNull String modelKey, ProviderOverride override) {
|
||||||
return new OpenAiCompatibleProvider(baseUrl, apiKey, override.getModel(), override);
|
if (override == null) {
|
||||||
|
return new OpenAiCompatibleProvider(getProviderName(), modelKey, baseUrl, apiKey, model, getOverride());
|
||||||
|
}
|
||||||
|
return new OpenAiCompatibleProvider(getProviderName(), modelKey, baseUrl, apiKey, override.getModel(), override);
|
||||||
|
}
|
||||||
|
|
||||||
|
private <T> Result<T> executeWithRetry(String failureMessage, ThrowingSupplier<T> supplier) {
|
||||||
|
Exception lastFailure = null;
|
||||||
|
for (int attempt = 1; attempt <= MAX_ATTEMPTS; attempt++) {
|
||||||
|
Result<T> result = Result.runCatching(supplier::get);
|
||||||
|
if (result.isSuccess()) {
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
Throwable throwable = result.exceptionOrNull();
|
||||||
|
if (throwable instanceof Exception exception) {
|
||||||
|
lastFailure = exception;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (throwable instanceof Error error) {
|
||||||
|
throw error;
|
||||||
|
}
|
||||||
|
return Result.failure(invokeException(failureMessage, throwable));
|
||||||
|
}
|
||||||
|
return Result.failure(invokeException(failureMessage, lastFailure));
|
||||||
|
}
|
||||||
|
|
||||||
|
private ModelInvokeException invokeException(String message, Throwable cause) {
|
||||||
|
return new ModelInvokeException(
|
||||||
|
message,
|
||||||
|
getProviderName(),
|
||||||
|
getModelKey(),
|
||||||
|
baseUrl,
|
||||||
|
model,
|
||||||
|
getOverride() == null ? Map.of() : toOverrideReport(getOverride()),
|
||||||
|
cause
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
private Map<String, String> toOverrideReport(ProviderOverride override) {
|
||||||
|
Map<String, String> result = new LinkedHashMap<>();
|
||||||
|
result.put("model", override.getModel());
|
||||||
|
if (override.getTemperature() != null) {
|
||||||
|
result.put("temperature", override.getTemperature().toString());
|
||||||
|
}
|
||||||
|
if (override.getTopP() != null) {
|
||||||
|
result.put("topP", override.getTopP().toString());
|
||||||
|
}
|
||||||
|
if (override.getMaxTokens() != null) {
|
||||||
|
result.put("maxTokens", override.getMaxTokens().toString());
|
||||||
|
}
|
||||||
|
JSONObject extras = override.getExtras();
|
||||||
|
if (extras != null) {
|
||||||
|
extras.forEach((key, value) -> result.put("extra." + key, value == null ? "null" : value.toString()));
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
@FunctionalInterface
|
||||||
|
private interface ThrowingSupplier<T> {
|
||||||
|
T get() throws Exception;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,76 @@
|
|||||||
|
package work.slhaf.partner.framework.agent.support
|
||||||
|
|
||||||
|
import work.slhaf.partner.framework.agent.exception.AgentRuntimeException
|
||||||
|
|
||||||
|
class Result<T> private constructor(
|
||||||
|
private val value: T?,
|
||||||
|
private val exception: Throwable?
|
||||||
|
) {
|
||||||
|
|
||||||
|
fun isSuccess(): Boolean = exception == null
|
||||||
|
|
||||||
|
fun isFailure(): Boolean = exception != null
|
||||||
|
|
||||||
|
fun getOrNull(): T? = value
|
||||||
|
|
||||||
|
fun exceptionOrNull(): Throwable? = exception
|
||||||
|
|
||||||
|
fun getOrThrow(): T {
|
||||||
|
if (exception == null) {
|
||||||
|
@Suppress("UNCHECKED_CAST")
|
||||||
|
return value as T
|
||||||
|
}
|
||||||
|
when (exception) {
|
||||||
|
is RuntimeException -> throw exception
|
||||||
|
is Error -> throw exception
|
||||||
|
else -> throw IllegalStateException(exception.message, exception)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fun getOrDefault(defaultValue: T): T {
|
||||||
|
return if (exception == null) {
|
||||||
|
@Suppress("UNCHECKED_CAST")
|
||||||
|
value as T
|
||||||
|
} else {
|
||||||
|
defaultValue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun toString(): String {
|
||||||
|
return if (exception == null) {
|
||||||
|
"Result.success($value)"
|
||||||
|
} else {
|
||||||
|
"Result.failure($exception)"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fun interface ThrowingSupplier<T> {
|
||||||
|
@Throws(Throwable::class)
|
||||||
|
fun get(): T
|
||||||
|
}
|
||||||
|
|
||||||
|
companion object {
|
||||||
|
@JvmStatic
|
||||||
|
fun <T> success(value: T): Result<T> = Result(value, null)
|
||||||
|
|
||||||
|
@JvmStatic
|
||||||
|
fun <T> failure(exception: Throwable): Result<T> = Result(null, exception)
|
||||||
|
|
||||||
|
@JvmStatic
|
||||||
|
fun <T> runCatching(block: ThrowingSupplier<T>): Result<T> {
|
||||||
|
return try {
|
||||||
|
success(block.get())
|
||||||
|
} catch (throwable: Throwable) {
|
||||||
|
failure(
|
||||||
|
when (throwable) {
|
||||||
|
is AgentRuntimeException, is Error -> throwable
|
||||||
|
else -> AgentRuntimeException(
|
||||||
|
throwable.message ?: "Unexpected runtime failure.",
|
||||||
|
throwable
|
||||||
|
)
|
||||||
|
}
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user