From 9b11866abf49ed6b9b5e6e7e84ab77334dca3bfb Mon Sep 17 00:00:00 2001 From: Jim Bosch Date: Wed, 31 Jan 2024 14:33:43 -0500 Subject: [PATCH] WIP: DirectQueryDriver. --- .../butler/direct_query_driver/__init__.py | 29 + .../direct_query_driver/_analyzed_query.py | 170 ++++ .../direct_query_driver/_convert_results.py | 61 ++ .../daf/butler/direct_query_driver/_driver.py | 789 ++++++++++++++++++ .../direct_query_driver/_postprocessing.py | 138 +++ .../direct_query_driver/_sql_builder.py | 257 ++++++ .../_sql_column_visitor.py | 239 ++++++ .../butler/registry/collections/nameKey.py | 6 + .../registry/collections/synthIntKey.py | 11 + .../datasets/byDimensions/_storage.py | 212 ++++- .../daf/butler/registry/dimensions/static.py | 198 ++++- .../registry/interfaces/_collections.py | 26 + .../butler/registry/interfaces/_datasets.py | 16 +- .../butler/registry/interfaces/_dimensions.py | 19 +- 14 files changed, 2162 insertions(+), 9 deletions(-) create mode 100644 python/lsst/daf/butler/direct_query_driver/__init__.py create mode 100644 python/lsst/daf/butler/direct_query_driver/_analyzed_query.py create mode 100644 python/lsst/daf/butler/direct_query_driver/_convert_results.py create mode 100644 python/lsst/daf/butler/direct_query_driver/_driver.py create mode 100644 python/lsst/daf/butler/direct_query_driver/_postprocessing.py create mode 100644 python/lsst/daf/butler/direct_query_driver/_sql_builder.py create mode 100644 python/lsst/daf/butler/direct_query_driver/_sql_column_visitor.py diff --git a/python/lsst/daf/butler/direct_query_driver/__init__.py b/python/lsst/daf/butler/direct_query_driver/__init__.py new file mode 100644 index 0000000000..d8aae48e4b --- /dev/null +++ b/python/lsst/daf/butler/direct_query_driver/__init__.py @@ -0,0 +1,29 @@ +# This file is part of daf_butler. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (http://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This software is dual licensed under the GNU General Public License and also +# under a 3-clause BSD license. Recipients may choose which of these licenses +# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, +# respectively. If you choose the GPL option then the following text applies +# (but note that there is still no warranty even if you opt for BSD instead): +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +from ._postprocessing import Postprocessing +from ._sql_builder import SqlBuilder diff --git a/python/lsst/daf/butler/direct_query_driver/_analyzed_query.py b/python/lsst/daf/butler/direct_query_driver/_analyzed_query.py new file mode 100644 index 0000000000..71d98c7898 --- /dev/null +++ b/python/lsst/daf/butler/direct_query_driver/_analyzed_query.py @@ -0,0 +1,170 @@ +# This file is part of daf_butler. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (http://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This software is dual licensed under the GNU General Public License and also +# under a 3-clause BSD license. Recipients may choose which of these licenses +# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, +# respectively. If you choose the GPL option then the following text applies +# (but note that there is still no warranty even if you opt for BSD instead): +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +from __future__ import annotations + +__all__ = ("AnalyzedQuery", "AnalyzedDatasetSearch", "DataIdExtractionVisitor") + +import dataclasses +from collections.abc import Iterator +from typing import Any + +from ..dimensions import DataIdValue, DimensionElement, DimensionGroup, DimensionUniverse +from ..queries import tree as qt +from ..queries.visitors import ColumnExpressionVisitor, PredicateVisitFlags, SimplePredicateVisitor +from ..registry.interfaces import CollectionRecord +from ._postprocessing import Postprocessing + + +@dataclasses.dataclass +class AnalyzedDatasetSearch: + name: str + shrunk: str + dimensions: DimensionGroup + collection_records: list[CollectionRecord] = dataclasses.field(default_factory=list) + messages: list[str] = dataclasses.field(default_factory=list) + is_calibration_search: bool = False + + +@dataclasses.dataclass +class AnalyzedQuery: + predicate: qt.Predicate + postprocessing: Postprocessing + base_columns: qt.ColumnSet + projection_columns: qt.ColumnSet + final_columns: qt.ColumnSet + find_first_dataset: str | None + materializations: dict[qt.MaterializationKey, DimensionGroup] = dataclasses.field(default_factory=dict) + datasets: dict[str, AnalyzedDatasetSearch] = dataclasses.field(default_factory=dict) + messages: list[str] = dataclasses.field(default_factory=list) + constraint_data_id: dict[str, DataIdValue] = dataclasses.field(default_factory=dict) + data_coordinate_uploads: dict[qt.DataCoordinateUploadKey, DimensionGroup] = dataclasses.field( + default_factory=dict + ) + needs_dimension_distinct: bool = False + needs_find_first_resolution: bool = False + projection_region_aggregates: list[DimensionElement] = dataclasses.field(default_factory=list) + + @property + def universe(self) -> DimensionUniverse: + return self.base_columns.dimensions.universe + + @property + def needs_projection(self) -> bool: + return self.needs_dimension_distinct or self.postprocessing.check_validity_match_count + + def iter_mandatory_base_elements(self) -> Iterator[DimensionElement]: + for element_name in self.base_columns.dimensions.elements: + element = self.universe[element_name] + if self.base_columns.dimension_fields[element_name]: + # We need to get dimension record fields for this element, and + # its table is the only place to get those. + yield element + elif element.defines_relationships: + # We als need to join in DimensionElements tables that define + # one-to-many and many-to-many relationships, but data + # coordinate uploads, materializations, and datasets can also + # provide these relationships. Data coordinate uploads and + # dataset tables only have required dimensions, and can hence + # only provide relationships involving those. + if any( + element.minimal_group.names <= upload_dimensions.required + for upload_dimensions in self.data_coordinate_uploads.values() + ): + continue + if any( + element.minimal_group.names <= dataset_spec.dimensions.required + for dataset_spec in self.datasets.values() + ): + continue + # Materializations have all key columns for their dimensions. + if any( + element in materialization_dimensions.names + for materialization_dimensions in self.materializations.values() + ): + continue + yield element + + +class DataIdExtractionVisitor( + SimplePredicateVisitor, + ColumnExpressionVisitor[tuple[str, None] | tuple[None, Any] | tuple[None, None]], +): + def __init__(self, data_id: dict[str, DataIdValue], messages: list[str]): + self.data_id = data_id + self.messages = messages + + def visit_comparison( + self, + a: qt.ColumnExpression, + operator: qt.ComparisonOperator, + b: qt.ColumnExpression, + flags: PredicateVisitFlags, + ) -> None: + if flags & PredicateVisitFlags.HAS_OR_SIBLINGS: + return None + if flags & PredicateVisitFlags.INVERTED: + if operator == "!=": + operator = "==" + else: + return None + if operator != "==": + return None + k_a, v_a = a.visit(self) + k_b, v_b = b.visit(self) + if k_a is not None and v_b is not None: + key = k_a + value = v_b + elif k_b is not None and v_a is not None: + key = k_b + value = v_a + else: + return None + if (old := self.data_id.setdefault(key, value)) != value: + self.messages.append(f"'where' expression requires both {key}={value!r} and {key}={old!r}.") + return None + + def visit_binary_expression(self, expression: qt.BinaryExpression) -> tuple[None, None]: + return None, None + + def visit_unary_expression(self, expression: qt.UnaryExpression) -> tuple[None, None]: + return None, None + + def visit_literal(self, expression: qt.ColumnLiteral) -> tuple[None, Any]: + return None, expression.get_literal_value() + + def visit_dimension_key_reference(self, expression: qt.DimensionKeyReference) -> tuple[str, None]: + return expression.dimension.name, None + + def visit_dimension_field_reference(self, expression: qt.DimensionFieldReference) -> tuple[None, None]: + return None, None + + def visit_dataset_field_reference(self, expression: qt.DatasetFieldReference) -> tuple[None, None]: + return None, None + + def visit_reversed(self, expression: qt.Reversed) -> tuple[None, None]: + raise AssertionError("No Reversed expressions in predicates.") diff --git a/python/lsst/daf/butler/direct_query_driver/_convert_results.py b/python/lsst/daf/butler/direct_query_driver/_convert_results.py new file mode 100644 index 0000000000..899c1760b4 --- /dev/null +++ b/python/lsst/daf/butler/direct_query_driver/_convert_results.py @@ -0,0 +1,61 @@ +# This file is part of daf_butler. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (http://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This software is dual licensed under the GNU General Public License and also +# under a 3-clause BSD license. Recipients may choose which of these licenses +# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, +# respectively. If you choose the GPL option then the following text applies +# (but note that there is still no warranty even if you opt for BSD instead): +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +from __future__ import annotations + +__all__ = ("convert_dimension_record_results",) + +from collections.abc import Iterable +from typing import TYPE_CHECKING + +import sqlalchemy + +from ..dimensions import DimensionRecordSet + +if TYPE_CHECKING: + from ..queries.driver import DimensionRecordResultPage, PageKey + from ..queries.result_specs import DimensionRecordResultSpec + from ..registry.nameShrinker import NameShrinker + + +def convert_dimension_record_results( + raw_rows: Iterable[sqlalchemy.Row], + spec: DimensionRecordResultSpec, + next_key: PageKey | None, + name_shrinker: NameShrinker, +) -> DimensionRecordResultPage: + record_set = DimensionRecordSet(spec.element) + columns = spec.get_result_columns() + column_mapping = [ + (field, name_shrinker.shrink(columns.get_qualified_name(spec.element.name, field))) + for field in spec.element.schema.names + ] + record_cls = spec.element.RecordClass + if not spec.element.temporal: + for raw_row in raw_rows: + record_set.add(record_cls(**{k: raw_row._mapping[v] for k, v in column_mapping})) + return DimensionRecordResultPage(spec=spec, next_key=next_key, rows=record_set) diff --git a/python/lsst/daf/butler/direct_query_driver/_driver.py b/python/lsst/daf/butler/direct_query_driver/_driver.py new file mode 100644 index 0000000000..654c4f80bc --- /dev/null +++ b/python/lsst/daf/butler/direct_query_driver/_driver.py @@ -0,0 +1,789 @@ +# This file is part of daf_butler. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (http://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This software is dual licensed under the GNU General Public License and also +# under a 3-clause BSD license. Recipients may choose which of these licenses +# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, +# respectively. If you choose the GPL option then the following text applies +# (but note that there is still no warranty even if you opt for BSD instead): +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +from __future__ import annotations + +import uuid + +__all__ = ("DirectQueryDriver",) + +import logging +from collections.abc import Iterable, Iterator, Sequence +from contextlib import ExitStack +from typing import TYPE_CHECKING, Any, cast, overload + +import sqlalchemy + +from .. import ddl +from ..dimensions import DataIdValue, DimensionGroup, DimensionUniverse +from ..queries import tree as qt +from ..queries.driver import ( + DataCoordinateResultPage, + DatasetRefResultPage, + DimensionRecordResultPage, + GeneralResultPage, + PageKey, + QueryDriver, + ResultPage, +) +from ..queries.result_specs import ( + DataCoordinateResultSpec, + DatasetRefResultSpec, + DimensionRecordResultSpec, + GeneralResultSpec, + ResultSpec, +) +from ..registry import CollectionSummary, CollectionType, NoDefaultCollectionError, RegistryDefaults +from ..registry.interfaces import ChainedCollectionRecord, CollectionRecord +from ..registry.managers import RegistryManagerInstances +from ..registry.nameShrinker import NameShrinker +from ._analyzed_query import AnalyzedDatasetSearch, AnalyzedQuery, DataIdExtractionVisitor +from ._convert_results import convert_dimension_record_results +from ._sql_column_visitor import SqlColumnVisitor + +if TYPE_CHECKING: + from ..registry.interfaces import Database + from ._postprocessing import Postprocessing + from ._sql_builder import SqlBuilder + + +_LOG = logging.getLogger(__name__) + + +class DirectQueryDriver(QueryDriver): + """The `QueryDriver` implementation for `DirectButler`. + + Parameters + ---------- + db : `Database` + Abstraction for the SQL database. + universe : `DimensionUniverse` + Definitions of all dimensions. + managers : `RegistryManagerInstances` + Struct of registry manager objects. + defaults : `RegistryDefaults` + Struct holding the default collection search path and governor + dimensions. + raw_page_size : `int`, optional + Number of database rows to fetch for each result page. The actual + number of rows in a page may be smaller due to postprocessing. + postprocessing_filter_factor : `int`, optional + The number of database rows we expect to have to fetch to yield a + single output row for queries that involve postprocessing. This is + purely a performance tuning parameter that attempts to balance between + fetching too much and requiring multiple fetches; the true value is + highly dependent on the actual query. + """ + + def __init__( + self, + db: Database, + universe: DimensionUniverse, + managers: RegistryManagerInstances, + defaults: RegistryDefaults, + raw_page_size: int = 10000, + postprocessing_filter_factor: int = 10, + ): + self.db = db + self.managers = managers + self._universe = universe + self._defaults = defaults + self._materializations: dict[qt.MaterializationKey, tuple[sqlalchemy.Table, Postprocessing]] = {} + self._upload_tables: dict[qt.DataCoordinateUploadKey, sqlalchemy.Table] = {} + self._exit_stack: ExitStack | None = None + self._raw_page_size = raw_page_size + self._postprocessing_filter_factor = postprocessing_filter_factor + self._active_pages: dict[PageKey, tuple[Iterator[Sequence[sqlalchemy.Row]], Postprocessing]] = {} + self._name_shrinker = NameShrinker(self.db.dialect.max_identifier_length) + + def __enter__(self) -> None: + self._exit_stack = ExitStack() + + def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: + assert self._exit_stack is not None + self._exit_stack.__exit__(exc_type, exc_value, traceback) + self._exit_stack = None + + @property + def universe(self) -> DimensionUniverse: + return self._universe + + @overload + def execute( + self, result_spec: DataCoordinateResultSpec, tree: qt.QueryTree + ) -> DataCoordinateResultPage: ... + + @overload + def execute( + self, result_spec: DimensionRecordResultSpec, tree: qt.QueryTree + ) -> DimensionRecordResultPage: ... + + @overload + def execute(self, result_spec: DatasetRefResultSpec, tree: qt.QueryTree) -> DatasetRefResultPage: ... + + @overload + def execute(self, result_spec: GeneralResultSpec, tree: qt.QueryTree) -> GeneralResultPage: ... + + def execute(self, result_spec: ResultSpec, tree: qt.QueryTree) -> ResultPage: + # Docstring inherited. + if self._exit_stack is None: + raise RuntimeError("QueryDriver context must be entered before 'materialize' is called.") + # Make a set of the columns the query needs to make available to the + # SELECT clause and any ORDER BY or GROUP BY clauses. This does not + # include columns needed only by the WHERE or JOIN ON clauses (those + # will be handled inside `_make_vanilla_sql_builder`). + + # Build the FROM and WHERE clauses and identify any post-query + # processing we need to run. + query, sql_builder = self.analyze_query( + tree, + final_columns=result_spec.get_result_columns(), + order_by=result_spec.order_by, + find_first_dataset=result_spec.find_first_dataset, + ) + sql_builder = self.build_query(query, sql_builder) + sql_select = sql_builder.select(query.final_columns, query.postprocessing) + if result_spec.order_by: + visitor = SqlColumnVisitor(sql_builder, self) + sql_select = sql_select.order_by(*[visitor.expect_scalar(term) for term in result_spec.order_by]) + if result_spec.limit is not None: + if query.postprocessing: + query.postprocessing.limit = result_spec.limit + else: + sql_select = sql_select.limit(result_spec.limit) + if result_spec.offset: + if query.postprocessing: + sql_select = sql_select.offset(result_spec.offset) + else: + query.postprocessing.offset = result_spec.offset + if query.postprocessing.limit is not None: + # We might want to fetch many fewer rows that the default page + # size if we have to implement offset and limit in postprocessing. + raw_page_size = min( + self._postprocessing_filter_factor + * (query.postprocessing.offset + query.postprocessing.limit), + self._raw_page_size, + ) + cursor = self._exit_stack.enter_context( + self.db.query(sql_select.execution_options(yield_per=raw_page_size)) + ) + raw_page_iter = cursor.partitions() + return self._process_page(raw_page_iter, result_spec, query.postprocessing) + + @overload + def fetch_next_page( + self, result_spec: DataCoordinateResultSpec, key: PageKey + ) -> DataCoordinateResultPage: ... + + @overload + def fetch_next_page( + self, result_spec: DimensionRecordResultSpec, key: PageKey + ) -> DimensionRecordResultPage: ... + + @overload + def fetch_next_page(self, result_spec: DatasetRefResultSpec, key: PageKey) -> DatasetRefResultPage: ... + + @overload + def fetch_next_page(self, result_spec: GeneralResultSpec, key: PageKey) -> GeneralResultPage: ... + + def fetch_next_page(self, result_spec: ResultSpec, key: PageKey) -> ResultPage: + raw_page_iter, postprocessing = self._active_pages.pop(key) + return self._process_page(raw_page_iter, result_spec, postprocessing) + + def materialize( + self, + tree: qt.QueryTree, + dimensions: DimensionGroup, + datasets: frozenset[str], + ) -> qt.MaterializationKey: + # Docstring inherited. + if self._exit_stack is None: + raise RuntimeError("QueryDriver context must be entered before 'materialize' is called.") + query, sql_builder = self.analyze_query(tree, qt.ColumnSet(dimensions)) + # Current implementation ignores 'datasets' because figuring out what + # to put in the temporary table for them is tricky, especially if + # calibration collections are involved. + sql_builder = self.build_query(query, sql_builder) + sql_select = sql_builder.select(query.final_columns, query.postprocessing) + table = self._exit_stack.enter_context( + self.db.temporary_table(sql_builder.make_table_spec(query.final_columns, query.postprocessing)) + ) + self.db.insert(table, select=sql_select) + key = uuid.uuid4() + self._materializations[key] = (table, query.postprocessing) + return key + + def upload_data_coordinates( + self, dimensions: DimensionGroup, rows: Iterable[tuple[DataIdValue, ...]] + ) -> qt.DataCoordinateUploadKey: + # Docstring inherited. + if self._exit_stack is None: + raise RuntimeError("QueryDriver context must be entered before 'materialize' is called.") + table_spec = ddl.TableSpec( + [ + self.universe.dimensions[name].primary_key.model_copy(update=dict(name=name)).to_sql_spec() + for name in dimensions.required + ] + ) + if not dimensions: + table_spec.fields.add( + ddl.FieldSpec( + SqlBuilder.EMPTY_COLUMNS_NAME, dtype=SqlBuilder.EMPTY_COLUMNS_TYPE, nullable=True + ) + ) + table = self._exit_stack.enter_context(self.db.temporary_table(table_spec)) + self.db.insert(table, *(dict(zip(dimensions.required, values)) for values in rows)) + key = uuid.uuid4() + self._upload_tables[key] = table + return key + + def count( + self, + tree: qt.QueryTree, + columns: qt.ColumnSet, + find_first_dataset: str | None, + *, + exact: bool, + discard: bool, + ) -> int: + # Docstring inherited. + query, sql_builder = self.analyze_query(tree, columns, find_first_dataset=find_first_dataset) + sql_builder = self.build_query(query, sql_builder) + if query.postprocessing and exact: + if not discard: + raise RuntimeError("Cannot count query rows exactly without discarding them.") + sql_select = sql_builder.select(columns, query.postprocessing) + n = 0 + with self.db.query(sql_select.execution_options(yield_per=self._raw_page_size)) as results: + for _ in query.postprocessing.apply(results): + n + 1 + return n + # Do COUNT(*) on the original query's FROM clause. + sql_builder.special["_ROWCOUNT"] = sqlalchemy.func.count() + sql_select = sql_builder.select(qt.ColumnSet(self._universe.empty.as_group())) + with self.db.query(sql_select) as result: + return cast(int, result.scalar()) + + def any(self, tree: qt.QueryTree, *, execute: bool, exact: bool) -> bool: + # Docstring inherited. + query, sql_builder = self.analyze_query(tree, qt.ColumnSet(tree.dimensions)) + if not all(d.collection_records for d in query.datasets.values()): + return False + if not execute: + if exact: + raise RuntimeError("Cannot obtain exact result for 'any' without executing.") + return True + sql_builder = self.build_query(query, sql_builder) + if query.postprocessing and exact: + sql_select = sql_builder.select(query.final_columns, query.postprocessing) + with self.db.query( + sql_select.execution_options(yield_per=self._postprocessing_filter_factor) + ) as result: + for _ in query.postprocessing.apply(result): + return True + return False + sql_select = sql_builder.select(query.final_columns).limit(1) + with self.db.query(sql_select) as result: + return result.first() is not None + + def explain_no_results(self, tree: qt.QueryTree, execute: bool) -> Iterable[str]: + # Docstring inherited. + query, _ = self.analyze_query(tree, qt.ColumnSet(tree.dimensions)) + if query.messages or not execute: + return query.messages + # TODO: guess at ways to split up query that might fail or succeed if + # run separately, execute them with LIMIT 1 and report the results. + return [] + + def get_dataset_dimensions(self, name: str) -> DimensionGroup: + # Docstring inherited + return self.managers.datasets[name].datasetType.dimensions.as_group() + + def get_default_collections(self) -> tuple[str, ...]: + # Docstring inherited. + if not self._defaults.collections: + raise NoDefaultCollectionError("No collections provided and no default collections.") + return tuple(self._defaults.collections) + + def resolve_collection_path( + self, collections: Iterable[str] + ) -> list[tuple[CollectionRecord, CollectionSummary]]: + result: list[tuple[CollectionRecord, CollectionSummary]] = [] + done: set[str] = set() + + def recurse(collection_names: Iterable[str]) -> None: + for collection_name in collection_names: + if collection_name not in done: + done.add(collection_name) + record = self.managers.collections.find(collection_name) + + if record.type is CollectionType.CHAINED: + recurse(cast(ChainedCollectionRecord, record).children) + else: + result.append((record, self.managers.datasets.getCollectionSummary(record))) + + recurse(collections) + + return result + + def analyze_query( + self, + tree: qt.QueryTree, + final_columns: qt.ColumnSet, + order_by: Iterable[qt.OrderExpression] = (), + find_first_dataset: str | None = None, + ) -> tuple[AnalyzedQuery, SqlBuilder]: + # Delegate to the dimensions manager to rewrite the predicate and + # start a SqlBuilder and Postprocessing to cover any spatial overlap + # joins or constraints. We'll return that SqlBuilder at the end. + ( + predicate, + sql_builder, + postprocessing, + ) = self.managers.dimensions.process_query_overlaps( + tree.dimensions, + tree.predicate, + tree.join_operand_dimensions, + ) + # Initialize the AnalyzedQuery instance we'll update throughout this + # method. + query = AnalyzedQuery( + predicate, + postprocessing, + base_columns=qt.ColumnSet(tree.dimensions), + projection_columns=final_columns.copy(), + final_columns=final_columns, + find_first_dataset=find_first_dataset, + ) + # The base query needs to join in all columns required by the + # predicate. + predicate.gather_required_columns(query.base_columns) + # The "projection" query differs from the final query by not omitting + # any dimension keys (since that makes it easier to reason about), + # including any columns needed by order_by terms, and including + # the dataset rank if there's a find-first search in play. + query.projection_columns.restore_dimension_keys() + for term in order_by: + term.gather_required_columns(query.projection_columns) + if query.find_first_dataset is not None: + query.projection_columns.dataset_fields[query.find_first_dataset].add("collection_key") + # The base query also needs to include all columns needed by the + # downstream projection query. + query.base_columns.update(query.projection_columns) + # Extract the data ID implied by the predicate; we can use the governor + # dimensions in that to constrain the collections we search for + # datasets later. + query.predicate.visit(DataIdExtractionVisitor(query.constraint_data_id, query.messages)) + # We also check that the predicate doesn't reference any dimensions + # without constraining their governor dimensions, since that's a + # particularly easy mistake to make and it's almost never intentional. + # We also also the registry data ID values to provide governor values. + where_columns = qt.ColumnSet(query.universe.empty.as_group()) + query.predicate.gather_required_columns(where_columns) + for governor in where_columns.dimensions.governors: + if governor not in query.constraint_data_id: + if governor in self._defaults.dataId.dimensions: + query.constraint_data_id[governor] = self._defaults.dataId[governor] + else: + raise qt.InvalidQueryTreeError( + f"Query 'where' expression references a dimension dependent on {governor} without " + "constraining it directly." + ) + # Add materializations, which can also bring in more postprocessing. + for m_key, m_dimensions in tree.materializations.items(): + _, m_postprocessing = self._materializations[m_key] + query.materializations[m_key] = m_dimensions + # When a query is materialized, the new tree's has an empty + # (trivially true) predicate, and the materialization prevents the + # creation of automatic spatial joins that are already included in + # the materialization, so we don't need to deduplicate these + # filters. It's possible for there to be duplicates, but only if + # the user explicitly adds a redundant constraint, and we'll still + # behave correctly (just less efficiently) if that happens. + postprocessing.spatial_join_filtering.extend(m_postprocessing.spatial_join_filtering) + postprocessing.spatial_where_filtering.extend(m_postprocessing.spatial_where_filtering) + # Add data coordinate uploads. + query.data_coordinate_uploads.update(tree.data_coordinate_uploads) + # Add dataset_searches and filter out collections that don't have the + # right dataset type or governor dimensions. + name_shrinker = make_dataset_name_shrinker(self.db.dialect) + for dataset_type_name, dataset_search in tree.datasets.items(): + dataset = AnalyzedDatasetSearch( + dataset_type_name, name_shrinker.shrink(dataset_type_name), dataset_search.dimensions + ) + for collection_record, collection_summary in self.resolve_collection_path( + dataset_search.collections + ): + rejected: bool = False + if dataset.name not in collection_summary.dataset_types.names: + dataset.messages.append( + f"No datasets of type {dataset.name!r} in collection {collection_record.name}." + ) + rejected = True + for governor in query.constraint_data_id.keys() & collection_summary.governors.keys(): + if query.constraint_data_id[governor] not in collection_summary.governors[governor]: + dataset.messages.append( + f"No datasets with {governor}={query.constraint_data_id[governor]!r} " + f"in collection {collection_record.name}." + ) + rejected = True + if not rejected: + if collection_record.type is CollectionType.CALIBRATION: + dataset.is_calibration_search = True + dataset.collection_records.append(collection_record) + if dataset.dimensions != self.get_dataset_type(dataset_type_name).dimensions.as_group(): + # This is really for server-side defensiveness; it's hard to + # imagine the query getting different dimensions for a dataset + # type in two calls to the same query driver. + raise qt.InvalidQueryTreeError( + f"Incorrect dimensions {dataset.dimensions} for dataset {dataset_type_name} " + f"in query (vs. {self.get_dataset_type(dataset_type_name).dimensions.as_group()})." + ) + query.datasets[dataset_type_name] = dataset + if not dataset.collection_records: + query.messages.append(f"Search for dataset type {dataset_type_name!r} is doomed to fail.") + query.messages.extend(dataset.messages) + # Set flags that indicate certain kinds of special processing the query + # will need, mostly in the "projection" stage, where we might do a + # GROUP BY or DISTINCT [ON]. + if query.find_first_dataset is not None: + # If we're doing a find-first search and there's a calibration + # collection in play, we need to make sure the rows coming out of + # the base query have only one timespan for each data ID + + # collection, and we can only do that with a GROUP BY and COUNT. + query.postprocessing.check_validity_match_count = query.datasets[ + query.find_first_dataset + ].is_calibration_search + # We only actually need to include the find-first resolution query + # logic if there's more than one collection. + query.needs_find_first_resolution = ( + len(query.datasets[query.find_first_dataset].collection_records) > 1 + ) + if query.projection_columns.dimensions != query.base_columns.dimensions: + # We're going from a larger set of dimensions to a smaller set, + # that means we'll be doing a SELECT DISTINCT [ON] or GROUP BY. + query.needs_dimension_distinct = True + # If there are any dataset fields being propagated through that + # projection and there is more than one collection, we need to + # include the collection_key column so we can use that as one of + # the DISTINCT ON or GROUP BY columns. + for dataset_type, fields_for_dataset in query.projection_columns.dataset_fields.items(): + if len(query.datasets[dataset_type].collection_records) > 1: + fields_for_dataset.add("collection_key") + # If there's a projection and we're doing postprocessing, we might + # be collapsing the dimensions of the postprocessing regions. When + # that happens, we want to apply an aggregate function to them that + # computes the union of the regions that are grouped together. + for element in query.postprocessing.iter_missing(query.projection_columns): + if element.name not in query.projection_columns.dimensions.elements: + query.projection_region_aggregates.append(element) + break + return query, sql_builder + + def build_query(self, query: AnalyzedQuery, sql_builder: SqlBuilder) -> SqlBuilder: + sql_builder = self._build_base_query(query, sql_builder) + if query.needs_projection: + sql_builder = self._project_query(query, sql_builder) + if query.needs_find_first_resolution: + sql_builder = self._apply_find_first(query, sql_builder) + elif query.needs_find_first_resolution: + sql_builder = self._apply_find_first( + query, sql_builder.cte(query.projection_columns, query.postprocessing) + ) + return sql_builder + + def _build_base_query(self, query: AnalyzedQuery, sql_builder: SqlBuilder) -> SqlBuilder: + # Process data coordinate upload joins. + for upload_key, upload_dimensions in query.data_coordinate_uploads.items(): + sql_builder = sql_builder.join( + SqlBuilder(self.db, self._upload_tables[upload_key]).extract_dimensions( + upload_dimensions.required + ) + ) + # Process materialization joins. + for materialization_key, materialization_spec in query.materializations.items(): + sql_builder = self._join_materialization(sql_builder, materialization_key, materialization_spec) + # Process dataset joins. + for dataset_type, dataset_search in query.datasets.items(): + sql_builder = self._join_dataset_search( + sql_builder, + dataset_type, + dataset_search, + query.base_columns, + ) + # Join in dimension element tables that we know we need relationships + # or columns from. + for element in query.iter_mandatory_base_elements(): + sql_builder = sql_builder.join( + self.managers.dimensions.make_sql_builder( + element, query.base_columns.dimension_fields[element.name] + ) + ) + # See if any dimension keys are still missing, and if so join in their + # tables. Note that we know there are no fields needed from these. + while not (sql_builder.dimension_keys.keys() >= query.base_columns.dimensions.names): + # Look for opportunities to join in multiple dimensions via single + # table, to reduce the total number of tables joined in. + missing_dimension_names = query.base_columns.dimensions.names - sql_builder.dimension_keys.keys() + best = self._universe[ + max( + missing_dimension_names, + key=lambda name: len(self._universe[name].dimensions.names & missing_dimension_names), + ) + ] + sql_builder = sql_builder.join(self.managers.dimensions.make_sql_builder(best, frozenset())) + # Add the WHERE clause to the builder. + return sql_builder.where_sql(query.predicate.visit(SqlColumnVisitor(sql_builder, self))) + + def _project_query(self, query: AnalyzedQuery, sql_builder: SqlBuilder) -> SqlBuilder: + assert query.needs_projection + # This method generates a Common Table Expresssion (CTE) using either a + # SELECT DISTINCT [ON] or a SELECT with GROUP BY. + # We'll work out which as we go + have_aggregates: bool = False + # Dimension key columns form at least most of our GROUP BY or DISTINCT + # ON clause; we'll work out which of those we'll use. + unique_keys: list[sqlalchemy.ColumnElement[Any]] = [ + sql_builder.dimension_keys[k][0] for k in query.projection_columns.dimensions.data_coordinate_keys + ] + # There are two reasons we might need an aggregate function: + # - to make sure temporal constraints and joins have resulted in at + # most one validity range match for each data ID and collection, + # when we're doing a find-first query. + # - to compute the unions of regions we need for postprocessing, when + # the data IDs for those regions are not wholly included in the + # results (i.e. we need to postprocess on + # visit_detector_region.region, but the output rows don't have + # detector, just visit - so we compute the union of the + # visit_detector region over all matched detectors). + if query.postprocessing.check_validity_match_count: + sql_builder.special[query.postprocessing.VALIDITY_MATCH_COUNT] = sqlalchemy.func.count().label( + query.postprocessing.VALIDITY_MATCH_COUNT + ) + have_aggregates = True + for element in query.projection_region_aggregates: + sql_builder.fields[element.name]["region"] = ddl.Base64Region.union_aggregate( + sql_builder.fields[element.name]["region"] + ) + have_aggregates = True + # Many of our fields derive their uniqueness from the unique_key + # fields: if rows are uniqe over the 'unique_key' fields, then they're + # automatically unique over these 'derived_fields'. We just remember + # these as pairs of (logical_table, field) for now. + derived_fields: list[tuple[str, str]] = [] + # All dimension record fields are derived fields. + for element_name, fields_for_element in query.projection_columns.dimension_fields.items(): + for element_field in fields_for_element: + derived_fields.append((element_name, element_field)) + # Some dataset fields are derived fields and some are unique keys, and + # it depends on the kinds of collection(s) we're searching and whether + # it's a find-first query. + for dataset_type, fields_for_dataset in query.projection_columns.dataset_fields.items(): + for dataset_field in fields_for_dataset: + if dataset_field == "collection_key": + # If the collection_key field is present, it's needed for + # uniqueness if we're looking in more than one collection. + # If not, it's a derived field. + if len(query.datasets[dataset_type].collection_records) > 1: + unique_keys.append(sql_builder.fields[dataset_type]["collection_key"]) + else: + derived_fields.append((dataset_type, "collection_key")) + elif dataset_field == "timespan" and query.datasets[dataset_type].is_calibration_search: + # If we're doing a non-find-first query against a + # CALIBRATION collection, the timespan is also a unique + # key... + if dataset_type == query.find_first_dataset: + # ...unless we're doing a find-first search on this + # dataset, in which case we need to use ANY_VALUE on + # the timespan and check that _VALIDITY_MATCH_COUNT + # (added earlier) is one, indicating that there was + # indeed only one timespan for each data ID in each + # collection that survived the base query's WHERE + # clauses and JOINs. + if not self.db.has_any_aggregate: + raise NotImplementedError( + f"Cannot generate query that returns {dataset_type}.timespan after a " + "find-first search, because this a database does not support the ANY_VALUE " + "aggregate function (or equivalent)." + ) + sql_builder.timespans[dataset_type] = sql_builder.timespans[ + dataset_type + ].apply_any_aggregate(self.db.apply_any_aggregate) + else: + unique_keys.extend(sql_builder.timespans[dataset_type].flatten()) + else: + # Other dataset fields derive their uniqueness from key + # fields. + derived_fields.append((dataset_type, dataset_field)) + if not have_aggregates and not derived_fields: + # SELECT DISTINCT is sufficient. + return sql_builder.cte(query.projection_columns, query.postprocessing, distinct=True) + elif not have_aggregates and self.db.has_distinct_on: + # SELECT DISTINCT ON is sufficient and works. + return sql_builder.cte(query.projection_columns, query.postprocessing, distinct=unique_keys) + else: + # GROUP BY is the only option. + if derived_fields: + if self.db.has_any_aggregate: + for logical_table, field in derived_fields: + if field == "timespan": + sql_builder.timespans[logical_table] = sql_builder.timespans[ + logical_table + ].apply_any_aggregate(self.db.apply_any_aggregate) + else: + sql_builder.fields[logical_table][field] = self.db.apply_any_aggregate( + sql_builder.fields[logical_table][field] + ) + else: + _LOG.warning( + "Adding %d fields to GROUP BY because this database backend does not support the " + "ANY_VALUE aggregate function (or equivalent). This may result in a poor query " + "plan. Materializing the query first sometimes avoids this problem.", + len(derived_fields), + ) + for logical_table, field in derived_fields: + if field == "timespan": + unique_keys.extend(sql_builder.timespans[logical_table].flatten()) + else: + unique_keys.append(sql_builder.fields[logical_table][field]) + return sql_builder.cte(query.projection_columns, query.postprocessing, group_by=unique_keys) + + def _apply_find_first(self, query: AnalyzedQuery, sql_builder: SqlBuilder) -> SqlBuilder: + assert query.needs_find_first_resolution + assert query.find_first_dataset is not None + assert sql_builder.sql_from_clause is not None + # The query we're building looks like this: + # + # WITH {dst}_base AS ( + # {target} + # ... + # ) + # SELECT + # {dst}_window.*, + # FROM ( + # SELECT + # {dst}_base.*, + # ROW_NUMBER() OVER ( + # PARTITION BY {dst_base}.{dimensions} + # ORDER BY {rank} + # ) AS rownum + # ) {dst}_window + # WHERE + # {dst}_window.rownum = 1; + # + # The outermost SELECT will be represented by the SqlBuilder we return. + + # The sql_builder we're given corresponds to the Common Table + # Expression (CTE) at the top, and is guaranteed to have + # ``query.projected_columns`` (+ postprocessing columns). + # We start by filling out the "window" SELECT statement... + partition_by = [sql_builder.dimension_keys[d][0] for d in query.base_columns.dimensions.required] + rank_sql_column = sqlalchemy.case( + { + record.key: n + for n, record in enumerate(query.datasets[query.find_first_dataset].collection_records) + }, + value=sql_builder.fields[query.find_first_dataset]["collection_key"], + ) + if partition_by: + sql_builder.special["_ROWNUM"] = sqlalchemy.sql.func.row_number().over( + partition_by=partition_by, order_by=rank_sql_column + ) + else: + sql_builder.special["_ROWNUM"] = sqlalchemy.sql.func.row_number().over(order_by=rank_sql_column) + # ... and then turn that into a subquery with a constraint on rownum. + sql_builder = sql_builder.subquery(query.base_columns, query.postprocessing) + sql_builder = sql_builder.where_sql(sql_builder.special["_ROWNUM"] == 1) + del sql_builder.special["_ROWNUM"] + return sql_builder + + def _join_materialization( + self, + sql_builder: SqlBuilder, + materialization_key: qt.MaterializationKey, + dimensions: DimensionGroup, + ) -> SqlBuilder: + columns = qt.ColumnSet(dimensions) + table, postprocessing = self._materializations[materialization_key] + return sql_builder.join(SqlBuilder(self.db, table).extract_columns(columns, postprocessing)) + + def _join_dataset_search( + self, + sql_builder: SqlBuilder, + dataset_type: str, + processed_dataset_search: AnalyzedDatasetSearch, + columns: qt.ColumnSet, + ) -> SqlBuilder: + storage = self.managers.datasets[dataset_type] + # The next two asserts will need to be dropped (and the implications + # dealt with instead) if materializations start having dataset fields. + assert ( + dataset_type not in sql_builder.fields + ), "Dataset fields have unexpected already been joined in." + assert ( + dataset_type not in sql_builder.timespans + ), "Dataset timespan has unexpected already been joined in." + return sql_builder.join( + storage.make_sql_builder( + processed_dataset_search.collection_records, columns.dataset_fields[dataset_type] + ) + ) + + def _process_page( + self, + raw_page_iter: Iterator[Sequence[sqlalchemy.Row]], + result_spec: ResultSpec, + postprocessing: Postprocessing, + ) -> ResultPage: + try: + raw_page = next(raw_page_iter) + except StopIteration: + raw_page = tuple() + if len(raw_page) == self._raw_page_size: + # There's some chance we got unlucky and this page exactly finishes + # off the query, and we won't know the next page does not exist + # until we try to fetch it. But that's better than always fetching + # the next page up front. + next_key = uuid.uuid4() + self._active_pages[next_key] = (raw_page_iter, postprocessing) + else: + next_key = None + match result_spec: + case DimensionRecordResultSpec(): + return convert_dimension_record_results( + postprocessing.apply(raw_page), + result_spec, + next_key, + self._name_shrinker, + ) + case _: + raise NotImplementedError("TODO") + + +def make_dataset_name_shrinker(dialect: sqlalchemy.Dialect) -> NameShrinker: + max_dataset_field_length = max(len(field) for field in qt.DATASET_FIELD_NAMES) + return NameShrinker(dialect.max_identifier_length - max_dataset_field_length - 1, 6) diff --git a/python/lsst/daf/butler/direct_query_driver/_postprocessing.py b/python/lsst/daf/butler/direct_query_driver/_postprocessing.py new file mode 100644 index 0000000000..2dc4ead742 --- /dev/null +++ b/python/lsst/daf/butler/direct_query_driver/_postprocessing.py @@ -0,0 +1,138 @@ +# This file is part of daf_butler. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (http://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This software is dual licensed under the GNU General Public License and also +# under a 3-clause BSD license. Recipients may choose which of these licenses +# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, +# respectively. If you choose the GPL option then the following text applies +# (but note that there is still no warranty even if you opt for BSD instead): +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +from __future__ import annotations + +__all__ = ("Postprocessing", "ValidityRangeMatchError") + +from collections.abc import Iterable, Iterator +from typing import TYPE_CHECKING, ClassVar + +import sqlalchemy +from lsst.sphgeom import DISJOINT, Region + +from ..queries import tree as qt + +if TYPE_CHECKING: + from ..dimensions import DimensionElement + + +class ValidityRangeMatchError(RuntimeError): + pass + + +class Postprocessing: + def __init__(self) -> None: + self.spatial_join_filtering: list[tuple[DimensionElement, DimensionElement]] = [] + self.spatial_where_filtering: list[tuple[DimensionElement, Region]] = [] + self.check_validity_match_count: bool = False + self._offset: int = 0 + self._limit: int | None = None + + VALIDITY_MATCH_COUNT: ClassVar[str] = "_VALIDITY_MATCH_COUNT" + + @property + def offset(self) -> int: + return self._offset + + @offset.setter + def offset(self, value: int) -> None: + if value and not self: + raise RuntimeError( + "Postprocessing should only implement 'offset' if it needs to do spatial filtering." + ) + self._offset = value + + @property + def limit(self) -> int | None: + return self._limit + + @limit.setter + def limit(self, value: int | None) -> None: + if value and not self: + raise RuntimeError( + "Postprocessing should only implement 'limit' if it needs to do spatial filtering." + ) + self._limit = value + + def __bool__(self) -> bool: + return bool(self.spatial_join_filtering or self.spatial_where_filtering) + + def gather_columns_required(self, columns: qt.ColumnSet) -> None: + for element in self.iter_region_dimension_elements(): + columns.update_dimensions(element.minimal_group) + columns.dimension_fields[element.name].add("region") + + def iter_region_dimension_elements(self) -> Iterator[DimensionElement]: + for a, b in self.spatial_join_filtering: + yield a + yield b + for element, _ in self.spatial_where_filtering: + yield element + + def iter_missing(self, columns: qt.ColumnSet) -> Iterator[DimensionElement]: + done: set[DimensionElement] = set() + for element in self.iter_region_dimension_elements(): + if element not in done: + if "region" not in columns.dimension_fields.get(element.name, frozenset()): + yield element + done.add(element) + + def apply(self, rows: Iterable[sqlalchemy.Row]) -> Iterable[sqlalchemy.Row]: + if not self: + yield from rows + joins = [ + ( + qt.ColumnSet.get_qualified_name(a.name, "region"), + qt.ColumnSet.get_qualified_name(b.name, "region"), + ) + for a, b in self.spatial_join_filtering + ] + where = [ + (qt.ColumnSet.get_qualified_name(element.name, "region"), region) + for element, region in self.spatial_where_filtering + ] + for row in rows: + m = row._mapping + if any(m[a].relate(m[b]) & DISJOINT for a, b in joins) or any( + m[field].relate(region) & DISJOINT for field, region in where + ): + continue + if self.check_validity_match_count and m[self.VALIDITY_MATCH_COUNT] > 1: + raise ValidityRangeMatchError( + "Ambiguous calibration validity range match. This usually means a temporal join or " + "'where' needs to be added, but it could also mean that multiple validity ranges " + "overlap a single output data ID." + ) + if self._offset: + self._offset -= 1 + continue + if self._limit == 0: + break + yield row + if self._limit is not None: + self._limit -= 1 diff --git a/python/lsst/daf/butler/direct_query_driver/_sql_builder.py b/python/lsst/daf/butler/direct_query_driver/_sql_builder.py new file mode 100644 index 0000000000..84726f4bc5 --- /dev/null +++ b/python/lsst/daf/butler/direct_query_driver/_sql_builder.py @@ -0,0 +1,257 @@ +# This file is part of daf_butler. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (http://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This software is dual licensed under the GNU General Public License and also +# under a 3-clause BSD license. Recipients may choose which of these licenses +# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, +# respectively. If you choose the GPL option then the following text applies +# (but note that there is still no warranty even if you opt for BSD instead): +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +from __future__ import annotations + +__all__ = ("SqlBuilder",) + +import dataclasses +import itertools +from collections.abc import Iterable, Sequence +from typing import TYPE_CHECKING, Any, ClassVar + +import sqlalchemy + +from .. import ddl +from ..nonempty_mapping import NonemptyMapping +from ..queries import tree as qt +from ._postprocessing import Postprocessing + +if TYPE_CHECKING: + from ..registry.interfaces import Database + from ..timespan_database_representation import TimespanDatabaseRepresentation + + +@dataclasses.dataclass +class SqlBuilder: + db: Database + sql_from_clause: sqlalchemy.FromClause | None = None + sql_where_terms: list[sqlalchemy.ColumnElement[bool]] = dataclasses.field(default_factory=list) + needs_distinct: bool = False + + dimension_keys: NonemptyMapping[str, list[sqlalchemy.ColumnElement]] = dataclasses.field( + default_factory=lambda: NonemptyMapping(list) + ) + + fields: NonemptyMapping[str, dict[str, sqlalchemy.ColumnElement[Any]]] = dataclasses.field( + default_factory=lambda: NonemptyMapping(dict) + ) + + timespans: dict[str, TimespanDatabaseRepresentation] = dataclasses.field(default_factory=dict) + + special: dict[str, sqlalchemy.ColumnElement[Any]] = dataclasses.field(default_factory=dict) + + EMPTY_COLUMNS_NAME: ClassVar[str] = "IGNORED" + """Name of the column added to a SQL ``SELECT`` query in order to represent + relations that have no real columns. + """ + + EMPTY_COLUMNS_TYPE: ClassVar[type] = sqlalchemy.Boolean + """Type of the column added to a SQL ``SELECT`` query in order to represent + relations that have no real columns. + """ + + @property + def sql_columns(self) -> sqlalchemy.ColumnCollection: + assert self.sql_from_clause is not None + return self.sql_from_clause.columns + + @classmethod + def handle_empty_columns( + cls, columns: list[sqlalchemy.sql.ColumnElement] + ) -> list[sqlalchemy.ColumnElement]: + """Handle the edge case where a SELECT statement has no columns, by + adding a literal column that should be ignored. + + Parameters + ---------- + columns : `list` [ `sqlalchemy.ColumnElement` ] + List of SQLAlchemy column objects. This may have no elements when + this method is called, and will always have at least one element + when it returns. + + Returns + ------- + columns : `list` [ `sqlalchemy.ColumnElement` ] + The same list that was passed in, after any modification. + """ + if not columns: + columns.append(sqlalchemy.sql.literal(True).label(cls.EMPTY_COLUMNS_NAME)) + return columns + + def select( + self, + columns: qt.ColumnSet, + postprocessing: Postprocessing | None = None, + *, + distinct: bool | Sequence[sqlalchemy.ColumnElement[Any]] = False, + group_by: Sequence[sqlalchemy.ColumnElement] = (), + ) -> sqlalchemy.Select: + sql_columns: list[sqlalchemy.ColumnElement[Any]] = [] + for logical_table, field in columns: + name = columns.get_qualified_name(logical_table, field) + if field is None: + sql_columns.append(self.dimension_keys[logical_table][0].label(name)) + elif columns.is_timespan(logical_table, field): + sql_columns.extend(self.timespans[logical_table].flatten(name)) + else: + sql_columns.append(self.fields[logical_table][field].label(name)) + if postprocessing is not None: + for element in postprocessing.iter_missing(columns): + assert ( + element.name in columns.dimensions.elements + ), "Region aggregates not handled by this method." + sql_columns.append( + self.fields[element.name]["region"].label( + columns.get_qualified_name(element.name, "region") + ) + ) + for label, sql_column in self.special.items(): + sql_columns.append(sql_column.label(label)) + self.handle_empty_columns(sql_columns) + result = sqlalchemy.select(*sql_columns) + if self.sql_from_clause is not None: + result = result.select_from(self.sql_from_clause) + if self.needs_distinct or distinct: + if distinct is True or distinct is False: + result = result.distinct() + else: + result = result.distinct(*distinct) + if group_by: + result = result.group_by(*group_by) + if self.sql_where_terms: + result = result.where(*self.sql_where_terms) + return result + + def make_table_spec( + self, + columns: qt.ColumnSet, + postprocessing: Postprocessing | None = None, + ) -> ddl.TableSpec: + assert not self.special, "special columns not supported in make_table_spec" + results = ddl.TableSpec( + [columns.get_column_spec(logical_table, field).to_sql_spec() for logical_table, field in columns] + ) + if postprocessing: + for element in postprocessing.iter_missing(columns): + results.fields.add( + ddl.FieldSpec.for_region(columns.get_qualified_name(element.name, "region")) + ) + return results + + def extract_dimensions(self, dimensions: Iterable[str], **kwargs: str) -> SqlBuilder: + assert self.sql_from_clause is not None, "Cannot extract columns with no FROM clause." + for dimension_name in dimensions: + self.dimension_keys[dimension_name].append(self.sql_from_clause.columns[dimension_name]) + for k, v in kwargs.items(): + self.dimension_keys[v].append(self.sql_from_clause.columns[k]) + return self + + def extract_columns( + self, columns: qt.ColumnSet, postprocessing: Postprocessing | None = None + ) -> SqlBuilder: + assert self.sql_from_clause is not None, "Cannot extract columns with no FROM clause." + for logical_table, field in columns: + name = columns.get_qualified_name(logical_table, field) + if field is None: + self.dimension_keys[logical_table].append(self.sql_from_clause.columns[name]) + elif columns.is_timespan(logical_table, field): + self.timespans[logical_table] = self.db.getTimespanRepresentation().from_columns( + self.sql_from_clause.columns, name + ) + else: + self.fields[logical_table][field] = self.sql_from_clause.columns[name] + if postprocessing is not None: + for element in postprocessing.iter_missing(columns): + self.fields[element.name]["region"] = self.sql_from_clause.columns[name] + if postprocessing.check_validity_match_count: + self.special[postprocessing.VALIDITY_MATCH_COUNT] = self.sql_from_clause.columns[ + postprocessing.VALIDITY_MATCH_COUNT + ] + return self + + def join(self, other: SqlBuilder) -> SqlBuilder: + join_on: list[sqlalchemy.ColumnElement] = [] + for dimension_name in self.dimension_keys.keys() & other.dimension_keys.keys(): + for column1, column2 in itertools.product( + self.dimension_keys[dimension_name], other.dimension_keys[dimension_name] + ): + join_on.append(column1 == column2) + self.dimension_keys[dimension_name].extend(other.dimension_keys[dimension_name]) + if self.sql_from_clause is None: + self.sql_from_clause = other.sql_from_clause + elif other.sql_from_clause is not None: + self.sql_from_clause = self.sql_from_clause.join( + other.sql_from_clause, onclause=sqlalchemy.and_(*join_on) + ) + self.sql_where_terms += other.sql_where_terms + self.needs_distinct = self.needs_distinct or other.needs_distinct + self.special.update(other.special) + return self + + def where_sql(self, *arg: sqlalchemy.ColumnElement[bool]) -> SqlBuilder: + self.sql_where_terms.extend(arg) + return self + + def cte( + self, + columns: qt.ColumnSet, + postprocessing: Postprocessing | None = None, + *, + distinct: bool | Sequence[sqlalchemy.ColumnElement[Any]] = False, + group_by: Sequence[sqlalchemy.ColumnElement] = (), + ) -> SqlBuilder: + return SqlBuilder( + self.db, + self.select(columns, postprocessing, distinct=distinct, group_by=group_by).cte(), + ).extract_columns(columns, postprocessing) + + def subquery( + self, + columns: qt.ColumnSet, + postprocessing: Postprocessing | None = None, + *, + distinct: bool | Sequence[sqlalchemy.ColumnElement[Any]] = False, + group_by: Sequence[sqlalchemy.ColumnElement] = (), + ) -> SqlBuilder: + return SqlBuilder( + self.db, + self.select(columns, postprocessing, distinct=distinct, group_by=group_by).subquery(), + ).extract_columns(columns, postprocessing) + + def union_subquery( + self, + others: Iterable[SqlBuilder], + columns: qt.ColumnSet, + postprocessing: Postprocessing | None = None, + ) -> SqlBuilder: + select0 = self.select(columns, postprocessing) + other_selects = [other.select(columns, postprocessing) for other in others] + return SqlBuilder( + self.db, + select0.union(*other_selects).subquery(), + ).extract_columns(columns, postprocessing) diff --git a/python/lsst/daf/butler/direct_query_driver/_sql_column_visitor.py b/python/lsst/daf/butler/direct_query_driver/_sql_column_visitor.py new file mode 100644 index 0000000000..d3a6de201f --- /dev/null +++ b/python/lsst/daf/butler/direct_query_driver/_sql_column_visitor.py @@ -0,0 +1,239 @@ +# This file is part of daf_butler. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (http://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This software is dual licensed under the GNU General Public License and also +# under a 3-clause BSD license. Recipients may choose which of these licenses +# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, +# respectively. If you choose the GPL option then the following text applies +# (but note that there is still no warranty even if you opt for BSD instead): +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +from __future__ import annotations + +__all__ = ("SqlColumnVisitor",) + +from typing import TYPE_CHECKING, Any, cast + +import sqlalchemy + +from .. import ddl +from ..queries.visitors import ColumnExpressionVisitor, PredicateVisitFlags, PredicateVisitor +from ..timespan_database_representation import TimespanDatabaseRepresentation + +if TYPE_CHECKING: + from ..queries import tree as qt + from ._driver import DirectQueryDriver + from ._sql_builder import SqlBuilder + + +class SqlColumnVisitor( + ColumnExpressionVisitor[sqlalchemy.ColumnElement[Any] | TimespanDatabaseRepresentation], + PredicateVisitor[ + sqlalchemy.ColumnElement[bool], sqlalchemy.ColumnElement[bool], sqlalchemy.ColumnElement[bool] + ], +): + def __init__(self, sql_builder: SqlBuilder, driver: DirectQueryDriver): + self._driver = driver + self._sql_builder = sql_builder + + def visit_literal( + self, expression: qt.ColumnLiteral + ) -> sqlalchemy.ColumnElement[Any] | TimespanDatabaseRepresentation: + # Docstring inherited. + if expression.column_type == "timespan": + return self._driver.db.getTimespanRepresentation().fromLiteral(expression.get_literal_value()) + return sqlalchemy.literal( + expression.get_literal_value(), type_=ddl.VALID_CONFIG_COLUMN_TYPES[expression.column_type] + ) + + def visit_dimension_key_reference( + self, expression: qt.DimensionKeyReference + ) -> sqlalchemy.ColumnElement[int | str]: + # Docstring inherited. + return self._sql_builder.dimension_keys[expression.dimension.name][0] + + def visit_dimension_field_reference( + self, expression: qt.DimensionFieldReference + ) -> sqlalchemy.ColumnElement[Any] | TimespanDatabaseRepresentation: + # Docstring inherited. + if expression.column_type == "timespan": + return self._sql_builder.timespans[expression.element.name] + return self._sql_builder.fields[expression.element.name][expression.field] + + def visit_dataset_field_reference( + self, expression: qt.DatasetFieldReference + ) -> sqlalchemy.ColumnElement[Any] | TimespanDatabaseRepresentation: + # Docstring inherited. + if expression.column_type == "timespan": + return self._sql_builder.timespans[expression.dataset_type] + return self._sql_builder.fields[expression.dataset_type][expression.field] + + def visit_unary_expression(self, expression: qt.UnaryExpression) -> sqlalchemy.ColumnElement[Any]: + # Docstring inherited. + match expression.operator: + case "-": + return -self.expect_scalar(expression.operand) + case "begin_of": + return self.expect_timespan(expression.operand).lower() + case "end_of": + return self.expect_timespan(expression.operand).upper() + raise AssertionError(f"Invalid unary expression operator {expression.operator!r}.") + + def visit_binary_expression(self, expression: qt.BinaryExpression) -> sqlalchemy.ColumnElement[Any]: + # Docstring inherited. + a = self.expect_scalar(expression.a) + b = self.expect_scalar(expression.b) + match expression.operator: + case "+": + return a + b + case "-": + return a - b + case "*": + return a * b + case "/": + return a / b + case "%": + return a % b + raise AssertionError(f"Invalid binary expression operator {expression.operator!r}.") + + def visit_reversed(self, expression: qt.Reversed) -> sqlalchemy.ColumnElement[Any]: + # Docstring inherited. + return self.expect_scalar(expression.operand).desc() + + def visit_comparison( + self, + a: qt.ColumnExpression, + operator: qt.ComparisonOperator, + b: qt.ColumnExpression, + flags: PredicateVisitFlags, + ) -> sqlalchemy.ColumnElement[bool]: + # Docstring inherited. + if operator == "overlaps": + assert a.column_type == "timespan", "Spatial overlaps should be transformed away by now." + return self.expect_timespan(a).overlaps(self.expect_timespan(b)) + lhs = self.expect_scalar(a) + rhs = self.expect_scalar(b) + match operator: + case "==": + return lhs == rhs + case "!=": + return lhs != rhs + case "<": + return lhs < rhs + case ">": + return lhs > rhs + case "<=": + return lhs <= rhs + case ">=": + return lhs >= rhs + raise AssertionError(f"Invalid comparison operator {operator!r}.") + + def visit_is_null( + self, operand: qt.ColumnExpression, flags: PredicateVisitFlags + ) -> sqlalchemy.ColumnElement[bool]: + # Docstring inherited. + if operand.column_type == "timespan": + return self.expect_timespan(operand).isNull() + return self.expect_scalar(operand) == sqlalchemy.null() + + def visit_in_container( + self, + member: qt.ColumnExpression, + container: tuple[qt.ColumnExpression, ...], + flags: PredicateVisitFlags, + ) -> sqlalchemy.ColumnElement[bool]: + # Docstring inherited. + return self.expect_scalar(member).in_([self.expect_scalar(item) for item in container]) + + def visit_in_range( + self, member: qt.ColumnExpression, start: int, stop: int | None, step: int, flags: PredicateVisitFlags + ) -> sqlalchemy.ColumnElement[bool]: + # Docstring inherited. + sql_member = self.expect_scalar(member) + if stop is None: + target = sql_member >= sqlalchemy.literal(start) + else: + stop_inclusive = stop - 1 + if start == stop_inclusive: + return sql_member == sqlalchemy.literal(start) + else: + target = sqlalchemy.sql.between( + sql_member, + sqlalchemy.literal(start), + sqlalchemy.literal(stop_inclusive), + ) + if step != 1: + return sqlalchemy.sql.and_( + *[ + target, + sql_member % sqlalchemy.literal(step) == sqlalchemy.literal(start % step), + ] + ) + else: + return target + + def visit_in_query_tree( + self, + member: qt.ColumnExpression, + column: qt.ColumnExpression, + query_tree: qt.QueryTree, + flags: PredicateVisitFlags, + ) -> sqlalchemy.ColumnElement[bool]: + # Docstring inherited. + columns = qt.ColumnSet(self._driver.universe.empty.as_group()) + column.gather_required_columns(columns) + query, sql_builder = self._driver.analyze_query(query_tree, columns) + self._driver.build_query(query, sql_builder) + if query.postprocessing: + raise NotImplementedError( + "Right-hand side subquery in IN expression would require postprocessing." + ) + subquery_visitor = SqlColumnVisitor(sql_builder, self._driver) + sql_builder.special["_MEMBER"] = subquery_visitor.expect_scalar(column) + subquery_select = sql_builder.select(qt.ColumnSet(self._driver.universe.empty.as_group())) + sql_member = self.expect_scalar(member) + return sql_member.in_(subquery_select) + + def apply_logical_and( + self, originals: qt.PredicateOperands, results: tuple[sqlalchemy.ColumnElement[bool], ...] + ) -> sqlalchemy.ColumnElement[bool]: + # Docstring inherited. + return sqlalchemy.and_(*results) + + def apply_logical_or( + self, + originals: tuple[qt.PredicateLeaf, ...], + results: tuple[sqlalchemy.ColumnElement[bool], ...], + flags: PredicateVisitFlags, + ) -> sqlalchemy.ColumnElement[bool]: + # Docstring inherited. + return sqlalchemy.or_(*results) + + def apply_logical_not( + self, original: qt.PredicateLeaf, result: sqlalchemy.ColumnElement[bool], flags: PredicateVisitFlags + ) -> sqlalchemy.ColumnElement[bool]: + # Docstring inherited. + return sqlalchemy.not_(result) + + def expect_scalar(self, expression: qt.OrderExpression) -> sqlalchemy.ColumnElement[Any]: + return cast(sqlalchemy.ColumnElement[Any], expression.visit(self)) + + def expect_timespan(self, expression: qt.ColumnExpression) -> TimespanDatabaseRepresentation: + return cast(TimespanDatabaseRepresentation, expression.visit(self)) diff --git a/python/lsst/daf/butler/registry/collections/nameKey.py b/python/lsst/daf/butler/registry/collections/nameKey.py index ccd7d26b6a..7da4390656 100644 --- a/python/lsst/daf/butler/registry/collections/nameKey.py +++ b/python/lsst/daf/butler/registry/collections/nameKey.py @@ -179,6 +179,12 @@ def getParentChains(self, key: str) -> set[str]: parent_names = set(sql_result.scalars().all()) return parent_names + def lookup_name_sql( + self, sql_key: sqlalchemy.ColumnElement[str], sql_from_clause: sqlalchemy.FromClause + ) -> tuple[sqlalchemy.ColumnElement[str], sqlalchemy.FromClause]: + # Docstring inherited. + return sql_key, sql_from_clause + def _fetch_by_name(self, names: Iterable[str]) -> list[CollectionRecord[str]]: # Docstring inherited from base class. return self._fetch_by_key(names) diff --git a/python/lsst/daf/butler/registry/collections/synthIntKey.py b/python/lsst/daf/butler/registry/collections/synthIntKey.py index b96a42f0fc..38605a6ad9 100644 --- a/python/lsst/daf/butler/registry/collections/synthIntKey.py +++ b/python/lsst/daf/butler/registry/collections/synthIntKey.py @@ -180,6 +180,17 @@ def getParentChains(self, key: int) -> set[str]: parent_names = set(sql_result.scalars().all()) return parent_names + def lookup_name_sql( + self, sql_key: sqlalchemy.ColumnElement[int], sql_from_clause: sqlalchemy.FromClause + ) -> tuple[sqlalchemy.ColumnElement[str], sqlalchemy.FromClause]: + # Docstring inherited. + return ( + self._tables.collection.c.name, + sql_from_clause.join( + self._tables.collection, onclause=self._tables.collection.c[_KEY_FIELD_SPEC.name] == sql_key + ), + ) + def _fetch_by_name(self, names: Iterable[str]) -> list[CollectionRecord[int]]: # Docstring inherited from base class. _LOG.debug("Fetching collection records using names %s.", names) diff --git a/python/lsst/daf/butler/registry/datasets/byDimensions/_storage.py b/python/lsst/daf/butler/registry/datasets/byDimensions/_storage.py index c67d0b6fb8..7a7a7ae7fb 100644 --- a/python/lsst/daf/butler/registry/datasets/byDimensions/_storage.py +++ b/python/lsst/daf/butler/registry/datasets/byDimensions/_storage.py @@ -34,7 +34,7 @@ import datetime from collections.abc import Callable, Iterable, Iterator, Sequence, Set -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Literal import astropy.time import sqlalchemy @@ -46,11 +46,13 @@ from ...._dataset_type import DatasetType from ...._timespan import Timespan from ....dimensions import DataCoordinate +from ....direct_query_driver import SqlBuilder # new query system, server+direct only +from ....queries import tree as qt # new query system, both clients + server from ..._collection_summary import CollectionSummary from ..._collection_type import CollectionType from ..._exceptions import CollectionTypeError, ConflictingDefinitionError from ...interfaces import DatasetRecordStorage -from ...queries import SqlQueryContext +from ...queries import SqlQueryContext # old registry query system from .tables import makeTagTableSpec if TYPE_CHECKING: @@ -552,6 +554,212 @@ def _finish_single_relation( ) return leaf + def make_sql_builder( + self, + collections: Sequence[CollectionRecord], + fields: Set[qt.DatasetFieldName | Literal["collection_key"]], + ) -> SqlBuilder: + # This method largely mimics `make_relation`, but it uses the new query + # system primitives instead of the old one. In terms of the SQL + # queries it builds, there are two more main differences: + # + # - Collection and run columns are now string names rather than IDs. + # This insulates the query result-processing code from collection + # caching and the collection manager subclass details. + # + # - The subquery always has unique rows, which is achieved by using + # SELECT DISTINCT when necessary. + # + collection_types = {collection.type for collection in collections} + assert CollectionType.CHAINED not in collection_types, "CHAINED collections must be flattened." + # + # There are two kinds of table in play here: + # + # - the static dataset table (with the dataset ID, dataset type ID, + # run ID/name, and ingest date); + # + # - the dynamic tags/calibs table (with the dataset ID, dataset type + # type ID, collection ID/name, data ID, and possibly validity + # range). + # + # That means that we might want to return a query against either table + # or a JOIN of both, depending on which quantities the caller wants. + # But the data ID is always included, which means we'll always include + # the tags/calibs table and join in the static dataset table only if we + # need things from it that we can't get from the tags/calibs table. + # + # Note that it's important that we include a WHERE constraint on both + # tables for any column (e.g. dataset_type_id) that is in both when + # it's given explicitly; not doing can prevent the query planner from + # using very important indexes. At present, we don't include those + # redundant columns in the JOIN ON expression, however, because the + # FOREIGN KEY (and its index) are defined only on dataset_id. + tag_sql_builder: SqlBuilder | None = None + if collection_types != {CollectionType.CALIBRATION}: + # We'll need a subquery for the tags table if any of the given + # collections are not a CALIBRATION collection. This intentionally + # also fires when the list of collections is empty as a way to + # create a dummy subquery that we know will fail. + # We give the table an alias because it might appear multiple times + # in the same query, for different dataset types. + tag_sql_builder = SqlBuilder(self._db, self._tags.alias(f"{self.datasetType.name}_tags")) + if "timespan" in fields: + tag_sql_builder.timespans[self.datasetType.name] = ( + self._db.getTimespanRepresentation().fromLiteral(Timespan(None, None)) + ) + tag_sql_builder = self._finish_sql_builder( + tag_sql_builder, + [ + (record, rank) + for rank, record in enumerate(collections) + if record.type is not CollectionType.CALIBRATION + ], + fields, + ) + calib_sql_builder: SqlBuilder | None = None + if CollectionType.CALIBRATION in collection_types: + # If at least one collection is a CALIBRATION collection, we'll + # need a subquery for the calibs table, and could include the + # timespan as a result or constraint. + assert ( + self._calibs is not None + ), "DatasetTypes with isCalibration() == False can never be found in a CALIBRATION collection." + calib_sql_builder = SqlBuilder(self._db, self._calibs.alias(f"{self.datasetType.name}_calibs")) + if "timespan" in fields: + calib_sql_builder.timespans[self.datasetType.name] = ( + self._db.getTimespanRepresentation().from_columns(self._calibs.columns) + ) + calib_sql_builder = self._finish_sql_builder( + calib_sql_builder, + [ + (record, rank) + for rank, record in enumerate(collections) + if record.type is CollectionType.CALIBRATION + ], + fields, + ) + # In calibration collections, we need timespan as well as data ID + # to ensure unique rows. + calib_sql_builder.needs_distinct = calib_sql_builder.needs_distinct and "timespan" not in fields + columns = qt.ColumnSet(self.datasetType.dimensions.as_group()) + columns.dataset_fields[self.datasetType.name].update(fields) + columns.drop_implied_dimension_keys() + if tag_sql_builder is not None: + if calib_sql_builder is not None: + # Need a UNION subquery. + return tag_sql_builder.union_subquery([calib_sql_builder], columns) + elif tag_sql_builder.needs_distinct: + # Need a SELECT DISTINCT subquery. + return tag_sql_builder.subquery(columns) + else: + return tag_sql_builder + elif calib_sql_builder is not None: + if calib_sql_builder.needs_distinct: + return calib_sql_builder.subquery(columns) + else: + return calib_sql_builder + else: + raise AssertionError("Branch should be unreachable.") + + def _finish_sql_builder( + self, + sql_builder: SqlBuilder, + collections: Sequence[tuple[CollectionRecord, int]], + fields: Set[qt.DatasetFieldName | Literal["collection_key"]], + ) -> SqlBuilder: + # This method plays the same role as _finish_single_relation in the new + # query system. It is called exactly one or two times by + # make_sql_builder, just as _finish_single_relation is called exactly + # one or two times by make_relation. See make_sql_builder comments for + # what's different. + assert sql_builder.sql_from_clause is not None + run_collections_only = all(record.type is CollectionType.RUN for record, _ in collections) + sql_builder.where_sql(sql_builder.sql_from_clause.c.dataset_type_id == self._dataset_type_id) + dataset_id_col = sql_builder.sql_from_clause.c.dataset_id + collection_col = sql_builder.sql_from_clause.c[self._collections.getCollectionForeignKeyName()] + fields_provided = sql_builder.fields[self.datasetType.name] + # We always constrain and optionally retrieve the collection(s) via the + # tags/calibs table. + if "collection_key" in fields: + sql_builder.fields[self.datasetType.name]["collection_key"] = collection_col + if len(collections) == 1: + only_collection_record, _ = collections[0] + sql_builder.where_sql(collection_col == only_collection_record.key) + if "collection" in fields: + fields_provided["collection"] = sqlalchemy.literal(only_collection_record.name) + elif not collections: + sql_builder.where_sql(sqlalchemy.literal(False)) + if "collection" in fields: + fields_provided["collection"] = sqlalchemy.literal("NO COLLECTIONS") + else: + sql_builder.where_sql(collection_col.in_([collection.key for collection, _ in collections])) + if "collection" in fields: + # Avoid a join to the collection table to get the name by using + # a CASE statement. The SQL will be a bit more verbose but + # more efficient. + fields_provided["collection"] = sqlalchemy.case( + {record.key: record.name for record, _ in collections}, value=collection_col + ) + # Add more column definitions, starting with the data ID. + sql_builder.extract_dimensions(self.datasetType.dimensions.required.names) + # We can always get the dataset_id from the tags/calibs table, even if + # could also get it from the 'static' dataset table. + if "dataset_id" in fields: + fields_provided["dataset_id"] = dataset_id_col + + # It's possible we now have everything we need, from just the + # tags/calibs table. The things we might need to get from the static + # dataset table are the run key and the ingest date. + need_static_table = False + if "run" in fields: + if len(collections) == 1 and run_collections_only: + # If we are searching exactly one RUN collection, we + # know that if we find the dataset in that collection, + # then that's the datasets's run; we don't need to + # query for it. + fields_provided["run"] = sqlalchemy.literal(only_collection_record.name) + elif run_collections_only: + # Once again we can avoid joining to the collection table by + # adding a CASE statement. + fields_provided["run"] = sqlalchemy.case( + {record.key: record.name for record, _ in collections}, + value=self._static.dataset.c[self._runKeyColumn], + ) + need_static_table = True + else: + # Here we can't avoid a join to the collection table, because + # we might find a dataset via something other than its RUN + # collection. + fields_provided["run"], sql_builder.sql_from_clause = self._collections.lookup_name_sql( + self._static.dataset.c[self._runKeyColumn], + sql_builder.sql_from_clause, + ) + need_static_table = True + # Ingest date can only come from the static table. + if "ingest_date" in fields: + fields_provided["ingest_date"] = self._static.dataset.c.ingest_date + need_static_table = True + if need_static_table: + # If we need the static table, join it in via dataset_id. + sql_builder.sql_from_clause = sql_builder.sql_from_clause.join( + self._static.dataset, onclause=(dataset_id_col == self._static.dataset.c.id) + ) + # Also constrain dataset_type_id in static table in case that helps + # generate a better plan. We could also include this in the JOIN ON + # clause, but my guess is that that's a good idea IFF it's in the + # foreign key, and right now it isn't. + sql_builder.where_sql(self._static.dataset.c.dataset_type_id == self._dataset_type_id) + sql_builder.needs_distinct = ( + # If there are multiple collections and we're searching any non-RUN + # collection, we could find the same dataset twice, which would + # yield duplicate rows unless "collection" or "rank" is there to + # make those rows unique. + len(collections) > 1 + and not run_collections_only + and ("collection_key" not in fields) + ) + return sql_builder + def getDataId(self, id: DatasetId) -> DataCoordinate: """Return DataId for a dataset. diff --git a/python/lsst/daf/butler/registry/dimensions/static.py b/python/lsst/daf/butler/registry/dimensions/static.py index 3a3903e855..ceda38b95a 100644 --- a/python/lsst/daf/butler/registry/dimensions/static.py +++ b/python/lsst/daf/butler/registry/dimensions/static.py @@ -30,17 +30,19 @@ import itertools import logging from collections import defaultdict -from collections.abc import Mapping, Sequence, Set +from collections.abc import Iterable, Mapping, Sequence, Set from typing import TYPE_CHECKING, Any import sqlalchemy from lsst.daf.relation import Calculation, ColumnExpression, Join, Relation, sql +from lsst.sphgeom import Region from ... import ddl from ..._column_tags import DimensionKeyColumnTag, DimensionRecordColumnTag from ..._column_type_info import LogicalColumn from ..._named import NamedKeyDict from ...dimensions import ( + DatabaseDimensionElement, DatabaseTopologicalFamily, DataCoordinate, Dimension, @@ -53,11 +55,15 @@ addDimensionForeignKey, ) from ...dimensions.record_cache import DimensionRecordCache +from ...direct_query_driver import Postprocessing, SqlBuilder # Future query system (direct,server). +from ...queries import tree as qt # Future query system (direct,client,server) +from ...queries.overlaps import OverlapsVisitor +from ...queries.visitors import PredicateVisitFlags from .._exceptions import MissingSpatialOverlapError from ..interfaces import Database, DimensionRecordStorageManager, StaticTablesContext, VersionTuple if TYPE_CHECKING: - from .. import queries + from .. import queries # Current Registry.query* system. # This has to be updated on every schema change @@ -426,6 +432,38 @@ def make_spatial_join_relation( ) return overlaps, needs_refinement + def make_sql_builder(self, element: DimensionElement, fields: Set[str]) -> SqlBuilder: + if element.implied_union_target is not None: + assert not fields, "Dimensions with implied-union storage never have fields." + return self.make_sql_builder(element.implied_union_target, fields).subquery( + qt.ColumnSet(element.minimal_group).drop_implied_dimension_keys(), distinct=True + ) + if not element.has_own_table: + raise NotImplementedError(f"Cannot join dimension element {element} with no table.") + table = self._tables[element.name] + result = SqlBuilder(self._db, table) + for dimension_name, column_name in zip(element.required.names, element.schema.required.names): + result.dimension_keys[dimension_name].append(table.columns[column_name]) + result.extract_dimensions(element.implied.names) + for field in fields: + if field == "timespan": + result.timespans[element.name] = self._db.getTimespanRepresentation().from_columns( + table.columns + ) + else: + result.fields[element.name][field] = table.columns[field] + return result + + def process_query_overlaps( + self, + dimensions: DimensionGroup, + predicate: qt.Predicate, + join_operands: Iterable[DimensionGroup], + ) -> tuple[qt.Predicate, SqlBuilder, Postprocessing]: + overlaps_visitor = _CommonSkyPixMediatedOverlapsVisitor(self._db, dimensions, self._overlap_tables) + new_predicate = overlaps_visitor.run(predicate, join_operands) + return new_predicate, overlaps_visitor.sql_builder, overlaps_visitor.postprocessing + def _make_relation( self, element: DimensionElement, @@ -948,3 +986,159 @@ def load(self, key: int) -> DimensionGroup: self.refresh() graph = self._groupsByKey[key] return graph + + +class _CommonSkyPixMediatedOverlapsVisitor(OverlapsVisitor): + def __init__( + self, + db: Database, + dimensions: DimensionGroup, + overlap_tables: Mapping[str, tuple[sqlalchemy.Table, sqlalchemy.Table]], + ): + super().__init__(dimensions) + self.sql_builder: SqlBuilder = SqlBuilder(db) + self.postprocessing = Postprocessing() + self.common_skypix = dimensions.universe.commonSkyPix + self.overlap_tables: Mapping[str, tuple[sqlalchemy.Table, sqlalchemy.Table]] = overlap_tables + self.common_skypix_overlaps_done: set[DatabaseDimensionElement] = set() + + def visit_spatial_constraint( + self, + element: DimensionElement, + region: Region, + flags: PredicateVisitFlags, + ) -> qt.Predicate | None: + # Reject spatial constraints that are nested inside OR or NOT, because + # the postprocessing needed for those would be a lot harder. + if flags & PredicateVisitFlags.INVERTED or flags & PredicateVisitFlags.HAS_OR_SIBLINGS: + raise NotImplementedError( + "Spatial overlap constraints nested inside OR or NOT are not supported." + ) + # Delegate to super just because that's good practice with + # OverlapVisitor. + super().visit_spatial_constraint(element, region, flags) + match element: + case DatabaseDimensionElement(): + # If this is a database dimension element like tract, patch, or + # visit, we need to: + # - join in the common skypix overlap table for this element; + # - constrain the common skypix index to be inside the + # ranges that overlap the region as a SQL where clause; + # - add postprocessing to reject rows where the database + # dimension element's region doesn't actually overlap the + # region. + self.postprocessing.spatial_where_filtering.append((element, region)) + if self.common_skypix.name in self.dimensions: + # The common skypix dimension should be part of the query + # as a first-class dimension, so we can join in the overlap + # table directly, and fall through to the end of this + # function to construct a Predicate that will turn into the + # SQL WHERE clause we want. + self._join_common_skypix_overlap(element) + skypix = self.common_skypix + else: + # We need to hide the common skypix dimension from the + # larger query, so we make a subquery out of the overlap + # table that embeds the SQL WHERE clause we want and then + # projects out that dimension (with SELECT DISTINCT, to + # avoid introducing duplicate rows into the larger query). + overlap_sql_builder = self._make_common_skypix_overlap_sql_builder(element) + sql_where_or: list[sqlalchemy.ColumnElement[bool]] = [] + sql_skypix_col = overlap_sql_builder.dimension_keys[self.common_skypix.name][0] + for begin, end in self.common_skypix.pixelization.envelope(region): + sql_where_or.append(sqlalchemy.and_(sql_skypix_col >= begin, sql_skypix_col < end)) + overlap_sql_builder.where_sql(sqlalchemy.or_(*sql_where_or)) + self.sql_builder = self.sql_builder.join( + overlap_sql_builder.subquery( + qt.ColumnSet(element.minimal_group).drop_implied_dimension_keys(), distinct=True + ) + ) + # Short circuit here since the SQL WHERE clause has already + # been embedded in the subquery. + return qt.Predicate.from_bool(True) + case SkyPixDimension(): + # If this is a skypix dimension, we can do a index-in-ranges + # test directly on that dimension. Note that this doesn't on + # its own guarantee the skypix dimension column will be in the + # query; that'll be the job of the DirectQueryDriver to sort + # out (generally this will require a dataset using that skypix + # dimension to be joined in, unless this is the common skypix + # system). + assert ( + element.name in self.dimensions + ), "QueryTree guarantees dimensions are expanded when constraints are added." + skypix = element + case _: + raise NotImplementedError( + f"Spatial overlap constraint for dimension {element} not supported." + ) + # Convert the region-overlap constraint into a skypix + # index range-membership constraint in SQL... + result = qt.Predicate.from_bool(False) + skypix_col_ref = qt.DimensionKeyReference.model_construct(dimension=skypix) + for begin, end in skypix.pixelization.envelope(region): + result = result.logical_or(qt.Predicate.in_range(skypix_col_ref, start=begin, stop=end)) + return result + + def visit_spatial_join( + self, a: DimensionElement, b: DimensionElement, flags: PredicateVisitFlags + ) -> qt.Predicate | None: + # Reject spatial joins that are nested inside OR or NOT, because the + # postprocessing needed for those would be a lot harder. + if flags & PredicateVisitFlags.INVERTED or flags & PredicateVisitFlags.HAS_OR_SIBLINGS: + raise NotImplementedError("Spatial overlap joins nested inside OR or NOT are not supported.") + # Delegate to super to check for invalid joins and record this + # "connection" for use when seeing whether to add an automatic join + # later. + super().visit_spatial_join(a, b, flags) + match (a, b): + case (self.common_skypix, DatabaseDimensionElement() as b): + self._join_common_skypix_overlap(b) + case (DatabaseDimensionElement() as a, self.common_skypix): + self._join_common_skypix_overlap(a) + case (DatabaseDimensionElement() as a, DatabaseDimensionElement() as b): + if self.common_skypix.name in self.dimensions: + # We want the common skypix dimension to appear in the + # query as a first-class dimension, so just join in the + # two overlap tables directly. + self._join_common_skypix_overlap(a) + self._join_common_skypix_overlap(b) + else: + # We do not want the common skypix system to appear in the + # query or cause duplicate rows, so we join the two overlap + # tables in a subquery that projects out the common skypix + # index column with SELECT DISTINCT. + + self.sql_builder = self.sql_builder.join( + self._make_common_skypix_overlap_sql_builder(a) + .join(self._make_common_skypix_overlap_sql_builder(b)) + .subquery( + qt.ColumnSet(a.minimal_group | b.minimal_group).drop_implied_dimension_keys(), + distinct=True, + ) + ) + # In both cases we add postprocessing to check that the regions + # really do overlap, since overlapping the same common skypix + # tile is necessary but not sufficient for that. + self.postprocessing.spatial_join_filtering.append((a, b)) + case _: + raise NotImplementedError(f"Unsupported combination for spatial join: {a, b}.") + return qt.Predicate.from_bool(True) + + def _join_common_skypix_overlap(self, element: DatabaseDimensionElement) -> None: + if element not in self.common_skypix_overlaps_done: + self.sql_builder = self.sql_builder.join(self._make_common_skypix_overlap_sql_builder(element)) + self.common_skypix_overlaps_done.add(element) + + def _make_common_skypix_overlap_sql_builder(self, element: DatabaseDimensionElement) -> SqlBuilder: + _, overlap_table = self.overlap_tables[element.name] + return self.sql_builder.join( + SqlBuilder(self.sql_builder.db, overlap_table) + .extract_dimensions(element.required.names, skypix_index=self.common_skypix.name) + .where_sql( + sqlalchemy.and_( + overlap_table.c.skypix_system == self.common_skypix.system.name, + overlap_table.c.skypix_level == self.common_skypix.level, + ) + ) + ) diff --git a/python/lsst/daf/butler/registry/interfaces/_collections.py b/python/lsst/daf/butler/registry/interfaces/_collections.py index cef7b9741f..418d264460 100644 --- a/python/lsst/daf/butler/registry/interfaces/_collections.py +++ b/python/lsst/daf/butler/registry/interfaces/_collections.py @@ -39,6 +39,8 @@ from collections.abc import Iterable, Set from typing import TYPE_CHECKING, Any, Generic, Self, TypeVar +import sqlalchemy + from ..._timespan import Timespan from .._collection_type import CollectionType from ..wildcards import CollectionWildcard @@ -621,3 +623,27 @@ def update_chain( `~CollectionType.CHAINED` collections in ``children`` first. """ raise NotImplementedError() + + @abstractmethod + def lookup_name_sql( + self, sql_key: sqlalchemy.ColumnElement[_Key], sql_from_clause: sqlalchemy.FromClause + ) -> tuple[sqlalchemy.ColumnElement[str], sqlalchemy.FromClause]: + """Return a SQLAlchemy column and FROM clause that enable a query + to look up a collection name from the key. + + Parameters + ---------- + sql_key : `sqlalchemy.ColumnElement` + SQL column expression that evaluates to the collection key. + sql_from_clause : `sqlalchemy.FromClause` + SQL FROM clause from which ``sql_key`` was obtained. + + Returns + ------- + sql_name : `sqlalchemy.ColumnElement` [ `str` ] + SQL column expression that evalutes to the collection name. + sql_from_clause : `sqlalchemy.FromClause` + SQL FROM clause that includes the given ``sql_from_clause`` and + any table needed to provided ``sql_name``. + """ + raise NotImplementedError() diff --git a/python/lsst/daf/butler/registry/interfaces/_datasets.py b/python/lsst/daf/butler/registry/interfaces/_datasets.py index abc85d4a05..6c401a9e47 100644 --- a/python/lsst/daf/butler/registry/interfaces/_datasets.py +++ b/python/lsst/daf/butler/registry/interfaces/_datasets.py @@ -32,8 +32,8 @@ __all__ = ("DatasetRecordStorageManager", "DatasetRecordStorage") from abc import ABC, abstractmethod -from collections.abc import Iterable, Iterator, Mapping, Set -from typing import TYPE_CHECKING, Any +from collections.abc import Iterable, Iterator, Mapping, Sequence, Set +from typing import TYPE_CHECKING, Any, Literal from lsst.daf.relation import Relation @@ -45,9 +45,11 @@ from ._versioning import VersionedExtension, VersionTuple if TYPE_CHECKING: + from ...direct_query_driver import SqlBuilder # new query system, server+direct only + from ...queries import tree as qt # new query system, both clients + server from .._caching_context import CachingContext from .._collection_summary import CollectionSummary - from ..queries import SqlQueryContext + from ..queries import SqlQueryContext # old registry query system from ._collections import CollectionManager, CollectionRecord, RunRecord from ._database import Database, StaticTablesContext from ._dimensions import DimensionRecordStorageManager @@ -311,6 +313,14 @@ def make_relation( """ raise NotImplementedError() + @abstractmethod + def make_sql_builder( + self, + collections: Sequence[CollectionRecord], + fields: Set[qt.DatasetFieldName | Literal["collection_key"]], + ) -> SqlBuilder: + raise NotImplementedError() + datasetType: DatasetType """Dataset type whose records this object manages (`DatasetType`). """ diff --git a/python/lsst/daf/butler/registry/interfaces/_dimensions.py b/python/lsst/daf/butler/registry/interfaces/_dimensions.py index c14e19ca29..9f6aff6b6d 100644 --- a/python/lsst/daf/butler/registry/interfaces/_dimensions.py +++ b/python/lsst/daf/butler/registry/interfaces/_dimensions.py @@ -29,7 +29,7 @@ __all__ = ("DimensionRecordStorageManager",) from abc import abstractmethod -from collections.abc import Set +from collections.abc import Iterable, Set from typing import TYPE_CHECKING, Any from lsst.daf.relation import Join, Relation @@ -46,7 +46,9 @@ from ._versioning import VersionedExtension, VersionTuple if TYPE_CHECKING: - from .. import queries + from ...direct_query_driver import Postprocessing, SqlBuilder # Future query system (direct,server). + from ...queries.tree import Predicate # Future query system (direct,client,server). + from .. import queries # Old Registry.query* system. from ._database import Database, StaticTablesContext @@ -357,6 +359,19 @@ def make_spatial_join_relation( """ raise NotImplementedError() + @abstractmethod + def make_sql_builder(self, element: DimensionElement, fields: Set[str]) -> SqlBuilder: + raise NotImplementedError() + + @abstractmethod + def process_query_overlaps( + self, + dimensions: DimensionGroup, + predicate: Predicate, + join_operands: Iterable[DimensionGroup], + ) -> tuple[Predicate, SqlBuilder, Postprocessing]: + raise NotImplementedError() + universe: DimensionUniverse """Universe of all dimensions and dimension elements known to the `Registry` (`DimensionUniverse`).