Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/feature/advanced-boolean' into f…
Browse files Browse the repository at this point in the history
…eature/advanced-boolean

# Conflicts:
#	vitrivr-engine-module-pgvector/src/main/kotlin/org/vitrivr/engine/database/pgvector/descriptor/AbstractDescriptorReader.kt
  • Loading branch information
Ralph Gasser committed Dec 2, 2024
2 parents 1abd673 + 9e9c3f2 commit 9b05f96
Showing 1 changed file with 31 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ import org.vitrivr.engine.core.model.query.bool.Logical
import org.vitrivr.engine.core.model.retrievable.RetrievableId
import org.vitrivr.engine.core.model.retrievable.Retrieved
import org.vitrivr.engine.database.pgvector.*
import org.vitrivr.engine.database.pgvector.descriptor.scalar.ScalarDescriptorReader
import org.vitrivr.engine.database.pgvector.descriptor.struct.StructDescriptorReader
import java.sql.PreparedStatement
import java.sql.ResultSet
import java.util.*
Expand Down Expand Up @@ -211,9 +213,35 @@ abstract class AbstractDescriptorReader<D : Descriptor<*>>(final override val fi
* @return Set of [RetrievableId]s that match the [BooleanPredicate].
*/
protected fun getMatches(predicate: BooleanPredicate): Set<RetrievableId> = when (predicate) {
is Comparison<*> -> predicate.field.getReader().query(Query(predicate)).map { it.id }.toSet()
is Logical.And -> predicate.predicates.map { getMatches(it) }.reduce { acc, set -> acc.intersect(set) }
is Logical.Or -> predicate.predicates.flatMap { getMatches(it) }.toSet()
is Comparison<*> -> {
val field = predicate.field
val reader = field.getReader()
when (reader) {
is ScalarDescriptorReader -> reader.query(Query(predicate)).map { it.id }.toSet()
is StructDescriptorReader -> reader.query(Query(predicate)).map { it.id }.toSet()
else -> throw IllegalArgumentException("Cannot resolve predicate $predicate.")
}
}

is Logical.And -> {
val intersection = mutableSetOf<RetrievableId>()
for ((index, child) in predicate.predicates.withIndex()) {
if (index == 0) {
intersection.addAll(getMatches(child))
} else {
intersection.intersect(getMatches(child))
}
}
intersection
}

is Logical.Or -> {
val union = mutableSetOf<RetrievableId>()
for (child in predicate.predicates) {
union.addAll(getMatches(child))
}
union
}
}

/**
Expand Down

0 comments on commit 9b05f96

Please sign in to comment.