refactor(chat): replace custom client with OpenAI runtime and remove file-based module prompt loading logic, prompt will be provided by each module

This commit is contained in:
2026-03-09 21:51:07 +08:00
parent 8dc7ed080b
commit 1b2ccaee9c
32 changed files with 288 additions and 615 deletions

View File

@@ -8,11 +8,9 @@ import work.slhaf.partner.api.agent.factory.component.abstracts.AbstractAgentMod
import work.slhaf.partner.api.agent.factory.component.abstracts.ActivateModel
import work.slhaf.partner.api.agent.factory.component.annotation.AgentComponent
import work.slhaf.partner.api.agent.factory.component.exception.ModuleFactoryInitFailedException
import work.slhaf.partner.api.agent.factory.config.exception.PromptNotExistException
import work.slhaf.partner.api.agent.factory.config.pojo.ModelConfig
import work.slhaf.partner.api.agent.factory.context.AgentRegisterContext
import work.slhaf.partner.api.agent.factory.context.ModuleContextData
import work.slhaf.partner.api.chat.pojo.Message
import java.lang.reflect.Modifier
import java.time.ZonedDateTime
@@ -21,7 +19,7 @@ import java.time.ZonedDateTime
*
* 行为:
* - 若实例是 [AbstractAgentModule],按 Running/Sub/Standalone 构造 `ModuleContextData` 并注册到 modules。
* - 若实现了 [ActivateModel]必须存在对应 `modelPromptMap` 条目,随后构建 `modelInfo`。
* - 若实现了 [ActivateModel]使用模块提供的 prompt 元数据构建 `modelInfo`。
* - 若不是模块类型,尝试注册为 additional component失败仅记录错误日志
*/
class ComponentRegisterFactory : AgentBaseFactory() {
@@ -35,7 +33,6 @@ class ComponentRegisterFactory : AgentBaseFactory() {
val agentContext = context.agentContext
val modelConfigMap = configFactoryContext.modelConfigMap
val modelPromptMap = configFactoryContext.modelPromptMap
val defaultConfig = modelConfigMap["default"]!!
reflections.getTypesAnnotatedWith(AgentComponent::class.java)
@@ -56,7 +53,6 @@ class ComponentRegisterFactory : AgentBaseFactory() {
componentClass,
componentInstance,
modelConfigMap,
modelPromptMap,
defaultConfig
)
} else {
@@ -71,7 +67,6 @@ class ComponentRegisterFactory : AgentBaseFactory() {
componentClass: Class<*>,
module: AbstractAgentModule,
modelConfigMap: Map<String, ModelConfig>,
modelPromptMap: Map<String, List<Message>>,
defaultConfig: ModelConfig
) {
if (agentContext.modules.containsKey(module.moduleName)) {
@@ -84,12 +79,10 @@ class ComponentRegisterFactory : AgentBaseFactory() {
val modelInfo = if (module is ActivateModel) {
val modelKey = module.modelKey()
val modelConfig = modelConfigMap[modelKey] ?: defaultConfig
val modelPrompt = modelPromptMap[modelKey]
?: throw PromptNotExistException("不存在的modelPrompt: $modelKey")
ModuleContextData.ModelInfo(
modelConfig.baseUrl,
modelConfig.model,
JSONArray.parseArray(JSONObject.toJSONString(modelPrompt))
JSONArray.parseArray(JSONObject.toJSONString(module.modulePrompt()))
)
} else {
null

View File

@@ -3,13 +3,10 @@ package work.slhaf.partner.api.agent.factory.component.abstracts
import org.slf4j.Logger
import org.slf4j.LoggerFactory
import work.slhaf.partner.api.agent.factory.component.annotation.AgentComponent
import work.slhaf.partner.api.agent.factory.component.annotation.Init
import work.slhaf.partner.api.agent.runtime.config.AgentConfigLoader
import work.slhaf.partner.api.agent.runtime.interaction.flow.RunningFlowContext
import work.slhaf.partner.api.chat.ChatClient
import work.slhaf.partner.api.chat.constant.ChatConstant
import work.slhaf.partner.api.chat.pojo.ChatResponse
import work.slhaf.partner.api.chat.pojo.Message
import work.slhaf.partner.api.chat.runtime.OpenAiChatRuntime
/**
* 模块基类
@@ -39,58 +36,37 @@ sealed class AbstractAgentModule {
interface ActivateModel {
val model: Model
get() = modelMap.computeIfAbsent(modelKey()) {
buildModel()
val runtime: OpenAiChatRuntime
get() = runtimeMap.computeIfAbsent(modelKey()) {
buildRuntime()
}
companion object {
val modelMap: MutableMap<String, Model> = mutableMapOf()
val runtimeMap: MutableMap<String, OpenAiChatRuntime> = mutableMapOf()
private val configManager: AgentConfigLoader = AgentConfigLoader.INSTANCE
}
@Init(order = -1)
fun modelSettings() {
modelMap[modelKey()] = buildModel()
}
fun buildModel(): Model {
fun buildRuntime(): OpenAiChatRuntime {
val modelConfig = configManager.loadModelConfig(modelKey())
val chatClient = ChatClient(modelConfig.baseUrl, modelConfig.apikey, modelConfig.model)
val model = Model(chatClient)
return OpenAiChatRuntime(modelConfig.baseUrl, modelConfig.apikey, modelConfig.model)
}
val baseMessages = if (withBasicPrompt()) {
loadSpecificPromptAndBasicPrompt(modelKey())
} else {
configManager.loadModelPrompt(modelKey())
fun chat(messages: List<Message>): String {
return runtime.chat(mergeMessages(messages), useStreaming())
}
fun <T : Any> formattedChat(messages: List<Message>, responseType: Class<T>): T {
return runtime.formattedChat(mergeMessages(messages), useStreaming(), responseType)
}
fun mergeMessages(messages: List<Message>): List<Message> {
if (modulePrompt().isEmpty()) {
return messages
}
return buildList {
addAll(modulePrompt())
addAll(messages)
}
model.baseMessages.addAll(baseMessages)
return model
}
private fun loadSpecificPromptAndBasicPrompt(modelKey: String): MutableList<Message> {
val messages: MutableList<Message> = ArrayList()
messages.addAll(configManager.loadModelPrompt("basic"))
messages.addAll(configManager.loadModelPrompt(modelKey))
return messages
}
fun chat(): ChatResponse {
val temp = ArrayList<Message?>()
temp.addAll(model.baseMessages)
temp.addAll(model.chatMessages)
return model.chatClient.runChat(temp)
}
fun singleChat(input: String): ChatResponse {
val temp = ArrayList<Message>(model.baseMessages)
temp.add(Message(ChatConstant.Character.USER, input))
return model.chatClient.runChat(temp)
}
fun updateChatClientSettings() {
model.chatClient.temperature = 0.4
model.chatClient.top_p = 0.8
}
/**
@@ -104,11 +80,7 @@ interface ActivateModel {
}
}
fun withBasicPrompt(): Boolean
fun modulePrompt(): List<Message> = emptyList()
data class Model(
val chatClient: ChatClient,
val chatMessages: MutableList<Message> = mutableListOf(),
val baseMessages: MutableList<Message> = mutableListOf()
)
fun useStreaming(): Boolean = false
}

View File

@@ -3,7 +3,6 @@ package work.slhaf.partner.api.agent.factory.config
import org.slf4j.LoggerFactory
import work.slhaf.partner.api.agent.factory.AgentBaseFactory
import work.slhaf.partner.api.agent.factory.config.exception.ConfigNotExistException
import work.slhaf.partner.api.agent.factory.config.exception.PromptNotExistException
import work.slhaf.partner.api.agent.factory.context.AgentRegisterContext
import work.slhaf.partner.api.agent.runtime.config.AgentConfigLoader
import work.slhaf.partner.api.agent.runtime.config.FileAgentConfigLoader
@@ -14,8 +13,8 @@ import java.lang.reflect.Modifier
*
* 行为:
* - 使用全局 `AgentConfigLoader.INSTANCE`,为空时退回 [FileAgentConfigLoader]。
* - 加载并写入 `modelConfigMap`、`modelPromptMap` 到 `ConfigFactoryContext`。
* - 校验 `default` 配置与 `basic` 提示词是否存在。
* - 加载并写入 `modelConfigMap` 到 `ConfigFactoryContext`。
* - 校验 `default` 配置是否存在。
* - 反射读取配置加载器实现类(相对基类新增)的静态字段,并写入 `AgentContext.metadata`。
*/
class ConfigLoaderFactory : AgentBaseFactory() {
@@ -33,26 +32,16 @@ class ConfigLoaderFactory : AgentBaseFactory() {
val configFactoryContext = context.configFactoryContext
configFactoryContext.modelConfigMap.putAll(agentConfigLoader.modelConfigMap)
configFactoryContext.modelPromptMap.putAll(agentConfigLoader.modelPromptMap)
check(configFactoryContext.modelConfigMap.keys, configFactoryContext.modelPromptMap.keys)
check(configFactoryContext.modelConfigMap.keys)
collectLoaderMetadata(context, agentConfigLoader)
}
private fun check(configKeys: Set<String>, promptKeys: Set<String>) {
log.info("执行config与prompt检测...")
private fun check(configKeys: Set<String>) {
log.info("执行config检测...")
if (!configKeys.contains("default")) {
throw ConfigNotExistException("缺少默认配置! 需确保存在一个模型配置的key为`default`")
}
if (!promptKeys.contains("basic")) {
throw PromptNotExistException("缺少基础Prompt! 需要确保存在key为basic的Prompt文件它将与其他Prompt共同作用于模块节点。")
}
val configKeySet = configKeys.toMutableSet().apply { remove("default") }
val promptKeySet = promptKeys.toMutableSet().apply { remove("basic") }
if (!promptKeySet.containsAll(configKeySet)) {
log.warn("存在未被提示词包含的模型配置,该配置将无法生效!")
}
log.info("检测完毕.")
}

View File

@@ -1,12 +0,0 @@
package work.slhaf.partner.api.agent.factory.config.pojo;
import lombok.Data;
import work.slhaf.partner.api.chat.pojo.Message;
import java.util.List;
@Data
public class PrimaryModelPrompt {
private String key;
private List<Message> messages;
}

View File

@@ -4,7 +4,6 @@ import org.reflections.Reflections
import org.reflections.scanners.Scanners
import org.reflections.util.ConfigurationBuilder
import work.slhaf.partner.api.agent.factory.config.pojo.ModelConfig
import work.slhaf.partner.api.chat.pojo.Message
import java.lang.reflect.Method
import java.net.URL
@@ -25,7 +24,6 @@ class AgentRegisterContext(urls: List<URL>) {
}
class ConfigFactoryContext {
val modelPromptMap: HashMap<String, List<Message>> = HashMap()
val modelConfigMap: HashMap<String, ModelConfig> = HashMap()
}

View File

@@ -3,12 +3,9 @@ package work.slhaf.partner.api.agent.runtime.config;
import lombok.Data;
import lombok.Setter;
import lombok.extern.slf4j.Slf4j;
import work.slhaf.partner.api.agent.factory.config.exception.PromptNotExistException;
import work.slhaf.partner.api.agent.factory.config.pojo.ModelConfig;
import work.slhaf.partner.api.chat.pojo.Message;
import java.util.HashMap;
import java.util.List;
@Slf4j
@Data
@@ -18,45 +15,22 @@ public abstract class AgentConfigLoader {
@Setter
public static AgentConfigLoader INSTANCE;
protected HashMap<String, ModelConfig> modelConfigMap;
protected HashMap<String, List<Message>> modelPromptMap;
public void load() {
modelConfigMap = loadModelConfig();
modelPromptMap = loadModelPrompt();
}
protected abstract HashMap<String, List<Message>> loadModelPrompt();
protected abstract HashMap<String, ModelConfig> loadModelConfig();
public abstract void dumpModelConfig(String key);
// Keep explicit getters for Kotlin compilation phase (without Lombok-generated methods).
public HashMap<String, ModelConfig> getModelConfigMap() {
return modelConfigMap;
}
public HashMap<String, List<Message>> getModelPromptMap() {
return modelPromptMap;
}
public List<Message> loadModelPrompt(String modelKey) {
if (!modelPromptMap.containsKey(modelKey)) {
throw new PromptNotExistException("不存在的modelPrompt: " + modelKey);
}
return modelPromptMap.get(modelKey);
}
public ModelConfig loadModelConfig(String modelKey) {
if (!modelConfigMap.containsKey(modelKey)) {
return modelConfigMap.get(DEFAULT_KEY);
}
return modelConfigMap.get(modelKey);
}
public void updateModelConfig(String modelKey, ModelConfig config) {
modelConfigMap.put(modelKey, config);
dumpModelConfig(modelKey);
}
}

View File

@@ -2,17 +2,14 @@ package work.slhaf.partner.api.agent.runtime.config;
import cn.hutool.json.JSONUtil;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.io.FileUtils;
import work.slhaf.partner.api.agent.factory.config.exception.*;
import work.slhaf.partner.api.agent.factory.config.exception.ConfigDirNotExistException;
import work.slhaf.partner.api.agent.factory.config.exception.ConfigNotExistException;
import work.slhaf.partner.api.agent.factory.config.pojo.ModelConfig;
import work.slhaf.partner.api.agent.factory.config.pojo.PrimaryModelConfig;
import work.slhaf.partner.api.agent.factory.config.pojo.PrimaryModelPrompt;
import work.slhaf.partner.api.chat.pojo.Message;
import java.io.File;
import java.nio.charset.StandardCharsets;
import java.util.HashMap;
import java.util.List;
/**
* 默认配置工厂
@@ -23,28 +20,6 @@ public class FileAgentConfigLoader extends AgentConfigLoader {
protected static final String CONFIG_DIR = "./config/";
protected static final String MODEL_CONFIG_DIR = "./config/model/";
protected static final String PROMPT_CONFIG_DIR = "./config/prompt/";
@Override
protected HashMap<String, List<Message>> loadModelPrompt() {
File file = new File(PROMPT_CONFIG_DIR);
if (!file.exists() && !file.isDirectory()) {
throw new PromptDirNotExistException("未找到提示词目录: " + PROMPT_CONFIG_DIR + " 请手动创建!");
}
File[] files = file.listFiles();
if (files == null || files.length == 0) {
throw new PromptNotExistException("在目录 " + PROMPT_CONFIG_DIR + " 中未找到提示词配置!");
}
HashMap<String, List<Message>> promptMap = new HashMap<>();
for (File f : files) {
if (f.isDirectory()) {
continue;
}
PrimaryModelPrompt primaryModelPrompt = JSONUtil.readJSONObject(f, StandardCharsets.UTF_8).toBean(PrimaryModelPrompt.class);
promptMap.put(primaryModelPrompt.getKey(), primaryModelPrompt.getMessages());
}
return promptMap;
}
@Override
protected HashMap<String, ModelConfig> loadModelConfig() {
@@ -67,17 +42,4 @@ public class FileAgentConfigLoader extends AgentConfigLoader {
}
return configMap;
}
@Override
public void dumpModelConfig(String key) {
try {
File file = new File(MODEL_CONFIG_DIR + key + ".json");
if (!file.exists()) {
file.createNewFile();
}
FileUtils.writeStringToFile(file, JSONUtil.toJsonPrettyStr(modelConfigMap.get(key)), StandardCharsets.UTF_8, false);
} catch (Exception e) {
throw new ConfigUpdateFailedException("ModelConfig 配置文件更新失败!");
}
}
}

View File

@@ -1,84 +0,0 @@
package work.slhaf.partner.api.chat;
import cn.hutool.core.io.IORuntimeException;
import cn.hutool.http.HttpRequest;
import cn.hutool.http.HttpResponse;
import cn.hutool.json.JSONUtil;
import lombok.Data;
import lombok.NoArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import work.slhaf.partner.api.chat.constant.ChatConstant;
import work.slhaf.partner.api.chat.pojo.ChatBody;
import work.slhaf.partner.api.chat.pojo.ChatResponse;
import work.slhaf.partner.api.chat.pojo.Message;
import work.slhaf.partner.api.chat.pojo.PrimaryChatResponse;
import java.util.List;
@Slf4j
@Data
@NoArgsConstructor
public class ChatClient {
private String clientId;
private String url;
private String apikey;
private String model;
private double top_p;
private double temperature;
private int max_tokens;
public ChatClient(String url, String apikey, String model) {
this.url = url;
this.apikey = apikey;
this.model = model;
}
public ChatResponse runChat(List<Message> messages) {
HttpRequest request = HttpRequest.post(url);
request.setConnectionTimeout(2000);
request.setReadTimeout(15000);
request.header("Content-Type", "application/json");
request.header("Authorization", "Bearer " + apikey);
ChatBody body;
if (top_p > 0) {
body = ChatBody.builder()
.model(model)
.messages(messages)
.top_p(top_p)
.temperature(temperature)
.max_tokens(max_tokens)
.build();
} else {
body = ChatBody.builder()
.model(model)
.messages(messages)
.build();
}
ChatResponse finalResponse;
try {
HttpResponse response = request.body(JSONUtil.toJsonStr(body)).execute();
PrimaryChatResponse primaryChatResponse = JSONUtil.toBean(response.body(), PrimaryChatResponse.class);
finalResponse = ChatResponse.builder()
.status(ChatConstant.ResponseStatus.SUCCESS)
.message(primaryChatResponse.getChoices().get(0).getMessage().getContent())
.usageBean(primaryChatResponse.getUsage())
.build();
response.close();
} catch (IORuntimeException e) {
log.error("请求超时", e);
finalResponse = ChatResponse.builder()
.message("连接超时")
.status(ChatConstant.ResponseStatus.FAILED)
.usageBean(null)
.build();
}
return finalResponse;
}
}

View File

@@ -1,25 +0,0 @@
package work.slhaf.partner.api.chat.pojo;
import lombok.*;
import java.util.List;
@Builder
@Data
@AllArgsConstructor
@NoArgsConstructor
public class ChatBody {
@NonNull
private String model;
@NonNull
private List<Message> messages;
@Builder.Default
private double temperature = 1;
@Builder.Default
private double top_p = 1;
private boolean stream;
@Builder.Default
private int max_tokens = 1024;
private int presence_penalty;
private int frequency_penalty;
}

View File

@@ -1,17 +0,0 @@
package work.slhaf.partner.api.chat.pojo;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;
import work.slhaf.partner.api.chat.constant.ChatConstant;
@Data
@Builder
@AllArgsConstructor
@NoArgsConstructor
public class ChatResponse {
private ChatConstant.ResponseStatus status;
private String message;
private PrimaryChatResponse.UsageBean usageBean;
}

View File

@@ -1,111 +0,0 @@
package work.slhaf.partner.api.chat.pojo;
import lombok.Getter;
import lombok.Setter;
import java.util.List;
@Getter
@Setter
public class PrimaryChatResponse {
/**
* id
*/
private String id;
/**
* object
*/
private String object;
/**
* created
*/
private int created;
/**
* model
*/
private String model;
/**
* choices
*/
private List<ChoicesBean> choices;
/**
* usage
*/
private UsageBean usage;
/**
* system_fingerprint
*/
private String system_fingerprint;
@Setter
@Getter
public static class UsageBean {
/**
* prompt_tokens
*/
private int prompt_tokens;
/**
* completion_tokens
*/
private int completion_tokens;
/**
* total_tokens
*/
private int total_tokens;
/**
* prompt_cache_hit_tokens
*/
private int prompt_cache_hit_tokens;
/**
* prompt_cache_miss_tokens
*/
private int prompt_cache_miss_tokens;
@Override
public String toString() {
return "UsageBean{" +
"prompt_tokens=" + prompt_tokens +
", completion_tokens=" + completion_tokens +
", total_tokens=" + total_tokens +
", prompt_cache_hit_tokens=" + prompt_cache_hit_tokens +
", prompt_cache_miss_tokens=" + prompt_cache_miss_tokens +
'}';
}
}
@Setter
@Getter
public static class ChoicesBean {
/**
* index
*/
private int index;
/**
* message
*/
private MessageBean message;
/**
* logprobs
*/
private Object logprobs;
/**
* finish_reason
*/
private String finish_reason;
@Setter
@Getter
public static class MessageBean {
/**
* role
*/
private String role;
/**
* content
*/
private String content;
}
}
}

View File

@@ -0,0 +1,77 @@
package work.slhaf.partner.api.chat.runtime;
import com.openai.client.OpenAIClient;
import com.openai.client.okhttp.OpenAIOkHttpClient;
import com.openai.core.http.StreamResponse;
import com.openai.helpers.ChatCompletionAccumulator;
import com.openai.models.chat.completions.*;
import work.slhaf.partner.api.chat.pojo.Message;
import java.time.Duration;
import java.util.List;
public class OpenAiChatRuntime {
private final OpenAIClient client;
private final String model;
public OpenAiChatRuntime(String baseUrl, String apikey, String model) {
this.client = OpenAIOkHttpClient.builder()
.baseUrl(baseUrl)
.apiKey(apikey)
.timeout(Duration.ofSeconds(30))
.build();
this.model = model;
}
public String chat(List<Message> messages, boolean streaming) {
ChatCompletionCreateParams params = buildParams(messages);
if (!streaming) {
return extractText(client.chat().completions().create(params));
}
ChatCompletionAccumulator accumulator = ChatCompletionAccumulator.create();
try (StreamResponse<ChatCompletionChunk> response = client.chat().completions().createStreaming(params)) {
response.stream().forEach(accumulator::accumulate);
}
return extractText(accumulator.chatCompletion());
}
public <T> T formattedChat(List<Message> messages, boolean streaming, Class<T> responseType) {
StructuredChatCompletionCreateParams<T> params = buildParams(messages).toBuilder()
.responseFormat(responseType)
.build();
if (!streaming) {
return extractStructured(client.chat().completions().create(params));
}
ChatCompletionAccumulator accumulator = ChatCompletionAccumulator.create();
try (StreamResponse<ChatCompletionChunk> response = client.chat().completions().createStreaming(params.rawParams())) {
response.stream().forEach(accumulator::accumulate);
}
return extractStructured(accumulator.chatCompletion(responseType));
}
private ChatCompletionCreateParams buildParams(List<Message> messages) {
return ChatCompletionCreateParams.builder()
.model(model)
.messages(OpenAiMessageAdapter.toParams(messages))
.build();
}
private String extractText(ChatCompletion completion) {
if (completion.choices().isEmpty()) {
throw new IllegalStateException("OpenAI chat completion returned no choices.");
}
return completion.choices().getFirst().message().content()
.orElseThrow(() -> new IllegalStateException("OpenAI chat completion returned empty content."));
}
private <T> T extractStructured(StructuredChatCompletion<T> completion) {
if (completion.choices().isEmpty()) {
throw new IllegalStateException("OpenAI structured chat completion returned no choices.");
}
return completion.choices().getFirst().message().content()
.orElseThrow(() -> new IllegalStateException("OpenAI structured chat completion returned empty content."));
}
}

View File

@@ -0,0 +1,40 @@
package work.slhaf.partner.api.chat.runtime;
import com.openai.models.chat.completions.ChatCompletionAssistantMessageParam;
import com.openai.models.chat.completions.ChatCompletionMessageParam;
import com.openai.models.chat.completions.ChatCompletionSystemMessageParam;
import com.openai.models.chat.completions.ChatCompletionUserMessageParam;
import work.slhaf.partner.api.chat.constant.ChatConstant;
import work.slhaf.partner.api.chat.pojo.Message;
import java.util.ArrayList;
import java.util.List;
public final class OpenAiMessageAdapter {
private OpenAiMessageAdapter() {
}
public static List<ChatCompletionMessageParam> toParams(List<Message> messages) {
List<ChatCompletionMessageParam> params = new ArrayList<>(messages.size());
for (Message message : messages) {
params.add(toParam(message));
}
return params;
}
public static ChatCompletionMessageParam toParam(Message message) {
return switch (message.getRole()) {
case ChatConstant.Character.SYSTEM -> ChatCompletionMessageParam.ofSystem(
ChatCompletionSystemMessageParam.builder().content(message.getContent()).build()
);
case ChatConstant.Character.ASSISTANT -> ChatCompletionMessageParam.ofAssistant(
ChatCompletionAssistantMessageParam.builder().content(message.getContent()).build()
);
case ChatConstant.Character.USER -> ChatCompletionMessageParam.ofUser(
ChatCompletionUserMessageParam.builder().content(message.getContent()).build()
);
default -> throw new IllegalArgumentException("Unsupported message role: " + message.getRole());
};
}
}