mirror of
https://github.com/slhaf/Partner.git
synced 2026-05-12 08:43:02 +08:00
refactor(chat): extract model activation into provider registry
- move ActivateModel and StreamChatMessageConsumer into api.chat - replace direct OpenAI runtime construction with ModelRuntimeRegistry - add provider config, runtime override and OpenAI-compatible provider forking - rename OpenAiChatRuntime to OpenAiCompatibleProvider and update imports
This commit is contained in:
@@ -5,13 +5,13 @@ import com.alibaba.fastjson2.JSONObject
|
||||
import org.slf4j.LoggerFactory
|
||||
import work.slhaf.partner.api.agent.factory.AgentBaseFactory
|
||||
import work.slhaf.partner.api.agent.factory.component.abstracts.AbstractAgentModule
|
||||
import work.slhaf.partner.api.agent.factory.component.abstracts.ActivateModel
|
||||
import work.slhaf.partner.api.agent.factory.component.annotation.AgentComponent
|
||||
import work.slhaf.partner.api.agent.factory.component.exception.ModuleFactoryInitFailedException
|
||||
import work.slhaf.partner.api.agent.factory.config.pojo.ModelConfig
|
||||
import work.slhaf.partner.api.agent.factory.context.AgentContext
|
||||
import work.slhaf.partner.api.agent.factory.context.AgentRegisterContext
|
||||
import work.slhaf.partner.api.agent.factory.context.ModuleContextData
|
||||
import work.slhaf.partner.api.chat.ActivateModel
|
||||
import java.lang.reflect.Modifier
|
||||
import java.time.ZonedDateTime
|
||||
|
||||
|
||||
@@ -3,11 +3,7 @@ package work.slhaf.partner.api.agent.factory.component.abstracts
|
||||
import org.slf4j.Logger
|
||||
import org.slf4j.LoggerFactory
|
||||
import work.slhaf.partner.api.agent.factory.component.annotation.AgentComponent
|
||||
import work.slhaf.partner.api.agent.runtime.config.AgentConfigLoader
|
||||
import work.slhaf.partner.api.agent.runtime.interaction.flow.RunningFlowContext
|
||||
import work.slhaf.partner.api.chat.pojo.Message
|
||||
import work.slhaf.partner.api.chat.runtime.OpenAiChatRuntime
|
||||
import work.slhaf.partner.api.chat.runtime.StreamChatMessageConsumer
|
||||
|
||||
/**
|
||||
* 模块基类
|
||||
@@ -35,55 +31,3 @@ sealed class AbstractAgentModule {
|
||||
|
||||
}
|
||||
|
||||
interface ActivateModel {
|
||||
|
||||
val runtime: OpenAiChatRuntime
|
||||
get() = runtimeMap.computeIfAbsent(modelKey()) {
|
||||
buildRuntime()
|
||||
}
|
||||
|
||||
companion object {
|
||||
val runtimeMap: MutableMap<String, OpenAiChatRuntime> = mutableMapOf()
|
||||
private val configManager: AgentConfigLoader = AgentConfigLoader.INSTANCE
|
||||
}
|
||||
|
||||
fun buildRuntime(): OpenAiChatRuntime {
|
||||
val modelConfig = configManager.loadModelConfig(modelKey())
|
||||
return OpenAiChatRuntime(modelConfig.baseUrl, modelConfig.apikey, modelConfig.model)
|
||||
}
|
||||
|
||||
fun chat(messages: List<Message>): String {
|
||||
return runtime.chat(mergeMessages(messages))
|
||||
}
|
||||
|
||||
fun streamChat(messages: List<Message>, handler: StreamChatMessageConsumer) {
|
||||
return runtime.streamChat(mergeMessages(messages), handler)
|
||||
}
|
||||
|
||||
fun <T : Any> formattedChat(messages: List<Message>, responseType: Class<T>): T {
|
||||
return runtime.formattedChat(mergeMessages(messages), responseType)
|
||||
}
|
||||
|
||||
fun mergeMessages(messages: List<Message>): List<Message> {
|
||||
if (modulePrompt().isEmpty()) {
|
||||
return messages
|
||||
}
|
||||
return buildList {
|
||||
addAll(modulePrompt())
|
||||
addAll(messages)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 对应调用的模型配置名称
|
||||
*/
|
||||
fun modelKey(): String {
|
||||
return if (this is AbstractAgentModule) {
|
||||
this.moduleName
|
||||
} else {
|
||||
javaClass.simpleName
|
||||
}
|
||||
}
|
||||
|
||||
fun modulePrompt(): List<Message> = emptyList()
|
||||
}
|
||||
|
||||
@@ -0,0 +1,42 @@
|
||||
package work.slhaf.partner.api.chat
|
||||
|
||||
import work.slhaf.partner.api.agent.factory.component.abstracts.AbstractAgentModule
|
||||
import work.slhaf.partner.api.chat.pojo.Message
|
||||
|
||||
interface ActivateModel {
|
||||
|
||||
fun chat(messages: List<Message>): String {
|
||||
return ModelRuntimeRegistry.resolveProvider(modelKey()).chat(mergeMessages(messages))
|
||||
}
|
||||
|
||||
fun streamChat(messages: List<Message>, handler: StreamChatMessageConsumer) {
|
||||
ModelRuntimeRegistry.resolveProvider(modelKey()).streamChat(mergeMessages(messages), handler)
|
||||
}
|
||||
|
||||
fun <T : Any> formattedChat(messages: List<Message>, responseType: Class<T>): T {
|
||||
return ModelRuntimeRegistry.resolveProvider(modelKey()).formattedChat(mergeMessages(messages), responseType)
|
||||
}
|
||||
|
||||
fun mergeMessages(messages: List<Message>): List<Message> {
|
||||
if (modulePrompt().isEmpty()) {
|
||||
return messages
|
||||
}
|
||||
return buildList {
|
||||
addAll(modulePrompt())
|
||||
addAll(messages)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 对应调用的模型配置名称
|
||||
*/
|
||||
fun modelKey(): String {
|
||||
return if (this is AbstractAgentModule) {
|
||||
this.moduleName
|
||||
} else {
|
||||
javaClass.simpleName
|
||||
}
|
||||
}
|
||||
|
||||
fun modulePrompt(): List<Message> = emptyList()
|
||||
}
|
||||
@@ -0,0 +1,76 @@
|
||||
package work.slhaf.partner.api.chat
|
||||
|
||||
import work.slhaf.partner.api.chat.provider.ModelProvider
|
||||
import work.slhaf.partner.api.chat.provider.ProviderOverride
|
||||
import work.slhaf.partner.api.chat.provider.openai.OpenAiCompatibleProvider
|
||||
|
||||
object ModelRuntimeRegistry {
|
||||
|
||||
private const val DEFAULT_PROVIDER = "default"
|
||||
|
||||
/**
|
||||
* 基础的 provider 提供商,可 fork 出新的 runtime provider,必须提供一个 default provider
|
||||
*/
|
||||
private val baseProvider = mutableMapOf<String, ModelProvider>()
|
||||
|
||||
/**
|
||||
* 根据模块进行对应的 provider
|
||||
*/
|
||||
private val runtimeProvider = mutableMapOf<String, ModelProvider>()
|
||||
|
||||
fun resolveProvider(modelKey: String): ModelProvider {
|
||||
val provider = runtimeProvider[modelKey]
|
||||
if (provider != null) {
|
||||
return provider
|
||||
}
|
||||
return baseProvider[DEFAULT_PROVIDER]!!
|
||||
}
|
||||
|
||||
private fun registerProvider(config: ProviderConfig) {
|
||||
when (config) {
|
||||
is OpenAiCompatibleProviderConfig -> baseProvider[config.name] =
|
||||
OpenAiCompatibleProvider(config.baseUrl, config.apiKey, config.defaultModel)
|
||||
}
|
||||
}
|
||||
|
||||
private fun forkProvider(config: RuntimeProviderConfig) {
|
||||
val provider = baseProvider[config.providerName]
|
||||
?: throw IllegalArgumentException("Provider ${config.providerName} not found")
|
||||
val override = config.override
|
||||
|
||||
runtimeProvider[config.modelKey] = if (override != null) {
|
||||
provider.fork(override)
|
||||
} else {
|
||||
provider
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
data class RuntimeProviderConfig(
|
||||
val modelKey: String,
|
||||
val providerName: String,
|
||||
|
||||
val override: ProviderOverride?
|
||||
)
|
||||
|
||||
|
||||
sealed class ProviderConfig {
|
||||
abstract val name: String
|
||||
abstract val type: ProviderType
|
||||
abstract val defaultModel: String
|
||||
|
||||
enum class ProviderType {
|
||||
OPENAI_COMPATIBLE
|
||||
}
|
||||
}
|
||||
|
||||
data class OpenAiCompatibleProviderConfig(
|
||||
override val name: String,
|
||||
override val type: ProviderType = ProviderType.OPENAI_COMPATIBLE,
|
||||
override val defaultModel: String,
|
||||
|
||||
val baseUrl: String,
|
||||
val apiKey: String
|
||||
) : ProviderConfig()
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
package work.slhaf.partner.api.chat.runtime;
|
||||
package work.slhaf.partner.api.chat;
|
||||
|
||||
public abstract class StreamChatMessageConsumer {
|
||||
private final StringBuilder responseText = new StringBuilder();
|
||||
@@ -0,0 +1,27 @@
|
||||
package work.slhaf.partner.api.chat.provider
|
||||
|
||||
import com.alibaba.fastjson2.JSONObject
|
||||
import work.slhaf.partner.api.chat.StreamChatMessageConsumer
|
||||
import work.slhaf.partner.api.chat.pojo.Message
|
||||
|
||||
abstract class ModelProvider @JvmOverloads constructor(
|
||||
val override: ProviderOverride? = null
|
||||
) {
|
||||
|
||||
abstract fun fork(override: ProviderOverride): ModelProvider
|
||||
|
||||
abstract fun streamChat(messages: List<Message>, consumer: StreamChatMessageConsumer)
|
||||
|
||||
abstract fun chat(messages: List<Message>): String
|
||||
|
||||
abstract fun <T> formattedChat(messages: List<Message>, type: Class<T>): T
|
||||
}
|
||||
|
||||
data class ProviderOverride(
|
||||
val model: String,
|
||||
val temperature: Double?,
|
||||
val topP: Double?,
|
||||
val maxTokens: Int?,
|
||||
|
||||
val extras: JSONObject?
|
||||
)
|
||||
@@ -1,34 +1,57 @@
|
||||
package work.slhaf.partner.api.chat.runtime;
|
||||
package work.slhaf.partner.api.chat.provider.openai;
|
||||
|
||||
import com.alibaba.fastjson2.JSONObject;
|
||||
import com.openai.client.OpenAIClient;
|
||||
import com.openai.client.okhttp.OpenAIOkHttpClient;
|
||||
import com.openai.core.JsonValue;
|
||||
import com.openai.core.http.StreamResponse;
|
||||
import com.openai.models.chat.completions.*;
|
||||
import org.jetbrains.annotations.NotNull;
|
||||
import work.slhaf.partner.api.chat.StreamChatMessageConsumer;
|
||||
import work.slhaf.partner.api.chat.pojo.Message;
|
||||
import work.slhaf.partner.api.chat.provider.ModelProvider;
|
||||
import work.slhaf.partner.api.chat.provider.ProviderOverride;
|
||||
|
||||
import java.time.Duration;
|
||||
import java.util.List;
|
||||
|
||||
public class OpenAiChatRuntime {
|
||||
public class OpenAiCompatibleProvider extends ModelProvider {
|
||||
|
||||
private final OpenAIClient client;
|
||||
private final String baseUrl;
|
||||
private final String apiKey;
|
||||
private final String model;
|
||||
|
||||
public OpenAiChatRuntime(String baseUrl, String apikey, String model) {
|
||||
private final OpenAIClient client;
|
||||
|
||||
public OpenAiCompatibleProvider(String baseUrl, String apikey, String model) {
|
||||
this.client = OpenAIOkHttpClient.builder()
|
||||
.baseUrl(baseUrl)
|
||||
.apiKey(apikey)
|
||||
.timeout(Duration.ofSeconds(30))
|
||||
.build();
|
||||
this.baseUrl = baseUrl;
|
||||
this.apiKey = apikey;
|
||||
this.model = model;
|
||||
}
|
||||
|
||||
public String chat(List<Message> messages) {
|
||||
public OpenAiCompatibleProvider(String baseUrl, String apikey, String model, ProviderOverride override) {
|
||||
super(override);
|
||||
this.client = OpenAIOkHttpClient.builder()
|
||||
.baseUrl(baseUrl)
|
||||
.apiKey(apikey)
|
||||
.timeout(Duration.ofSeconds(30))
|
||||
.build();
|
||||
this.baseUrl = baseUrl;
|
||||
this.apiKey = apikey;
|
||||
this.model = model;
|
||||
}
|
||||
|
||||
public @NotNull String chat(@NotNull List<Message> messages) {
|
||||
ChatCompletionCreateParams params = buildParams(messages);
|
||||
return extractText(client.chat().completions().create(params));
|
||||
}
|
||||
|
||||
public void streamChat(List<Message> messages, StreamChatMessageConsumer handler) {
|
||||
public void streamChat(@NotNull List<Message> messages, StreamChatMessageConsumer handler) {
|
||||
ChatCompletionCreateParams params = buildParams(messages);
|
||||
try (StreamResponse<ChatCompletionChunk> streamResponse = client.chat().completions().createStreaming(params)) {
|
||||
streamResponse.stream()
|
||||
@@ -39,7 +62,7 @@ public class OpenAiChatRuntime {
|
||||
}
|
||||
}
|
||||
|
||||
public <T> T formattedChat(List<Message> messages, Class<T> responseType) {
|
||||
public <T> T formattedChat(@NotNull List<Message> messages, @NotNull Class<T> responseType) {
|
||||
StructuredChatCompletionCreateParams<T> params = buildParams(messages).toBuilder()
|
||||
.responseFormat(responseType)
|
||||
.build();
|
||||
@@ -47,10 +70,30 @@ public class OpenAiChatRuntime {
|
||||
}
|
||||
|
||||
private ChatCompletionCreateParams buildParams(List<Message> messages) {
|
||||
return ChatCompletionCreateParams.builder()
|
||||
ChatCompletionCreateParams.Builder paramsBuilder = ChatCompletionCreateParams.builder()
|
||||
.model(model)
|
||||
.messages(OpenAiMessageAdapter.toParams(messages))
|
||||
.build();
|
||||
.messages(OpenAiMessageAdapter.toParams(messages));
|
||||
|
||||
ProviderOverride override = getOverride();
|
||||
if (override != null) {
|
||||
if (override.getTemperature() != null) {
|
||||
paramsBuilder.temperature(override.getTemperature());
|
||||
}
|
||||
if (override.getTopP() != null) {
|
||||
paramsBuilder.topP(override.getTopP());
|
||||
}
|
||||
if (override.getMaxTokens() != null) {
|
||||
paramsBuilder.maxCompletionTokens(override.getMaxTokens());
|
||||
}
|
||||
JSONObject extras = override.getExtras();
|
||||
if (extras != null) {
|
||||
extras.forEach((key, value) -> {
|
||||
paramsBuilder.putAdditionalBodyProperty(key, JsonValue.from(value));
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
return paramsBuilder.build();
|
||||
}
|
||||
|
||||
private String extractText(ChatCompletion completion) {
|
||||
@@ -68,4 +111,9 @@ public class OpenAiChatRuntime {
|
||||
return completion.choices().getFirst().message().content()
|
||||
.orElseThrow(() -> new IllegalStateException("OpenAI structured chat completion returned empty content."));
|
||||
}
|
||||
|
||||
@Override
|
||||
public @NotNull ModelProvider fork(@NotNull ProviderOverride override) {
|
||||
return new OpenAiCompatibleProvider(baseUrl, apiKey, override.getModel(), override);
|
||||
}
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package work.slhaf.partner.api.chat.runtime;
|
||||
package work.slhaf.partner.api.chat.provider.openai;
|
||||
|
||||
import com.openai.models.chat.completions.ChatCompletionAssistantMessageParam;
|
||||
import com.openai.models.chat.completions.ChatCompletionMessageParam;
|
||||
@@ -2,7 +2,7 @@ package work.slhaf.partner.api.chat.pojo;
|
||||
|
||||
import com.alibaba.fastjson2.JSON;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import work.slhaf.partner.api.chat.runtime.OpenAiMessageAdapter;
|
||||
import work.slhaf.partner.api.chat.provider.openai.OpenAiMessageAdapter;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import static org.junit.jupiter.api.Assertions.assertThrows;
|
||||
|
||||
Reference in New Issue
Block a user