mirror of
https://github.com/slhaf/Partner.git
synced 2026-05-12 08:43:02 +08:00
refactor(framework): migrate module abstracts/ActivateModel to Kotlin and introduce shared model/context structures
This commit is contained in:
@@ -4,11 +4,13 @@ import com.alibaba.fastjson2.JSONObject;
|
||||
import lombok.Data;
|
||||
import lombok.EqualsAndHashCode;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.jetbrains.annotations.NotNull;
|
||||
import work.slhaf.partner.api.agent.factory.capability.annotation.InjectCapability;
|
||||
import work.slhaf.partner.api.agent.factory.module.abstracts.AbstractAgentRunningModule;
|
||||
import work.slhaf.partner.api.agent.factory.module.abstracts.ActivateModel;
|
||||
import work.slhaf.partner.api.agent.factory.module.annotation.CoreModule;
|
||||
import work.slhaf.partner.api.agent.factory.module.annotation.Init;
|
||||
import work.slhaf.partner.api.chat.ChatClient;
|
||||
import work.slhaf.partner.api.chat.constant.ChatConstant;
|
||||
import work.slhaf.partner.api.chat.pojo.ChatResponse;
|
||||
import work.slhaf.partner.api.chat.pojo.Message;
|
||||
@@ -33,13 +35,12 @@ public class CoreModel extends AbstractAgentRunningModule<PartnerRunningFlowCont
|
||||
|
||||
@InjectCapability
|
||||
private CognationCapability cognationCapability;
|
||||
private List<Message> appendedMessages;
|
||||
private List<Message> appendedMessages = new ArrayList<>();
|
||||
|
||||
@Init
|
||||
public void init(){
|
||||
List<Message> chatMessages = this.cognationCapability.getChatMessages();
|
||||
this.getModel().setChatMessages(chatMessages);
|
||||
this.appendedMessages = new ArrayList<>();
|
||||
this.getModel().getChatMessages().addAll(chatMessages);
|
||||
|
||||
updateChatClientSettings();
|
||||
log.info("[CoreModel] CoreModel注册完毕...");
|
||||
@@ -47,12 +48,13 @@ public class CoreModel extends AbstractAgentRunningModule<PartnerRunningFlowCont
|
||||
|
||||
@Override
|
||||
public void updateChatClientSettings() {
|
||||
chatClient().setTemperature(0.3);
|
||||
chatClient().setTop_p(0.7);
|
||||
ChatClient chatClient = getModel().getChatClient();
|
||||
chatClient.setTemperature(0.3);
|
||||
chatClient.setTop_p(0.7);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String modelKey() {
|
||||
public @NotNull String modelKey() {
|
||||
return "core_model";
|
||||
}
|
||||
|
||||
@@ -75,7 +77,7 @@ public class CoreModel extends AbstractAgentRunningModule<PartnerRunningFlowCont
|
||||
activateModule(runningFlowContext);
|
||||
setMessageCount(runningFlowContext);
|
||||
|
||||
log.debug("[CoreModel] 当前消息列表大小: {}", chatMessages().size());
|
||||
log.debug("[CoreModel] 当前消息列表大小: {}", getModel().getChatMessages().size());
|
||||
log.debug("[CoreModel] 当前核心prompt内容: {}", runningFlowContext.getCoreContext().toString());
|
||||
|
||||
setMessage(runningFlowContext.getCoreContext().toString());
|
||||
@@ -110,13 +112,13 @@ public class CoreModel extends AbstractAgentRunningModule<PartnerRunningFlowCont
|
||||
log.error("[CoreModel] CoreModel执行异常: {}", e.getLocalizedMessage());
|
||||
if (count > 3) {
|
||||
handleExceptionResponse(response, "主模型交互出错: " + e.getLocalizedMessage());
|
||||
chatMessages().removeLast();
|
||||
getModel().getChatMessages().removeLast();
|
||||
break;
|
||||
}
|
||||
} finally {
|
||||
updateCoreResponse(runningFlowContext, response);
|
||||
resetAppendedMessages();
|
||||
log.debug("[CoreModel] 消息列表更新大小: {}", chatMessages().size());
|
||||
log.debug("[CoreModel] 消息列表更新大小: {}", getModel().getChatMessages().size());
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -145,17 +147,20 @@ public class CoreModel extends AbstractAgentRunningModule<PartnerRunningFlowCont
|
||||
}
|
||||
|
||||
@Override
|
||||
public ChatResponse chat() {
|
||||
List<Message> temp = new ArrayList<>(baseMessages().subList(0, baseMessages().size() - 2));
|
||||
public @NotNull ChatResponse chat() {
|
||||
List<@NotNull Message> baseMessages = getModel().getBaseMessages();
|
||||
List<@NotNull Message> chatMessages = getModel().getChatMessages();
|
||||
List<Message> temp = new ArrayList<>(baseMessages.subList(0, baseMessages.size() - 2));
|
||||
temp.addAll(appendedMessages);
|
||||
temp.addAll(baseMessages().subList(baseMessages().size() - 2, baseMessages().size()));
|
||||
temp.addAll(chatMessages());
|
||||
return chatClient().runChat(temp);
|
||||
temp.addAll(baseMessages.subList(baseMessages.size() - 2, baseMessages.size()));
|
||||
temp.addAll(chatMessages);
|
||||
return getModel().getChatClient().runChat(temp);
|
||||
}
|
||||
|
||||
private void updateModuleContextAndChatMessages(PartnerRunningFlowContext runningFlowContext, String response, ChatResponse chatResponse) {
|
||||
cognationCapability.getMessageLock().lock();
|
||||
chatMessages().removeIf(m -> {
|
||||
List<@NotNull Message> chatMessages = getModel().getChatMessages();
|
||||
chatMessages.removeIf(m -> {
|
||||
if (m.getRole().equals(ChatConstant.Character.ASSISTANT)) {
|
||||
return false;
|
||||
}
|
||||
@@ -169,9 +174,9 @@ public class CoreModel extends AbstractAgentRunningModule<PartnerRunningFlowCont
|
||||
//添加时间标志
|
||||
String dateTime = LocalDateTime.now().format(DateTimeFormatter.ofPattern("\r\n**[yyyy-MM-dd HH:mm:ss]"));
|
||||
Message primaryUserMessage = new Message(ChatConstant.Character.USER, runningFlowContext.getCoreContext().getText() + dateTime);
|
||||
chatMessages().add(primaryUserMessage);
|
||||
chatMessages.add(primaryUserMessage);
|
||||
Message assistantMessage = new Message(ChatConstant.Character.ASSISTANT, response);
|
||||
chatMessages().add(assistantMessage);
|
||||
chatMessages.add(assistantMessage);
|
||||
cognationCapability.getMessageLock().unlock();
|
||||
//设置上下文
|
||||
runningFlowContext.getModuleContext().getExtraContext().put("total_token", chatResponse.getUsageBean().getTotal_tokens());
|
||||
@@ -184,7 +189,7 @@ public class CoreModel extends AbstractAgentRunningModule<PartnerRunningFlowCont
|
||||
|
||||
private void setMessage(String coreContextStr) {
|
||||
Message userMessage = new Message(ChatConstant.Character.USER, coreContextStr);
|
||||
chatMessages().add(userMessage);
|
||||
getModel().getChatMessages().add(userMessage);
|
||||
}
|
||||
|
||||
private void handleExceptionResponse(JSONObject response, String chatResponse) {
|
||||
@@ -193,7 +198,7 @@ public class CoreModel extends AbstractAgentRunningModule<PartnerRunningFlowCont
|
||||
}
|
||||
|
||||
private void setMessageCount(PartnerRunningFlowContext runningFlowContext) {
|
||||
runningFlowContext.getModuleContext().getExtraContext().put("message_count", chatMessages().size());
|
||||
runningFlowContext.getModuleContext().getExtraContext().put("message_count", getModel().getChatMessages().size());
|
||||
}
|
||||
|
||||
private void setAppendedPromptMessage(List<AppendPromptData> appendPrompt) {
|
||||
|
||||
@@ -0,0 +1,27 @@
|
||||
package work.slhaf.partner.api.agent.factory.context
|
||||
|
||||
import com.alibaba.fastjson2.JSONArray
|
||||
import work.slhaf.partner.api.agent.factory.module.abstracts.AbstractAgentModule
|
||||
|
||||
object AgentContext {
|
||||
|
||||
}
|
||||
|
||||
sealed class ModuleContextData<T : AbstractAgentModule> {
|
||||
abstract val name: String
|
||||
abstract val clazz: Class<T>
|
||||
abstract val instance: T
|
||||
abstract val prompt: JSONArray
|
||||
abstract val modelActivated: Boolean
|
||||
|
||||
data class RunningModule<T : AbstractAgentModule>(
|
||||
override val name: String,
|
||||
override val clazz: Class<T>,
|
||||
override val instance: T,
|
||||
override val prompt: JSONArray,
|
||||
override val modelActivated: Boolean,
|
||||
|
||||
val order: Int
|
||||
) : ModuleContextData<T>()
|
||||
|
||||
}
|
||||
@@ -1,96 +0,0 @@
|
||||
package work.slhaf.partner.api.agent.factory.module.abstracts;
|
||||
|
||||
import cn.hutool.core.bean.BeanUtil;
|
||||
import work.slhaf.partner.api.agent.factory.config.pojo.ModelConfig;
|
||||
import work.slhaf.partner.api.agent.factory.module.annotation.Init;
|
||||
import work.slhaf.partner.api.agent.runtime.config.AgentConfigManager;
|
||||
import work.slhaf.partner.api.agent.runtime.interaction.flow.entity.Model;
|
||||
import work.slhaf.partner.api.chat.ChatClient;
|
||||
import work.slhaf.partner.api.chat.constant.ChatConstant;
|
||||
import work.slhaf.partner.api.chat.pojo.ChatResponse;
|
||||
import work.slhaf.partner.api.chat.pojo.Message;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
public interface ActivateModel {
|
||||
|
||||
AgentConfigManager AGENT_CONFIG_MANAGER = AgentConfigManager.INSTANCE;
|
||||
|
||||
@Init(order = -1)
|
||||
default void modelSettings() {
|
||||
Model model = getModel();
|
||||
ModelConfig modelConfig = AgentConfigManager.INSTANCE.loadModelConfig(modelKey());
|
||||
model.setBaseMessages(withBasicPrompt() ? loadSpecificPromptAndBasicPrompt(modelKey()) : loadSpecificPrompt(modelKey()));
|
||||
model.setChatClient(new ChatClient(modelConfig.getBaseUrl(), modelConfig.getApikey(), modelConfig.getModel()));
|
||||
}
|
||||
|
||||
default void updateModelSettings(ChatClient newChatClient) {
|
||||
BeanUtil.copyProperties(newChatClient, chatClient());
|
||||
}
|
||||
|
||||
private List<Message> loadSpecificPrompt(String modelKey) {
|
||||
return AGENT_CONFIG_MANAGER.loadModelPrompt(modelKey);
|
||||
}
|
||||
|
||||
private List<Message> loadSpecificPromptAndBasicPrompt(String modelKey) {
|
||||
List<Message> messages = new ArrayList<>();
|
||||
messages.addAll(AGENT_CONFIG_MANAGER.loadModelPrompt("basic"));
|
||||
messages.addAll(AGENT_CONFIG_MANAGER.loadModelPrompt(modelKey));
|
||||
return messages;
|
||||
}
|
||||
|
||||
default ChatResponse chat() {
|
||||
Model model = getModel();
|
||||
List<Message> temp = new ArrayList<>();
|
||||
temp.addAll(model.getBaseMessages());
|
||||
temp.addAll(model.getChatMessages());
|
||||
return model.getChatClient().runChat(temp);
|
||||
}
|
||||
|
||||
default ChatResponse singleChat(String input) {
|
||||
Model model = getModel();
|
||||
List<Message> temp = new ArrayList<>(model.getBaseMessages());
|
||||
temp.add(new Message(ChatConstant.Character.USER, input));
|
||||
return model.getChatClient().runChat(temp);
|
||||
}
|
||||
|
||||
default void updateChatClientSettings() {
|
||||
Model model = getModel();
|
||||
model.getChatClient().setTemperature(0.4);
|
||||
model.getChatClient().setTop_p(0.8);
|
||||
}
|
||||
|
||||
default List<Message> chatMessages() {
|
||||
return getModel().getChatMessages();
|
||||
}
|
||||
|
||||
default List<Message> baseMessages() {
|
||||
return getModel().getBaseMessages();
|
||||
}
|
||||
|
||||
default ChatClient chatClient() {
|
||||
return getModel().getChatClient();
|
||||
}
|
||||
|
||||
/**
|
||||
* 仅适用Module子类,否则需要重写
|
||||
*
|
||||
* @return 持有的model实例
|
||||
*/
|
||||
default Model getModel() {
|
||||
return ((AbstractAgentModule) this).getModel();
|
||||
}
|
||||
|
||||
default void setModel(Model model) {
|
||||
((AbstractAgentModule) this).setModel(model);
|
||||
}
|
||||
|
||||
/**
|
||||
* 对应调用的模型配置名称
|
||||
*/
|
||||
String modelKey();
|
||||
|
||||
boolean withBasicPrompt();
|
||||
|
||||
}
|
||||
@@ -1,16 +1,105 @@
|
||||
package work.slhaf.partner.api.agent.factory.module.abstracts;
|
||||
package work.slhaf.partner.api.agent.factory.module.abstracts
|
||||
|
||||
import lombok.Getter;
|
||||
import lombok.Setter;
|
||||
import work.slhaf.partner.api.agent.runtime.interaction.flow.entity.Model;
|
||||
import work.slhaf.partner.api.agent.factory.module.annotation.Init
|
||||
import work.slhaf.partner.api.agent.runtime.config.AgentConfigManager
|
||||
import work.slhaf.partner.api.agent.runtime.interaction.flow.entity.RunningFlowContext
|
||||
import work.slhaf.partner.api.chat.ChatClient
|
||||
import work.slhaf.partner.api.chat.constant.ChatConstant
|
||||
import work.slhaf.partner.api.chat.pojo.ChatResponse
|
||||
import work.slhaf.partner.api.chat.pojo.Message
|
||||
|
||||
/**
|
||||
* 模块基类
|
||||
*/
|
||||
public abstract class AbstractAgentModule {
|
||||
abstract class AbstractAgentModule {
|
||||
var moduleName: String = javaClass.simpleName
|
||||
}
|
||||
|
||||
@Getter
|
||||
@Setter
|
||||
protected Model model = new Model();
|
||||
data class Model(
|
||||
val chatClient: ChatClient,
|
||||
val chatMessages: MutableList<Message> = mutableListOf(),
|
||||
val baseMessages: MutableList<Message> = mutableListOf()
|
||||
)
|
||||
|
||||
interface RunningModule<T : RunningFlowContext> {
|
||||
fun execute(context: T)
|
||||
}
|
||||
|
||||
interface SubModule<I, O> {
|
||||
fun execute(input: I): O
|
||||
}
|
||||
|
||||
interface StandaloneModule
|
||||
|
||||
interface ActivateModel {
|
||||
|
||||
companion object {
|
||||
val configManager: AgentConfigManager = AgentConfigManager.INSTANCE
|
||||
val modelMap: MutableMap<String, Model> = mutableMapOf()
|
||||
}
|
||||
|
||||
fun getModel(): Model {
|
||||
fun buildModel(): Model {
|
||||
val modelConfig = configManager.loadModelConfig(modelKey())
|
||||
val chatClient = ChatClient(modelConfig.baseUrl, modelConfig.apikey, modelConfig.model)
|
||||
val model = Model(chatClient)
|
||||
|
||||
val baseMessages = if (withBasicPrompt()) {
|
||||
loadSpecificPromptAndBasicPrompt(modelKey())
|
||||
} else {
|
||||
configManager.loadModelPrompt(modelKey())
|
||||
}
|
||||
model.baseMessages.addAll(baseMessages)
|
||||
return model
|
||||
}
|
||||
|
||||
val model = modelMap.computeIfAbsent(modelKey()) {
|
||||
buildModel()
|
||||
}
|
||||
return model
|
||||
}
|
||||
|
||||
@Init(order = -1)
|
||||
fun modelSettings() {
|
||||
val model = getModel()
|
||||
modelMap[modelKey()] = model
|
||||
}
|
||||
|
||||
private fun loadSpecificPromptAndBasicPrompt(modelKey: String): MutableList<Message> {
|
||||
val messages: MutableList<Message> = ArrayList()
|
||||
messages.addAll(configManager.loadModelPrompt("basic"))
|
||||
messages.addAll(configManager.loadModelPrompt(modelKey))
|
||||
return messages
|
||||
}
|
||||
|
||||
fun chat(): ChatResponse {
|
||||
val model = this.getModel()
|
||||
val temp = ArrayList<Message?>()
|
||||
temp.addAll(model.baseMessages)
|
||||
temp.addAll(model.chatMessages)
|
||||
return model.chatClient.runChat(temp)
|
||||
}
|
||||
|
||||
fun singleChat(input: String): ChatResponse {
|
||||
val model = this.getModel()
|
||||
val temp = ArrayList<Message>(model.baseMessages)
|
||||
temp.add(Message(ChatConstant.Character.USER, input))
|
||||
return model.chatClient.runChat(temp)
|
||||
}
|
||||
|
||||
fun updateChatClientSettings() {
|
||||
val model = this.getModel()
|
||||
model.chatClient.temperature = 0.4
|
||||
model.chatClient.top_p = 0.8
|
||||
}
|
||||
|
||||
/**
|
||||
* 对应调用的模型配置名称
|
||||
*/
|
||||
fun modelKey(): String {
|
||||
return (this as AbstractAgentModule).moduleName
|
||||
}
|
||||
|
||||
fun withBasicPrompt(): Boolean
|
||||
|
||||
}
|
||||
|
||||
@@ -1,16 +0,0 @@
|
||||
package work.slhaf.partner.api.agent.runtime.interaction.flow.entity;
|
||||
|
||||
import lombok.Data;
|
||||
import work.slhaf.partner.api.chat.ChatClient;
|
||||
import work.slhaf.partner.api.chat.pojo.Message;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
@Data
|
||||
public class Model {
|
||||
|
||||
protected ChatClient chatClient;
|
||||
protected List<Message> chatMessages;
|
||||
protected List<Message> baseMessages;
|
||||
|
||||
}
|
||||
Reference in New Issue
Block a user