mirror of
https://github.com/slhaf/Partner.git
synced 2026-05-12 08:43:02 +08:00
fix(model): correct deserialization behavior of model provider config
This commit is contained in:
@@ -13,7 +13,6 @@ import work.slhaf.partner.framework.agent.model.provider.ModelProvider
|
||||
import work.slhaf.partner.framework.agent.model.provider.ProviderOverride
|
||||
import work.slhaf.partner.framework.agent.model.provider.openai.OpenAiCompatibleProvider
|
||||
import java.nio.file.Path
|
||||
import java.util.Locale.getDefault
|
||||
import java.util.concurrent.locks.ReentrantLock
|
||||
import kotlin.concurrent.withLock
|
||||
|
||||
@@ -86,8 +85,15 @@ object ModelRuntimeRegistry : Configurable, ConfigRegistration<ModelRuntimeRegis
|
||||
override fun type(): Class<ModelRuntimeRegistryConfig> = ModelRuntimeRegistryConfig::class.java
|
||||
|
||||
override fun init(config: ModelRuntimeRegistryConfig, json: JSONObject?) = providerLock.withLock {
|
||||
val acceptableConfig = try {
|
||||
parseJsonConfig(json)
|
||||
} catch (e: Exception) {
|
||||
log.warn("Unable to load model config", e)
|
||||
config
|
||||
}
|
||||
|
||||
try {
|
||||
applyConfig(config)
|
||||
applyConfig(acceptableConfig)
|
||||
} catch (e: ModelRegistryException) {
|
||||
throw ModelRegistryStartupException(
|
||||
e.message ?: "Failed to apply model runtime config",
|
||||
@@ -104,36 +110,8 @@ object ModelRuntimeRegistry : Configurable, ConfigRegistration<ModelRuntimeRegis
|
||||
val baseProviderSnapshot = baseProvider.toMap()
|
||||
val runtimeProviderSnapshot = runtimeProvider.toMap()
|
||||
try {
|
||||
val providerSetJson = root.getJSONArray("providerConfigSet")
|
||||
?: throw runtimeModelException("providerConfigSet is missing or not an array")
|
||||
baseProvider.clear()
|
||||
for (i in providerSetJson.indices) {
|
||||
val providerJson = providerSetJson.getJSONObject(i)
|
||||
?: throw runtimeModelException("providerConfigSet[$i] is not an object")
|
||||
val typeText = providerJson.getString("type")
|
||||
?: throw runtimeModelException(
|
||||
"providerConfigSet[$i].type is missing",
|
||||
providerJson.getString("name") ?: COMPONENT_NAME
|
||||
)
|
||||
val providerType = try {
|
||||
ProviderConfig.ProviderType.valueOf(typeText.uppercase(getDefault()))
|
||||
} catch (e: IllegalArgumentException) {
|
||||
throw runtimeModelException(
|
||||
"Unsupported provider type: $typeText",
|
||||
providerJson.getString("name") ?: COMPONENT_NAME,
|
||||
cause = e
|
||||
)
|
||||
}
|
||||
val concreteProviderConfig = when (providerType) {
|
||||
OPENAI_COMPATIBLE -> providerJson.toJavaObject(OpenAiCompatibleProviderConfig::class.java)
|
||||
}
|
||||
registerProvider(concreteProviderConfig)
|
||||
}
|
||||
if (!baseProvider.containsKey(DEFAULT_PROVIDER)) {
|
||||
throw runtimeModelException("Provider default not found", DEFAULT_PROVIDER)
|
||||
}
|
||||
runtimeProvider.clear()
|
||||
config.runtimeConfigSet.forEach { forkProvider(it) }
|
||||
val parsedConfig = parseJsonConfig(root)
|
||||
applyConfig(parsedConfig)
|
||||
} catch (e: Exception) {
|
||||
log.error("Error while loading runtime provider config", e)
|
||||
baseProvider.clear()
|
||||
@@ -143,6 +121,32 @@ object ModelRuntimeRegistry : Configurable, ConfigRegistration<ModelRuntimeRegis
|
||||
}
|
||||
}
|
||||
|
||||
private fun parseJsonConfig(json: JSONObject?): ModelRuntimeRegistryConfig {
|
||||
if (json == null) {
|
||||
throw runtimeModelException("Unable to find model config")
|
||||
}
|
||||
val providerConfigSet = json.getJSONArray("providerConfigSet").filterIsInstance<JSONObject>()
|
||||
.map { config ->
|
||||
val type = config.getString("type")
|
||||
val config = if (type.equals(OPENAI_COMPATIBLE.name.uppercase())) {
|
||||
config.toJavaObject(OpenAiCompatibleProviderConfig::class.java)
|
||||
} else {
|
||||
null
|
||||
}
|
||||
if (config == null) {
|
||||
throw runtimeModelException("Unknown config type: $type")
|
||||
}
|
||||
config
|
||||
}.toSet()
|
||||
|
||||
val runtimeConfigSet = json.getJSONArray("runtimeConfigSet").filterIsInstance<JSONObject>()
|
||||
.map { config ->
|
||||
config.toJavaObject(RuntimeProviderConfig::class.java)
|
||||
}.toSet()
|
||||
|
||||
return ModelRuntimeRegistryConfig(providerConfigSet, runtimeConfigSet)
|
||||
}
|
||||
|
||||
override fun defaultConfig(): ModelRuntimeRegistryConfig? {
|
||||
val defaultBaseUrl = System.getenv("PARTNER_DEFAULT_BASE_URL") ?: return null
|
||||
val defaultApiKey = System.getenv("PARTNER_DEFAULT_API_KEY") ?: return null
|
||||
|
||||
Reference in New Issue
Block a user