diff --git a/PartnerCtl/src/main/java/work/slhaf/partner/ctl/commands/InitCommand.kt b/PartnerCtl/src/main/java/work/slhaf/partner/ctl/commands/InitCommand.kt index e0e825a2..8963b6f4 100644 --- a/PartnerCtl/src/main/java/work/slhaf/partner/ctl/commands/InitCommand.kt +++ b/PartnerCtl/src/main/java/work/slhaf/partner/ctl/commands/InitCommand.kt @@ -1,10 +1,13 @@ package work.slhaf.partner.ctl.commands -import kotlinx.serialization.json.Json +import kotlinx.serialization.json.* import picocli.CommandLine import work.slhaf.partner.ctl.commands.data.GatewayConfig +import work.slhaf.partner.ctl.commands.data.OpenAiCompatible +import work.slhaf.partner.ctl.commands.data.ProviderConfig import work.slhaf.partner.ctl.commands.init.buildFromSource import work.slhaf.partner.ctl.commands.init.configureExternalGateway +import work.slhaf.partner.ctl.commands.init.configureOpenAiCompatible import work.slhaf.partner.ctl.commands.init.configureWebSocketGateway import work.slhaf.partner.ctl.support.loadAvailableGateway import work.slhaf.partner.ctl.ui.Choice @@ -152,16 +155,94 @@ class InitCommand : Runnable { } private fun configureModel(prompt: Prompt) { - TODO("Not yet implemented") + prompt.section("Configure Model") + + val modelChoices = ModelProviderChoice.entries.map { Choice(it.display, it) } + + val chosenModelProviders = mutableListOf() + + var defaultAlreadySet = false + while (true) { + val choice = prompt.select( + label = if (!defaultAlreadySet) { + "Choose default model provider type" + } else { + "Choose model provider type" + }, + choices = modelChoices + ) + + val providerConfig = when (choice) { + ModelProviderChoice.OPENAI_COMPATIBLE -> configureOpenAiCompatible(prompt, defaultAlreadySet) + ModelProviderChoice.SKIP -> { + if (defaultAlreadySet) { + break + } else { + prompt.warn( + "No default model provider configured. Partner may not start normally unless model.json exists " + + "or PARTNER_DEFAULT_BASE_URL, PARTNER_DEFAULT_API_KEY, and PARTNER_DEFAULT_MODEL are provided at runtime." + ) + if (prompt.confirm("Skip model configuration?", false)) { + break + } else { + null + } + } + } + } + + providerConfig?.let { + chosenModelProviders.add(it) + if (!defaultAlreadySet) { + defaultAlreadySet = true + if (!prompt.confirm("Add additional model provider?", false)) { + break + } + } + + } + } + + if (chosenModelProviders.isNotEmpty()) { + val json = Json { + prettyPrint = true + encodeDefaults = true + } + + val jsonObject = buildJsonObject { + putJsonArray("providerConfigSet") { + chosenModelProviders.forEach { + add(json.encodeProviderConfig(it)) + } + } + putJsonArray("runtimeConfigSet") {} + } + + val modelPath = home.resolve("config").resolve("model.json").toAbsolutePath().normalize() + Files.writeString(modelPath, json.encodeToString(JsonObject.serializer(), jsonObject)) + + prompt.success("Model config written to $modelPath") + } } private fun finalize(prompt: Prompt) { TODO("Not yet implemented") } + private fun Json.encodeProviderConfig(providerConfig: ProviderConfig): JsonElement { + return when (providerConfig) { + is OpenAiCompatible -> encodeToJsonElement(providerConfig) + else -> error("Unsupported provider config type: ${providerConfig::class.simpleName}") + } + } + private enum class InstallChoice { BUILD_FROM_SOURCE } + private enum class ModelProviderChoice(val display: String) { + OPENAI_COMPATIBLE("OpenAI Compatible"), + SKIP("Skip") + } } \ No newline at end of file diff --git a/PartnerCtl/src/main/java/work/slhaf/partner/ctl/commands/data/config.kt b/PartnerCtl/src/main/java/work/slhaf/partner/ctl/commands/data/config.kt index 9ff6d897..26f88c11 100644 --- a/PartnerCtl/src/main/java/work/slhaf/partner/ctl/commands/data/config.kt +++ b/PartnerCtl/src/main/java/work/slhaf/partner/ctl/commands/data/config.kt @@ -13,4 +13,21 @@ data class GatewayConfig( val channelName: String, val params: JsonObject ) +} + +interface ProviderConfig { + val name: String + val type: String + val defaultModel: String +} + +@Serializable +data class OpenAiCompatible( + override val name: String, + override val defaultModel: String, + + val baseUrl: String, + val apiKey: String, +) : ProviderConfig { + override val type: String = "OPENAI_COMPATIBLE" } \ No newline at end of file diff --git a/PartnerCtl/src/main/java/work/slhaf/partner/ctl/commands/init/configure.kt b/PartnerCtl/src/main/java/work/slhaf/partner/ctl/commands/init/configure.kt index c7282373..788b70d1 100644 --- a/PartnerCtl/src/main/java/work/slhaf/partner/ctl/commands/init/configure.kt +++ b/PartnerCtl/src/main/java/work/slhaf/partner/ctl/commands/init/configure.kt @@ -2,8 +2,11 @@ package work.slhaf.partner.ctl.commands.init import kotlinx.serialization.json.* import work.slhaf.partner.ctl.commands.data.GatewayConfig +import work.slhaf.partner.ctl.commands.data.OpenAiCompatible +import work.slhaf.partner.ctl.commands.data.ProviderConfig import work.slhaf.partner.ctl.support.* import work.slhaf.partner.ctl.ui.Prompt +import java.net.URI import java.nio.file.Files import java.nio.file.Path import java.nio.file.Paths @@ -148,3 +151,49 @@ private fun validateFieldValue(field: Field, value: String): String? { ?.let { "${field.label} only accepts valid JSON" } } } + +fun configureOpenAiCompatible(prompt: Prompt, defaultAlreadySet: Boolean): ProviderConfig { + val name = if (defaultAlreadySet) { + prompt.ask("Provider name") { + if (it == "default") { + "Default provider cannot be duplicate" + } else { + null + } + } + } else { + "default" + } + + val baseUrl = prompt.ask("Base url") { value -> + validateNetworkUrl(value) + } + + val apikey = prompt.ask("Apikey") + val defaultModel = prompt.ask("Default model") + return OpenAiCompatible( + name = name, + baseUrl = baseUrl, + apiKey = apikey, + defaultModel = defaultModel + ) +} + +private fun validateNetworkUrl(value: String): String? { + val trimmed = value.trim() + if (trimmed.isEmpty()) { + return "Base url is required" + } + + val uri = runCatching { URI(trimmed) }.getOrElse { + return "Base url must be a valid URL" + } + + return when { + uri.scheme !in setOf("http", "https") -> "Base url must start with http:// or https://" + uri.host.isNullOrBlank() -> "Base url must include a valid host" + uri.rawUserInfo != null -> "Base url must not include user info" + uri.rawFragment != null -> "Base url must not include fragment" + else -> null + } +} \ No newline at end of file