Skip to content

Commit

Permalink
Merge pull request #123 from vitrivr/feature/filtercriteria
Browse files Browse the repository at this point in the history
[Feature] Filter Criteria for Late Filtering
  • Loading branch information
net-cscience-raphael authored Dec 18, 2024
2 parents 0533298 + 15bd7f6 commit 68fd3d4
Show file tree
Hide file tree
Showing 17 changed files with 436 additions and 16 deletions.
2 changes: 1 addition & 1 deletion example-configs/schema/dense.json
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
"clip": {
"factory": "DenseEmbedding",
"parameters": {
"host": "http://10.34.64.84:8888/",
"host": "http://10.34.64.83:8888/",
"model": "open-clip-vit-b32",
"length": "512",
"timeoutSeconds": "100",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,10 +112,9 @@ enum class ComparisonOperator(val value: String) {
when (v1) {
is Value.String,
is Value.Text -> {
(v1.value as String).replace("\\", "\\\\").replace("[", "\\[").replace("]", "\\]")
.replace("*", "\\*").replace("%", "*").toRegex().matches(v2.value as String)
(v2.value as String).replace("\\", "\\\\").replace("[", "\\[").replace("]", "\\]")
.replace("*", "\\*").replace("%", ".*").replace("_", ".?").toRegex().matches(v1.value as String)
}

else -> false
}

Expand All @@ -128,7 +127,7 @@ enum class ComparisonOperator(val value: String) {
* @param str The [String] which should be one of the [ComparisonOperator]
* @throws IllegalArgumentException In case the given string is not one of the defined ones.
*/
fun fromString(str: String): ComparisonOperator {
infix fun fromString(str: String): ComparisonOperator {
return when (str.trim()) {
EQ.value -> EQ
NEQ.value -> NEQ
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,19 @@ enum class Distance {
override fun invoke(v1: Value.DoubleVector, v2: Value.DoubleVector): Double = throw UnsupportedOperationException("Jaccard distance is not supported for float vectors.")
};

companion object {
infix fun fromString(value: String): Distance {
return when (value) {
"manhattan" -> MANHATTAN
"euclidean" -> EUCLIDEAN
"cosine" -> COSINE
"hamming" -> HAMMING
"jaccard" -> JACCARD
else -> throw IllegalArgumentException("Distance function $value is not supported.")
}
}
}

/**
* Calculates this [Distance] between two [Value.FloatVector].
*
Expand All @@ -115,4 +128,5 @@ enum class Distance {
* @return [Double]
*/
abstract operator fun invoke(v1: Value.DoubleVector, v2: Value.DoubleVector): Double
}
}

Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ abstract class AbstractFileMetadataDescriptorReaderTest(schemaPath: String) : Ab

/* Check results. */
val result = reader.query(query).toList()
Assertions.assertTrue(result.isNotEmpty())
for (r in result) {
Assertions.assertTrue(r.path.value.endsWith(".jpg"))
}
Expand Down Expand Up @@ -139,6 +140,7 @@ abstract class AbstractFileMetadataDescriptorReaderTest(schemaPath: String) : Ab

/* Check results. */
val result = reader.query(query).toList()
Assertions.assertTrue(result.isNotEmpty())
for (r in result) {
Assertions.assertTrue(r.size.value > size.value)
}
Expand Down Expand Up @@ -166,6 +168,7 @@ abstract class AbstractFileMetadataDescriptorReaderTest(schemaPath: String) : Ab

/* Check results. */
val result = reader.query(query).toList()
Assertions.assertTrue(result.isNotEmpty())
for (r in result) {
Assertions.assertTrue(r.size.value < size.value)
}
Expand All @@ -191,6 +194,7 @@ abstract class AbstractFileMetadataDescriptorReaderTest(schemaPath: String) : Ab

/* Check results. */
val result = reader.query(query).toList()
// TODO enable Assertions.assertTrue(result.isNotEmpty())
for (r in result) {
Assertions.assertTrue(r.path.value.contains("var"))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import org.vitrivr.engine.core.model.descriptor.vector.FloatVectorDescriptor
import org.vitrivr.engine.core.model.metamodel.Analyser.Companion.merge
import org.vitrivr.engine.core.model.metamodel.Schema
import org.vitrivr.engine.core.model.query.Query
import org.vitrivr.engine.core.model.query.basics.Distance
import org.vitrivr.engine.core.model.query.proximity.ProximityQuery
import org.vitrivr.engine.core.model.retrievable.Retrievable
import org.vitrivr.engine.core.model.types.Value
Expand All @@ -36,6 +37,8 @@ class DenseEmbedding : ExternalFesAnalyser<ContentElement<*>, FloatVectorDescrip
companion object {
const val LENGTH_PARAMETER_DEFAULT = 512
const val LENGTH_PARAMETER_NAME = "length"
const val DISTANCE_PARAMETER_DEFAULT = "euclidean"
const val DISTANCE_PARAMETER_NAME = "distance"
}
override val contentClasses = setOf(ImageContent::class, TextContent::class)
override val descriptorClass = FloatVectorDescriptor::class
Expand Down Expand Up @@ -103,6 +106,7 @@ class DenseEmbedding : ExternalFesAnalyser<ContentElement<*>, FloatVectorDescrip
val retries = field.parameters[RETRIES_PARAMETER_NAME]?.toIntOrNull() ?: RETRIES_PARAMETER_DEFAULT
val model = field.parameters[MODEL_PARAMETER_NAME] ?: throw IllegalStateException("Model parameter not set.")
val k = context.getProperty(field.fieldName, "limit")?.toLongOrNull() ?: 1000L
val distance = Distance fromString (field.parameters[DISTANCE_PARAMETER_NAME] ?: DISTANCE_PARAMETER_DEFAULT)
val fetchVector = context.getProperty(field.fieldName, "returnDescriptor")?.toBooleanStrictOrNull() ?: false

/* Generate vector for content element. */
Expand All @@ -116,6 +120,6 @@ class DenseEmbedding : ExternalFesAnalyser<ContentElement<*>, FloatVectorDescrip
}

/* Return retriever. */
return this.newRetrieverForQuery(field, ProximityQuery(value = vector, k = k, fetchVector = fetchVector), context)
return this.newRetrieverForQuery(field, ProximityQuery(value = vector, distance = distance, k = k, fetchVector = fetchVector), context)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -45,15 +45,15 @@ class ScalarDescriptorReader(field: Schema.Field<*, ScalarDescriptor<*, *>>, con
val descriptorId = result.getObject(DESCRIPTOR_ID_COLUMN_NAME, UUID::class.java)
val retrievableId = result.getObject(RETRIEVABLE_ID_COLUMN_NAME, UUID::class.java)
return when (this.prototype) {
is BooleanDescriptor -> BooleanDescriptor(descriptorId, retrievableId, Value.Boolean(result.getBoolean(VALUE_ATTRIBUTE_NAME)))
is ByteDescriptor -> ByteDescriptor(descriptorId, retrievableId, Value.Byte(result.getByte(VALUE_ATTRIBUTE_NAME)))
is ShortDescriptor -> ShortDescriptor(descriptorId, retrievableId, Value.Short(result.getShort(VALUE_ATTRIBUTE_NAME)))
is IntDescriptor -> IntDescriptor(descriptorId, retrievableId, Value.Int(result.getInt(VALUE_ATTRIBUTE_NAME)))
is LongDescriptor -> LongDescriptor(descriptorId, retrievableId, Value.Long(result.getLong(VALUE_ATTRIBUTE_NAME)))
is FloatDescriptor -> FloatDescriptor(descriptorId, retrievableId, Value.Float(result.getFloat(VALUE_ATTRIBUTE_NAME)))
is DoubleDescriptor -> DoubleDescriptor(descriptorId, retrievableId, Value.Double(result.getDouble(VALUE_ATTRIBUTE_NAME)))
is StringDescriptor -> StringDescriptor(descriptorId, retrievableId, Value.String(result.getString(VALUE_ATTRIBUTE_NAME)))
is TextDescriptor -> TextDescriptor(descriptorId, retrievableId, Value.Text(result.getString(VALUE_ATTRIBUTE_NAME)))
is BooleanDescriptor -> BooleanDescriptor(descriptorId, retrievableId, Value.Boolean(result.getBoolean(VALUE_ATTRIBUTE_NAME)), this.field as Schema.Field<*, BooleanDescriptor>)
is ByteDescriptor -> ByteDescriptor(descriptorId, retrievableId, Value.Byte(result.getByte(VALUE_ATTRIBUTE_NAME)), this.field as Schema.Field<*, ByteDescriptor>)
is ShortDescriptor -> ShortDescriptor(descriptorId, retrievableId, Value.Short(result.getShort(VALUE_ATTRIBUTE_NAME)), this.field as Schema.Field<*, ShortDescriptor>)
is IntDescriptor -> IntDescriptor(descriptorId, retrievableId, Value.Int(result.getInt(VALUE_ATTRIBUTE_NAME)), this.field as Schema.Field<*, IntDescriptor>)
is LongDescriptor -> LongDescriptor(descriptorId, retrievableId, Value.Long(result.getLong(VALUE_ATTRIBUTE_NAME)), this.field as Schema.Field<*, LongDescriptor>)
is FloatDescriptor -> FloatDescriptor(descriptorId, retrievableId, Value.Float(result.getFloat(VALUE_ATTRIBUTE_NAME)), this.field as Schema.Field<*, FloatDescriptor>)
is DoubleDescriptor -> DoubleDescriptor(descriptorId, retrievableId, Value.Double(result.getDouble(VALUE_ATTRIBUTE_NAME)), this.field as Schema.Field<*, DoubleDescriptor>)
is StringDescriptor -> StringDescriptor(descriptorId, retrievableId, Value.String(result.getString(VALUE_ATTRIBUTE_NAME)), this.field as Schema.Field<*, StringDescriptor>)
is TextDescriptor -> TextDescriptor(descriptorId, retrievableId, Value.Text(result.getString(VALUE_ATTRIBUTE_NAME)),this.field as Schema.Field<*, TextDescriptor>)
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
package org.vitrivr.engine.query.operators.transform.benchmark

import io.github.oshai.kotlinlogging.KLogger
import io.github.oshai.kotlinlogging.KotlinLogging
import kotlinx.serialization.json.Json
import kotlinx.serialization.json.encodeToJsonElement
import java.io.*
import java.nio.file.Path
import java.util.concurrent.BlockingQueue
import java.util.concurrent.LinkedBlockingQueue


class BenchmarkLogger(val logfile: Path) : Runnable {
private val logger: KLogger = KotlinLogging.logger {}

private val queue: BlockingQueue<BenchmarkMessage> = LinkedBlockingQueue()

infix fun log(message: BenchmarkMessage) {
queue.add(message)
}

override fun run() {
while (true) {

val log = queue.take()
logger.info { log }


FileOutputStream(File(logfile.toString()), true).bufferedWriter().use { writer ->
writer.appendLine("${Json.encodeToJsonElement(log).toString()},")
writer.close()
}
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
package org.vitrivr.engine.query.operators.transform.benchmark

import kotlinx.serialization.Serializable

@Serializable
data class BenchmarkMessage (
val name: String,
val source: String,
val timestamp: String,
val inputSize: Int,
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
package org.vitrivr.engine.query.operators.transform.benchmark

import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.emitAll
import kotlinx.coroutines.flow.flow
import kotlinx.coroutines.flow.toList
import org.vitrivr.engine.core.database.descriptor.DescriptorReader
import org.vitrivr.engine.core.model.metamodel.Schema
import org.vitrivr.engine.core.model.retrievable.Retrievable
import org.vitrivr.engine.core.model.retrievable.Retrieved
import org.vitrivr.engine.core.model.retrievable.attributes.PropertyAttribute
import org.vitrivr.engine.core.model.types.Value
import org.vitrivr.engine.core.operators.Operator
import org.vitrivr.engine.core.operators.general.Transformer
import java.nio.file.Path
import java.time.LocalDateTime
import java.util.Timer
import javax.management.Descriptor

/**
* Appends [Descriptor] to a [Retrieved] based on the values of a [Schema.Field], if available.
*
* @version 1.1.2
* @author Luca Rossetto
* @author Ralph Gasser
*/
class TimeBenchmark(
override val input: Operator<out Retrievable>,
val path: Path,
val pretty: String,
override val name: String
) : Transformer {

companion object {
@Volatile
private var bl: BenchmarkLogger? = null
}

init {
if (bl == null) {
bl = BenchmarkLogger(path)
Thread(bl).start()
}
}

override fun toFlow(scope: CoroutineScope): Flow<Retrievable> = flow {
val inputRetrieved = input.toFlow(scope).toList()
bl!! log BenchmarkMessage(name, pretty, LocalDateTime.now().toString(), inputRetrieved.size)
inputRetrieved.forEach { emit(it) }
}
}


Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
package org.vitrivr.engine.query.operators.transform.benchmark

import org.vitrivr.engine.core.context.Context
import org.vitrivr.engine.core.context.QueryContext
import org.vitrivr.engine.core.model.retrievable.Retrievable
import org.vitrivr.engine.core.operators.Operator
import org.vitrivr.engine.core.operators.general.TransformerFactory
import kotlin.io.path.Path

class TimeBenchmarkFactory() : TransformerFactory {
override fun newTransformer(name: String, input: Operator<out Retrievable>, context: Context): TimeBenchmark {
require(context is QueryContext)
val logfilePath = Path(context[name, "logfile"]?.toString() ?: "benchmark.log")
val prettyName = context[name, "pretty"]?.toString() ?: name
return TimeBenchmark(input, logfilePath, prettyName, name)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
package org.vitrivr.engine.query.operators.transform.filter

import io.github.oshai.kotlinlogging.KLogger
import io.github.oshai.kotlinlogging.KotlinLogging
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.flow
import kotlinx.coroutines.flow.toList
import org.vitrivr.engine.core.database.descriptor.DescriptorReader
import org.vitrivr.engine.core.model.metamodel.Schema
import org.vitrivr.engine.core.model.query.basics.ComparisonOperator
import org.vitrivr.engine.core.model.retrievable.Retrievable
import org.vitrivr.engine.core.model.retrievable.Retrieved
import org.vitrivr.engine.core.model.retrievable.attributes.PropertyAttribute
import org.vitrivr.engine.core.model.types.Value
import org.vitrivr.engine.core.operators.Operator
import org.vitrivr.engine.core.operators.general.Transformer
import java.sql.Date
import javax.management.Descriptor

/**
* Appends [Descriptor] to a [Retrieved] based on the values of a [Schema.Field], if available.
*
* @version 1.1.2
* @author Luca Rossetto
* @author Ralph Gasser
*/
class FieldLookupLateFilter(
override val input: Operator<out Retrievable>,
/* The reader for a given field. */
private val reader: DescriptorReader<*>,
/* keys to filter on */
val keys: List<String>,
/* boolean operator*/
val comparison: ComparisonOperator = ComparisonOperator.EQ,
/* value to compare to */
val value: String,
/* append field*/
val append: Boolean,
/* appends late filter */
val limit: Int = Int.MAX_VALUE,
override val name: String
) : Transformer {
private val logger: KLogger = KotlinLogging.logger {}

override fun toFlow(scope: CoroutineScope): Flow<Retrievable> = flow {
/* Parse input IDs.*/
val inputRetrieved = input.toFlow(scope).toList()

/* Fetch entries for the provided IDs. */
val ids = inputRetrieved.map { it.id }.toSet()
val descriptors = if (ids.isEmpty()) {
emptyMap()
} else {
this@FieldLookupLateFilter.reader.getAllForRetrievable(ids).associateBy { it.retrievableId!! }
}

// Multi keys for
if (keys.size > 1)
throw IllegalArgumentException("only one key is supported yet")

var emitted = 0
/* Emit retrievable with added attribute. */
inputRetrieved.forEach { retrieved ->
val descriptor = descriptors[retrieved.id]
if (descriptor != null) {
//retrieved.addDescriptor(descriptor)
/* Somewhat experimental. Goal: Attach information in a meaningful manner, such that it can be serialised */
val values = descriptor.values().toMap()
val attribute = keys.map {
(when (values[it]) {
is Value.String -> Pair(it to (values[it] as Value.String), Value.of(value.toString()))
is Value.Text -> Pair(it to (values[it] as Value.Text), Value.of(value.toString()))
is Value.Boolean -> Pair(it to (values[it] as Value.Boolean), Value.of(value.toBoolean()))
is Value.Int -> Pair(it to (values[it] as Value.Int), Value.of(value.toInt()))
is Value.Long -> Pair(it to (values[it] as Value.Long), Value.of(value.toLong()))
is Value.Float -> Pair(it to (values[it] as Value.Float), Value.of(value.toFloat()))
is Value.Double -> Pair(it to (values[it] as Value.Double), Value.of(value.toDouble()))
is Value.Byte -> Pair(it to (values[it] as Value.Byte), Value.of(value.toByte()))
is Value.Short -> Pair(it to (values[it] as Value.Short), Value.of(value.toShort()))
is Value.DateTime -> Pair(it to (values[it] as Value.DateTime), Value.of(Date.valueOf(value)))
else -> Pair(it to null, null)
})
}

retrieved.takeIf { append == true }?.let {
retrieved.addDescriptor(descriptor)
retrieved.addAttribute(PropertyAttribute(attribute.map { it.first.first.toString() to it.first.second!!.value.toString() }
.toMap()))
}

attribute[0].takeIf { it.first.second != null && it.second != null }?.let {
it.takeIf { ++emitted <= limit && comparison.compare(it.first.second!!, it.second!!) }?.let {
emit(retrieved)
}
}
}
}
}
}
Loading

0 comments on commit 68fd3d4

Please sign in to comment.