mirror of
https://github.com/slhaf/Partner.git
synced 2026-05-12 16:53:04 +08:00
refactor(CommunicationProducer): split context and supply message assembly
This commit is contained in:
@@ -4,21 +4,35 @@ import com.alibaba.fastjson2.JSONObject;
|
||||
import lombok.Data;
|
||||
import lombok.EqualsAndHashCode;
|
||||
import org.jetbrains.annotations.NotNull;
|
||||
import org.w3c.dom.Document;
|
||||
import org.w3c.dom.Element;
|
||||
import work.slhaf.partner.api.agent.factory.capability.annotation.InjectCapability;
|
||||
import work.slhaf.partner.api.agent.factory.component.abstracts.AbstractAgentModule;
|
||||
import work.slhaf.partner.api.agent.factory.component.abstracts.ActivateModel;
|
||||
import work.slhaf.partner.api.agent.factory.component.annotation.Init;
|
||||
import work.slhaf.partner.api.agent.runtime.interaction.flow.ContextBlock;
|
||||
import work.slhaf.partner.api.chat.pojo.Message;
|
||||
import work.slhaf.partner.api.chat.pojo.MetaMessage;
|
||||
import work.slhaf.partner.core.cognation.CognationCapability;
|
||||
import work.slhaf.partner.module.common.entity.AppendPromptData;
|
||||
import work.slhaf.partner.module.common.model.ModelConstant;
|
||||
import work.slhaf.partner.runtime.interaction.data.context.PartnerRunningFlowContext;
|
||||
|
||||
import javax.xml.parsers.DocumentBuilderFactory;
|
||||
import javax.xml.transform.OutputKeys;
|
||||
import javax.xml.transform.Transformer;
|
||||
import javax.xml.transform.TransformerFactory;
|
||||
import javax.xml.transform.dom.DOMSource;
|
||||
import javax.xml.transform.stream.StreamResult;
|
||||
import java.io.ByteArrayInputStream;
|
||||
import java.io.StringWriter;
|
||||
import java.nio.charset.StandardCharsets;
|
||||
import java.time.LocalDateTime;
|
||||
import java.time.format.DateTimeFormatter;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Comparator;
|
||||
import java.util.LinkedHashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
import static work.slhaf.partner.common.util.ExtractUtil.extractJson;
|
||||
|
||||
@@ -26,9 +40,18 @@ import static work.slhaf.partner.common.util.ExtractUtil.extractJson;
|
||||
@Data
|
||||
public class CommunicationProducer extends AbstractAgentModule.Running<PartnerRunningFlowContext> implements ActivateModel {
|
||||
|
||||
private static final String MODULE_PROMPT = """
|
||||
你是 Partner 的表达模块。
|
||||
你接下来收到的消息固定分为三个区段:
|
||||
1. system message 是 Head, 用于说明整个输入结构与输出要求。
|
||||
2. <context> 区段只承载 type=CONTEXT 的上下文块, 其中每个子块都带有独立来源, 仅作为理解当前状态与辅助决策的依据。
|
||||
3. Conversation 区段是对话轨迹; 最新的一条 user message 会使用 <input> 结构, 其中 <content> 是本轮用户原始输入, 其他子标签是输入元信息与 type=SUPPLY 的补充块, 补充块会按 blockName 分区。
|
||||
你必须综合 Context 与 Conversation 回答最新输入, 不要把 XML 标签当作需要原样复述给用户的内容。
|
||||
直接输出最终回应内容即可, 不需要额外包装为 JSON。
|
||||
""";
|
||||
|
||||
@InjectCapability
|
||||
private CognationCapability cognationCapability;
|
||||
private final List<Message> appendedMessages = new ArrayList<>();
|
||||
private final List<Message> chatMessages = new ArrayList<>();
|
||||
|
||||
@Init
|
||||
@@ -48,172 +71,220 @@ public class CommunicationProducer extends AbstractAgentModule.Running<PartnerRu
|
||||
return true;
|
||||
}
|
||||
|
||||
@Override
|
||||
public @NotNull List<Message> modulePrompt() {
|
||||
return List.of(new Message(Message.Character.SYSTEM, MODULE_PROMPT));
|
||||
}
|
||||
|
||||
@Override
|
||||
public void execute(PartnerRunningFlowContext runningFlowContext) {
|
||||
String userId = runningFlowContext.getSource();
|
||||
log.debug("[CommunicationProducer] 主对话流程开始: {}", userId);
|
||||
beforeChat(runningFlowContext);
|
||||
log.debug("Communicating with: {}",runningFlowContext.getSource());
|
||||
executeChat(runningFlowContext);
|
||||
log.debug("[CommunicationProducer] 主对话流程({})结束...", userId);
|
||||
}
|
||||
|
||||
private void beforeChat(PartnerRunningFlowContext runningFlowContext) {
|
||||
setAppendedPromptMessage(runningFlowContext);
|
||||
activateModule(runningFlowContext);
|
||||
setMessageCount(runningFlowContext);
|
||||
|
||||
log.debug("[CommunicationProducer] 当前消息列表大小: {}", chatMessages.size());
|
||||
log.debug("[CommunicationProducer] 当前核心prompt内容: {}", runningFlowContext.getCoreContext().toString());
|
||||
|
||||
setMessage(runningFlowContext.getCoreContext().toString());
|
||||
}
|
||||
|
||||
// TODO need to update message appending logic
|
||||
private void setAppendedPromptMessage(PartnerRunningFlowContext runningFlowContext) {
|
||||
List<AppendPromptData> appendedPrompt = runningFlowContext.getModuleContext().getAppendedPrompt();
|
||||
int appendedPromptSize = getAppendedPromptSize(appendedPrompt);
|
||||
if (appendedPromptSize > 0) {
|
||||
setAppendedPromptMessage(appendedPrompt);
|
||||
}
|
||||
}
|
||||
|
||||
private void executeChat(PartnerRunningFlowContext runningFlowContext) {
|
||||
JSONObject response = new JSONObject();
|
||||
String responseText = null;
|
||||
|
||||
// TODO considering removing retries in module
|
||||
int count = 0;
|
||||
while (true) {
|
||||
try {
|
||||
String chatResponse = this.chat(buildChatMessages());
|
||||
try {
|
||||
response.putAll(JSONObject.parse(extractJson(chatResponse)));
|
||||
} catch (Exception e) {
|
||||
log.warn("主模型回复格式出错, 将直接作为消息返回, 建议尝试更换主模型...");
|
||||
handleExceptionResponse(response, chatResponse);
|
||||
}
|
||||
log.debug("[CommunicationProducer] CommunicationProducer 响应内容: {}", response);
|
||||
updateModuleContextAndChatMessages(runningFlowContext, response.getString("text"));
|
||||
// TODO 为各模块提供 emit msg 能力后, 在这里统一接收并分发结构化输出.
|
||||
responseText = this.chat(buildChatMessages(runningFlowContext));
|
||||
log.debug("CommunicationProducer responses: {}", responseText);
|
||||
updateModuleContextAndChatMessages(runningFlowContext, responseText);
|
||||
break;
|
||||
} catch (Exception e) {
|
||||
count++;
|
||||
log.error("[CommunicationProducer] CoreModel执行异常: {}", e.getLocalizedMessage());
|
||||
log.error("Communicating exception occurred: {}", e.getLocalizedMessage());
|
||||
if (count > 3) {
|
||||
handleExceptionResponse(response, "主模型交互出错: " + e.getLocalizedMessage());
|
||||
chatMessages.removeLast();
|
||||
responseText = "CommunicationProducer Failed: " + e.getLocalizedMessage();
|
||||
break;
|
||||
}
|
||||
} finally {
|
||||
updateCoreResponse(runningFlowContext, response);
|
||||
resetAppendedMessages();
|
||||
log.debug("[CommunicationProducer] 消息列表更新大小: {}", chatMessages.size());
|
||||
updateCoreResponse(runningFlowContext, responseText);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private int getAppendedPromptSize(List<AppendPromptData> appendedPrompt) {
|
||||
int size = 0;
|
||||
for (AppendPromptData data : appendedPrompt) {
|
||||
size += data.getAppendedPrompt().size();
|
||||
private void updateCoreResponse(PartnerRunningFlowContext runningFlowContext, String responseText) {
|
||||
runningFlowContext.getCoreResponse().put("text", responseText);
|
||||
}
|
||||
|
||||
private List<Message> buildChatMessages(PartnerRunningFlowContext runningFlowContext) {
|
||||
List<Message> temp = new ArrayList<>(chatMessages.size() + 2);
|
||||
Message contextMessage = buildContextMessage(runningFlowContext);
|
||||
if (contextMessage != null) {
|
||||
temp.add(contextMessage);
|
||||
}
|
||||
return size;
|
||||
}
|
||||
|
||||
private void activateModule(PartnerRunningFlowContext context) {
|
||||
for (AppendPromptData data : context.getModuleContext().getAppendedPrompt()) {
|
||||
if (data.getAppendedPrompt().isEmpty()) continue;
|
||||
context.getCoreContext().activateModule(data.getModuleName());
|
||||
}
|
||||
}
|
||||
|
||||
private void updateCoreResponse(PartnerRunningFlowContext runningFlowContext, JSONObject response) {
|
||||
runningFlowContext.getCoreResponse().put("text", response.getString("text"));
|
||||
}
|
||||
|
||||
private void resetAppendedMessages() {
|
||||
this.appendedMessages.clear();
|
||||
}
|
||||
|
||||
private List<Message> buildChatMessages() {
|
||||
List<Message> temp = new ArrayList<>(appendedMessages.size() + chatMessages.size());
|
||||
temp.addAll(appendedMessages);
|
||||
temp.addAll(chatMessages);
|
||||
temp.add(buildInputMessage(runningFlowContext));
|
||||
return temp;
|
||||
}
|
||||
|
||||
private void updateModuleContextAndChatMessages(PartnerRunningFlowContext runningFlowContext, String response) {
|
||||
cognationCapability.getMessageLock().lock();
|
||||
chatMessages.removeIf(m -> {
|
||||
if (m.getRole() == Message.Character.ASSISTANT) {
|
||||
return false;
|
||||
}
|
||||
try {
|
||||
JSONObject.parseObject(extractJson(m.getContent()));
|
||||
return true;
|
||||
} catch (Exception e) {
|
||||
return false;
|
||||
}
|
||||
});
|
||||
//添加时间标志
|
||||
// TODO 此处的时间标识应当采用 RunningFlowContext 携带时间
|
||||
String dateTime = LocalDateTime.now().format(DateTimeFormatter.ofPattern("\r\n**[yyyy-MM-dd HH:mm:ss]"));
|
||||
Message primaryUserMessage = new Message(Message.Character.USER, runningFlowContext.getCoreContext().getText() + dateTime);
|
||||
chatMessages.add(primaryUserMessage);
|
||||
Message assistantMessage = new Message(Message.Character.ASSISTANT, response);
|
||||
chatMessages.add(assistantMessage);
|
||||
cognationCapability.getMessageLock().unlock();
|
||||
//区分单人聊天场景
|
||||
// if (runningFlowContext.isSingle()) {
|
||||
MetaMessage metaMessage = new MetaMessage(primaryUserMessage, assistantMessage);
|
||||
cognationCapability.addMetaMessage(runningFlowContext.getSource(), metaMessage);
|
||||
// }
|
||||
}
|
||||
|
||||
private void setMessage(String coreContextStr) {
|
||||
Message userMessage = new Message(Message.Character.USER, coreContextStr);
|
||||
chatMessages.add(userMessage);
|
||||
}
|
||||
|
||||
private void handleExceptionResponse(JSONObject response, String chatResponse) {
|
||||
response.put("text", chatResponse);
|
||||
// interactionContext.setFinished(true);
|
||||
}
|
||||
|
||||
private void setMessageCount(PartnerRunningFlowContext runningFlowContext) {
|
||||
runningFlowContext.getModuleContext().getExtraContext().put("message_count", chatMessages.size());
|
||||
}
|
||||
|
||||
private void setAppendedPromptMessage(List<AppendPromptData> appendPrompt) {
|
||||
Message appendDeclareMessage = new Message(Message.Character.USER, ModelConstant.CharacterPrefix.SYSTEM + "认知补充开始");
|
||||
this.appendedMessages.add(appendDeclareMessage);
|
||||
for (AppendPromptData data : appendPrompt) {
|
||||
setStartMessage(data);
|
||||
setContentMessage(data);
|
||||
setEndMessage(data);
|
||||
setAssistantMessage();
|
||||
try {
|
||||
chatMessages.removeIf(this::isStructuredUserMessage);
|
||||
// TODO 此处的时间标识应当采用 RunningFlowContext 携带时间
|
||||
String dateTime = LocalDateTime.now().format(DateTimeFormatter.ofPattern("\r\n**[yyyy-MM-dd HH:mm:ss]"));
|
||||
Message primaryUserMessage = new Message(
|
||||
Message.Character.USER,
|
||||
formatConversationUserMessage(runningFlowContext) + dateTime
|
||||
);
|
||||
chatMessages.add(primaryUserMessage);
|
||||
Message assistantMessage = new Message(Message.Character.ASSISTANT, response);
|
||||
chatMessages.add(assistantMessage);
|
||||
MetaMessage metaMessage = new MetaMessage(primaryUserMessage, assistantMessage);
|
||||
cognationCapability.addMetaMessage(runningFlowContext.getSource(), metaMessage);
|
||||
} finally {
|
||||
cognationCapability.getMessageLock().unlock();
|
||||
}
|
||||
Message appendEndMessage = new Message(Message.Character.USER, ModelConstant.CharacterPrefix.SYSTEM + "认知补充结束");
|
||||
this.appendedMessages.add(appendEndMessage);
|
||||
}
|
||||
|
||||
private void setAssistantMessage() {
|
||||
Message message = new Message(Message.Character.ASSISTANT, "嗯,明白了");
|
||||
appendedMessages.add(message);
|
||||
private Message buildContextMessage(PartnerRunningFlowContext runningFlowContext) {
|
||||
List<ContextBlock> contextBlocks = filterContextBlocks(
|
||||
runningFlowContext.getContextBlocks(),
|
||||
ContextBlock.Type.CONTEXT
|
||||
);
|
||||
if (contextBlocks.isEmpty()) {
|
||||
return null;
|
||||
}
|
||||
return new Message(Message.Character.USER, buildContextXml(contextBlocks));
|
||||
}
|
||||
|
||||
private void setEndMessage(AppendPromptData data) {
|
||||
Message endMessage = new Message(Message.Character.USER, ModelConstant.CharacterPrefix.SYSTEM + data.getModuleName() + "认知补充结束.");
|
||||
appendedMessages.add(endMessage);
|
||||
private Message buildInputMessage(PartnerRunningFlowContext runningFlowContext) {
|
||||
return new Message(Message.Character.USER, buildInputXml(runningFlowContext));
|
||||
}
|
||||
|
||||
private void setContentMessage(AppendPromptData data) {
|
||||
data.getAppendedPrompt().forEach((k, v) -> {
|
||||
Message contentMessage = new Message(Message.Character.USER, ModelConstant.CharacterPrefix.SYSTEM + k + v + "\r\n");
|
||||
appendedMessages.add(contentMessage);
|
||||
});
|
||||
private String buildContextXml(List<ContextBlock> contextBlocks) {
|
||||
try {
|
||||
Document document = newDocument();
|
||||
Element root = document.createElement("context");
|
||||
document.appendChild(root);
|
||||
|
||||
contextBlocks.stream()
|
||||
.sorted(Comparator.comparingInt(ContextBlock::getPriority))
|
||||
.map(ContextBlock::encodeToXml)
|
||||
.forEach(blockXml -> {
|
||||
Element blockElement = parseElement(blockXml);
|
||||
root.appendChild(document.importNode(blockElement, true));
|
||||
});
|
||||
|
||||
return toXmlString(document);
|
||||
} catch (Exception e) {
|
||||
throw new IllegalStateException("构建 context 区段失败", e);
|
||||
}
|
||||
}
|
||||
|
||||
private void setStartMessage(AppendPromptData data) {
|
||||
Message startMessage = new Message(Message.Character.USER, ModelConstant.CharacterPrefix.SYSTEM + data.getModuleName() + "以下为" + data.getModuleName() + "相关认知.");
|
||||
appendedMessages.add(startMessage);
|
||||
private String buildInputXml(PartnerRunningFlowContext runningFlowContext) {
|
||||
try {
|
||||
Document document = newDocument();
|
||||
Element root = document.createElement("input");
|
||||
document.appendChild(root);
|
||||
|
||||
appendTextElement(document, root, "content", runningFlowContext.getInput());
|
||||
appendTextElement(document, root, "source", runningFlowContext.getSource());
|
||||
for (Map.Entry<String, String> entry : runningFlowContext.getAdditionalUserInfo().entrySet()) {
|
||||
appendTextElement(document, root, sanitizeTagName(entry.getKey()), entry.getValue());
|
||||
}
|
||||
appendSupplyBlocks(document, root, runningFlowContext.getContextBlocks());
|
||||
|
||||
return toXmlString(document);
|
||||
} catch (Exception e) {
|
||||
throw new IllegalStateException("构建 input 区段失败", e);
|
||||
}
|
||||
}
|
||||
|
||||
private boolean isStructuredUserMessage(Message message) {
|
||||
if (message.getRole() == Message.Character.ASSISTANT) {
|
||||
return false;
|
||||
}
|
||||
String content = message.getContent();
|
||||
String trimmed = content.trim();
|
||||
if (trimmed.startsWith("<input>") || trimmed.startsWith("<context>") || trimmed.startsWith("<?xml")) {
|
||||
return true;
|
||||
}
|
||||
try {
|
||||
JSONObject.parseObject(extractJson(content));
|
||||
return true;
|
||||
} catch (Exception e) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
private List<ContextBlock> filterContextBlocks(List<ContextBlock> contextBlocks, ContextBlock.Type type) {
|
||||
return contextBlocks.stream()
|
||||
.filter(block -> block.getType() == type)
|
||||
.sorted(Comparator.comparingInt(ContextBlock::getPriority))
|
||||
.toList();
|
||||
}
|
||||
|
||||
private String formatConversationUserMessage(PartnerRunningFlowContext runningFlowContext) {
|
||||
return runningFlowContext.getSource() + ": " + runningFlowContext.getInput();
|
||||
}
|
||||
|
||||
private Document newDocument() throws Exception {
|
||||
return DocumentBuilderFactory.newInstance()
|
||||
.newDocumentBuilder()
|
||||
.newDocument();
|
||||
}
|
||||
|
||||
private void appendTextElement(Document document, Element parent, String tagName, String value) {
|
||||
Element element = document.createElement(tagName);
|
||||
element.setTextContent(value == null ? "" : value);
|
||||
parent.appendChild(element);
|
||||
}
|
||||
|
||||
private Element parseElement(String xml) {
|
||||
try {
|
||||
Document parsedDocument = DocumentBuilderFactory.newInstance()
|
||||
.newDocumentBuilder()
|
||||
.parse(new ByteArrayInputStream(xml.getBytes(StandardCharsets.UTF_8)));
|
||||
return parsedDocument.getDocumentElement();
|
||||
} catch (Exception e) {
|
||||
throw new IllegalStateException("解析 ContextBlock XML 失败", e);
|
||||
}
|
||||
}
|
||||
|
||||
private void appendSupplyBlocks(Document document, Element inputRoot, List<ContextBlock> contextBlocks) {
|
||||
Map<String, List<ContextBlock>> groupedBlocks = filterContextBlocks(contextBlocks, ContextBlock.Type.SUPPLY).stream()
|
||||
.collect(Collectors.groupingBy(
|
||||
block -> sanitizeTagName(block.getBlockName()),
|
||||
LinkedHashMap::new,
|
||||
Collectors.toList()
|
||||
));
|
||||
|
||||
for (Map.Entry<String, List<ContextBlock>> entry : groupedBlocks.entrySet()) {
|
||||
Element groupElement = document.createElement(entry.getKey());
|
||||
inputRoot.appendChild(groupElement);
|
||||
for (ContextBlock block : entry.getValue()) {
|
||||
Element blockElement = parseElement(block.encodeToXml());
|
||||
groupElement.appendChild(document.importNode(blockElement, true));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private String sanitizeTagName(String rawTagName) {
|
||||
if (rawTagName == null || rawTagName.isBlank()) {
|
||||
return "meta";
|
||||
}
|
||||
String sanitized = rawTagName.replaceAll("[^A-Za-z0-9_.-]", "_");
|
||||
if (!Character.isLetter(sanitized.charAt(0)) && sanitized.charAt(0) != '_') {
|
||||
sanitized = "_" + sanitized;
|
||||
}
|
||||
return sanitized;
|
||||
}
|
||||
|
||||
private String toXmlString(Document document) throws Exception {
|
||||
Transformer transformer = TransformerFactory.newInstance().newTransformer();
|
||||
transformer.setOutputProperty(OutputKeys.INDENT, "yes");
|
||||
transformer.setOutputProperty(OutputKeys.ENCODING, "UTF-8");
|
||||
transformer.setOutputProperty(OutputKeys.OMIT_XML_DECLARATION, "yes");
|
||||
transformer.setOutputProperty("{http://xml.apache.org/xslt}indent-amount", "2");
|
||||
StringWriter writer = new StringWriter();
|
||||
transformer.transform(new DOMSource(document), new StreamResult(writer));
|
||||
return writer.toString();
|
||||
}
|
||||
|
||||
@Override
|
||||
|
||||
Reference in New Issue
Block a user