refactor(framework): unify model invocation result and exception handling

This commit is contained in:
2026-04-11 14:50:12 +08:00
parent 3732555f02
commit b8cb2afbcf
17 changed files with 293 additions and 95 deletions

View File

@@ -30,7 +30,7 @@ public class ActionCorrectionRecognizer extends AbstractAgentModule.Sub<Correcti
resolveContextMessage(), resolveContextMessage(),
resolveTaskMessage(input) resolveTaskMessage(input)
); );
return formattedChat(messages, CorrectionRecognizerResult.class); return formattedChat(messages, CorrectionRecognizerResult.class).getOrThrow();
} }
private Message resolveTaskMessage(CorrectionRecognizerInput input) { private Message resolveTaskMessage(CorrectionRecognizerInput input) {

View File

@@ -30,7 +30,7 @@ public class ActionCorrector extends AbstractAgentModule.Sub<CorrectorInput, Cor
resolveContextMessage(), resolveContextMessage(),
resolveTaskMessage(input) resolveTaskMessage(input)
); );
return formattedChat(messages, CorrectorResult.class); return formattedChat(messages, CorrectorResult.class).getOrThrow();
} }
private Message resolveTaskMessage(CorrectorInput input) { private Message resolveTaskMessage(CorrectorInput input) {

View File

@@ -11,6 +11,7 @@ import work.slhaf.partner.framework.agent.factory.capability.annotation.InjectCa
import work.slhaf.partner.framework.agent.factory.component.abstracts.AbstractAgentModule; import work.slhaf.partner.framework.agent.factory.component.abstracts.AbstractAgentModule;
import work.slhaf.partner.framework.agent.model.ActivateModel; import work.slhaf.partner.framework.agent.model.ActivateModel;
import work.slhaf.partner.framework.agent.model.pojo.Message; import work.slhaf.partner.framework.agent.model.pojo.Message;
import work.slhaf.partner.framework.agent.support.Result;
import work.slhaf.partner.module.TaskBlock; import work.slhaf.partner.module.TaskBlock;
import work.slhaf.partner.module.action.executor.entity.ExtractorInput; import work.slhaf.partner.module.action.executor.entity.ExtractorInput;
import work.slhaf.partner.module.action.executor.entity.ExtractorResult; import work.slhaf.partner.module.action.executor.entity.ExtractorResult;
@@ -28,20 +29,19 @@ public class ParamsExtractor extends AbstractAgentModule.Sub<ExtractorInput, Ext
@Override @Override
public ExtractorResult execute(ExtractorInput input) { public ExtractorResult execute(ExtractorInput input) {
ExtractorResult result; List<Message> messages = List.of(
try { resolveContextMessage(),
List<Message> messages = List.of( resolveTaskMessage(input)
resolveContextMessage(), );
resolveTaskMessage(input) Result<ExtractorResult> result = formattedChat(messages, ExtractorResult.class);
); if (result.isFailure()) {
result = formattedChat(messages, ExtractorResult.class); log.error("ParamsExtractor解析结果失败", result.exceptionOrNull());
} catch (Exception e) { ExtractorResult fallback = new ExtractorResult();
log.error("ParamsExtractor解析结果失败", e); fallback.setOk(false);
result = new ExtractorResult(); fallback.setParams(new HashMap<>());
result.setOk(false); return fallback;
result.setParams(new HashMap<>());
} }
return result; return result.getOrThrow();
} }
private Message resolveTaskMessage(ExtractorInput input) { private Message resolveTaskMessage(ExtractorInput input) {

View File

@@ -15,6 +15,7 @@ import work.slhaf.partner.framework.agent.factory.component.abstracts.AbstractAg
import work.slhaf.partner.framework.agent.factory.component.annotation.Init; import work.slhaf.partner.framework.agent.factory.component.annotation.Init;
import work.slhaf.partner.framework.agent.model.ActivateModel; import work.slhaf.partner.framework.agent.model.ActivateModel;
import work.slhaf.partner.framework.agent.model.pojo.Message; import work.slhaf.partner.framework.agent.model.pojo.Message;
import work.slhaf.partner.framework.agent.support.Result;
import work.slhaf.partner.module.action.planner.evaluator.entity.EvaluatorInput; import work.slhaf.partner.module.action.planner.evaluator.entity.EvaluatorInput;
import work.slhaf.partner.module.action.planner.evaluator.entity.EvaluatorResult; import work.slhaf.partner.module.action.planner.evaluator.entity.EvaluatorResult;
@@ -61,10 +62,15 @@ public class ActionEvaluator extends AbstractAgentModule.Sub<EvaluatorInput, Lis
availableMetaActionContext(), availableMetaActionContext(),
new Message(Message.Character.USER, tendency) new Message(Message.Character.USER, tendency)
); );
EvaluatorResult evaluatorResult = formattedChat( Result<EvaluatorResult> result = formattedChat(
messages, messages,
EvaluatorResult.class EvaluatorResult.class
); );
if (result.isFailure()) {
log.error("ActionEvaluator评估失败: {}", tendency, result.exceptionOrNull());
return;
}
EvaluatorResult evaluatorResult = result.getOrThrow();
evaluatorResult.setTendency(tendency); evaluatorResult.setTendency(tendency);
synchronized (evaluatorResults) { synchronized (evaluatorResults) {
evaluatorResults.add(evaluatorResult); evaluatorResults.add(evaluatorResult);

View File

@@ -7,6 +7,7 @@ import work.slhaf.partner.framework.agent.factory.capability.annotation.InjectCa
import work.slhaf.partner.framework.agent.factory.component.abstracts.AbstractAgentModule; import work.slhaf.partner.framework.agent.factory.component.abstracts.AbstractAgentModule;
import work.slhaf.partner.framework.agent.model.ActivateModel; import work.slhaf.partner.framework.agent.model.ActivateModel;
import work.slhaf.partner.framework.agent.model.pojo.Message; import work.slhaf.partner.framework.agent.model.pojo.Message;
import work.slhaf.partner.framework.agent.support.Result;
import work.slhaf.partner.module.action.planner.extractor.entity.ExtractorResult; import work.slhaf.partner.module.action.planner.extractor.entity.ExtractorResult;
import java.util.List; import java.util.List;
@@ -18,23 +19,18 @@ public class ActionExtractor extends AbstractAgentModule.Sub<String, ExtractorRe
@Override @Override
public ExtractorResult execute(String input) { public ExtractorResult execute(String input) {
for (int i = 0; i < 3; i++) { List<Message> messages = List.of(
try { cognitionCapability.contextWorkspace().resolve(List.of(
List<Message> messages = List.of( ContextBlock.VisibleDomain.COGNITION,
cognitionCapability.contextWorkspace().resolve(List.of( ContextBlock.VisibleDomain.ACTION
ContextBlock.VisibleDomain.COGNITION, )).encodeToMessage(),
ContextBlock.VisibleDomain.ACTION new Message(Message.Character.USER, input)
)).encodeToMessage(), );
new Message(Message.Character.USER, input) Result<ExtractorResult> result = formattedChat(messages, ExtractorResult.class);
); if (result.isSuccess()) {
return formattedChat( return result.getOrThrow();
messages,
ExtractorResult.class
);
} catch (Exception e) {
log.error("提取信息出错", e);
}
} }
log.error("提取信息出错", result.exceptionOrNull());
return new ExtractorResult(); return new ExtractorResult();
} }

View File

@@ -12,6 +12,7 @@ import work.slhaf.partner.framework.agent.factory.component.annotation.Init;
import work.slhaf.partner.framework.agent.model.ActivateModel; import work.slhaf.partner.framework.agent.model.ActivateModel;
import work.slhaf.partner.framework.agent.model.StreamChatMessageConsumer; import work.slhaf.partner.framework.agent.model.StreamChatMessageConsumer;
import work.slhaf.partner.framework.agent.model.pojo.Message; import work.slhaf.partner.framework.agent.model.pojo.Message;
import work.slhaf.partner.framework.agent.support.Result;
import work.slhaf.partner.runtime.PartnerRunningFlowContext; import work.slhaf.partner.runtime.PartnerRunningFlowContext;
import javax.xml.parsers.DocumentBuilderFactory; import javax.xml.parsers.DocumentBuilderFactory;
@@ -28,6 +29,8 @@ import java.util.stream.Collectors;
@Data @Data
public class CommunicationProducer extends AbstractAgentModule.Running<PartnerRunningFlowContext> implements ActivateModel { public class CommunicationProducer extends AbstractAgentModule.Running<PartnerRunningFlowContext> implements ActivateModel {
private static final String INTERRUPTED_MARKER = " [response interrupted due to internal exception]";
private static final String MODULE_PROMPT = """ private static final String MODULE_PROMPT = """
你是 Partner 的表达模块。 你是 Partner 的表达模块。
你接下来收到的消息固定分为三个区段: 你接下来收到的消息固定分为三个区段:
@@ -64,7 +67,11 @@ public class CommunicationProducer extends AbstractAgentModule.Running<PartnerRu
private void executeChat(PartnerRunningFlowContext runningFlowContext) { private void executeChat(PartnerRunningFlowContext runningFlowContext) {
StreamChatMessageConsumer consumer = ReplyDispatcher.INSTANCE.createConsumer(runningFlowContext.getTarget()); StreamChatMessageConsumer consumer = ReplyDispatcher.INSTANCE.createConsumer(runningFlowContext.getTarget());
this.streamChat(buildChatMessages(runningFlowContext), consumer); Result<kotlin.Unit> result = this.streamChat(buildChatMessages(runningFlowContext), consumer);
if (result.isFailure()) {
log.error("Streaming response failed", result.exceptionOrNull());
consumer.onDelta(INTERRUPTED_MARKER);
}
updateChatMessages(runningFlowContext, consumer.collectResponse()); updateChatMessages(runningFlowContext, consumer.collectResponse());
updateContext(); updateContext();
} }

View File

@@ -45,7 +45,7 @@ public class DialogRollingService extends AbstractAgentModule.Standalone impleme
List<Message> messages = List.of( List<Message> messages = List.of(
resolveTaskBlock(snapshotMessages) resolveTaskBlock(snapshotMessages)
); );
return chat(messages); return chat(messages).getOrThrow();
} }
private @NotNull BlockContent buildDialogAbstractBlock(String summary, @Nullable String unitId, @Nullable String sliceId) { private @NotNull BlockContent buildDialogAbstractBlock(String summary, @Nullable String unitId, @Nullable String sliceId) {

View File

@@ -15,6 +15,7 @@ import work.slhaf.partner.framework.agent.factory.component.abstracts.AbstractAg
import work.slhaf.partner.framework.agent.factory.component.annotation.Init; import work.slhaf.partner.framework.agent.factory.component.annotation.Init;
import work.slhaf.partner.framework.agent.model.ActivateModel; import work.slhaf.partner.framework.agent.model.ActivateModel;
import work.slhaf.partner.framework.agent.model.pojo.Message; import work.slhaf.partner.framework.agent.model.pojo.Message;
import work.slhaf.partner.framework.agent.support.Result;
import work.slhaf.partner.module.TaskBlock; import work.slhaf.partner.module.TaskBlock;
import work.slhaf.partner.module.memory.selector.ActivatedMemorySlice; import work.slhaf.partner.module.memory.selector.ActivatedMemorySlice;
import work.slhaf.partner.module.memory.selector.evaluator.entity.EvaluatorBatchInput; import work.slhaf.partner.module.memory.selector.evaluator.entity.EvaluatorBatchInput;
@@ -64,13 +65,16 @@ public class SliceSelectEvaluator extends AbstractAgentModule.Sub<EvaluatorInput
contextMessage, contextMessage,
resolveTaskMessage(batchInput) resolveTaskMessage(batchInput)
); );
EvaluatorBatchResult batchResult = formattedChat(messages, EvaluatorBatchResult.class); Result<EvaluatorBatchResult> batchResult = formattedChat(messages, EvaluatorBatchResult.class);
if (batchResult.isPassed()) { if (batchResult.isFailure()) {
log.debug("切片评估失败,已跳过当前切片", batchResult.exceptionOrNull());
return;
}
if (batchResult.getOrThrow().isPassed()) {
synchronized (result) { synchronized (result) {
result.add(slice); result.add(slice);
} }
} }
} catch (Exception ignore) {
} finally { } finally {
latch.countDown(); latch.countDown();
} }

View File

@@ -13,6 +13,7 @@ import work.slhaf.partner.framework.agent.factory.component.abstracts.AbstractAg
import work.slhaf.partner.framework.agent.factory.component.annotation.InjectModule; import work.slhaf.partner.framework.agent.factory.component.annotation.InjectModule;
import work.slhaf.partner.framework.agent.model.ActivateModel; import work.slhaf.partner.framework.agent.model.ActivateModel;
import work.slhaf.partner.framework.agent.model.pojo.Message; import work.slhaf.partner.framework.agent.model.pojo.Message;
import work.slhaf.partner.framework.agent.support.Result;
import work.slhaf.partner.module.TaskBlock; import work.slhaf.partner.module.TaskBlock;
import work.slhaf.partner.module.memory.runtime.MemoryRuntime; import work.slhaf.partner.module.memory.runtime.MemoryRuntime;
import work.slhaf.partner.module.memory.selector.extractor.entity.ExtractorInput; import work.slhaf.partner.module.memory.selector.extractor.entity.ExtractorInput;
@@ -34,18 +35,19 @@ public class MemorySelectExtractor extends AbstractAgentModule.Sub<ExtractorInpu
public ExtractorResult execute(ExtractorInput input) { public ExtractorResult execute(ExtractorInput input) {
log.debug("[MemorySelectExtractor] 主题提取模块开始..."); log.debug("[MemorySelectExtractor] 主题提取模块开始...");
ExtractorResult extractorResult; ExtractorResult extractorResult;
try { List<Message> messages = List.of(
List<Message> messages = List.of( resolveContextMessage(),
resolveContextMessage(), resolveTaskMessage(input)
resolveTaskMessage(input) );
); Result<ExtractorResult> result = formattedChat(
extractorResult = formattedChat( messages,
messages, ExtractorResult.class
ExtractorResult.class );
); if (result.isSuccess()) {
extractorResult = result.getOrThrow();
log.debug("[MemorySelectExtractor] 主题提取结果: {}", extractorResult); log.debug("[MemorySelectExtractor] 主题提取结果: {}", extractorResult);
} catch (Exception e) { } else {
log.error("[MemorySelectExtractor] 主题提取出错: ", e); log.error("[MemorySelectExtractor] 主题提取出错: ", result.exceptionOrNull());
extractorResult = new ExtractorResult(); extractorResult = new ExtractorResult();
extractorResult.setRecall(false); extractorResult.setRecall(false);
extractorResult.setMatches(List.of()); extractorResult.setMatches(List.of());

View File

@@ -28,7 +28,7 @@ public class MultiSummarizer extends AbstractAgentModule.Sub<SummarizeInput, Sum
SummarizeResult result = formattedChat( SummarizeResult result = formattedChat(
List.of(new Message(Message.Character.USER, JSONUtil.toJsonPrettyStr(input))), List.of(new Message(Message.Character.USER, JSONUtil.toJsonPrettyStr(input))),
SummarizeResult.class SummarizeResult.class
); ).getOrThrow();
log.debug("[MemorySummarizer] 整体摘要结果: {}", JSONObject.toJSONString(result)); log.debug("[MemorySummarizer] 整体摘要结果: {}", JSONObject.toJSONString(result));
return fix(result); return fix(result);
} }

View File

@@ -60,7 +60,7 @@ public class SingleSummarizer extends AbstractAgentModule.Sub<List<Message>, Voi
private String singleExecute(String primaryContent) { private String singleExecute(String primaryContent) {
try { try {
return chat(List.of(new Message(Message.Character.USER, primaryContent))); return chat(List.of(new Message(Message.Character.USER, primaryContent))).getOrThrow();
} catch (Exception e) { } catch (Exception e) {
log.error("[SingleSummarizer] 单消息总结出错: ", e); log.error("[SingleSummarizer] 单消息总结出错: ", e);
return primaryContent; return primaryContent;

View File

@@ -17,7 +17,7 @@ public class TotalSummarizer extends AbstractAgentModule.Sub<HashMap<String, Str
return formattedChat( return formattedChat(
List.of(new Message(Message.Character.USER, JSONUtil.toJsonPrettyStr(singleMemorySummary))), List.of(new Message(Message.Character.USER, JSONUtil.toJsonPrettyStr(singleMemorySummary))),
SummaryContent.class SummaryContent.class
).getContent(); ).getOrThrow().getContent();
} }
@Override @Override

View File

@@ -2,18 +2,22 @@ package work.slhaf.partner.framework.agent.model
import work.slhaf.partner.framework.agent.factory.component.abstracts.AbstractAgentModule import work.slhaf.partner.framework.agent.factory.component.abstracts.AbstractAgentModule
import work.slhaf.partner.framework.agent.model.pojo.Message import work.slhaf.partner.framework.agent.model.pojo.Message
import work.slhaf.partner.framework.agent.support.Result
interface ActivateModel { interface ActivateModel {
fun chat(messages: List<Message>): String { fun chat(messages: List<Message>): Result<String> {
return ModelRuntimeRegistry.resolveProvider(modelKey()).chat(mergeMessages(messages)) return ModelRuntimeRegistry.resolveProvider(modelKey()).chat(mergeMessages(messages))
} }
fun streamChat(messages: List<Message>, handler: StreamChatMessageConsumer) { fun streamChat(
ModelRuntimeRegistry.resolveProvider(modelKey()).streamChat(mergeMessages(messages), handler) messages: List<Message>,
handler: StreamChatMessageConsumer
): work.slhaf.partner.framework.agent.support.Result<Unit> {
return ModelRuntimeRegistry.resolveProvider(modelKey()).streamChat(mergeMessages(messages), handler)
} }
fun <T : Any> formattedChat(messages: List<Message>, responseType: Class<T>): T { fun <T : Any> formattedChat(messages: List<Message>, responseType: Class<T>): Result<T> {
return ModelRuntimeRegistry.resolveProvider(modelKey()).formattedChat(mergeMessages(messages), responseType) return ModelRuntimeRegistry.resolveProvider(modelKey()).formattedChat(mergeMessages(messages), responseType)
} }

View File

@@ -46,7 +46,13 @@ object ModelRuntimeRegistry : Configurable, ConfigRegistration<ModelRuntimeRegis
private fun registerProvider(config: ProviderConfig) { private fun registerProvider(config: ProviderConfig) {
when (config) { when (config) {
is OpenAiCompatibleProviderConfig -> baseProvider[config.name] = is OpenAiCompatibleProviderConfig -> baseProvider[config.name] =
OpenAiCompatibleProvider(config.baseUrl, config.apiKey, config.defaultModel) OpenAiCompatibleProvider(
config.name,
DEFAULT_PROVIDER,
config.baseUrl,
config.apiKey,
config.defaultModel
)
} }
} }
@@ -61,11 +67,7 @@ object ModelRuntimeRegistry : Configurable, ConfigRegistration<ModelRuntimeRegis
val override = config.override val override = config.override
try { try {
runtimeProvider[config.modelKey] = if (override != null) { runtimeProvider[config.modelKey] = provider.fork(config.modelKey, override)
provider.fork(override)
} else {
provider
}
} catch (e: Exception) { } catch (e: Exception) {
throw runtimeModelException( throw runtimeModelException(
"Failed to build runtime provider for model key ${config.modelKey}", "Failed to build runtime provider for model key ${config.modelKey}",

View File

@@ -3,18 +3,21 @@ package work.slhaf.partner.framework.agent.model.provider
import com.alibaba.fastjson2.JSONObject import com.alibaba.fastjson2.JSONObject
import work.slhaf.partner.framework.agent.model.StreamChatMessageConsumer import work.slhaf.partner.framework.agent.model.StreamChatMessageConsumer
import work.slhaf.partner.framework.agent.model.pojo.Message import work.slhaf.partner.framework.agent.model.pojo.Message
import work.slhaf.partner.framework.agent.support.Result
abstract class ModelProvider @JvmOverloads constructor( abstract class ModelProvider @JvmOverloads constructor(
val providerName: String,
val modelKey: String,
val override: ProviderOverride? = null val override: ProviderOverride? = null
) { ) {
abstract fun fork(override: ProviderOverride): ModelProvider abstract fun fork(modelKey: String, override: ProviderOverride? = null): ModelProvider
abstract fun streamChat(messages: List<Message>, consumer: StreamChatMessageConsumer) abstract fun streamChat(messages: List<Message>, consumer: StreamChatMessageConsumer): Result<Unit>
abstract fun chat(messages: List<Message>): String abstract fun chat(messages: List<Message>): Result<String>
abstract fun <T> formattedChat(messages: List<Message>, type: Class<T>): T abstract fun <T> formattedChat(messages: List<Message>, type: Class<T>): Result<T>
} }
data class ProviderOverride( data class ProviderOverride(

View File

@@ -6,24 +6,33 @@ import com.openai.client.okhttp.OpenAIOkHttpClient;
import com.openai.core.JsonValue; import com.openai.core.JsonValue;
import com.openai.core.http.StreamResponse; import com.openai.core.http.StreamResponse;
import com.openai.models.chat.completions.*; import com.openai.models.chat.completions.*;
import kotlin.Unit;
import org.jetbrains.annotations.NotNull; import org.jetbrains.annotations.NotNull;
import work.slhaf.partner.framework.agent.exception.ModelInvokeException;
import work.slhaf.partner.framework.agent.model.StreamChatMessageConsumer; import work.slhaf.partner.framework.agent.model.StreamChatMessageConsumer;
import work.slhaf.partner.framework.agent.model.pojo.Message; import work.slhaf.partner.framework.agent.model.pojo.Message;
import work.slhaf.partner.framework.agent.model.provider.ModelProvider; import work.slhaf.partner.framework.agent.model.provider.ModelProvider;
import work.slhaf.partner.framework.agent.model.provider.ProviderOverride; import work.slhaf.partner.framework.agent.model.provider.ProviderOverride;
import work.slhaf.partner.framework.agent.support.Result;
import java.time.Duration; import java.time.Duration;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List; import java.util.List;
import java.util.Map;
public class OpenAiCompatibleProvider extends ModelProvider { public class OpenAiCompatibleProvider extends ModelProvider {
private static final int MAX_ATTEMPTS = 3;
private final String baseUrl; private final String baseUrl;
private final String apiKey; private final String apiKey;
private final String model; private final String model;
private final OpenAIClient client; private final OpenAIClient client;
public OpenAiCompatibleProvider(String baseUrl, String apikey, String model) { public OpenAiCompatibleProvider(String providerName, String modelKey, String baseUrl, String apikey, String model) {
super(providerName, modelKey, null);
this.client = OpenAIOkHttpClient.builder() this.client = OpenAIOkHttpClient.builder()
.baseUrl(baseUrl) .baseUrl(baseUrl)
.apiKey(apikey) .apiKey(apikey)
@@ -34,8 +43,8 @@ public class OpenAiCompatibleProvider extends ModelProvider {
this.model = model; this.model = model;
} }
public OpenAiCompatibleProvider(String baseUrl, String apikey, String model, ProviderOverride override) { public OpenAiCompatibleProvider(String providerName, String modelKey, String baseUrl, String apikey, String model, ProviderOverride override) {
super(override); super(providerName, modelKey, override);
this.client = OpenAIOkHttpClient.builder() this.client = OpenAIOkHttpClient.builder()
.baseUrl(baseUrl) .baseUrl(baseUrl)
.apiKey(apikey) .apiKey(apikey)
@@ -46,27 +55,59 @@ public class OpenAiCompatibleProvider extends ModelProvider {
this.model = model; this.model = model;
} }
public @NotNull String chat(@NotNull List<Message> messages) { @Override
ChatCompletionCreateParams params = buildParams(messages); public @NotNull Result<String> chat(@NotNull List<Message> messages) {
return extractText(client.chat().completions().create(params)); return executeWithRetry(
"OpenAI-compatible provider failed to complete the chat request after 3 attempts.",
() -> extractText(client.chat().completions().create(buildParams(messages)))
);
} }
public void streamChat(@NotNull List<Message> messages, StreamChatMessageConsumer handler) { @Override
ChatCompletionCreateParams params = buildParams(messages); public @NotNull Result<Unit> streamChat(@NotNull List<Message> messages, @NotNull StreamChatMessageConsumer handler) {
try (StreamResponse<ChatCompletionChunk> streamResponse = client.chat().completions().createStreaming(params)) { Exception lastFailure = null;
streamResponse.stream() int remainingAttempts = MAX_ATTEMPTS;
.flatMap(completion -> completion.choices().stream()) while (remainingAttempts > 0) {
.flatMap(choice -> choice.delta().content().stream()) boolean emitted = false;
.filter(delta -> !delta.isEmpty()) try (StreamResponse<ChatCompletionChunk> streamResponse = client.chat().completions().createStreaming(buildParams(messages))) {
.forEach(handler::onDelta); Iterator<ChatCompletionChunk> iterator = streamResponse.stream().iterator();
while (iterator.hasNext()) {
ChatCompletionChunk chunk = iterator.next();
for (ChatCompletionChunk.Choice choice : chunk.choices()) {
String delta = choice.delta().content().orElse("");
if (delta.isEmpty()) {
continue;
}
emitted = true;
handler.onDelta(delta);
}
}
return Result.success(Unit.INSTANCE);
} catch (Exception e) {
lastFailure = e;
remainingAttempts--;
if (emitted || remainingAttempts == 0) {
break;
}
}
} }
return Result.failure(invokeException(
"OpenAI-compatible provider failed to stream the chat response after 3 attempts.",
lastFailure
));
} }
public <T> T formattedChat(@NotNull List<Message> messages, @NotNull Class<T> responseType) { @Override
StructuredChatCompletionCreateParams<T> params = buildParams(messages).toBuilder() public <T> @NotNull Result<T> formattedChat(@NotNull List<Message> messages, @NotNull Class<T> responseType) {
.responseFormat(responseType) return executeWithRetry(
.build(); "OpenAI-compatible provider failed to complete the structured chat request after 3 attempts.",
return extractStructured(client.chat().completions().create(params)); () -> {
StructuredChatCompletionCreateParams<T> params = buildParams(messages).toBuilder()
.responseFormat(responseType)
.build();
return extractStructured(client.chat().completions().create(params));
}
);
} }
private ChatCompletionCreateParams buildParams(List<Message> messages) { private ChatCompletionCreateParams buildParams(List<Message> messages) {
@@ -87,9 +128,7 @@ public class OpenAiCompatibleProvider extends ModelProvider {
} }
JSONObject extras = override.getExtras(); JSONObject extras = override.getExtras();
if (extras != null) { if (extras != null) {
extras.forEach((key, value) -> { extras.forEach((key, value) -> paramsBuilder.putAdditionalBodyProperty(key, JsonValue.from(value)));
paramsBuilder.putAdditionalBodyProperty(key, JsonValue.from(value));
});
} }
} }
@@ -98,22 +137,81 @@ public class OpenAiCompatibleProvider extends ModelProvider {
private String extractText(ChatCompletion completion) { private String extractText(ChatCompletion completion) {
if (completion.choices().isEmpty()) { if (completion.choices().isEmpty()) {
throw new IllegalStateException("OpenAI chat completion returned no choices."); throw invokeException("OpenAI chat completion returned no choices.", null);
} }
return completion.choices().getFirst().message().content() return completion.choices().getFirst().message().content()
.orElseThrow(() -> new IllegalStateException("OpenAI chat completion returned empty content.")); .orElseThrow(() -> invokeException("OpenAI chat completion returned empty content.", null));
} }
private <T> T extractStructured(StructuredChatCompletion<T> completion) { private <T> T extractStructured(StructuredChatCompletion<T> completion) {
if (completion.choices().isEmpty()) { if (completion.choices().isEmpty()) {
throw new IllegalStateException("OpenAI structured chat completion returned no choices."); throw invokeException("OpenAI structured chat completion returned no choices.", null);
} }
return completion.choices().getFirst().message().content() return completion.choices().getFirst().message().content()
.orElseThrow(() -> new IllegalStateException("OpenAI structured chat completion returned empty content.")); .orElseThrow(() -> invokeException("OpenAI structured chat completion returned empty content.", null));
} }
@Override @Override
public @NotNull ModelProvider fork(@NotNull ProviderOverride override) { public @NotNull ModelProvider fork(@NotNull String modelKey, ProviderOverride override) {
return new OpenAiCompatibleProvider(baseUrl, apiKey, override.getModel(), override); if (override == null) {
return new OpenAiCompatibleProvider(getProviderName(), modelKey, baseUrl, apiKey, model, getOverride());
}
return new OpenAiCompatibleProvider(getProviderName(), modelKey, baseUrl, apiKey, override.getModel(), override);
}
private <T> Result<T> executeWithRetry(String failureMessage, ThrowingSupplier<T> supplier) {
Exception lastFailure = null;
for (int attempt = 1; attempt <= MAX_ATTEMPTS; attempt++) {
Result<T> result = Result.runCatching(supplier::get);
if (result.isSuccess()) {
return result;
}
Throwable throwable = result.exceptionOrNull();
if (throwable instanceof Exception exception) {
lastFailure = exception;
continue;
}
if (throwable instanceof Error error) {
throw error;
}
return Result.failure(invokeException(failureMessage, throwable));
}
return Result.failure(invokeException(failureMessage, lastFailure));
}
private ModelInvokeException invokeException(String message, Throwable cause) {
return new ModelInvokeException(
message,
getProviderName(),
getModelKey(),
baseUrl,
model,
getOverride() == null ? Map.of() : toOverrideReport(getOverride()),
cause
);
}
private Map<String, String> toOverrideReport(ProviderOverride override) {
Map<String, String> result = new LinkedHashMap<>();
result.put("model", override.getModel());
if (override.getTemperature() != null) {
result.put("temperature", override.getTemperature().toString());
}
if (override.getTopP() != null) {
result.put("topP", override.getTopP().toString());
}
if (override.getMaxTokens() != null) {
result.put("maxTokens", override.getMaxTokens().toString());
}
JSONObject extras = override.getExtras();
if (extras != null) {
extras.forEach((key, value) -> result.put("extra." + key, value == null ? "null" : value.toString()));
}
return result;
}
@FunctionalInterface
private interface ThrowingSupplier<T> {
T get() throws Exception;
} }
} }

View File

@@ -0,0 +1,76 @@
package work.slhaf.partner.framework.agent.support
import work.slhaf.partner.framework.agent.exception.AgentRuntimeException
class Result<T> private constructor(
private val value: T?,
private val exception: Throwable?
) {
fun isSuccess(): Boolean = exception == null
fun isFailure(): Boolean = exception != null
fun getOrNull(): T? = value
fun exceptionOrNull(): Throwable? = exception
fun getOrThrow(): T {
if (exception == null) {
@Suppress("UNCHECKED_CAST")
return value as T
}
when (exception) {
is RuntimeException -> throw exception
is Error -> throw exception
else -> throw IllegalStateException(exception.message, exception)
}
}
fun getOrDefault(defaultValue: T): T {
return if (exception == null) {
@Suppress("UNCHECKED_CAST")
value as T
} else {
defaultValue
}
}
override fun toString(): String {
return if (exception == null) {
"Result.success($value)"
} else {
"Result.failure($exception)"
}
}
fun interface ThrowingSupplier<T> {
@Throws(Throwable::class)
fun get(): T
}
companion object {
@JvmStatic
fun <T> success(value: T): Result<T> = Result(value, null)
@JvmStatic
fun <T> failure(exception: Throwable): Result<T> = Result(null, exception)
@JvmStatic
fun <T> runCatching(block: ThrowingSupplier<T>): Result<T> {
return try {
success(block.get())
} catch (throwable: Throwable) {
failure(
when (throwable) {
is AgentRuntimeException, is Error -> throwable
else -> AgentRuntimeException(
throwable.message ?: "Unexpected runtime failure.",
throwable
)
}
)
}
}
}
}