mirror of
https://github.com/slhaf/Partner.git
synced 2026-05-12 16:53:04 +08:00
fix(openai-provider): add cached structured-output downgrade to prompt-only JSON fallback
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
package work.slhaf.partner.framework.agent.model.provider.openai;
|
||||
|
||||
import com.alibaba.fastjson2.JSON;
|
||||
import com.alibaba.fastjson2.JSONObject;
|
||||
import com.openai.client.OpenAIClient;
|
||||
import com.openai.client.okhttp.OpenAIOkHttpClient;
|
||||
@@ -18,10 +19,13 @@ import work.slhaf.partner.framework.agent.support.Result;
|
||||
|
||||
import java.time.Duration;
|
||||
import java.util.*;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
import java.util.concurrent.ConcurrentMap;
|
||||
|
||||
public class OpenAiCompatibleProvider extends ModelProvider {
|
||||
|
||||
private static final int MAX_ATTEMPTS = 3;
|
||||
private static final ConcurrentMap<StructuredOutputCacheKey, StructuredOutputMode> STRUCTURED_OUTPUT_MODE_CACHE = new ConcurrentHashMap<>();
|
||||
|
||||
private final String baseUrl;
|
||||
private final String apiKey;
|
||||
@@ -99,15 +103,52 @@ public class OpenAiCompatibleProvider extends ModelProvider {
|
||||
public <T> @NotNull Result<T> formattedChat(@NotNull List<Message> messages, @NotNull Class<T> responseType) {
|
||||
return executeWithRetry(
|
||||
"OpenAI-compatible provider failed to complete the structured chat request after 3 attempts.",
|
||||
() -> {
|
||||
StructuredChatCompletionCreateParams<T> params = buildParams(ensureJsonInstruction(messages, responseType)).toBuilder()
|
||||
.responseFormat(responseType)
|
||||
.build();
|
||||
return extractStructured(client.chat().completions().create(params));
|
||||
}
|
||||
() -> formattedChatByCachedMode(messages, responseType)
|
||||
);
|
||||
}
|
||||
|
||||
private <T> T formattedChatByCachedMode(List<Message> messages, Class<T> responseType) {
|
||||
StructuredOutputCacheKey cacheKey = new StructuredOutputCacheKey(baseUrl, model);
|
||||
StructuredOutputMode mode = STRUCTURED_OUTPUT_MODE_CACHE.getOrDefault(cacheKey, StructuredOutputMode.UNKNOWN);
|
||||
if (mode == StructuredOutputMode.PROMPT_ONLY_JSON) {
|
||||
return promptOnlyFormattedChat(messages, responseType);
|
||||
}
|
||||
return strictThenPromptFallback(messages, responseType, cacheKey);
|
||||
}
|
||||
|
||||
private <T> T strictThenPromptFallback(List<Message> messages, Class<T> responseType, StructuredOutputCacheKey cacheKey) {
|
||||
try {
|
||||
T result = strictFormattedChat(messages, responseType);
|
||||
STRUCTURED_OUTPUT_MODE_CACHE.put(cacheKey, StructuredOutputMode.STRICT_RESPONSE_FORMAT);
|
||||
return result;
|
||||
} catch (Exception structuredFailure) {
|
||||
try {
|
||||
T result = promptOnlyFormattedChat(messages, responseType);
|
||||
if (StructuredOutputFailureClassifier.shouldDowngradeToPromptOnlyJson(structuredFailure)) {
|
||||
STRUCTURED_OUTPUT_MODE_CACHE.put(cacheKey, StructuredOutputMode.PROMPT_ONLY_JSON);
|
||||
}
|
||||
return result;
|
||||
} catch (Exception fallbackFailure) {
|
||||
structuredFailure.addSuppressed(fallbackFailure);
|
||||
throw structuredFailure;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private <T> T strictFormattedChat(List<Message> messages, Class<T> responseType) {
|
||||
StructuredChatCompletionCreateParams<T> params = buildParams(ensureJsonInstruction(messages, responseType)).toBuilder()
|
||||
.responseFormat(responseType)
|
||||
.build();
|
||||
return extractStructured(client.chat().completions().create(params));
|
||||
}
|
||||
|
||||
private <T> T promptOnlyFormattedChat(List<Message> messages, Class<T> responseType) {
|
||||
ChatCompletionCreateParams params = buildParams(ensureJsonInstruction(messages, responseType));
|
||||
String rawText = extractText(client.chat().completions().create(params));
|
||||
String jsonText = extractJsonObject(rawText);
|
||||
return JSON.parseObject(jsonText, responseType);
|
||||
}
|
||||
|
||||
private List<Message> ensureJsonInstruction(List<Message> messages, Class<?> responseType) {
|
||||
String jsonInstruction = JsonShapeInstructionBuilder.build(responseType);
|
||||
|
||||
@@ -132,6 +173,32 @@ public class OpenAiCompatibleProvider extends ModelProvider {
|
||||
}
|
||||
|
||||
|
||||
private String extractJsonObject(String text) {
|
||||
String trimmed = text == null ? "" : text.trim();
|
||||
if (trimmed.isBlank()) {
|
||||
throw invokeException("OpenAI-compatible provider returned empty content in prompt-only JSON fallback.", null);
|
||||
}
|
||||
trimmed = stripMarkdownFence(trimmed);
|
||||
if (trimmed.startsWith("{") && trimmed.endsWith("}")) {
|
||||
return trimmed;
|
||||
}
|
||||
int start = trimmed.indexOf('{');
|
||||
int end = trimmed.lastIndexOf('}');
|
||||
if (start >= 0 && end > start) {
|
||||
return trimmed.substring(start, end + 1).trim();
|
||||
}
|
||||
throw invokeException("OpenAI-compatible provider prompt-only JSON fallback returned no JSON object.", null);
|
||||
}
|
||||
|
||||
private String stripMarkdownFence(String text) {
|
||||
String trimmed = text.trim();
|
||||
if (!trimmed.startsWith("```")) {
|
||||
return trimmed;
|
||||
}
|
||||
String withoutOpeningFence = trimmed.replaceFirst("^```[a-zA-Z0-9_-]*\\s*", "");
|
||||
return withoutOpeningFence.replaceFirst("\\s*```$", "").trim();
|
||||
}
|
||||
|
||||
private ChatCompletionCreateParams buildParams(List<Message> messages) {
|
||||
ChatCompletionCreateParams.Builder paramsBuilder = ChatCompletionCreateParams.builder()
|
||||
.model(model)
|
||||
@@ -225,6 +292,15 @@ public class OpenAiCompatibleProvider extends ModelProvider {
|
||||
return result;
|
||||
}
|
||||
|
||||
private enum StructuredOutputMode {
|
||||
UNKNOWN,
|
||||
STRICT_RESPONSE_FORMAT,
|
||||
PROMPT_ONLY_JSON
|
||||
}
|
||||
|
||||
private record StructuredOutputCacheKey(String baseUrl, String model) {
|
||||
}
|
||||
|
||||
@FunctionalInterface
|
||||
private interface ThrowingSupplier<T> {
|
||||
T get() throws Exception;
|
||||
|
||||
@@ -0,0 +1,87 @@
|
||||
package work.slhaf.partner.framework.agent.model.provider.openai;
|
||||
|
||||
import java.util.*;
|
||||
|
||||
final class StructuredOutputFailureClassifier {
|
||||
|
||||
private static final List<String> TRANSIENT_FAILURE_PATTERNS = List.of(
|
||||
"timeout",
|
||||
"error reading response",
|
||||
"streamresetexception",
|
||||
"interruptedioexception",
|
||||
"sockettimeoutexception",
|
||||
"connectexception",
|
||||
"connection reset",
|
||||
"503",
|
||||
"502",
|
||||
"504",
|
||||
"429",
|
||||
"internalserverexception",
|
||||
"provider returned error",
|
||||
"openaiioexception",
|
||||
"request failed"
|
||||
);
|
||||
private static final List<String> AUTH_OR_CONFIG_FAILURE_PATTERNS = List.of(
|
||||
"没有权限",
|
||||
"not activated",
|
||||
"permission",
|
||||
"unauthorized",
|
||||
"forbidden",
|
||||
"invalid api key",
|
||||
"sslexception",
|
||||
"unsupported or unrecognized ssl message"
|
||||
);
|
||||
private static final List<String> STRUCTURED_COMPATIBILITY_FAILURE_PATTERNS = List.of(
|
||||
"response_format type is unavailable",
|
||||
"messages must contain the word 'json'",
|
||||
"messages must contain the word json",
|
||||
"structured chat completion returned empty content",
|
||||
"error parsing json:",
|
||||
"unrecognizedpropertyexception",
|
||||
"mismatchedinputexception",
|
||||
"invalidformatexception"
|
||||
);
|
||||
|
||||
private StructuredOutputFailureClassifier() {
|
||||
}
|
||||
|
||||
static boolean shouldDowngradeToPromptOnlyJson(Throwable failure) {
|
||||
FailureSnapshot snapshot = FailureSnapshot.from(failure);
|
||||
if (snapshot.containsAny(TRANSIENT_FAILURE_PATTERNS)) {
|
||||
return false;
|
||||
}
|
||||
if (snapshot.containsAny(AUTH_OR_CONFIG_FAILURE_PATTERNS)) {
|
||||
return false;
|
||||
}
|
||||
return snapshot.containsAny(STRUCTURED_COMPATIBILITY_FAILURE_PATTERNS);
|
||||
}
|
||||
|
||||
private record FailureSnapshot(String text) {
|
||||
|
||||
static FailureSnapshot from(Throwable failure) {
|
||||
Set<Throwable> visited = Collections.newSetFromMap(new IdentityHashMap<>());
|
||||
StringBuilder builder = new StringBuilder();
|
||||
collect(failure, builder, visited);
|
||||
return new FailureSnapshot(builder.toString().toLowerCase(Locale.ROOT));
|
||||
}
|
||||
|
||||
private static void collect(Throwable throwable, StringBuilder builder, Set<Throwable> visited) {
|
||||
if (throwable == null || visited.contains(throwable)) {
|
||||
return;
|
||||
}
|
||||
visited.add(throwable);
|
||||
builder.append(throwable.getClass().getName()).append('\n');
|
||||
if (throwable.getMessage() != null) {
|
||||
builder.append(throwable.getMessage()).append('\n');
|
||||
}
|
||||
collect(throwable.getCause(), builder, visited);
|
||||
for (Throwable suppressed : throwable.getSuppressed()) {
|
||||
collect(suppressed, builder, visited);
|
||||
}
|
||||
}
|
||||
|
||||
boolean containsAny(List<String> patterns) {
|
||||
return patterns.stream().anyMatch(text::contains);
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user