mirror of
https://github.com/slhaf/Partner.git
synced 2026-05-12 16:53:04 +08:00
fix(openai-provider): add cached structured-output downgrade to prompt-only JSON fallback
This commit is contained in:
@@ -1,5 +1,6 @@
|
|||||||
package work.slhaf.partner.framework.agent.model.provider.openai;
|
package work.slhaf.partner.framework.agent.model.provider.openai;
|
||||||
|
|
||||||
|
import com.alibaba.fastjson2.JSON;
|
||||||
import com.alibaba.fastjson2.JSONObject;
|
import com.alibaba.fastjson2.JSONObject;
|
||||||
import com.openai.client.OpenAIClient;
|
import com.openai.client.OpenAIClient;
|
||||||
import com.openai.client.okhttp.OpenAIOkHttpClient;
|
import com.openai.client.okhttp.OpenAIOkHttpClient;
|
||||||
@@ -18,10 +19,13 @@ import work.slhaf.partner.framework.agent.support.Result;
|
|||||||
|
|
||||||
import java.time.Duration;
|
import java.time.Duration;
|
||||||
import java.util.*;
|
import java.util.*;
|
||||||
|
import java.util.concurrent.ConcurrentHashMap;
|
||||||
|
import java.util.concurrent.ConcurrentMap;
|
||||||
|
|
||||||
public class OpenAiCompatibleProvider extends ModelProvider {
|
public class OpenAiCompatibleProvider extends ModelProvider {
|
||||||
|
|
||||||
private static final int MAX_ATTEMPTS = 3;
|
private static final int MAX_ATTEMPTS = 3;
|
||||||
|
private static final ConcurrentMap<StructuredOutputCacheKey, StructuredOutputMode> STRUCTURED_OUTPUT_MODE_CACHE = new ConcurrentHashMap<>();
|
||||||
|
|
||||||
private final String baseUrl;
|
private final String baseUrl;
|
||||||
private final String apiKey;
|
private final String apiKey;
|
||||||
@@ -99,15 +103,52 @@ public class OpenAiCompatibleProvider extends ModelProvider {
|
|||||||
public <T> @NotNull Result<T> formattedChat(@NotNull List<Message> messages, @NotNull Class<T> responseType) {
|
public <T> @NotNull Result<T> formattedChat(@NotNull List<Message> messages, @NotNull Class<T> responseType) {
|
||||||
return executeWithRetry(
|
return executeWithRetry(
|
||||||
"OpenAI-compatible provider failed to complete the structured chat request after 3 attempts.",
|
"OpenAI-compatible provider failed to complete the structured chat request after 3 attempts.",
|
||||||
() -> {
|
() -> formattedChatByCachedMode(messages, responseType)
|
||||||
StructuredChatCompletionCreateParams<T> params = buildParams(ensureJsonInstruction(messages, responseType)).toBuilder()
|
|
||||||
.responseFormat(responseType)
|
|
||||||
.build();
|
|
||||||
return extractStructured(client.chat().completions().create(params));
|
|
||||||
}
|
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private <T> T formattedChatByCachedMode(List<Message> messages, Class<T> responseType) {
|
||||||
|
StructuredOutputCacheKey cacheKey = new StructuredOutputCacheKey(baseUrl, model);
|
||||||
|
StructuredOutputMode mode = STRUCTURED_OUTPUT_MODE_CACHE.getOrDefault(cacheKey, StructuredOutputMode.UNKNOWN);
|
||||||
|
if (mode == StructuredOutputMode.PROMPT_ONLY_JSON) {
|
||||||
|
return promptOnlyFormattedChat(messages, responseType);
|
||||||
|
}
|
||||||
|
return strictThenPromptFallback(messages, responseType, cacheKey);
|
||||||
|
}
|
||||||
|
|
||||||
|
private <T> T strictThenPromptFallback(List<Message> messages, Class<T> responseType, StructuredOutputCacheKey cacheKey) {
|
||||||
|
try {
|
||||||
|
T result = strictFormattedChat(messages, responseType);
|
||||||
|
STRUCTURED_OUTPUT_MODE_CACHE.put(cacheKey, StructuredOutputMode.STRICT_RESPONSE_FORMAT);
|
||||||
|
return result;
|
||||||
|
} catch (Exception structuredFailure) {
|
||||||
|
try {
|
||||||
|
T result = promptOnlyFormattedChat(messages, responseType);
|
||||||
|
if (StructuredOutputFailureClassifier.shouldDowngradeToPromptOnlyJson(structuredFailure)) {
|
||||||
|
STRUCTURED_OUTPUT_MODE_CACHE.put(cacheKey, StructuredOutputMode.PROMPT_ONLY_JSON);
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
} catch (Exception fallbackFailure) {
|
||||||
|
structuredFailure.addSuppressed(fallbackFailure);
|
||||||
|
throw structuredFailure;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private <T> T strictFormattedChat(List<Message> messages, Class<T> responseType) {
|
||||||
|
StructuredChatCompletionCreateParams<T> params = buildParams(ensureJsonInstruction(messages, responseType)).toBuilder()
|
||||||
|
.responseFormat(responseType)
|
||||||
|
.build();
|
||||||
|
return extractStructured(client.chat().completions().create(params));
|
||||||
|
}
|
||||||
|
|
||||||
|
private <T> T promptOnlyFormattedChat(List<Message> messages, Class<T> responseType) {
|
||||||
|
ChatCompletionCreateParams params = buildParams(ensureJsonInstruction(messages, responseType));
|
||||||
|
String rawText = extractText(client.chat().completions().create(params));
|
||||||
|
String jsonText = extractJsonObject(rawText);
|
||||||
|
return JSON.parseObject(jsonText, responseType);
|
||||||
|
}
|
||||||
|
|
||||||
private List<Message> ensureJsonInstruction(List<Message> messages, Class<?> responseType) {
|
private List<Message> ensureJsonInstruction(List<Message> messages, Class<?> responseType) {
|
||||||
String jsonInstruction = JsonShapeInstructionBuilder.build(responseType);
|
String jsonInstruction = JsonShapeInstructionBuilder.build(responseType);
|
||||||
|
|
||||||
@@ -132,6 +173,32 @@ public class OpenAiCompatibleProvider extends ModelProvider {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
private String extractJsonObject(String text) {
|
||||||
|
String trimmed = text == null ? "" : text.trim();
|
||||||
|
if (trimmed.isBlank()) {
|
||||||
|
throw invokeException("OpenAI-compatible provider returned empty content in prompt-only JSON fallback.", null);
|
||||||
|
}
|
||||||
|
trimmed = stripMarkdownFence(trimmed);
|
||||||
|
if (trimmed.startsWith("{") && trimmed.endsWith("}")) {
|
||||||
|
return trimmed;
|
||||||
|
}
|
||||||
|
int start = trimmed.indexOf('{');
|
||||||
|
int end = trimmed.lastIndexOf('}');
|
||||||
|
if (start >= 0 && end > start) {
|
||||||
|
return trimmed.substring(start, end + 1).trim();
|
||||||
|
}
|
||||||
|
throw invokeException("OpenAI-compatible provider prompt-only JSON fallback returned no JSON object.", null);
|
||||||
|
}
|
||||||
|
|
||||||
|
private String stripMarkdownFence(String text) {
|
||||||
|
String trimmed = text.trim();
|
||||||
|
if (!trimmed.startsWith("```")) {
|
||||||
|
return trimmed;
|
||||||
|
}
|
||||||
|
String withoutOpeningFence = trimmed.replaceFirst("^```[a-zA-Z0-9_-]*\\s*", "");
|
||||||
|
return withoutOpeningFence.replaceFirst("\\s*```$", "").trim();
|
||||||
|
}
|
||||||
|
|
||||||
private ChatCompletionCreateParams buildParams(List<Message> messages) {
|
private ChatCompletionCreateParams buildParams(List<Message> messages) {
|
||||||
ChatCompletionCreateParams.Builder paramsBuilder = ChatCompletionCreateParams.builder()
|
ChatCompletionCreateParams.Builder paramsBuilder = ChatCompletionCreateParams.builder()
|
||||||
.model(model)
|
.model(model)
|
||||||
@@ -225,6 +292,15 @@ public class OpenAiCompatibleProvider extends ModelProvider {
|
|||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private enum StructuredOutputMode {
|
||||||
|
UNKNOWN,
|
||||||
|
STRICT_RESPONSE_FORMAT,
|
||||||
|
PROMPT_ONLY_JSON
|
||||||
|
}
|
||||||
|
|
||||||
|
private record StructuredOutputCacheKey(String baseUrl, String model) {
|
||||||
|
}
|
||||||
|
|
||||||
@FunctionalInterface
|
@FunctionalInterface
|
||||||
private interface ThrowingSupplier<T> {
|
private interface ThrowingSupplier<T> {
|
||||||
T get() throws Exception;
|
T get() throws Exception;
|
||||||
|
|||||||
@@ -0,0 +1,87 @@
|
|||||||
|
package work.slhaf.partner.framework.agent.model.provider.openai;
|
||||||
|
|
||||||
|
import java.util.*;
|
||||||
|
|
||||||
|
final class StructuredOutputFailureClassifier {
|
||||||
|
|
||||||
|
private static final List<String> TRANSIENT_FAILURE_PATTERNS = List.of(
|
||||||
|
"timeout",
|
||||||
|
"error reading response",
|
||||||
|
"streamresetexception",
|
||||||
|
"interruptedioexception",
|
||||||
|
"sockettimeoutexception",
|
||||||
|
"connectexception",
|
||||||
|
"connection reset",
|
||||||
|
"503",
|
||||||
|
"502",
|
||||||
|
"504",
|
||||||
|
"429",
|
||||||
|
"internalserverexception",
|
||||||
|
"provider returned error",
|
||||||
|
"openaiioexception",
|
||||||
|
"request failed"
|
||||||
|
);
|
||||||
|
private static final List<String> AUTH_OR_CONFIG_FAILURE_PATTERNS = List.of(
|
||||||
|
"没有权限",
|
||||||
|
"not activated",
|
||||||
|
"permission",
|
||||||
|
"unauthorized",
|
||||||
|
"forbidden",
|
||||||
|
"invalid api key",
|
||||||
|
"sslexception",
|
||||||
|
"unsupported or unrecognized ssl message"
|
||||||
|
);
|
||||||
|
private static final List<String> STRUCTURED_COMPATIBILITY_FAILURE_PATTERNS = List.of(
|
||||||
|
"response_format type is unavailable",
|
||||||
|
"messages must contain the word 'json'",
|
||||||
|
"messages must contain the word json",
|
||||||
|
"structured chat completion returned empty content",
|
||||||
|
"error parsing json:",
|
||||||
|
"unrecognizedpropertyexception",
|
||||||
|
"mismatchedinputexception",
|
||||||
|
"invalidformatexception"
|
||||||
|
);
|
||||||
|
|
||||||
|
private StructuredOutputFailureClassifier() {
|
||||||
|
}
|
||||||
|
|
||||||
|
static boolean shouldDowngradeToPromptOnlyJson(Throwable failure) {
|
||||||
|
FailureSnapshot snapshot = FailureSnapshot.from(failure);
|
||||||
|
if (snapshot.containsAny(TRANSIENT_FAILURE_PATTERNS)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (snapshot.containsAny(AUTH_OR_CONFIG_FAILURE_PATTERNS)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return snapshot.containsAny(STRUCTURED_COMPATIBILITY_FAILURE_PATTERNS);
|
||||||
|
}
|
||||||
|
|
||||||
|
private record FailureSnapshot(String text) {
|
||||||
|
|
||||||
|
static FailureSnapshot from(Throwable failure) {
|
||||||
|
Set<Throwable> visited = Collections.newSetFromMap(new IdentityHashMap<>());
|
||||||
|
StringBuilder builder = new StringBuilder();
|
||||||
|
collect(failure, builder, visited);
|
||||||
|
return new FailureSnapshot(builder.toString().toLowerCase(Locale.ROOT));
|
||||||
|
}
|
||||||
|
|
||||||
|
private static void collect(Throwable throwable, StringBuilder builder, Set<Throwable> visited) {
|
||||||
|
if (throwable == null || visited.contains(throwable)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
visited.add(throwable);
|
||||||
|
builder.append(throwable.getClass().getName()).append('\n');
|
||||||
|
if (throwable.getMessage() != null) {
|
||||||
|
builder.append(throwable.getMessage()).append('\n');
|
||||||
|
}
|
||||||
|
collect(throwable.getCause(), builder, visited);
|
||||||
|
for (Throwable suppressed : throwable.getSuppressed()) {
|
||||||
|
collect(suppressed, builder, visited);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
boolean containsAny(List<String> patterns) {
|
||||||
|
return patterns.stream().anyMatch(text::contains);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user