refactor(model): manage model registry by ConfigCenter

This commit is contained in:
2026-04-04 17:53:54 +08:00
parent 660bb01440
commit 9771aa1de5

View File

@@ -1,10 +1,18 @@
package work.slhaf.partner.api.agent.model package work.slhaf.partner.api.agent.model
import org.slf4j.LoggerFactory
import work.slhaf.partner.api.agent.model.provider.ModelProvider import work.slhaf.partner.api.agent.model.provider.ModelProvider
import work.slhaf.partner.api.agent.model.provider.ProviderOverride import work.slhaf.partner.api.agent.model.provider.ProviderOverride
import work.slhaf.partner.api.agent.model.provider.openai.OpenAiCompatibleProvider import work.slhaf.partner.api.agent.model.provider.openai.OpenAiCompatibleProvider
import work.slhaf.partner.api.agent.runtime.config.Config
import work.slhaf.partner.api.agent.runtime.config.ConfigDoc
import work.slhaf.partner.api.agent.runtime.config.ConfigRegistration
import work.slhaf.partner.api.agent.runtime.config.Configurable
import java.nio.file.Path
import java.util.concurrent.locks.ReentrantLock
import kotlin.concurrent.withLock
object ModelRuntimeRegistry { object ModelRuntimeRegistry : Configurable, ConfigRegistration<ModelRuntimeRegistryConfig> {
private const val DEFAULT_PROVIDER = "default" private const val DEFAULT_PROVIDER = "default"
@@ -18,7 +26,10 @@ object ModelRuntimeRegistry {
*/ */
private val runtimeProvider = mutableMapOf<String, ModelProvider>() private val runtimeProvider = mutableMapOf<String, ModelProvider>()
fun resolveProvider(modelKey: String): ModelProvider { private val providerLock = ReentrantLock()
private val log = LoggerFactory.getLogger(ModelRuntimeRegistry::class.java)
fun resolveProvider(modelKey: String): ModelProvider = providerLock.withLock {
val provider = runtimeProvider[modelKey] val provider = runtimeProvider[modelKey]
if (provider != null) { if (provider != null) {
return provider return provider
@@ -45,8 +56,86 @@ object ModelRuntimeRegistry {
} }
} }
override fun declare(): Map<Path, ConfigRegistration<out Config>> {
return mapOf(Path.of("model", "model.json") to this)
} }
override fun type(): Class<ModelRuntimeRegistryConfig> = ModelRuntimeRegistryConfig::class.java
override fun init(config: ModelRuntimeRegistryConfig) = providerLock.withLock {
config.providerConfigSet.forEach { registerProvider(it) }
config.runtimeConfigSet.forEach { forkProvider(it) }
}
override fun onReload(config: ModelRuntimeRegistryConfig) = providerLock.withLock {
val baseProviderSnapshot = baseProvider.toMap()
val runtimeProviderSnapshot = runtimeProvider.toMap()
try {
baseProvider.clear()
config.providerConfigSet.forEach { registerProvider(it) }
runtimeProvider.clear()
config.runtimeConfigSet.forEach { forkProvider(it) }
} catch (e: Exception) {
log.error("Error while loading runtime provider config", e)
baseProvider.clear()
baseProvider.putAll(baseProviderSnapshot)
runtimeProvider.clear()
runtimeProvider.putAll(runtimeProviderSnapshot)
}
}
override fun defaultConfig(): ModelRuntimeRegistryConfig? {
val defaultBaseUrl = System.getenv("PARTNER_DEFAULT_BASE_URL") ?: return null
val defaultApiKey = System.getenv("PARTNER_DEFAULT_API_KEY") ?: return null
val defaultModel = System.getenv("PARTNER_DEFAULT_MODEL") ?: return null
return ModelRuntimeRegistryConfig(
setOf(
OpenAiCompatibleProviderConfig(
"default",
ProviderConfig.ProviderType.OPENAI_COMPATIBLE,
defaultModel, defaultBaseUrl, defaultApiKey
)
), setOf()
)
}
}
data class ModelRuntimeRegistryConfig(
@field:ConfigDoc(
description = "提供商配置集合", example = """ [ {
"name": "example_provider_name",
"type": "OPENAI_COMPATIBLE",
"defaultModel": "example_default_model",
"baseUrl": "example_base_url",
"apiKey": "example_apikey"
}
]
"""
)
val providerConfigSet: Set<ProviderConfig>,
@field:ConfigDoc(
description = "模块所用提供商配置,可覆写 model、temperature、top_p、max_tokens 等配置",
example = """
[
{
"modelKey": "example_model_key", // 模块通过该 key 定位到对应的提供商配置
"providerName: "example_provider_name", // 该配置对应的提供商名称
"override": { // 该配置需要重写的内容, 如果无需重写,可忽略该字段,该字段的各个子字段均为可选覆写
"model": "example_override_model",
"temperature": "example_override_temperature",
"topP": "example_override_top_p",
"maxTokens": "example_override_max_tokens",
"extra": { // 要覆写的额外内容
"example1": "value1"
}
}
}
]
"""
)
val runtimeConfigSet: Set<RuntimeProviderConfig>
) : Config()
data class RuntimeProviderConfig( data class RuntimeProviderConfig(
val modelKey: String, val modelKey: String,
val providerName: String, val providerName: String,