mirror of
https://github.com/slhaf/Partner.git
synced 2026-05-12 08:43:02 +08:00
feat(partnerctl-init): add interactive model provider configuration
This commit is contained in:
@@ -1,10 +1,13 @@
|
|||||||
package work.slhaf.partner.ctl.commands
|
package work.slhaf.partner.ctl.commands
|
||||||
|
|
||||||
import kotlinx.serialization.json.Json
|
import kotlinx.serialization.json.*
|
||||||
import picocli.CommandLine
|
import picocli.CommandLine
|
||||||
import work.slhaf.partner.ctl.commands.data.GatewayConfig
|
import work.slhaf.partner.ctl.commands.data.GatewayConfig
|
||||||
|
import work.slhaf.partner.ctl.commands.data.OpenAiCompatible
|
||||||
|
import work.slhaf.partner.ctl.commands.data.ProviderConfig
|
||||||
import work.slhaf.partner.ctl.commands.init.buildFromSource
|
import work.slhaf.partner.ctl.commands.init.buildFromSource
|
||||||
import work.slhaf.partner.ctl.commands.init.configureExternalGateway
|
import work.slhaf.partner.ctl.commands.init.configureExternalGateway
|
||||||
|
import work.slhaf.partner.ctl.commands.init.configureOpenAiCompatible
|
||||||
import work.slhaf.partner.ctl.commands.init.configureWebSocketGateway
|
import work.slhaf.partner.ctl.commands.init.configureWebSocketGateway
|
||||||
import work.slhaf.partner.ctl.support.loadAvailableGateway
|
import work.slhaf.partner.ctl.support.loadAvailableGateway
|
||||||
import work.slhaf.partner.ctl.ui.Choice
|
import work.slhaf.partner.ctl.ui.Choice
|
||||||
@@ -152,16 +155,94 @@ class InitCommand : Runnable {
|
|||||||
}
|
}
|
||||||
|
|
||||||
private fun configureModel(prompt: Prompt) {
|
private fun configureModel(prompt: Prompt) {
|
||||||
TODO("Not yet implemented")
|
prompt.section("Configure Model")
|
||||||
|
|
||||||
|
val modelChoices = ModelProviderChoice.entries.map { Choice(it.display, it) }
|
||||||
|
|
||||||
|
val chosenModelProviders = mutableListOf<ProviderConfig>()
|
||||||
|
|
||||||
|
var defaultAlreadySet = false
|
||||||
|
while (true) {
|
||||||
|
val choice = prompt.select(
|
||||||
|
label = if (!defaultAlreadySet) {
|
||||||
|
"Choose default model provider type"
|
||||||
|
} else {
|
||||||
|
"Choose model provider type"
|
||||||
|
},
|
||||||
|
choices = modelChoices
|
||||||
|
)
|
||||||
|
|
||||||
|
val providerConfig = when (choice) {
|
||||||
|
ModelProviderChoice.OPENAI_COMPATIBLE -> configureOpenAiCompatible(prompt, defaultAlreadySet)
|
||||||
|
ModelProviderChoice.SKIP -> {
|
||||||
|
if (defaultAlreadySet) {
|
||||||
|
break
|
||||||
|
} else {
|
||||||
|
prompt.warn(
|
||||||
|
"No default model provider configured. Partner may not start normally unless model.json exists " +
|
||||||
|
"or PARTNER_DEFAULT_BASE_URL, PARTNER_DEFAULT_API_KEY, and PARTNER_DEFAULT_MODEL are provided at runtime."
|
||||||
|
)
|
||||||
|
if (prompt.confirm("Skip model configuration?", false)) {
|
||||||
|
break
|
||||||
|
} else {
|
||||||
|
null
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
providerConfig?.let {
|
||||||
|
chosenModelProviders.add(it)
|
||||||
|
if (!defaultAlreadySet) {
|
||||||
|
defaultAlreadySet = true
|
||||||
|
if (!prompt.confirm("Add additional model provider?", false)) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (chosenModelProviders.isNotEmpty()) {
|
||||||
|
val json = Json {
|
||||||
|
prettyPrint = true
|
||||||
|
encodeDefaults = true
|
||||||
|
}
|
||||||
|
|
||||||
|
val jsonObject = buildJsonObject {
|
||||||
|
putJsonArray("providerConfigSet") {
|
||||||
|
chosenModelProviders.forEach {
|
||||||
|
add(json.encodeProviderConfig(it))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
putJsonArray("runtimeConfigSet") {}
|
||||||
|
}
|
||||||
|
|
||||||
|
val modelPath = home.resolve("config").resolve("model.json").toAbsolutePath().normalize()
|
||||||
|
Files.writeString(modelPath, json.encodeToString(JsonObject.serializer(), jsonObject))
|
||||||
|
|
||||||
|
prompt.success("Model config written to $modelPath")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private fun finalize(prompt: Prompt) {
|
private fun finalize(prompt: Prompt) {
|
||||||
TODO("Not yet implemented")
|
TODO("Not yet implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private fun Json.encodeProviderConfig(providerConfig: ProviderConfig): JsonElement {
|
||||||
|
return when (providerConfig) {
|
||||||
|
is OpenAiCompatible -> encodeToJsonElement(providerConfig)
|
||||||
|
else -> error("Unsupported provider config type: ${providerConfig::class.simpleName}")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
private enum class InstallChoice {
|
private enum class InstallChoice {
|
||||||
BUILD_FROM_SOURCE
|
BUILD_FROM_SOURCE
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private enum class ModelProviderChoice(val display: String) {
|
||||||
|
OPENAI_COMPATIBLE("OpenAI Compatible"),
|
||||||
|
SKIP("Skip")
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
@@ -14,3 +14,20 @@ data class GatewayConfig(
|
|||||||
val params: JsonObject
|
val params: JsonObject
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
interface ProviderConfig {
|
||||||
|
val name: String
|
||||||
|
val type: String
|
||||||
|
val defaultModel: String
|
||||||
|
}
|
||||||
|
|
||||||
|
@Serializable
|
||||||
|
data class OpenAiCompatible(
|
||||||
|
override val name: String,
|
||||||
|
override val defaultModel: String,
|
||||||
|
|
||||||
|
val baseUrl: String,
|
||||||
|
val apiKey: String,
|
||||||
|
) : ProviderConfig {
|
||||||
|
override val type: String = "OPENAI_COMPATIBLE"
|
||||||
|
}
|
||||||
@@ -2,8 +2,11 @@ package work.slhaf.partner.ctl.commands.init
|
|||||||
|
|
||||||
import kotlinx.serialization.json.*
|
import kotlinx.serialization.json.*
|
||||||
import work.slhaf.partner.ctl.commands.data.GatewayConfig
|
import work.slhaf.partner.ctl.commands.data.GatewayConfig
|
||||||
|
import work.slhaf.partner.ctl.commands.data.OpenAiCompatible
|
||||||
|
import work.slhaf.partner.ctl.commands.data.ProviderConfig
|
||||||
import work.slhaf.partner.ctl.support.*
|
import work.slhaf.partner.ctl.support.*
|
||||||
import work.slhaf.partner.ctl.ui.Prompt
|
import work.slhaf.partner.ctl.ui.Prompt
|
||||||
|
import java.net.URI
|
||||||
import java.nio.file.Files
|
import java.nio.file.Files
|
||||||
import java.nio.file.Path
|
import java.nio.file.Path
|
||||||
import java.nio.file.Paths
|
import java.nio.file.Paths
|
||||||
@@ -148,3 +151,49 @@ private fun validateFieldValue(field: Field, value: String): String? {
|
|||||||
?.let { "${field.label} only accepts valid JSON" }
|
?.let { "${field.label} only accepts valid JSON" }
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fun configureOpenAiCompatible(prompt: Prompt, defaultAlreadySet: Boolean): ProviderConfig {
|
||||||
|
val name = if (defaultAlreadySet) {
|
||||||
|
prompt.ask("Provider name") {
|
||||||
|
if (it == "default") {
|
||||||
|
"Default provider cannot be duplicate"
|
||||||
|
} else {
|
||||||
|
null
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
"default"
|
||||||
|
}
|
||||||
|
|
||||||
|
val baseUrl = prompt.ask("Base url") { value ->
|
||||||
|
validateNetworkUrl(value)
|
||||||
|
}
|
||||||
|
|
||||||
|
val apikey = prompt.ask("Apikey")
|
||||||
|
val defaultModel = prompt.ask("Default model")
|
||||||
|
return OpenAiCompatible(
|
||||||
|
name = name,
|
||||||
|
baseUrl = baseUrl,
|
||||||
|
apiKey = apikey,
|
||||||
|
defaultModel = defaultModel
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
private fun validateNetworkUrl(value: String): String? {
|
||||||
|
val trimmed = value.trim()
|
||||||
|
if (trimmed.isEmpty()) {
|
||||||
|
return "Base url is required"
|
||||||
|
}
|
||||||
|
|
||||||
|
val uri = runCatching { URI(trimmed) }.getOrElse {
|
||||||
|
return "Base url must be a valid URL"
|
||||||
|
}
|
||||||
|
|
||||||
|
return when {
|
||||||
|
uri.scheme !in setOf("http", "https") -> "Base url must start with http:// or https://"
|
||||||
|
uri.host.isNullOrBlank() -> "Base url must include a valid host"
|
||||||
|
uri.rawUserInfo != null -> "Base url must not include user info"
|
||||||
|
uri.rawFragment != null -> "Base url must not include fragment"
|
||||||
|
else -> null
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user