refactor(runtime): support collect context by source and interrupt same-source running flow by module order

This commit is contained in:
2026-04-15 14:32:52 +08:00
parent 247057e100
commit dc147000ba
12 changed files with 537 additions and 122 deletions

View File

@@ -15,18 +15,36 @@ import work.slhaf.partner.framework.agent.interaction.data.InteractionEvent
import work.slhaf.partner.framework.agent.interaction.flow.RunningFlowContext
import work.slhaf.partner.framework.agent.support.Result
import java.nio.file.Path
import java.util.*
object AgentRuntime : Configurable, ConfigRegistration<ModuleMaskConfig> {
private const val DEFAULT_LOG_CHANNEL = "log_channel"
private val scope = CoroutineScope(SupervisorJob() + Dispatchers.Default)
private val channel = Channel<RunningFlowContext>(Channel.UNLIMITED)
private val responseChannels = mutableMapOf<String, ResponseChannel>(
LogChannel.channelName to LogChannel
)
private val wakeSignal = Channel<Unit>(Channel.UNLIMITED)
private val stateLock = Any()
/**
* 按照 source 分开存储的最新的 contextinput 聚合、其余信息按照最新输入
*/
private val latestContextsBySource = LinkedHashMap<String, RunningFlowContext>()
/**
* source 队列,其中元素不会重复,触发唤醒信号时,从该队列取出 source 并处理对应的 context
*/
private val sourceQueue = ArrayDeque<String>()
/**
* 与对应 source 的最新 context 对应,用于记录 context 版本状态
*/
private val sourceVersions = mutableMapOf<String, Long>()
private val responseChannels = mutableMapOf<String, ResponseChannel>()
@Volatile
private var defaultChannel: String = LogChannel.channelName
private var defaultChannel: String = DEFAULT_LOG_CHANNEL
@Volatile
private var runningModules: Map<Int, List<AbstractAgentModule.Running<RunningFlowContext>>> = emptyMap()
@@ -34,13 +52,20 @@ object AgentRuntime : Configurable, ConfigRegistration<ModuleMaskConfig> {
@Volatile
private var maskedModules: Set<String> = emptySet()
@Volatile
private var currentExecutingSource: String? = null
@Volatile
private var currentExecutingContext: RunningFlowContext? = null
init {
register()
scope.launch {
for (ctx in channel) {
executeTurn(ctx)
for (@Suppress("UNUSED_VARIABLE") ignored in wakeSignal) {
drainQueue()
}
}
responseChannels.putIfAbsent(DEFAULT_LOG_CHANNEL, LogChannel)
}
fun registerResponseChannel(channelName: String, responseChannel: ResponseChannel) {
@@ -48,7 +73,7 @@ object AgentRuntime : Configurable, ConfigRegistration<ModuleMaskConfig> {
}
fun unregisterResponseChannel(channelName: String) {
if (channelName == LogChannel.channelName) {
if (channelName == DEFAULT_LOG_CHANNEL) {
return
}
responseChannels.remove(channelName)
@@ -64,26 +89,95 @@ object AgentRuntime : Configurable, ConfigRegistration<ModuleMaskConfig> {
fun response(event: InteractionEvent, channelName: String = defaultChannel) {
val channel = responseChannels[channelName]
if (channel == null) {
responseChannels[defaultChannel]?.response(event) ?: LogChannel.response(event)
responseChannels[defaultChannel]?.response(event)
?: responseChannels[DEFAULT_LOG_CHANNEL]?.response(event)
?: LogChannel.response(event)
} else {
channel.response(event)
}
}
fun <C : RunningFlowContext> submit(context: C) = runBlocking {
channel.send(context)
synchronized(stateLock) {
val source = context.source
latestContextsBySource[source] = latestContextsBySource[source]?.mergedWith(context) ?: context
sourceVersions[source] = (sourceVersions[source] ?: 0L) + 1L
if (!sourceQueue.contains(source)) {
sourceQueue.addLast(source)
}
if (currentExecutingSource == source) {
currentExecutingContext?.status?.interrupted = true
}
}
wakeSignal.send(Unit)
}
private suspend fun executeTurn(runningFlowContext: RunningFlowContext) {
private suspend fun drainQueue() {
while (true) {
val source = synchronized(stateLock) {
sourceQueue.firstOrNull()
} ?: return
executeSource(source)
}
}
private suspend fun executeSource(source: String) {
while (true) {
val execution = synchronized(stateLock) {
val context = latestContextsBySource[source] ?: run {
sourceQueue.remove(source)
sourceVersions.remove(source)
return
}
currentExecutingSource = source
currentExecutingContext = context
context.status.interrupted = false
SourceExecution(context, sourceVersions[source] ?: 0L)
}
val interrupted = executeTurn(execution.context)
val shouldRetry = synchronized(stateLock) {
currentExecutingSource = null
currentExecutingContext = null
val latestContext = latestContextsBySource[source]
val latestVersion = sourceVersions[source] ?: execution.version
when {
latestContext == null -> {
sourceQueue.remove(source)
sourceVersions.remove(source)
false
}
interrupted || latestVersion != execution.version -> true
else -> {
latestContextsBySource.remove(source)
sourceQueue.remove(source)
sourceVersions.remove(source)
false
}
}
}
if (!shouldRetry) {
return
}
}
}
private suspend fun executeTurn(runningFlowContext: RunningFlowContext): Boolean {
if (runningModules.isEmpty()) {
refreshRunningModules()
}
for (modules in runningModules.values) {
if (runningFlowContext.status.interrupted) {
return true
}
executeOrder(modules, runningFlowContext)
}
return runningFlowContext.status.interrupted
}
private fun refreshRunningModules() {
@@ -102,6 +196,9 @@ object AgentRuntime : Configurable, ConfigRegistration<ModuleMaskConfig> {
coroutineScope {
val jobs = modules.map { module ->
async {
if (runningFlowContext.status.interrupted) {
return@async
}
if (runningFlowContext.skippedModules.contains(module.moduleName)) {
return@async
}
@@ -144,6 +241,10 @@ object AgentRuntime : Configurable, ConfigRegistration<ModuleMaskConfig> {
refreshRunningModules()
}
private data class SourceExecution(
val context: RunningFlowContext,
val version: Long
)
}
data class ModuleMaskConfig(

View File

@@ -1,39 +1,60 @@
package work.slhaf.partner.framework.agent.interaction.flow
import com.alibaba.fastjson2.JSONObject
import org.w3c.dom.Document
import org.w3c.dom.Element
import java.time.Instant
import java.time.LocalDateTime
import java.time.ZonedDateTime
import java.time.ZoneId
import java.util.*
import kotlin.math.min
/**
* 流程上下文
*/
abstract class RunningFlowContext {
abstract class RunningFlowContext protected constructor(
inputs: List<InputEntry>,
val firstInputEpochMillis: Long,
additionalUserInfo: Map<String, String> = emptyMap(),
skippedModules: Set<String> = emptySet(),
target: String = ""
) {
/**
* 消息来源: 由谁发出
*/
abstract val source: String
/**
* 消息内容
* 输入序列
*/
abstract val input: String
val inputs: List<InputEntry> = inputs.sortedBy { it.offsetMillis }
/**
* 兼容旧路径的纯文本输入表示,按时间顺序换行拼接
*/
val input: String
get() = formatInputsForHistory()
/**
* 消息回应对象,默认与 source 一致
*/
var target = source
var target: String = target
private val _additionalUserInfo = mutableMapOf<String, String>()
private val _additionalUserInfo = additionalUserInfo.toMutableMap()
val additionalUserInfo: Map<String, String>
get() = _additionalUserInfo
private val _skippedModules = mutableSetOf<String>()
private val _skippedModules = skippedModules.toMutableSet()
val skippedModules: Set<String>
get() = _skippedModules
val status = Status()
val firstInputDateTime: LocalDateTime
get() = Instant.ofEpochMilli(firstInputEpochMillis)
.atZone(ZoneId.systemDefault())
.toLocalDateTime()
fun addSkippedModule(moduleName: String) {
_skippedModules.add(moduleName)
}
@@ -45,14 +66,104 @@ abstract class RunningFlowContext {
fun putUserInfo(key: String, value: Any) {
_additionalUserInfo[key] = try {
JSONObject.toJSONString(value)
} catch (e: Exception) {
} catch (_: Exception) {
value.toString()
}
}
fun formatInputsForHistory(): String = inputs.joinToString("\n") { it.content }
@JvmOverloads
fun appendInputsXml(
document: Document,
parent: Element,
containerTagName: String = "inputs",
inputTagName: String = "input",
intervalAttributeName: String = "interval-to-first"
) {
val inputsElement = document.createElement(containerTagName)
parent.appendChild(inputsElement)
inputs.forEach { entry ->
val inputElement = document.createElement(inputTagName)
inputElement.setAttribute(intervalAttributeName, entry.offsetMillis.toString())
inputElement.textContent = entry.content
inputsElement.appendChild(inputElement)
}
}
fun encodeInputsXml(): String {
val builder = StringBuilder()
builder.append("<inputs>")
inputs.forEach { entry ->
builder.append("<input interval-to-first=\"")
.append(escapeXml(entry.offsetMillis.toString()))
.append("\">")
.append(escapeXml(entry.content))
.append("</input>")
}
builder.append("</inputs>")
return builder.toString()
}
fun mergedWith(other: RunningFlowContext): RunningFlowContext {
require(source == other.source) {
"Unable to merge RunningFlowContext from different source: $source != ${other.source}"
}
val mergedFirstEpochMillis = min(firstInputEpochMillis, other.firstInputEpochMillis)
val mergedInputs = buildList(inputs.size + other.inputs.size) {
addAll(normalizeInputs(this@RunningFlowContext, mergedFirstEpochMillis))
addAll(normalizeInputs(other, mergedFirstEpochMillis))
}.sortedBy { it.offsetMillis }
val mergedAdditionalUserInfo = LinkedHashMap<String, String>(_additionalUserInfo)
mergedAdditionalUserInfo.putAll(other.additionalUserInfo)
val mergedSkippedModules = LinkedHashSet<String>(_skippedModules)
mergedSkippedModules.addAll(other.skippedModules)
return copyWith(
inputs = mergedInputs,
firstInputEpochMillis = mergedFirstEpochMillis,
additionalUserInfo = mergedAdditionalUserInfo,
skippedModules = mergedSkippedModules,
target = other.target.ifBlank { target }
)
}
protected abstract fun copyWith(
inputs: List<InputEntry>,
firstInputEpochMillis: Long,
additionalUserInfo: Map<String, String>,
skippedModules: Set<String>,
target: String
): RunningFlowContext
private fun normalizeInputs(context: RunningFlowContext, firstEpochMillis: Long): List<InputEntry> {
return context.inputs.map { entry ->
InputEntry(
offsetMillis = context.firstInputEpochMillis + entry.offsetMillis - firstEpochMillis,
content = entry.content
)
}
}
private fun escapeXml(value: String): String {
return value
.replace("&", "&amp;")
.replace("<", "&lt;")
.replace(">", "&gt;")
.replace("\"", "&quot;")
.replace("'", "&apos;")
}
data class InputEntry(
val offsetMillis: Long,
val content: String
)
class Info {
val uuid = UUID.randomUUID().toString()
val dateTime: LocalDateTime = ZonedDateTime.now().toLocalDateTime()
val dateTime: LocalDateTime = LocalDateTime.now()
}
class Status {
@@ -62,6 +173,12 @@ abstract class RunningFlowContext {
val ok: Boolean
get() = errors.isEmpty()
/**
* 模块边界上的协作式打断标记
*/
@Volatile
var interrupted: Boolean = false
/**
* 本次执行时收集到的异常信息
*/

View File

@@ -0,0 +1,224 @@
package work.slhaf.partner.framework.agent.interaction
import org.junit.jupiter.api.AfterEach
import org.junit.jupiter.api.Assertions.assertEquals
import org.junit.jupiter.api.Assertions.assertTrue
import org.junit.jupiter.api.BeforeEach
import org.junit.jupiter.api.Test
import work.slhaf.partner.framework.agent.factory.component.abstracts.AbstractAgentModule
import work.slhaf.partner.framework.agent.factory.context.AgentContext
import work.slhaf.partner.framework.agent.factory.context.ModuleContextData
import work.slhaf.partner.framework.agent.interaction.flow.RunningFlowContext
import java.time.ZonedDateTime
import java.util.concurrent.CopyOnWriteArrayList
import java.util.concurrent.CountDownLatch
import java.util.concurrent.TimeUnit
import java.util.concurrent.atomic.AtomicInteger
class AgentRuntimeTest {
@BeforeEach
fun setUp() {
resetAgentRuntime()
clearModules()
}
@AfterEach
fun tearDown() {
resetAgentRuntime()
clearModules()
}
@Test
fun `running flow context preserves offsets and xml encoding`() {
val first = TestRunningFlowContext.of("source-a", "first", 1_000L)
val second = TestRunningFlowContext.of("source-a", "second", 1_250L)
val merged = first.mergedWith(second)
assertEquals(listOf(0L, 250L), merged.inputs.map { it.offsetMillis })
assertEquals("first\nsecond", merged.input)
assertEquals(
"<inputs><input interval-to-first=\"0\">first</input><input interval-to-first=\"250\">second</input></inputs>",
merged.encodeInputsXml()
)
}
@Test
fun `agent runtime keeps source queue in first arrival order`() {
val recorder = RecordingModule(order = 1, expectedExecutions = 2)
registerModule("queue-recorder", recorder)
AgentRuntime.submit(TestRunningFlowContext.of("source-a", "alpha"))
AgentRuntime.submit(TestRunningFlowContext.of("source-b", "beta"))
assertTrue(recorder.latch.await(5, TimeUnit.SECONDS))
assertEquals(listOf("source-a", "source-b"), recorder.sources)
}
@Test
fun `agent runtime interrupts current source and reruns from chain head with merged context`() {
val blocking = BlockingModule()
val finalizer = RecordingModule(order = 2, expectedExecutions = 1)
registerModule("blocking-module", blocking)
registerModule("finalizer-module", finalizer)
AgentRuntime.submit(TestRunningFlowContext.of("source-a", "first", 1_000L))
assertTrue(blocking.firstExecutionStarted.await(5, TimeUnit.SECONDS))
AgentRuntime.submit(TestRunningFlowContext.of("source-a", "second", 1_300L))
blocking.releaseFirstExecution.countDown()
assertTrue(finalizer.latch.await(5, TimeUnit.SECONDS))
waitUntil { blocking.seenInputSizes.size >= 2 }
assertEquals(listOf(1, 2), blocking.seenInputSizes)
assertEquals(listOf(2), finalizer.inputSizes)
assertEquals(listOf("first\nsecond"), finalizer.historyInputs)
}
private fun registerModule(name: String, module: AbstractAgentModule.Running<*>) {
@Suppress("UNCHECKED_CAST")
AgentContext.addModule(
name,
ModuleContextData.Running(
module.javaClass,
module,
ZonedDateTime.now(),
null,
module.order()
) as ModuleContextData<AbstractAgentModule>
)
}
private fun clearModules() {
@Suppress("UNCHECKED_CAST")
val modules = AgentContext.modules as MutableMap<String, ModuleContextData<AbstractAgentModule>>
modules.clear()
}
private fun resetAgentRuntime() {
setPrivateField("runningModules", emptyMap<Int, List<AbstractAgentModule.Running<RunningFlowContext>>>())
setPrivateField("maskedModules", emptySet<String>())
setPrivateField("currentExecutingSource", null)
setPrivateField("currentExecutingContext", null)
getPrivateMutableMap<String, RunningFlowContext>("latestContextsBySource").clear()
getPrivateMutableMap<String, Long>("sourceVersions").clear()
getPrivateDeque<String>("sourceQueue").clear()
}
private fun waitUntil(timeoutMillis: Long = 5_000L, condition: () -> Boolean) {
val deadline = System.currentTimeMillis() + timeoutMillis
while (System.currentTimeMillis() < deadline) {
if (condition()) {
return
}
Thread.sleep(20L)
}
error("Condition was not satisfied within $timeoutMillis ms")
}
private fun setPrivateField(fieldName: String, value: Any?) {
val field = AgentRuntime::class.java.getDeclaredField(fieldName)
field.isAccessible = true
field.set(AgentRuntime, value)
}
@Suppress("UNCHECKED_CAST")
private fun <K, V> getPrivateMutableMap(fieldName: String): MutableMap<K, V> {
val field = AgentRuntime::class.java.getDeclaredField(fieldName)
field.isAccessible = true
return field.get(AgentRuntime) as MutableMap<K, V>
}
@Suppress("UNCHECKED_CAST")
private fun <T> getPrivateDeque(fieldName: String): java.util.ArrayDeque<T> {
val field = AgentRuntime::class.java.getDeclaredField(fieldName)
field.isAccessible = true
return field.get(AgentRuntime) as java.util.ArrayDeque<T>
}
private class RecordingModule(
private val order: Int,
expectedExecutions: Int
) : AbstractAgentModule.Running<TestRunningFlowContext>() {
val sources = CopyOnWriteArrayList<String>()
val inputSizes = CopyOnWriteArrayList<Int>()
val historyInputs = CopyOnWriteArrayList<String>()
val latch = CountDownLatch(expectedExecutions)
init {
moduleName = "recording-$order"
}
override fun doExecute(context: TestRunningFlowContext) {
sources.add(context.source)
inputSizes.add(context.inputs.size)
historyInputs.add(context.input)
latch.countDown()
}
override fun order(): Int = order
}
private class BlockingModule : AbstractAgentModule.Running<TestRunningFlowContext>() {
val seenInputSizes = CopyOnWriteArrayList<Int>()
val firstExecutionStarted = CountDownLatch(1)
val releaseFirstExecution = CountDownLatch(1)
private val invocationCount = AtomicInteger(0)
init {
moduleName = "blocking"
}
override fun doExecute(context: TestRunningFlowContext) {
seenInputSizes.add(context.inputs.size)
if (invocationCount.getAndIncrement() == 0) {
firstExecutionStarted.countDown()
releaseFirstExecution.await(5, TimeUnit.SECONDS)
}
}
override fun order(): Int = 1
}
private class TestRunningFlowContext private constructor(
override val source: String,
inputs: List<InputEntry>,
firstInputEpochMillis: Long,
target: String = source
) : RunningFlowContext(inputs, firstInputEpochMillis, target = target) {
companion object {
fun of(
source: String,
input: String,
receivedAtMillis: Long = System.currentTimeMillis()
): TestRunningFlowContext {
return TestRunningFlowContext(
source = source,
inputs = listOf(InputEntry(0L, input)),
firstInputEpochMillis = receivedAtMillis
)
}
}
override fun copyWith(
inputs: List<InputEntry>,
firstInputEpochMillis: Long,
additionalUserInfo: Map<String, String>,
skippedModules: Set<String>,
target: String
): RunningFlowContext {
return TestRunningFlowContext(
source = source,
inputs = inputs,
firstInputEpochMillis = firstInputEpochMillis,
target = target
).apply {
additionalUserInfo.forEach(::putUserInfo)
skippedModules.forEach(::addSkippedModule)
}
}
}
}