refactor(chat): extract model activation into provider registry

- move ActivateModel and StreamChatMessageConsumer into api.chat
- replace direct OpenAI runtime construction with ModelRuntimeRegistry
- add provider config, runtime override and OpenAI-compatible provider forking
- rename OpenAiChatRuntime to OpenAiCompatibleProvider and update imports
This commit is contained in:
2026-03-31 18:37:41 +08:00
parent 81aa4b7933
commit e4df68ea5d
22 changed files with 221 additions and 84 deletions

View File

@@ -6,7 +6,7 @@ import org.w3c.dom.Document;
import org.w3c.dom.Element;
import work.slhaf.partner.api.agent.factory.capability.annotation.InjectCapability;
import work.slhaf.partner.api.agent.factory.component.abstracts.AbstractAgentModule;
import work.slhaf.partner.api.agent.factory.component.abstracts.ActivateModel;
import work.slhaf.partner.api.chat.ActivateModel;
import work.slhaf.partner.api.chat.pojo.Message;
import work.slhaf.partner.core.cognition.CognitionCapability;
import work.slhaf.partner.core.cognition.ContextBlock;

View File

@@ -6,7 +6,7 @@ import org.w3c.dom.Document;
import org.w3c.dom.Element;
import work.slhaf.partner.api.agent.factory.capability.annotation.InjectCapability;
import work.slhaf.partner.api.agent.factory.component.abstracts.AbstractAgentModule;
import work.slhaf.partner.api.agent.factory.component.abstracts.ActivateModel;
import work.slhaf.partner.api.chat.ActivateModel;
import work.slhaf.partner.api.chat.pojo.Message;
import work.slhaf.partner.core.cognition.CognitionCapability;
import work.slhaf.partner.core.cognition.ContextBlock;

View File

@@ -6,7 +6,7 @@ import org.w3c.dom.Document;
import org.w3c.dom.Element;
import work.slhaf.partner.api.agent.factory.capability.annotation.InjectCapability;
import work.slhaf.partner.api.agent.factory.component.abstracts.AbstractAgentModule;
import work.slhaf.partner.api.agent.factory.component.abstracts.ActivateModel;
import work.slhaf.partner.api.chat.ActivateModel;
import work.slhaf.partner.api.chat.pojo.Message;
import work.slhaf.partner.core.action.entity.MetaActionInfo;
import work.slhaf.partner.core.cognition.CognitionCapability;

View File

@@ -6,8 +6,8 @@ import org.w3c.dom.Document;
import org.w3c.dom.Element;
import work.slhaf.partner.api.agent.factory.capability.annotation.InjectCapability;
import work.slhaf.partner.api.agent.factory.component.abstracts.AbstractAgentModule;
import work.slhaf.partner.api.agent.factory.component.abstracts.ActivateModel;
import work.slhaf.partner.api.agent.factory.component.annotation.Init;
import work.slhaf.partner.api.chat.ActivateModel;
import work.slhaf.partner.api.chat.pojo.Message;
import work.slhaf.partner.core.action.ActionCapability;
import work.slhaf.partner.core.action.ActionCore;

View File

@@ -3,7 +3,7 @@ package work.slhaf.partner.module.action.planner.extractor;
import org.jetbrains.annotations.NotNull;
import work.slhaf.partner.api.agent.factory.capability.annotation.InjectCapability;
import work.slhaf.partner.api.agent.factory.component.abstracts.AbstractAgentModule;
import work.slhaf.partner.api.agent.factory.component.abstracts.ActivateModel;
import work.slhaf.partner.api.chat.ActivateModel;
import work.slhaf.partner.api.chat.pojo.Message;
import work.slhaf.partner.core.action.ActionCapability;
import work.slhaf.partner.core.cognition.CognitionCapability;

View File

@@ -8,10 +8,10 @@ import org.w3c.dom.Document;
import org.w3c.dom.Element;
import work.slhaf.partner.api.agent.factory.capability.annotation.InjectCapability;
import work.slhaf.partner.api.agent.factory.component.abstracts.AbstractAgentModule;
import work.slhaf.partner.api.agent.factory.component.abstracts.ActivateModel;
import work.slhaf.partner.api.agent.factory.component.annotation.Init;
import work.slhaf.partner.api.chat.ActivateModel;
import work.slhaf.partner.api.chat.StreamChatMessageConsumer;
import work.slhaf.partner.api.chat.pojo.Message;
import work.slhaf.partner.api.chat.runtime.StreamChatMessageConsumer;
import work.slhaf.partner.core.cognition.*;
import work.slhaf.partner.runtime.interaction.data.context.PartnerRunningFlowContext;

View File

@@ -7,7 +7,7 @@ import org.w3c.dom.Document;
import org.w3c.dom.Element;
import work.slhaf.partner.api.agent.factory.capability.annotation.InjectCapability;
import work.slhaf.partner.api.agent.factory.component.abstracts.AbstractAgentModule;
import work.slhaf.partner.api.agent.factory.component.abstracts.ActivateModel;
import work.slhaf.partner.api.chat.ActivateModel;
import work.slhaf.partner.api.chat.pojo.Message;
import work.slhaf.partner.core.cognition.BlockContent;
import work.slhaf.partner.core.cognition.CognitionCapability;

View File

@@ -5,7 +5,7 @@ import kotlinx.coroutines.channels.Channel
import work.slhaf.partner.api.agent.runtime.interaction.AgentRuntime
import work.slhaf.partner.api.agent.runtime.interaction.data.InteractionEvent.EventStatus
import work.slhaf.partner.api.agent.runtime.interaction.data.Reply
import work.slhaf.partner.api.chat.runtime.StreamChatMessageConsumer
import work.slhaf.partner.api.chat.StreamChatMessageConsumer
import kotlin.time.Duration.Companion.milliseconds
object ReplyDispatcher {

View File

@@ -8,8 +8,8 @@ import org.w3c.dom.Document;
import org.w3c.dom.Element;
import work.slhaf.partner.api.agent.factory.capability.annotation.InjectCapability;
import work.slhaf.partner.api.agent.factory.component.abstracts.AbstractAgentModule;
import work.slhaf.partner.api.agent.factory.component.abstracts.ActivateModel;
import work.slhaf.partner.api.agent.factory.component.annotation.Init;
import work.slhaf.partner.api.chat.ActivateModel;
import work.slhaf.partner.api.chat.pojo.Message;
import work.slhaf.partner.core.action.ActionCapability;
import work.slhaf.partner.core.action.ActionCore;

View File

@@ -8,8 +8,8 @@ import org.w3c.dom.Document;
import org.w3c.dom.Element;
import work.slhaf.partner.api.agent.factory.capability.annotation.InjectCapability;
import work.slhaf.partner.api.agent.factory.component.abstracts.AbstractAgentModule;
import work.slhaf.partner.api.agent.factory.component.abstracts.ActivateModel;
import work.slhaf.partner.api.agent.factory.component.annotation.InjectModule;
import work.slhaf.partner.api.chat.ActivateModel;
import work.slhaf.partner.api.chat.pojo.Message;
import work.slhaf.partner.core.cognition.CognitionCapability;
import work.slhaf.partner.core.cognition.ContextBlock;

View File

@@ -5,7 +5,7 @@ import com.alibaba.fastjson2.JSONObject;
import lombok.Data;
import lombok.EqualsAndHashCode;
import work.slhaf.partner.api.agent.factory.component.abstracts.AbstractAgentModule;
import work.slhaf.partner.api.agent.factory.component.abstracts.ActivateModel;
import work.slhaf.partner.api.chat.ActivateModel;
import work.slhaf.partner.api.chat.pojo.Message;
import work.slhaf.partner.module.memory.updater.summarizer.entity.SummarizeInput;
import work.slhaf.partner.module.memory.updater.summarizer.entity.SummarizeResult;

View File

@@ -5,8 +5,8 @@ import lombok.Data;
import lombok.EqualsAndHashCode;
import work.slhaf.partner.api.agent.factory.capability.annotation.InjectCapability;
import work.slhaf.partner.api.agent.factory.component.abstracts.AbstractAgentModule;
import work.slhaf.partner.api.agent.factory.component.abstracts.ActivateModel;
import work.slhaf.partner.api.agent.factory.component.annotation.Init;
import work.slhaf.partner.api.chat.ActivateModel;
import work.slhaf.partner.api.chat.pojo.Message;
import work.slhaf.partner.core.action.ActionCapability;
import work.slhaf.partner.core.action.ActionCore;

View File

@@ -4,7 +4,7 @@ import cn.hutool.json.JSONUtil;
import lombok.Data;
import lombok.EqualsAndHashCode;
import work.slhaf.partner.api.agent.factory.component.abstracts.AbstractAgentModule;
import work.slhaf.partner.api.agent.factory.component.abstracts.ActivateModel;
import work.slhaf.partner.api.chat.ActivateModel;
import work.slhaf.partner.api.chat.pojo.Message;
import java.util.HashMap;

View File

@@ -5,13 +5,13 @@ import com.alibaba.fastjson2.JSONObject
import org.slf4j.LoggerFactory
import work.slhaf.partner.api.agent.factory.AgentBaseFactory
import work.slhaf.partner.api.agent.factory.component.abstracts.AbstractAgentModule
import work.slhaf.partner.api.agent.factory.component.abstracts.ActivateModel
import work.slhaf.partner.api.agent.factory.component.annotation.AgentComponent
import work.slhaf.partner.api.agent.factory.component.exception.ModuleFactoryInitFailedException
import work.slhaf.partner.api.agent.factory.config.pojo.ModelConfig
import work.slhaf.partner.api.agent.factory.context.AgentContext
import work.slhaf.partner.api.agent.factory.context.AgentRegisterContext
import work.slhaf.partner.api.agent.factory.context.ModuleContextData
import work.slhaf.partner.api.chat.ActivateModel
import java.lang.reflect.Modifier
import java.time.ZonedDateTime

View File

@@ -3,11 +3,7 @@ package work.slhaf.partner.api.agent.factory.component.abstracts
import org.slf4j.Logger
import org.slf4j.LoggerFactory
import work.slhaf.partner.api.agent.factory.component.annotation.AgentComponent
import work.slhaf.partner.api.agent.runtime.config.AgentConfigLoader
import work.slhaf.partner.api.agent.runtime.interaction.flow.RunningFlowContext
import work.slhaf.partner.api.chat.pojo.Message
import work.slhaf.partner.api.chat.runtime.OpenAiChatRuntime
import work.slhaf.partner.api.chat.runtime.StreamChatMessageConsumer
/**
* 模块基类
@@ -35,55 +31,3 @@ sealed class AbstractAgentModule {
}
interface ActivateModel {
val runtime: OpenAiChatRuntime
get() = runtimeMap.computeIfAbsent(modelKey()) {
buildRuntime()
}
companion object {
val runtimeMap: MutableMap<String, OpenAiChatRuntime> = mutableMapOf()
private val configManager: AgentConfigLoader = AgentConfigLoader.INSTANCE
}
fun buildRuntime(): OpenAiChatRuntime {
val modelConfig = configManager.loadModelConfig(modelKey())
return OpenAiChatRuntime(modelConfig.baseUrl, modelConfig.apikey, modelConfig.model)
}
fun chat(messages: List<Message>): String {
return runtime.chat(mergeMessages(messages))
}
fun streamChat(messages: List<Message>, handler: StreamChatMessageConsumer) {
return runtime.streamChat(mergeMessages(messages), handler)
}
fun <T : Any> formattedChat(messages: List<Message>, responseType: Class<T>): T {
return runtime.formattedChat(mergeMessages(messages), responseType)
}
fun mergeMessages(messages: List<Message>): List<Message> {
if (modulePrompt().isEmpty()) {
return messages
}
return buildList {
addAll(modulePrompt())
addAll(messages)
}
}
/**
* 对应调用的模型配置名称
*/
fun modelKey(): String {
return if (this is AbstractAgentModule) {
this.moduleName
} else {
javaClass.simpleName
}
}
fun modulePrompt(): List<Message> = emptyList()
}

View File

@@ -0,0 +1,42 @@
package work.slhaf.partner.api.chat
import work.slhaf.partner.api.agent.factory.component.abstracts.AbstractAgentModule
import work.slhaf.partner.api.chat.pojo.Message
interface ActivateModel {
fun chat(messages: List<Message>): String {
return ModelRuntimeRegistry.resolveProvider(modelKey()).chat(mergeMessages(messages))
}
fun streamChat(messages: List<Message>, handler: StreamChatMessageConsumer) {
ModelRuntimeRegistry.resolveProvider(modelKey()).streamChat(mergeMessages(messages), handler)
}
fun <T : Any> formattedChat(messages: List<Message>, responseType: Class<T>): T {
return ModelRuntimeRegistry.resolveProvider(modelKey()).formattedChat(mergeMessages(messages), responseType)
}
fun mergeMessages(messages: List<Message>): List<Message> {
if (modulePrompt().isEmpty()) {
return messages
}
return buildList {
addAll(modulePrompt())
addAll(messages)
}
}
/**
* 对应调用的模型配置名称
*/
fun modelKey(): String {
return if (this is AbstractAgentModule) {
this.moduleName
} else {
javaClass.simpleName
}
}
fun modulePrompt(): List<Message> = emptyList()
}

View File

@@ -0,0 +1,76 @@
package work.slhaf.partner.api.chat
import work.slhaf.partner.api.chat.provider.ModelProvider
import work.slhaf.partner.api.chat.provider.ProviderOverride
import work.slhaf.partner.api.chat.provider.openai.OpenAiCompatibleProvider
object ModelRuntimeRegistry {
private const val DEFAULT_PROVIDER = "default"
/**
* 基础的 provider 提供商,可 fork 出新的 runtime provider必须提供一个 default provider
*/
private val baseProvider = mutableMapOf<String, ModelProvider>()
/**
* 根据模块进行对应的 provider
*/
private val runtimeProvider = mutableMapOf<String, ModelProvider>()
fun resolveProvider(modelKey: String): ModelProvider {
val provider = runtimeProvider[modelKey]
if (provider != null) {
return provider
}
return baseProvider[DEFAULT_PROVIDER]!!
}
private fun registerProvider(config: ProviderConfig) {
when (config) {
is OpenAiCompatibleProviderConfig -> baseProvider[config.name] =
OpenAiCompatibleProvider(config.baseUrl, config.apiKey, config.defaultModel)
}
}
private fun forkProvider(config: RuntimeProviderConfig) {
val provider = baseProvider[config.providerName]
?: throw IllegalArgumentException("Provider ${config.providerName} not found")
val override = config.override
runtimeProvider[config.modelKey] = if (override != null) {
provider.fork(override)
} else {
provider
}
}
}
data class RuntimeProviderConfig(
val modelKey: String,
val providerName: String,
val override: ProviderOverride?
)
sealed class ProviderConfig {
abstract val name: String
abstract val type: ProviderType
abstract val defaultModel: String
enum class ProviderType {
OPENAI_COMPATIBLE
}
}
data class OpenAiCompatibleProviderConfig(
override val name: String,
override val type: ProviderType = ProviderType.OPENAI_COMPATIBLE,
override val defaultModel: String,
val baseUrl: String,
val apiKey: String
) : ProviderConfig()

View File

@@ -1,4 +1,4 @@
package work.slhaf.partner.api.chat.runtime;
package work.slhaf.partner.api.chat;
public abstract class StreamChatMessageConsumer {
private final StringBuilder responseText = new StringBuilder();

View File

@@ -0,0 +1,27 @@
package work.slhaf.partner.api.chat.provider
import com.alibaba.fastjson2.JSONObject
import work.slhaf.partner.api.chat.StreamChatMessageConsumer
import work.slhaf.partner.api.chat.pojo.Message
abstract class ModelProvider @JvmOverloads constructor(
val override: ProviderOverride? = null
) {
abstract fun fork(override: ProviderOverride): ModelProvider
abstract fun streamChat(messages: List<Message>, consumer: StreamChatMessageConsumer)
abstract fun chat(messages: List<Message>): String
abstract fun <T> formattedChat(messages: List<Message>, type: Class<T>): T
}
data class ProviderOverride(
val model: String,
val temperature: Double?,
val topP: Double?,
val maxTokens: Int?,
val extras: JSONObject?
)

View File

@@ -1,34 +1,57 @@
package work.slhaf.partner.api.chat.runtime;
package work.slhaf.partner.api.chat.provider.openai;
import com.alibaba.fastjson2.JSONObject;
import com.openai.client.OpenAIClient;
import com.openai.client.okhttp.OpenAIOkHttpClient;
import com.openai.core.JsonValue;
import com.openai.core.http.StreamResponse;
import com.openai.models.chat.completions.*;
import org.jetbrains.annotations.NotNull;
import work.slhaf.partner.api.chat.StreamChatMessageConsumer;
import work.slhaf.partner.api.chat.pojo.Message;
import work.slhaf.partner.api.chat.provider.ModelProvider;
import work.slhaf.partner.api.chat.provider.ProviderOverride;
import java.time.Duration;
import java.util.List;
public class OpenAiChatRuntime {
public class OpenAiCompatibleProvider extends ModelProvider {
private final OpenAIClient client;
private final String baseUrl;
private final String apiKey;
private final String model;
public OpenAiChatRuntime(String baseUrl, String apikey, String model) {
private final OpenAIClient client;
public OpenAiCompatibleProvider(String baseUrl, String apikey, String model) {
this.client = OpenAIOkHttpClient.builder()
.baseUrl(baseUrl)
.apiKey(apikey)
.timeout(Duration.ofSeconds(30))
.build();
this.baseUrl = baseUrl;
this.apiKey = apikey;
this.model = model;
}
public String chat(List<Message> messages) {
public OpenAiCompatibleProvider(String baseUrl, String apikey, String model, ProviderOverride override) {
super(override);
this.client = OpenAIOkHttpClient.builder()
.baseUrl(baseUrl)
.apiKey(apikey)
.timeout(Duration.ofSeconds(30))
.build();
this.baseUrl = baseUrl;
this.apiKey = apikey;
this.model = model;
}
public @NotNull String chat(@NotNull List<Message> messages) {
ChatCompletionCreateParams params = buildParams(messages);
return extractText(client.chat().completions().create(params));
}
public void streamChat(List<Message> messages, StreamChatMessageConsumer handler) {
public void streamChat(@NotNull List<Message> messages, StreamChatMessageConsumer handler) {
ChatCompletionCreateParams params = buildParams(messages);
try (StreamResponse<ChatCompletionChunk> streamResponse = client.chat().completions().createStreaming(params)) {
streamResponse.stream()
@@ -39,7 +62,7 @@ public class OpenAiChatRuntime {
}
}
public <T> T formattedChat(List<Message> messages, Class<T> responseType) {
public <T> T formattedChat(@NotNull List<Message> messages, @NotNull Class<T> responseType) {
StructuredChatCompletionCreateParams<T> params = buildParams(messages).toBuilder()
.responseFormat(responseType)
.build();
@@ -47,10 +70,30 @@ public class OpenAiChatRuntime {
}
private ChatCompletionCreateParams buildParams(List<Message> messages) {
return ChatCompletionCreateParams.builder()
ChatCompletionCreateParams.Builder paramsBuilder = ChatCompletionCreateParams.builder()
.model(model)
.messages(OpenAiMessageAdapter.toParams(messages))
.build();
.messages(OpenAiMessageAdapter.toParams(messages));
ProviderOverride override = getOverride();
if (override != null) {
if (override.getTemperature() != null) {
paramsBuilder.temperature(override.getTemperature());
}
if (override.getTopP() != null) {
paramsBuilder.topP(override.getTopP());
}
if (override.getMaxTokens() != null) {
paramsBuilder.maxCompletionTokens(override.getMaxTokens());
}
JSONObject extras = override.getExtras();
if (extras != null) {
extras.forEach((key, value) -> {
paramsBuilder.putAdditionalBodyProperty(key, JsonValue.from(value));
});
}
}
return paramsBuilder.build();
}
private String extractText(ChatCompletion completion) {
@@ -68,4 +111,9 @@ public class OpenAiChatRuntime {
return completion.choices().getFirst().message().content()
.orElseThrow(() -> new IllegalStateException("OpenAI structured chat completion returned empty content."));
}
@Override
public @NotNull ModelProvider fork(@NotNull ProviderOverride override) {
return new OpenAiCompatibleProvider(baseUrl, apiKey, override.getModel(), override);
}
}

View File

@@ -1,4 +1,4 @@
package work.slhaf.partner.api.chat.runtime;
package work.slhaf.partner.api.chat.provider.openai;
import com.openai.models.chat.completions.ChatCompletionAssistantMessageParam;
import com.openai.models.chat.completions.ChatCompletionMessageParam;

View File

@@ -2,7 +2,7 @@ package work.slhaf.partner.api.chat.pojo;
import com.alibaba.fastjson2.JSON;
import org.junit.jupiter.api.Test;
import work.slhaf.partner.api.chat.runtime.OpenAiMessageAdapter;
import work.slhaf.partner.api.chat.provider.openai.OpenAiMessageAdapter;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;