mirror of
https://github.com/slhaf/Partner.git
synced 2026-05-12 08:43:02 +08:00
refactor(chat): replace custom client with OpenAI runtime and remove file-based module prompt loading logic, prompt will be provided by each module
This commit is contained in:
@@ -8,11 +8,9 @@ import work.slhaf.partner.api.agent.factory.component.abstracts.AbstractAgentMod
|
||||
import work.slhaf.partner.api.agent.factory.component.abstracts.ActivateModel
|
||||
import work.slhaf.partner.api.agent.factory.component.annotation.AgentComponent
|
||||
import work.slhaf.partner.api.agent.factory.component.exception.ModuleFactoryInitFailedException
|
||||
import work.slhaf.partner.api.agent.factory.config.exception.PromptNotExistException
|
||||
import work.slhaf.partner.api.agent.factory.config.pojo.ModelConfig
|
||||
import work.slhaf.partner.api.agent.factory.context.AgentRegisterContext
|
||||
import work.slhaf.partner.api.agent.factory.context.ModuleContextData
|
||||
import work.slhaf.partner.api.chat.pojo.Message
|
||||
import java.lang.reflect.Modifier
|
||||
import java.time.ZonedDateTime
|
||||
|
||||
@@ -21,7 +19,7 @@ import java.time.ZonedDateTime
|
||||
*
|
||||
* 行为:
|
||||
* - 若实例是 [AbstractAgentModule],按 Running/Sub/Standalone 构造 `ModuleContextData` 并注册到 modules。
|
||||
* - 若实现了 [ActivateModel],必须存在对应 `modelPromptMap` 条目,随后构建 `modelInfo`。
|
||||
* - 若实现了 [ActivateModel],使用模块提供的 prompt 元数据构建 `modelInfo`。
|
||||
* - 若不是模块类型,尝试注册为 additional component(失败仅记录错误日志)。
|
||||
*/
|
||||
class ComponentRegisterFactory : AgentBaseFactory() {
|
||||
@@ -35,7 +33,6 @@ class ComponentRegisterFactory : AgentBaseFactory() {
|
||||
val agentContext = context.agentContext
|
||||
|
||||
val modelConfigMap = configFactoryContext.modelConfigMap
|
||||
val modelPromptMap = configFactoryContext.modelPromptMap
|
||||
val defaultConfig = modelConfigMap["default"]!!
|
||||
|
||||
reflections.getTypesAnnotatedWith(AgentComponent::class.java)
|
||||
@@ -56,7 +53,6 @@ class ComponentRegisterFactory : AgentBaseFactory() {
|
||||
componentClass,
|
||||
componentInstance,
|
||||
modelConfigMap,
|
||||
modelPromptMap,
|
||||
defaultConfig
|
||||
)
|
||||
} else {
|
||||
@@ -71,7 +67,6 @@ class ComponentRegisterFactory : AgentBaseFactory() {
|
||||
componentClass: Class<*>,
|
||||
module: AbstractAgentModule,
|
||||
modelConfigMap: Map<String, ModelConfig>,
|
||||
modelPromptMap: Map<String, List<Message>>,
|
||||
defaultConfig: ModelConfig
|
||||
) {
|
||||
if (agentContext.modules.containsKey(module.moduleName)) {
|
||||
@@ -84,12 +79,10 @@ class ComponentRegisterFactory : AgentBaseFactory() {
|
||||
val modelInfo = if (module is ActivateModel) {
|
||||
val modelKey = module.modelKey()
|
||||
val modelConfig = modelConfigMap[modelKey] ?: defaultConfig
|
||||
val modelPrompt = modelPromptMap[modelKey]
|
||||
?: throw PromptNotExistException("不存在的modelPrompt: $modelKey")
|
||||
ModuleContextData.ModelInfo(
|
||||
modelConfig.baseUrl,
|
||||
modelConfig.model,
|
||||
JSONArray.parseArray(JSONObject.toJSONString(modelPrompt))
|
||||
JSONArray.parseArray(JSONObject.toJSONString(module.modulePrompt()))
|
||||
)
|
||||
} else {
|
||||
null
|
||||
|
||||
@@ -3,13 +3,10 @@ package work.slhaf.partner.api.agent.factory.component.abstracts
|
||||
import org.slf4j.Logger
|
||||
import org.slf4j.LoggerFactory
|
||||
import work.slhaf.partner.api.agent.factory.component.annotation.AgentComponent
|
||||
import work.slhaf.partner.api.agent.factory.component.annotation.Init
|
||||
import work.slhaf.partner.api.agent.runtime.config.AgentConfigLoader
|
||||
import work.slhaf.partner.api.agent.runtime.interaction.flow.RunningFlowContext
|
||||
import work.slhaf.partner.api.chat.ChatClient
|
||||
import work.slhaf.partner.api.chat.constant.ChatConstant
|
||||
import work.slhaf.partner.api.chat.pojo.ChatResponse
|
||||
import work.slhaf.partner.api.chat.pojo.Message
|
||||
import work.slhaf.partner.api.chat.runtime.OpenAiChatRuntime
|
||||
|
||||
/**
|
||||
* 模块基类
|
||||
@@ -39,58 +36,37 @@ sealed class AbstractAgentModule {
|
||||
|
||||
interface ActivateModel {
|
||||
|
||||
val model: Model
|
||||
get() = modelMap.computeIfAbsent(modelKey()) {
|
||||
buildModel()
|
||||
val runtime: OpenAiChatRuntime
|
||||
get() = runtimeMap.computeIfAbsent(modelKey()) {
|
||||
buildRuntime()
|
||||
}
|
||||
|
||||
companion object {
|
||||
val modelMap: MutableMap<String, Model> = mutableMapOf()
|
||||
val runtimeMap: MutableMap<String, OpenAiChatRuntime> = mutableMapOf()
|
||||
private val configManager: AgentConfigLoader = AgentConfigLoader.INSTANCE
|
||||
}
|
||||
|
||||
@Init(order = -1)
|
||||
fun modelSettings() {
|
||||
modelMap[modelKey()] = buildModel()
|
||||
}
|
||||
|
||||
fun buildModel(): Model {
|
||||
fun buildRuntime(): OpenAiChatRuntime {
|
||||
val modelConfig = configManager.loadModelConfig(modelKey())
|
||||
val chatClient = ChatClient(modelConfig.baseUrl, modelConfig.apikey, modelConfig.model)
|
||||
val model = Model(chatClient)
|
||||
return OpenAiChatRuntime(modelConfig.baseUrl, modelConfig.apikey, modelConfig.model)
|
||||
}
|
||||
|
||||
val baseMessages = if (withBasicPrompt()) {
|
||||
loadSpecificPromptAndBasicPrompt(modelKey())
|
||||
} else {
|
||||
configManager.loadModelPrompt(modelKey())
|
||||
fun chat(messages: List<Message>): String {
|
||||
return runtime.chat(mergeMessages(messages), useStreaming())
|
||||
}
|
||||
|
||||
fun <T : Any> formattedChat(messages: List<Message>, responseType: Class<T>): T {
|
||||
return runtime.formattedChat(mergeMessages(messages), useStreaming(), responseType)
|
||||
}
|
||||
|
||||
fun mergeMessages(messages: List<Message>): List<Message> {
|
||||
if (modulePrompt().isEmpty()) {
|
||||
return messages
|
||||
}
|
||||
return buildList {
|
||||
addAll(modulePrompt())
|
||||
addAll(messages)
|
||||
}
|
||||
model.baseMessages.addAll(baseMessages)
|
||||
return model
|
||||
}
|
||||
|
||||
private fun loadSpecificPromptAndBasicPrompt(modelKey: String): MutableList<Message> {
|
||||
val messages: MutableList<Message> = ArrayList()
|
||||
messages.addAll(configManager.loadModelPrompt("basic"))
|
||||
messages.addAll(configManager.loadModelPrompt(modelKey))
|
||||
return messages
|
||||
}
|
||||
|
||||
fun chat(): ChatResponse {
|
||||
val temp = ArrayList<Message?>()
|
||||
temp.addAll(model.baseMessages)
|
||||
temp.addAll(model.chatMessages)
|
||||
return model.chatClient.runChat(temp)
|
||||
}
|
||||
|
||||
fun singleChat(input: String): ChatResponse {
|
||||
val temp = ArrayList<Message>(model.baseMessages)
|
||||
temp.add(Message(ChatConstant.Character.USER, input))
|
||||
return model.chatClient.runChat(temp)
|
||||
}
|
||||
|
||||
fun updateChatClientSettings() {
|
||||
model.chatClient.temperature = 0.4
|
||||
model.chatClient.top_p = 0.8
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -104,11 +80,7 @@ interface ActivateModel {
|
||||
}
|
||||
}
|
||||
|
||||
fun withBasicPrompt(): Boolean
|
||||
fun modulePrompt(): List<Message> = emptyList()
|
||||
|
||||
data class Model(
|
||||
val chatClient: ChatClient,
|
||||
val chatMessages: MutableList<Message> = mutableListOf(),
|
||||
val baseMessages: MutableList<Message> = mutableListOf()
|
||||
)
|
||||
fun useStreaming(): Boolean = false
|
||||
}
|
||||
|
||||
@@ -3,7 +3,6 @@ package work.slhaf.partner.api.agent.factory.config
|
||||
import org.slf4j.LoggerFactory
|
||||
import work.slhaf.partner.api.agent.factory.AgentBaseFactory
|
||||
import work.slhaf.partner.api.agent.factory.config.exception.ConfigNotExistException
|
||||
import work.slhaf.partner.api.agent.factory.config.exception.PromptNotExistException
|
||||
import work.slhaf.partner.api.agent.factory.context.AgentRegisterContext
|
||||
import work.slhaf.partner.api.agent.runtime.config.AgentConfigLoader
|
||||
import work.slhaf.partner.api.agent.runtime.config.FileAgentConfigLoader
|
||||
@@ -14,8 +13,8 @@ import java.lang.reflect.Modifier
|
||||
*
|
||||
* 行为:
|
||||
* - 使用全局 `AgentConfigLoader.INSTANCE`,为空时退回 [FileAgentConfigLoader]。
|
||||
* - 加载并写入 `modelConfigMap`、`modelPromptMap` 到 `ConfigFactoryContext`。
|
||||
* - 校验 `default` 配置与 `basic` 提示词是否存在。
|
||||
* - 加载并写入 `modelConfigMap` 到 `ConfigFactoryContext`。
|
||||
* - 校验 `default` 配置是否存在。
|
||||
* - 反射读取配置加载器实现类(相对基类新增)的静态字段,并写入 `AgentContext.metadata`。
|
||||
*/
|
||||
class ConfigLoaderFactory : AgentBaseFactory() {
|
||||
@@ -33,26 +32,16 @@ class ConfigLoaderFactory : AgentBaseFactory() {
|
||||
|
||||
val configFactoryContext = context.configFactoryContext
|
||||
configFactoryContext.modelConfigMap.putAll(agentConfigLoader.modelConfigMap)
|
||||
configFactoryContext.modelPromptMap.putAll(agentConfigLoader.modelPromptMap)
|
||||
|
||||
check(configFactoryContext.modelConfigMap.keys, configFactoryContext.modelPromptMap.keys)
|
||||
check(configFactoryContext.modelConfigMap.keys)
|
||||
collectLoaderMetadata(context, agentConfigLoader)
|
||||
}
|
||||
|
||||
private fun check(configKeys: Set<String>, promptKeys: Set<String>) {
|
||||
log.info("执行config与prompt检测...")
|
||||
private fun check(configKeys: Set<String>) {
|
||||
log.info("执行config检测...")
|
||||
if (!configKeys.contains("default")) {
|
||||
throw ConfigNotExistException("缺少默认配置! 需确保存在一个模型配置的key为`default`")
|
||||
}
|
||||
if (!promptKeys.contains("basic")) {
|
||||
throw PromptNotExistException("缺少基础Prompt! 需要确保存在key为basic的Prompt文件,它将与其他Prompt共同作用于模块节点。")
|
||||
}
|
||||
|
||||
val configKeySet = configKeys.toMutableSet().apply { remove("default") }
|
||||
val promptKeySet = promptKeys.toMutableSet().apply { remove("basic") }
|
||||
if (!promptKeySet.containsAll(configKeySet)) {
|
||||
log.warn("存在未被提示词包含的模型配置,该配置将无法生效!")
|
||||
}
|
||||
log.info("检测完毕.")
|
||||
}
|
||||
|
||||
|
||||
@@ -1,12 +0,0 @@
|
||||
package work.slhaf.partner.api.agent.factory.config.pojo;
|
||||
|
||||
import lombok.Data;
|
||||
import work.slhaf.partner.api.chat.pojo.Message;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
@Data
|
||||
public class PrimaryModelPrompt {
|
||||
private String key;
|
||||
private List<Message> messages;
|
||||
}
|
||||
@@ -4,7 +4,6 @@ import org.reflections.Reflections
|
||||
import org.reflections.scanners.Scanners
|
||||
import org.reflections.util.ConfigurationBuilder
|
||||
import work.slhaf.partner.api.agent.factory.config.pojo.ModelConfig
|
||||
import work.slhaf.partner.api.chat.pojo.Message
|
||||
import java.lang.reflect.Method
|
||||
import java.net.URL
|
||||
|
||||
@@ -25,7 +24,6 @@ class AgentRegisterContext(urls: List<URL>) {
|
||||
}
|
||||
|
||||
class ConfigFactoryContext {
|
||||
val modelPromptMap: HashMap<String, List<Message>> = HashMap()
|
||||
val modelConfigMap: HashMap<String, ModelConfig> = HashMap()
|
||||
}
|
||||
|
||||
|
||||
@@ -3,12 +3,9 @@ package work.slhaf.partner.api.agent.runtime.config;
|
||||
import lombok.Data;
|
||||
import lombok.Setter;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import work.slhaf.partner.api.agent.factory.config.exception.PromptNotExistException;
|
||||
import work.slhaf.partner.api.agent.factory.config.pojo.ModelConfig;
|
||||
import work.slhaf.partner.api.chat.pojo.Message;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
|
||||
@Slf4j
|
||||
@Data
|
||||
@@ -18,45 +15,22 @@ public abstract class AgentConfigLoader {
|
||||
@Setter
|
||||
public static AgentConfigLoader INSTANCE;
|
||||
protected HashMap<String, ModelConfig> modelConfigMap;
|
||||
protected HashMap<String, List<Message>> modelPromptMap;
|
||||
|
||||
public void load() {
|
||||
modelConfigMap = loadModelConfig();
|
||||
modelPromptMap = loadModelPrompt();
|
||||
}
|
||||
|
||||
protected abstract HashMap<String, List<Message>> loadModelPrompt();
|
||||
|
||||
protected abstract HashMap<String, ModelConfig> loadModelConfig();
|
||||
|
||||
public abstract void dumpModelConfig(String key);
|
||||
|
||||
// Keep explicit getters for Kotlin compilation phase (without Lombok-generated methods).
|
||||
public HashMap<String, ModelConfig> getModelConfigMap() {
|
||||
return modelConfigMap;
|
||||
}
|
||||
|
||||
public HashMap<String, List<Message>> getModelPromptMap() {
|
||||
return modelPromptMap;
|
||||
}
|
||||
|
||||
public List<Message> loadModelPrompt(String modelKey) {
|
||||
if (!modelPromptMap.containsKey(modelKey)) {
|
||||
throw new PromptNotExistException("不存在的modelPrompt: " + modelKey);
|
||||
}
|
||||
return modelPromptMap.get(modelKey);
|
||||
}
|
||||
|
||||
public ModelConfig loadModelConfig(String modelKey) {
|
||||
if (!modelConfigMap.containsKey(modelKey)) {
|
||||
return modelConfigMap.get(DEFAULT_KEY);
|
||||
}
|
||||
return modelConfigMap.get(modelKey);
|
||||
}
|
||||
|
||||
public void updateModelConfig(String modelKey, ModelConfig config) {
|
||||
modelConfigMap.put(modelKey, config);
|
||||
dumpModelConfig(modelKey);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -2,17 +2,14 @@ package work.slhaf.partner.api.agent.runtime.config;
|
||||
|
||||
import cn.hutool.json.JSONUtil;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.io.FileUtils;
|
||||
import work.slhaf.partner.api.agent.factory.config.exception.*;
|
||||
import work.slhaf.partner.api.agent.factory.config.exception.ConfigDirNotExistException;
|
||||
import work.slhaf.partner.api.agent.factory.config.exception.ConfigNotExistException;
|
||||
import work.slhaf.partner.api.agent.factory.config.pojo.ModelConfig;
|
||||
import work.slhaf.partner.api.agent.factory.config.pojo.PrimaryModelConfig;
|
||||
import work.slhaf.partner.api.agent.factory.config.pojo.PrimaryModelPrompt;
|
||||
import work.slhaf.partner.api.chat.pojo.Message;
|
||||
|
||||
import java.io.File;
|
||||
import java.nio.charset.StandardCharsets;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* 默认配置工厂
|
||||
@@ -23,28 +20,6 @@ public class FileAgentConfigLoader extends AgentConfigLoader {
|
||||
|
||||
protected static final String CONFIG_DIR = "./config/";
|
||||
protected static final String MODEL_CONFIG_DIR = "./config/model/";
|
||||
protected static final String PROMPT_CONFIG_DIR = "./config/prompt/";
|
||||
|
||||
@Override
|
||||
protected HashMap<String, List<Message>> loadModelPrompt() {
|
||||
File file = new File(PROMPT_CONFIG_DIR);
|
||||
if (!file.exists() && !file.isDirectory()) {
|
||||
throw new PromptDirNotExistException("未找到提示词目录: " + PROMPT_CONFIG_DIR + " 请手动创建!");
|
||||
}
|
||||
File[] files = file.listFiles();
|
||||
if (files == null || files.length == 0) {
|
||||
throw new PromptNotExistException("在目录 " + PROMPT_CONFIG_DIR + " 中未找到提示词配置!");
|
||||
}
|
||||
HashMap<String, List<Message>> promptMap = new HashMap<>();
|
||||
for (File f : files) {
|
||||
if (f.isDirectory()) {
|
||||
continue;
|
||||
}
|
||||
PrimaryModelPrompt primaryModelPrompt = JSONUtil.readJSONObject(f, StandardCharsets.UTF_8).toBean(PrimaryModelPrompt.class);
|
||||
promptMap.put(primaryModelPrompt.getKey(), primaryModelPrompt.getMessages());
|
||||
}
|
||||
return promptMap;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected HashMap<String, ModelConfig> loadModelConfig() {
|
||||
@@ -67,17 +42,4 @@ public class FileAgentConfigLoader extends AgentConfigLoader {
|
||||
}
|
||||
return configMap;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void dumpModelConfig(String key) {
|
||||
try {
|
||||
File file = new File(MODEL_CONFIG_DIR + key + ".json");
|
||||
if (!file.exists()) {
|
||||
file.createNewFile();
|
||||
}
|
||||
FileUtils.writeStringToFile(file, JSONUtil.toJsonPrettyStr(modelConfigMap.get(key)), StandardCharsets.UTF_8, false);
|
||||
} catch (Exception e) {
|
||||
throw new ConfigUpdateFailedException("ModelConfig 配置文件更新失败!");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,84 +0,0 @@
|
||||
package work.slhaf.partner.api.chat;
|
||||
|
||||
import cn.hutool.core.io.IORuntimeException;
|
||||
import cn.hutool.http.HttpRequest;
|
||||
import cn.hutool.http.HttpResponse;
|
||||
import cn.hutool.json.JSONUtil;
|
||||
import lombok.Data;
|
||||
import lombok.NoArgsConstructor;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import work.slhaf.partner.api.chat.constant.ChatConstant;
|
||||
import work.slhaf.partner.api.chat.pojo.ChatBody;
|
||||
import work.slhaf.partner.api.chat.pojo.ChatResponse;
|
||||
import work.slhaf.partner.api.chat.pojo.Message;
|
||||
import work.slhaf.partner.api.chat.pojo.PrimaryChatResponse;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
@Slf4j
|
||||
@Data
|
||||
@NoArgsConstructor
|
||||
public class ChatClient {
|
||||
private String clientId;
|
||||
|
||||
private String url;
|
||||
private String apikey;
|
||||
private String model;
|
||||
|
||||
private double top_p;
|
||||
private double temperature;
|
||||
private int max_tokens;
|
||||
|
||||
public ChatClient(String url, String apikey, String model) {
|
||||
this.url = url;
|
||||
this.apikey = apikey;
|
||||
this.model = model;
|
||||
}
|
||||
|
||||
public ChatResponse runChat(List<Message> messages) {
|
||||
HttpRequest request = HttpRequest.post(url);
|
||||
request.setConnectionTimeout(2000);
|
||||
request.setReadTimeout(15000);
|
||||
request.header("Content-Type", "application/json");
|
||||
request.header("Authorization", "Bearer " + apikey);
|
||||
|
||||
ChatBody body;
|
||||
if (top_p > 0) {
|
||||
body = ChatBody.builder()
|
||||
.model(model)
|
||||
.messages(messages)
|
||||
.top_p(top_p)
|
||||
.temperature(temperature)
|
||||
.max_tokens(max_tokens)
|
||||
.build();
|
||||
} else {
|
||||
body = ChatBody.builder()
|
||||
.model(model)
|
||||
.messages(messages)
|
||||
.build();
|
||||
}
|
||||
|
||||
ChatResponse finalResponse;
|
||||
|
||||
try {
|
||||
HttpResponse response = request.body(JSONUtil.toJsonStr(body)).execute();
|
||||
PrimaryChatResponse primaryChatResponse = JSONUtil.toBean(response.body(), PrimaryChatResponse.class);
|
||||
finalResponse = ChatResponse.builder()
|
||||
.status(ChatConstant.ResponseStatus.SUCCESS)
|
||||
.message(primaryChatResponse.getChoices().get(0).getMessage().getContent())
|
||||
.usageBean(primaryChatResponse.getUsage())
|
||||
.build();
|
||||
|
||||
response.close();
|
||||
} catch (IORuntimeException e) {
|
||||
log.error("请求超时", e);
|
||||
finalResponse = ChatResponse.builder()
|
||||
.message("连接超时")
|
||||
.status(ChatConstant.ResponseStatus.FAILED)
|
||||
.usageBean(null)
|
||||
.build();
|
||||
}
|
||||
return finalResponse;
|
||||
}
|
||||
|
||||
}
|
||||
@@ -1,25 +0,0 @@
|
||||
package work.slhaf.partner.api.chat.pojo;
|
||||
|
||||
import lombok.*;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
@Builder
|
||||
@Data
|
||||
@AllArgsConstructor
|
||||
@NoArgsConstructor
|
||||
public class ChatBody {
|
||||
@NonNull
|
||||
private String model;
|
||||
@NonNull
|
||||
private List<Message> messages;
|
||||
@Builder.Default
|
||||
private double temperature = 1;
|
||||
@Builder.Default
|
||||
private double top_p = 1;
|
||||
private boolean stream;
|
||||
@Builder.Default
|
||||
private int max_tokens = 1024;
|
||||
private int presence_penalty;
|
||||
private int frequency_penalty;
|
||||
}
|
||||
@@ -1,17 +0,0 @@
|
||||
package work.slhaf.partner.api.chat.pojo;
|
||||
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Builder;
|
||||
import lombok.Data;
|
||||
import lombok.NoArgsConstructor;
|
||||
import work.slhaf.partner.api.chat.constant.ChatConstant;
|
||||
|
||||
@Data
|
||||
@Builder
|
||||
@AllArgsConstructor
|
||||
@NoArgsConstructor
|
||||
public class ChatResponse {
|
||||
private ChatConstant.ResponseStatus status;
|
||||
private String message;
|
||||
private PrimaryChatResponse.UsageBean usageBean;
|
||||
}
|
||||
@@ -1,111 +0,0 @@
|
||||
package work.slhaf.partner.api.chat.pojo;
|
||||
|
||||
import lombok.Getter;
|
||||
import lombok.Setter;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
@Getter
|
||||
@Setter
|
||||
public class PrimaryChatResponse {
|
||||
|
||||
/**
|
||||
* id
|
||||
*/
|
||||
private String id;
|
||||
/**
|
||||
* object
|
||||
*/
|
||||
private String object;
|
||||
/**
|
||||
* created
|
||||
*/
|
||||
private int created;
|
||||
/**
|
||||
* model
|
||||
*/
|
||||
private String model;
|
||||
/**
|
||||
* choices
|
||||
*/
|
||||
private List<ChoicesBean> choices;
|
||||
/**
|
||||
* usage
|
||||
*/
|
||||
private UsageBean usage;
|
||||
/**
|
||||
* system_fingerprint
|
||||
*/
|
||||
private String system_fingerprint;
|
||||
|
||||
@Setter
|
||||
@Getter
|
||||
public static class UsageBean {
|
||||
/**
|
||||
* prompt_tokens
|
||||
*/
|
||||
private int prompt_tokens;
|
||||
/**
|
||||
* completion_tokens
|
||||
*/
|
||||
private int completion_tokens;
|
||||
/**
|
||||
* total_tokens
|
||||
*/
|
||||
private int total_tokens;
|
||||
/**
|
||||
* prompt_cache_hit_tokens
|
||||
*/
|
||||
private int prompt_cache_hit_tokens;
|
||||
/**
|
||||
* prompt_cache_miss_tokens
|
||||
*/
|
||||
private int prompt_cache_miss_tokens;
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return "UsageBean{" +
|
||||
"prompt_tokens=" + prompt_tokens +
|
||||
", completion_tokens=" + completion_tokens +
|
||||
", total_tokens=" + total_tokens +
|
||||
", prompt_cache_hit_tokens=" + prompt_cache_hit_tokens +
|
||||
", prompt_cache_miss_tokens=" + prompt_cache_miss_tokens +
|
||||
'}';
|
||||
}
|
||||
}
|
||||
|
||||
@Setter
|
||||
@Getter
|
||||
public static class ChoicesBean {
|
||||
/**
|
||||
* index
|
||||
*/
|
||||
private int index;
|
||||
/**
|
||||
* message
|
||||
*/
|
||||
private MessageBean message;
|
||||
/**
|
||||
* logprobs
|
||||
*/
|
||||
private Object logprobs;
|
||||
/**
|
||||
* finish_reason
|
||||
*/
|
||||
private String finish_reason;
|
||||
|
||||
@Setter
|
||||
@Getter
|
||||
public static class MessageBean {
|
||||
/**
|
||||
* role
|
||||
*/
|
||||
private String role;
|
||||
/**
|
||||
* content
|
||||
*/
|
||||
private String content;
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,77 @@
|
||||
package work.slhaf.partner.api.chat.runtime;
|
||||
|
||||
import com.openai.client.OpenAIClient;
|
||||
import com.openai.client.okhttp.OpenAIOkHttpClient;
|
||||
import com.openai.core.http.StreamResponse;
|
||||
import com.openai.helpers.ChatCompletionAccumulator;
|
||||
import com.openai.models.chat.completions.*;
|
||||
import work.slhaf.partner.api.chat.pojo.Message;
|
||||
|
||||
import java.time.Duration;
|
||||
import java.util.List;
|
||||
|
||||
public class OpenAiChatRuntime {
|
||||
|
||||
private final OpenAIClient client;
|
||||
private final String model;
|
||||
|
||||
public OpenAiChatRuntime(String baseUrl, String apikey, String model) {
|
||||
this.client = OpenAIOkHttpClient.builder()
|
||||
.baseUrl(baseUrl)
|
||||
.apiKey(apikey)
|
||||
.timeout(Duration.ofSeconds(30))
|
||||
.build();
|
||||
this.model = model;
|
||||
}
|
||||
|
||||
public String chat(List<Message> messages, boolean streaming) {
|
||||
ChatCompletionCreateParams params = buildParams(messages);
|
||||
if (!streaming) {
|
||||
return extractText(client.chat().completions().create(params));
|
||||
}
|
||||
|
||||
ChatCompletionAccumulator accumulator = ChatCompletionAccumulator.create();
|
||||
try (StreamResponse<ChatCompletionChunk> response = client.chat().completions().createStreaming(params)) {
|
||||
response.stream().forEach(accumulator::accumulate);
|
||||
}
|
||||
return extractText(accumulator.chatCompletion());
|
||||
}
|
||||
|
||||
public <T> T formattedChat(List<Message> messages, boolean streaming, Class<T> responseType) {
|
||||
StructuredChatCompletionCreateParams<T> params = buildParams(messages).toBuilder()
|
||||
.responseFormat(responseType)
|
||||
.build();
|
||||
if (!streaming) {
|
||||
return extractStructured(client.chat().completions().create(params));
|
||||
}
|
||||
|
||||
ChatCompletionAccumulator accumulator = ChatCompletionAccumulator.create();
|
||||
try (StreamResponse<ChatCompletionChunk> response = client.chat().completions().createStreaming(params.rawParams())) {
|
||||
response.stream().forEach(accumulator::accumulate);
|
||||
}
|
||||
return extractStructured(accumulator.chatCompletion(responseType));
|
||||
}
|
||||
|
||||
private ChatCompletionCreateParams buildParams(List<Message> messages) {
|
||||
return ChatCompletionCreateParams.builder()
|
||||
.model(model)
|
||||
.messages(OpenAiMessageAdapter.toParams(messages))
|
||||
.build();
|
||||
}
|
||||
|
||||
private String extractText(ChatCompletion completion) {
|
||||
if (completion.choices().isEmpty()) {
|
||||
throw new IllegalStateException("OpenAI chat completion returned no choices.");
|
||||
}
|
||||
return completion.choices().getFirst().message().content()
|
||||
.orElseThrow(() -> new IllegalStateException("OpenAI chat completion returned empty content."));
|
||||
}
|
||||
|
||||
private <T> T extractStructured(StructuredChatCompletion<T> completion) {
|
||||
if (completion.choices().isEmpty()) {
|
||||
throw new IllegalStateException("OpenAI structured chat completion returned no choices.");
|
||||
}
|
||||
return completion.choices().getFirst().message().content()
|
||||
.orElseThrow(() -> new IllegalStateException("OpenAI structured chat completion returned empty content."));
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,40 @@
|
||||
package work.slhaf.partner.api.chat.runtime;
|
||||
|
||||
import com.openai.models.chat.completions.ChatCompletionAssistantMessageParam;
|
||||
import com.openai.models.chat.completions.ChatCompletionMessageParam;
|
||||
import com.openai.models.chat.completions.ChatCompletionSystemMessageParam;
|
||||
import com.openai.models.chat.completions.ChatCompletionUserMessageParam;
|
||||
import work.slhaf.partner.api.chat.constant.ChatConstant;
|
||||
import work.slhaf.partner.api.chat.pojo.Message;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
public final class OpenAiMessageAdapter {
|
||||
|
||||
private OpenAiMessageAdapter() {
|
||||
}
|
||||
|
||||
public static List<ChatCompletionMessageParam> toParams(List<Message> messages) {
|
||||
List<ChatCompletionMessageParam> params = new ArrayList<>(messages.size());
|
||||
for (Message message : messages) {
|
||||
params.add(toParam(message));
|
||||
}
|
||||
return params;
|
||||
}
|
||||
|
||||
public static ChatCompletionMessageParam toParam(Message message) {
|
||||
return switch (message.getRole()) {
|
||||
case ChatConstant.Character.SYSTEM -> ChatCompletionMessageParam.ofSystem(
|
||||
ChatCompletionSystemMessageParam.builder().content(message.getContent()).build()
|
||||
);
|
||||
case ChatConstant.Character.ASSISTANT -> ChatCompletionMessageParam.ofAssistant(
|
||||
ChatCompletionAssistantMessageParam.builder().content(message.getContent()).build()
|
||||
);
|
||||
case ChatConstant.Character.USER -> ChatCompletionMessageParam.ofUser(
|
||||
ChatCompletionUserMessageParam.builder().content(message.getContent()).build()
|
||||
);
|
||||
default -> throw new IllegalArgumentException("Unsupported message role: " + message.getRole());
|
||||
};
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user