diff --git a/build.gradle.kts b/build.gradle.kts index 5737ac3..87bd57e 100644 --- a/build.gradle.kts +++ b/build.gradle.kts @@ -16,6 +16,8 @@ dependencies{ implementation("cn.bigmodel.openapi:oapi-java-sdk:release-V4-2.0.2") implementation ("com.aliyun:ocr_api20210707:3.1.1") implementation ("junit:junit:4.13.2") + implementation ("org.apache.logging.log4j:log4j-core:2.23.1") + implementation ("org.apache.logging.log4j:log4j-api:2.23.1") } diff --git a/src/main/java/plugin/App.java b/src/main/java/plugin/App.java index 9f8d099..025dafd 100644 --- a/src/main/java/plugin/App.java +++ b/src/main/java/plugin/App.java @@ -6,18 +6,21 @@ import net.mamoe.mirai.event.GlobalEventChannel; import net.mamoe.mirai.event.events.FriendMessageEvent; import net.mamoe.mirai.event.events.GroupMessageEvent; import net.mamoe.mirai.utils.MiraiLogger; -import plugin.listener.UserMessageListener; +import plugin.constant.ChatConstant; +import plugin.listener.FriendMessageListener; +import plugin.listener.GroupMessageListener; +import plugin.listener.OwnerMessageListener; import plugin.pojo.Config; import plugin.utils.ConfigUtil; import java.io.IOException; - +import java.util.HashMap; public final class App extends JavaPlugin { public static final App INSTANCE = new App(); - public static MiraiLogger logger; private String owner, bot; + private static HashMap customCommands; private App() { super(new JvmPluginDescriptionBuilder("com.plugin.chatAI-InGroup-v2", "0.1.0") @@ -31,12 +34,12 @@ public final class App extends JavaPlugin { public void onEnable() { //加载配置 try { - logger = getLogger(); ConfigUtil.load(); Thread.sleep(1500); Config config = ConfigUtil.getConfig(); owner = config.getOwner().substring(1); bot = config.getBot().substring(1); + customCommands = config.getCustomCommands(); } catch (IOException | ClassNotFoundException | InterruptedException e) { throw new RuntimeException(e); @@ -48,13 +51,21 @@ public final class App extends JavaPlugin { GlobalEventChannel.INSTANCE.filterIsInstance(GroupMessageEvent.class) .filter(event -> { String msg = event.getMessage().contentToString(); - return (msg.startsWith(".") && msg.length() != 1) || msg.startsWith("@"+bot) || msg.startsWith("/c "); - }).registerListenerHost(new UserMessageListener()); + return (msg.startsWith(".") && msg.length() != 1) || msg.startsWith("@"+bot) || customCommands.containsKey(msg.split(" ")[0]); + }).registerListenerHost(new GroupMessageListener()); + + //所有者监听 + GlobalEventChannel.INSTANCE.filterIsInstance(GroupMessageEvent.class) + .filter(event -> { + String msg = event.getMessage().contentToString(); + String sender = String.valueOf(event.getSender().getId()); + return msg.startsWith(ChatConstant.SET) && sender.equals(owner); + }).registerListenerHost(new OwnerMessageListener()); //私聊监听器 GlobalEventChannel.INSTANCE.filterIsInstance(FriendMessageEvent.class) .filter(event -> true) - .registerListenerHost(new UserMessageListener()); + .registerListenerHost(new FriendMessageListener()); } diff --git a/src/main/java/plugin/constant/ChatConstant.java b/src/main/java/plugin/constant/ChatConstant.java index d8b785d..0eae303 100644 --- a/src/main/java/plugin/constant/ChatConstant.java +++ b/src/main/java/plugin/constant/ChatConstant.java @@ -23,12 +23,14 @@ public class ChatConstant { /** * 普通对话标志 */ - public static final String NORMAL_MESSAGE_START = "@"; + public static final String DEFAULT_MESSAGE_START = "@"; /** - * code对话标志 + * 分隔符(空格) */ - public static final String CODE_MESSAGE_START = "/c "; + public static final String BLANK = " "; + + public static final String SET = "/set "; /** * 切换模型 @@ -44,4 +46,9 @@ public class ChatConstant { * 清理消息 */ public static final String CLEAR = "clear"; + + /** + * 当前模型 + */ + public static final String CURRENT_MODEL = "当前模型"; } diff --git a/src/main/java/plugin/constant/ConfigConstant.java b/src/main/java/plugin/constant/ConfigConstant.java new file mode 100644 index 0000000..de0e7b7 --- /dev/null +++ b/src/main/java/plugin/constant/ConfigConstant.java @@ -0,0 +1,13 @@ +package plugin.constant; + +/** + * @author SLHAF + */ +public class ConfigConstant { + + public static final String DEFAULT = "default"; + + public static final String NULL = "null"; + + public static final String CUSTOM_SPLIT = "\\|"; +} diff --git a/src/main/java/plugin/constant/MethodsConstant.java b/src/main/java/plugin/constant/MethodsConstant.java new file mode 100644 index 0000000..7cffce8 --- /dev/null +++ b/src/main/java/plugin/constant/MethodsConstant.java @@ -0,0 +1,23 @@ +package plugin.constant; + +public enum MethodsConstant { + + /** + * 正常对话 + */ + NORMAL, + + /** + * 单次对话 + */ + ONCE, + /** + * 预设对话 + */ + CUSTOM, + + /** + * 未匹配 + */ + NONE +} diff --git a/src/main/java/plugin/listener/FriendMessageListener.java b/src/main/java/plugin/listener/FriendMessageListener.java new file mode 100644 index 0000000..80abc71 --- /dev/null +++ b/src/main/java/plugin/listener/FriendMessageListener.java @@ -0,0 +1,63 @@ +package plugin.listener; + +import net.mamoe.mirai.event.EventHandler; +import net.mamoe.mirai.event.SimpleListenerHost; +import net.mamoe.mirai.event.events.FriendMessageEvent; +import plugin.constant.ChatConstant; +import plugin.constant.MethodsConstant; +import plugin.pojo.Config; +import plugin.utils.AIUtil; +import plugin.utils.ConfigUtil; + +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +public class FriendMessageListener extends SimpleListenerHost { + + private static final Config config = ConfigUtil.getConfig(); + + @EventHandler + public void friendMessageHandler(FriendMessageEvent event) { + /*String id = String.valueOf(event.getFriend().getId()); + String content = event.getMessage().contentToString(); + String miraiCode = event.getMessage().serializeToMiraiCode(); + String url = null; + if (miraiCode.matches(ChatConstant.MATCH_MESSAGE)) { + String regex = ChatConstant.MATCH_IMAGE; + Pattern pattern = Pattern.compile(regex); + + // 创建Matcher对象 + Matcher matcher = pattern.matcher(miraiCode); + + // 查找并提取链接 + if (matcher.find()) { + //提取第一个括号内的内容 + url = matcher.group(1); + } + } + MethodsConstant method = MethodsConstant.NORMAL; + + + if (content.contains(ChatConstant.CHANGE_MODEL) && !id.equals(config.getOwner().substring(1))) { + event.getFriend().sendMessage("没有权限!"); + return; + } + + //处理消息头 + if (content.startsWith(ChatConstant.CODE_MESSAGE_START)) { + content = content.substring(3); + method = MethodsConstant.CUSTOM; + } else if (content.startsWith(ChatConstant.ONCE_MESSAGE_START)) { + content = content.substring(1); + method = MethodsConstant.ONCE; + } + //发送请求并获取回应 + String response = switch (method) { + case CUSTOM -> AIUtil.customChat(Long.valueOf(id), content, url, chatCommand); + case ONCE -> AIUtil.chatOnce(content, url); + case NORMAL -> AIUtil.defaultChat(Long.valueOf(id), content, url); + default -> "ERROR!"; + }; + event.getFriend().sendMessage(response);*/ + } +} diff --git a/src/main/java/plugin/listener/GroupMessageListener.java b/src/main/java/plugin/listener/GroupMessageListener.java new file mode 100644 index 0000000..0c48dc7 --- /dev/null +++ b/src/main/java/plugin/listener/GroupMessageListener.java @@ -0,0 +1,90 @@ +package plugin.listener; + +import kotlin.coroutines.CoroutineContext; +import net.mamoe.mirai.event.EventHandler; +import net.mamoe.mirai.event.SimpleListenerHost; +import net.mamoe.mirai.event.events.GroupMessageEvent; +import net.mamoe.mirai.message.data.At; +import org.jetbrains.annotations.NotNull; +import plugin.constant.ChatConstant; +import plugin.constant.MethodsConstant; +import plugin.pojo.Config; +import plugin.utils.AIUtil; +import plugin.utils.ConfigUtil; + +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +import static plugin.utils.ConfigUtil.logger; + + +/** + * @author SLHAF + */ +public class GroupMessageListener extends SimpleListenerHost { + + private static final Config config = ConfigUtil.getConfig(); + + @Override + public void handleException(@NotNull CoroutineContext context, @NotNull Throwable exception) { + super.handleException(context, exception); + logger.error(exception.getMessage()); + } + + /** + * 负责提取链接、获取消息内容(去除指令头),提取指令、提取发送者id + * @param event 接收群聊消息 + */ + @EventHandler + public void groupMessageHandler(GroupMessageEvent event) { + //处理消息 + String id = String.valueOf(event.getSender().getId()); + String content = event.getMessage().contentToString(); + String miraiCode = event.getMessage().serializeToMiraiCode(); + String url = null; + String chatCommand = null; + if (miraiCode.matches(ChatConstant.MATCH_MESSAGE)) { + String regex = ChatConstant.MATCH_IMAGE; + Pattern pattern = Pattern.compile(regex); + + // 创建Matcher对象 + Matcher matcher = pattern.matcher(miraiCode); + + // 查找并提取链接 + if (matcher.find()) { + //提取第一个括号内的内容 + url = matcher.group(1); + } + } + + MethodsConstant method = MethodsConstant.NONE; + + //消息头处理 + if (content.startsWith(ChatConstant.ONCE_MESSAGE_START)) { + //单次对话 + content = content.substring(1); + method = MethodsConstant.ONCE; + } else if (content.startsWith(ChatConstant.DEFAULT_MESSAGE_START + event.getBot().getId())) { + //默认对话 + content = content.substring((ChatConstant.DEFAULT_MESSAGE_START + event.getBot().getId()).length()); + method = MethodsConstant.NORMAL; + } else if (config.getCustomCommands().containsKey(content.split(ChatConstant.BLANK)[0])) { + //预设对话 + content = content.split(ChatConstant.BLANK)[1]; + method = MethodsConstant.CUSTOM; + chatCommand = content.split(ChatConstant.BLANK)[0]; + } + //消息内容处理 + if (content.isBlank()) { + content = "在吗"; + } + //发送请求并获取回应 + String response = switch (method) { + case CUSTOM -> AIUtil.customChat(Long.valueOf(id), content, url,chatCommand); + case NORMAL -> AIUtil.defaultChat(Long.valueOf(id), content, url); + case ONCE -> AIUtil.chatOnce(content, url); + default -> "ERROR!"; + }; + event.getGroup().sendMessage(new At(Long.parseLong(id)).plus("\r\n").plus(response)); + } +} diff --git a/src/main/java/plugin/listener/OwnerMessageListener.java b/src/main/java/plugin/listener/OwnerMessageListener.java new file mode 100644 index 0000000..8c445dc --- /dev/null +++ b/src/main/java/plugin/listener/OwnerMessageListener.java @@ -0,0 +1,6 @@ +package plugin.listener; + +import net.mamoe.mirai.event.SimpleListenerHost; + +public class OwnerMessageListener extends SimpleListenerHost { +} diff --git a/src/main/java/plugin/listener/UserMessageListener.java b/src/main/java/plugin/listener/UserMessageListener.java deleted file mode 100644 index 2df0619..0000000 --- a/src/main/java/plugin/listener/UserMessageListener.java +++ /dev/null @@ -1,152 +0,0 @@ -package plugin.listener; - -import kotlin.coroutines.CoroutineContext; -import net.mamoe.mirai.event.EventHandler; -import net.mamoe.mirai.event.SimpleListenerHost; -import net.mamoe.mirai.event.events.FriendMessageEvent; -import net.mamoe.mirai.event.events.GroupMessageEvent; -import net.mamoe.mirai.message.data.At; -import org.jetbrains.annotations.NotNull; -import plugin.constant.ChatConstant; -import plugin.pojo.Config; -import plugin.utils.AIUtil; -import plugin.utils.ConfigUtil; - -import java.util.regex.Matcher; -import java.util.regex.Pattern; - -import static plugin.App.logger; - -/** - * @author SLHAF - */ -public class UserMessageListener extends SimpleListenerHost { - - private static final Config config = ConfigUtil.getConfig(); - - public enum Methods { - - /** - * 正常对话 - */ - NORMAL, - - /** - * 单次对话 - */ - ONCE, - - /** - * 预设:code - */ - CODE, - - /** - * 未匹配 - */ - NONE - } - - @Override - public void handleException(@NotNull CoroutineContext context, @NotNull Throwable exception) { - super.handleException(context, exception); - logger.error(exception.getMessage()); - } - - @EventHandler - public void groupMessageHandler(GroupMessageEvent event) { - //处理消息 - String id = String.valueOf(event.getSender().getId()); - String content = event.getMessage().contentToString(); - String miraiCode = event.getMessage().serializeToMiraiCode(); - String url = null; - if (miraiCode.matches(ChatConstant.MATCH_MESSAGE)) { - String regex = ChatConstant.MATCH_IMAGE; - Pattern pattern = Pattern.compile(regex); - - // 创建Matcher对象 - Matcher matcher = pattern.matcher(miraiCode); - - // 查找并提取链接 - if (matcher.find()) { - //提取第一个括号内的内容 - url = matcher.group(1); - } - } - Methods method = Methods.NONE; - - if (content.contains(ChatConstant.CHANGE_MODEL) && !id.equals(config.getOwner().substring(1))) { - event.getGroup().sendMessage(new At(Long.parseLong(id)).plus("没有权限!")); - return; - } - - //消息头处理 - if (content.startsWith(ChatConstant.ONCE_MESSAGE_START)) { - content = content.substring(1); - method = Methods.ONCE; - } else if (content.startsWith(ChatConstant.NORMAL_MESSAGE_START + event.getBot().getId())) { - content = content.substring((ChatConstant.NORMAL_MESSAGE_START + event.getBot().getId()).length()); - method = Methods.NORMAL; - } else if (content.startsWith(ChatConstant.CODE_MESSAGE_START)) { - content = content.substring(3); - method = Methods.CODE; - } - //消息内容处理 - if (content.isBlank()) { - content = "在吗"; - } - //发送请求并获取回应 - String response = switch (method) { - case CODE -> AIUtil.chatCode(Long.valueOf(id), content, url); - case NORMAL -> AIUtil.chatNormal(Long.valueOf(id), content, url); - case ONCE -> AIUtil.chatOnce(content, url); - default -> "ERROR!"; - }; - event.getGroup().sendMessage(new At(Long.parseLong(id)).plus("\r\n").plus(response)); - } - - @EventHandler - public void friendMessageHandler(FriendMessageEvent event) { - String id = String.valueOf(event.getFriend().getId()); - String content = event.getMessage().contentToString(); - String miraiCode = event.getMessage().serializeToMiraiCode(); - String url = null; - if (miraiCode.matches(ChatConstant.MATCH_MESSAGE)) { - String regex = ChatConstant.MATCH_IMAGE; - Pattern pattern = Pattern.compile(regex); - - // 创建Matcher对象 - Matcher matcher = pattern.matcher(miraiCode); - - // 查找并提取链接 - if (matcher.find()) { - //提取第一个括号内的内容 - url = matcher.group(1); - } - } - Methods method = Methods.NORMAL; - - - if (content.contains(ChatConstant.CHANGE_MODEL) && !id.equals(config.getOwner().substring(1))) { - event.getFriend().sendMessage("没有权限!"); - return; - } - - //处理消息头 - if (content.startsWith(ChatConstant.CODE_MESSAGE_START)) { - content = content.substring(3); - method = Methods.CODE; - } else if (content.startsWith(ChatConstant.ONCE_MESSAGE_START)) { - content = content.substring(1); - method = Methods.ONCE; - } - //发送请求并获取回应 - String response = switch (method) { - case CODE -> AIUtil.chatCode(Long.valueOf(id), content, url); - case ONCE -> AIUtil.chatOnce(content, url); - case NORMAL -> AIUtil.chatNormal(Long.valueOf(id), content, url); - default -> "ERROR!"; - }; - event.getFriend().sendMessage(response); - } -} diff --git a/src/main/java/plugin/pojo/Config.java b/src/main/java/plugin/pojo/Config.java index a35ad69..4b5fd3e 100644 --- a/src/main/java/plugin/pojo/Config.java +++ b/src/main/java/plugin/pojo/Config.java @@ -1,6 +1,7 @@ package plugin.pojo; import java.util.HashMap; +import java.util.LinkedHashMap; public class Config { /** @@ -18,8 +19,7 @@ public class Config { * 基础配置 */ private String owner; - private String modelNormal; - private String modelCode; + private String defaultModel; private String bot; private String timeout; private String timeCheck; @@ -27,18 +27,17 @@ public class Config { /** * 自定义预设 */ - private HashMap customCommands; + private LinkedHashMap customCommands; public Config() { } - public Config(String apikey, String accessKeyId, String accessKeySecret, String owner, String modelNormal, String modelCode, String bot, String timeout, String timeCheck, HashMap customCommands) { + public Config(String apikey, String accessKeyId, String accessKeySecret, String owner, String defaultModel, String bot, String timeout, String timeCheck, LinkedHashMap customCommands) { this.apikey = apikey; this.accessKeyId = accessKeyId; this.accessKeySecret = accessKeySecret; this.owner = owner; - this.modelNormal = modelNormal; - this.modelCode = modelCode; + this.defaultModel = defaultModel; this.bot = bot; this.timeout = timeout; this.timeCheck = timeCheck; @@ -113,32 +112,16 @@ public class Config { * 获取 * @return modelNormal */ - public String getModelNormal() { - return modelNormal; + public String getDefaultModel() { + return defaultModel; } /** * 设置 - * @param modelNormal + * @param defaultModel */ - public void setModelNormal(String modelNormal) { - this.modelNormal = modelNormal; - } - - /** - * 获取 - * @return modelCode - */ - public String getModelCode() { - return modelCode; - } - - /** - * 设置 - * @param modelCode - */ - public void setModelCode(String modelCode) { - this.modelCode = modelCode; + public void setDefaultModel(String defaultModel) { + this.defaultModel = defaultModel; } /** @@ -201,12 +184,12 @@ public class Config { * 设置 * @param customCommands */ - public void setCustomCommands(HashMap customCommands) { + public void setCustomCommands(LinkedHashMap customCommands) { this.customCommands = customCommands; } @Override public String toString() { - return "Config{apikey = " + apikey + ", accessKeyId = " + accessKeyId + ", accessKeySecret = " + accessKeySecret + ", owner = " + owner + ", modelNormal = " + modelNormal + ", modelCode = " + modelCode + ", bot = " + bot + ", timeout = " + timeout + ", timeCheck = " + timeCheck + ", customCommands = " + customCommands + "}"; + return "Config{apikey = " + apikey + ", accessKeyId = " + accessKeyId + ", accessKeySecret = " + accessKeySecret + ", owner = " + owner + ", modelNormal = " + defaultModel + ", modelCode = " + ", bot = " + bot + ", timeout = " + timeout + ", timeCheck = " + timeCheck + ", customCommands = " + customCommands + "}"; } } diff --git a/src/main/java/plugin/pojo/UserCustomMessage.java b/src/main/java/plugin/pojo/UserCustomMessage.java new file mode 100644 index 0000000..427fd58 --- /dev/null +++ b/src/main/java/plugin/pojo/UserCustomMessage.java @@ -0,0 +1,56 @@ +package plugin.pojo; + +import com.zhipu.oapi.service.v4.model.ChatMessage; + +import java.util.List; + +public class UserCustomMessage { + private String command; + private List messages; + + + public UserCustomMessage() { + } + + public UserCustomMessage(String command, List messages) { + this.command = command; + this.messages = messages; + } + + /** + * 获取 + * @return command + */ + public String getCommand() { + return command; + } + + /** + * 设置 + * @param command + */ + public void setCommand(String command) { + this.command = command; + } + + /** + * 获取 + * @return messages + */ + public List getMessages() { + return messages; + } + + /** + * 设置 + * @param messages + */ + public void setMessages(List messages) { + this.messages = messages; + } + + @Override + public String toString() { + return "UserCustomMessage{command = " + command + ", messages = " + messages + "}"; + } +} diff --git a/src/main/java/plugin/utils/AIUtil.java b/src/main/java/plugin/utils/AIUtil.java index 7a53ecb..6b88f9a 100644 --- a/src/main/java/plugin/utils/AIUtil.java +++ b/src/main/java/plugin/utils/AIUtil.java @@ -8,13 +8,17 @@ import com.zhipu.oapi.service.v4.model.ChatMessageRole; import com.zhipu.oapi.service.v4.model.ModelApiResponse; import plugin.constant.AIConstant; import plugin.constant.ChatConstant; +import plugin.constant.ConfigConstant; +import plugin.constant.MethodsConstant; import plugin.pojo.Config; +import plugin.pojo.UserCustomMessage; import java.util.ArrayList; import java.util.HashMap; import java.util.List; -import static plugin.App.logger; +import static plugin.utils.ConfigUtil.logger; + /** * @author SLHAF @@ -23,13 +27,18 @@ public class AIUtil { private static final String APIKEY; private static final ClientV4 CLIENT; private static final String REQUEST_ID_TEMPLATE = "ChatAI_InGroup_v2"; - private static final HashMap> userMessagesNormal = new HashMap<>(); - private static final HashMap userLatestTimeNormal = new HashMap<>(); - private static String modelNormal; + private static final HashMap> userDefaultMessages = new HashMap<>(); + private static final HashMap userLatestTimeOfDefault = new HashMap<>(); + private static String defaultModel; - private static final HashMap> userMessagesCode = new HashMap<>(); - private static final HashMap userLatestTimeCode = new HashMap<>(); - private static String modelCode; + private static final HashMap> userCustomMessages = new HashMap<>(); + private static final HashMap userLatestTimeOfCustom = new HashMap<>(); + /*private static String modelCode;*/ + /** + * 结构: + * /c : glm-4-flush|预设内容 + */ + private static HashMap customCommands; private static final Long CHECK_TIME, TIMEOUT; @@ -37,8 +46,8 @@ public class AIUtil { Config config = ConfigUtil.getConfig(); APIKEY = config.getApikey(); CLIENT = new ClientV4.Builder(APIKEY).build(); - modelNormal = config.getModelNormal(); - modelCode = config.getModelCode(); + defaultModel = config.getDefaultModel(); + customCommands = config.getCustomCommands(); CHECK_TIME = Long.valueOf(config.getTimeCheck().substring(1)); TIMEOUT = Long.valueOf(config.getTimeout().substring(1)); new Thread(() -> { @@ -49,117 +58,120 @@ public class AIUtil { throw new RuntimeException(e); } - synchronized (userMessagesNormal) { - if (!userLatestTimeNormal.isEmpty()) { - //查看user最近时间,如果超过30min,则清理对应记录 - userLatestTimeNormal.forEach((id, latestTime) -> { + if (!userLatestTimeOfDefault.isEmpty()) { + //查看user最近时间,如果超过30min,则清理对应记录 + userLatestTimeOfDefault.forEach((id, latestTime) -> { + synchronized (userDefaultMessages.get(id)) { Long currentTime = System.currentTimeMillis(); - if (currentTime - latestTime > TIMEOUT && userMessagesNormal.containsKey(id)) { - userMessagesNormal.remove(id); - logger.info("Normal记录清理:" + id); + if (currentTime - latestTime > TIMEOUT && userDefaultMessages.containsKey(id)) { + userDefaultMessages.remove(id); + logger.info("default记录清理:" + id); } - }); - } + } + }); } - synchronized (userMessagesCode) { - if (!userLatestTimeCode.isEmpty()) { - //查看user最近时间,如果超过30min,则清理对应记录 - userLatestTimeCode.forEach((id, latestTime) -> { + + if (!userLatestTimeOfCustom.isEmpty()) { + //查看user最近时间,如果超过30min,则清理对应记录 + userLatestTimeOfCustom.forEach((id, latestTime) -> { + synchronized (userCustomMessages.get(id)) { Long currentTime = System.currentTimeMillis(); - if (currentTime - latestTime > 30 * 60 * 1000 && userMessagesCode.containsKey(id)) { - userMessagesCode.remove(id); - logger.info("Code记录清理:" + id); + if (currentTime - latestTime > 30 * 60 * 1000 && userCustomMessages.containsKey(id)) { + userCustomMessages.remove(id); + logger.info("custom记录清理:" + id); } - }); - } + } + }); } + } }).start(); logger.info("清理线程启动"); - logger.info("当前代码模型: " + modelCode); - logger.info("当前聊天模型: " + modelNormal); + logger.info("当前默认聊天模型: " + defaultModel); } private AIUtil() { } - public static String chatCode(Long id, String content, String url) { - synchronized (userMessagesCode) { - if (ChatConstant.CLEAR.equals(content.replace(" ", ""))) { - userMessagesCode.remove(id); - return "消息记录已清空"; - } else if (content.replace(" ", "").startsWith(ChatConstant.CHANGE_MODEL)) { - content = content.replace(" ", ""); - modelCode = content.substring(4); - ConfigUtil.modelCodeChange(modelCode); - return "聊天模型切换为: " + modelCode; - } else if (AIConstant.CURRENT_MODEL.equals(content.replace(" ", ""))) { - return "当前模型为: " + modelCode; + public static String customChat(Long id, String content, String url, String chatCommand) { + if (ChatConstant.CLEAR.equals(content.replace(ChatConstant.BLANK, ""))) { + userCustomMessages.remove(id); + return "消息记录已清空"; + } /*else if (content.replace(ChatConstant.BLANK, "").startsWith(ChatConstant.CHANGE_MODEL)) { + content = content.replace(ChatConstant.BLANK, ""); + modelCode = content.substring(4); + ConfigUtil.modelCodeChange(modelCode); + return "聊天模型切换为: " + modelCode; + } */ else if (AIConstant.CURRENT_MODEL.equals(content.replace(ChatConstant.BLANK, ""))) { + String modelName = customCommands.get(chatCommand).split(ConfigConstant.CUSTOM_SPLIT)[0]; + return "当前模型为: " + modelName; + } + //查看本次id是否有记录存在 + if (!userCustomMessages.containsKey(id)) { + //创建消息list + List chatMessages = new ArrayList<>(); + if (!customCommands.get(chatCommand).split(ConfigConstant.CUSTOM_SPLIT)[1].equals(ConfigConstant.NULL)) { + ChatMessage customMessage = new ChatMessage(ChatMessageRole.SYSTEM.value(), customCommands.get(chatCommand).split(ConfigConstant.CUSTOM_SPLIT)[1]); + chatMessages.add(customMessage); } - //查看本次id是否有记录存在 - if (!userMessagesCode.containsKey(id)) { - //创建消息list - List chatMessage = new ArrayList<>(); - chatMessage.add(new ChatMessage(ChatMessageRole.SYSTEM.value(), "你是一位智能编程助手,你会为用户回答关于编程、代码、计算机方面的任何问题,并提供格式规范、可以执行、准确安全的代码,并在必要时提供详细的解释。 请用中文回答。")); - userMessagesCode.put(id, chatMessage); - } - return getChatResponse(id, content, url,modelCode, userMessagesCode, userLatestTimeCode); + List userCustomMessageList = new ArrayList<>(); + userCustomMessageList.add(new UserCustomMessage(chatCommand, chatMessages)); + userCustomMessages.put(id, userCustomMessageList); + } + String modelName = customCommands.get(chatCommand).split(ConfigConstant.CUSTOM_SPLIT)[0]; + synchronized (userCustomMessages.get(id)) { + return getChatResponse(id, content, url, modelName, userLatestTimeOfCustom, chatCommand); } } - public static String chatNormal(Long id, String content,String url) { - synchronized (userMessagesNormal) { - if (ChatConstant.CLEAR.equals(content.replace(" ", ""))) { - userMessagesNormal.remove(id); - return "消息记录已清空"; - } else if (content.replace(" ", "").startsWith(ChatConstant.CHANGE_MODEL)) { - content = content.replace(" ", ""); - modelNormal = content.substring(4); - ConfigUtil.modelNormalChange(modelNormal); - return "聊天模型切换为: " + modelNormal; - } else if (ChatConstant.CHANGE_MODEL.equals(content.replace(" ", ""))) { - return "当前模型为: " + modelNormal; + public static String defaultChat(Long id, String content, String url) { + if (ChatConstant.CLEAR.equals(content.replace(ChatConstant.BLANK, ""))) { + userDefaultMessages.remove(id); + return "消息记录已清空"; + }else if (ChatConstant.CURRENT_MODEL.equals(content.replace(ChatConstant.BLANK, ""))) { + return "当前模型为: " + defaultModel; + } + //查看本次id是否有记录存在 + if (!userDefaultMessages.containsKey(id)) { + //创建消息list + List chatMessages = new ArrayList<>(); + if (customCommands.containsKey(ConfigConstant.DEFAULT) && !ConfigConstant.NULL.equals(customCommands.get(ConfigConstant.DEFAULT))) { + ChatMessage customMessage = new ChatMessage(ChatMessageRole.SYSTEM.value(), customCommands.get(ConfigConstant.DEFAULT)); + chatMessages.add(customMessage); } - //查看本次id是否有记录存在 - if (!userMessagesNormal.containsKey(id)) { - //创建消息list - List chatMessage = new ArrayList<>(); - userMessagesNormal.put(id, chatMessage); - } - return getChatResponse(id, content, url,modelNormal, userMessagesNormal, userLatestTimeNormal); + userDefaultMessages.put(id, chatMessages); + } + synchronized (userDefaultMessages.get(id)) { + return getChatResponse(id, content, url, defaultModel, userLatestTimeOfDefault, null); } } - public static String chatOnce(String content,String url) { - if (content.replace(" ", "").startsWith(ChatConstant.CHANGE_MODEL)) { - content = content.replace(" ", ""); - modelNormal = content.substring(4); - ConfigUtil.modelNormalChange(modelNormal); - return "代码模型切换为: " + modelNormal; - } else if (AIConstant.CURRENT_MODEL.equals(content.replace(" ", ""))) { - return "当前模型为: " + modelNormal; + public static String chatOnce(String content, String url) { + + if (AIConstant.CURRENT_MODEL.equals(content.replace(ChatConstant.BLANK, ""))) { + return "当前模型为: " + defaultModel; } String result = ""; - if(url != null){ - if (!OCRUtil.isSupported){ + if (url != null) { + if (!OCRUtil.isSupported) { return "当前不支持文字识别,请检查阿里云OCR相关配置。"; } else { String contentOfImage = OCRUtil.getContentOfImage(url); - if (contentOfImage == null){ + if (contentOfImage == null) { result = "未识别出图片内容。"; - }else if (AIConstant.ERROR.equals(contentOfImage)){ + } else if (AIConstant.ERROR.equals(contentOfImage)) { result = "识别图片内容出错,请查看控制台。"; - }else { + } else { content = content.replace("[图片]", "\r\n[" + contentOfImage + "]\r\n"); } } } - String requestId = String.format(REQUEST_ID_TEMPLATE, System.currentTimeMillis()); + String requestId = REQUEST_ID_TEMPLATE + "_once_"+System.currentTimeMillis(); List messages = new ArrayList<>(); messages.add(new ChatMessage(ChatMessageRole.USER.value(), content)); ChatCompletionRequest chatCompletionRequest = ChatCompletionRequest.builder() - .model(modelNormal) + .model(defaultModel) .stream(Boolean.FALSE) .invokeMethod(Constants.invokeMethod) .messages(messages) @@ -174,43 +186,84 @@ public class AIUtil { logger.warning("code: " + code); logger.warning("ErrorCode:" + invokeModelApiResp.getData().getError().getCode()); logger.warning("msg:" + invokeModelApiResp.getData().getError().getMessage()); - return invokeModelApiResp.getMsg()+"\r\n<"+result+">"; + return invokeModelApiResp.getMsg() + "\r\n<" + result + ">"; } } - private static String getChatResponse(Long id, String content, String url,String model, HashMap> userMessages, HashMap userLatestTime) { + /** + * 取得回复内容 + *
chatMessages内容只有两个static变量,只需根据传入的chatCommand是否为null进行判断 + * + * @param id 聊天用户ID + * @param content 聊天用户发送的内容,content已在Listener处进行处理 + * @param url 获取到的图片url,如果没有图片,则为null + * @param model 本次对话所需的模型名称 + *
当为userDefaultMessages时: + *
+ *
qq1 : [msg1,msg2,msg3] + *
qq2 : [msg1,msg2,msg3] + *
+ *
当为userCustomMessages时: + *
+ *
qq1 : [{command1,[msg1.msg2,msg3]},{command2,[msg1,msg2,msg3]},{command3,[msg1,msg2,msg3]}] + *
qq2 : [{command1,[msg1.msg2,msg3]},{command2,[msg1,msg2,msg3]},{command3,[msg1,msg2,msg3]}] + *
+ * @param userLatestTime 最近操作时间 + * @param chatCommand + * @return 得到的模型响应内容 + */ + private static String getChatResponse(Long id, String content, String url, String model, HashMap userLatestTime, String chatCommand) { userLatestTime.put(id, System.currentTimeMillis()); - String requestId = String.format(REQUEST_ID_TEMPLATE, System.currentTimeMillis()); + String requestId = REQUEST_ID_TEMPLATE + "_"+ model + "_" + System.currentTimeMillis(); //处理url内容 String result = ""; - if(url != null){ - if (!OCRUtil.isSupported){ + if (url != null) { + if (!OCRUtil.isSupported) { logger.warning("unSupportedOCR"); return "当前不支持文字识别,请检查阿里云OCR相关配置。"; } else { String contentOfImage = OCRUtil.getContentOfImage(url); - if (contentOfImage == null){ + if (contentOfImage == null) { result = "未识别出图片内容。"; - }else if (AIConstant.ERROR.equals(contentOfImage)){ + } else if (AIConstant.ERROR.equals(contentOfImage)) { result = "识别图片内容出错,请查看控制台。"; - }else { + } else { content = content.replace("[图片]", "\r\n[" + contentOfImage + "]\r\n"); } } } logger.info("final content:" + content); + //根据primaryUserMessages中的内容来定义userMessages + List chatMessages = null; + if (chatCommand == null) { + chatMessages = userDefaultMessages.get(id); + } else { + List userCustomMessageList = userCustomMessages.get(id); + for (UserCustomMessage userCustomMessage : userCustomMessageList) { + //在调用时已确保存在指令对应的消息记录 + if (userCustomMessage.getCommand().equals(chatCommand)) { + chatMessages = userCustomMessage.getMessages(); + break; + } + } + } + + if (chatMessages == null) { + return "消息记录读取失败"; + } + //添加消息 - userMessages.get(id).add(new ChatMessage(ChatMessageRole.USER.value(), content)); + chatMessages.add(new ChatMessage(ChatMessageRole.USER.value(), content)); //创建并发送请求 ChatCompletionRequest chatCompletionRequest = ChatCompletionRequest.builder() .model(model) .stream(Boolean.FALSE) .invokeMethod(Constants.invokeMethod) - .messages(userMessages.get(id)) + .messages(chatMessages) .requestId(requestId) .build(); @@ -219,13 +272,13 @@ public class AIUtil { if (code == 200) { printTokenInfo(invokeModelApiResp); String response = invokeModelApiResp.getData().getChoices().get(0).getMessage().getContent().toString(); - userMessages.get(id).add(new ChatMessage(ChatMessageRole.ASSISTANT.value(), response)); + chatMessages.add(new ChatMessage(ChatMessageRole.ASSISTANT.value(), response)); return response; } else { logger.warning("code: " + code); logger.warning("ErrorCode:" + invokeModelApiResp.getData().getError().getCode()); logger.warning("msg:" + invokeModelApiResp.getData().getError().getMessage()); - return invokeModelApiResp.getMsg()+"\r\n<"+result+">"; + return invokeModelApiResp.getMsg() + "\r\n<" + result + ">"; } } diff --git a/src/main/java/plugin/utils/ConfigUtil.java b/src/main/java/plugin/utils/ConfigUtil.java index 823935a..8687dea 100644 --- a/src/main/java/plugin/utils/ConfigUtil.java +++ b/src/main/java/plugin/utils/ConfigUtil.java @@ -1,23 +1,27 @@ package plugin.utils; import cn.hutool.core.bean.BeanUtil; -import lombok.Getter; +import lombok.extern.slf4j.Slf4j; +import net.mamoe.mirai.utils.LoggerAdapters; +import net.mamoe.mirai.utils.MiraiLogger; import org.yaml.snakeyaml.DumperOptions; import org.yaml.snakeyaml.Yaml; +import plugin.constant.ConfigConstant; import plugin.pojo.Config; import java.io.*; -import java.util.HashMap; +import java.util.LinkedHashMap; -import static plugin.App.logger; /** * @author SLHAF */ +@Slf4j public class ConfigUtil { private static final String CONFIG_PATH = "./config/ChatAIinGroup/config.yaml"; private static final Yaml yaml; private static Config config; + public static MiraiLogger logger = LoggerAdapters.asMiraiLogger(log); private ConfigUtil() { } @@ -45,14 +49,14 @@ public class ConfigUtil { config.setAccessKeyId("your_ali_access_key_id"); config.setAccessKeySecret("your_ali_access_key_secret"); config.setOwner("your_bot_owner_qq_number(e.g. Q1145141919810)"); - config.setModelNormal("glm-4-flash"); - config.setModelCode("glm-4-flash"); + config.setDefaultModel("glm-4-flash"); config.setBot("your_bot_qq_number(e.g. Q1145141919810)"); config.setTimeout("M3600000"); config.setTimeCheck("M60000"); - HashMap commands = new HashMap<>(); - commands.put("/c ", "你是一位智能编程助手,你会为用户回答关于编程、代码、计算机方面的任何问题,并提供格式规范、可以执行、准确安全的代码,并在必要时提供详细的解释。 请用中文回答。"); - commands.put("/example", "预设内容"); + LinkedHashMap commands = new LinkedHashMap<>(); + commands.put("default","glm-4-flash|null"); + commands.put("/c ", "glm-4-flash|你是一位智能编程助手,你会为用户回答关于编程、代码、计算机方面的任何问题,并提供格式规范、可以执行、准确安全的代码,并在必要时提供详细的解释。 请用中文回答。"); + commands.put("/example ", "模型名称|预设内容"); config.setCustomCommands(commands); dump(); logger.warning("配置文件创建成功,请关闭后进行配置"); @@ -78,18 +82,19 @@ public class ConfigUtil { * * @param modelName 模型名称 */ - public static void modelNormalChange(String modelName) { + public static void defaultModelChange(String modelName) { try { - config.setModelNormal(modelName); + config.setDefaultModel(modelName); dump(); } catch (IOException e) { throw new RuntimeException(e); } } - public static void modelCodeChange(String modelName) { + public static void customModelChange(String command,String modelName) { try { - config.setModelCode(modelName); + String customContent = config.getCustomCommands().get(command).split(ConfigConstant.CUSTOM_SPLIT)[1]; + config.getCustomCommands().put(command,modelName+ConfigConstant.CUSTOM_SPLIT+customContent); dump(); } catch (IOException e) { throw new RuntimeException(e); diff --git a/src/main/java/plugin/utils/OCRUtil.java b/src/main/java/plugin/utils/OCRUtil.java index 3103de2..313f164 100644 --- a/src/main/java/plugin/utils/OCRUtil.java +++ b/src/main/java/plugin/utils/OCRUtil.java @@ -9,7 +9,8 @@ import com.aliyun.teautil.models.RuntimeOptions; import plugin.pojo.Config; import plugin.pojo.OCRDataInfo; -import static plugin.App.logger; +import static plugin.utils.ConfigUtil.logger; + public class OCRUtil { private static Client client; diff --git a/src/test/java/MyTest.java b/src/test/java/MyTest.java index b82d4f7..f274c8b 100644 --- a/src/test/java/MyTest.java +++ b/src/test/java/MyTest.java @@ -10,8 +10,13 @@ import org.apache.http.impl.client.HttpClients; import org.apache.http.message.BasicHeader; import org.apache.http.util.EntityUtils; import org.junit.Test; +import plugin.App; +import plugin.utils.AIUtil; +import plugin.utils.ConfigUtil; import java.io.IOException; +import java.lang.reflect.InvocationTargetException; +import java.lang.reflect.Method; import java.util.regex.Matcher; import java.util.regex.Pattern; @@ -132,4 +137,15 @@ public class MyTest { response.close(); client.close(); } + + @Test + public void mainTest() throws ClassNotFoundException, IOException { + ConfigUtil.load(); + + Long id = 2998813882L; + String content = "hello"; + String chatCommand = "/c "; + String s = AIUtil.customChat(id, content, null, chatCommand); + System.out.println(s); + } }