推进 ActionExtractor 语义缓存机制: 两种嵌入模型的连接方式测试完毕,在高性能主机上,可以通过ollama调用mxbai-embed-large这类模型,但放到4核8G香橙派3B就会出现推理时长过长,哪怕换成ONNX RUNTIME JAVA 也难以避免,但如果更换成 nomic-embed-text + ONNX RUNTIME JAVA ,仍能够拿到70左右ms的推理时长,远低于提取模型以及向量模型API的调用时长。预期可提供两种语义缓存所用的嵌入模型接入方式: 通过 http 调用 本地ollama接口; 指定 ONNX 格式的嵌入模型直接调用。

This commit is contained in:
2025-10-16 23:04:41 +08:00
parent e78048f66d
commit 7094a8a68b
7 changed files with 230 additions and 38 deletions

1
.idea/vcs.xml generated
View File

@@ -2,6 +2,5 @@
<project version="4">
<component name="VcsDirectoryMappings">
<mapping directory="$PROJECT_DIR$" vcs="Git" />
<mapping directory="$USER_HOME$/Projects/IdeaProjects/Projects/Partner" vcs="Git" />
</component>
</project>

View File

@@ -1,33 +0,0 @@
autoDetectedPackages:
- factory
- module
- work.slhaf
enableAutoDetect: true
entryDisplayConfig:
excludedPathPatterns: []
skipJsCss: true
funcDisplayConfig:
skipConstructors: false
skipFieldAccess: true
skipFieldChange: true
skipGetters: false
skipNonProjectPackages: false
skipPrivateMethods: false
skipSetters: false
ignoreSameClassCall: null
ignoreSamePackageCall: null
includedPackagePrefixes: null
includedParentClasses: null
maxColSize: 32
maxNumFirst: 12
maxNumFirstImportant: 1024
maxNumHash: 3
maxNumHashImportant: 256
maxObjectDepth: 4
maxStrSize: 4096
name: xcodemap-filter
openMainWindow: true
recordMode: manual
sourceDisplayConfig:
color: blue
startOnDebug: false

File diff suppressed because one or more lines are too long

View File

@@ -34,6 +34,16 @@
<artifactId>nd4j-api</artifactId>
<version>1.0.0-M2.1</version>
</dependency>
<dependency>
<groupId>com.microsoft.onnxruntime</groupId>
<artifactId>onnxruntime</artifactId>
<version>1.23.1</version>
</dependency>
<dependency>
<groupId>ai.djl.huggingface</groupId>
<artifactId>tokenizers</artifactId>
<version>0.34.0</version>
</dependency>
</dependencies>
<properties>

View File

@@ -8,8 +8,8 @@ import work.slhaf.partner.api.chat.pojo.Message;
import work.slhaf.partner.common.thread.InteractionThreadPoolExecutor;
import work.slhaf.partner.core.action.ActionCapability;
import work.slhaf.partner.core.action.entity.*;
import work.slhaf.partner.core.cache.CacheCapability;
import work.slhaf.partner.core.cognation.CognationCapability;
import work.slhaf.partner.core.memory.MemoryCapability;
import work.slhaf.partner.core.perceive.PerceiveCapability;
import work.slhaf.partner.module.common.module.PreRunningModule;
import work.slhaf.partner.module.modules.action.planner.confirmer.ActionConfirmer;
@@ -42,7 +42,7 @@ public class ActionPlanner extends PreRunningModule {
@InjectCapability
private PerceiveCapability perceiveCapability;
@InjectCapability
private CacheCapability cacheCapability;
private MemoryCapability memoryCapability;
@InjectModule
private ActionEvaluator actionEvaluator;
@@ -197,7 +197,7 @@ public class ActionPlanner extends PreRunningModule {
input.setTendencies(extractorResult.getTendencies());
input.setUser(perceiveCapability.getUser(userId));
input.setRecentMessages(cognationCapability.getChatMessages());
input.setActivatedSlices(cacheCapability.getActivatedSlices(userId));
input.setActivatedSlices(memoryCapability.getActivatedSlices(userId));
return input;
}

View File

@@ -133,8 +133,10 @@ public class MemorySelector extends PreRunningModule {
return "[记忆模块]";
}
protected HashMap<String, String> getPromptDataMap(String userId) {
@Override
protected HashMap<String, String> getPromptDataMap(PartnerRunningFlowContext context) {
HashMap<String, String> map = new HashMap<>();
String userId = context.getUserId();
String dialogMapStr = memoryCapability.getDialogMapStr();
if (!dialogMapStr.isEmpty()) {
map.put("[记忆缓存] <你最近两日和所有聊天者的对话记忆印象>", dialogMapStr);

View File

@@ -0,0 +1,103 @@
import ai.djl.huggingface.tokenizers.Encoding;
import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer;
import ai.onnxruntime.OnnxTensor;
import ai.onnxruntime.OrtEnvironment;
import ai.onnxruntime.OrtException;
import ai.onnxruntime.OrtSession;
import lombok.extern.slf4j.Slf4j;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
import java.io.IOException;
import java.nio.file.Path;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
@Slf4j
public class OnnxTest {
static String tokenizer_json;
static String base;
static String model;
@BeforeAll
static void init() {
base = "/home/slhaf/IdeaProjects/Projects/Partner/data/vector/";
tokenizer_json = base + "tokenizer.json";
model = base + "model_quantized.onnx";
}
@Test
void tokenizerTest() throws IOException {
long l1 = System.currentTimeMillis();
HuggingFaceTokenizer tokenizer = HuggingFaceTokenizer.newInstance(Path.of(tokenizer_json));
long l2 = System.currentTimeMillis();
Encoding encode = tokenizer.encode("test: Hello World");
long l3 = System.currentTimeMillis();
long[] ids = encode.getIds();
long[] attentionMask = encode.getAttentionMask();
log.info(Arrays.toString(ids));
log.info("-----------------------------");
log.info(Arrays.toString(attentionMask));
log.info("-----------------------------");
log.info("加载耗时: {}ms", l2 - l1);
log.info("计算耗时: {}ms", l3 - l2);
tokenizer.close();
/* 输出:
* [101, 3231, 1024, 7592, 2088, 102]
* -----------------------------
* [1, 1, 1, 1, 1, 1]
* -----------------------------
* 加载耗时: 4206ms
* 计算耗时: 1ms
*/
}
@Test
void onnxTest() throws IOException, OrtException {
long l1 = System.currentTimeMillis();
HuggingFaceTokenizer tokenizer = HuggingFaceTokenizer.newInstance(Path.of(tokenizer_json));
long l2 = System.currentTimeMillis();//tokenizer加载耗时
Encoding encode = tokenizer.encode("test: Hello World");
long l3 = System.currentTimeMillis();//计算耗时
long[] ids = encode.getIds();
long[] attentionMask = encode.getAttentionMask();
long l4 = System.currentTimeMillis();
OrtEnvironment env = OrtEnvironment.getEnvironment();
OrtSession.SessionOptions ops = new OrtSession.SessionOptions();
OrtSession session = env.createSession(model, ops);
long l5 = System.currentTimeMillis();//模型加载耗时
long[][] inputIdsBatch = {ids};
long[][] attentionMaskBatch = {attentionMask};
long[][] tokenTypeIdsBatch = {new long[ids.length]}; // 初始化全 0
for (int i = 0; i < ids.length; i++) tokenTypeIdsBatch[0][i] = 0;
OnnxTensor inputTensor = OnnxTensor.createTensor(env, inputIdsBatch);
OnnxTensor maskTensor = OnnxTensor.createTensor(env, attentionMaskBatch);
OnnxTensor tokenTypeTensor = OnnxTensor.createTensor(env, tokenTypeIdsBatch);
Map<String, OnnxTensor> inputs = new HashMap<>();
inputs.put("input_ids", inputTensor);
inputs.put("attention_mask", maskTensor);
inputs.put("token_type_ids", tokenTypeTensor);
long l6 = System.currentTimeMillis();
OrtSession.Result result = session.run(inputs);
long l7 = System.currentTimeMillis();//模型计算耗时
OnnxTensor embeddingTensor = (OnnxTensor) result.get(0);
float[] embeddings = embeddingTensor.getFloatBuffer().array();
log.info(Arrays.toString(embeddings));
log.info("------------------------");
log.info("tokenizer加载耗时: {}ms", l2 - l1);
log.info("tokenizer计算耗时: {}ms", l3 - l2);
log.info("模型加载耗时: {}ms", l5 - l4);
log.info("模型数据准备耗时: {}ms", l6 - l5);
log.info("模型计算耗时: {}ms", l7 - l6);
tokenizer.close();
}
}