diff --git a/Partner-Framework/src/main/java/work/slhaf/partner/api/agent/model/ModelRuntimeRegistry.kt b/Partner-Framework/src/main/java/work/slhaf/partner/api/agent/model/ModelRuntimeRegistry.kt index cc4f4168..eee8e081 100644 --- a/Partner-Framework/src/main/java/work/slhaf/partner/api/agent/model/ModelRuntimeRegistry.kt +++ b/Partner-Framework/src/main/java/work/slhaf/partner/api/agent/model/ModelRuntimeRegistry.kt @@ -1,10 +1,18 @@ package work.slhaf.partner.api.agent.model +import org.slf4j.LoggerFactory import work.slhaf.partner.api.agent.model.provider.ModelProvider import work.slhaf.partner.api.agent.model.provider.ProviderOverride import work.slhaf.partner.api.agent.model.provider.openai.OpenAiCompatibleProvider +import work.slhaf.partner.api.agent.runtime.config.Config +import work.slhaf.partner.api.agent.runtime.config.ConfigDoc +import work.slhaf.partner.api.agent.runtime.config.ConfigRegistration +import work.slhaf.partner.api.agent.runtime.config.Configurable +import java.nio.file.Path +import java.util.concurrent.locks.ReentrantLock +import kotlin.concurrent.withLock -object ModelRuntimeRegistry { +object ModelRuntimeRegistry : Configurable, ConfigRegistration { private const val DEFAULT_PROVIDER = "default" @@ -18,7 +26,10 @@ object ModelRuntimeRegistry { */ private val runtimeProvider = mutableMapOf() - fun resolveProvider(modelKey: String): ModelProvider { + private val providerLock = ReentrantLock() + private val log = LoggerFactory.getLogger(ModelRuntimeRegistry::class.java) + + fun resolveProvider(modelKey: String): ModelProvider = providerLock.withLock { val provider = runtimeProvider[modelKey] if (provider != null) { return provider @@ -45,8 +56,86 @@ object ModelRuntimeRegistry { } } + override fun declare(): Map> { + return mapOf(Path.of("model", "model.json") to this) + } + + override fun type(): Class = ModelRuntimeRegistryConfig::class.java + + override fun init(config: ModelRuntimeRegistryConfig) = providerLock.withLock { + config.providerConfigSet.forEach { registerProvider(it) } + config.runtimeConfigSet.forEach { forkProvider(it) } + } + + override fun onReload(config: ModelRuntimeRegistryConfig) = providerLock.withLock { + val baseProviderSnapshot = baseProvider.toMap() + val runtimeProviderSnapshot = runtimeProvider.toMap() + try { + baseProvider.clear() + config.providerConfigSet.forEach { registerProvider(it) } + runtimeProvider.clear() + config.runtimeConfigSet.forEach { forkProvider(it) } + } catch (e: Exception) { + log.error("Error while loading runtime provider config", e) + baseProvider.clear() + baseProvider.putAll(baseProviderSnapshot) + runtimeProvider.clear() + runtimeProvider.putAll(runtimeProviderSnapshot) + } + } + + override fun defaultConfig(): ModelRuntimeRegistryConfig? { + val defaultBaseUrl = System.getenv("PARTNER_DEFAULT_BASE_URL") ?: return null + val defaultApiKey = System.getenv("PARTNER_DEFAULT_API_KEY") ?: return null + val defaultModel = System.getenv("PARTNER_DEFAULT_MODEL") ?: return null + return ModelRuntimeRegistryConfig( + setOf( + OpenAiCompatibleProviderConfig( + "default", + ProviderConfig.ProviderType.OPENAI_COMPATIBLE, + defaultModel, defaultBaseUrl, defaultApiKey + ) + ), setOf() + ) + } } +data class ModelRuntimeRegistryConfig( + @field:ConfigDoc( + description = "提供商配置集合", example = """ [ { + "name": "example_provider_name", + "type": "OPENAI_COMPATIBLE", + "defaultModel": "example_default_model", + "baseUrl": "example_base_url", + "apiKey": "example_apikey" + } + ] + """ + ) + val providerConfigSet: Set, + @field:ConfigDoc( + description = "模块所用提供商配置,可覆写 model、temperature、top_p、max_tokens 等配置", + example = """ + [ + { + "modelKey": "example_model_key", // 模块通过该 key 定位到对应的提供商配置 + "providerName: "example_provider_name", // 该配置对应的提供商名称 + "override": { // 该配置需要重写的内容, 如果无需重写,可忽略该字段,该字段的各个子字段均为可选覆写 + "model": "example_override_model", + "temperature": "example_override_temperature", + "topP": "example_override_top_p", + "maxTokens": "example_override_max_tokens", + "extra": { // 要覆写的额外内容 + "example1": "value1" + } + } + } + ] + """ + ) + val runtimeConfigSet: Set +) : Config() + data class RuntimeProviderConfig( val modelKey: String, val providerName: String,