mirror of
https://github.com/slhaf/Partner.git
synced 2026-05-12 16:53:04 +08:00
refactor(model): manage model registry by ConfigCenter
This commit is contained in:
@@ -1,10 +1,18 @@
|
||||
package work.slhaf.partner.api.agent.model
|
||||
|
||||
import org.slf4j.LoggerFactory
|
||||
import work.slhaf.partner.api.agent.model.provider.ModelProvider
|
||||
import work.slhaf.partner.api.agent.model.provider.ProviderOverride
|
||||
import work.slhaf.partner.api.agent.model.provider.openai.OpenAiCompatibleProvider
|
||||
import work.slhaf.partner.api.agent.runtime.config.Config
|
||||
import work.slhaf.partner.api.agent.runtime.config.ConfigDoc
|
||||
import work.slhaf.partner.api.agent.runtime.config.ConfigRegistration
|
||||
import work.slhaf.partner.api.agent.runtime.config.Configurable
|
||||
import java.nio.file.Path
|
||||
import java.util.concurrent.locks.ReentrantLock
|
||||
import kotlin.concurrent.withLock
|
||||
|
||||
object ModelRuntimeRegistry {
|
||||
object ModelRuntimeRegistry : Configurable, ConfigRegistration<ModelRuntimeRegistryConfig> {
|
||||
|
||||
private const val DEFAULT_PROVIDER = "default"
|
||||
|
||||
@@ -18,7 +26,10 @@ object ModelRuntimeRegistry {
|
||||
*/
|
||||
private val runtimeProvider = mutableMapOf<String, ModelProvider>()
|
||||
|
||||
fun resolveProvider(modelKey: String): ModelProvider {
|
||||
private val providerLock = ReentrantLock()
|
||||
private val log = LoggerFactory.getLogger(ModelRuntimeRegistry::class.java)
|
||||
|
||||
fun resolveProvider(modelKey: String): ModelProvider = providerLock.withLock {
|
||||
val provider = runtimeProvider[modelKey]
|
||||
if (provider != null) {
|
||||
return provider
|
||||
@@ -45,8 +56,86 @@ object ModelRuntimeRegistry {
|
||||
}
|
||||
}
|
||||
|
||||
override fun declare(): Map<Path, ConfigRegistration<out Config>> {
|
||||
return mapOf(Path.of("model", "model.json") to this)
|
||||
}
|
||||
|
||||
override fun type(): Class<ModelRuntimeRegistryConfig> = ModelRuntimeRegistryConfig::class.java
|
||||
|
||||
override fun init(config: ModelRuntimeRegistryConfig) = providerLock.withLock {
|
||||
config.providerConfigSet.forEach { registerProvider(it) }
|
||||
config.runtimeConfigSet.forEach { forkProvider(it) }
|
||||
}
|
||||
|
||||
override fun onReload(config: ModelRuntimeRegistryConfig) = providerLock.withLock {
|
||||
val baseProviderSnapshot = baseProvider.toMap()
|
||||
val runtimeProviderSnapshot = runtimeProvider.toMap()
|
||||
try {
|
||||
baseProvider.clear()
|
||||
config.providerConfigSet.forEach { registerProvider(it) }
|
||||
runtimeProvider.clear()
|
||||
config.runtimeConfigSet.forEach { forkProvider(it) }
|
||||
} catch (e: Exception) {
|
||||
log.error("Error while loading runtime provider config", e)
|
||||
baseProvider.clear()
|
||||
baseProvider.putAll(baseProviderSnapshot)
|
||||
runtimeProvider.clear()
|
||||
runtimeProvider.putAll(runtimeProviderSnapshot)
|
||||
}
|
||||
}
|
||||
|
||||
override fun defaultConfig(): ModelRuntimeRegistryConfig? {
|
||||
val defaultBaseUrl = System.getenv("PARTNER_DEFAULT_BASE_URL") ?: return null
|
||||
val defaultApiKey = System.getenv("PARTNER_DEFAULT_API_KEY") ?: return null
|
||||
val defaultModel = System.getenv("PARTNER_DEFAULT_MODEL") ?: return null
|
||||
return ModelRuntimeRegistryConfig(
|
||||
setOf(
|
||||
OpenAiCompatibleProviderConfig(
|
||||
"default",
|
||||
ProviderConfig.ProviderType.OPENAI_COMPATIBLE,
|
||||
defaultModel, defaultBaseUrl, defaultApiKey
|
||||
)
|
||||
), setOf()
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
data class ModelRuntimeRegistryConfig(
|
||||
@field:ConfigDoc(
|
||||
description = "提供商配置集合", example = """ [ {
|
||||
"name": "example_provider_name",
|
||||
"type": "OPENAI_COMPATIBLE",
|
||||
"defaultModel": "example_default_model",
|
||||
"baseUrl": "example_base_url",
|
||||
"apiKey": "example_apikey"
|
||||
}
|
||||
]
|
||||
"""
|
||||
)
|
||||
val providerConfigSet: Set<ProviderConfig>,
|
||||
@field:ConfigDoc(
|
||||
description = "模块所用提供商配置,可覆写 model、temperature、top_p、max_tokens 等配置",
|
||||
example = """
|
||||
[
|
||||
{
|
||||
"modelKey": "example_model_key", // 模块通过该 key 定位到对应的提供商配置
|
||||
"providerName: "example_provider_name", // 该配置对应的提供商名称
|
||||
"override": { // 该配置需要重写的内容, 如果无需重写,可忽略该字段,该字段的各个子字段均为可选覆写
|
||||
"model": "example_override_model",
|
||||
"temperature": "example_override_temperature",
|
||||
"topP": "example_override_top_p",
|
||||
"maxTokens": "example_override_max_tokens",
|
||||
"extra": { // 要覆写的额外内容
|
||||
"example1": "value1"
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
"""
|
||||
)
|
||||
val runtimeConfigSet: Set<RuntimeProviderConfig>
|
||||
) : Config()
|
||||
|
||||
data class RuntimeProviderConfig(
|
||||
val modelKey: String,
|
||||
val providerName: String,
|
||||
|
||||
Reference in New Issue
Block a user