mirror of
https://github.com/slhaf/Partner.git
synced 2026-05-12 08:43:02 +08:00
feat(chat): support streaming reply in agent turn
This commit is contained in:
@@ -11,6 +11,7 @@ import work.slhaf.partner.api.agent.factory.component.abstracts.AbstractAgentMod
|
|||||||
import work.slhaf.partner.api.agent.factory.component.abstracts.ActivateModel;
|
import work.slhaf.partner.api.agent.factory.component.abstracts.ActivateModel;
|
||||||
import work.slhaf.partner.api.agent.factory.component.annotation.Init;
|
import work.slhaf.partner.api.agent.factory.component.annotation.Init;
|
||||||
import work.slhaf.partner.api.chat.pojo.Message;
|
import work.slhaf.partner.api.chat.pojo.Message;
|
||||||
|
import work.slhaf.partner.api.chat.runtime.StreamChatMessageConsumer;
|
||||||
import work.slhaf.partner.core.cognition.*;
|
import work.slhaf.partner.core.cognition.*;
|
||||||
import work.slhaf.partner.runtime.interaction.data.context.PartnerRunningFlowContext;
|
import work.slhaf.partner.runtime.interaction.data.context.PartnerRunningFlowContext;
|
||||||
|
|
||||||
@@ -53,11 +54,6 @@ public class CommunicationProducer extends AbstractAgentModule.Running<PartnerRu
|
|||||||
return "communication_producer";
|
return "communication_producer";
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public boolean useStreaming() {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public @NotNull List<Message> modulePrompt() {
|
public @NotNull List<Message> modulePrompt() {
|
||||||
return List.of(new Message(Message.Character.SYSTEM, MODULE_PROMPT));
|
return List.of(new Message(Message.Character.SYSTEM, MODULE_PROMPT));
|
||||||
@@ -70,29 +66,10 @@ public class CommunicationProducer extends AbstractAgentModule.Running<PartnerRu
|
|||||||
}
|
}
|
||||||
|
|
||||||
private void executeChat(PartnerRunningFlowContext runningFlowContext) {
|
private void executeChat(PartnerRunningFlowContext runningFlowContext) {
|
||||||
String responseText = null;
|
StreamChatMessageConsumer consumer = ReplyDispatcher.INSTANCE.createConsumer(runningFlowContext.getTarget());
|
||||||
|
this.streamChat(buildChatMessages(runningFlowContext), consumer);
|
||||||
// TODO considering removing retries in module
|
updateChatMessages(runningFlowContext, consumer.collectResponse());
|
||||||
int count = 0;
|
|
||||||
while (true) {
|
|
||||||
try {
|
|
||||||
// TODO 为各模块提供 emit msg 能力后, 在这里统一接收并分发结构化输出.
|
|
||||||
responseText = this.chat(buildChatMessages(runningFlowContext));
|
|
||||||
log.debug("CommunicationProducer responses: {}", responseText);
|
|
||||||
updateChatMessages(runningFlowContext, responseText);
|
|
||||||
updateContext();
|
updateContext();
|
||||||
break;
|
|
||||||
} catch (Exception e) {
|
|
||||||
count++;
|
|
||||||
log.error("Communicating exception occurred: {}", e.getLocalizedMessage());
|
|
||||||
if (count > 3) {
|
|
||||||
responseText = "CommunicationProducer Failed: " + e.getLocalizedMessage();
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
} finally {
|
|
||||||
updateCoreResponse(runningFlowContext, responseText);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private void updateContext() {
|
private void updateContext() {
|
||||||
|
|||||||
@@ -0,0 +1,96 @@
|
|||||||
|
package work.slhaf.partner.module.communication
|
||||||
|
|
||||||
|
import kotlinx.coroutines.*
|
||||||
|
import kotlinx.coroutines.channels.Channel
|
||||||
|
import work.slhaf.partner.api.agent.runtime.interaction.AgentRuntime
|
||||||
|
import work.slhaf.partner.api.agent.runtime.interaction.data.InteractionEvent.EventStatus
|
||||||
|
import work.slhaf.partner.api.agent.runtime.interaction.data.Reply
|
||||||
|
import work.slhaf.partner.api.chat.runtime.StreamChatMessageConsumer
|
||||||
|
import kotlin.time.Duration.Companion.milliseconds
|
||||||
|
|
||||||
|
object ReplyDispatcher {
|
||||||
|
|
||||||
|
private const val AGGREGATE_WINDOW_MILLIS = 100L
|
||||||
|
|
||||||
|
// TODO 通过配置中心动态指定响应通道
|
||||||
|
private var channelName: String? = null
|
||||||
|
|
||||||
|
private val scope = CoroutineScope(SupervisorJob() + Dispatchers.Default)
|
||||||
|
private val collectorChannel = Channel<ReplyChunk>(Channel.UNLIMITED)
|
||||||
|
|
||||||
|
init {
|
||||||
|
scope.launch {
|
||||||
|
var pendingChunk: ReplyChunk? = null
|
||||||
|
while (true) {
|
||||||
|
val firstChunk = pendingChunk ?: collectorChannel.receiveCatching().getOrNull() ?: break
|
||||||
|
pendingChunk = null
|
||||||
|
val builder = StringBuilder(firstChunk.delta)
|
||||||
|
|
||||||
|
while (true) {
|
||||||
|
val nextChunk = withTimeoutOrNull(AGGREGATE_WINDOW_MILLIS.milliseconds) {
|
||||||
|
collectorChannel.receiveCatching()
|
||||||
|
} ?: break
|
||||||
|
|
||||||
|
if (nextChunk.isClosed) {
|
||||||
|
flush(builder.toString(), firstChunk.target, firstChunk.channelName)
|
||||||
|
return@launch
|
||||||
|
}
|
||||||
|
|
||||||
|
val chunk = nextChunk.getOrNull() ?: break
|
||||||
|
if (chunk.target == firstChunk.target && chunk.channelName == firstChunk.channelName) {
|
||||||
|
builder.append(chunk.delta)
|
||||||
|
} else {
|
||||||
|
pendingChunk = chunk
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
flush(builder.toString(), firstChunk.target, firstChunk.channelName)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private fun flush(content: String, target: String, channelName: String?) {
|
||||||
|
if (content.isEmpty()) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
val event = Reply(
|
||||||
|
status = EventStatus.RUNNING,
|
||||||
|
target = target,
|
||||||
|
content = content,
|
||||||
|
mode = Reply.ContentMode.APPEND,
|
||||||
|
done = false
|
||||||
|
)
|
||||||
|
if (channelName.isNullOrBlank()) {
|
||||||
|
AgentRuntime.response(event)
|
||||||
|
} else {
|
||||||
|
AgentRuntime.response(event, channelName)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fun createConsumer(target: String): StreamChatMessageConsumer = ReplyConsumer(
|
||||||
|
collectorChannel = collectorChannel,
|
||||||
|
target = target,
|
||||||
|
channelName = channelName
|
||||||
|
)
|
||||||
|
|
||||||
|
private data class ReplyChunk(
|
||||||
|
val delta: String,
|
||||||
|
val target: String,
|
||||||
|
val channelName: String?
|
||||||
|
)
|
||||||
|
|
||||||
|
private class ReplyConsumer(
|
||||||
|
private val collectorChannel: Channel<ReplyChunk>,
|
||||||
|
private val target: String,
|
||||||
|
private val channelName: String?
|
||||||
|
) : StreamChatMessageConsumer() {
|
||||||
|
|
||||||
|
override fun consumeDelta(delta: String?) {
|
||||||
|
if (delta != null) {
|
||||||
|
collectorChannel.trySend(ReplyChunk(delta, target, channelName)).isSuccess
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -7,6 +7,7 @@ import work.slhaf.partner.api.agent.runtime.config.AgentConfigLoader
|
|||||||
import work.slhaf.partner.api.agent.runtime.interaction.flow.RunningFlowContext
|
import work.slhaf.partner.api.agent.runtime.interaction.flow.RunningFlowContext
|
||||||
import work.slhaf.partner.api.chat.pojo.Message
|
import work.slhaf.partner.api.chat.pojo.Message
|
||||||
import work.slhaf.partner.api.chat.runtime.OpenAiChatRuntime
|
import work.slhaf.partner.api.chat.runtime.OpenAiChatRuntime
|
||||||
|
import work.slhaf.partner.api.chat.runtime.StreamChatMessageConsumer
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 模块基类
|
* 模块基类
|
||||||
@@ -52,11 +53,15 @@ interface ActivateModel {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fun chat(messages: List<Message>): String {
|
fun chat(messages: List<Message>): String {
|
||||||
return runtime.chat(mergeMessages(messages), useStreaming())
|
return runtime.chat(mergeMessages(messages))
|
||||||
|
}
|
||||||
|
|
||||||
|
fun streamChat(messages: List<Message>, handler: StreamChatMessageConsumer) {
|
||||||
|
return runtime.streamChat(mergeMessages(messages), handler)
|
||||||
}
|
}
|
||||||
|
|
||||||
fun <T : Any> formattedChat(messages: List<Message>, responseType: Class<T>): T {
|
fun <T : Any> formattedChat(messages: List<Message>, responseType: Class<T>): T {
|
||||||
return runtime.formattedChat(mergeMessages(messages), useStreaming(), responseType)
|
return runtime.formattedChat(mergeMessages(messages), responseType)
|
||||||
}
|
}
|
||||||
|
|
||||||
fun mergeMessages(messages: List<Message>): List<Message> {
|
fun mergeMessages(messages: List<Message>): List<Message> {
|
||||||
@@ -81,6 +86,4 @@ interface ActivateModel {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fun modulePrompt(): List<Message> = emptyList()
|
fun modulePrompt(): List<Message> = emptyList()
|
||||||
|
|
||||||
fun useStreaming(): Boolean = false
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,7 +3,6 @@ package work.slhaf.partner.api.chat.runtime;
|
|||||||
import com.openai.client.OpenAIClient;
|
import com.openai.client.OpenAIClient;
|
||||||
import com.openai.client.okhttp.OpenAIOkHttpClient;
|
import com.openai.client.okhttp.OpenAIOkHttpClient;
|
||||||
import com.openai.core.http.StreamResponse;
|
import com.openai.core.http.StreamResponse;
|
||||||
import com.openai.helpers.ChatCompletionAccumulator;
|
|
||||||
import com.openai.models.chat.completions.*;
|
import com.openai.models.chat.completions.*;
|
||||||
import work.slhaf.partner.api.chat.pojo.Message;
|
import work.slhaf.partner.api.chat.pojo.Message;
|
||||||
|
|
||||||
@@ -24,34 +23,29 @@ public class OpenAiChatRuntime {
|
|||||||
this.model = model;
|
this.model = model;
|
||||||
}
|
}
|
||||||
|
|
||||||
public String chat(List<Message> messages, boolean streaming) {
|
public String chat(List<Message> messages) {
|
||||||
ChatCompletionCreateParams params = buildParams(messages);
|
ChatCompletionCreateParams params = buildParams(messages);
|
||||||
if (!streaming) {
|
|
||||||
return extractText(client.chat().completions().create(params));
|
return extractText(client.chat().completions().create(params));
|
||||||
}
|
}
|
||||||
|
|
||||||
ChatCompletionAccumulator accumulator = ChatCompletionAccumulator.create();
|
public void streamChat(List<Message> messages, StreamChatMessageConsumer handler) {
|
||||||
try (StreamResponse<ChatCompletionChunk> response = client.chat().completions().createStreaming(params)) {
|
ChatCompletionCreateParams params = buildParams(messages);
|
||||||
response.stream().forEach(accumulator::accumulate);
|
try (StreamResponse<ChatCompletionChunk> streamResponse = client.chat().completions().createStreaming(params)) {
|
||||||
|
streamResponse.stream()
|
||||||
|
.flatMap(completion -> completion.choices().stream())
|
||||||
|
.flatMap(choice -> choice.delta().content().stream())
|
||||||
|
.filter(delta -> !delta.isEmpty())
|
||||||
|
.forEach(handler::onDelta);
|
||||||
}
|
}
|
||||||
return extractText(accumulator.chatCompletion());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public <T> T formattedChat(List<Message> messages, boolean streaming, Class<T> responseType) {
|
public <T> T formattedChat(List<Message> messages, Class<T> responseType) {
|
||||||
StructuredChatCompletionCreateParams<T> params = buildParams(messages).toBuilder()
|
StructuredChatCompletionCreateParams<T> params = buildParams(messages).toBuilder()
|
||||||
.responseFormat(responseType)
|
.responseFormat(responseType)
|
||||||
.build();
|
.build();
|
||||||
if (!streaming) {
|
|
||||||
return extractStructured(client.chat().completions().create(params));
|
return extractStructured(client.chat().completions().create(params));
|
||||||
}
|
}
|
||||||
|
|
||||||
ChatCompletionAccumulator accumulator = ChatCompletionAccumulator.create();
|
|
||||||
try (StreamResponse<ChatCompletionChunk> response = client.chat().completions().createStreaming(params.rawParams())) {
|
|
||||||
response.stream().forEach(accumulator::accumulate);
|
|
||||||
}
|
|
||||||
return extractStructured(accumulator.chatCompletion(responseType));
|
|
||||||
}
|
|
||||||
|
|
||||||
private ChatCompletionCreateParams buildParams(List<Message> messages) {
|
private ChatCompletionCreateParams buildParams(List<Message> messages) {
|
||||||
return ChatCompletionCreateParams.builder()
|
return ChatCompletionCreateParams.builder()
|
||||||
.model(model)
|
.model(model)
|
||||||
|
|||||||
@@ -0,0 +1,16 @@
|
|||||||
|
package work.slhaf.partner.api.chat.runtime;
|
||||||
|
|
||||||
|
public abstract class StreamChatMessageConsumer {
|
||||||
|
private final StringBuilder responseText = new StringBuilder();
|
||||||
|
|
||||||
|
public void onDelta(String delta) {
|
||||||
|
consumeDelta(delta);
|
||||||
|
responseText.append(delta);
|
||||||
|
}
|
||||||
|
|
||||||
|
public String collectResponse() {
|
||||||
|
return responseText.toString();
|
||||||
|
}
|
||||||
|
|
||||||
|
protected abstract void consumeDelta(String delta);
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user