From 19423e72e7c7b079d0caf36ad17f99be46ee5990 Mon Sep 17 00:00:00 2001 From: Ralph Gasser Date: Thu, 28 Nov 2024 10:39:38 +0100 Subject: [PATCH] Adds support for general BooleanPredicates. Signed-off-by: Ralph Gasser --- .../descriptor/AbstractDescriptorReader.kt | 24 +++++++++++++++---- .../scalar/ScalarDescriptorReader.kt | 8 ++++--- .../struct/StructDescriptorReader.kt | 4 +++- .../vector/VectorDescriptorReader.kt | 7 +++--- 4 files changed, 32 insertions(+), 11 deletions(-) diff --git a/vitrivr-engine-module-pgvector/src/main/kotlin/org/vitrivr/engine/database/pgvector/descriptor/AbstractDescriptorReader.kt b/vitrivr-engine-module-pgvector/src/main/kotlin/org/vitrivr/engine/database/pgvector/descriptor/AbstractDescriptorReader.kt index 5370c934..94f80991 100644 --- a/vitrivr-engine-module-pgvector/src/main/kotlin/org/vitrivr/engine/database/pgvector/descriptor/AbstractDescriptorReader.kt +++ b/vitrivr-engine-module-pgvector/src/main/kotlin/org/vitrivr/engine/database/pgvector/descriptor/AbstractDescriptorReader.kt @@ -14,6 +14,7 @@ import org.vitrivr.engine.core.model.retrievable.Retrieved import org.vitrivr.engine.database.pgvector.* import org.vitrivr.engine.database.pgvector.descriptor.scalar.ScalarDescriptorReader import org.vitrivr.engine.database.pgvector.descriptor.struct.StructDescriptorReader +import java.sql.PreparedStatement import java.sql.ResultSet import java.util.* @@ -211,7 +212,7 @@ abstract class AbstractDescriptorReader>(final override val fi * @param predicate [BooleanPredicate] to resolve. * @return Set of [RetrievableId]s that match the [BooleanPredicate]. */ - internal fun resolveBooleanPredicate(predicate: BooleanPredicate): Set = when (predicate) { + protected fun getMatches(predicate: BooleanPredicate): Set = when (predicate) { is Comparison<*> -> { val field = predicate.field val reader = field.getReader() @@ -226,9 +227,9 @@ abstract class AbstractDescriptorReader>(final override val fi val intersection = mutableSetOf() for ((index, child) in predicate.predicates.withIndex()) { if (index == 0) { - intersection.addAll(resolveBooleanPredicate(child)) + intersection.addAll(getMatches(child)) } else { - intersection.intersect(resolveBooleanPredicate(child)) + intersection.intersect(getMatches(child)) } } intersection @@ -237,12 +238,27 @@ abstract class AbstractDescriptorReader>(final override val fi is Logical.Or -> { val union = mutableSetOf() for (child in predicate.predicates) { - union.addAll(resolveBooleanPredicate(child)) + union.addAll(getMatches(child)) } union } } + /** + * [PreparedStatement] a [BooleanPredicate] predicate and returns a [PreparedStatement]. + * + * @param query The [BooleanPredicate] to prepare. + * @return [PreparedStatement]s. + */ + protected fun prepareBoolean(query: BooleanPredicate, limit: Long? = null): PreparedStatement { + val tableName = "\"${this.tableName.lowercase()}\"" + val retrievableIds = this.getMatches(query) + val sql = "SELECT * FROM $tableName WHERE $RETRIEVABLE_ID_COLUMN_NAME = ANY(?) ${limit.toLimitClause()}" + val stmt = this.connection.jdbc.prepareStatement(sql) + stmt.setArray(1, this.connection.jdbc.createArrayOf("OTHER", retrievableIds.toTypedArray())) + return stmt + } + /** * Converts a [ResultSet] to a [Descriptor] of type [D]. * diff --git a/vitrivr-engine-module-pgvector/src/main/kotlin/org/vitrivr/engine/database/pgvector/descriptor/scalar/ScalarDescriptorReader.kt b/vitrivr-engine-module-pgvector/src/main/kotlin/org/vitrivr/engine/database/pgvector/descriptor/scalar/ScalarDescriptorReader.kt index dda0bfab..78cbcba9 100644 --- a/vitrivr-engine-module-pgvector/src/main/kotlin/org/vitrivr/engine/database/pgvector/descriptor/scalar/ScalarDescriptorReader.kt +++ b/vitrivr-engine-module-pgvector/src/main/kotlin/org/vitrivr/engine/database/pgvector/descriptor/scalar/ScalarDescriptorReader.kt @@ -6,6 +6,7 @@ import org.vitrivr.engine.core.model.descriptor.struct.StructDescriptor import org.vitrivr.engine.core.model.descriptor.vector.VectorDescriptor import org.vitrivr.engine.core.model.metamodel.Schema import org.vitrivr.engine.core.model.query.Query +import org.vitrivr.engine.core.model.query.bool.BooleanPredicate import org.vitrivr.engine.core.model.query.bool.Comparison import org.vitrivr.engine.core.model.query.fulltext.SimpleFulltextPredicate import org.vitrivr.engine.core.model.retrievable.Retrieved @@ -34,6 +35,7 @@ class ScalarDescriptorReader(field: Schema.Field<*, ScalarDescriptor<*, *>>, con when (val predicate = query.predicate) { is SimpleFulltextPredicate -> prepareFulltext(predicate, query.limit) is Comparison<*> -> prepareComparison(predicate, query.limit) + is BooleanPredicate -> prepareBoolean(predicate, query.limit) else -> throw IllegalArgumentException("Query of type ${query::class} is not supported by ScalarDescriptorReader.") }.use { stmt -> stmt.executeQuery().use { result -> @@ -69,7 +71,7 @@ class ScalarDescriptorReader(field: Schema.Field<*, ScalarDescriptor<*, *>>, con /** * Prepares a [SimpleFulltextPredicate] and returns a [Sequence] of [ScalarDescriptor]s. * - * @param query The [SimpleFulltextPredicate] to execute. + * @param query The [SimpleFulltextPredicate] to prepare. * @param limit The maximum number of results to return. * @return [PreparedStatement]s. */ @@ -84,7 +86,7 @@ class ScalarDescriptorReader(field: Schema.Field<*, ScalarDescriptor<*, *>>, con return stmt } else { val sql = "SELECT * FROM $tableName WHERE $VALUE_ATTRIBUTE_NAME @@ to_tsquery(?) AND $RETRIEVABLE_ID_COLUMN_NAME = ANY(?) ${limit.toLimitClause()}" - val retrievableIds = this.resolveBooleanPredicate(filter) + val retrievableIds = this.getMatches(filter) val stmt = this.connection.jdbc.prepareStatement(sql) stmt.setString(1, fulltextQueryString) stmt.setArray(2, this.connection.jdbc.createArrayOf("OTHER", retrievableIds.toTypedArray())) @@ -95,7 +97,7 @@ class ScalarDescriptorReader(field: Schema.Field<*, ScalarDescriptor<*, *>>, con /** * [PreparedStatement] a [Comparison] predicate and returns a [PreparedStatement]. * - * @param query The [Comparison] to execute. + * @param query The [Comparison] to prepare. * @return [PreparedStatement]s. */ private fun prepareComparison(query: Comparison<*>, limit: Long? = null): PreparedStatement { diff --git a/vitrivr-engine-module-pgvector/src/main/kotlin/org/vitrivr/engine/database/pgvector/descriptor/struct/StructDescriptorReader.kt b/vitrivr-engine-module-pgvector/src/main/kotlin/org/vitrivr/engine/database/pgvector/descriptor/struct/StructDescriptorReader.kt index 582aeede..5eb8cca5 100644 --- a/vitrivr-engine-module-pgvector/src/main/kotlin/org/vitrivr/engine/database/pgvector/descriptor/struct/StructDescriptorReader.kt +++ b/vitrivr-engine-module-pgvector/src/main/kotlin/org/vitrivr/engine/database/pgvector/descriptor/struct/StructDescriptorReader.kt @@ -5,6 +5,7 @@ import org.vitrivr.engine.core.model.descriptor.scalar.ScalarDescriptor import org.vitrivr.engine.core.model.descriptor.struct.StructDescriptor import org.vitrivr.engine.core.model.metamodel.Schema import org.vitrivr.engine.core.model.query.Query +import org.vitrivr.engine.core.model.query.bool.BooleanPredicate import org.vitrivr.engine.core.model.query.bool.Comparison import org.vitrivr.engine.core.model.query.fulltext.SimpleFulltextPredicate import org.vitrivr.engine.core.model.retrievable.Retrieved @@ -37,6 +38,7 @@ class StructDescriptorReader(field: Schema.Field<*, StructDescriptor<*>>, connec when (val predicate = query.predicate) { is SimpleFulltextPredicate -> prepareFulltext(predicate) is Comparison<*> -> prepareComparison(predicate) + is BooleanPredicate -> prepareBoolean(predicate) else -> throw IllegalArgumentException("Query of type ${query::class} is not supported by StructDescriptorReader.") }.use { stmt -> stmt.executeQuery().use { result -> @@ -110,7 +112,7 @@ class StructDescriptorReader(field: Schema.Field<*, StructDescriptor<*>>, connec return stmt } else { val sql = "SELECT * FROM $tableName WHERE ${query.attributeName} @@ to_tsquery(?) AND $RETRIEVABLE_ID_COLUMN_NAME = ANY(?) ${limit.toLimitClause()}" - val retrievableIds = this.resolveBooleanPredicate(filter) + val retrievableIds = this.getMatches(filter) val stmt = this.connection.jdbc.prepareStatement(sql) stmt.setString(1, fulltextQueryString) stmt.setArray(2, this.connection.jdbc.createArrayOf("OTHER", retrievableIds.toTypedArray())) diff --git a/vitrivr-engine-module-pgvector/src/main/kotlin/org/vitrivr/engine/database/pgvector/descriptor/vector/VectorDescriptorReader.kt b/vitrivr-engine-module-pgvector/src/main/kotlin/org/vitrivr/engine/database/pgvector/descriptor/vector/VectorDescriptorReader.kt index d38285bb..1566237e 100644 --- a/vitrivr-engine-module-pgvector/src/main/kotlin/org/vitrivr/engine/database/pgvector/descriptor/vector/VectorDescriptorReader.kt +++ b/vitrivr-engine-module-pgvector/src/main/kotlin/org/vitrivr/engine/database/pgvector/descriptor/vector/VectorDescriptorReader.kt @@ -6,6 +6,7 @@ import org.vitrivr.engine.core.model.descriptor.vector.VectorDescriptor.Companio import org.vitrivr.engine.core.model.metamodel.Schema import org.vitrivr.engine.core.model.query.Predicate import org.vitrivr.engine.core.model.query.Query +import org.vitrivr.engine.core.model.query.bool.BooleanPredicate import org.vitrivr.engine.core.model.query.proximity.ProximityPredicate import org.vitrivr.engine.core.model.retrievable.Retrieved import org.vitrivr.engine.core.model.retrievable.attributes.DistanceAttribute @@ -38,7 +39,7 @@ class VectorDescriptorReader(field: Schema.Field<*, VectorDescriptor<*, *>>, con } } } - + is BooleanPredicate -> prepareBoolean(predicate) else -> throw UnsupportedOperationException("Query of typ ${query::class} is not supported by VectorDescriptorReader.") } } @@ -136,7 +137,7 @@ class VectorDescriptorReader(field: Schema.Field<*, VectorDescriptor<*, *>>, con "ORDER BY $DISTANCE_COLUMN_NAME ${query.order} " + "LIMIT ${query.k}" - val retrievableIds = this.resolveBooleanPredicate(filter) + val retrievableIds = this.getMatches(filter) val stmt = this@VectorDescriptorReader.connection.jdbc.prepareStatement(sql) stmt.setValue(1, query.value) stmt.setArray(2, this.connection.jdbc.createArrayOf("OTHER", retrievableIds.toTypedArray())) @@ -177,7 +178,7 @@ class VectorDescriptorReader(field: Schema.Field<*, VectorDescriptor<*, *>>, con "FROM $cteTable INNER JOIN $RETRIEVABLE_ENTITY_NAME ON ($RETRIEVABLE_ENTITY_NAME.$RETRIEVABLE_ID_COLUMN_NAME = $cteTable.$RETRIEVABLE_ID_COLUMN_NAME)" + "ORDER BY $cteTable.$DISTANCE_COLUMN_NAME ${query.order}" - val retrievableIds = this.resolveBooleanPredicate(filter) + val retrievableIds = this.getMatches(filter) val stmt = this@VectorDescriptorReader.connection.jdbc.prepareStatement(sql) stmt.setValue(1, query.value) stmt.setArray(2, this.connection.jdbc.createArrayOf("OTHER", retrievableIds.toTypedArray()))