mirror of
https://github.com/slhaf/Partner.git
synced 2026-05-12 08:43:02 +08:00
fix(openai-provider): generate and inject response-type JSON shape instruction for structured chat
This commit is contained in:
@@ -0,0 +1,114 @@
|
|||||||
|
package work.slhaf.partner.framework.agent.model.provider.openai;
|
||||||
|
|
||||||
|
import java.lang.reflect.*;
|
||||||
|
import java.util.*;
|
||||||
|
|
||||||
|
final class JsonShapeInstructionBuilder {
|
||||||
|
|
||||||
|
private static final int MAX_DEPTH = 4;
|
||||||
|
|
||||||
|
private JsonShapeInstructionBuilder() {
|
||||||
|
}
|
||||||
|
|
||||||
|
static String build(Class<?> responseType) {
|
||||||
|
return "Return only a valid JSON object.\n"
|
||||||
|
+ "The JSON object must directly match this exact output shape for " + responseType.getSimpleName() + ":\n"
|
||||||
|
+ buildJsonShape(responseType, 0, new HashSet<>()) + "\n\n"
|
||||||
|
+ "Rules:\n"
|
||||||
|
+ "- The top-level object must directly match the shape above.\n"
|
||||||
|
+ "- Do not wrap it in \"" + responseType.getSimpleName() + "\" or any other class name.\n"
|
||||||
|
+ "- Do not rename fields or invent alternative field names.\n"
|
||||||
|
+ "- Do not output markdown, comments, explanations, or code fences.";
|
||||||
|
}
|
||||||
|
|
||||||
|
private static String buildJsonShape(Type type, int depth, Set<Type> visiting) {
|
||||||
|
if (depth > MAX_DEPTH) {
|
||||||
|
return "{}";
|
||||||
|
}
|
||||||
|
if (type instanceof ParameterizedType parameterizedType) {
|
||||||
|
Type rawType = parameterizedType.getRawType();
|
||||||
|
if (rawType instanceof Class<?> rawClass && Collection.class.isAssignableFrom(rawClass)) {
|
||||||
|
Type[] arguments = parameterizedType.getActualTypeArguments();
|
||||||
|
if (arguments.length == 0) {
|
||||||
|
return "[]";
|
||||||
|
}
|
||||||
|
return arrayShape(buildJsonShape(arguments[0], depth + 1, visiting), depth);
|
||||||
|
}
|
||||||
|
if (rawType instanceof Class<?> rawClass && Map.class.isAssignableFrom(rawClass)) {
|
||||||
|
return "{}";
|
||||||
|
}
|
||||||
|
return buildJsonShape(rawType, depth, visiting);
|
||||||
|
}
|
||||||
|
if (type instanceof GenericArrayType genericArrayType) {
|
||||||
|
return arrayShape(buildJsonShape(genericArrayType.getGenericComponentType(), depth + 1, visiting), depth);
|
||||||
|
}
|
||||||
|
if (!(type instanceof Class<?> clazz)) {
|
||||||
|
return "null";
|
||||||
|
}
|
||||||
|
if (clazz.isArray()) {
|
||||||
|
return arrayShape(buildJsonShape(clazz.getComponentType(), depth + 1, visiting), depth);
|
||||||
|
}
|
||||||
|
if (clazz == String.class || clazz == Character.class || clazz == char.class) {
|
||||||
|
return "\"\"";
|
||||||
|
}
|
||||||
|
if (clazz == boolean.class || clazz == Boolean.class) {
|
||||||
|
return "false";
|
||||||
|
}
|
||||||
|
if (Number.class.isAssignableFrom(clazz) || clazz.isPrimitive()) {
|
||||||
|
return "0";
|
||||||
|
}
|
||||||
|
if (clazz.isEnum()) {
|
||||||
|
Object[] constants = clazz.getEnumConstants();
|
||||||
|
return constants == null || constants.length == 0 ? "\"\"" : "\"" + constants[0] + "\"";
|
||||||
|
}
|
||||||
|
if (Collection.class.isAssignableFrom(clazz)) {
|
||||||
|
return "[]";
|
||||||
|
}
|
||||||
|
if (Map.class.isAssignableFrom(clazz)) {
|
||||||
|
return "{}";
|
||||||
|
}
|
||||||
|
if (clazz.getName().startsWith("java.")) {
|
||||||
|
return "\"\"";
|
||||||
|
}
|
||||||
|
if (visiting.contains(clazz)) {
|
||||||
|
return "{}";
|
||||||
|
}
|
||||||
|
|
||||||
|
visiting.add(clazz);
|
||||||
|
List<Field> fields = Arrays.stream(clazz.getDeclaredFields())
|
||||||
|
.filter(field -> !field.isSynthetic())
|
||||||
|
.filter(field -> !Modifier.isStatic(field.getModifiers()))
|
||||||
|
.filter(field -> !Modifier.isTransient(field.getModifiers()))
|
||||||
|
.toList();
|
||||||
|
if (fields.isEmpty()) {
|
||||||
|
visiting.remove(clazz);
|
||||||
|
return "{}";
|
||||||
|
}
|
||||||
|
|
||||||
|
StringBuilder builder = new StringBuilder();
|
||||||
|
builder.append("{\n");
|
||||||
|
for (int i = 0; i < fields.size(); i++) {
|
||||||
|
Field field = fields.get(i);
|
||||||
|
builder.append(indent(depth + 1))
|
||||||
|
.append("\"")
|
||||||
|
.append(field.getName())
|
||||||
|
.append("\": ")
|
||||||
|
.append(buildJsonShape(field.getGenericType(), depth + 1, visiting));
|
||||||
|
if (i < fields.size() - 1) {
|
||||||
|
builder.append(",");
|
||||||
|
}
|
||||||
|
builder.append("\n");
|
||||||
|
}
|
||||||
|
builder.append(indent(depth)).append("}");
|
||||||
|
visiting.remove(clazz);
|
||||||
|
return builder.toString();
|
||||||
|
}
|
||||||
|
|
||||||
|
private static String arrayShape(String itemShape, int depth) {
|
||||||
|
return "[\n" + indent(depth + 1) + itemShape.replace("\n", "\n" + indent(depth + 1)) + "\n" + indent(depth) + "]";
|
||||||
|
}
|
||||||
|
|
||||||
|
private static String indent(int depth) {
|
||||||
|
return " ".repeat(Math.max(0, depth));
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -100,7 +100,7 @@ public class OpenAiCompatibleProvider extends ModelProvider {
|
|||||||
return executeWithRetry(
|
return executeWithRetry(
|
||||||
"OpenAI-compatible provider failed to complete the structured chat request after 3 attempts.",
|
"OpenAI-compatible provider failed to complete the structured chat request after 3 attempts.",
|
||||||
() -> {
|
() -> {
|
||||||
StructuredChatCompletionCreateParams<T> params = buildParams(ensureJsonInstruction(messages)).toBuilder()
|
StructuredChatCompletionCreateParams<T> params = buildParams(ensureJsonInstruction(messages, responseType)).toBuilder()
|
||||||
.responseFormat(responseType)
|
.responseFormat(responseType)
|
||||||
.build();
|
.build();
|
||||||
return extractStructured(client.chat().completions().create(params));
|
return extractStructured(client.chat().completions().create(params));
|
||||||
@@ -108,23 +108,17 @@ public class OpenAiCompatibleProvider extends ModelProvider {
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
private List<Message> ensureJsonInstruction(List<Message> messages) {
|
private List<Message> ensureJsonInstruction(List<Message> messages, Class<?> responseType) {
|
||||||
boolean containsJsonInstruction = messages.stream()
|
String jsonInstruction = JsonShapeInstructionBuilder.build(responseType);
|
||||||
.map(Message::getContent)
|
|
||||||
.filter(content -> !content.isBlank())
|
|
||||||
.anyMatch(content -> content.toLowerCase(Locale.ROOT).contains("json"));
|
|
||||||
if (containsJsonInstruction) {
|
|
||||||
return messages;
|
|
||||||
}
|
|
||||||
|
|
||||||
String jsonInstruction = "Return only a valid JSON object.";
|
|
||||||
List<Message> patched = new ArrayList<>(messages.size() + 1);
|
List<Message> patched = new ArrayList<>(messages.size() + 1);
|
||||||
boolean merged = false;
|
boolean merged = false;
|
||||||
for (Message message : messages) {
|
for (Message message : messages) {
|
||||||
if (!merged && message.getRole() == Message.Character.SYSTEM) {
|
if (!merged && message.getRole() == Message.Character.SYSTEM) {
|
||||||
|
String separator = message.getContent().isBlank() ? "" : "\n\n";
|
||||||
patched.add(new Message(
|
patched.add(new Message(
|
||||||
Message.Character.SYSTEM,
|
Message.Character.SYSTEM,
|
||||||
message.getContent() + "\n\n" + jsonInstruction
|
message.getContent() + separator + jsonInstruction
|
||||||
));
|
));
|
||||||
merged = true;
|
merged = true;
|
||||||
continue;
|
continue;
|
||||||
@@ -137,6 +131,7 @@ public class OpenAiCompatibleProvider extends ModelProvider {
|
|||||||
return patched;
|
return patched;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
private ChatCompletionCreateParams buildParams(List<Message> messages) {
|
private ChatCompletionCreateParams buildParams(List<Message> messages) {
|
||||||
ChatCompletionCreateParams.Builder paramsBuilder = ChatCompletionCreateParams.builder()
|
ChatCompletionCreateParams.Builder paramsBuilder = ChatCompletionCreateParams.builder()
|
||||||
.model(model)
|
.model(model)
|
||||||
|
|||||||
@@ -0,0 +1,79 @@
|
|||||||
|
package work.slhaf.partner.framework.agent.model.provider.openai;
|
||||||
|
|
||||||
|
import org.junit.jupiter.api.Test;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
import static org.junit.jupiter.api.Assertions.assertFalse;
|
||||||
|
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||||
|
|
||||||
|
class JsonShapeInstructionBuilderTest {
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void shouldBuildShapeInstructionForComplexPojo() {
|
||||||
|
String instruction = JsonShapeInstructionBuilder.build(ComplexResponse.class);
|
||||||
|
|
||||||
|
assertTrue(instruction.startsWith("Return only a valid JSON object."));
|
||||||
|
assertTrue(instruction.contains("The JSON object must directly match this exact output shape for ComplexResponse:"));
|
||||||
|
assertTrue(instruction.contains("Do not wrap it in \"ComplexResponse\" or any other class name."));
|
||||||
|
assertTrue(instruction.contains("Do not rename fields or invent alternative field names."));
|
||||||
|
|
||||||
|
assertTrue(instruction.contains("\"id\": \"\""));
|
||||||
|
assertTrue(instruction.contains("\"success\": false"));
|
||||||
|
assertTrue(instruction.contains("\"retryCount\": 0"));
|
||||||
|
assertTrue(instruction.contains("\"score\": 0"));
|
||||||
|
assertTrue(instruction.contains("\"status\": \"READY\""));
|
||||||
|
assertTrue(instruction.contains("\"tags\": ["));
|
||||||
|
assertTrue(instruction.contains("\"items\": ["));
|
||||||
|
assertTrue(instruction.contains("\"metadata\": {}"));
|
||||||
|
assertTrue(instruction.contains("\"matrix\": ["));
|
||||||
|
assertTrue(instruction.contains("\"nested\": {"));
|
||||||
|
assertTrue(instruction.contains("\"name\": \"\""));
|
||||||
|
assertTrue(instruction.contains("\"enabled\": false"));
|
||||||
|
assertTrue(instruction.contains("\"notes\": ["));
|
||||||
|
|
||||||
|
assertFalse(instruction.contains("staticValue"));
|
||||||
|
assertFalse(instruction.contains("transientValue"));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void shouldPreventRecursiveExpansion() {
|
||||||
|
String instruction = JsonShapeInstructionBuilder.build(RecursiveResponse.class);
|
||||||
|
|
||||||
|
assertTrue(instruction.contains("\"name\": \"\""));
|
||||||
|
assertTrue(instruction.contains("\"next\": {}"));
|
||||||
|
}
|
||||||
|
|
||||||
|
private enum Status {
|
||||||
|
READY,
|
||||||
|
DONE
|
||||||
|
}
|
||||||
|
|
||||||
|
private static class ComplexResponse {
|
||||||
|
private static String staticValue;
|
||||||
|
private transient String transientValue;
|
||||||
|
|
||||||
|
private String id;
|
||||||
|
private boolean success;
|
||||||
|
private int retryCount;
|
||||||
|
private Double score;
|
||||||
|
private Status status;
|
||||||
|
private List<String> tags;
|
||||||
|
private List<NestedItem> items;
|
||||||
|
private Map<String, String> metadata;
|
||||||
|
private int[] matrix;
|
||||||
|
private NestedItem nested;
|
||||||
|
}
|
||||||
|
|
||||||
|
private static class NestedItem {
|
||||||
|
private String name;
|
||||||
|
private Boolean enabled;
|
||||||
|
private List<String> notes;
|
||||||
|
}
|
||||||
|
|
||||||
|
private static class RecursiveResponse {
|
||||||
|
private String name;
|
||||||
|
private RecursiveResponse next;
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user