mirror of
https://github.com/slhaf/Partner.git
synced 2026-05-12 16:53:04 +08:00
refactor(CommunicationProducer): split context and supply message assembly
This commit is contained in:
@@ -4,21 +4,35 @@ import com.alibaba.fastjson2.JSONObject;
|
|||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
import lombok.EqualsAndHashCode;
|
import lombok.EqualsAndHashCode;
|
||||||
import org.jetbrains.annotations.NotNull;
|
import org.jetbrains.annotations.NotNull;
|
||||||
|
import org.w3c.dom.Document;
|
||||||
|
import org.w3c.dom.Element;
|
||||||
import work.slhaf.partner.api.agent.factory.capability.annotation.InjectCapability;
|
import work.slhaf.partner.api.agent.factory.capability.annotation.InjectCapability;
|
||||||
import work.slhaf.partner.api.agent.factory.component.abstracts.AbstractAgentModule;
|
import work.slhaf.partner.api.agent.factory.component.abstracts.AbstractAgentModule;
|
||||||
import work.slhaf.partner.api.agent.factory.component.abstracts.ActivateModel;
|
import work.slhaf.partner.api.agent.factory.component.abstracts.ActivateModel;
|
||||||
import work.slhaf.partner.api.agent.factory.component.annotation.Init;
|
import work.slhaf.partner.api.agent.factory.component.annotation.Init;
|
||||||
|
import work.slhaf.partner.api.agent.runtime.interaction.flow.ContextBlock;
|
||||||
import work.slhaf.partner.api.chat.pojo.Message;
|
import work.slhaf.partner.api.chat.pojo.Message;
|
||||||
import work.slhaf.partner.api.chat.pojo.MetaMessage;
|
import work.slhaf.partner.api.chat.pojo.MetaMessage;
|
||||||
import work.slhaf.partner.core.cognation.CognationCapability;
|
import work.slhaf.partner.core.cognation.CognationCapability;
|
||||||
import work.slhaf.partner.module.common.entity.AppendPromptData;
|
|
||||||
import work.slhaf.partner.module.common.model.ModelConstant;
|
|
||||||
import work.slhaf.partner.runtime.interaction.data.context.PartnerRunningFlowContext;
|
import work.slhaf.partner.runtime.interaction.data.context.PartnerRunningFlowContext;
|
||||||
|
|
||||||
|
import javax.xml.parsers.DocumentBuilderFactory;
|
||||||
|
import javax.xml.transform.OutputKeys;
|
||||||
|
import javax.xml.transform.Transformer;
|
||||||
|
import javax.xml.transform.TransformerFactory;
|
||||||
|
import javax.xml.transform.dom.DOMSource;
|
||||||
|
import javax.xml.transform.stream.StreamResult;
|
||||||
|
import java.io.ByteArrayInputStream;
|
||||||
|
import java.io.StringWriter;
|
||||||
|
import java.nio.charset.StandardCharsets;
|
||||||
import java.time.LocalDateTime;
|
import java.time.LocalDateTime;
|
||||||
import java.time.format.DateTimeFormatter;
|
import java.time.format.DateTimeFormatter;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
|
import java.util.Comparator;
|
||||||
|
import java.util.LinkedHashMap;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
import java.util.stream.Collectors;
|
||||||
|
|
||||||
import static work.slhaf.partner.common.util.ExtractUtil.extractJson;
|
import static work.slhaf.partner.common.util.ExtractUtil.extractJson;
|
||||||
|
|
||||||
@@ -26,9 +40,18 @@ import static work.slhaf.partner.common.util.ExtractUtil.extractJson;
|
|||||||
@Data
|
@Data
|
||||||
public class CommunicationProducer extends AbstractAgentModule.Running<PartnerRunningFlowContext> implements ActivateModel {
|
public class CommunicationProducer extends AbstractAgentModule.Running<PartnerRunningFlowContext> implements ActivateModel {
|
||||||
|
|
||||||
|
private static final String MODULE_PROMPT = """
|
||||||
|
你是 Partner 的表达模块。
|
||||||
|
你接下来收到的消息固定分为三个区段:
|
||||||
|
1. system message 是 Head, 用于说明整个输入结构与输出要求。
|
||||||
|
2. <context> 区段只承载 type=CONTEXT 的上下文块, 其中每个子块都带有独立来源, 仅作为理解当前状态与辅助决策的依据。
|
||||||
|
3. Conversation 区段是对话轨迹; 最新的一条 user message 会使用 <input> 结构, 其中 <content> 是本轮用户原始输入, 其他子标签是输入元信息与 type=SUPPLY 的补充块, 补充块会按 blockName 分区。
|
||||||
|
你必须综合 Context 与 Conversation 回答最新输入, 不要把 XML 标签当作需要原样复述给用户的内容。
|
||||||
|
直接输出最终回应内容即可, 不需要额外包装为 JSON。
|
||||||
|
""";
|
||||||
|
|
||||||
@InjectCapability
|
@InjectCapability
|
||||||
private CognationCapability cognationCapability;
|
private CognationCapability cognationCapability;
|
||||||
private final List<Message> appendedMessages = new ArrayList<>();
|
|
||||||
private final List<Message> chatMessages = new ArrayList<>();
|
private final List<Message> chatMessages = new ArrayList<>();
|
||||||
|
|
||||||
@Init
|
@Init
|
||||||
@@ -48,172 +71,220 @@ public class CommunicationProducer extends AbstractAgentModule.Running<PartnerRu
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public @NotNull List<Message> modulePrompt() {
|
||||||
|
return List.of(new Message(Message.Character.SYSTEM, MODULE_PROMPT));
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void execute(PartnerRunningFlowContext runningFlowContext) {
|
public void execute(PartnerRunningFlowContext runningFlowContext) {
|
||||||
String userId = runningFlowContext.getSource();
|
log.debug("Communicating with: {}",runningFlowContext.getSource());
|
||||||
log.debug("[CommunicationProducer] 主对话流程开始: {}", userId);
|
|
||||||
beforeChat(runningFlowContext);
|
|
||||||
executeChat(runningFlowContext);
|
executeChat(runningFlowContext);
|
||||||
log.debug("[CommunicationProducer] 主对话流程({})结束...", userId);
|
|
||||||
}
|
|
||||||
|
|
||||||
private void beforeChat(PartnerRunningFlowContext runningFlowContext) {
|
|
||||||
setAppendedPromptMessage(runningFlowContext);
|
|
||||||
activateModule(runningFlowContext);
|
|
||||||
setMessageCount(runningFlowContext);
|
|
||||||
|
|
||||||
log.debug("[CommunicationProducer] 当前消息列表大小: {}", chatMessages.size());
|
|
||||||
log.debug("[CommunicationProducer] 当前核心prompt内容: {}", runningFlowContext.getCoreContext().toString());
|
|
||||||
|
|
||||||
setMessage(runningFlowContext.getCoreContext().toString());
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO need to update message appending logic
|
|
||||||
private void setAppendedPromptMessage(PartnerRunningFlowContext runningFlowContext) {
|
|
||||||
List<AppendPromptData> appendedPrompt = runningFlowContext.getModuleContext().getAppendedPrompt();
|
|
||||||
int appendedPromptSize = getAppendedPromptSize(appendedPrompt);
|
|
||||||
if (appendedPromptSize > 0) {
|
|
||||||
setAppendedPromptMessage(appendedPrompt);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private void executeChat(PartnerRunningFlowContext runningFlowContext) {
|
private void executeChat(PartnerRunningFlowContext runningFlowContext) {
|
||||||
JSONObject response = new JSONObject();
|
String responseText = null;
|
||||||
|
|
||||||
|
// TODO considering removing retries in module
|
||||||
int count = 0;
|
int count = 0;
|
||||||
while (true) {
|
while (true) {
|
||||||
try {
|
try {
|
||||||
String chatResponse = this.chat(buildChatMessages());
|
// TODO 为各模块提供 emit msg 能力后, 在这里统一接收并分发结构化输出.
|
||||||
try {
|
responseText = this.chat(buildChatMessages(runningFlowContext));
|
||||||
response.putAll(JSONObject.parse(extractJson(chatResponse)));
|
log.debug("CommunicationProducer responses: {}", responseText);
|
||||||
} catch (Exception e) {
|
updateModuleContextAndChatMessages(runningFlowContext, responseText);
|
||||||
log.warn("主模型回复格式出错, 将直接作为消息返回, 建议尝试更换主模型...");
|
|
||||||
handleExceptionResponse(response, chatResponse);
|
|
||||||
}
|
|
||||||
log.debug("[CommunicationProducer] CommunicationProducer 响应内容: {}", response);
|
|
||||||
updateModuleContextAndChatMessages(runningFlowContext, response.getString("text"));
|
|
||||||
break;
|
break;
|
||||||
} catch (Exception e) {
|
} catch (Exception e) {
|
||||||
count++;
|
count++;
|
||||||
log.error("[CommunicationProducer] CoreModel执行异常: {}", e.getLocalizedMessage());
|
log.error("Communicating exception occurred: {}", e.getLocalizedMessage());
|
||||||
if (count > 3) {
|
if (count > 3) {
|
||||||
handleExceptionResponse(response, "主模型交互出错: " + e.getLocalizedMessage());
|
responseText = "CommunicationProducer Failed: " + e.getLocalizedMessage();
|
||||||
chatMessages.removeLast();
|
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
} finally {
|
} finally {
|
||||||
updateCoreResponse(runningFlowContext, response);
|
updateCoreResponse(runningFlowContext, responseText);
|
||||||
resetAppendedMessages();
|
|
||||||
log.debug("[CommunicationProducer] 消息列表更新大小: {}", chatMessages.size());
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private int getAppendedPromptSize(List<AppendPromptData> appendedPrompt) {
|
private void updateCoreResponse(PartnerRunningFlowContext runningFlowContext, String responseText) {
|
||||||
int size = 0;
|
runningFlowContext.getCoreResponse().put("text", responseText);
|
||||||
for (AppendPromptData data : appendedPrompt) {
|
}
|
||||||
size += data.getAppendedPrompt().size();
|
|
||||||
|
private List<Message> buildChatMessages(PartnerRunningFlowContext runningFlowContext) {
|
||||||
|
List<Message> temp = new ArrayList<>(chatMessages.size() + 2);
|
||||||
|
Message contextMessage = buildContextMessage(runningFlowContext);
|
||||||
|
if (contextMessage != null) {
|
||||||
|
temp.add(contextMessage);
|
||||||
}
|
}
|
||||||
return size;
|
|
||||||
}
|
|
||||||
|
|
||||||
private void activateModule(PartnerRunningFlowContext context) {
|
|
||||||
for (AppendPromptData data : context.getModuleContext().getAppendedPrompt()) {
|
|
||||||
if (data.getAppendedPrompt().isEmpty()) continue;
|
|
||||||
context.getCoreContext().activateModule(data.getModuleName());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private void updateCoreResponse(PartnerRunningFlowContext runningFlowContext, JSONObject response) {
|
|
||||||
runningFlowContext.getCoreResponse().put("text", response.getString("text"));
|
|
||||||
}
|
|
||||||
|
|
||||||
private void resetAppendedMessages() {
|
|
||||||
this.appendedMessages.clear();
|
|
||||||
}
|
|
||||||
|
|
||||||
private List<Message> buildChatMessages() {
|
|
||||||
List<Message> temp = new ArrayList<>(appendedMessages.size() + chatMessages.size());
|
|
||||||
temp.addAll(appendedMessages);
|
|
||||||
temp.addAll(chatMessages);
|
temp.addAll(chatMessages);
|
||||||
|
temp.add(buildInputMessage(runningFlowContext));
|
||||||
return temp;
|
return temp;
|
||||||
}
|
}
|
||||||
|
|
||||||
private void updateModuleContextAndChatMessages(PartnerRunningFlowContext runningFlowContext, String response) {
|
private void updateModuleContextAndChatMessages(PartnerRunningFlowContext runningFlowContext, String response) {
|
||||||
cognationCapability.getMessageLock().lock();
|
cognationCapability.getMessageLock().lock();
|
||||||
chatMessages.removeIf(m -> {
|
try {
|
||||||
if (m.getRole() == Message.Character.ASSISTANT) {
|
chatMessages.removeIf(this::isStructuredUserMessage);
|
||||||
return false;
|
// TODO 此处的时间标识应当采用 RunningFlowContext 携带时间
|
||||||
}
|
String dateTime = LocalDateTime.now().format(DateTimeFormatter.ofPattern("\r\n**[yyyy-MM-dd HH:mm:ss]"));
|
||||||
try {
|
Message primaryUserMessage = new Message(
|
||||||
JSONObject.parseObject(extractJson(m.getContent()));
|
Message.Character.USER,
|
||||||
return true;
|
formatConversationUserMessage(runningFlowContext) + dateTime
|
||||||
} catch (Exception e) {
|
);
|
||||||
return false;
|
chatMessages.add(primaryUserMessage);
|
||||||
}
|
Message assistantMessage = new Message(Message.Character.ASSISTANT, response);
|
||||||
});
|
chatMessages.add(assistantMessage);
|
||||||
//添加时间标志
|
MetaMessage metaMessage = new MetaMessage(primaryUserMessage, assistantMessage);
|
||||||
// TODO 此处的时间标识应当采用 RunningFlowContext 携带时间
|
cognationCapability.addMetaMessage(runningFlowContext.getSource(), metaMessage);
|
||||||
String dateTime = LocalDateTime.now().format(DateTimeFormatter.ofPattern("\r\n**[yyyy-MM-dd HH:mm:ss]"));
|
} finally {
|
||||||
Message primaryUserMessage = new Message(Message.Character.USER, runningFlowContext.getCoreContext().getText() + dateTime);
|
cognationCapability.getMessageLock().unlock();
|
||||||
chatMessages.add(primaryUserMessage);
|
|
||||||
Message assistantMessage = new Message(Message.Character.ASSISTANT, response);
|
|
||||||
chatMessages.add(assistantMessage);
|
|
||||||
cognationCapability.getMessageLock().unlock();
|
|
||||||
//区分单人聊天场景
|
|
||||||
// if (runningFlowContext.isSingle()) {
|
|
||||||
MetaMessage metaMessage = new MetaMessage(primaryUserMessage, assistantMessage);
|
|
||||||
cognationCapability.addMetaMessage(runningFlowContext.getSource(), metaMessage);
|
|
||||||
// }
|
|
||||||
}
|
|
||||||
|
|
||||||
private void setMessage(String coreContextStr) {
|
|
||||||
Message userMessage = new Message(Message.Character.USER, coreContextStr);
|
|
||||||
chatMessages.add(userMessage);
|
|
||||||
}
|
|
||||||
|
|
||||||
private void handleExceptionResponse(JSONObject response, String chatResponse) {
|
|
||||||
response.put("text", chatResponse);
|
|
||||||
// interactionContext.setFinished(true);
|
|
||||||
}
|
|
||||||
|
|
||||||
private void setMessageCount(PartnerRunningFlowContext runningFlowContext) {
|
|
||||||
runningFlowContext.getModuleContext().getExtraContext().put("message_count", chatMessages.size());
|
|
||||||
}
|
|
||||||
|
|
||||||
private void setAppendedPromptMessage(List<AppendPromptData> appendPrompt) {
|
|
||||||
Message appendDeclareMessage = new Message(Message.Character.USER, ModelConstant.CharacterPrefix.SYSTEM + "认知补充开始");
|
|
||||||
this.appendedMessages.add(appendDeclareMessage);
|
|
||||||
for (AppendPromptData data : appendPrompt) {
|
|
||||||
setStartMessage(data);
|
|
||||||
setContentMessage(data);
|
|
||||||
setEndMessage(data);
|
|
||||||
setAssistantMessage();
|
|
||||||
}
|
}
|
||||||
Message appendEndMessage = new Message(Message.Character.USER, ModelConstant.CharacterPrefix.SYSTEM + "认知补充结束");
|
|
||||||
this.appendedMessages.add(appendEndMessage);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private void setAssistantMessage() {
|
private Message buildContextMessage(PartnerRunningFlowContext runningFlowContext) {
|
||||||
Message message = new Message(Message.Character.ASSISTANT, "嗯,明白了");
|
List<ContextBlock> contextBlocks = filterContextBlocks(
|
||||||
appendedMessages.add(message);
|
runningFlowContext.getContextBlocks(),
|
||||||
|
ContextBlock.Type.CONTEXT
|
||||||
|
);
|
||||||
|
if (contextBlocks.isEmpty()) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
return new Message(Message.Character.USER, buildContextXml(contextBlocks));
|
||||||
}
|
}
|
||||||
|
|
||||||
private void setEndMessage(AppendPromptData data) {
|
private Message buildInputMessage(PartnerRunningFlowContext runningFlowContext) {
|
||||||
Message endMessage = new Message(Message.Character.USER, ModelConstant.CharacterPrefix.SYSTEM + data.getModuleName() + "认知补充结束.");
|
return new Message(Message.Character.USER, buildInputXml(runningFlowContext));
|
||||||
appendedMessages.add(endMessage);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private void setContentMessage(AppendPromptData data) {
|
private String buildContextXml(List<ContextBlock> contextBlocks) {
|
||||||
data.getAppendedPrompt().forEach((k, v) -> {
|
try {
|
||||||
Message contentMessage = new Message(Message.Character.USER, ModelConstant.CharacterPrefix.SYSTEM + k + v + "\r\n");
|
Document document = newDocument();
|
||||||
appendedMessages.add(contentMessage);
|
Element root = document.createElement("context");
|
||||||
});
|
document.appendChild(root);
|
||||||
|
|
||||||
|
contextBlocks.stream()
|
||||||
|
.sorted(Comparator.comparingInt(ContextBlock::getPriority))
|
||||||
|
.map(ContextBlock::encodeToXml)
|
||||||
|
.forEach(blockXml -> {
|
||||||
|
Element blockElement = parseElement(blockXml);
|
||||||
|
root.appendChild(document.importNode(blockElement, true));
|
||||||
|
});
|
||||||
|
|
||||||
|
return toXmlString(document);
|
||||||
|
} catch (Exception e) {
|
||||||
|
throw new IllegalStateException("构建 context 区段失败", e);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private void setStartMessage(AppendPromptData data) {
|
private String buildInputXml(PartnerRunningFlowContext runningFlowContext) {
|
||||||
Message startMessage = new Message(Message.Character.USER, ModelConstant.CharacterPrefix.SYSTEM + data.getModuleName() + "以下为" + data.getModuleName() + "相关认知.");
|
try {
|
||||||
appendedMessages.add(startMessage);
|
Document document = newDocument();
|
||||||
|
Element root = document.createElement("input");
|
||||||
|
document.appendChild(root);
|
||||||
|
|
||||||
|
appendTextElement(document, root, "content", runningFlowContext.getInput());
|
||||||
|
appendTextElement(document, root, "source", runningFlowContext.getSource());
|
||||||
|
for (Map.Entry<String, String> entry : runningFlowContext.getAdditionalUserInfo().entrySet()) {
|
||||||
|
appendTextElement(document, root, sanitizeTagName(entry.getKey()), entry.getValue());
|
||||||
|
}
|
||||||
|
appendSupplyBlocks(document, root, runningFlowContext.getContextBlocks());
|
||||||
|
|
||||||
|
return toXmlString(document);
|
||||||
|
} catch (Exception e) {
|
||||||
|
throw new IllegalStateException("构建 input 区段失败", e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private boolean isStructuredUserMessage(Message message) {
|
||||||
|
if (message.getRole() == Message.Character.ASSISTANT) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
String content = message.getContent();
|
||||||
|
String trimmed = content.trim();
|
||||||
|
if (trimmed.startsWith("<input>") || trimmed.startsWith("<context>") || trimmed.startsWith("<?xml")) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
try {
|
||||||
|
JSONObject.parseObject(extractJson(content));
|
||||||
|
return true;
|
||||||
|
} catch (Exception e) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private List<ContextBlock> filterContextBlocks(List<ContextBlock> contextBlocks, ContextBlock.Type type) {
|
||||||
|
return contextBlocks.stream()
|
||||||
|
.filter(block -> block.getType() == type)
|
||||||
|
.sorted(Comparator.comparingInt(ContextBlock::getPriority))
|
||||||
|
.toList();
|
||||||
|
}
|
||||||
|
|
||||||
|
private String formatConversationUserMessage(PartnerRunningFlowContext runningFlowContext) {
|
||||||
|
return runningFlowContext.getSource() + ": " + runningFlowContext.getInput();
|
||||||
|
}
|
||||||
|
|
||||||
|
private Document newDocument() throws Exception {
|
||||||
|
return DocumentBuilderFactory.newInstance()
|
||||||
|
.newDocumentBuilder()
|
||||||
|
.newDocument();
|
||||||
|
}
|
||||||
|
|
||||||
|
private void appendTextElement(Document document, Element parent, String tagName, String value) {
|
||||||
|
Element element = document.createElement(tagName);
|
||||||
|
element.setTextContent(value == null ? "" : value);
|
||||||
|
parent.appendChild(element);
|
||||||
|
}
|
||||||
|
|
||||||
|
private Element parseElement(String xml) {
|
||||||
|
try {
|
||||||
|
Document parsedDocument = DocumentBuilderFactory.newInstance()
|
||||||
|
.newDocumentBuilder()
|
||||||
|
.parse(new ByteArrayInputStream(xml.getBytes(StandardCharsets.UTF_8)));
|
||||||
|
return parsedDocument.getDocumentElement();
|
||||||
|
} catch (Exception e) {
|
||||||
|
throw new IllegalStateException("解析 ContextBlock XML 失败", e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private void appendSupplyBlocks(Document document, Element inputRoot, List<ContextBlock> contextBlocks) {
|
||||||
|
Map<String, List<ContextBlock>> groupedBlocks = filterContextBlocks(contextBlocks, ContextBlock.Type.SUPPLY).stream()
|
||||||
|
.collect(Collectors.groupingBy(
|
||||||
|
block -> sanitizeTagName(block.getBlockName()),
|
||||||
|
LinkedHashMap::new,
|
||||||
|
Collectors.toList()
|
||||||
|
));
|
||||||
|
|
||||||
|
for (Map.Entry<String, List<ContextBlock>> entry : groupedBlocks.entrySet()) {
|
||||||
|
Element groupElement = document.createElement(entry.getKey());
|
||||||
|
inputRoot.appendChild(groupElement);
|
||||||
|
for (ContextBlock block : entry.getValue()) {
|
||||||
|
Element blockElement = parseElement(block.encodeToXml());
|
||||||
|
groupElement.appendChild(document.importNode(blockElement, true));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private String sanitizeTagName(String rawTagName) {
|
||||||
|
if (rawTagName == null || rawTagName.isBlank()) {
|
||||||
|
return "meta";
|
||||||
|
}
|
||||||
|
String sanitized = rawTagName.replaceAll("[^A-Za-z0-9_.-]", "_");
|
||||||
|
if (!Character.isLetter(sanitized.charAt(0)) && sanitized.charAt(0) != '_') {
|
||||||
|
sanitized = "_" + sanitized;
|
||||||
|
}
|
||||||
|
return sanitized;
|
||||||
|
}
|
||||||
|
|
||||||
|
private String toXmlString(Document document) throws Exception {
|
||||||
|
Transformer transformer = TransformerFactory.newInstance().newTransformer();
|
||||||
|
transformer.setOutputProperty(OutputKeys.INDENT, "yes");
|
||||||
|
transformer.setOutputProperty(OutputKeys.ENCODING, "UTF-8");
|
||||||
|
transformer.setOutputProperty(OutputKeys.OMIT_XML_DECLARATION, "yes");
|
||||||
|
transformer.setOutputProperty("{http://xml.apache.org/xslt}indent-amount", "2");
|
||||||
|
StringWriter writer = new StringWriter();
|
||||||
|
transformer.transform(new DOMSource(document), new StreamResult(writer));
|
||||||
|
return writer.toString();
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|||||||
@@ -0,0 +1,215 @@
|
|||||||
|
package work.slhaf.partner.module.modules.core;
|
||||||
|
|
||||||
|
import org.junit.jupiter.api.BeforeEach;
|
||||||
|
import org.junit.jupiter.api.Test;
|
||||||
|
import org.junit.jupiter.api.extension.ExtendWith;
|
||||||
|
import org.jetbrains.annotations.NotNull;
|
||||||
|
import org.mockito.ArgumentCaptor;
|
||||||
|
import org.mockito.Mock;
|
||||||
|
import org.mockito.junit.jupiter.MockitoExtension;
|
||||||
|
import org.w3c.dom.Document;
|
||||||
|
import org.w3c.dom.Element;
|
||||||
|
import work.slhaf.partner.api.agent.runtime.interaction.flow.ContextBlock;
|
||||||
|
import work.slhaf.partner.api.chat.pojo.Message;
|
||||||
|
import work.slhaf.partner.api.chat.pojo.MetaMessage;
|
||||||
|
import work.slhaf.partner.core.cognation.CognationCapability;
|
||||||
|
import work.slhaf.partner.runtime.interaction.data.context.PartnerRunningFlowContext;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.concurrent.locks.ReentrantLock;
|
||||||
|
|
||||||
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
|
import static org.junit.jupiter.api.Assertions.assertFalse;
|
||||||
|
import static org.junit.jupiter.api.Assertions.assertNotNull;
|
||||||
|
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||||
|
import static org.mockito.ArgumentMatchers.anyString;
|
||||||
|
import static org.mockito.Mockito.lenient;
|
||||||
|
import static org.mockito.Mockito.verify;
|
||||||
|
|
||||||
|
@ExtendWith(MockitoExtension.class)
|
||||||
|
class CommunicationProducerTest {
|
||||||
|
|
||||||
|
@Mock
|
||||||
|
private CognationCapability cognationCapability;
|
||||||
|
|
||||||
|
private TestCommunicationProducer producer;
|
||||||
|
|
||||||
|
@BeforeEach
|
||||||
|
void setUp() {
|
||||||
|
producer = new TestCommunicationProducer();
|
||||||
|
producer.setCognationCapability(cognationCapability);
|
||||||
|
lenient().when(cognationCapability.getMessageLock()).thenReturn(new ReentrantLock());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void execute_shouldAssembleHeadContextConversationAndPersistCompactHistory() {
|
||||||
|
List<Message> history = List.of(
|
||||||
|
new Message(Message.Character.USER, "[USER]: old-user: 旧消息"),
|
||||||
|
new Message(Message.Character.ASSISTANT, "旧回复")
|
||||||
|
);
|
||||||
|
lenient().when(cognationCapability.getChatMessages()).thenReturn(history);
|
||||||
|
|
||||||
|
producer.init();
|
||||||
|
|
||||||
|
PartnerRunningFlowContext context = PartnerRunningFlowContext.Companion.fromUser(
|
||||||
|
"u-1",
|
||||||
|
"你好,介绍一下你现在看到的上下文",
|
||||||
|
"wechat",
|
||||||
|
"tester"
|
||||||
|
);
|
||||||
|
context.getContextBlocks().add(new TestContextBlock(20, "memory", "slice-A"));
|
||||||
|
context.getContextBlocks().add(new TestContextBlock(10, "perceive", "slice-B"));
|
||||||
|
context.getContextBlocks().add(new TestContextBlock(30, ContextBlock.Type.SUPPLY, "tool", "source-a", "supply-A"));
|
||||||
|
context.getContextBlocks().add(new TestContextBlock(40, ContextBlock.Type.SUPPLY, "tool", "source-b", "supply-B"));
|
||||||
|
|
||||||
|
producer.execute(context);
|
||||||
|
|
||||||
|
List<Message> sentMessages = producer.getCapturedMessages();
|
||||||
|
assertEquals(5, sentMessages.size());
|
||||||
|
assertEquals(Message.Character.SYSTEM, sentMessages.get(0).getRole());
|
||||||
|
assertTrue(sentMessages.get(0).getContent().contains("Head"));
|
||||||
|
|
||||||
|
String contextXml = sentMessages.get(1).getContent();
|
||||||
|
assertTrue(contextXml.contains("<context>"));
|
||||||
|
assertTrue(contextXml.indexOf("slice-B") < contextXml.indexOf("slice-A"));
|
||||||
|
assertTrue(contextXml.contains("source=\"perceive\""));
|
||||||
|
assertTrue(contextXml.contains("source=\"memory\""));
|
||||||
|
assertFalse(contextXml.contains("supply-A"));
|
||||||
|
assertFalse(contextXml.contains("supply-B"));
|
||||||
|
|
||||||
|
String inputXml = sentMessages.get(4).getContent();
|
||||||
|
assertTrue(inputXml.contains("<input>"));
|
||||||
|
assertTrue(inputXml.contains("<content>你好,介绍一下你现在看到的上下文</content>"));
|
||||||
|
assertTrue(inputXml.contains("<source>[USER]: u-1</source>"));
|
||||||
|
assertTrue(inputXml.contains("<platform>wechat</platform>"));
|
||||||
|
assertTrue(inputXml.contains("<nickname>tester</nickname>"));
|
||||||
|
assertTrue(inputXml.contains("<tool>"));
|
||||||
|
assertEquals(inputXml.indexOf("<tool>"), inputXml.lastIndexOf("<tool>"));
|
||||||
|
assertTrue(inputXml.contains("source=\"source-a\""));
|
||||||
|
assertTrue(inputXml.contains("source=\"source-b\""));
|
||||||
|
assertTrue(inputXml.contains("supply-A"));
|
||||||
|
assertTrue(inputXml.contains("supply-B"));
|
||||||
|
|
||||||
|
assertEquals(4, producer.getChatMessages().size());
|
||||||
|
Message lastUserMessage = producer.getChatMessages().get(2);
|
||||||
|
assertEquals(Message.Character.USER, lastUserMessage.getRole());
|
||||||
|
assertTrue(lastUserMessage.getContent().startsWith("[USER]: u-1: 你好,介绍一下你现在看到的上下文"));
|
||||||
|
assertFalse(lastUserMessage.getContent().contains("<input>"));
|
||||||
|
|
||||||
|
Message lastAssistantMessage = producer.getChatMessages().get(3);
|
||||||
|
assertEquals("收到", lastAssistantMessage.getContent());
|
||||||
|
|
||||||
|
ArgumentCaptor<MetaMessage> metaMessageCaptor = ArgumentCaptor.forClass(MetaMessage.class);
|
||||||
|
verify(cognationCapability).addMetaMessage(anyString(), metaMessageCaptor.capture());
|
||||||
|
MetaMessage metaMessage = metaMessageCaptor.getValue();
|
||||||
|
assertNotNull(metaMessage);
|
||||||
|
assertTrue(metaMessage.getUserMessage().getContent().startsWith("[USER]: u-1: 你好,介绍一下你现在看到的上下文"));
|
||||||
|
assertEquals("收到", metaMessage.getAssistantMessage().getContent());
|
||||||
|
assertEquals("收到", context.getCoreResponse().getString("text"));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void execute_shouldDropLegacyStructuredUserMessagesFromHistory() {
|
||||||
|
List<Message> history = new ArrayList<>();
|
||||||
|
history.add(new Message(Message.Character.USER, "<input><content>legacy</content></input>"));
|
||||||
|
history.add(new Message(Message.Character.ASSISTANT, "legacy-assistant"));
|
||||||
|
history.add(new Message(Message.Character.USER, "{\"text\":\"legacy-json\"}"));
|
||||||
|
lenient().when(cognationCapability.getChatMessages()).thenReturn(history);
|
||||||
|
|
||||||
|
producer.init();
|
||||||
|
|
||||||
|
PartnerRunningFlowContext context = PartnerRunningFlowContext.Companion.fromSelf("新输入");
|
||||||
|
producer.execute(context);
|
||||||
|
|
||||||
|
List<Message> updatedHistory = producer.getChatMessages();
|
||||||
|
assertEquals(3, updatedHistory.size());
|
||||||
|
assertEquals("legacy-assistant", updatedHistory.get(0).getContent());
|
||||||
|
assertTrue(updatedHistory.get(1).getContent().startsWith("[AGENT]: self: 新输入"));
|
||||||
|
assertEquals("收到", updatedHistory.get(2).getContent());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void execute_shouldSkipContextMessageWhenOnlySupplyBlocksExist() {
|
||||||
|
lenient().when(cognationCapability.getChatMessages()).thenReturn(List.of());
|
||||||
|
|
||||||
|
producer.init();
|
||||||
|
|
||||||
|
PartnerRunningFlowContext context = PartnerRunningFlowContext.Companion.fromUser(
|
||||||
|
"u-2",
|
||||||
|
"只有补充块",
|
||||||
|
"qq",
|
||||||
|
"tester2"
|
||||||
|
);
|
||||||
|
context.getContextBlocks().add(new TestContextBlock(5, ContextBlock.Type.SUPPLY, "tool", "source-x", "supply-X"));
|
||||||
|
|
||||||
|
producer.execute(context);
|
||||||
|
|
||||||
|
List<Message> sentMessages = producer.getCapturedMessages();
|
||||||
|
assertEquals(2, sentMessages.size());
|
||||||
|
assertEquals(Message.Character.SYSTEM, sentMessages.get(0).getRole());
|
||||||
|
assertTrue(sentMessages.get(1).getContent().contains("<input>"));
|
||||||
|
assertFalse(sentMessages.get(1).getContent().contains("<context>"));
|
||||||
|
assertTrue(sentMessages.get(1).getContent().contains("<tool>"));
|
||||||
|
assertTrue(sentMessages.get(1).getContent().contains("supply-X"));
|
||||||
|
}
|
||||||
|
|
||||||
|
private static final class TestCommunicationProducer extends CommunicationProducer {
|
||||||
|
private List<Message> capturedMessages = List.of();
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public @NotNull String chat(@NotNull List<Message> messages) {
|
||||||
|
capturedMessages = new ArrayList<>(mergeMessages(messages));
|
||||||
|
return "收到";
|
||||||
|
}
|
||||||
|
|
||||||
|
List<Message> getCapturedMessages() {
|
||||||
|
return capturedMessages;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private static final class TestContextBlock extends ContextBlock {
|
||||||
|
private final int priority;
|
||||||
|
private final Type type;
|
||||||
|
private final String blockName;
|
||||||
|
private final String source;
|
||||||
|
private final String payload;
|
||||||
|
|
||||||
|
private TestContextBlock(int priority, String source, String payload) {
|
||||||
|
this(priority, Type.CONTEXT, "test-block", source, payload);
|
||||||
|
}
|
||||||
|
|
||||||
|
private TestContextBlock(int priority, Type type, String blockName, String source, String payload) {
|
||||||
|
this.priority = priority;
|
||||||
|
this.type = type;
|
||||||
|
this.blockName = blockName;
|
||||||
|
this.source = source;
|
||||||
|
this.payload = payload;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int getPriority() {
|
||||||
|
return priority;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public @NotNull Type getType() {
|
||||||
|
return type;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public @NotNull String getBlockName() {
|
||||||
|
return blockName;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public @NotNull String getSource() {
|
||||||
|
return source;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected void fillXml(@NotNull Document document, @NotNull Element root) {
|
||||||
|
appendTextElement(document, root, "payload", payload);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,127 @@
|
|||||||
|
package work.slhaf.partner.api.agent.runtime.interaction.flow
|
||||||
|
|
||||||
|
import org.w3c.dom.Document
|
||||||
|
import org.w3c.dom.Element
|
||||||
|
import java.io.StringWriter
|
||||||
|
import javax.xml.parsers.DocumentBuilderFactory
|
||||||
|
import javax.xml.transform.OutputKeys
|
||||||
|
import javax.xml.transform.TransformerFactory
|
||||||
|
import javax.xml.transform.dom.DOMSource
|
||||||
|
import javax.xml.transform.stream.StreamResult
|
||||||
|
|
||||||
|
abstract class ContextBlock {
|
||||||
|
|
||||||
|
abstract val priority: Int
|
||||||
|
abstract val type: Type
|
||||||
|
|
||||||
|
abstract val blockName: String
|
||||||
|
abstract val source: String
|
||||||
|
|
||||||
|
enum class Type {
|
||||||
|
CONTEXT,
|
||||||
|
SUPPLY
|
||||||
|
}
|
||||||
|
|
||||||
|
fun encodeToXml(): String {
|
||||||
|
val document = DocumentBuilderFactory.newInstance()
|
||||||
|
.newDocumentBuilder()
|
||||||
|
.newDocument()
|
||||||
|
|
||||||
|
val root = document.createElement(blockName)
|
||||||
|
root.setAttribute("source",source)
|
||||||
|
document.appendChild(root)
|
||||||
|
|
||||||
|
fillXml(document, root)
|
||||||
|
|
||||||
|
return document.toXmlString()
|
||||||
|
}
|
||||||
|
|
||||||
|
protected abstract fun fillXml(document: Document, root: Element)
|
||||||
|
|
||||||
|
protected fun appendTextElement(
|
||||||
|
document: Document,
|
||||||
|
parent: Element,
|
||||||
|
tagName: String,
|
||||||
|
value: Any?
|
||||||
|
): Element {
|
||||||
|
val element = document.createElement(tagName)
|
||||||
|
element.textContent = value?.toString() ?: ""
|
||||||
|
parent.appendChild(element)
|
||||||
|
return element
|
||||||
|
}
|
||||||
|
|
||||||
|
protected fun appendChildElement(
|
||||||
|
document: Document,
|
||||||
|
parent: Element,
|
||||||
|
tagName: String,
|
||||||
|
block: Element.() -> Unit = {}
|
||||||
|
): Element {
|
||||||
|
val element = document.createElement(tagName)
|
||||||
|
parent.appendChild(element)
|
||||||
|
element.block()
|
||||||
|
return element
|
||||||
|
}
|
||||||
|
|
||||||
|
protected fun appendCDataElement(
|
||||||
|
document: Document,
|
||||||
|
parent: Element,
|
||||||
|
tagName: String,
|
||||||
|
value: String?
|
||||||
|
): Element {
|
||||||
|
val element = document.createElement(tagName)
|
||||||
|
element.appendChild(document.createCDATASection(value ?: ""))
|
||||||
|
parent.appendChild(element)
|
||||||
|
return element
|
||||||
|
}
|
||||||
|
|
||||||
|
protected fun <T> appendListElement(
|
||||||
|
document: Document,
|
||||||
|
parent: Element,
|
||||||
|
wrapperTagName: String,
|
||||||
|
itemTagName: String,
|
||||||
|
values: Iterable<T>,
|
||||||
|
block: Element.(T) -> Unit = { value ->
|
||||||
|
textContent = value?.toString() ?: ""
|
||||||
|
}
|
||||||
|
): Element {
|
||||||
|
val wrapper = document.createElement(wrapperTagName)
|
||||||
|
parent.appendChild(wrapper)
|
||||||
|
|
||||||
|
for (value in values) {
|
||||||
|
val item = document.createElement(itemTagName)
|
||||||
|
wrapper.appendChild(item)
|
||||||
|
item.block(value)
|
||||||
|
}
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
}
|
||||||
|
|
||||||
|
protected fun <T> appendRepeatedElements(
|
||||||
|
document: Document,
|
||||||
|
parent: Element,
|
||||||
|
itemTagName: String,
|
||||||
|
values: Iterable<T>,
|
||||||
|
block: Element.(T) -> Unit = { value ->
|
||||||
|
textContent = value?.toString() ?: ""
|
||||||
|
}
|
||||||
|
) {
|
||||||
|
for (value in values) {
|
||||||
|
val item = document.createElement(itemTagName)
|
||||||
|
parent.appendChild(item)
|
||||||
|
item.block(value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private fun Document.toXmlString(): String {
|
||||||
|
val transformer = TransformerFactory.newInstance().newTransformer().apply {
|
||||||
|
setOutputProperty(OutputKeys.INDENT, "yes")
|
||||||
|
setOutputProperty(OutputKeys.ENCODING, "UTF-8")
|
||||||
|
setOutputProperty("{http://xml.apache.org/xslt}indent-amount", "2")
|
||||||
|
}
|
||||||
|
|
||||||
|
return StringWriter().use { writer ->
|
||||||
|
transformer.transform(DOMSource(this), StreamResult(writer))
|
||||||
|
writer.toString()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -24,6 +24,8 @@ abstract class RunningFlowContext {
|
|||||||
*/
|
*/
|
||||||
var target = source
|
var target = source
|
||||||
|
|
||||||
|
val contextBlocks = mutableListOf<ContextBlock>()
|
||||||
|
|
||||||
private val _additionalUserInfo = mutableMapOf<String, String>()
|
private val _additionalUserInfo = mutableMapOf<String, String>()
|
||||||
val additionalUserInfo: Map<String, String>
|
val additionalUserInfo: Map<String, String>
|
||||||
get() = _additionalUserInfo
|
get() = _additionalUserInfo
|
||||||
|
|||||||
Reference in New Issue
Block a user