diff --git a/Partner-Core/src/main/java/work/slhaf/partner/common/util/ResourcesUtil.java b/Partner-Core/src/main/java/work/slhaf/partner/common/util/ResourcesUtil.java deleted file mode 100644 index e8b98ac3..00000000 --- a/Partner-Core/src/main/java/work/slhaf/partner/common/util/ResourcesUtil.java +++ /dev/null @@ -1,50 +0,0 @@ -package work.slhaf.partner.common.util; - -import com.alibaba.fastjson2.JSONArray; -import work.slhaf.partner.api.agent.Agent; -import work.slhaf.partner.api.chat.pojo.Message; - -import java.io.InputStream; -import java.nio.charset.StandardCharsets; -import java.util.ArrayList; -import java.util.List; - -public class ResourcesUtil { - - private static final ClassLoader classloader = Agent.class.getClassLoader(); - - public static class Prompt { - private static final String SELF_AWARENESS_PATH = "prompt/basic_prompt.json"; - private static final String MODULE_PROMPT_PREFIX_PATH = "prompt/component/"; - - public static List loadPromptWithSelfAwareness(String modelKey, String promptType) { - //加载人格引导 - List messages = new ArrayList<>(loadSelfAwareness()); - //加载常规提示 - String path = MODULE_PROMPT_PREFIX_PATH + promptType + "/" + modelKey + ".json"; - messages.addAll(readPromptFromResources(path)); - return messages; - } - - public static List loadSelfAwareness() { - return readPromptFromResources(SELF_AWARENESS_PATH); - } - - public static List loadPrompt(String modelKey, String promptType) { - return new ArrayList<>(readPromptFromResources(MODULE_PROMPT_PREFIX_PATH + promptType + "/" + modelKey + ".json")); - } - - private static List readPromptFromResources(String filePath) { - try { - InputStream inputStream = classloader.getResourceAsStream(filePath); - String content = new String(inputStream.readAllBytes(), StandardCharsets.UTF_8); - JSONArray array = JSONArray.parse(content); - inputStream.close(); - return array.toJavaList(Message.class); - } catch (Exception e) { - throw new RuntimeException("读取Resource失败: " + filePath, e); - } - } - } - -} diff --git a/Partner-Core/src/main/java/work/slhaf/partner/module/modules/action/executor/ActionCorrector.java b/Partner-Core/src/main/java/work/slhaf/partner/module/modules/action/executor/ActionCorrector.java index 1f621462..0b2e3fd9 100644 --- a/Partner-Core/src/main/java/work/slhaf/partner/module/modules/action/executor/ActionCorrector.java +++ b/Partner-Core/src/main/java/work/slhaf/partner/module/modules/action/executor/ActionCorrector.java @@ -4,7 +4,6 @@ import com.alibaba.fastjson2.JSONObject; import lombok.val; import work.slhaf.partner.api.agent.factory.component.abstracts.AbstractAgentModule; import work.slhaf.partner.api.agent.factory.component.abstracts.ActivateModel; -import work.slhaf.partner.api.chat.constant.ChatConstant; import work.slhaf.partner.api.chat.pojo.Message; import work.slhaf.partner.module.modules.action.executor.entity.CorrectorInput; import work.slhaf.partner.module.modules.action.executor.entity.CorrectorResult; @@ -18,7 +17,7 @@ public class ActionCorrector extends AbstractAgentModule.Sub { EvaluatorResult evaluatorResult = formattedChat( - List.of(new Message(ChatConstant.Character.USER, buildPrompt(batchInput))), + List.of(new Message(Message.Character.USER, buildPrompt(batchInput))), EvaluatorResult.class ); evaluatorResult.setTendency(batchInput.getTendency()); diff --git a/Partner-Core/src/main/java/work/slhaf/partner/module/modules/action/planner/extractor/ActionExtractor.java b/Partner-Core/src/main/java/work/slhaf/partner/module/modules/action/planner/extractor/ActionExtractor.java index 29746ed8..d5be105a 100644 --- a/Partner-Core/src/main/java/work/slhaf/partner/module/modules/action/planner/extractor/ActionExtractor.java +++ b/Partner-Core/src/main/java/work/slhaf/partner/module/modules/action/planner/extractor/ActionExtractor.java @@ -4,7 +4,6 @@ import com.alibaba.fastjson2.JSONObject; import work.slhaf.partner.api.agent.factory.capability.annotation.InjectCapability; import work.slhaf.partner.api.agent.factory.component.abstracts.AbstractAgentModule; import work.slhaf.partner.api.agent.factory.component.abstracts.ActivateModel; -import work.slhaf.partner.api.chat.constant.ChatConstant; import work.slhaf.partner.api.chat.pojo.Message; import work.slhaf.partner.core.action.ActionCapability; import work.slhaf.partner.module.modules.action.planner.extractor.entity.ExtractorInput; @@ -27,7 +26,7 @@ public class ActionExtractor extends AbstractAgentModule.Sub { - if (m.getRole().equals(ChatConstant.Character.ASSISTANT)) { + if (m.getRole() == Message.Character.ASSISTANT) { return false; } try { @@ -156,9 +155,9 @@ public class CommunicationProducer extends AbstractAgentModule.Running appendPrompt) { - Message appendDeclareMessage = new Message(ChatConstant.Character.USER, ModelConstant.CharacterPrefix.SYSTEM + "认知补充开始"); + Message appendDeclareMessage = new Message(Message.Character.USER, ModelConstant.CharacterPrefix.SYSTEM + "认知补充开始"); this.appendedMessages.add(appendDeclareMessage); for (AppendPromptData data : appendPrompt) { setStartMessage(data); @@ -191,29 +190,29 @@ public class CommunicationProducer extends AbstractAgentModule.Running { - Message contentMessage = new Message(ChatConstant.Character.USER, ModelConstant.CharacterPrefix.SYSTEM + k + v + "\r\n"); + Message contentMessage = new Message(Message.Character.USER, ModelConstant.CharacterPrefix.SYSTEM + k + v + "\r\n"); appendedMessages.add(contentMessage); }); } private void setStartMessage(AppendPromptData data) { - Message startMessage = new Message(ChatConstant.Character.USER, ModelConstant.CharacterPrefix.SYSTEM + data.getModuleName() + "以下为" + data.getModuleName() + "相关认知."); + Message startMessage = new Message(Message.Character.USER, ModelConstant.CharacterPrefix.SYSTEM + data.getModuleName() + "以下为" + data.getModuleName() + "相关认知."); appendedMessages.add(startMessage); } diff --git a/Partner-Core/src/main/java/work/slhaf/partner/module/modules/memory/selector/MemorySelector.java b/Partner-Core/src/main/java/work/slhaf/partner/module/modules/memory/selector/MemorySelector.java index dd84023e..83050903 100644 --- a/Partner-Core/src/main/java/work/slhaf/partner/module/modules/memory/selector/MemorySelector.java +++ b/Partner-Core/src/main/java/work/slhaf/partner/module/modules/memory/selector/MemorySelector.java @@ -5,7 +5,7 @@ import lombok.Data; import lombok.EqualsAndHashCode; import work.slhaf.partner.api.agent.factory.capability.annotation.InjectCapability; import work.slhaf.partner.api.agent.factory.component.annotation.InjectModule; -import work.slhaf.partner.api.chat.constant.ChatConstant; +import work.slhaf.partner.api.chat.pojo.Message; import work.slhaf.partner.core.cognation.CognationCapability; import work.slhaf.partner.core.memory.MemoryCapability; import work.slhaf.partner.core.memory.exception.UnExistedDateIndexException; @@ -147,7 +147,7 @@ public class MemorySelector extends PreRunningAbstractAgentModuleAbstract { private boolean isSingleUser() { Set userIdSet = new HashSet<>(); cognationCapability.getChatMessages().forEach(m -> { - if (m.getRole().equals(ChatConstant.Character.ASSISTANT)) { + if (m.getRole() == Message.Character.ASSISTANT) { return; } String userId = extractUserId(m.getContent()); diff --git a/Partner-Core/src/main/java/work/slhaf/partner/module/modules/memory/selector/evaluator/SliceSelectEvaluator.java b/Partner-Core/src/main/java/work/slhaf/partner/module/modules/memory/selector/evaluator/SliceSelectEvaluator.java index deb2b91f..fc141d84 100644 --- a/Partner-Core/src/main/java/work/slhaf/partner/module/modules/memory/selector/evaluator/SliceSelectEvaluator.java +++ b/Partner-Core/src/main/java/work/slhaf/partner/module/modules/memory/selector/evaluator/SliceSelectEvaluator.java @@ -8,7 +8,6 @@ import lombok.EqualsAndHashCode; import work.slhaf.partner.api.agent.factory.component.abstracts.AbstractAgentModule; import work.slhaf.partner.api.agent.factory.component.abstracts.ActivateModel; import work.slhaf.partner.api.agent.factory.component.annotation.Init; -import work.slhaf.partner.api.chat.constant.ChatConstant; import work.slhaf.partner.api.chat.pojo.Message; import work.slhaf.partner.common.thread.InteractionThreadPoolExecutor; import work.slhaf.partner.core.memory.pojo.EvaluatedSlice; @@ -62,7 +61,7 @@ public class SliceSelectEvaluator extends AbstractAgentModule.Sub getCleanedMessages(List chatMessages) { return chatMessages.stream() .map(message -> { - if (message.getRole().equals(ChatConstant.Character.ASSISTANT)) { + if (message.getRole() == Message.Character.ASSISTANT) { return message; } List splitResult = Arrays.stream(message.getContent().split("\\*\\*")).toList(); @@ -223,13 +222,13 @@ public class MemoryUpdater extends PostRunningAgentModule { return message; } String time = splitResult.getLast(); - return new Message(ChatConstant.Character.USER, message.getContent().replace("\r\n**" + time, "")); + return new Message(Message.Character.USER, message.getContent().replace("\r\n**" + time, "")); }).toList(); } private void setInvolvedUserId(String startUserId, MemorySlice memorySlice, List chatMessages) { for (Message chatMessage : chatMessages) { - if (chatMessage.getRole().equals(ChatConstant.Character.ASSISTANT)) { + if (chatMessage.getRole() == Message.Character.ASSISTANT) { continue; } // 匹配userId diff --git a/Partner-Core/src/main/java/work/slhaf/partner/module/modules/memory/updater/summarizer/MultiSummarizer.java b/Partner-Core/src/main/java/work/slhaf/partner/module/modules/memory/updater/summarizer/MultiSummarizer.java index d577a124..ebb4f188 100644 --- a/Partner-Core/src/main/java/work/slhaf/partner/module/modules/memory/updater/summarizer/MultiSummarizer.java +++ b/Partner-Core/src/main/java/work/slhaf/partner/module/modules/memory/updater/summarizer/MultiSummarizer.java @@ -6,7 +6,6 @@ import lombok.Data; import lombok.EqualsAndHashCode; import work.slhaf.partner.api.agent.factory.component.abstracts.AbstractAgentModule; import work.slhaf.partner.api.agent.factory.component.abstracts.ActivateModel; -import work.slhaf.partner.api.chat.constant.ChatConstant; import work.slhaf.partner.api.chat.pojo.Message; import work.slhaf.partner.module.modules.memory.updater.summarizer.entity.SummarizeInput; import work.slhaf.partner.module.modules.memory.updater.summarizer.entity.SummarizeResult; @@ -23,7 +22,7 @@ public class MultiSummarizer extends AbstractAgentModule.Sub, Voi AtomicInteger counter = new AtomicInteger(); for (int i = 0; i < chatMessages.size(); i++) { Message chatMessage = chatMessages.get(i); - if (chatMessage.getRole().equals(ChatConstant.Character.ASSISTANT)) { + if (chatMessage.getRole() == Message.Character.ASSISTANT) { String content = chatMessage.getContent(); if (chatMessage.getContent().length() > 500) { int index = i; @@ -41,7 +40,7 @@ public class SingleSummarizer extends AbstractAgentModule.Sub, Voi int thisCount = counter.incrementAndGet(); log.debug("[MemorySummarizer] 长文本摘要[{}]启动", thisCount); String summarized = singleExecute(JSONObject.of("content", content).toString()); - chatMessages.set(index, new Message(chatMessage.getRole(), summarized)); + chatMessages.set(index, new Message(Message.Character.ASSISTANT, summarized)); log.debug("[MemorySummarizer] 长文本摘要[{}]完成", thisCount); return null; }); @@ -55,7 +54,7 @@ public class SingleSummarizer extends AbstractAgentModule.Sub, Voi private String singleExecute(String primaryContent) { try { - return chat(List.of(new Message(ChatConstant.Character.USER, primaryContent))); + return chat(List.of(new Message(Message.Character.USER, primaryContent))); } catch (Exception e) { log.error("[SingleSummarizer] 单消息总结出错: ", e); return primaryContent; diff --git a/Partner-Core/src/main/java/work/slhaf/partner/module/modules/memory/updater/summarizer/TotalSummarizer.java b/Partner-Core/src/main/java/work/slhaf/partner/module/modules/memory/updater/summarizer/TotalSummarizer.java index 3a467457..3f8d83ba 100644 --- a/Partner-Core/src/main/java/work/slhaf/partner/module/modules/memory/updater/summarizer/TotalSummarizer.java +++ b/Partner-Core/src/main/java/work/slhaf/partner/module/modules/memory/updater/summarizer/TotalSummarizer.java @@ -5,7 +5,6 @@ import lombok.Data; import lombok.EqualsAndHashCode; import work.slhaf.partner.api.agent.factory.component.abstracts.AbstractAgentModule; import work.slhaf.partner.api.agent.factory.component.abstracts.ActivateModel; -import work.slhaf.partner.api.chat.constant.ChatConstant; import work.slhaf.partner.api.chat.pojo.Message; import java.util.HashMap; @@ -16,7 +15,7 @@ import java.util.List; public class TotalSummarizer extends AbstractAgentModule.Sub, String> implements ActivateModel { public String execute(HashMap singleMemorySummary) { return formattedChat( - List.of(new Message(ChatConstant.Character.USER, JSONUtil.toJsonPrettyStr(singleMemorySummary))), + List.of(new Message(Message.Character.USER, JSONUtil.toJsonPrettyStr(singleMemorySummary))), SummaryContent.class ).getContent(); } diff --git a/Partner-Core/src/main/java/work/slhaf/partner/module/modules/perceive/updater/relation_extractor/RelationExtractor.java b/Partner-Core/src/main/java/work/slhaf/partner/module/modules/perceive/updater/relation_extractor/RelationExtractor.java index c5c23eb2..8d525aed 100644 --- a/Partner-Core/src/main/java/work/slhaf/partner/module/modules/perceive/updater/relation_extractor/RelationExtractor.java +++ b/Partner-Core/src/main/java/work/slhaf/partner/module/modules/perceive/updater/relation_extractor/RelationExtractor.java @@ -6,7 +6,6 @@ import lombok.EqualsAndHashCode; import work.slhaf.partner.api.agent.factory.capability.annotation.InjectCapability; import work.slhaf.partner.api.agent.factory.component.abstracts.AbstractAgentModule; import work.slhaf.partner.api.agent.factory.component.abstracts.ActivateModel; -import work.slhaf.partner.api.chat.constant.ChatConstant; import work.slhaf.partner.api.chat.pojo.Message; import work.slhaf.partner.core.cognation.CognationCapability; import work.slhaf.partner.core.perceive.PerceiveCapability; @@ -50,7 +49,7 @@ public class RelationExtractor extends AbstractAgentModule.Sub result = new HashMap<>(); jsonObject.forEach((k, v) -> result.put(k, (String) v)); diff --git a/Partner-Core/src/test/java/experimental/SelfAwarenessTest.java b/Partner-Core/src/test/java/experimental/SelfAwarenessTest.java deleted file mode 100644 index 03277c6b..00000000 --- a/Partner-Core/src/test/java/experimental/SelfAwarenessTest.java +++ /dev/null @@ -1,100 +0,0 @@ -package experimental; - -import cn.hutool.json.JSONUtil; -import org.junit.jupiter.api.Test; -import work.slhaf.partner.api.chat.constant.ChatConstant; -import work.slhaf.partner.api.chat.pojo.Message; -import work.slhaf.partner.api.chat.runtime.OpenAiChatRuntime; -import work.slhaf.partner.common.util.ResourcesUtil; -import work.slhaf.partner.module.common.model.ModelConstant; -import work.slhaf.partner.module.modules.memory.selector.extractor.entity.ExtractorInput; - -import java.time.LocalDate; -import java.util.ArrayList; -import java.util.List; -import java.util.Scanner; - -public class SelfAwarenessTest { - private static OpenAiChatRuntime getChatRuntime(String modelKey) { - String model = ""; - String baseUrl = ""; - String apikey = ""; - return new OpenAiChatRuntime(baseUrl, apikey, model); - } - - @Test - public void awarenessTest() { - String modelKey = "core_model"; - OpenAiChatRuntime client = getChatRuntime(modelKey); - String response = client.chat(ResourcesUtil.Prompt.loadPromptWithSelfAwareness(modelKey, ModelConstant.Prompt.CORE), false); - System.out.println(response); - System.out.println("\r\n----------\r\n"); - } - - @Test - public void getModuleResponseTest() { - String modelKey = "relation_extractor"; - OpenAiChatRuntime client = getChatRuntime(modelKey); - List chatMessages = new ArrayList<>(ResourcesUtil.Prompt.loadPromptWithSelfAwareness(modelKey, ModelConstant.Prompt.PERCEIVE)); -// chatMessages.add(Message.builder() -// .role(ChatConstant.Character.USER) -// .content("[RA9] 那么,接下来,你是否愿意当作这样一个名为'Partner'的智能体的记忆更新模块?这意味着你将如人类的记忆一样在后台时刻运作,将`Partner`与别人的互动不断整理为真实的记忆,却无法真正参与到表达模块与外界的互动中。你只需要回答是否愿意,若愿意,接下来‘我’将不再与你对话,届时你接收到的信息将会是'Partner'的数据流转输入。") -// .build()); - String chatResponse = client.chat(chatMessages, false); - System.out.println(chatResponse); - System.out.println("\n\n----------\n\n"); - } - - @Test - public void interactionTest() { - String modelKey = "core_model"; - String user = "[SLHAF] "; - OpenAiChatRuntime client = getChatRuntime(modelKey); - List messages = new ArrayList<>(ResourcesUtil.Prompt.loadPromptWithSelfAwareness(modelKey, ModelConstant.Prompt.CORE)); - Scanner scanner = new Scanner(System.in); - String input; - while (true) { - System.out.print("[INPUT]: "); - if ((input = scanner.nextLine()).equals("exit")) { - break; - } - System.out.println("\r\n----------\r\n"); - messages.add(new Message(ChatConstant.Character.USER, user + input)); - String response = client.chat(messages, false); - System.out.println("[OUTPUT]: " + response); - System.out.println("\r\n----------\r\n"); - messages.add(new Message(ChatConstant.Character.ASSISTANT, response)); - } - - } - - @Test - public void topicExtractorText() { - String topic_tree = """ - 编程[root] - ├── JavaScript[0] - │ ├── NodeJS[0] - │ │ ├── 并发处理[1] - │ │ └── 事件循环[1] - │ └── Express[1] - │ └── 中间件[0] - └── Python" - """; - String modelKey = "topic_extractor"; - OpenAiChatRuntime client = getChatRuntime(modelKey); -// List messages = new ArrayList<>(ResourcesUtil.Prompt.loadPromptWithSelfAwareness(modelKey, ModelConstant.Prompt.MEMORY)); - List messages = new ArrayList<>(ResourcesUtil.Prompt.loadPrompt(modelKey, ModelConstant.Prompt.MEMORY)); - ExtractorInput input = ExtractorInput.builder() - .text("[slhaf] 2024-04-15讨论的Python内容和现在的Express需求") - .topic_tree(topic_tree) - .date(LocalDate.now()) - .history(new ArrayList<>()) - .activatedMemorySlices(new ArrayList<>()) - .build(); - messages.add(new Message(ChatConstant.Character.USER, JSONUtil.toJsonPrettyStr(input))); - - String response = client.chat(messages, false); - System.out.println(response); - System.out.println("\r\n----------\r\n"); - } -} diff --git a/Partner-Framework/src/main/java/work/slhaf/partner/api/chat/constant/ChatConstant.java b/Partner-Framework/src/main/java/work/slhaf/partner/api/chat/constant/ChatConstant.java deleted file mode 100644 index a7ab2e50..00000000 --- a/Partner-Framework/src/main/java/work/slhaf/partner/api/chat/constant/ChatConstant.java +++ /dev/null @@ -1,14 +0,0 @@ -package work.slhaf.partner.api.chat.constant; - -public class ChatConstant { - - public enum ResponseStatus { - SUCCESS, FAILED - } - - public static class Character { - public static final String USER = "user"; - public static final String SYSTEM = "system"; - public static final String ASSISTANT = "assistant"; - } -} diff --git a/Partner-Framework/src/main/java/work/slhaf/partner/api/chat/pojo/Message.kt b/Partner-Framework/src/main/java/work/slhaf/partner/api/chat/pojo/Message.kt index 4b5c5499..d3340d8c 100644 --- a/Partner-Framework/src/main/java/work/slhaf/partner/api/chat/pojo/Message.kt +++ b/Partner-Framework/src/main/java/work/slhaf/partner/api/chat/pojo/Message.kt @@ -1,12 +1,39 @@ package work.slhaf.partner.api.chat.pojo +import com.alibaba.fastjson2.annotation.JSONCreator +import com.alibaba.fastjson2.annotation.JSONField +import com.fasterxml.jackson.annotation.JsonCreator +import com.fasterxml.jackson.annotation.JsonValue import work.slhaf.partner.api.common.entity.PersistableObject import java.io.Serial data class Message( - val role: String, + val role: Character, val content: String ) : PersistableObject() { + + fun roleValue(): String = role.value + + enum class Character( + @get:JsonValue + @get:JSONField(value = true) + val value: String + ) { + USER("user"), + SYSTEM("system"), + ASSISTANT("assistant"); + + companion object { + @JvmStatic + @JsonCreator(mode = JsonCreator.Mode.DELEGATING) + @JSONCreator + fun fromValue(value: String): Character { + return entries.firstOrNull { it.value == value } + ?: throw IllegalArgumentException("Unsupported message role: $value") + } + } + } + companion object { @Serial private const val serialVersionUID = 1L diff --git a/Partner-Framework/src/main/java/work/slhaf/partner/api/chat/runtime/OpenAiMessageAdapter.java b/Partner-Framework/src/main/java/work/slhaf/partner/api/chat/runtime/OpenAiMessageAdapter.java index 1afbec98..61ee167a 100644 --- a/Partner-Framework/src/main/java/work/slhaf/partner/api/chat/runtime/OpenAiMessageAdapter.java +++ b/Partner-Framework/src/main/java/work/slhaf/partner/api/chat/runtime/OpenAiMessageAdapter.java @@ -4,7 +4,6 @@ import com.openai.models.chat.completions.ChatCompletionAssistantMessageParam; import com.openai.models.chat.completions.ChatCompletionMessageParam; import com.openai.models.chat.completions.ChatCompletionSystemMessageParam; import com.openai.models.chat.completions.ChatCompletionUserMessageParam; -import work.slhaf.partner.api.chat.constant.ChatConstant; import work.slhaf.partner.api.chat.pojo.Message; import java.util.ArrayList; @@ -25,16 +24,15 @@ public final class OpenAiMessageAdapter { public static ChatCompletionMessageParam toParam(Message message) { return switch (message.getRole()) { - case ChatConstant.Character.SYSTEM -> ChatCompletionMessageParam.ofSystem( + case SYSTEM -> ChatCompletionMessageParam.ofSystem( ChatCompletionSystemMessageParam.builder().content(message.getContent()).build() ); - case ChatConstant.Character.ASSISTANT -> ChatCompletionMessageParam.ofAssistant( + case ASSISTANT -> ChatCompletionMessageParam.ofAssistant( ChatCompletionAssistantMessageParam.builder().content(message.getContent()).build() ); - case ChatConstant.Character.USER -> ChatCompletionMessageParam.ofUser( + case USER -> ChatCompletionMessageParam.ofUser( ChatCompletionUserMessageParam.builder().content(message.getContent()).build() ); - default -> throw new IllegalArgumentException("Unsupported message role: " + message.getRole()); }; } } diff --git a/Partner-Framework/src/test/java/work/slhaf/partner/api/chat/pojo/MessageTest.java b/Partner-Framework/src/test/java/work/slhaf/partner/api/chat/pojo/MessageTest.java new file mode 100644 index 00000000..750d79a1 --- /dev/null +++ b/Partner-Framework/src/test/java/work/slhaf/partner/api/chat/pojo/MessageTest.java @@ -0,0 +1,38 @@ +package work.slhaf.partner.api.chat.pojo; + +import com.alibaba.fastjson2.JSON; +import org.junit.jupiter.api.Test; +import work.slhaf.partner.api.chat.runtime.OpenAiMessageAdapter; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; + +class MessageTest { + + @Test + void shouldSerializeRoleAsProtocolValue() { + String json = JSON.toJSONString(new Message(Message.Character.USER, "hello")); + + assertEquals("{\"content\":\"hello\",\"role\":\"user\"}", json); + } + + @Test + void shouldDeserializeRoleFromProtocolValue() { + Message message = JSON.parseObject("{\"role\":\"assistant\",\"content\":\"ok\"}", Message.class); + + assertEquals(Message.Character.ASSISTANT, message.getRole()); + assertEquals("assistant", message.roleValue()); + } + + @Test + void shouldRejectUnsupportedRole() { + assertThrows(IllegalArgumentException.class, () -> Message.Character.fromValue("tool")); + } + + @Test + void shouldAdaptAllSupportedRoles() { + OpenAiMessageAdapter.toParam(new Message(Message.Character.USER, "u")); + OpenAiMessageAdapter.toParam(new Message(Message.Character.SYSTEM, "s")); + OpenAiMessageAdapter.toParam(new Message(Message.Character.ASSISTANT, "a")); + } +}