From 81aa4b79336358fda2d0931ac15e525dbb5beddd Mon Sep 17 00:00:00 2001 From: slhafzjw Date: Tue, 31 Mar 2026 14:44:58 +0800 Subject: [PATCH] feat(chat): support streaming reply in agent turn --- .../communication/CommunicationProducer.java | 33 +------ .../module/communication/ReplyDispatcher.kt | 96 +++++++++++++++++++ .../component/abstracts/AgentModule.kt | 11 ++- .../api/chat/runtime/OpenAiChatRuntime.java | 36 +++---- .../runtime/StreamChatMessageConsumer.java | 16 ++++ 5 files changed, 139 insertions(+), 53 deletions(-) create mode 100644 Partner-Core/src/main/java/work/slhaf/partner/module/communication/ReplyDispatcher.kt create mode 100644 Partner-Framework/src/main/java/work/slhaf/partner/api/chat/runtime/StreamChatMessageConsumer.java diff --git a/Partner-Core/src/main/java/work/slhaf/partner/module/communication/CommunicationProducer.java b/Partner-Core/src/main/java/work/slhaf/partner/module/communication/CommunicationProducer.java index 259de998..f01e1435 100644 --- a/Partner-Core/src/main/java/work/slhaf/partner/module/communication/CommunicationProducer.java +++ b/Partner-Core/src/main/java/work/slhaf/partner/module/communication/CommunicationProducer.java @@ -11,6 +11,7 @@ import work.slhaf.partner.api.agent.factory.component.abstracts.AbstractAgentMod import work.slhaf.partner.api.agent.factory.component.abstracts.ActivateModel; import work.slhaf.partner.api.agent.factory.component.annotation.Init; import work.slhaf.partner.api.chat.pojo.Message; +import work.slhaf.partner.api.chat.runtime.StreamChatMessageConsumer; import work.slhaf.partner.core.cognition.*; import work.slhaf.partner.runtime.interaction.data.context.PartnerRunningFlowContext; @@ -53,11 +54,6 @@ public class CommunicationProducer extends AbstractAgentModule.Running modulePrompt() { return List.of(new Message(Message.Character.SYSTEM, MODULE_PROMPT)); @@ -70,29 +66,10 @@ public class CommunicationProducer extends AbstractAgentModule.Running 3) { - responseText = "CommunicationProducer Failed: " + e.getLocalizedMessage(); - break; - } - } finally { - updateCoreResponse(runningFlowContext, responseText); - } - } + StreamChatMessageConsumer consumer = ReplyDispatcher.INSTANCE.createConsumer(runningFlowContext.getTarget()); + this.streamChat(buildChatMessages(runningFlowContext), consumer); + updateChatMessages(runningFlowContext, consumer.collectResponse()); + updateContext(); } private void updateContext() { diff --git a/Partner-Core/src/main/java/work/slhaf/partner/module/communication/ReplyDispatcher.kt b/Partner-Core/src/main/java/work/slhaf/partner/module/communication/ReplyDispatcher.kt new file mode 100644 index 00000000..c722ab9f --- /dev/null +++ b/Partner-Core/src/main/java/work/slhaf/partner/module/communication/ReplyDispatcher.kt @@ -0,0 +1,96 @@ +package work.slhaf.partner.module.communication + +import kotlinx.coroutines.* +import kotlinx.coroutines.channels.Channel +import work.slhaf.partner.api.agent.runtime.interaction.AgentRuntime +import work.slhaf.partner.api.agent.runtime.interaction.data.InteractionEvent.EventStatus +import work.slhaf.partner.api.agent.runtime.interaction.data.Reply +import work.slhaf.partner.api.chat.runtime.StreamChatMessageConsumer +import kotlin.time.Duration.Companion.milliseconds + +object ReplyDispatcher { + + private const val AGGREGATE_WINDOW_MILLIS = 100L + + // TODO 通过配置中心动态指定响应通道 + private var channelName: String? = null + + private val scope = CoroutineScope(SupervisorJob() + Dispatchers.Default) + private val collectorChannel = Channel(Channel.UNLIMITED) + + init { + scope.launch { + var pendingChunk: ReplyChunk? = null + while (true) { + val firstChunk = pendingChunk ?: collectorChannel.receiveCatching().getOrNull() ?: break + pendingChunk = null + val builder = StringBuilder(firstChunk.delta) + + while (true) { + val nextChunk = withTimeoutOrNull(AGGREGATE_WINDOW_MILLIS.milliseconds) { + collectorChannel.receiveCatching() + } ?: break + + if (nextChunk.isClosed) { + flush(builder.toString(), firstChunk.target, firstChunk.channelName) + return@launch + } + + val chunk = nextChunk.getOrNull() ?: break + if (chunk.target == firstChunk.target && chunk.channelName == firstChunk.channelName) { + builder.append(chunk.delta) + } else { + pendingChunk = chunk + break + } + } + + flush(builder.toString(), firstChunk.target, firstChunk.channelName) + } + } + } + + private fun flush(content: String, target: String, channelName: String?) { + if (content.isEmpty()) { + return + } + val event = Reply( + status = EventStatus.RUNNING, + target = target, + content = content, + mode = Reply.ContentMode.APPEND, + done = false + ) + if (channelName.isNullOrBlank()) { + AgentRuntime.response(event) + } else { + AgentRuntime.response(event, channelName) + } + } + + fun createConsumer(target: String): StreamChatMessageConsumer = ReplyConsumer( + collectorChannel = collectorChannel, + target = target, + channelName = channelName + ) + + private data class ReplyChunk( + val delta: String, + val target: String, + val channelName: String? + ) + + private class ReplyConsumer( + private val collectorChannel: Channel, + private val target: String, + private val channelName: String? + ) : StreamChatMessageConsumer() { + + override fun consumeDelta(delta: String?) { + if (delta != null) { + collectorChannel.trySend(ReplyChunk(delta, target, channelName)).isSuccess + } + } + + } +} diff --git a/Partner-Framework/src/main/java/work/slhaf/partner/api/agent/factory/component/abstracts/AgentModule.kt b/Partner-Framework/src/main/java/work/slhaf/partner/api/agent/factory/component/abstracts/AgentModule.kt index c358c78b..00a5eecb 100644 --- a/Partner-Framework/src/main/java/work/slhaf/partner/api/agent/factory/component/abstracts/AgentModule.kt +++ b/Partner-Framework/src/main/java/work/slhaf/partner/api/agent/factory/component/abstracts/AgentModule.kt @@ -7,6 +7,7 @@ import work.slhaf.partner.api.agent.runtime.config.AgentConfigLoader import work.slhaf.partner.api.agent.runtime.interaction.flow.RunningFlowContext import work.slhaf.partner.api.chat.pojo.Message import work.slhaf.partner.api.chat.runtime.OpenAiChatRuntime +import work.slhaf.partner.api.chat.runtime.StreamChatMessageConsumer /** * 模块基类 @@ -52,11 +53,15 @@ interface ActivateModel { } fun chat(messages: List): String { - return runtime.chat(mergeMessages(messages), useStreaming()) + return runtime.chat(mergeMessages(messages)) + } + + fun streamChat(messages: List, handler: StreamChatMessageConsumer) { + return runtime.streamChat(mergeMessages(messages), handler) } fun formattedChat(messages: List, responseType: Class): T { - return runtime.formattedChat(mergeMessages(messages), useStreaming(), responseType) + return runtime.formattedChat(mergeMessages(messages), responseType) } fun mergeMessages(messages: List): List { @@ -81,6 +86,4 @@ interface ActivateModel { } fun modulePrompt(): List = emptyList() - - fun useStreaming(): Boolean = false } diff --git a/Partner-Framework/src/main/java/work/slhaf/partner/api/chat/runtime/OpenAiChatRuntime.java b/Partner-Framework/src/main/java/work/slhaf/partner/api/chat/runtime/OpenAiChatRuntime.java index b06a4277..325239d6 100644 --- a/Partner-Framework/src/main/java/work/slhaf/partner/api/chat/runtime/OpenAiChatRuntime.java +++ b/Partner-Framework/src/main/java/work/slhaf/partner/api/chat/runtime/OpenAiChatRuntime.java @@ -3,7 +3,6 @@ package work.slhaf.partner.api.chat.runtime; import com.openai.client.OpenAIClient; import com.openai.client.okhttp.OpenAIOkHttpClient; import com.openai.core.http.StreamResponse; -import com.openai.helpers.ChatCompletionAccumulator; import com.openai.models.chat.completions.*; import work.slhaf.partner.api.chat.pojo.Message; @@ -24,32 +23,27 @@ public class OpenAiChatRuntime { this.model = model; } - public String chat(List messages, boolean streaming) { + public String chat(List messages) { ChatCompletionCreateParams params = buildParams(messages); - if (!streaming) { - return extractText(client.chat().completions().create(params)); - } - - ChatCompletionAccumulator accumulator = ChatCompletionAccumulator.create(); - try (StreamResponse response = client.chat().completions().createStreaming(params)) { - response.stream().forEach(accumulator::accumulate); - } - return extractText(accumulator.chatCompletion()); + return extractText(client.chat().completions().create(params)); } - public T formattedChat(List messages, boolean streaming, Class responseType) { + public void streamChat(List messages, StreamChatMessageConsumer handler) { + ChatCompletionCreateParams params = buildParams(messages); + try (StreamResponse streamResponse = client.chat().completions().createStreaming(params)) { + streamResponse.stream() + .flatMap(completion -> completion.choices().stream()) + .flatMap(choice -> choice.delta().content().stream()) + .filter(delta -> !delta.isEmpty()) + .forEach(handler::onDelta); + } + } + + public T formattedChat(List messages, Class responseType) { StructuredChatCompletionCreateParams params = buildParams(messages).toBuilder() .responseFormat(responseType) .build(); - if (!streaming) { - return extractStructured(client.chat().completions().create(params)); - } - - ChatCompletionAccumulator accumulator = ChatCompletionAccumulator.create(); - try (StreamResponse response = client.chat().completions().createStreaming(params.rawParams())) { - response.stream().forEach(accumulator::accumulate); - } - return extractStructured(accumulator.chatCompletion(responseType)); + return extractStructured(client.chat().completions().create(params)); } private ChatCompletionCreateParams buildParams(List messages) { diff --git a/Partner-Framework/src/main/java/work/slhaf/partner/api/chat/runtime/StreamChatMessageConsumer.java b/Partner-Framework/src/main/java/work/slhaf/partner/api/chat/runtime/StreamChatMessageConsumer.java new file mode 100644 index 00000000..12feabe1 --- /dev/null +++ b/Partner-Framework/src/main/java/work/slhaf/partner/api/chat/runtime/StreamChatMessageConsumer.java @@ -0,0 +1,16 @@ +package work.slhaf.partner.api.chat.runtime; + +public abstract class StreamChatMessageConsumer { + private final StringBuilder responseText = new StringBuilder(); + + public void onDelta(String delta) { + consumeDelta(delta); + responseText.append(delta); + } + + public String collectResponse() { + return responseText.toString(); + } + + protected abstract void consumeDelta(String delta); +}