diff --git a/Partner-Core/src/main/java/work/slhaf/partner/module/action/executor/ActionCorrectionRecognizer.java b/Partner-Core/src/main/java/work/slhaf/partner/module/action/executor/ActionCorrectionRecognizer.java index 7509bda8..47cb3e9c 100644 --- a/Partner-Core/src/main/java/work/slhaf/partner/module/action/executor/ActionCorrectionRecognizer.java +++ b/Partner-Core/src/main/java/work/slhaf/partner/module/action/executor/ActionCorrectionRecognizer.java @@ -6,7 +6,7 @@ import org.w3c.dom.Document; import org.w3c.dom.Element; import work.slhaf.partner.api.agent.factory.capability.annotation.InjectCapability; import work.slhaf.partner.api.agent.factory.component.abstracts.AbstractAgentModule; -import work.slhaf.partner.api.agent.factory.component.abstracts.ActivateModel; +import work.slhaf.partner.api.chat.ActivateModel; import work.slhaf.partner.api.chat.pojo.Message; import work.slhaf.partner.core.cognition.CognitionCapability; import work.slhaf.partner.core.cognition.ContextBlock; diff --git a/Partner-Core/src/main/java/work/slhaf/partner/module/action/executor/ActionCorrector.java b/Partner-Core/src/main/java/work/slhaf/partner/module/action/executor/ActionCorrector.java index bfdb74bb..c8d31a5e 100644 --- a/Partner-Core/src/main/java/work/slhaf/partner/module/action/executor/ActionCorrector.java +++ b/Partner-Core/src/main/java/work/slhaf/partner/module/action/executor/ActionCorrector.java @@ -6,7 +6,7 @@ import org.w3c.dom.Document; import org.w3c.dom.Element; import work.slhaf.partner.api.agent.factory.capability.annotation.InjectCapability; import work.slhaf.partner.api.agent.factory.component.abstracts.AbstractAgentModule; -import work.slhaf.partner.api.agent.factory.component.abstracts.ActivateModel; +import work.slhaf.partner.api.chat.ActivateModel; import work.slhaf.partner.api.chat.pojo.Message; import work.slhaf.partner.core.cognition.CognitionCapability; import work.slhaf.partner.core.cognition.ContextBlock; diff --git a/Partner-Core/src/main/java/work/slhaf/partner/module/action/executor/ParamsExtractor.java b/Partner-Core/src/main/java/work/slhaf/partner/module/action/executor/ParamsExtractor.java index ec7d98b8..84dff8c1 100644 --- a/Partner-Core/src/main/java/work/slhaf/partner/module/action/executor/ParamsExtractor.java +++ b/Partner-Core/src/main/java/work/slhaf/partner/module/action/executor/ParamsExtractor.java @@ -6,7 +6,7 @@ import org.w3c.dom.Document; import org.w3c.dom.Element; import work.slhaf.partner.api.agent.factory.capability.annotation.InjectCapability; import work.slhaf.partner.api.agent.factory.component.abstracts.AbstractAgentModule; -import work.slhaf.partner.api.agent.factory.component.abstracts.ActivateModel; +import work.slhaf.partner.api.chat.ActivateModel; import work.slhaf.partner.api.chat.pojo.Message; import work.slhaf.partner.core.action.entity.MetaActionInfo; import work.slhaf.partner.core.cognition.CognitionCapability; diff --git a/Partner-Core/src/main/java/work/slhaf/partner/module/action/planner/evaluator/ActionEvaluator.java b/Partner-Core/src/main/java/work/slhaf/partner/module/action/planner/evaluator/ActionEvaluator.java index 9583ce43..e68324a2 100644 --- a/Partner-Core/src/main/java/work/slhaf/partner/module/action/planner/evaluator/ActionEvaluator.java +++ b/Partner-Core/src/main/java/work/slhaf/partner/module/action/planner/evaluator/ActionEvaluator.java @@ -6,8 +6,8 @@ import org.w3c.dom.Document; import org.w3c.dom.Element; import work.slhaf.partner.api.agent.factory.capability.annotation.InjectCapability; import work.slhaf.partner.api.agent.factory.component.abstracts.AbstractAgentModule; -import work.slhaf.partner.api.agent.factory.component.abstracts.ActivateModel; import work.slhaf.partner.api.agent.factory.component.annotation.Init; +import work.slhaf.partner.api.chat.ActivateModel; import work.slhaf.partner.api.chat.pojo.Message; import work.slhaf.partner.core.action.ActionCapability; import work.slhaf.partner.core.action.ActionCore; diff --git a/Partner-Core/src/main/java/work/slhaf/partner/module/action/planner/extractor/ActionExtractor.java b/Partner-Core/src/main/java/work/slhaf/partner/module/action/planner/extractor/ActionExtractor.java index 125b4045..3610c12f 100644 --- a/Partner-Core/src/main/java/work/slhaf/partner/module/action/planner/extractor/ActionExtractor.java +++ b/Partner-Core/src/main/java/work/slhaf/partner/module/action/planner/extractor/ActionExtractor.java @@ -3,7 +3,7 @@ package work.slhaf.partner.module.action.planner.extractor; import org.jetbrains.annotations.NotNull; import work.slhaf.partner.api.agent.factory.capability.annotation.InjectCapability; import work.slhaf.partner.api.agent.factory.component.abstracts.AbstractAgentModule; -import work.slhaf.partner.api.agent.factory.component.abstracts.ActivateModel; +import work.slhaf.partner.api.chat.ActivateModel; import work.slhaf.partner.api.chat.pojo.Message; import work.slhaf.partner.core.action.ActionCapability; import work.slhaf.partner.core.cognition.CognitionCapability; diff --git a/Partner-Core/src/main/java/work/slhaf/partner/module/communication/CommunicationProducer.java b/Partner-Core/src/main/java/work/slhaf/partner/module/communication/CommunicationProducer.java index f01e1435..c0902ae4 100644 --- a/Partner-Core/src/main/java/work/slhaf/partner/module/communication/CommunicationProducer.java +++ b/Partner-Core/src/main/java/work/slhaf/partner/module/communication/CommunicationProducer.java @@ -8,10 +8,10 @@ import org.w3c.dom.Document; import org.w3c.dom.Element; import work.slhaf.partner.api.agent.factory.capability.annotation.InjectCapability; import work.slhaf.partner.api.agent.factory.component.abstracts.AbstractAgentModule; -import work.slhaf.partner.api.agent.factory.component.abstracts.ActivateModel; import work.slhaf.partner.api.agent.factory.component.annotation.Init; +import work.slhaf.partner.api.chat.ActivateModel; +import work.slhaf.partner.api.chat.StreamChatMessageConsumer; import work.slhaf.partner.api.chat.pojo.Message; -import work.slhaf.partner.api.chat.runtime.StreamChatMessageConsumer; import work.slhaf.partner.core.cognition.*; import work.slhaf.partner.runtime.interaction.data.context.PartnerRunningFlowContext; diff --git a/Partner-Core/src/main/java/work/slhaf/partner/module/communication/DialogRollingService.java b/Partner-Core/src/main/java/work/slhaf/partner/module/communication/DialogRollingService.java index bdef636d..9de70ed4 100644 --- a/Partner-Core/src/main/java/work/slhaf/partner/module/communication/DialogRollingService.java +++ b/Partner-Core/src/main/java/work/slhaf/partner/module/communication/DialogRollingService.java @@ -7,7 +7,7 @@ import org.w3c.dom.Document; import org.w3c.dom.Element; import work.slhaf.partner.api.agent.factory.capability.annotation.InjectCapability; import work.slhaf.partner.api.agent.factory.component.abstracts.AbstractAgentModule; -import work.slhaf.partner.api.agent.factory.component.abstracts.ActivateModel; +import work.slhaf.partner.api.chat.ActivateModel; import work.slhaf.partner.api.chat.pojo.Message; import work.slhaf.partner.core.cognition.BlockContent; import work.slhaf.partner.core.cognition.CognitionCapability; diff --git a/Partner-Core/src/main/java/work/slhaf/partner/module/communication/ReplyDispatcher.kt b/Partner-Core/src/main/java/work/slhaf/partner/module/communication/ReplyDispatcher.kt index c722ab9f..0dd4abc9 100644 --- a/Partner-Core/src/main/java/work/slhaf/partner/module/communication/ReplyDispatcher.kt +++ b/Partner-Core/src/main/java/work/slhaf/partner/module/communication/ReplyDispatcher.kt @@ -5,7 +5,7 @@ import kotlinx.coroutines.channels.Channel import work.slhaf.partner.api.agent.runtime.interaction.AgentRuntime import work.slhaf.partner.api.agent.runtime.interaction.data.InteractionEvent.EventStatus import work.slhaf.partner.api.agent.runtime.interaction.data.Reply -import work.slhaf.partner.api.chat.runtime.StreamChatMessageConsumer +import work.slhaf.partner.api.chat.StreamChatMessageConsumer import kotlin.time.Duration.Companion.milliseconds object ReplyDispatcher { diff --git a/Partner-Core/src/main/java/work/slhaf/partner/module/memory/selector/evaluator/SliceSelectEvaluator.java b/Partner-Core/src/main/java/work/slhaf/partner/module/memory/selector/evaluator/SliceSelectEvaluator.java index b2ef836b..f8386ef5 100644 --- a/Partner-Core/src/main/java/work/slhaf/partner/module/memory/selector/evaluator/SliceSelectEvaluator.java +++ b/Partner-Core/src/main/java/work/slhaf/partner/module/memory/selector/evaluator/SliceSelectEvaluator.java @@ -8,8 +8,8 @@ import org.w3c.dom.Document; import org.w3c.dom.Element; import work.slhaf.partner.api.agent.factory.capability.annotation.InjectCapability; import work.slhaf.partner.api.agent.factory.component.abstracts.AbstractAgentModule; -import work.slhaf.partner.api.agent.factory.component.abstracts.ActivateModel; import work.slhaf.partner.api.agent.factory.component.annotation.Init; +import work.slhaf.partner.api.chat.ActivateModel; import work.slhaf.partner.api.chat.pojo.Message; import work.slhaf.partner.core.action.ActionCapability; import work.slhaf.partner.core.action.ActionCore; diff --git a/Partner-Core/src/main/java/work/slhaf/partner/module/memory/selector/extractor/MemorySelectExtractor.java b/Partner-Core/src/main/java/work/slhaf/partner/module/memory/selector/extractor/MemorySelectExtractor.java index a9f6b148..a7bfec32 100644 --- a/Partner-Core/src/main/java/work/slhaf/partner/module/memory/selector/extractor/MemorySelectExtractor.java +++ b/Partner-Core/src/main/java/work/slhaf/partner/module/memory/selector/extractor/MemorySelectExtractor.java @@ -8,8 +8,8 @@ import org.w3c.dom.Document; import org.w3c.dom.Element; import work.slhaf.partner.api.agent.factory.capability.annotation.InjectCapability; import work.slhaf.partner.api.agent.factory.component.abstracts.AbstractAgentModule; -import work.slhaf.partner.api.agent.factory.component.abstracts.ActivateModel; import work.slhaf.partner.api.agent.factory.component.annotation.InjectModule; +import work.slhaf.partner.api.chat.ActivateModel; import work.slhaf.partner.api.chat.pojo.Message; import work.slhaf.partner.core.cognition.CognitionCapability; import work.slhaf.partner.core.cognition.ContextBlock; diff --git a/Partner-Core/src/main/java/work/slhaf/partner/module/memory/updater/summarizer/MultiSummarizer.java b/Partner-Core/src/main/java/work/slhaf/partner/module/memory/updater/summarizer/MultiSummarizer.java index ec1fe56e..67c76a74 100644 --- a/Partner-Core/src/main/java/work/slhaf/partner/module/memory/updater/summarizer/MultiSummarizer.java +++ b/Partner-Core/src/main/java/work/slhaf/partner/module/memory/updater/summarizer/MultiSummarizer.java @@ -5,7 +5,7 @@ import com.alibaba.fastjson2.JSONObject; import lombok.Data; import lombok.EqualsAndHashCode; import work.slhaf.partner.api.agent.factory.component.abstracts.AbstractAgentModule; -import work.slhaf.partner.api.agent.factory.component.abstracts.ActivateModel; +import work.slhaf.partner.api.chat.ActivateModel; import work.slhaf.partner.api.chat.pojo.Message; import work.slhaf.partner.module.memory.updater.summarizer.entity.SummarizeInput; import work.slhaf.partner.module.memory.updater.summarizer.entity.SummarizeResult; diff --git a/Partner-Core/src/main/java/work/slhaf/partner/module/memory/updater/summarizer/SingleSummarizer.java b/Partner-Core/src/main/java/work/slhaf/partner/module/memory/updater/summarizer/SingleSummarizer.java index b91fb586..70f3eeef 100644 --- a/Partner-Core/src/main/java/work/slhaf/partner/module/memory/updater/summarizer/SingleSummarizer.java +++ b/Partner-Core/src/main/java/work/slhaf/partner/module/memory/updater/summarizer/SingleSummarizer.java @@ -5,8 +5,8 @@ import lombok.Data; import lombok.EqualsAndHashCode; import work.slhaf.partner.api.agent.factory.capability.annotation.InjectCapability; import work.slhaf.partner.api.agent.factory.component.abstracts.AbstractAgentModule; -import work.slhaf.partner.api.agent.factory.component.abstracts.ActivateModel; import work.slhaf.partner.api.agent.factory.component.annotation.Init; +import work.slhaf.partner.api.chat.ActivateModel; import work.slhaf.partner.api.chat.pojo.Message; import work.slhaf.partner.core.action.ActionCapability; import work.slhaf.partner.core.action.ActionCore; diff --git a/Partner-Core/src/main/java/work/slhaf/partner/module/memory/updater/summarizer/TotalSummarizer.java b/Partner-Core/src/main/java/work/slhaf/partner/module/memory/updater/summarizer/TotalSummarizer.java index e77dc51b..1930cc2b 100644 --- a/Partner-Core/src/main/java/work/slhaf/partner/module/memory/updater/summarizer/TotalSummarizer.java +++ b/Partner-Core/src/main/java/work/slhaf/partner/module/memory/updater/summarizer/TotalSummarizer.java @@ -4,7 +4,7 @@ import cn.hutool.json.JSONUtil; import lombok.Data; import lombok.EqualsAndHashCode; import work.slhaf.partner.api.agent.factory.component.abstracts.AbstractAgentModule; -import work.slhaf.partner.api.agent.factory.component.abstracts.ActivateModel; +import work.slhaf.partner.api.chat.ActivateModel; import work.slhaf.partner.api.chat.pojo.Message; import java.util.HashMap; diff --git a/Partner-Framework/src/main/java/work/slhaf/partner/api/agent/factory/component/ComponentRegisterFactory.kt b/Partner-Framework/src/main/java/work/slhaf/partner/api/agent/factory/component/ComponentRegisterFactory.kt index bf32a42e..e5ffba40 100644 --- a/Partner-Framework/src/main/java/work/slhaf/partner/api/agent/factory/component/ComponentRegisterFactory.kt +++ b/Partner-Framework/src/main/java/work/slhaf/partner/api/agent/factory/component/ComponentRegisterFactory.kt @@ -5,13 +5,13 @@ import com.alibaba.fastjson2.JSONObject import org.slf4j.LoggerFactory import work.slhaf.partner.api.agent.factory.AgentBaseFactory import work.slhaf.partner.api.agent.factory.component.abstracts.AbstractAgentModule -import work.slhaf.partner.api.agent.factory.component.abstracts.ActivateModel import work.slhaf.partner.api.agent.factory.component.annotation.AgentComponent import work.slhaf.partner.api.agent.factory.component.exception.ModuleFactoryInitFailedException import work.slhaf.partner.api.agent.factory.config.pojo.ModelConfig import work.slhaf.partner.api.agent.factory.context.AgentContext import work.slhaf.partner.api.agent.factory.context.AgentRegisterContext import work.slhaf.partner.api.agent.factory.context.ModuleContextData +import work.slhaf.partner.api.chat.ActivateModel import java.lang.reflect.Modifier import java.time.ZonedDateTime diff --git a/Partner-Framework/src/main/java/work/slhaf/partner/api/agent/factory/component/abstracts/AgentModule.kt b/Partner-Framework/src/main/java/work/slhaf/partner/api/agent/factory/component/abstracts/AgentModule.kt index 00a5eecb..ac943f3c 100644 --- a/Partner-Framework/src/main/java/work/slhaf/partner/api/agent/factory/component/abstracts/AgentModule.kt +++ b/Partner-Framework/src/main/java/work/slhaf/partner/api/agent/factory/component/abstracts/AgentModule.kt @@ -3,11 +3,7 @@ package work.slhaf.partner.api.agent.factory.component.abstracts import org.slf4j.Logger import org.slf4j.LoggerFactory import work.slhaf.partner.api.agent.factory.component.annotation.AgentComponent -import work.slhaf.partner.api.agent.runtime.config.AgentConfigLoader import work.slhaf.partner.api.agent.runtime.interaction.flow.RunningFlowContext -import work.slhaf.partner.api.chat.pojo.Message -import work.slhaf.partner.api.chat.runtime.OpenAiChatRuntime -import work.slhaf.partner.api.chat.runtime.StreamChatMessageConsumer /** * 模块基类 @@ -35,55 +31,3 @@ sealed class AbstractAgentModule { } -interface ActivateModel { - - val runtime: OpenAiChatRuntime - get() = runtimeMap.computeIfAbsent(modelKey()) { - buildRuntime() - } - - companion object { - val runtimeMap: MutableMap = mutableMapOf() - private val configManager: AgentConfigLoader = AgentConfigLoader.INSTANCE - } - - fun buildRuntime(): OpenAiChatRuntime { - val modelConfig = configManager.loadModelConfig(modelKey()) - return OpenAiChatRuntime(modelConfig.baseUrl, modelConfig.apikey, modelConfig.model) - } - - fun chat(messages: List): String { - return runtime.chat(mergeMessages(messages)) - } - - fun streamChat(messages: List, handler: StreamChatMessageConsumer) { - return runtime.streamChat(mergeMessages(messages), handler) - } - - fun formattedChat(messages: List, responseType: Class): T { - return runtime.formattedChat(mergeMessages(messages), responseType) - } - - fun mergeMessages(messages: List): List { - if (modulePrompt().isEmpty()) { - return messages - } - return buildList { - addAll(modulePrompt()) - addAll(messages) - } - } - - /** - * 对应调用的模型配置名称 - */ - fun modelKey(): String { - return if (this is AbstractAgentModule) { - this.moduleName - } else { - javaClass.simpleName - } - } - - fun modulePrompt(): List = emptyList() -} diff --git a/Partner-Framework/src/main/java/work/slhaf/partner/api/chat/ActivateModel.kt b/Partner-Framework/src/main/java/work/slhaf/partner/api/chat/ActivateModel.kt new file mode 100644 index 00000000..d626492e --- /dev/null +++ b/Partner-Framework/src/main/java/work/slhaf/partner/api/chat/ActivateModel.kt @@ -0,0 +1,42 @@ +package work.slhaf.partner.api.chat + +import work.slhaf.partner.api.agent.factory.component.abstracts.AbstractAgentModule +import work.slhaf.partner.api.chat.pojo.Message + +interface ActivateModel { + + fun chat(messages: List): String { + return ModelRuntimeRegistry.resolveProvider(modelKey()).chat(mergeMessages(messages)) + } + + fun streamChat(messages: List, handler: StreamChatMessageConsumer) { + ModelRuntimeRegistry.resolveProvider(modelKey()).streamChat(mergeMessages(messages), handler) + } + + fun formattedChat(messages: List, responseType: Class): T { + return ModelRuntimeRegistry.resolveProvider(modelKey()).formattedChat(mergeMessages(messages), responseType) + } + + fun mergeMessages(messages: List): List { + if (modulePrompt().isEmpty()) { + return messages + } + return buildList { + addAll(modulePrompt()) + addAll(messages) + } + } + + /** + * 对应调用的模型配置名称 + */ + fun modelKey(): String { + return if (this is AbstractAgentModule) { + this.moduleName + } else { + javaClass.simpleName + } + } + + fun modulePrompt(): List = emptyList() +} \ No newline at end of file diff --git a/Partner-Framework/src/main/java/work/slhaf/partner/api/chat/ModelRuntimeRegistry.kt b/Partner-Framework/src/main/java/work/slhaf/partner/api/chat/ModelRuntimeRegistry.kt new file mode 100644 index 00000000..baf1e04c --- /dev/null +++ b/Partner-Framework/src/main/java/work/slhaf/partner/api/chat/ModelRuntimeRegistry.kt @@ -0,0 +1,76 @@ +package work.slhaf.partner.api.chat + +import work.slhaf.partner.api.chat.provider.ModelProvider +import work.slhaf.partner.api.chat.provider.ProviderOverride +import work.slhaf.partner.api.chat.provider.openai.OpenAiCompatibleProvider + +object ModelRuntimeRegistry { + + private const val DEFAULT_PROVIDER = "default" + + /** + * 基础的 provider 提供商,可 fork 出新的 runtime provider,必须提供一个 default provider + */ + private val baseProvider = mutableMapOf() + + /** + * 根据模块进行对应的 provider + */ + private val runtimeProvider = mutableMapOf() + + fun resolveProvider(modelKey: String): ModelProvider { + val provider = runtimeProvider[modelKey] + if (provider != null) { + return provider + } + return baseProvider[DEFAULT_PROVIDER]!! + } + + private fun registerProvider(config: ProviderConfig) { + when (config) { + is OpenAiCompatibleProviderConfig -> baseProvider[config.name] = + OpenAiCompatibleProvider(config.baseUrl, config.apiKey, config.defaultModel) + } + } + + private fun forkProvider(config: RuntimeProviderConfig) { + val provider = baseProvider[config.providerName] + ?: throw IllegalArgumentException("Provider ${config.providerName} not found") + val override = config.override + + runtimeProvider[config.modelKey] = if (override != null) { + provider.fork(override) + } else { + provider + } + } + +} + +data class RuntimeProviderConfig( + val modelKey: String, + val providerName: String, + + val override: ProviderOverride? +) + + +sealed class ProviderConfig { + abstract val name: String + abstract val type: ProviderType + abstract val defaultModel: String + + enum class ProviderType { + OPENAI_COMPATIBLE + } +} + +data class OpenAiCompatibleProviderConfig( + override val name: String, + override val type: ProviderType = ProviderType.OPENAI_COMPATIBLE, + override val defaultModel: String, + + val baseUrl: String, + val apiKey: String +) : ProviderConfig() + diff --git a/Partner-Framework/src/main/java/work/slhaf/partner/api/chat/runtime/StreamChatMessageConsumer.java b/Partner-Framework/src/main/java/work/slhaf/partner/api/chat/StreamChatMessageConsumer.java similarity index 89% rename from Partner-Framework/src/main/java/work/slhaf/partner/api/chat/runtime/StreamChatMessageConsumer.java rename to Partner-Framework/src/main/java/work/slhaf/partner/api/chat/StreamChatMessageConsumer.java index 12feabe1..603ae279 100644 --- a/Partner-Framework/src/main/java/work/slhaf/partner/api/chat/runtime/StreamChatMessageConsumer.java +++ b/Partner-Framework/src/main/java/work/slhaf/partner/api/chat/StreamChatMessageConsumer.java @@ -1,4 +1,4 @@ -package work.slhaf.partner.api.chat.runtime; +package work.slhaf.partner.api.chat; public abstract class StreamChatMessageConsumer { private final StringBuilder responseText = new StringBuilder(); diff --git a/Partner-Framework/src/main/java/work/slhaf/partner/api/chat/provider/ModelProvider.kt b/Partner-Framework/src/main/java/work/slhaf/partner/api/chat/provider/ModelProvider.kt new file mode 100644 index 00000000..408052a0 --- /dev/null +++ b/Partner-Framework/src/main/java/work/slhaf/partner/api/chat/provider/ModelProvider.kt @@ -0,0 +1,27 @@ +package work.slhaf.partner.api.chat.provider + +import com.alibaba.fastjson2.JSONObject +import work.slhaf.partner.api.chat.StreamChatMessageConsumer +import work.slhaf.partner.api.chat.pojo.Message + +abstract class ModelProvider @JvmOverloads constructor( + val override: ProviderOverride? = null +) { + + abstract fun fork(override: ProviderOverride): ModelProvider + + abstract fun streamChat(messages: List, consumer: StreamChatMessageConsumer) + + abstract fun chat(messages: List): String + + abstract fun formattedChat(messages: List, type: Class): T +} + +data class ProviderOverride( + val model: String, + val temperature: Double?, + val topP: Double?, + val maxTokens: Int?, + + val extras: JSONObject? +) \ No newline at end of file diff --git a/Partner-Framework/src/main/java/work/slhaf/partner/api/chat/runtime/OpenAiChatRuntime.java b/Partner-Framework/src/main/java/work/slhaf/partner/api/chat/provider/openai/OpenAiCompatibleProvider.java similarity index 51% rename from Partner-Framework/src/main/java/work/slhaf/partner/api/chat/runtime/OpenAiChatRuntime.java rename to Partner-Framework/src/main/java/work/slhaf/partner/api/chat/provider/openai/OpenAiCompatibleProvider.java index 325239d6..c6afb844 100644 --- a/Partner-Framework/src/main/java/work/slhaf/partner/api/chat/runtime/OpenAiChatRuntime.java +++ b/Partner-Framework/src/main/java/work/slhaf/partner/api/chat/provider/openai/OpenAiCompatibleProvider.java @@ -1,34 +1,57 @@ -package work.slhaf.partner.api.chat.runtime; +package work.slhaf.partner.api.chat.provider.openai; +import com.alibaba.fastjson2.JSONObject; import com.openai.client.OpenAIClient; import com.openai.client.okhttp.OpenAIOkHttpClient; +import com.openai.core.JsonValue; import com.openai.core.http.StreamResponse; import com.openai.models.chat.completions.*; +import org.jetbrains.annotations.NotNull; +import work.slhaf.partner.api.chat.StreamChatMessageConsumer; import work.slhaf.partner.api.chat.pojo.Message; +import work.slhaf.partner.api.chat.provider.ModelProvider; +import work.slhaf.partner.api.chat.provider.ProviderOverride; import java.time.Duration; import java.util.List; -public class OpenAiChatRuntime { +public class OpenAiCompatibleProvider extends ModelProvider { - private final OpenAIClient client; + private final String baseUrl; + private final String apiKey; private final String model; - public OpenAiChatRuntime(String baseUrl, String apikey, String model) { + private final OpenAIClient client; + + public OpenAiCompatibleProvider(String baseUrl, String apikey, String model) { this.client = OpenAIOkHttpClient.builder() .baseUrl(baseUrl) .apiKey(apikey) .timeout(Duration.ofSeconds(30)) .build(); + this.baseUrl = baseUrl; + this.apiKey = apikey; this.model = model; } - public String chat(List messages) { + public OpenAiCompatibleProvider(String baseUrl, String apikey, String model, ProviderOverride override) { + super(override); + this.client = OpenAIOkHttpClient.builder() + .baseUrl(baseUrl) + .apiKey(apikey) + .timeout(Duration.ofSeconds(30)) + .build(); + this.baseUrl = baseUrl; + this.apiKey = apikey; + this.model = model; + } + + public @NotNull String chat(@NotNull List messages) { ChatCompletionCreateParams params = buildParams(messages); return extractText(client.chat().completions().create(params)); } - public void streamChat(List messages, StreamChatMessageConsumer handler) { + public void streamChat(@NotNull List messages, StreamChatMessageConsumer handler) { ChatCompletionCreateParams params = buildParams(messages); try (StreamResponse streamResponse = client.chat().completions().createStreaming(params)) { streamResponse.stream() @@ -39,7 +62,7 @@ public class OpenAiChatRuntime { } } - public T formattedChat(List messages, Class responseType) { + public T formattedChat(@NotNull List messages, @NotNull Class responseType) { StructuredChatCompletionCreateParams params = buildParams(messages).toBuilder() .responseFormat(responseType) .build(); @@ -47,10 +70,30 @@ public class OpenAiChatRuntime { } private ChatCompletionCreateParams buildParams(List messages) { - return ChatCompletionCreateParams.builder() + ChatCompletionCreateParams.Builder paramsBuilder = ChatCompletionCreateParams.builder() .model(model) - .messages(OpenAiMessageAdapter.toParams(messages)) - .build(); + .messages(OpenAiMessageAdapter.toParams(messages)); + + ProviderOverride override = getOverride(); + if (override != null) { + if (override.getTemperature() != null) { + paramsBuilder.temperature(override.getTemperature()); + } + if (override.getTopP() != null) { + paramsBuilder.topP(override.getTopP()); + } + if (override.getMaxTokens() != null) { + paramsBuilder.maxCompletionTokens(override.getMaxTokens()); + } + JSONObject extras = override.getExtras(); + if (extras != null) { + extras.forEach((key, value) -> { + paramsBuilder.putAdditionalBodyProperty(key, JsonValue.from(value)); + }); + } + } + + return paramsBuilder.build(); } private String extractText(ChatCompletion completion) { @@ -68,4 +111,9 @@ public class OpenAiChatRuntime { return completion.choices().getFirst().message().content() .orElseThrow(() -> new IllegalStateException("OpenAI structured chat completion returned empty content.")); } + + @Override + public @NotNull ModelProvider fork(@NotNull ProviderOverride override) { + return new OpenAiCompatibleProvider(baseUrl, apiKey, override.getModel(), override); + } } diff --git a/Partner-Framework/src/main/java/work/slhaf/partner/api/chat/runtime/OpenAiMessageAdapter.java b/Partner-Framework/src/main/java/work/slhaf/partner/api/chat/provider/openai/OpenAiMessageAdapter.java similarity index 96% rename from Partner-Framework/src/main/java/work/slhaf/partner/api/chat/runtime/OpenAiMessageAdapter.java rename to Partner-Framework/src/main/java/work/slhaf/partner/api/chat/provider/openai/OpenAiMessageAdapter.java index 61ee167a..827b7d8b 100644 --- a/Partner-Framework/src/main/java/work/slhaf/partner/api/chat/runtime/OpenAiMessageAdapter.java +++ b/Partner-Framework/src/main/java/work/slhaf/partner/api/chat/provider/openai/OpenAiMessageAdapter.java @@ -1,4 +1,4 @@ -package work.slhaf.partner.api.chat.runtime; +package work.slhaf.partner.api.chat.provider.openai; import com.openai.models.chat.completions.ChatCompletionAssistantMessageParam; import com.openai.models.chat.completions.ChatCompletionMessageParam; diff --git a/Partner-Framework/src/test/java/work/slhaf/partner/api/chat/pojo/MessageTest.java b/Partner-Framework/src/test/java/work/slhaf/partner/api/chat/pojo/MessageTest.java index 750d79a1..17c3d421 100644 --- a/Partner-Framework/src/test/java/work/slhaf/partner/api/chat/pojo/MessageTest.java +++ b/Partner-Framework/src/test/java/work/slhaf/partner/api/chat/pojo/MessageTest.java @@ -2,7 +2,7 @@ package work.slhaf.partner.api.chat.pojo; import com.alibaba.fastjson2.JSON; import org.junit.jupiter.api.Test; -import work.slhaf.partner.api.chat.runtime.OpenAiMessageAdapter; +import work.slhaf.partner.api.chat.provider.openai.OpenAiMessageAdapter; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertThrows;