fix(openai-provider): generate and inject response-type JSON shape instruction for structured chat

This commit is contained in:
2026-04-25 20:23:28 +08:00
parent 075a8ece3e
commit dd5ab3aaf3
3 changed files with 199 additions and 11 deletions

View File

@@ -0,0 +1,114 @@
package work.slhaf.partner.framework.agent.model.provider.openai;
import java.lang.reflect.*;
import java.util.*;
final class JsonShapeInstructionBuilder {
private static final int MAX_DEPTH = 4;
private JsonShapeInstructionBuilder() {
}
static String build(Class<?> responseType) {
return "Return only a valid JSON object.\n"
+ "The JSON object must directly match this exact output shape for " + responseType.getSimpleName() + ":\n"
+ buildJsonShape(responseType, 0, new HashSet<>()) + "\n\n"
+ "Rules:\n"
+ "- The top-level object must directly match the shape above.\n"
+ "- Do not wrap it in \"" + responseType.getSimpleName() + "\" or any other class name.\n"
+ "- Do not rename fields or invent alternative field names.\n"
+ "- Do not output markdown, comments, explanations, or code fences.";
}
private static String buildJsonShape(Type type, int depth, Set<Type> visiting) {
if (depth > MAX_DEPTH) {
return "{}";
}
if (type instanceof ParameterizedType parameterizedType) {
Type rawType = parameterizedType.getRawType();
if (rawType instanceof Class<?> rawClass && Collection.class.isAssignableFrom(rawClass)) {
Type[] arguments = parameterizedType.getActualTypeArguments();
if (arguments.length == 0) {
return "[]";
}
return arrayShape(buildJsonShape(arguments[0], depth + 1, visiting), depth);
}
if (rawType instanceof Class<?> rawClass && Map.class.isAssignableFrom(rawClass)) {
return "{}";
}
return buildJsonShape(rawType, depth, visiting);
}
if (type instanceof GenericArrayType genericArrayType) {
return arrayShape(buildJsonShape(genericArrayType.getGenericComponentType(), depth + 1, visiting), depth);
}
if (!(type instanceof Class<?> clazz)) {
return "null";
}
if (clazz.isArray()) {
return arrayShape(buildJsonShape(clazz.getComponentType(), depth + 1, visiting), depth);
}
if (clazz == String.class || clazz == Character.class || clazz == char.class) {
return "\"\"";
}
if (clazz == boolean.class || clazz == Boolean.class) {
return "false";
}
if (Number.class.isAssignableFrom(clazz) || clazz.isPrimitive()) {
return "0";
}
if (clazz.isEnum()) {
Object[] constants = clazz.getEnumConstants();
return constants == null || constants.length == 0 ? "\"\"" : "\"" + constants[0] + "\"";
}
if (Collection.class.isAssignableFrom(clazz)) {
return "[]";
}
if (Map.class.isAssignableFrom(clazz)) {
return "{}";
}
if (clazz.getName().startsWith("java.")) {
return "\"\"";
}
if (visiting.contains(clazz)) {
return "{}";
}
visiting.add(clazz);
List<Field> fields = Arrays.stream(clazz.getDeclaredFields())
.filter(field -> !field.isSynthetic())
.filter(field -> !Modifier.isStatic(field.getModifiers()))
.filter(field -> !Modifier.isTransient(field.getModifiers()))
.toList();
if (fields.isEmpty()) {
visiting.remove(clazz);
return "{}";
}
StringBuilder builder = new StringBuilder();
builder.append("{\n");
for (int i = 0; i < fields.size(); i++) {
Field field = fields.get(i);
builder.append(indent(depth + 1))
.append("\"")
.append(field.getName())
.append("\": ")
.append(buildJsonShape(field.getGenericType(), depth + 1, visiting));
if (i < fields.size() - 1) {
builder.append(",");
}
builder.append("\n");
}
builder.append(indent(depth)).append("}");
visiting.remove(clazz);
return builder.toString();
}
private static String arrayShape(String itemShape, int depth) {
return "[\n" + indent(depth + 1) + itemShape.replace("\n", "\n" + indent(depth + 1)) + "\n" + indent(depth) + "]";
}
private static String indent(int depth) {
return " ".repeat(Math.max(0, depth));
}
}

View File

@@ -100,7 +100,7 @@ public class OpenAiCompatibleProvider extends ModelProvider {
return executeWithRetry( return executeWithRetry(
"OpenAI-compatible provider failed to complete the structured chat request after 3 attempts.", "OpenAI-compatible provider failed to complete the structured chat request after 3 attempts.",
() -> { () -> {
StructuredChatCompletionCreateParams<T> params = buildParams(ensureJsonInstruction(messages)).toBuilder() StructuredChatCompletionCreateParams<T> params = buildParams(ensureJsonInstruction(messages, responseType)).toBuilder()
.responseFormat(responseType) .responseFormat(responseType)
.build(); .build();
return extractStructured(client.chat().completions().create(params)); return extractStructured(client.chat().completions().create(params));
@@ -108,23 +108,17 @@ public class OpenAiCompatibleProvider extends ModelProvider {
); );
} }
private List<Message> ensureJsonInstruction(List<Message> messages) { private List<Message> ensureJsonInstruction(List<Message> messages, Class<?> responseType) {
boolean containsJsonInstruction = messages.stream() String jsonInstruction = JsonShapeInstructionBuilder.build(responseType);
.map(Message::getContent)
.filter(content -> !content.isBlank())
.anyMatch(content -> content.toLowerCase(Locale.ROOT).contains("json"));
if (containsJsonInstruction) {
return messages;
}
String jsonInstruction = "Return only a valid JSON object.";
List<Message> patched = new ArrayList<>(messages.size() + 1); List<Message> patched = new ArrayList<>(messages.size() + 1);
boolean merged = false; boolean merged = false;
for (Message message : messages) { for (Message message : messages) {
if (!merged && message.getRole() == Message.Character.SYSTEM) { if (!merged && message.getRole() == Message.Character.SYSTEM) {
String separator = message.getContent().isBlank() ? "" : "\n\n";
patched.add(new Message( patched.add(new Message(
Message.Character.SYSTEM, Message.Character.SYSTEM,
message.getContent() + "\n\n" + jsonInstruction message.getContent() + separator + jsonInstruction
)); ));
merged = true; merged = true;
continue; continue;
@@ -137,6 +131,7 @@ public class OpenAiCompatibleProvider extends ModelProvider {
return patched; return patched;
} }
private ChatCompletionCreateParams buildParams(List<Message> messages) { private ChatCompletionCreateParams buildParams(List<Message> messages) {
ChatCompletionCreateParams.Builder paramsBuilder = ChatCompletionCreateParams.builder() ChatCompletionCreateParams.Builder paramsBuilder = ChatCompletionCreateParams.builder()
.model(model) .model(model)

View File

@@ -0,0 +1,79 @@
package work.slhaf.partner.framework.agent.model.provider.openai;
import org.junit.jupiter.api.Test;
import java.util.List;
import java.util.Map;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertTrue;
class JsonShapeInstructionBuilderTest {
@Test
void shouldBuildShapeInstructionForComplexPojo() {
String instruction = JsonShapeInstructionBuilder.build(ComplexResponse.class);
assertTrue(instruction.startsWith("Return only a valid JSON object."));
assertTrue(instruction.contains("The JSON object must directly match this exact output shape for ComplexResponse:"));
assertTrue(instruction.contains("Do not wrap it in \"ComplexResponse\" or any other class name."));
assertTrue(instruction.contains("Do not rename fields or invent alternative field names."));
assertTrue(instruction.contains("\"id\": \"\""));
assertTrue(instruction.contains("\"success\": false"));
assertTrue(instruction.contains("\"retryCount\": 0"));
assertTrue(instruction.contains("\"score\": 0"));
assertTrue(instruction.contains("\"status\": \"READY\""));
assertTrue(instruction.contains("\"tags\": ["));
assertTrue(instruction.contains("\"items\": ["));
assertTrue(instruction.contains("\"metadata\": {}"));
assertTrue(instruction.contains("\"matrix\": ["));
assertTrue(instruction.contains("\"nested\": {"));
assertTrue(instruction.contains("\"name\": \"\""));
assertTrue(instruction.contains("\"enabled\": false"));
assertTrue(instruction.contains("\"notes\": ["));
assertFalse(instruction.contains("staticValue"));
assertFalse(instruction.contains("transientValue"));
}
@Test
void shouldPreventRecursiveExpansion() {
String instruction = JsonShapeInstructionBuilder.build(RecursiveResponse.class);
assertTrue(instruction.contains("\"name\": \"\""));
assertTrue(instruction.contains("\"next\": {}"));
}
private enum Status {
READY,
DONE
}
private static class ComplexResponse {
private static String staticValue;
private transient String transientValue;
private String id;
private boolean success;
private int retryCount;
private Double score;
private Status status;
private List<String> tags;
private List<NestedItem> items;
private Map<String, String> metadata;
private int[] matrix;
private NestedItem nested;
}
private static class NestedItem {
private String name;
private Boolean enabled;
private List<String> notes;
}
private static class RecursiveResponse {
private String name;
private RecursiveResponse next;
}
}