mirror of
https://github.com/slhaf/Partner.git
synced 2026-05-12 08:43:02 +08:00
推进 ActionExtractor 语义缓存机制: 两种嵌入模型的连接方式测试完毕,在高性能主机上,可以通过ollama调用mxbai-embed-large这类模型,但放到4核8G香橙派3B就会出现推理时长过长,哪怕换成ONNX RUNTIME JAVA 也难以避免,但如果更换成 nomic-embed-text + ONNX RUNTIME JAVA ,仍能够拿到70左右ms的推理时长,远低于提取模型以及向量模型API的调用时长。预期可提供两种语义缓存所用的嵌入模型接入方式: 通过 http 调用 本地ollama接口; 指定 ONNX 格式的嵌入模型直接调用。
This commit is contained in:
@@ -8,8 +8,8 @@ import work.slhaf.partner.api.chat.pojo.Message;
|
||||
import work.slhaf.partner.common.thread.InteractionThreadPoolExecutor;
|
||||
import work.slhaf.partner.core.action.ActionCapability;
|
||||
import work.slhaf.partner.core.action.entity.*;
|
||||
import work.slhaf.partner.core.cache.CacheCapability;
|
||||
import work.slhaf.partner.core.cognation.CognationCapability;
|
||||
import work.slhaf.partner.core.memory.MemoryCapability;
|
||||
import work.slhaf.partner.core.perceive.PerceiveCapability;
|
||||
import work.slhaf.partner.module.common.module.PreRunningModule;
|
||||
import work.slhaf.partner.module.modules.action.planner.confirmer.ActionConfirmer;
|
||||
@@ -42,7 +42,7 @@ public class ActionPlanner extends PreRunningModule {
|
||||
@InjectCapability
|
||||
private PerceiveCapability perceiveCapability;
|
||||
@InjectCapability
|
||||
private CacheCapability cacheCapability;
|
||||
private MemoryCapability memoryCapability;
|
||||
|
||||
@InjectModule
|
||||
private ActionEvaluator actionEvaluator;
|
||||
@@ -197,7 +197,7 @@ public class ActionPlanner extends PreRunningModule {
|
||||
input.setTendencies(extractorResult.getTendencies());
|
||||
input.setUser(perceiveCapability.getUser(userId));
|
||||
input.setRecentMessages(cognationCapability.getChatMessages());
|
||||
input.setActivatedSlices(cacheCapability.getActivatedSlices(userId));
|
||||
input.setActivatedSlices(memoryCapability.getActivatedSlices(userId));
|
||||
return input;
|
||||
}
|
||||
|
||||
|
||||
@@ -133,8 +133,10 @@ public class MemorySelector extends PreRunningModule {
|
||||
return "[记忆模块]";
|
||||
}
|
||||
|
||||
protected HashMap<String, String> getPromptDataMap(String userId) {
|
||||
@Override
|
||||
protected HashMap<String, String> getPromptDataMap(PartnerRunningFlowContext context) {
|
||||
HashMap<String, String> map = new HashMap<>();
|
||||
String userId = context.getUserId();
|
||||
String dialogMapStr = memoryCapability.getDialogMapStr();
|
||||
if (!dialogMapStr.isEmpty()) {
|
||||
map.put("[记忆缓存] <你最近两日和所有聊天者的对话记忆印象>", dialogMapStr);
|
||||
|
||||
103
Partner-Main/src/test/java/OnnxTest.java
Normal file
103
Partner-Main/src/test/java/OnnxTest.java
Normal file
@@ -0,0 +1,103 @@
|
||||
import ai.djl.huggingface.tokenizers.Encoding;
|
||||
import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer;
|
||||
import ai.onnxruntime.OnnxTensor;
|
||||
import ai.onnxruntime.OrtEnvironment;
|
||||
import ai.onnxruntime.OrtException;
|
||||
import ai.onnxruntime.OrtSession;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.junit.jupiter.api.BeforeAll;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.nio.file.Path;
|
||||
import java.util.Arrays;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
|
||||
@Slf4j
|
||||
public class OnnxTest {
|
||||
static String tokenizer_json;
|
||||
static String base;
|
||||
static String model;
|
||||
|
||||
@BeforeAll
|
||||
static void init() {
|
||||
base = "/home/slhaf/IdeaProjects/Projects/Partner/data/vector/";
|
||||
tokenizer_json = base + "tokenizer.json";
|
||||
model = base + "model_quantized.onnx";
|
||||
}
|
||||
|
||||
@Test
|
||||
void tokenizerTest() throws IOException {
|
||||
long l1 = System.currentTimeMillis();
|
||||
HuggingFaceTokenizer tokenizer = HuggingFaceTokenizer.newInstance(Path.of(tokenizer_json));
|
||||
long l2 = System.currentTimeMillis();
|
||||
Encoding encode = tokenizer.encode("test: Hello World");
|
||||
long l3 = System.currentTimeMillis();
|
||||
long[] ids = encode.getIds();
|
||||
long[] attentionMask = encode.getAttentionMask();
|
||||
log.info(Arrays.toString(ids));
|
||||
log.info("-----------------------------");
|
||||
log.info(Arrays.toString(attentionMask));
|
||||
log.info("-----------------------------");
|
||||
log.info("加载耗时: {}ms", l2 - l1);
|
||||
log.info("计算耗时: {}ms", l3 - l2);
|
||||
tokenizer.close();
|
||||
/* 输出:
|
||||
* [101, 3231, 1024, 7592, 2088, 102]
|
||||
* -----------------------------
|
||||
* [1, 1, 1, 1, 1, 1]
|
||||
* -----------------------------
|
||||
* 加载耗时: 4206ms
|
||||
* 计算耗时: 1ms
|
||||
*/
|
||||
}
|
||||
|
||||
@Test
|
||||
void onnxTest() throws IOException, OrtException {
|
||||
long l1 = System.currentTimeMillis();
|
||||
HuggingFaceTokenizer tokenizer = HuggingFaceTokenizer.newInstance(Path.of(tokenizer_json));
|
||||
long l2 = System.currentTimeMillis();//tokenizer加载耗时
|
||||
Encoding encode = tokenizer.encode("test: Hello World");
|
||||
long l3 = System.currentTimeMillis();//计算耗时
|
||||
|
||||
long[] ids = encode.getIds();
|
||||
long[] attentionMask = encode.getAttentionMask();
|
||||
|
||||
long l4 = System.currentTimeMillis();
|
||||
OrtEnvironment env = OrtEnvironment.getEnvironment();
|
||||
OrtSession.SessionOptions ops = new OrtSession.SessionOptions();
|
||||
OrtSession session = env.createSession(model, ops);
|
||||
long l5 = System.currentTimeMillis();//模型加载耗时
|
||||
|
||||
long[][] inputIdsBatch = {ids};
|
||||
long[][] attentionMaskBatch = {attentionMask};
|
||||
long[][] tokenTypeIdsBatch = {new long[ids.length]}; // 初始化全 0
|
||||
for (int i = 0; i < ids.length; i++) tokenTypeIdsBatch[0][i] = 0;
|
||||
|
||||
OnnxTensor inputTensor = OnnxTensor.createTensor(env, inputIdsBatch);
|
||||
OnnxTensor maskTensor = OnnxTensor.createTensor(env, attentionMaskBatch);
|
||||
OnnxTensor tokenTypeTensor = OnnxTensor.createTensor(env, tokenTypeIdsBatch);
|
||||
|
||||
Map<String, OnnxTensor> inputs = new HashMap<>();
|
||||
inputs.put("input_ids", inputTensor);
|
||||
inputs.put("attention_mask", maskTensor);
|
||||
inputs.put("token_type_ids", tokenTypeTensor);
|
||||
|
||||
long l6 = System.currentTimeMillis();
|
||||
OrtSession.Result result = session.run(inputs);
|
||||
long l7 = System.currentTimeMillis();//模型计算耗时
|
||||
OnnxTensor embeddingTensor = (OnnxTensor) result.get(0);
|
||||
float[] embeddings = embeddingTensor.getFloatBuffer().array();
|
||||
|
||||
|
||||
log.info(Arrays.toString(embeddings));
|
||||
log.info("------------------------");
|
||||
log.info("tokenizer加载耗时: {}ms", l2 - l1);
|
||||
log.info("tokenizer计算耗时: {}ms", l3 - l2);
|
||||
log.info("模型加载耗时: {}ms", l5 - l4);
|
||||
log.info("模型数据准备耗时: {}ms", l6 - l5);
|
||||
log.info("模型计算耗时: {}ms", l7 - l6);
|
||||
tokenizer.close();
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user