Skip to content

Commit

Permalink
Adds support for general BooleanPredicates.
Browse files Browse the repository at this point in the history
Signed-off-by: Ralph Gasser <[email protected]>
  • Loading branch information
ppanopticon committed Nov 28, 2024
1 parent 3249028 commit 19423e7
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import org.vitrivr.engine.core.model.retrievable.Retrieved
import org.vitrivr.engine.database.pgvector.*
import org.vitrivr.engine.database.pgvector.descriptor.scalar.ScalarDescriptorReader
import org.vitrivr.engine.database.pgvector.descriptor.struct.StructDescriptorReader
import java.sql.PreparedStatement
import java.sql.ResultSet
import java.util.*

Expand Down Expand Up @@ -211,7 +212,7 @@ abstract class AbstractDescriptorReader<D : Descriptor<*>>(final override val fi
* @param predicate [BooleanPredicate] to resolve.
* @return Set of [RetrievableId]s that match the [BooleanPredicate].
*/
internal fun resolveBooleanPredicate(predicate: BooleanPredicate): Set<RetrievableId> = when (predicate) {
protected fun getMatches(predicate: BooleanPredicate): Set<RetrievableId> = when (predicate) {
is Comparison<*> -> {
val field = predicate.field
val reader = field.getReader()
Expand All @@ -226,9 +227,9 @@ abstract class AbstractDescriptorReader<D : Descriptor<*>>(final override val fi
val intersection = mutableSetOf<RetrievableId>()
for ((index, child) in predicate.predicates.withIndex()) {
if (index == 0) {
intersection.addAll(resolveBooleanPredicate(child))
intersection.addAll(getMatches(child))
} else {
intersection.intersect(resolveBooleanPredicate(child))
intersection.intersect(getMatches(child))
}
}
intersection
Expand All @@ -237,12 +238,27 @@ abstract class AbstractDescriptorReader<D : Descriptor<*>>(final override val fi
is Logical.Or -> {
val union = mutableSetOf<RetrievableId>()
for (child in predicate.predicates) {
union.addAll(resolveBooleanPredicate(child))
union.addAll(getMatches(child))
}
union
}
}

/**
* [PreparedStatement] a [BooleanPredicate] predicate and returns a [PreparedStatement].
*
* @param query The [BooleanPredicate] to prepare.
* @return [PreparedStatement]s.
*/
protected fun prepareBoolean(query: BooleanPredicate, limit: Long? = null): PreparedStatement {
val tableName = "\"${this.tableName.lowercase()}\""
val retrievableIds = this.getMatches(query)
val sql = "SELECT * FROM $tableName WHERE $RETRIEVABLE_ID_COLUMN_NAME = ANY(?) ${limit.toLimitClause()}"
val stmt = this.connection.jdbc.prepareStatement(sql)
stmt.setArray(1, this.connection.jdbc.createArrayOf("OTHER", retrievableIds.toTypedArray()))
return stmt
}

/**
* Converts a [ResultSet] to a [Descriptor] of type [D].
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import org.vitrivr.engine.core.model.descriptor.struct.StructDescriptor
import org.vitrivr.engine.core.model.descriptor.vector.VectorDescriptor
import org.vitrivr.engine.core.model.metamodel.Schema
import org.vitrivr.engine.core.model.query.Query
import org.vitrivr.engine.core.model.query.bool.BooleanPredicate
import org.vitrivr.engine.core.model.query.bool.Comparison
import org.vitrivr.engine.core.model.query.fulltext.SimpleFulltextPredicate
import org.vitrivr.engine.core.model.retrievable.Retrieved
Expand Down Expand Up @@ -34,6 +35,7 @@ class ScalarDescriptorReader(field: Schema.Field<*, ScalarDescriptor<*, *>>, con
when (val predicate = query.predicate) {
is SimpleFulltextPredicate -> prepareFulltext(predicate, query.limit)
is Comparison<*> -> prepareComparison(predicate, query.limit)
is BooleanPredicate -> prepareBoolean(predicate, query.limit)
else -> throw IllegalArgumentException("Query of type ${query::class} is not supported by ScalarDescriptorReader.")
}.use { stmt ->
stmt.executeQuery().use { result ->
Expand Down Expand Up @@ -69,7 +71,7 @@ class ScalarDescriptorReader(field: Schema.Field<*, ScalarDescriptor<*, *>>, con
/**
* Prepares a [SimpleFulltextPredicate] and returns a [Sequence] of [ScalarDescriptor]s.
*
* @param query The [SimpleFulltextPredicate] to execute.
* @param query The [SimpleFulltextPredicate] to prepare.
* @param limit The maximum number of results to return.
* @return [PreparedStatement]s.
*/
Expand All @@ -84,7 +86,7 @@ class ScalarDescriptorReader(field: Schema.Field<*, ScalarDescriptor<*, *>>, con
return stmt
} else {
val sql = "SELECT * FROM $tableName WHERE $VALUE_ATTRIBUTE_NAME @@ to_tsquery(?) AND $RETRIEVABLE_ID_COLUMN_NAME = ANY(?) ${limit.toLimitClause()}"
val retrievableIds = this.resolveBooleanPredicate(filter)
val retrievableIds = this.getMatches(filter)
val stmt = this.connection.jdbc.prepareStatement(sql)
stmt.setString(1, fulltextQueryString)
stmt.setArray(2, this.connection.jdbc.createArrayOf("OTHER", retrievableIds.toTypedArray()))
Expand All @@ -95,7 +97,7 @@ class ScalarDescriptorReader(field: Schema.Field<*, ScalarDescriptor<*, *>>, con
/**
* [PreparedStatement] a [Comparison] predicate and returns a [PreparedStatement].
*
* @param query The [Comparison] to execute.
* @param query The [Comparison] to prepare.
* @return [PreparedStatement]s.
*/
private fun prepareComparison(query: Comparison<*>, limit: Long? = null): PreparedStatement {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import org.vitrivr.engine.core.model.descriptor.scalar.ScalarDescriptor
import org.vitrivr.engine.core.model.descriptor.struct.StructDescriptor
import org.vitrivr.engine.core.model.metamodel.Schema
import org.vitrivr.engine.core.model.query.Query
import org.vitrivr.engine.core.model.query.bool.BooleanPredicate
import org.vitrivr.engine.core.model.query.bool.Comparison
import org.vitrivr.engine.core.model.query.fulltext.SimpleFulltextPredicate
import org.vitrivr.engine.core.model.retrievable.Retrieved
Expand Down Expand Up @@ -37,6 +38,7 @@ class StructDescriptorReader(field: Schema.Field<*, StructDescriptor<*>>, connec
when (val predicate = query.predicate) {
is SimpleFulltextPredicate -> prepareFulltext(predicate)
is Comparison<*> -> prepareComparison(predicate)
is BooleanPredicate -> prepareBoolean(predicate)
else -> throw IllegalArgumentException("Query of type ${query::class} is not supported by StructDescriptorReader.")
}.use { stmt ->
stmt.executeQuery().use { result ->
Expand Down Expand Up @@ -110,7 +112,7 @@ class StructDescriptorReader(field: Schema.Field<*, StructDescriptor<*>>, connec
return stmt
} else {
val sql = "SELECT * FROM $tableName WHERE ${query.attributeName} @@ to_tsquery(?) AND $RETRIEVABLE_ID_COLUMN_NAME = ANY(?) ${limit.toLimitClause()}"
val retrievableIds = this.resolveBooleanPredicate(filter)
val retrievableIds = this.getMatches(filter)
val stmt = this.connection.jdbc.prepareStatement(sql)
stmt.setString(1, fulltextQueryString)
stmt.setArray(2, this.connection.jdbc.createArrayOf("OTHER", retrievableIds.toTypedArray()))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import org.vitrivr.engine.core.model.descriptor.vector.VectorDescriptor.Companio
import org.vitrivr.engine.core.model.metamodel.Schema
import org.vitrivr.engine.core.model.query.Predicate
import org.vitrivr.engine.core.model.query.Query
import org.vitrivr.engine.core.model.query.bool.BooleanPredicate
import org.vitrivr.engine.core.model.query.proximity.ProximityPredicate
import org.vitrivr.engine.core.model.retrievable.Retrieved
import org.vitrivr.engine.core.model.retrievable.attributes.DistanceAttribute
Expand Down Expand Up @@ -38,7 +39,7 @@ class VectorDescriptorReader(field: Schema.Field<*, VectorDescriptor<*, *>>, con
}
}
}

is BooleanPredicate -> prepareBoolean(predicate)
else -> throw UnsupportedOperationException("Query of typ ${query::class} is not supported by VectorDescriptorReader.")
}
}
Expand Down Expand Up @@ -136,7 +137,7 @@ class VectorDescriptorReader(field: Schema.Field<*, VectorDescriptor<*, *>>, con
"ORDER BY $DISTANCE_COLUMN_NAME ${query.order} " +
"LIMIT ${query.k}"

val retrievableIds = this.resolveBooleanPredicate(filter)
val retrievableIds = this.getMatches(filter)
val stmt = this@VectorDescriptorReader.connection.jdbc.prepareStatement(sql)
stmt.setValue(1, query.value)
stmt.setArray(2, this.connection.jdbc.createArrayOf("OTHER", retrievableIds.toTypedArray()))
Expand Down Expand Up @@ -177,7 +178,7 @@ class VectorDescriptorReader(field: Schema.Field<*, VectorDescriptor<*, *>>, con
"FROM $cteTable INNER JOIN $RETRIEVABLE_ENTITY_NAME ON ($RETRIEVABLE_ENTITY_NAME.$RETRIEVABLE_ID_COLUMN_NAME = $cteTable.$RETRIEVABLE_ID_COLUMN_NAME)" +
"ORDER BY $cteTable.$DISTANCE_COLUMN_NAME ${query.order}"

val retrievableIds = this.resolveBooleanPredicate(filter)
val retrievableIds = this.getMatches(filter)
val stmt = this@VectorDescriptorReader.connection.jdbc.prepareStatement(sql)
stmt.setValue(1, query.value)
stmt.setArray(2, this.connection.jdbc.createArrayOf("OTHER", retrievableIds.toTypedArray()))
Expand Down

0 comments on commit 19423e7

Please sign in to comment.