From 94adf9a3683d268ad23af14faac601795cbfb64d Mon Sep 17 00:00:00 2001 From: slhafzjw Date: Sat, 25 Apr 2026 22:12:07 +0800 Subject: [PATCH] fix(openai-provider): add cached structured-output downgrade to prompt-only JSON fallback --- .../openai/OpenAiCompatibleProvider.java | 88 +++++++++++++++++-- .../StructuredOutputFailureClassifier.java | 87 ++++++++++++++++++ 2 files changed, 169 insertions(+), 6 deletions(-) create mode 100644 Partner-Framework/src/main/java/work/slhaf/partner/framework/agent/model/provider/openai/StructuredOutputFailureClassifier.java diff --git a/Partner-Framework/src/main/java/work/slhaf/partner/framework/agent/model/provider/openai/OpenAiCompatibleProvider.java b/Partner-Framework/src/main/java/work/slhaf/partner/framework/agent/model/provider/openai/OpenAiCompatibleProvider.java index 02593dcd..c5382eb1 100644 --- a/Partner-Framework/src/main/java/work/slhaf/partner/framework/agent/model/provider/openai/OpenAiCompatibleProvider.java +++ b/Partner-Framework/src/main/java/work/slhaf/partner/framework/agent/model/provider/openai/OpenAiCompatibleProvider.java @@ -1,5 +1,6 @@ package work.slhaf.partner.framework.agent.model.provider.openai; +import com.alibaba.fastjson2.JSON; import com.alibaba.fastjson2.JSONObject; import com.openai.client.OpenAIClient; import com.openai.client.okhttp.OpenAIOkHttpClient; @@ -18,10 +19,13 @@ import work.slhaf.partner.framework.agent.support.Result; import java.time.Duration; import java.util.*; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; public class OpenAiCompatibleProvider extends ModelProvider { private static final int MAX_ATTEMPTS = 3; + private static final ConcurrentMap STRUCTURED_OUTPUT_MODE_CACHE = new ConcurrentHashMap<>(); private final String baseUrl; private final String apiKey; @@ -99,15 +103,52 @@ public class OpenAiCompatibleProvider extends ModelProvider { public @NotNull Result formattedChat(@NotNull List messages, @NotNull Class responseType) { return executeWithRetry( "OpenAI-compatible provider failed to complete the structured chat request after 3 attempts.", - () -> { - StructuredChatCompletionCreateParams params = buildParams(ensureJsonInstruction(messages, responseType)).toBuilder() - .responseFormat(responseType) - .build(); - return extractStructured(client.chat().completions().create(params)); - } + () -> formattedChatByCachedMode(messages, responseType) ); } + private T formattedChatByCachedMode(List messages, Class responseType) { + StructuredOutputCacheKey cacheKey = new StructuredOutputCacheKey(baseUrl, model); + StructuredOutputMode mode = STRUCTURED_OUTPUT_MODE_CACHE.getOrDefault(cacheKey, StructuredOutputMode.UNKNOWN); + if (mode == StructuredOutputMode.PROMPT_ONLY_JSON) { + return promptOnlyFormattedChat(messages, responseType); + } + return strictThenPromptFallback(messages, responseType, cacheKey); + } + + private T strictThenPromptFallback(List messages, Class responseType, StructuredOutputCacheKey cacheKey) { + try { + T result = strictFormattedChat(messages, responseType); + STRUCTURED_OUTPUT_MODE_CACHE.put(cacheKey, StructuredOutputMode.STRICT_RESPONSE_FORMAT); + return result; + } catch (Exception structuredFailure) { + try { + T result = promptOnlyFormattedChat(messages, responseType); + if (StructuredOutputFailureClassifier.shouldDowngradeToPromptOnlyJson(structuredFailure)) { + STRUCTURED_OUTPUT_MODE_CACHE.put(cacheKey, StructuredOutputMode.PROMPT_ONLY_JSON); + } + return result; + } catch (Exception fallbackFailure) { + structuredFailure.addSuppressed(fallbackFailure); + throw structuredFailure; + } + } + } + + private T strictFormattedChat(List messages, Class responseType) { + StructuredChatCompletionCreateParams params = buildParams(ensureJsonInstruction(messages, responseType)).toBuilder() + .responseFormat(responseType) + .build(); + return extractStructured(client.chat().completions().create(params)); + } + + private T promptOnlyFormattedChat(List messages, Class responseType) { + ChatCompletionCreateParams params = buildParams(ensureJsonInstruction(messages, responseType)); + String rawText = extractText(client.chat().completions().create(params)); + String jsonText = extractJsonObject(rawText); + return JSON.parseObject(jsonText, responseType); + } + private List ensureJsonInstruction(List messages, Class responseType) { String jsonInstruction = JsonShapeInstructionBuilder.build(responseType); @@ -132,6 +173,32 @@ public class OpenAiCompatibleProvider extends ModelProvider { } + private String extractJsonObject(String text) { + String trimmed = text == null ? "" : text.trim(); + if (trimmed.isBlank()) { + throw invokeException("OpenAI-compatible provider returned empty content in prompt-only JSON fallback.", null); + } + trimmed = stripMarkdownFence(trimmed); + if (trimmed.startsWith("{") && trimmed.endsWith("}")) { + return trimmed; + } + int start = trimmed.indexOf('{'); + int end = trimmed.lastIndexOf('}'); + if (start >= 0 && end > start) { + return trimmed.substring(start, end + 1).trim(); + } + throw invokeException("OpenAI-compatible provider prompt-only JSON fallback returned no JSON object.", null); + } + + private String stripMarkdownFence(String text) { + String trimmed = text.trim(); + if (!trimmed.startsWith("```")) { + return trimmed; + } + String withoutOpeningFence = trimmed.replaceFirst("^```[a-zA-Z0-9_-]*\\s*", ""); + return withoutOpeningFence.replaceFirst("\\s*```$", "").trim(); + } + private ChatCompletionCreateParams buildParams(List messages) { ChatCompletionCreateParams.Builder paramsBuilder = ChatCompletionCreateParams.builder() .model(model) @@ -225,6 +292,15 @@ public class OpenAiCompatibleProvider extends ModelProvider { return result; } + private enum StructuredOutputMode { + UNKNOWN, + STRICT_RESPONSE_FORMAT, + PROMPT_ONLY_JSON + } + + private record StructuredOutputCacheKey(String baseUrl, String model) { + } + @FunctionalInterface private interface ThrowingSupplier { T get() throws Exception; diff --git a/Partner-Framework/src/main/java/work/slhaf/partner/framework/agent/model/provider/openai/StructuredOutputFailureClassifier.java b/Partner-Framework/src/main/java/work/slhaf/partner/framework/agent/model/provider/openai/StructuredOutputFailureClassifier.java new file mode 100644 index 00000000..db0710d6 --- /dev/null +++ b/Partner-Framework/src/main/java/work/slhaf/partner/framework/agent/model/provider/openai/StructuredOutputFailureClassifier.java @@ -0,0 +1,87 @@ +package work.slhaf.partner.framework.agent.model.provider.openai; + +import java.util.*; + +final class StructuredOutputFailureClassifier { + + private static final List TRANSIENT_FAILURE_PATTERNS = List.of( + "timeout", + "error reading response", + "streamresetexception", + "interruptedioexception", + "sockettimeoutexception", + "connectexception", + "connection reset", + "503", + "502", + "504", + "429", + "internalserverexception", + "provider returned error", + "openaiioexception", + "request failed" + ); + private static final List AUTH_OR_CONFIG_FAILURE_PATTERNS = List.of( + "没有权限", + "not activated", + "permission", + "unauthorized", + "forbidden", + "invalid api key", + "sslexception", + "unsupported or unrecognized ssl message" + ); + private static final List STRUCTURED_COMPATIBILITY_FAILURE_PATTERNS = List.of( + "response_format type is unavailable", + "messages must contain the word 'json'", + "messages must contain the word json", + "structured chat completion returned empty content", + "error parsing json:", + "unrecognizedpropertyexception", + "mismatchedinputexception", + "invalidformatexception" + ); + + private StructuredOutputFailureClassifier() { + } + + static boolean shouldDowngradeToPromptOnlyJson(Throwable failure) { + FailureSnapshot snapshot = FailureSnapshot.from(failure); + if (snapshot.containsAny(TRANSIENT_FAILURE_PATTERNS)) { + return false; + } + if (snapshot.containsAny(AUTH_OR_CONFIG_FAILURE_PATTERNS)) { + return false; + } + return snapshot.containsAny(STRUCTURED_COMPATIBILITY_FAILURE_PATTERNS); + } + + private record FailureSnapshot(String text) { + + static FailureSnapshot from(Throwable failure) { + Set visited = Collections.newSetFromMap(new IdentityHashMap<>()); + StringBuilder builder = new StringBuilder(); + collect(failure, builder, visited); + return new FailureSnapshot(builder.toString().toLowerCase(Locale.ROOT)); + } + + private static void collect(Throwable throwable, StringBuilder builder, Set visited) { + if (throwable == null || visited.contains(throwable)) { + return; + } + visited.add(throwable); + builder.append(throwable.getClass().getName()).append('\n'); + if (throwable.getMessage() != null) { + builder.append(throwable.getMessage()).append('\n'); + } + collect(throwable.getCause(), builder, visited); + for (Throwable suppressed : throwable.getSuppressed()) { + collect(suppressed, builder, visited); + } + } + + boolean containsAny(List patterns) { + return patterns.stream().anyMatch(text::contains); + } + } +}