diff --git a/Partner-Framework/src/main/java/work/slhaf/partner/framework/agent/model/provider/openai/JsonShapeInstructionBuilder.java b/Partner-Framework/src/main/java/work/slhaf/partner/framework/agent/model/provider/openai/JsonShapeInstructionBuilder.java new file mode 100644 index 00000000..7e10b30e --- /dev/null +++ b/Partner-Framework/src/main/java/work/slhaf/partner/framework/agent/model/provider/openai/JsonShapeInstructionBuilder.java @@ -0,0 +1,114 @@ +package work.slhaf.partner.framework.agent.model.provider.openai; + +import java.lang.reflect.*; +import java.util.*; + +final class JsonShapeInstructionBuilder { + + private static final int MAX_DEPTH = 4; + + private JsonShapeInstructionBuilder() { + } + + static String build(Class responseType) { + return "Return only a valid JSON object.\n" + + "The JSON object must directly match this exact output shape for " + responseType.getSimpleName() + ":\n" + + buildJsonShape(responseType, 0, new HashSet<>()) + "\n\n" + + "Rules:\n" + + "- The top-level object must directly match the shape above.\n" + + "- Do not wrap it in \"" + responseType.getSimpleName() + "\" or any other class name.\n" + + "- Do not rename fields or invent alternative field names.\n" + + "- Do not output markdown, comments, explanations, or code fences."; + } + + private static String buildJsonShape(Type type, int depth, Set visiting) { + if (depth > MAX_DEPTH) { + return "{}"; + } + if (type instanceof ParameterizedType parameterizedType) { + Type rawType = parameterizedType.getRawType(); + if (rawType instanceof Class rawClass && Collection.class.isAssignableFrom(rawClass)) { + Type[] arguments = parameterizedType.getActualTypeArguments(); + if (arguments.length == 0) { + return "[]"; + } + return arrayShape(buildJsonShape(arguments[0], depth + 1, visiting), depth); + } + if (rawType instanceof Class rawClass && Map.class.isAssignableFrom(rawClass)) { + return "{}"; + } + return buildJsonShape(rawType, depth, visiting); + } + if (type instanceof GenericArrayType genericArrayType) { + return arrayShape(buildJsonShape(genericArrayType.getGenericComponentType(), depth + 1, visiting), depth); + } + if (!(type instanceof Class clazz)) { + return "null"; + } + if (clazz.isArray()) { + return arrayShape(buildJsonShape(clazz.getComponentType(), depth + 1, visiting), depth); + } + if (clazz == String.class || clazz == Character.class || clazz == char.class) { + return "\"\""; + } + if (clazz == boolean.class || clazz == Boolean.class) { + return "false"; + } + if (Number.class.isAssignableFrom(clazz) || clazz.isPrimitive()) { + return "0"; + } + if (clazz.isEnum()) { + Object[] constants = clazz.getEnumConstants(); + return constants == null || constants.length == 0 ? "\"\"" : "\"" + constants[0] + "\""; + } + if (Collection.class.isAssignableFrom(clazz)) { + return "[]"; + } + if (Map.class.isAssignableFrom(clazz)) { + return "{}"; + } + if (clazz.getName().startsWith("java.")) { + return "\"\""; + } + if (visiting.contains(clazz)) { + return "{}"; + } + + visiting.add(clazz); + List fields = Arrays.stream(clazz.getDeclaredFields()) + .filter(field -> !field.isSynthetic()) + .filter(field -> !Modifier.isStatic(field.getModifiers())) + .filter(field -> !Modifier.isTransient(field.getModifiers())) + .toList(); + if (fields.isEmpty()) { + visiting.remove(clazz); + return "{}"; + } + + StringBuilder builder = new StringBuilder(); + builder.append("{\n"); + for (int i = 0; i < fields.size(); i++) { + Field field = fields.get(i); + builder.append(indent(depth + 1)) + .append("\"") + .append(field.getName()) + .append("\": ") + .append(buildJsonShape(field.getGenericType(), depth + 1, visiting)); + if (i < fields.size() - 1) { + builder.append(","); + } + builder.append("\n"); + } + builder.append(indent(depth)).append("}"); + visiting.remove(clazz); + return builder.toString(); + } + + private static String arrayShape(String itemShape, int depth) { + return "[\n" + indent(depth + 1) + itemShape.replace("\n", "\n" + indent(depth + 1)) + "\n" + indent(depth) + "]"; + } + + private static String indent(int depth) { + return " ".repeat(Math.max(0, depth)); + } +} diff --git a/Partner-Framework/src/main/java/work/slhaf/partner/framework/agent/model/provider/openai/OpenAiCompatibleProvider.java b/Partner-Framework/src/main/java/work/slhaf/partner/framework/agent/model/provider/openai/OpenAiCompatibleProvider.java index c3cd2c5e..02593dcd 100644 --- a/Partner-Framework/src/main/java/work/slhaf/partner/framework/agent/model/provider/openai/OpenAiCompatibleProvider.java +++ b/Partner-Framework/src/main/java/work/slhaf/partner/framework/agent/model/provider/openai/OpenAiCompatibleProvider.java @@ -100,7 +100,7 @@ public class OpenAiCompatibleProvider extends ModelProvider { return executeWithRetry( "OpenAI-compatible provider failed to complete the structured chat request after 3 attempts.", () -> { - StructuredChatCompletionCreateParams params = buildParams(ensureJsonInstruction(messages)).toBuilder() + StructuredChatCompletionCreateParams params = buildParams(ensureJsonInstruction(messages, responseType)).toBuilder() .responseFormat(responseType) .build(); return extractStructured(client.chat().completions().create(params)); @@ -108,23 +108,17 @@ public class OpenAiCompatibleProvider extends ModelProvider { ); } - private List ensureJsonInstruction(List messages) { - boolean containsJsonInstruction = messages.stream() - .map(Message::getContent) - .filter(content -> !content.isBlank()) - .anyMatch(content -> content.toLowerCase(Locale.ROOT).contains("json")); - if (containsJsonInstruction) { - return messages; - } + private List ensureJsonInstruction(List messages, Class responseType) { + String jsonInstruction = JsonShapeInstructionBuilder.build(responseType); - String jsonInstruction = "Return only a valid JSON object."; List patched = new ArrayList<>(messages.size() + 1); boolean merged = false; for (Message message : messages) { if (!merged && message.getRole() == Message.Character.SYSTEM) { + String separator = message.getContent().isBlank() ? "" : "\n\n"; patched.add(new Message( Message.Character.SYSTEM, - message.getContent() + "\n\n" + jsonInstruction + message.getContent() + separator + jsonInstruction )); merged = true; continue; @@ -137,6 +131,7 @@ public class OpenAiCompatibleProvider extends ModelProvider { return patched; } + private ChatCompletionCreateParams buildParams(List messages) { ChatCompletionCreateParams.Builder paramsBuilder = ChatCompletionCreateParams.builder() .model(model) diff --git a/Partner-Framework/src/test/java/work/slhaf/partner/framework/agent/model/provider/openai/JsonShapeInstructionBuilderTest.java b/Partner-Framework/src/test/java/work/slhaf/partner/framework/agent/model/provider/openai/JsonShapeInstructionBuilderTest.java new file mode 100644 index 00000000..ee391357 --- /dev/null +++ b/Partner-Framework/src/test/java/work/slhaf/partner/framework/agent/model/provider/openai/JsonShapeInstructionBuilderTest.java @@ -0,0 +1,79 @@ +package work.slhaf.partner.framework.agent.model.provider.openai; + +import org.junit.jupiter.api.Test; + +import java.util.List; +import java.util.Map; + +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; + +class JsonShapeInstructionBuilderTest { + + @Test + void shouldBuildShapeInstructionForComplexPojo() { + String instruction = JsonShapeInstructionBuilder.build(ComplexResponse.class); + + assertTrue(instruction.startsWith("Return only a valid JSON object.")); + assertTrue(instruction.contains("The JSON object must directly match this exact output shape for ComplexResponse:")); + assertTrue(instruction.contains("Do not wrap it in \"ComplexResponse\" or any other class name.")); + assertTrue(instruction.contains("Do not rename fields or invent alternative field names.")); + + assertTrue(instruction.contains("\"id\": \"\"")); + assertTrue(instruction.contains("\"success\": false")); + assertTrue(instruction.contains("\"retryCount\": 0")); + assertTrue(instruction.contains("\"score\": 0")); + assertTrue(instruction.contains("\"status\": \"READY\"")); + assertTrue(instruction.contains("\"tags\": [")); + assertTrue(instruction.contains("\"items\": [")); + assertTrue(instruction.contains("\"metadata\": {}")); + assertTrue(instruction.contains("\"matrix\": [")); + assertTrue(instruction.contains("\"nested\": {")); + assertTrue(instruction.contains("\"name\": \"\"")); + assertTrue(instruction.contains("\"enabled\": false")); + assertTrue(instruction.contains("\"notes\": [")); + + assertFalse(instruction.contains("staticValue")); + assertFalse(instruction.contains("transientValue")); + } + + @Test + void shouldPreventRecursiveExpansion() { + String instruction = JsonShapeInstructionBuilder.build(RecursiveResponse.class); + + assertTrue(instruction.contains("\"name\": \"\"")); + assertTrue(instruction.contains("\"next\": {}")); + } + + private enum Status { + READY, + DONE + } + + private static class ComplexResponse { + private static String staticValue; + private transient String transientValue; + + private String id; + private boolean success; + private int retryCount; + private Double score; + private Status status; + private List tags; + private List items; + private Map metadata; + private int[] matrix; + private NestedItem nested; + } + + private static class NestedItem { + private String name; + private Boolean enabled; + private List notes; + } + + private static class RecursiveResponse { + private String name; + private RecursiveResponse next; + } +}