diff --git a/Partner-Core/src/main/java/work/slhaf/partner/core/cognition/impression/Entity.kt b/Partner-Core/src/main/java/work/slhaf/partner/core/cognition/impression/Entity.kt index dc85a413..494a5889 100644 --- a/Partner-Core/src/main/java/work/slhaf/partner/core/cognition/impression/Entity.kt +++ b/Partner-Core/src/main/java/work/slhaf/partner/core/cognition/impression/Entity.kt @@ -138,7 +138,7 @@ class Entity @JvmOverloads constructor( val relationObject = relationValue as? JSONObject ?: return@forEach val relationMap = mutableMapOf() relationObject.forEach { (relation, strengthValue) -> - numberValue(strengthValue)?.let { relationMap[relation] = it } + doubleValue(strengthValue)?.let { relationMap[relation] = it } } if (relationMap.isNotEmpty()) { relations[target] = relationMap @@ -192,7 +192,7 @@ class Entity @JvmOverloads constructor( state.forEach { (key, value) -> val data = when (value) { is JSONObject -> loadIndexableData(value) - else -> IndexableData(numberValue(value) ?: return@forEach) + else -> IndexableData(doubleValue(value) ?: return@forEach) } loaded[key] = data } @@ -203,9 +203,9 @@ class Entity @JvmOverloads constructor( val data = IndexableData(state.getDouble("confidence") ?: 1.0) state.getJSONObject("vectors")?.forEach { (embeddingModel, vectorValue) -> val vectorArray = vectorValue as? JSONArray ?: return@forEach - val vector = DoubleArray(vectorArray.size) - for (index in 0 until vectorArray.size) { - vector[index] = numberValue(vectorArray[index]) ?: return@forEach + val vector = FloatArray(vectorArray.size) + for (index in vectorArray.indices) { + vector[index] = floatValue(vectorArray[index]) ?: return@forEach } data.updateVector(embeddingModel, vector) } @@ -220,29 +220,35 @@ class Entity @JvmOverloads constructor( ) } - private fun numberValue(value: Any?): Double? = when (value) { + private fun doubleValue(value: Any?): Double? = when (value) { is Number -> value.toDouble() is String -> value.toDoubleOrNull() else -> null } + private fun floatValue(value: Any?): Float? = when (value) { + is Number -> value.toFloat() + is String -> value.toFloatOrNull() + else -> null + } + data class IndexableData( var confidence: Double ) { - private val vectors: ConcurrentHashMap = ConcurrentHashMap() + private val vectors: ConcurrentHashMap = ConcurrentHashMap() fun updateVector( embeddingModel: String, - vector: DoubleArray + vector: FloatArray ) { vectors[embeddingModel] = vector.copyOf() } - fun getVector(embeddingModel: String): DoubleArray? { + fun getVector(embeddingModel: String): FloatArray? { return vectors[embeddingModel]?.copyOf() } - fun snapshotVectors(): Map { + fun snapshotVectors(): Map { return vectors.mapValues { (_, vector) -> vector.copyOf() } } } @@ -261,6 +267,6 @@ class Entity @JvmOverloads constructor( data class ImpressionView( val impression: String, val confidence: Double, - val vector: DoubleArray? + val vector: FloatArray? ) } diff --git a/Partner-Core/src/main/java/work/slhaf/partner/core/cognition/impression/ImpressionVectorIndex.java b/Partner-Core/src/main/java/work/slhaf/partner/core/cognition/impression/ImpressionVectorIndex.java index ac638b4b..b097ab31 100644 --- a/Partner-Core/src/main/java/work/slhaf/partner/core/cognition/impression/ImpressionVectorIndex.java +++ b/Partner-Core/src/main/java/work/slhaf/partner/core/cognition/impression/ImpressionVectorIndex.java @@ -1,13 +1,38 @@ package work.slhaf.partner.core.cognition.impression; +import work.slhaf.partner.common.vector.VectorClient; + +import java.util.concurrent.Executor; +import java.util.concurrent.Executors; + public class ImpressionVectorIndex { - public void sync(Entity entity){ - // TODO sync entity impressions/features with vector index. + private final Executor executor = Executors.newFixedThreadPool(2, r -> { + Thread thread = new Thread(r, "impression-vector-index"); + thread.setDaemon(true); + return thread; + }); + + public void sync(Entity entity) { + if (!VectorClient.status){ + return; + } + entity.snapshotFeatures().forEach(this::upsert); + entity.snapshotImpressions().forEach(this::upsert); } - public void upsert(String content, Entity.IndexableData indexableData){ - // TODO update vector for content when embedding/vector client boundary is finalized. + public void upsert(String text, Entity.IndexableData indexableData){ + if (VectorClient.status){ + return; + } + String modelId = VectorClient.VECTOR_MODEL_ID; + if (indexableData.getVector(modelId) != null) { + return; + } + executor.execute(() -> { + float[] vector = VectorClient.INSTANCE.compute(text); + indexableData.updateVector(modelId,vector); + }); } }