mirror of
https://github.com/slhaf/Partner.git
synced 2026-05-12 16:53:04 +08:00
fix(model): correct deserialization behavior of model provider config
This commit is contained in:
@@ -13,7 +13,6 @@ import work.slhaf.partner.framework.agent.model.provider.ModelProvider
|
|||||||
import work.slhaf.partner.framework.agent.model.provider.ProviderOverride
|
import work.slhaf.partner.framework.agent.model.provider.ProviderOverride
|
||||||
import work.slhaf.partner.framework.agent.model.provider.openai.OpenAiCompatibleProvider
|
import work.slhaf.partner.framework.agent.model.provider.openai.OpenAiCompatibleProvider
|
||||||
import java.nio.file.Path
|
import java.nio.file.Path
|
||||||
import java.util.Locale.getDefault
|
|
||||||
import java.util.concurrent.locks.ReentrantLock
|
import java.util.concurrent.locks.ReentrantLock
|
||||||
import kotlin.concurrent.withLock
|
import kotlin.concurrent.withLock
|
||||||
|
|
||||||
@@ -86,8 +85,15 @@ object ModelRuntimeRegistry : Configurable, ConfigRegistration<ModelRuntimeRegis
|
|||||||
override fun type(): Class<ModelRuntimeRegistryConfig> = ModelRuntimeRegistryConfig::class.java
|
override fun type(): Class<ModelRuntimeRegistryConfig> = ModelRuntimeRegistryConfig::class.java
|
||||||
|
|
||||||
override fun init(config: ModelRuntimeRegistryConfig, json: JSONObject?) = providerLock.withLock {
|
override fun init(config: ModelRuntimeRegistryConfig, json: JSONObject?) = providerLock.withLock {
|
||||||
|
val acceptableConfig = try {
|
||||||
|
parseJsonConfig(json)
|
||||||
|
} catch (e: Exception) {
|
||||||
|
log.warn("Unable to load model config", e)
|
||||||
|
config
|
||||||
|
}
|
||||||
|
|
||||||
try {
|
try {
|
||||||
applyConfig(config)
|
applyConfig(acceptableConfig)
|
||||||
} catch (e: ModelRegistryException) {
|
} catch (e: ModelRegistryException) {
|
||||||
throw ModelRegistryStartupException(
|
throw ModelRegistryStartupException(
|
||||||
e.message ?: "Failed to apply model runtime config",
|
e.message ?: "Failed to apply model runtime config",
|
||||||
@@ -104,36 +110,8 @@ object ModelRuntimeRegistry : Configurable, ConfigRegistration<ModelRuntimeRegis
|
|||||||
val baseProviderSnapshot = baseProvider.toMap()
|
val baseProviderSnapshot = baseProvider.toMap()
|
||||||
val runtimeProviderSnapshot = runtimeProvider.toMap()
|
val runtimeProviderSnapshot = runtimeProvider.toMap()
|
||||||
try {
|
try {
|
||||||
val providerSetJson = root.getJSONArray("providerConfigSet")
|
val parsedConfig = parseJsonConfig(root)
|
||||||
?: throw runtimeModelException("providerConfigSet is missing or not an array")
|
applyConfig(parsedConfig)
|
||||||
baseProvider.clear()
|
|
||||||
for (i in providerSetJson.indices) {
|
|
||||||
val providerJson = providerSetJson.getJSONObject(i)
|
|
||||||
?: throw runtimeModelException("providerConfigSet[$i] is not an object")
|
|
||||||
val typeText = providerJson.getString("type")
|
|
||||||
?: throw runtimeModelException(
|
|
||||||
"providerConfigSet[$i].type is missing",
|
|
||||||
providerJson.getString("name") ?: COMPONENT_NAME
|
|
||||||
)
|
|
||||||
val providerType = try {
|
|
||||||
ProviderConfig.ProviderType.valueOf(typeText.uppercase(getDefault()))
|
|
||||||
} catch (e: IllegalArgumentException) {
|
|
||||||
throw runtimeModelException(
|
|
||||||
"Unsupported provider type: $typeText",
|
|
||||||
providerJson.getString("name") ?: COMPONENT_NAME,
|
|
||||||
cause = e
|
|
||||||
)
|
|
||||||
}
|
|
||||||
val concreteProviderConfig = when (providerType) {
|
|
||||||
OPENAI_COMPATIBLE -> providerJson.toJavaObject(OpenAiCompatibleProviderConfig::class.java)
|
|
||||||
}
|
|
||||||
registerProvider(concreteProviderConfig)
|
|
||||||
}
|
|
||||||
if (!baseProvider.containsKey(DEFAULT_PROVIDER)) {
|
|
||||||
throw runtimeModelException("Provider default not found", DEFAULT_PROVIDER)
|
|
||||||
}
|
|
||||||
runtimeProvider.clear()
|
|
||||||
config.runtimeConfigSet.forEach { forkProvider(it) }
|
|
||||||
} catch (e: Exception) {
|
} catch (e: Exception) {
|
||||||
log.error("Error while loading runtime provider config", e)
|
log.error("Error while loading runtime provider config", e)
|
||||||
baseProvider.clear()
|
baseProvider.clear()
|
||||||
@@ -143,6 +121,32 @@ object ModelRuntimeRegistry : Configurable, ConfigRegistration<ModelRuntimeRegis
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private fun parseJsonConfig(json: JSONObject?): ModelRuntimeRegistryConfig {
|
||||||
|
if (json == null) {
|
||||||
|
throw runtimeModelException("Unable to find model config")
|
||||||
|
}
|
||||||
|
val providerConfigSet = json.getJSONArray("providerConfigSet").filterIsInstance<JSONObject>()
|
||||||
|
.map { config ->
|
||||||
|
val type = config.getString("type")
|
||||||
|
val config = if (type.equals(OPENAI_COMPATIBLE.name.uppercase())) {
|
||||||
|
config.toJavaObject(OpenAiCompatibleProviderConfig::class.java)
|
||||||
|
} else {
|
||||||
|
null
|
||||||
|
}
|
||||||
|
if (config == null) {
|
||||||
|
throw runtimeModelException("Unknown config type: $type")
|
||||||
|
}
|
||||||
|
config
|
||||||
|
}.toSet()
|
||||||
|
|
||||||
|
val runtimeConfigSet = json.getJSONArray("runtimeConfigSet").filterIsInstance<JSONObject>()
|
||||||
|
.map { config ->
|
||||||
|
config.toJavaObject(RuntimeProviderConfig::class.java)
|
||||||
|
}.toSet()
|
||||||
|
|
||||||
|
return ModelRuntimeRegistryConfig(providerConfigSet, runtimeConfigSet)
|
||||||
|
}
|
||||||
|
|
||||||
override fun defaultConfig(): ModelRuntimeRegistryConfig? {
|
override fun defaultConfig(): ModelRuntimeRegistryConfig? {
|
||||||
val defaultBaseUrl = System.getenv("PARTNER_DEFAULT_BASE_URL") ?: return null
|
val defaultBaseUrl = System.getenv("PARTNER_DEFAULT_BASE_URL") ?: return null
|
||||||
val defaultApiKey = System.getenv("PARTNER_DEFAULT_API_KEY") ?: return null
|
val defaultApiKey = System.getenv("PARTNER_DEFAULT_API_KEY") ?: return null
|
||||||
|
|||||||
Reference in New Issue
Block a user