feat(chat): support streaming reply in agent turn

This commit is contained in:
2026-03-31 14:44:58 +08:00
parent b4c44c7d98
commit 81aa4b7933
5 changed files with 139 additions and 53 deletions

View File

@@ -11,6 +11,7 @@ import work.slhaf.partner.api.agent.factory.component.abstracts.AbstractAgentMod
import work.slhaf.partner.api.agent.factory.component.abstracts.ActivateModel; import work.slhaf.partner.api.agent.factory.component.abstracts.ActivateModel;
import work.slhaf.partner.api.agent.factory.component.annotation.Init; import work.slhaf.partner.api.agent.factory.component.annotation.Init;
import work.slhaf.partner.api.chat.pojo.Message; import work.slhaf.partner.api.chat.pojo.Message;
import work.slhaf.partner.api.chat.runtime.StreamChatMessageConsumer;
import work.slhaf.partner.core.cognition.*; import work.slhaf.partner.core.cognition.*;
import work.slhaf.partner.runtime.interaction.data.context.PartnerRunningFlowContext; import work.slhaf.partner.runtime.interaction.data.context.PartnerRunningFlowContext;
@@ -53,11 +54,6 @@ public class CommunicationProducer extends AbstractAgentModule.Running<PartnerRu
return "communication_producer"; return "communication_producer";
} }
@Override
public boolean useStreaming() {
return true;
}
@Override @Override
public @NotNull List<Message> modulePrompt() { public @NotNull List<Message> modulePrompt() {
return List.of(new Message(Message.Character.SYSTEM, MODULE_PROMPT)); return List.of(new Message(Message.Character.SYSTEM, MODULE_PROMPT));
@@ -70,29 +66,10 @@ public class CommunicationProducer extends AbstractAgentModule.Running<PartnerRu
} }
private void executeChat(PartnerRunningFlowContext runningFlowContext) { private void executeChat(PartnerRunningFlowContext runningFlowContext) {
String responseText = null; StreamChatMessageConsumer consumer = ReplyDispatcher.INSTANCE.createConsumer(runningFlowContext.getTarget());
this.streamChat(buildChatMessages(runningFlowContext), consumer);
// TODO considering removing retries in module updateChatMessages(runningFlowContext, consumer.collectResponse());
int count = 0; updateContext();
while (true) {
try {
// TODO 为各模块提供 emit msg 能力后, 在这里统一接收并分发结构化输出.
responseText = this.chat(buildChatMessages(runningFlowContext));
log.debug("CommunicationProducer responses: {}", responseText);
updateChatMessages(runningFlowContext, responseText);
updateContext();
break;
} catch (Exception e) {
count++;
log.error("Communicating exception occurred: {}", e.getLocalizedMessage());
if (count > 3) {
responseText = "CommunicationProducer Failed: " + e.getLocalizedMessage();
break;
}
} finally {
updateCoreResponse(runningFlowContext, responseText);
}
}
} }
private void updateContext() { private void updateContext() {

View File

@@ -0,0 +1,96 @@
package work.slhaf.partner.module.communication
import kotlinx.coroutines.*
import kotlinx.coroutines.channels.Channel
import work.slhaf.partner.api.agent.runtime.interaction.AgentRuntime
import work.slhaf.partner.api.agent.runtime.interaction.data.InteractionEvent.EventStatus
import work.slhaf.partner.api.agent.runtime.interaction.data.Reply
import work.slhaf.partner.api.chat.runtime.StreamChatMessageConsumer
import kotlin.time.Duration.Companion.milliseconds
object ReplyDispatcher {
private const val AGGREGATE_WINDOW_MILLIS = 100L
// TODO 通过配置中心动态指定响应通道
private var channelName: String? = null
private val scope = CoroutineScope(SupervisorJob() + Dispatchers.Default)
private val collectorChannel = Channel<ReplyChunk>(Channel.UNLIMITED)
init {
scope.launch {
var pendingChunk: ReplyChunk? = null
while (true) {
val firstChunk = pendingChunk ?: collectorChannel.receiveCatching().getOrNull() ?: break
pendingChunk = null
val builder = StringBuilder(firstChunk.delta)
while (true) {
val nextChunk = withTimeoutOrNull(AGGREGATE_WINDOW_MILLIS.milliseconds) {
collectorChannel.receiveCatching()
} ?: break
if (nextChunk.isClosed) {
flush(builder.toString(), firstChunk.target, firstChunk.channelName)
return@launch
}
val chunk = nextChunk.getOrNull() ?: break
if (chunk.target == firstChunk.target && chunk.channelName == firstChunk.channelName) {
builder.append(chunk.delta)
} else {
pendingChunk = chunk
break
}
}
flush(builder.toString(), firstChunk.target, firstChunk.channelName)
}
}
}
private fun flush(content: String, target: String, channelName: String?) {
if (content.isEmpty()) {
return
}
val event = Reply(
status = EventStatus.RUNNING,
target = target,
content = content,
mode = Reply.ContentMode.APPEND,
done = false
)
if (channelName.isNullOrBlank()) {
AgentRuntime.response(event)
} else {
AgentRuntime.response(event, channelName)
}
}
fun createConsumer(target: String): StreamChatMessageConsumer = ReplyConsumer(
collectorChannel = collectorChannel,
target = target,
channelName = channelName
)
private data class ReplyChunk(
val delta: String,
val target: String,
val channelName: String?
)
private class ReplyConsumer(
private val collectorChannel: Channel<ReplyChunk>,
private val target: String,
private val channelName: String?
) : StreamChatMessageConsumer() {
override fun consumeDelta(delta: String?) {
if (delta != null) {
collectorChannel.trySend(ReplyChunk(delta, target, channelName)).isSuccess
}
}
}
}

View File

@@ -7,6 +7,7 @@ import work.slhaf.partner.api.agent.runtime.config.AgentConfigLoader
import work.slhaf.partner.api.agent.runtime.interaction.flow.RunningFlowContext import work.slhaf.partner.api.agent.runtime.interaction.flow.RunningFlowContext
import work.slhaf.partner.api.chat.pojo.Message import work.slhaf.partner.api.chat.pojo.Message
import work.slhaf.partner.api.chat.runtime.OpenAiChatRuntime import work.slhaf.partner.api.chat.runtime.OpenAiChatRuntime
import work.slhaf.partner.api.chat.runtime.StreamChatMessageConsumer
/** /**
* 模块基类 * 模块基类
@@ -52,11 +53,15 @@ interface ActivateModel {
} }
fun chat(messages: List<Message>): String { fun chat(messages: List<Message>): String {
return runtime.chat(mergeMessages(messages), useStreaming()) return runtime.chat(mergeMessages(messages))
}
fun streamChat(messages: List<Message>, handler: StreamChatMessageConsumer) {
return runtime.streamChat(mergeMessages(messages), handler)
} }
fun <T : Any> formattedChat(messages: List<Message>, responseType: Class<T>): T { fun <T : Any> formattedChat(messages: List<Message>, responseType: Class<T>): T {
return runtime.formattedChat(mergeMessages(messages), useStreaming(), responseType) return runtime.formattedChat(mergeMessages(messages), responseType)
} }
fun mergeMessages(messages: List<Message>): List<Message> { fun mergeMessages(messages: List<Message>): List<Message> {
@@ -81,6 +86,4 @@ interface ActivateModel {
} }
fun modulePrompt(): List<Message> = emptyList() fun modulePrompt(): List<Message> = emptyList()
fun useStreaming(): Boolean = false
} }

View File

@@ -3,7 +3,6 @@ package work.slhaf.partner.api.chat.runtime;
import com.openai.client.OpenAIClient; import com.openai.client.OpenAIClient;
import com.openai.client.okhttp.OpenAIOkHttpClient; import com.openai.client.okhttp.OpenAIOkHttpClient;
import com.openai.core.http.StreamResponse; import com.openai.core.http.StreamResponse;
import com.openai.helpers.ChatCompletionAccumulator;
import com.openai.models.chat.completions.*; import com.openai.models.chat.completions.*;
import work.slhaf.partner.api.chat.pojo.Message; import work.slhaf.partner.api.chat.pojo.Message;
@@ -24,32 +23,27 @@ public class OpenAiChatRuntime {
this.model = model; this.model = model;
} }
public String chat(List<Message> messages, boolean streaming) { public String chat(List<Message> messages) {
ChatCompletionCreateParams params = buildParams(messages); ChatCompletionCreateParams params = buildParams(messages);
if (!streaming) { return extractText(client.chat().completions().create(params));
return extractText(client.chat().completions().create(params));
}
ChatCompletionAccumulator accumulator = ChatCompletionAccumulator.create();
try (StreamResponse<ChatCompletionChunk> response = client.chat().completions().createStreaming(params)) {
response.stream().forEach(accumulator::accumulate);
}
return extractText(accumulator.chatCompletion());
} }
public <T> T formattedChat(List<Message> messages, boolean streaming, Class<T> responseType) { public void streamChat(List<Message> messages, StreamChatMessageConsumer handler) {
ChatCompletionCreateParams params = buildParams(messages);
try (StreamResponse<ChatCompletionChunk> streamResponse = client.chat().completions().createStreaming(params)) {
streamResponse.stream()
.flatMap(completion -> completion.choices().stream())
.flatMap(choice -> choice.delta().content().stream())
.filter(delta -> !delta.isEmpty())
.forEach(handler::onDelta);
}
}
public <T> T formattedChat(List<Message> messages, Class<T> responseType) {
StructuredChatCompletionCreateParams<T> params = buildParams(messages).toBuilder() StructuredChatCompletionCreateParams<T> params = buildParams(messages).toBuilder()
.responseFormat(responseType) .responseFormat(responseType)
.build(); .build();
if (!streaming) { return extractStructured(client.chat().completions().create(params));
return extractStructured(client.chat().completions().create(params));
}
ChatCompletionAccumulator accumulator = ChatCompletionAccumulator.create();
try (StreamResponse<ChatCompletionChunk> response = client.chat().completions().createStreaming(params.rawParams())) {
response.stream().forEach(accumulator::accumulate);
}
return extractStructured(accumulator.chatCompletion(responseType));
} }
private ChatCompletionCreateParams buildParams(List<Message> messages) { private ChatCompletionCreateParams buildParams(List<Message> messages) {

View File

@@ -0,0 +1,16 @@
package work.slhaf.partner.api.chat.runtime;
public abstract class StreamChatMessageConsumer {
private final StringBuilder responseText = new StringBuilder();
public void onDelta(String delta) {
consumeDelta(delta);
responseText.append(delta);
}
public String collectResponse() {
return responseText.toString();
}
protected abstract void consumeDelta(String delta);
}