mirror of
https://github.com/slhaf/Partner.git
synced 2026-05-12 16:53:04 +08:00
feat(partnerctl-init): add interactive model provider configuration
This commit is contained in:
@@ -1,10 +1,13 @@
|
||||
package work.slhaf.partner.ctl.commands
|
||||
|
||||
import kotlinx.serialization.json.Json
|
||||
import kotlinx.serialization.json.*
|
||||
import picocli.CommandLine
|
||||
import work.slhaf.partner.ctl.commands.data.GatewayConfig
|
||||
import work.slhaf.partner.ctl.commands.data.OpenAiCompatible
|
||||
import work.slhaf.partner.ctl.commands.data.ProviderConfig
|
||||
import work.slhaf.partner.ctl.commands.init.buildFromSource
|
||||
import work.slhaf.partner.ctl.commands.init.configureExternalGateway
|
||||
import work.slhaf.partner.ctl.commands.init.configureOpenAiCompatible
|
||||
import work.slhaf.partner.ctl.commands.init.configureWebSocketGateway
|
||||
import work.slhaf.partner.ctl.support.loadAvailableGateway
|
||||
import work.slhaf.partner.ctl.ui.Choice
|
||||
@@ -152,16 +155,94 @@ class InitCommand : Runnable {
|
||||
}
|
||||
|
||||
private fun configureModel(prompt: Prompt) {
|
||||
TODO("Not yet implemented")
|
||||
prompt.section("Configure Model")
|
||||
|
||||
val modelChoices = ModelProviderChoice.entries.map { Choice(it.display, it) }
|
||||
|
||||
val chosenModelProviders = mutableListOf<ProviderConfig>()
|
||||
|
||||
var defaultAlreadySet = false
|
||||
while (true) {
|
||||
val choice = prompt.select(
|
||||
label = if (!defaultAlreadySet) {
|
||||
"Choose default model provider type"
|
||||
} else {
|
||||
"Choose model provider type"
|
||||
},
|
||||
choices = modelChoices
|
||||
)
|
||||
|
||||
val providerConfig = when (choice) {
|
||||
ModelProviderChoice.OPENAI_COMPATIBLE -> configureOpenAiCompatible(prompt, defaultAlreadySet)
|
||||
ModelProviderChoice.SKIP -> {
|
||||
if (defaultAlreadySet) {
|
||||
break
|
||||
} else {
|
||||
prompt.warn(
|
||||
"No default model provider configured. Partner may not start normally unless model.json exists " +
|
||||
"or PARTNER_DEFAULT_BASE_URL, PARTNER_DEFAULT_API_KEY, and PARTNER_DEFAULT_MODEL are provided at runtime."
|
||||
)
|
||||
if (prompt.confirm("Skip model configuration?", false)) {
|
||||
break
|
||||
} else {
|
||||
null
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
providerConfig?.let {
|
||||
chosenModelProviders.add(it)
|
||||
if (!defaultAlreadySet) {
|
||||
defaultAlreadySet = true
|
||||
if (!prompt.confirm("Add additional model provider?", false)) {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
if (chosenModelProviders.isNotEmpty()) {
|
||||
val json = Json {
|
||||
prettyPrint = true
|
||||
encodeDefaults = true
|
||||
}
|
||||
|
||||
val jsonObject = buildJsonObject {
|
||||
putJsonArray("providerConfigSet") {
|
||||
chosenModelProviders.forEach {
|
||||
add(json.encodeProviderConfig(it))
|
||||
}
|
||||
}
|
||||
putJsonArray("runtimeConfigSet") {}
|
||||
}
|
||||
|
||||
val modelPath = home.resolve("config").resolve("model.json").toAbsolutePath().normalize()
|
||||
Files.writeString(modelPath, json.encodeToString(JsonObject.serializer(), jsonObject))
|
||||
|
||||
prompt.success("Model config written to $modelPath")
|
||||
}
|
||||
}
|
||||
|
||||
private fun finalize(prompt: Prompt) {
|
||||
TODO("Not yet implemented")
|
||||
}
|
||||
|
||||
private fun Json.encodeProviderConfig(providerConfig: ProviderConfig): JsonElement {
|
||||
return when (providerConfig) {
|
||||
is OpenAiCompatible -> encodeToJsonElement(providerConfig)
|
||||
else -> error("Unsupported provider config type: ${providerConfig::class.simpleName}")
|
||||
}
|
||||
}
|
||||
|
||||
private enum class InstallChoice {
|
||||
BUILD_FROM_SOURCE
|
||||
}
|
||||
|
||||
private enum class ModelProviderChoice(val display: String) {
|
||||
OPENAI_COMPATIBLE("OpenAI Compatible"),
|
||||
SKIP("Skip")
|
||||
}
|
||||
|
||||
}
|
||||
@@ -13,4 +13,21 @@ data class GatewayConfig(
|
||||
val channelName: String,
|
||||
val params: JsonObject
|
||||
)
|
||||
}
|
||||
|
||||
interface ProviderConfig {
|
||||
val name: String
|
||||
val type: String
|
||||
val defaultModel: String
|
||||
}
|
||||
|
||||
@Serializable
|
||||
data class OpenAiCompatible(
|
||||
override val name: String,
|
||||
override val defaultModel: String,
|
||||
|
||||
val baseUrl: String,
|
||||
val apiKey: String,
|
||||
) : ProviderConfig {
|
||||
override val type: String = "OPENAI_COMPATIBLE"
|
||||
}
|
||||
@@ -2,8 +2,11 @@ package work.slhaf.partner.ctl.commands.init
|
||||
|
||||
import kotlinx.serialization.json.*
|
||||
import work.slhaf.partner.ctl.commands.data.GatewayConfig
|
||||
import work.slhaf.partner.ctl.commands.data.OpenAiCompatible
|
||||
import work.slhaf.partner.ctl.commands.data.ProviderConfig
|
||||
import work.slhaf.partner.ctl.support.*
|
||||
import work.slhaf.partner.ctl.ui.Prompt
|
||||
import java.net.URI
|
||||
import java.nio.file.Files
|
||||
import java.nio.file.Path
|
||||
import java.nio.file.Paths
|
||||
@@ -148,3 +151,49 @@ private fun validateFieldValue(field: Field, value: String): String? {
|
||||
?.let { "${field.label} only accepts valid JSON" }
|
||||
}
|
||||
}
|
||||
|
||||
fun configureOpenAiCompatible(prompt: Prompt, defaultAlreadySet: Boolean): ProviderConfig {
|
||||
val name = if (defaultAlreadySet) {
|
||||
prompt.ask("Provider name") {
|
||||
if (it == "default") {
|
||||
"Default provider cannot be duplicate"
|
||||
} else {
|
||||
null
|
||||
}
|
||||
}
|
||||
} else {
|
||||
"default"
|
||||
}
|
||||
|
||||
val baseUrl = prompt.ask("Base url") { value ->
|
||||
validateNetworkUrl(value)
|
||||
}
|
||||
|
||||
val apikey = prompt.ask("Apikey")
|
||||
val defaultModel = prompt.ask("Default model")
|
||||
return OpenAiCompatible(
|
||||
name = name,
|
||||
baseUrl = baseUrl,
|
||||
apiKey = apikey,
|
||||
defaultModel = defaultModel
|
||||
)
|
||||
}
|
||||
|
||||
private fun validateNetworkUrl(value: String): String? {
|
||||
val trimmed = value.trim()
|
||||
if (trimmed.isEmpty()) {
|
||||
return "Base url is required"
|
||||
}
|
||||
|
||||
val uri = runCatching { URI(trimmed) }.getOrElse {
|
||||
return "Base url must be a valid URL"
|
||||
}
|
||||
|
||||
return when {
|
||||
uri.scheme !in setOf("http", "https") -> "Base url must start with http:// or https://"
|
||||
uri.host.isNullOrBlank() -> "Base url must include a valid host"
|
||||
uri.rawUserInfo != null -> "Base url must not include user info"
|
||||
uri.rawFragment != null -> "Base url must not include fragment"
|
||||
else -> null
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user