fix(openai-provider): add cached structured-output downgrade to prompt-only JSON fallback

This commit is contained in:
2026-04-25 22:12:07 +08:00
parent dd5ab3aaf3
commit 94adf9a368
2 changed files with 169 additions and 6 deletions

View File

@@ -1,5 +1,6 @@
package work.slhaf.partner.framework.agent.model.provider.openai; package work.slhaf.partner.framework.agent.model.provider.openai;
import com.alibaba.fastjson2.JSON;
import com.alibaba.fastjson2.JSONObject; import com.alibaba.fastjson2.JSONObject;
import com.openai.client.OpenAIClient; import com.openai.client.OpenAIClient;
import com.openai.client.okhttp.OpenAIOkHttpClient; import com.openai.client.okhttp.OpenAIOkHttpClient;
@@ -18,10 +19,13 @@ import work.slhaf.partner.framework.agent.support.Result;
import java.time.Duration; import java.time.Duration;
import java.util.*; import java.util.*;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
public class OpenAiCompatibleProvider extends ModelProvider { public class OpenAiCompatibleProvider extends ModelProvider {
private static final int MAX_ATTEMPTS = 3; private static final int MAX_ATTEMPTS = 3;
private static final ConcurrentMap<StructuredOutputCacheKey, StructuredOutputMode> STRUCTURED_OUTPUT_MODE_CACHE = new ConcurrentHashMap<>();
private final String baseUrl; private final String baseUrl;
private final String apiKey; private final String apiKey;
@@ -99,13 +103,50 @@ public class OpenAiCompatibleProvider extends ModelProvider {
public <T> @NotNull Result<T> formattedChat(@NotNull List<Message> messages, @NotNull Class<T> responseType) { public <T> @NotNull Result<T> formattedChat(@NotNull List<Message> messages, @NotNull Class<T> responseType) {
return executeWithRetry( return executeWithRetry(
"OpenAI-compatible provider failed to complete the structured chat request after 3 attempts.", "OpenAI-compatible provider failed to complete the structured chat request after 3 attempts.",
() -> { () -> formattedChatByCachedMode(messages, responseType)
);
}
private <T> T formattedChatByCachedMode(List<Message> messages, Class<T> responseType) {
StructuredOutputCacheKey cacheKey = new StructuredOutputCacheKey(baseUrl, model);
StructuredOutputMode mode = STRUCTURED_OUTPUT_MODE_CACHE.getOrDefault(cacheKey, StructuredOutputMode.UNKNOWN);
if (mode == StructuredOutputMode.PROMPT_ONLY_JSON) {
return promptOnlyFormattedChat(messages, responseType);
}
return strictThenPromptFallback(messages, responseType, cacheKey);
}
private <T> T strictThenPromptFallback(List<Message> messages, Class<T> responseType, StructuredOutputCacheKey cacheKey) {
try {
T result = strictFormattedChat(messages, responseType);
STRUCTURED_OUTPUT_MODE_CACHE.put(cacheKey, StructuredOutputMode.STRICT_RESPONSE_FORMAT);
return result;
} catch (Exception structuredFailure) {
try {
T result = promptOnlyFormattedChat(messages, responseType);
if (StructuredOutputFailureClassifier.shouldDowngradeToPromptOnlyJson(structuredFailure)) {
STRUCTURED_OUTPUT_MODE_CACHE.put(cacheKey, StructuredOutputMode.PROMPT_ONLY_JSON);
}
return result;
} catch (Exception fallbackFailure) {
structuredFailure.addSuppressed(fallbackFailure);
throw structuredFailure;
}
}
}
private <T> T strictFormattedChat(List<Message> messages, Class<T> responseType) {
StructuredChatCompletionCreateParams<T> params = buildParams(ensureJsonInstruction(messages, responseType)).toBuilder() StructuredChatCompletionCreateParams<T> params = buildParams(ensureJsonInstruction(messages, responseType)).toBuilder()
.responseFormat(responseType) .responseFormat(responseType)
.build(); .build();
return extractStructured(client.chat().completions().create(params)); return extractStructured(client.chat().completions().create(params));
} }
);
private <T> T promptOnlyFormattedChat(List<Message> messages, Class<T> responseType) {
ChatCompletionCreateParams params = buildParams(ensureJsonInstruction(messages, responseType));
String rawText = extractText(client.chat().completions().create(params));
String jsonText = extractJsonObject(rawText);
return JSON.parseObject(jsonText, responseType);
} }
private List<Message> ensureJsonInstruction(List<Message> messages, Class<?> responseType) { private List<Message> ensureJsonInstruction(List<Message> messages, Class<?> responseType) {
@@ -132,6 +173,32 @@ public class OpenAiCompatibleProvider extends ModelProvider {
} }
private String extractJsonObject(String text) {
String trimmed = text == null ? "" : text.trim();
if (trimmed.isBlank()) {
throw invokeException("OpenAI-compatible provider returned empty content in prompt-only JSON fallback.", null);
}
trimmed = stripMarkdownFence(trimmed);
if (trimmed.startsWith("{") && trimmed.endsWith("}")) {
return trimmed;
}
int start = trimmed.indexOf('{');
int end = trimmed.lastIndexOf('}');
if (start >= 0 && end > start) {
return trimmed.substring(start, end + 1).trim();
}
throw invokeException("OpenAI-compatible provider prompt-only JSON fallback returned no JSON object.", null);
}
private String stripMarkdownFence(String text) {
String trimmed = text.trim();
if (!trimmed.startsWith("```")) {
return trimmed;
}
String withoutOpeningFence = trimmed.replaceFirst("^```[a-zA-Z0-9_-]*\\s*", "");
return withoutOpeningFence.replaceFirst("\\s*```$", "").trim();
}
private ChatCompletionCreateParams buildParams(List<Message> messages) { private ChatCompletionCreateParams buildParams(List<Message> messages) {
ChatCompletionCreateParams.Builder paramsBuilder = ChatCompletionCreateParams.builder() ChatCompletionCreateParams.Builder paramsBuilder = ChatCompletionCreateParams.builder()
.model(model) .model(model)
@@ -225,6 +292,15 @@ public class OpenAiCompatibleProvider extends ModelProvider {
return result; return result;
} }
private enum StructuredOutputMode {
UNKNOWN,
STRICT_RESPONSE_FORMAT,
PROMPT_ONLY_JSON
}
private record StructuredOutputCacheKey(String baseUrl, String model) {
}
@FunctionalInterface @FunctionalInterface
private interface ThrowingSupplier<T> { private interface ThrowingSupplier<T> {
T get() throws Exception; T get() throws Exception;

View File

@@ -0,0 +1,87 @@
package work.slhaf.partner.framework.agent.model.provider.openai;
import java.util.*;
final class StructuredOutputFailureClassifier {
private static final List<String> TRANSIENT_FAILURE_PATTERNS = List.of(
"timeout",
"error reading response",
"streamresetexception",
"interruptedioexception",
"sockettimeoutexception",
"connectexception",
"connection reset",
"503",
"502",
"504",
"429",
"internalserverexception",
"provider returned error",
"openaiioexception",
"request failed"
);
private static final List<String> AUTH_OR_CONFIG_FAILURE_PATTERNS = List.of(
"没有权限",
"not activated",
"permission",
"unauthorized",
"forbidden",
"invalid api key",
"sslexception",
"unsupported or unrecognized ssl message"
);
private static final List<String> STRUCTURED_COMPATIBILITY_FAILURE_PATTERNS = List.of(
"response_format type is unavailable",
"messages must contain the word 'json'",
"messages must contain the word json",
"structured chat completion returned empty content",
"error parsing json:",
"unrecognizedpropertyexception",
"mismatchedinputexception",
"invalidformatexception"
);
private StructuredOutputFailureClassifier() {
}
static boolean shouldDowngradeToPromptOnlyJson(Throwable failure) {
FailureSnapshot snapshot = FailureSnapshot.from(failure);
if (snapshot.containsAny(TRANSIENT_FAILURE_PATTERNS)) {
return false;
}
if (snapshot.containsAny(AUTH_OR_CONFIG_FAILURE_PATTERNS)) {
return false;
}
return snapshot.containsAny(STRUCTURED_COMPATIBILITY_FAILURE_PATTERNS);
}
private record FailureSnapshot(String text) {
static FailureSnapshot from(Throwable failure) {
Set<Throwable> visited = Collections.newSetFromMap(new IdentityHashMap<>());
StringBuilder builder = new StringBuilder();
collect(failure, builder, visited);
return new FailureSnapshot(builder.toString().toLowerCase(Locale.ROOT));
}
private static void collect(Throwable throwable, StringBuilder builder, Set<Throwable> visited) {
if (throwable == null || visited.contains(throwable)) {
return;
}
visited.add(throwable);
builder.append(throwable.getClass().getName()).append('\n');
if (throwable.getMessage() != null) {
builder.append(throwable.getMessage()).append('\n');
}
collect(throwable.getCause(), builder, visited);
for (Throwable suppressed : throwable.getSuppressed()) {
collect(suppressed, builder, visited);
}
}
boolean containsAny(List<String> patterns) {
return patterns.stream().anyMatch(text::contains);
}
}
}