mirror of
https://github.com/slhaf/Partner.git
synced 2026-05-12 08:43:02 +08:00
refactor(CommunicationProducer): split context and supply message assembly
This commit is contained in:
@@ -4,21 +4,35 @@ import com.alibaba.fastjson2.JSONObject;
|
||||
import lombok.Data;
|
||||
import lombok.EqualsAndHashCode;
|
||||
import org.jetbrains.annotations.NotNull;
|
||||
import org.w3c.dom.Document;
|
||||
import org.w3c.dom.Element;
|
||||
import work.slhaf.partner.api.agent.factory.capability.annotation.InjectCapability;
|
||||
import work.slhaf.partner.api.agent.factory.component.abstracts.AbstractAgentModule;
|
||||
import work.slhaf.partner.api.agent.factory.component.abstracts.ActivateModel;
|
||||
import work.slhaf.partner.api.agent.factory.component.annotation.Init;
|
||||
import work.slhaf.partner.api.agent.runtime.interaction.flow.ContextBlock;
|
||||
import work.slhaf.partner.api.chat.pojo.Message;
|
||||
import work.slhaf.partner.api.chat.pojo.MetaMessage;
|
||||
import work.slhaf.partner.core.cognation.CognationCapability;
|
||||
import work.slhaf.partner.module.common.entity.AppendPromptData;
|
||||
import work.slhaf.partner.module.common.model.ModelConstant;
|
||||
import work.slhaf.partner.runtime.interaction.data.context.PartnerRunningFlowContext;
|
||||
|
||||
import javax.xml.parsers.DocumentBuilderFactory;
|
||||
import javax.xml.transform.OutputKeys;
|
||||
import javax.xml.transform.Transformer;
|
||||
import javax.xml.transform.TransformerFactory;
|
||||
import javax.xml.transform.dom.DOMSource;
|
||||
import javax.xml.transform.stream.StreamResult;
|
||||
import java.io.ByteArrayInputStream;
|
||||
import java.io.StringWriter;
|
||||
import java.nio.charset.StandardCharsets;
|
||||
import java.time.LocalDateTime;
|
||||
import java.time.format.DateTimeFormatter;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Comparator;
|
||||
import java.util.LinkedHashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
import static work.slhaf.partner.common.util.ExtractUtil.extractJson;
|
||||
|
||||
@@ -26,9 +40,18 @@ import static work.slhaf.partner.common.util.ExtractUtil.extractJson;
|
||||
@Data
|
||||
public class CommunicationProducer extends AbstractAgentModule.Running<PartnerRunningFlowContext> implements ActivateModel {
|
||||
|
||||
private static final String MODULE_PROMPT = """
|
||||
你是 Partner 的表达模块。
|
||||
你接下来收到的消息固定分为三个区段:
|
||||
1. system message 是 Head, 用于说明整个输入结构与输出要求。
|
||||
2. <context> 区段只承载 type=CONTEXT 的上下文块, 其中每个子块都带有独立来源, 仅作为理解当前状态与辅助决策的依据。
|
||||
3. Conversation 区段是对话轨迹; 最新的一条 user message 会使用 <input> 结构, 其中 <content> 是本轮用户原始输入, 其他子标签是输入元信息与 type=SUPPLY 的补充块, 补充块会按 blockName 分区。
|
||||
你必须综合 Context 与 Conversation 回答最新输入, 不要把 XML 标签当作需要原样复述给用户的内容。
|
||||
直接输出最终回应内容即可, 不需要额外包装为 JSON。
|
||||
""";
|
||||
|
||||
@InjectCapability
|
||||
private CognationCapability cognationCapability;
|
||||
private final List<Message> appendedMessages = new ArrayList<>();
|
||||
private final List<Message> chatMessages = new ArrayList<>();
|
||||
|
||||
@Init
|
||||
@@ -48,172 +71,220 @@ public class CommunicationProducer extends AbstractAgentModule.Running<PartnerRu
|
||||
return true;
|
||||
}
|
||||
|
||||
@Override
|
||||
public @NotNull List<Message> modulePrompt() {
|
||||
return List.of(new Message(Message.Character.SYSTEM, MODULE_PROMPT));
|
||||
}
|
||||
|
||||
@Override
|
||||
public void execute(PartnerRunningFlowContext runningFlowContext) {
|
||||
String userId = runningFlowContext.getSource();
|
||||
log.debug("[CommunicationProducer] 主对话流程开始: {}", userId);
|
||||
beforeChat(runningFlowContext);
|
||||
log.debug("Communicating with: {}",runningFlowContext.getSource());
|
||||
executeChat(runningFlowContext);
|
||||
log.debug("[CommunicationProducer] 主对话流程({})结束...", userId);
|
||||
}
|
||||
|
||||
private void beforeChat(PartnerRunningFlowContext runningFlowContext) {
|
||||
setAppendedPromptMessage(runningFlowContext);
|
||||
activateModule(runningFlowContext);
|
||||
setMessageCount(runningFlowContext);
|
||||
|
||||
log.debug("[CommunicationProducer] 当前消息列表大小: {}", chatMessages.size());
|
||||
log.debug("[CommunicationProducer] 当前核心prompt内容: {}", runningFlowContext.getCoreContext().toString());
|
||||
|
||||
setMessage(runningFlowContext.getCoreContext().toString());
|
||||
}
|
||||
|
||||
// TODO need to update message appending logic
|
||||
private void setAppendedPromptMessage(PartnerRunningFlowContext runningFlowContext) {
|
||||
List<AppendPromptData> appendedPrompt = runningFlowContext.getModuleContext().getAppendedPrompt();
|
||||
int appendedPromptSize = getAppendedPromptSize(appendedPrompt);
|
||||
if (appendedPromptSize > 0) {
|
||||
setAppendedPromptMessage(appendedPrompt);
|
||||
}
|
||||
}
|
||||
|
||||
private void executeChat(PartnerRunningFlowContext runningFlowContext) {
|
||||
JSONObject response = new JSONObject();
|
||||
String responseText = null;
|
||||
|
||||
// TODO considering removing retries in module
|
||||
int count = 0;
|
||||
while (true) {
|
||||
try {
|
||||
String chatResponse = this.chat(buildChatMessages());
|
||||
try {
|
||||
response.putAll(JSONObject.parse(extractJson(chatResponse)));
|
||||
} catch (Exception e) {
|
||||
log.warn("主模型回复格式出错, 将直接作为消息返回, 建议尝试更换主模型...");
|
||||
handleExceptionResponse(response, chatResponse);
|
||||
}
|
||||
log.debug("[CommunicationProducer] CommunicationProducer 响应内容: {}", response);
|
||||
updateModuleContextAndChatMessages(runningFlowContext, response.getString("text"));
|
||||
// TODO 为各模块提供 emit msg 能力后, 在这里统一接收并分发结构化输出.
|
||||
responseText = this.chat(buildChatMessages(runningFlowContext));
|
||||
log.debug("CommunicationProducer responses: {}", responseText);
|
||||
updateModuleContextAndChatMessages(runningFlowContext, responseText);
|
||||
break;
|
||||
} catch (Exception e) {
|
||||
count++;
|
||||
log.error("[CommunicationProducer] CoreModel执行异常: {}", e.getLocalizedMessage());
|
||||
log.error("Communicating exception occurred: {}", e.getLocalizedMessage());
|
||||
if (count > 3) {
|
||||
handleExceptionResponse(response, "主模型交互出错: " + e.getLocalizedMessage());
|
||||
chatMessages.removeLast();
|
||||
responseText = "CommunicationProducer Failed: " + e.getLocalizedMessage();
|
||||
break;
|
||||
}
|
||||
} finally {
|
||||
updateCoreResponse(runningFlowContext, response);
|
||||
resetAppendedMessages();
|
||||
log.debug("[CommunicationProducer] 消息列表更新大小: {}", chatMessages.size());
|
||||
updateCoreResponse(runningFlowContext, responseText);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private int getAppendedPromptSize(List<AppendPromptData> appendedPrompt) {
|
||||
int size = 0;
|
||||
for (AppendPromptData data : appendedPrompt) {
|
||||
size += data.getAppendedPrompt().size();
|
||||
private void updateCoreResponse(PartnerRunningFlowContext runningFlowContext, String responseText) {
|
||||
runningFlowContext.getCoreResponse().put("text", responseText);
|
||||
}
|
||||
|
||||
private List<Message> buildChatMessages(PartnerRunningFlowContext runningFlowContext) {
|
||||
List<Message> temp = new ArrayList<>(chatMessages.size() + 2);
|
||||
Message contextMessage = buildContextMessage(runningFlowContext);
|
||||
if (contextMessage != null) {
|
||||
temp.add(contextMessage);
|
||||
}
|
||||
return size;
|
||||
}
|
||||
|
||||
private void activateModule(PartnerRunningFlowContext context) {
|
||||
for (AppendPromptData data : context.getModuleContext().getAppendedPrompt()) {
|
||||
if (data.getAppendedPrompt().isEmpty()) continue;
|
||||
context.getCoreContext().activateModule(data.getModuleName());
|
||||
}
|
||||
}
|
||||
|
||||
private void updateCoreResponse(PartnerRunningFlowContext runningFlowContext, JSONObject response) {
|
||||
runningFlowContext.getCoreResponse().put("text", response.getString("text"));
|
||||
}
|
||||
|
||||
private void resetAppendedMessages() {
|
||||
this.appendedMessages.clear();
|
||||
}
|
||||
|
||||
private List<Message> buildChatMessages() {
|
||||
List<Message> temp = new ArrayList<>(appendedMessages.size() + chatMessages.size());
|
||||
temp.addAll(appendedMessages);
|
||||
temp.addAll(chatMessages);
|
||||
temp.add(buildInputMessage(runningFlowContext));
|
||||
return temp;
|
||||
}
|
||||
|
||||
private void updateModuleContextAndChatMessages(PartnerRunningFlowContext runningFlowContext, String response) {
|
||||
cognationCapability.getMessageLock().lock();
|
||||
chatMessages.removeIf(m -> {
|
||||
if (m.getRole() == Message.Character.ASSISTANT) {
|
||||
return false;
|
||||
}
|
||||
try {
|
||||
JSONObject.parseObject(extractJson(m.getContent()));
|
||||
return true;
|
||||
} catch (Exception e) {
|
||||
return false;
|
||||
}
|
||||
});
|
||||
//添加时间标志
|
||||
// TODO 此处的时间标识应当采用 RunningFlowContext 携带时间
|
||||
String dateTime = LocalDateTime.now().format(DateTimeFormatter.ofPattern("\r\n**[yyyy-MM-dd HH:mm:ss]"));
|
||||
Message primaryUserMessage = new Message(Message.Character.USER, runningFlowContext.getCoreContext().getText() + dateTime);
|
||||
chatMessages.add(primaryUserMessage);
|
||||
Message assistantMessage = new Message(Message.Character.ASSISTANT, response);
|
||||
chatMessages.add(assistantMessage);
|
||||
cognationCapability.getMessageLock().unlock();
|
||||
//区分单人聊天场景
|
||||
// if (runningFlowContext.isSingle()) {
|
||||
MetaMessage metaMessage = new MetaMessage(primaryUserMessage, assistantMessage);
|
||||
cognationCapability.addMetaMessage(runningFlowContext.getSource(), metaMessage);
|
||||
// }
|
||||
}
|
||||
|
||||
private void setMessage(String coreContextStr) {
|
||||
Message userMessage = new Message(Message.Character.USER, coreContextStr);
|
||||
chatMessages.add(userMessage);
|
||||
}
|
||||
|
||||
private void handleExceptionResponse(JSONObject response, String chatResponse) {
|
||||
response.put("text", chatResponse);
|
||||
// interactionContext.setFinished(true);
|
||||
}
|
||||
|
||||
private void setMessageCount(PartnerRunningFlowContext runningFlowContext) {
|
||||
runningFlowContext.getModuleContext().getExtraContext().put("message_count", chatMessages.size());
|
||||
}
|
||||
|
||||
private void setAppendedPromptMessage(List<AppendPromptData> appendPrompt) {
|
||||
Message appendDeclareMessage = new Message(Message.Character.USER, ModelConstant.CharacterPrefix.SYSTEM + "认知补充开始");
|
||||
this.appendedMessages.add(appendDeclareMessage);
|
||||
for (AppendPromptData data : appendPrompt) {
|
||||
setStartMessage(data);
|
||||
setContentMessage(data);
|
||||
setEndMessage(data);
|
||||
setAssistantMessage();
|
||||
try {
|
||||
chatMessages.removeIf(this::isStructuredUserMessage);
|
||||
// TODO 此处的时间标识应当采用 RunningFlowContext 携带时间
|
||||
String dateTime = LocalDateTime.now().format(DateTimeFormatter.ofPattern("\r\n**[yyyy-MM-dd HH:mm:ss]"));
|
||||
Message primaryUserMessage = new Message(
|
||||
Message.Character.USER,
|
||||
formatConversationUserMessage(runningFlowContext) + dateTime
|
||||
);
|
||||
chatMessages.add(primaryUserMessage);
|
||||
Message assistantMessage = new Message(Message.Character.ASSISTANT, response);
|
||||
chatMessages.add(assistantMessage);
|
||||
MetaMessage metaMessage = new MetaMessage(primaryUserMessage, assistantMessage);
|
||||
cognationCapability.addMetaMessage(runningFlowContext.getSource(), metaMessage);
|
||||
} finally {
|
||||
cognationCapability.getMessageLock().unlock();
|
||||
}
|
||||
Message appendEndMessage = new Message(Message.Character.USER, ModelConstant.CharacterPrefix.SYSTEM + "认知补充结束");
|
||||
this.appendedMessages.add(appendEndMessage);
|
||||
}
|
||||
|
||||
private void setAssistantMessage() {
|
||||
Message message = new Message(Message.Character.ASSISTANT, "嗯,明白了");
|
||||
appendedMessages.add(message);
|
||||
private Message buildContextMessage(PartnerRunningFlowContext runningFlowContext) {
|
||||
List<ContextBlock> contextBlocks = filterContextBlocks(
|
||||
runningFlowContext.getContextBlocks(),
|
||||
ContextBlock.Type.CONTEXT
|
||||
);
|
||||
if (contextBlocks.isEmpty()) {
|
||||
return null;
|
||||
}
|
||||
return new Message(Message.Character.USER, buildContextXml(contextBlocks));
|
||||
}
|
||||
|
||||
private void setEndMessage(AppendPromptData data) {
|
||||
Message endMessage = new Message(Message.Character.USER, ModelConstant.CharacterPrefix.SYSTEM + data.getModuleName() + "认知补充结束.");
|
||||
appendedMessages.add(endMessage);
|
||||
private Message buildInputMessage(PartnerRunningFlowContext runningFlowContext) {
|
||||
return new Message(Message.Character.USER, buildInputXml(runningFlowContext));
|
||||
}
|
||||
|
||||
private void setContentMessage(AppendPromptData data) {
|
||||
data.getAppendedPrompt().forEach((k, v) -> {
|
||||
Message contentMessage = new Message(Message.Character.USER, ModelConstant.CharacterPrefix.SYSTEM + k + v + "\r\n");
|
||||
appendedMessages.add(contentMessage);
|
||||
});
|
||||
private String buildContextXml(List<ContextBlock> contextBlocks) {
|
||||
try {
|
||||
Document document = newDocument();
|
||||
Element root = document.createElement("context");
|
||||
document.appendChild(root);
|
||||
|
||||
contextBlocks.stream()
|
||||
.sorted(Comparator.comparingInt(ContextBlock::getPriority))
|
||||
.map(ContextBlock::encodeToXml)
|
||||
.forEach(blockXml -> {
|
||||
Element blockElement = parseElement(blockXml);
|
||||
root.appendChild(document.importNode(blockElement, true));
|
||||
});
|
||||
|
||||
return toXmlString(document);
|
||||
} catch (Exception e) {
|
||||
throw new IllegalStateException("构建 context 区段失败", e);
|
||||
}
|
||||
}
|
||||
|
||||
private void setStartMessage(AppendPromptData data) {
|
||||
Message startMessage = new Message(Message.Character.USER, ModelConstant.CharacterPrefix.SYSTEM + data.getModuleName() + "以下为" + data.getModuleName() + "相关认知.");
|
||||
appendedMessages.add(startMessage);
|
||||
private String buildInputXml(PartnerRunningFlowContext runningFlowContext) {
|
||||
try {
|
||||
Document document = newDocument();
|
||||
Element root = document.createElement("input");
|
||||
document.appendChild(root);
|
||||
|
||||
appendTextElement(document, root, "content", runningFlowContext.getInput());
|
||||
appendTextElement(document, root, "source", runningFlowContext.getSource());
|
||||
for (Map.Entry<String, String> entry : runningFlowContext.getAdditionalUserInfo().entrySet()) {
|
||||
appendTextElement(document, root, sanitizeTagName(entry.getKey()), entry.getValue());
|
||||
}
|
||||
appendSupplyBlocks(document, root, runningFlowContext.getContextBlocks());
|
||||
|
||||
return toXmlString(document);
|
||||
} catch (Exception e) {
|
||||
throw new IllegalStateException("构建 input 区段失败", e);
|
||||
}
|
||||
}
|
||||
|
||||
private boolean isStructuredUserMessage(Message message) {
|
||||
if (message.getRole() == Message.Character.ASSISTANT) {
|
||||
return false;
|
||||
}
|
||||
String content = message.getContent();
|
||||
String trimmed = content.trim();
|
||||
if (trimmed.startsWith("<input>") || trimmed.startsWith("<context>") || trimmed.startsWith("<?xml")) {
|
||||
return true;
|
||||
}
|
||||
try {
|
||||
JSONObject.parseObject(extractJson(content));
|
||||
return true;
|
||||
} catch (Exception e) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
private List<ContextBlock> filterContextBlocks(List<ContextBlock> contextBlocks, ContextBlock.Type type) {
|
||||
return contextBlocks.stream()
|
||||
.filter(block -> block.getType() == type)
|
||||
.sorted(Comparator.comparingInt(ContextBlock::getPriority))
|
||||
.toList();
|
||||
}
|
||||
|
||||
private String formatConversationUserMessage(PartnerRunningFlowContext runningFlowContext) {
|
||||
return runningFlowContext.getSource() + ": " + runningFlowContext.getInput();
|
||||
}
|
||||
|
||||
private Document newDocument() throws Exception {
|
||||
return DocumentBuilderFactory.newInstance()
|
||||
.newDocumentBuilder()
|
||||
.newDocument();
|
||||
}
|
||||
|
||||
private void appendTextElement(Document document, Element parent, String tagName, String value) {
|
||||
Element element = document.createElement(tagName);
|
||||
element.setTextContent(value == null ? "" : value);
|
||||
parent.appendChild(element);
|
||||
}
|
||||
|
||||
private Element parseElement(String xml) {
|
||||
try {
|
||||
Document parsedDocument = DocumentBuilderFactory.newInstance()
|
||||
.newDocumentBuilder()
|
||||
.parse(new ByteArrayInputStream(xml.getBytes(StandardCharsets.UTF_8)));
|
||||
return parsedDocument.getDocumentElement();
|
||||
} catch (Exception e) {
|
||||
throw new IllegalStateException("解析 ContextBlock XML 失败", e);
|
||||
}
|
||||
}
|
||||
|
||||
private void appendSupplyBlocks(Document document, Element inputRoot, List<ContextBlock> contextBlocks) {
|
||||
Map<String, List<ContextBlock>> groupedBlocks = filterContextBlocks(contextBlocks, ContextBlock.Type.SUPPLY).stream()
|
||||
.collect(Collectors.groupingBy(
|
||||
block -> sanitizeTagName(block.getBlockName()),
|
||||
LinkedHashMap::new,
|
||||
Collectors.toList()
|
||||
));
|
||||
|
||||
for (Map.Entry<String, List<ContextBlock>> entry : groupedBlocks.entrySet()) {
|
||||
Element groupElement = document.createElement(entry.getKey());
|
||||
inputRoot.appendChild(groupElement);
|
||||
for (ContextBlock block : entry.getValue()) {
|
||||
Element blockElement = parseElement(block.encodeToXml());
|
||||
groupElement.appendChild(document.importNode(blockElement, true));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private String sanitizeTagName(String rawTagName) {
|
||||
if (rawTagName == null || rawTagName.isBlank()) {
|
||||
return "meta";
|
||||
}
|
||||
String sanitized = rawTagName.replaceAll("[^A-Za-z0-9_.-]", "_");
|
||||
if (!Character.isLetter(sanitized.charAt(0)) && sanitized.charAt(0) != '_') {
|
||||
sanitized = "_" + sanitized;
|
||||
}
|
||||
return sanitized;
|
||||
}
|
||||
|
||||
private String toXmlString(Document document) throws Exception {
|
||||
Transformer transformer = TransformerFactory.newInstance().newTransformer();
|
||||
transformer.setOutputProperty(OutputKeys.INDENT, "yes");
|
||||
transformer.setOutputProperty(OutputKeys.ENCODING, "UTF-8");
|
||||
transformer.setOutputProperty(OutputKeys.OMIT_XML_DECLARATION, "yes");
|
||||
transformer.setOutputProperty("{http://xml.apache.org/xslt}indent-amount", "2");
|
||||
StringWriter writer = new StringWriter();
|
||||
transformer.transform(new DOMSource(document), new StreamResult(writer));
|
||||
return writer.toString();
|
||||
}
|
||||
|
||||
@Override
|
||||
|
||||
@@ -0,0 +1,215 @@
|
||||
package work.slhaf.partner.module.modules.core;
|
||||
|
||||
import org.junit.jupiter.api.BeforeEach;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.extension.ExtendWith;
|
||||
import org.jetbrains.annotations.NotNull;
|
||||
import org.mockito.ArgumentCaptor;
|
||||
import org.mockito.Mock;
|
||||
import org.mockito.junit.jupiter.MockitoExtension;
|
||||
import org.w3c.dom.Document;
|
||||
import org.w3c.dom.Element;
|
||||
import work.slhaf.partner.api.agent.runtime.interaction.flow.ContextBlock;
|
||||
import work.slhaf.partner.api.chat.pojo.Message;
|
||||
import work.slhaf.partner.api.chat.pojo.MetaMessage;
|
||||
import work.slhaf.partner.core.cognation.CognationCapability;
|
||||
import work.slhaf.partner.runtime.interaction.data.context.PartnerRunningFlowContext;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.concurrent.locks.ReentrantLock;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import static org.junit.jupiter.api.Assertions.assertFalse;
|
||||
import static org.junit.jupiter.api.Assertions.assertNotNull;
|
||||
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||
import static org.mockito.ArgumentMatchers.anyString;
|
||||
import static org.mockito.Mockito.lenient;
|
||||
import static org.mockito.Mockito.verify;
|
||||
|
||||
@ExtendWith(MockitoExtension.class)
|
||||
class CommunicationProducerTest {
|
||||
|
||||
@Mock
|
||||
private CognationCapability cognationCapability;
|
||||
|
||||
private TestCommunicationProducer producer;
|
||||
|
||||
@BeforeEach
|
||||
void setUp() {
|
||||
producer = new TestCommunicationProducer();
|
||||
producer.setCognationCapability(cognationCapability);
|
||||
lenient().when(cognationCapability.getMessageLock()).thenReturn(new ReentrantLock());
|
||||
}
|
||||
|
||||
@Test
|
||||
void execute_shouldAssembleHeadContextConversationAndPersistCompactHistory() {
|
||||
List<Message> history = List.of(
|
||||
new Message(Message.Character.USER, "[USER]: old-user: 旧消息"),
|
||||
new Message(Message.Character.ASSISTANT, "旧回复")
|
||||
);
|
||||
lenient().when(cognationCapability.getChatMessages()).thenReturn(history);
|
||||
|
||||
producer.init();
|
||||
|
||||
PartnerRunningFlowContext context = PartnerRunningFlowContext.Companion.fromUser(
|
||||
"u-1",
|
||||
"你好,介绍一下你现在看到的上下文",
|
||||
"wechat",
|
||||
"tester"
|
||||
);
|
||||
context.getContextBlocks().add(new TestContextBlock(20, "memory", "slice-A"));
|
||||
context.getContextBlocks().add(new TestContextBlock(10, "perceive", "slice-B"));
|
||||
context.getContextBlocks().add(new TestContextBlock(30, ContextBlock.Type.SUPPLY, "tool", "source-a", "supply-A"));
|
||||
context.getContextBlocks().add(new TestContextBlock(40, ContextBlock.Type.SUPPLY, "tool", "source-b", "supply-B"));
|
||||
|
||||
producer.execute(context);
|
||||
|
||||
List<Message> sentMessages = producer.getCapturedMessages();
|
||||
assertEquals(5, sentMessages.size());
|
||||
assertEquals(Message.Character.SYSTEM, sentMessages.get(0).getRole());
|
||||
assertTrue(sentMessages.get(0).getContent().contains("Head"));
|
||||
|
||||
String contextXml = sentMessages.get(1).getContent();
|
||||
assertTrue(contextXml.contains("<context>"));
|
||||
assertTrue(contextXml.indexOf("slice-B") < contextXml.indexOf("slice-A"));
|
||||
assertTrue(contextXml.contains("source=\"perceive\""));
|
||||
assertTrue(contextXml.contains("source=\"memory\""));
|
||||
assertFalse(contextXml.contains("supply-A"));
|
||||
assertFalse(contextXml.contains("supply-B"));
|
||||
|
||||
String inputXml = sentMessages.get(4).getContent();
|
||||
assertTrue(inputXml.contains("<input>"));
|
||||
assertTrue(inputXml.contains("<content>你好,介绍一下你现在看到的上下文</content>"));
|
||||
assertTrue(inputXml.contains("<source>[USER]: u-1</source>"));
|
||||
assertTrue(inputXml.contains("<platform>wechat</platform>"));
|
||||
assertTrue(inputXml.contains("<nickname>tester</nickname>"));
|
||||
assertTrue(inputXml.contains("<tool>"));
|
||||
assertEquals(inputXml.indexOf("<tool>"), inputXml.lastIndexOf("<tool>"));
|
||||
assertTrue(inputXml.contains("source=\"source-a\""));
|
||||
assertTrue(inputXml.contains("source=\"source-b\""));
|
||||
assertTrue(inputXml.contains("supply-A"));
|
||||
assertTrue(inputXml.contains("supply-B"));
|
||||
|
||||
assertEquals(4, producer.getChatMessages().size());
|
||||
Message lastUserMessage = producer.getChatMessages().get(2);
|
||||
assertEquals(Message.Character.USER, lastUserMessage.getRole());
|
||||
assertTrue(lastUserMessage.getContent().startsWith("[USER]: u-1: 你好,介绍一下你现在看到的上下文"));
|
||||
assertFalse(lastUserMessage.getContent().contains("<input>"));
|
||||
|
||||
Message lastAssistantMessage = producer.getChatMessages().get(3);
|
||||
assertEquals("收到", lastAssistantMessage.getContent());
|
||||
|
||||
ArgumentCaptor<MetaMessage> metaMessageCaptor = ArgumentCaptor.forClass(MetaMessage.class);
|
||||
verify(cognationCapability).addMetaMessage(anyString(), metaMessageCaptor.capture());
|
||||
MetaMessage metaMessage = metaMessageCaptor.getValue();
|
||||
assertNotNull(metaMessage);
|
||||
assertTrue(metaMessage.getUserMessage().getContent().startsWith("[USER]: u-1: 你好,介绍一下你现在看到的上下文"));
|
||||
assertEquals("收到", metaMessage.getAssistantMessage().getContent());
|
||||
assertEquals("收到", context.getCoreResponse().getString("text"));
|
||||
}
|
||||
|
||||
@Test
|
||||
void execute_shouldDropLegacyStructuredUserMessagesFromHistory() {
|
||||
List<Message> history = new ArrayList<>();
|
||||
history.add(new Message(Message.Character.USER, "<input><content>legacy</content></input>"));
|
||||
history.add(new Message(Message.Character.ASSISTANT, "legacy-assistant"));
|
||||
history.add(new Message(Message.Character.USER, "{\"text\":\"legacy-json\"}"));
|
||||
lenient().when(cognationCapability.getChatMessages()).thenReturn(history);
|
||||
|
||||
producer.init();
|
||||
|
||||
PartnerRunningFlowContext context = PartnerRunningFlowContext.Companion.fromSelf("新输入");
|
||||
producer.execute(context);
|
||||
|
||||
List<Message> updatedHistory = producer.getChatMessages();
|
||||
assertEquals(3, updatedHistory.size());
|
||||
assertEquals("legacy-assistant", updatedHistory.get(0).getContent());
|
||||
assertTrue(updatedHistory.get(1).getContent().startsWith("[AGENT]: self: 新输入"));
|
||||
assertEquals("收到", updatedHistory.get(2).getContent());
|
||||
}
|
||||
|
||||
@Test
|
||||
void execute_shouldSkipContextMessageWhenOnlySupplyBlocksExist() {
|
||||
lenient().when(cognationCapability.getChatMessages()).thenReturn(List.of());
|
||||
|
||||
producer.init();
|
||||
|
||||
PartnerRunningFlowContext context = PartnerRunningFlowContext.Companion.fromUser(
|
||||
"u-2",
|
||||
"只有补充块",
|
||||
"qq",
|
||||
"tester2"
|
||||
);
|
||||
context.getContextBlocks().add(new TestContextBlock(5, ContextBlock.Type.SUPPLY, "tool", "source-x", "supply-X"));
|
||||
|
||||
producer.execute(context);
|
||||
|
||||
List<Message> sentMessages = producer.getCapturedMessages();
|
||||
assertEquals(2, sentMessages.size());
|
||||
assertEquals(Message.Character.SYSTEM, sentMessages.get(0).getRole());
|
||||
assertTrue(sentMessages.get(1).getContent().contains("<input>"));
|
||||
assertFalse(sentMessages.get(1).getContent().contains("<context>"));
|
||||
assertTrue(sentMessages.get(1).getContent().contains("<tool>"));
|
||||
assertTrue(sentMessages.get(1).getContent().contains("supply-X"));
|
||||
}
|
||||
|
||||
private static final class TestCommunicationProducer extends CommunicationProducer {
|
||||
private List<Message> capturedMessages = List.of();
|
||||
|
||||
@Override
|
||||
public @NotNull String chat(@NotNull List<Message> messages) {
|
||||
capturedMessages = new ArrayList<>(mergeMessages(messages));
|
||||
return "收到";
|
||||
}
|
||||
|
||||
List<Message> getCapturedMessages() {
|
||||
return capturedMessages;
|
||||
}
|
||||
}
|
||||
|
||||
private static final class TestContextBlock extends ContextBlock {
|
||||
private final int priority;
|
||||
private final Type type;
|
||||
private final String blockName;
|
||||
private final String source;
|
||||
private final String payload;
|
||||
|
||||
private TestContextBlock(int priority, String source, String payload) {
|
||||
this(priority, Type.CONTEXT, "test-block", source, payload);
|
||||
}
|
||||
|
||||
private TestContextBlock(int priority, Type type, String blockName, String source, String payload) {
|
||||
this.priority = priority;
|
||||
this.type = type;
|
||||
this.blockName = blockName;
|
||||
this.source = source;
|
||||
this.payload = payload;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int getPriority() {
|
||||
return priority;
|
||||
}
|
||||
|
||||
@Override
|
||||
public @NotNull Type getType() {
|
||||
return type;
|
||||
}
|
||||
|
||||
@Override
|
||||
public @NotNull String getBlockName() {
|
||||
return blockName;
|
||||
}
|
||||
|
||||
@Override
|
||||
public @NotNull String getSource() {
|
||||
return source;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void fillXml(@NotNull Document document, @NotNull Element root) {
|
||||
appendTextElement(document, root, "payload", payload);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,127 @@
|
||||
package work.slhaf.partner.api.agent.runtime.interaction.flow
|
||||
|
||||
import org.w3c.dom.Document
|
||||
import org.w3c.dom.Element
|
||||
import java.io.StringWriter
|
||||
import javax.xml.parsers.DocumentBuilderFactory
|
||||
import javax.xml.transform.OutputKeys
|
||||
import javax.xml.transform.TransformerFactory
|
||||
import javax.xml.transform.dom.DOMSource
|
||||
import javax.xml.transform.stream.StreamResult
|
||||
|
||||
abstract class ContextBlock {
|
||||
|
||||
abstract val priority: Int
|
||||
abstract val type: Type
|
||||
|
||||
abstract val blockName: String
|
||||
abstract val source: String
|
||||
|
||||
enum class Type {
|
||||
CONTEXT,
|
||||
SUPPLY
|
||||
}
|
||||
|
||||
fun encodeToXml(): String {
|
||||
val document = DocumentBuilderFactory.newInstance()
|
||||
.newDocumentBuilder()
|
||||
.newDocument()
|
||||
|
||||
val root = document.createElement(blockName)
|
||||
root.setAttribute("source",source)
|
||||
document.appendChild(root)
|
||||
|
||||
fillXml(document, root)
|
||||
|
||||
return document.toXmlString()
|
||||
}
|
||||
|
||||
protected abstract fun fillXml(document: Document, root: Element)
|
||||
|
||||
protected fun appendTextElement(
|
||||
document: Document,
|
||||
parent: Element,
|
||||
tagName: String,
|
||||
value: Any?
|
||||
): Element {
|
||||
val element = document.createElement(tagName)
|
||||
element.textContent = value?.toString() ?: ""
|
||||
parent.appendChild(element)
|
||||
return element
|
||||
}
|
||||
|
||||
protected fun appendChildElement(
|
||||
document: Document,
|
||||
parent: Element,
|
||||
tagName: String,
|
||||
block: Element.() -> Unit = {}
|
||||
): Element {
|
||||
val element = document.createElement(tagName)
|
||||
parent.appendChild(element)
|
||||
element.block()
|
||||
return element
|
||||
}
|
||||
|
||||
protected fun appendCDataElement(
|
||||
document: Document,
|
||||
parent: Element,
|
||||
tagName: String,
|
||||
value: String?
|
||||
): Element {
|
||||
val element = document.createElement(tagName)
|
||||
element.appendChild(document.createCDATASection(value ?: ""))
|
||||
parent.appendChild(element)
|
||||
return element
|
||||
}
|
||||
|
||||
protected fun <T> appendListElement(
|
||||
document: Document,
|
||||
parent: Element,
|
||||
wrapperTagName: String,
|
||||
itemTagName: String,
|
||||
values: Iterable<T>,
|
||||
block: Element.(T) -> Unit = { value ->
|
||||
textContent = value?.toString() ?: ""
|
||||
}
|
||||
): Element {
|
||||
val wrapper = document.createElement(wrapperTagName)
|
||||
parent.appendChild(wrapper)
|
||||
|
||||
for (value in values) {
|
||||
val item = document.createElement(itemTagName)
|
||||
wrapper.appendChild(item)
|
||||
item.block(value)
|
||||
}
|
||||
|
||||
return wrapper
|
||||
}
|
||||
|
||||
protected fun <T> appendRepeatedElements(
|
||||
document: Document,
|
||||
parent: Element,
|
||||
itemTagName: String,
|
||||
values: Iterable<T>,
|
||||
block: Element.(T) -> Unit = { value ->
|
||||
textContent = value?.toString() ?: ""
|
||||
}
|
||||
) {
|
||||
for (value in values) {
|
||||
val item = document.createElement(itemTagName)
|
||||
parent.appendChild(item)
|
||||
item.block(value)
|
||||
}
|
||||
}
|
||||
|
||||
private fun Document.toXmlString(): String {
|
||||
val transformer = TransformerFactory.newInstance().newTransformer().apply {
|
||||
setOutputProperty(OutputKeys.INDENT, "yes")
|
||||
setOutputProperty(OutputKeys.ENCODING, "UTF-8")
|
||||
setOutputProperty("{http://xml.apache.org/xslt}indent-amount", "2")
|
||||
}
|
||||
|
||||
return StringWriter().use { writer ->
|
||||
transformer.transform(DOMSource(this), StreamResult(writer))
|
||||
writer.toString()
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -24,6 +24,8 @@ abstract class RunningFlowContext {
|
||||
*/
|
||||
var target = source
|
||||
|
||||
val contextBlocks = mutableListOf<ContextBlock>()
|
||||
|
||||
private val _additionalUserInfo = mutableMapOf<String, String>()
|
||||
val additionalUserInfo: Map<String, String>
|
||||
get() = _additionalUserInfo
|
||||
|
||||
Reference in New Issue
Block a user