mirror of
https://github.com/slhaf/Partner.git
synced 2026-05-12 08:43:02 +08:00
fix(openai-provider): ensure structured chat adds JSON instruction when missing
This commit is contained in:
@@ -17,10 +17,7 @@ import work.slhaf.partner.framework.agent.model.provider.ProviderOverride;
|
|||||||
import work.slhaf.partner.framework.agent.support.Result;
|
import work.slhaf.partner.framework.agent.support.Result;
|
||||||
|
|
||||||
import java.time.Duration;
|
import java.time.Duration;
|
||||||
import java.util.Iterator;
|
import java.util.*;
|
||||||
import java.util.LinkedHashMap;
|
|
||||||
import java.util.List;
|
|
||||||
import java.util.Map;
|
|
||||||
|
|
||||||
public class OpenAiCompatibleProvider extends ModelProvider {
|
public class OpenAiCompatibleProvider extends ModelProvider {
|
||||||
|
|
||||||
@@ -103,7 +100,7 @@ public class OpenAiCompatibleProvider extends ModelProvider {
|
|||||||
return executeWithRetry(
|
return executeWithRetry(
|
||||||
"OpenAI-compatible provider failed to complete the structured chat request after 3 attempts.",
|
"OpenAI-compatible provider failed to complete the structured chat request after 3 attempts.",
|
||||||
() -> {
|
() -> {
|
||||||
StructuredChatCompletionCreateParams<T> params = buildParams(messages).toBuilder()
|
StructuredChatCompletionCreateParams<T> params = buildParams(ensureJsonInstruction(messages)).toBuilder()
|
||||||
.responseFormat(responseType)
|
.responseFormat(responseType)
|
||||||
.build();
|
.build();
|
||||||
return extractStructured(client.chat().completions().create(params));
|
return extractStructured(client.chat().completions().create(params));
|
||||||
@@ -111,6 +108,35 @@ public class OpenAiCompatibleProvider extends ModelProvider {
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private List<Message> ensureJsonInstruction(List<Message> messages) {
|
||||||
|
boolean containsJsonInstruction = messages.stream()
|
||||||
|
.map(Message::getContent)
|
||||||
|
.filter(content -> !content.isBlank())
|
||||||
|
.anyMatch(content -> content.toLowerCase(Locale.ROOT).contains("json"));
|
||||||
|
if (containsJsonInstruction) {
|
||||||
|
return messages;
|
||||||
|
}
|
||||||
|
|
||||||
|
String jsonInstruction = "Return only a valid JSON object that matches the requested response schema.";
|
||||||
|
List<Message> patched = new ArrayList<>(messages.size() + 1);
|
||||||
|
boolean merged = false;
|
||||||
|
for (Message message : messages) {
|
||||||
|
if (!merged && message.getRole() == Message.Character.SYSTEM) {
|
||||||
|
patched.add(new Message(
|
||||||
|
Message.Character.SYSTEM,
|
||||||
|
message.getContent() + "\n\n" + jsonInstruction
|
||||||
|
));
|
||||||
|
merged = true;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
patched.add(message);
|
||||||
|
}
|
||||||
|
if (!merged) {
|
||||||
|
patched.addFirst(new Message(Message.Character.SYSTEM, jsonInstruction));
|
||||||
|
}
|
||||||
|
return patched;
|
||||||
|
}
|
||||||
|
|
||||||
private ChatCompletionCreateParams buildParams(List<Message> messages) {
|
private ChatCompletionCreateParams buildParams(List<Message> messages) {
|
||||||
ChatCompletionCreateParams.Builder paramsBuilder = ChatCompletionCreateParams.builder()
|
ChatCompletionCreateParams.Builder paramsBuilder = ChatCompletionCreateParams.builder()
|
||||||
.model(model)
|
.model(model)
|
||||||
|
|||||||
Reference in New Issue
Block a user