diff --git a/python/lsst/daf/butler/direct_query_driver/__init__.py b/python/lsst/daf/butler/direct_query_driver/__init__.py
new file mode 100644
index 0000000000..d8aae48e4b
--- /dev/null
+++ b/python/lsst/daf/butler/direct_query_driver/__init__.py
@@ -0,0 +1,29 @@
+# This file is part of daf_butler.
+#
+# Developed for the LSST Data Management System.
+# This product includes software developed by the LSST Project
+# (http://www.lsst.org).
+# See the COPYRIGHT file at the top-level directory of this distribution
+# for details of code ownership.
+#
+# This software is dual licensed under the GNU General Public License and also
+# under a 3-clause BSD license. Recipients may choose which of these licenses
+# to use; please see the files gpl-3.0.txt and/or bsd_license.txt,
+# respectively. If you choose the GPL option then the following text applies
+# (but note that there is still no warranty even if you opt for BSD instead):
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see .
+
+from ._postprocessing import Postprocessing
+from ._sql_builder import SqlBuilder
diff --git a/python/lsst/daf/butler/direct_query_driver/_analyzed_query.py b/python/lsst/daf/butler/direct_query_driver/_analyzed_query.py
new file mode 100644
index 0000000000..71d98c7898
--- /dev/null
+++ b/python/lsst/daf/butler/direct_query_driver/_analyzed_query.py
@@ -0,0 +1,170 @@
+# This file is part of daf_butler.
+#
+# Developed for the LSST Data Management System.
+# This product includes software developed by the LSST Project
+# (http://www.lsst.org).
+# See the COPYRIGHT file at the top-level directory of this distribution
+# for details of code ownership.
+#
+# This software is dual licensed under the GNU General Public License and also
+# under a 3-clause BSD license. Recipients may choose which of these licenses
+# to use; please see the files gpl-3.0.txt and/or bsd_license.txt,
+# respectively. If you choose the GPL option then the following text applies
+# (but note that there is still no warranty even if you opt for BSD instead):
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see .
+
+from __future__ import annotations
+
+__all__ = ("AnalyzedQuery", "AnalyzedDatasetSearch", "DataIdExtractionVisitor")
+
+import dataclasses
+from collections.abc import Iterator
+from typing import Any
+
+from ..dimensions import DataIdValue, DimensionElement, DimensionGroup, DimensionUniverse
+from ..queries import tree as qt
+from ..queries.visitors import ColumnExpressionVisitor, PredicateVisitFlags, SimplePredicateVisitor
+from ..registry.interfaces import CollectionRecord
+from ._postprocessing import Postprocessing
+
+
+@dataclasses.dataclass
+class AnalyzedDatasetSearch:
+ name: str
+ shrunk: str
+ dimensions: DimensionGroup
+ collection_records: list[CollectionRecord] = dataclasses.field(default_factory=list)
+ messages: list[str] = dataclasses.field(default_factory=list)
+ is_calibration_search: bool = False
+
+
+@dataclasses.dataclass
+class AnalyzedQuery:
+ predicate: qt.Predicate
+ postprocessing: Postprocessing
+ base_columns: qt.ColumnSet
+ projection_columns: qt.ColumnSet
+ final_columns: qt.ColumnSet
+ find_first_dataset: str | None
+ materializations: dict[qt.MaterializationKey, DimensionGroup] = dataclasses.field(default_factory=dict)
+ datasets: dict[str, AnalyzedDatasetSearch] = dataclasses.field(default_factory=dict)
+ messages: list[str] = dataclasses.field(default_factory=list)
+ constraint_data_id: dict[str, DataIdValue] = dataclasses.field(default_factory=dict)
+ data_coordinate_uploads: dict[qt.DataCoordinateUploadKey, DimensionGroup] = dataclasses.field(
+ default_factory=dict
+ )
+ needs_dimension_distinct: bool = False
+ needs_find_first_resolution: bool = False
+ projection_region_aggregates: list[DimensionElement] = dataclasses.field(default_factory=list)
+
+ @property
+ def universe(self) -> DimensionUniverse:
+ return self.base_columns.dimensions.universe
+
+ @property
+ def needs_projection(self) -> bool:
+ return self.needs_dimension_distinct or self.postprocessing.check_validity_match_count
+
+ def iter_mandatory_base_elements(self) -> Iterator[DimensionElement]:
+ for element_name in self.base_columns.dimensions.elements:
+ element = self.universe[element_name]
+ if self.base_columns.dimension_fields[element_name]:
+ # We need to get dimension record fields for this element, and
+ # its table is the only place to get those.
+ yield element
+ elif element.defines_relationships:
+ # We als need to join in DimensionElements tables that define
+ # one-to-many and many-to-many relationships, but data
+ # coordinate uploads, materializations, and datasets can also
+ # provide these relationships. Data coordinate uploads and
+ # dataset tables only have required dimensions, and can hence
+ # only provide relationships involving those.
+ if any(
+ element.minimal_group.names <= upload_dimensions.required
+ for upload_dimensions in self.data_coordinate_uploads.values()
+ ):
+ continue
+ if any(
+ element.minimal_group.names <= dataset_spec.dimensions.required
+ for dataset_spec in self.datasets.values()
+ ):
+ continue
+ # Materializations have all key columns for their dimensions.
+ if any(
+ element in materialization_dimensions.names
+ for materialization_dimensions in self.materializations.values()
+ ):
+ continue
+ yield element
+
+
+class DataIdExtractionVisitor(
+ SimplePredicateVisitor,
+ ColumnExpressionVisitor[tuple[str, None] | tuple[None, Any] | tuple[None, None]],
+):
+ def __init__(self, data_id: dict[str, DataIdValue], messages: list[str]):
+ self.data_id = data_id
+ self.messages = messages
+
+ def visit_comparison(
+ self,
+ a: qt.ColumnExpression,
+ operator: qt.ComparisonOperator,
+ b: qt.ColumnExpression,
+ flags: PredicateVisitFlags,
+ ) -> None:
+ if flags & PredicateVisitFlags.HAS_OR_SIBLINGS:
+ return None
+ if flags & PredicateVisitFlags.INVERTED:
+ if operator == "!=":
+ operator = "=="
+ else:
+ return None
+ if operator != "==":
+ return None
+ k_a, v_a = a.visit(self)
+ k_b, v_b = b.visit(self)
+ if k_a is not None and v_b is not None:
+ key = k_a
+ value = v_b
+ elif k_b is not None and v_a is not None:
+ key = k_b
+ value = v_a
+ else:
+ return None
+ if (old := self.data_id.setdefault(key, value)) != value:
+ self.messages.append(f"'where' expression requires both {key}={value!r} and {key}={old!r}.")
+ return None
+
+ def visit_binary_expression(self, expression: qt.BinaryExpression) -> tuple[None, None]:
+ return None, None
+
+ def visit_unary_expression(self, expression: qt.UnaryExpression) -> tuple[None, None]:
+ return None, None
+
+ def visit_literal(self, expression: qt.ColumnLiteral) -> tuple[None, Any]:
+ return None, expression.get_literal_value()
+
+ def visit_dimension_key_reference(self, expression: qt.DimensionKeyReference) -> tuple[str, None]:
+ return expression.dimension.name, None
+
+ def visit_dimension_field_reference(self, expression: qt.DimensionFieldReference) -> tuple[None, None]:
+ return None, None
+
+ def visit_dataset_field_reference(self, expression: qt.DatasetFieldReference) -> tuple[None, None]:
+ return None, None
+
+ def visit_reversed(self, expression: qt.Reversed) -> tuple[None, None]:
+ raise AssertionError("No Reversed expressions in predicates.")
diff --git a/python/lsst/daf/butler/direct_query_driver/_convert_results.py b/python/lsst/daf/butler/direct_query_driver/_convert_results.py
new file mode 100644
index 0000000000..899c1760b4
--- /dev/null
+++ b/python/lsst/daf/butler/direct_query_driver/_convert_results.py
@@ -0,0 +1,61 @@
+# This file is part of daf_butler.
+#
+# Developed for the LSST Data Management System.
+# This product includes software developed by the LSST Project
+# (http://www.lsst.org).
+# See the COPYRIGHT file at the top-level directory of this distribution
+# for details of code ownership.
+#
+# This software is dual licensed under the GNU General Public License and also
+# under a 3-clause BSD license. Recipients may choose which of these licenses
+# to use; please see the files gpl-3.0.txt and/or bsd_license.txt,
+# respectively. If you choose the GPL option then the following text applies
+# (but note that there is still no warranty even if you opt for BSD instead):
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see .
+
+from __future__ import annotations
+
+__all__ = ("convert_dimension_record_results",)
+
+from collections.abc import Iterable
+from typing import TYPE_CHECKING
+
+import sqlalchemy
+
+from ..dimensions import DimensionRecordSet
+
+if TYPE_CHECKING:
+ from ..queries.driver import DimensionRecordResultPage, PageKey
+ from ..queries.result_specs import DimensionRecordResultSpec
+ from ..registry.nameShrinker import NameShrinker
+
+
+def convert_dimension_record_results(
+ raw_rows: Iterable[sqlalchemy.Row],
+ spec: DimensionRecordResultSpec,
+ next_key: PageKey | None,
+ name_shrinker: NameShrinker,
+) -> DimensionRecordResultPage:
+ record_set = DimensionRecordSet(spec.element)
+ columns = spec.get_result_columns()
+ column_mapping = [
+ (field, name_shrinker.shrink(columns.get_qualified_name(spec.element.name, field)))
+ for field in spec.element.schema.names
+ ]
+ record_cls = spec.element.RecordClass
+ if not spec.element.temporal:
+ for raw_row in raw_rows:
+ record_set.add(record_cls(**{k: raw_row._mapping[v] for k, v in column_mapping}))
+ return DimensionRecordResultPage(spec=spec, next_key=next_key, rows=record_set)
diff --git a/python/lsst/daf/butler/direct_query_driver/_driver.py b/python/lsst/daf/butler/direct_query_driver/_driver.py
new file mode 100644
index 0000000000..654c4f80bc
--- /dev/null
+++ b/python/lsst/daf/butler/direct_query_driver/_driver.py
@@ -0,0 +1,789 @@
+# This file is part of daf_butler.
+#
+# Developed for the LSST Data Management System.
+# This product includes software developed by the LSST Project
+# (http://www.lsst.org).
+# See the COPYRIGHT file at the top-level directory of this distribution
+# for details of code ownership.
+#
+# This software is dual licensed under the GNU General Public License and also
+# under a 3-clause BSD license. Recipients may choose which of these licenses
+# to use; please see the files gpl-3.0.txt and/or bsd_license.txt,
+# respectively. If you choose the GPL option then the following text applies
+# (but note that there is still no warranty even if you opt for BSD instead):
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see .
+
+from __future__ import annotations
+
+import uuid
+
+__all__ = ("DirectQueryDriver",)
+
+import logging
+from collections.abc import Iterable, Iterator, Sequence
+from contextlib import ExitStack
+from typing import TYPE_CHECKING, Any, cast, overload
+
+import sqlalchemy
+
+from .. import ddl
+from ..dimensions import DataIdValue, DimensionGroup, DimensionUniverse
+from ..queries import tree as qt
+from ..queries.driver import (
+ DataCoordinateResultPage,
+ DatasetRefResultPage,
+ DimensionRecordResultPage,
+ GeneralResultPage,
+ PageKey,
+ QueryDriver,
+ ResultPage,
+)
+from ..queries.result_specs import (
+ DataCoordinateResultSpec,
+ DatasetRefResultSpec,
+ DimensionRecordResultSpec,
+ GeneralResultSpec,
+ ResultSpec,
+)
+from ..registry import CollectionSummary, CollectionType, NoDefaultCollectionError, RegistryDefaults
+from ..registry.interfaces import ChainedCollectionRecord, CollectionRecord
+from ..registry.managers import RegistryManagerInstances
+from ..registry.nameShrinker import NameShrinker
+from ._analyzed_query import AnalyzedDatasetSearch, AnalyzedQuery, DataIdExtractionVisitor
+from ._convert_results import convert_dimension_record_results
+from ._sql_column_visitor import SqlColumnVisitor
+
+if TYPE_CHECKING:
+ from ..registry.interfaces import Database
+ from ._postprocessing import Postprocessing
+ from ._sql_builder import SqlBuilder
+
+
+_LOG = logging.getLogger(__name__)
+
+
+class DirectQueryDriver(QueryDriver):
+ """The `QueryDriver` implementation for `DirectButler`.
+
+ Parameters
+ ----------
+ db : `Database`
+ Abstraction for the SQL database.
+ universe : `DimensionUniverse`
+ Definitions of all dimensions.
+ managers : `RegistryManagerInstances`
+ Struct of registry manager objects.
+ defaults : `RegistryDefaults`
+ Struct holding the default collection search path and governor
+ dimensions.
+ raw_page_size : `int`, optional
+ Number of database rows to fetch for each result page. The actual
+ number of rows in a page may be smaller due to postprocessing.
+ postprocessing_filter_factor : `int`, optional
+ The number of database rows we expect to have to fetch to yield a
+ single output row for queries that involve postprocessing. This is
+ purely a performance tuning parameter that attempts to balance between
+ fetching too much and requiring multiple fetches; the true value is
+ highly dependent on the actual query.
+ """
+
+ def __init__(
+ self,
+ db: Database,
+ universe: DimensionUniverse,
+ managers: RegistryManagerInstances,
+ defaults: RegistryDefaults,
+ raw_page_size: int = 10000,
+ postprocessing_filter_factor: int = 10,
+ ):
+ self.db = db
+ self.managers = managers
+ self._universe = universe
+ self._defaults = defaults
+ self._materializations: dict[qt.MaterializationKey, tuple[sqlalchemy.Table, Postprocessing]] = {}
+ self._upload_tables: dict[qt.DataCoordinateUploadKey, sqlalchemy.Table] = {}
+ self._exit_stack: ExitStack | None = None
+ self._raw_page_size = raw_page_size
+ self._postprocessing_filter_factor = postprocessing_filter_factor
+ self._active_pages: dict[PageKey, tuple[Iterator[Sequence[sqlalchemy.Row]], Postprocessing]] = {}
+ self._name_shrinker = NameShrinker(self.db.dialect.max_identifier_length)
+
+ def __enter__(self) -> None:
+ self._exit_stack = ExitStack()
+
+ def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
+ assert self._exit_stack is not None
+ self._exit_stack.__exit__(exc_type, exc_value, traceback)
+ self._exit_stack = None
+
+ @property
+ def universe(self) -> DimensionUniverse:
+ return self._universe
+
+ @overload
+ def execute(
+ self, result_spec: DataCoordinateResultSpec, tree: qt.QueryTree
+ ) -> DataCoordinateResultPage: ...
+
+ @overload
+ def execute(
+ self, result_spec: DimensionRecordResultSpec, tree: qt.QueryTree
+ ) -> DimensionRecordResultPage: ...
+
+ @overload
+ def execute(self, result_spec: DatasetRefResultSpec, tree: qt.QueryTree) -> DatasetRefResultPage: ...
+
+ @overload
+ def execute(self, result_spec: GeneralResultSpec, tree: qt.QueryTree) -> GeneralResultPage: ...
+
+ def execute(self, result_spec: ResultSpec, tree: qt.QueryTree) -> ResultPage:
+ # Docstring inherited.
+ if self._exit_stack is None:
+ raise RuntimeError("QueryDriver context must be entered before 'materialize' is called.")
+ # Make a set of the columns the query needs to make available to the
+ # SELECT clause and any ORDER BY or GROUP BY clauses. This does not
+ # include columns needed only by the WHERE or JOIN ON clauses (those
+ # will be handled inside `_make_vanilla_sql_builder`).
+
+ # Build the FROM and WHERE clauses and identify any post-query
+ # processing we need to run.
+ query, sql_builder = self.analyze_query(
+ tree,
+ final_columns=result_spec.get_result_columns(),
+ order_by=result_spec.order_by,
+ find_first_dataset=result_spec.find_first_dataset,
+ )
+ sql_builder = self.build_query(query, sql_builder)
+ sql_select = sql_builder.select(query.final_columns, query.postprocessing)
+ if result_spec.order_by:
+ visitor = SqlColumnVisitor(sql_builder, self)
+ sql_select = sql_select.order_by(*[visitor.expect_scalar(term) for term in result_spec.order_by])
+ if result_spec.limit is not None:
+ if query.postprocessing:
+ query.postprocessing.limit = result_spec.limit
+ else:
+ sql_select = sql_select.limit(result_spec.limit)
+ if result_spec.offset:
+ if query.postprocessing:
+ sql_select = sql_select.offset(result_spec.offset)
+ else:
+ query.postprocessing.offset = result_spec.offset
+ if query.postprocessing.limit is not None:
+ # We might want to fetch many fewer rows that the default page
+ # size if we have to implement offset and limit in postprocessing.
+ raw_page_size = min(
+ self._postprocessing_filter_factor
+ * (query.postprocessing.offset + query.postprocessing.limit),
+ self._raw_page_size,
+ )
+ cursor = self._exit_stack.enter_context(
+ self.db.query(sql_select.execution_options(yield_per=raw_page_size))
+ )
+ raw_page_iter = cursor.partitions()
+ return self._process_page(raw_page_iter, result_spec, query.postprocessing)
+
+ @overload
+ def fetch_next_page(
+ self, result_spec: DataCoordinateResultSpec, key: PageKey
+ ) -> DataCoordinateResultPage: ...
+
+ @overload
+ def fetch_next_page(
+ self, result_spec: DimensionRecordResultSpec, key: PageKey
+ ) -> DimensionRecordResultPage: ...
+
+ @overload
+ def fetch_next_page(self, result_spec: DatasetRefResultSpec, key: PageKey) -> DatasetRefResultPage: ...
+
+ @overload
+ def fetch_next_page(self, result_spec: GeneralResultSpec, key: PageKey) -> GeneralResultPage: ...
+
+ def fetch_next_page(self, result_spec: ResultSpec, key: PageKey) -> ResultPage:
+ raw_page_iter, postprocessing = self._active_pages.pop(key)
+ return self._process_page(raw_page_iter, result_spec, postprocessing)
+
+ def materialize(
+ self,
+ tree: qt.QueryTree,
+ dimensions: DimensionGroup,
+ datasets: frozenset[str],
+ ) -> qt.MaterializationKey:
+ # Docstring inherited.
+ if self._exit_stack is None:
+ raise RuntimeError("QueryDriver context must be entered before 'materialize' is called.")
+ query, sql_builder = self.analyze_query(tree, qt.ColumnSet(dimensions))
+ # Current implementation ignores 'datasets' because figuring out what
+ # to put in the temporary table for them is tricky, especially if
+ # calibration collections are involved.
+ sql_builder = self.build_query(query, sql_builder)
+ sql_select = sql_builder.select(query.final_columns, query.postprocessing)
+ table = self._exit_stack.enter_context(
+ self.db.temporary_table(sql_builder.make_table_spec(query.final_columns, query.postprocessing))
+ )
+ self.db.insert(table, select=sql_select)
+ key = uuid.uuid4()
+ self._materializations[key] = (table, query.postprocessing)
+ return key
+
+ def upload_data_coordinates(
+ self, dimensions: DimensionGroup, rows: Iterable[tuple[DataIdValue, ...]]
+ ) -> qt.DataCoordinateUploadKey:
+ # Docstring inherited.
+ if self._exit_stack is None:
+ raise RuntimeError("QueryDriver context must be entered before 'materialize' is called.")
+ table_spec = ddl.TableSpec(
+ [
+ self.universe.dimensions[name].primary_key.model_copy(update=dict(name=name)).to_sql_spec()
+ for name in dimensions.required
+ ]
+ )
+ if not dimensions:
+ table_spec.fields.add(
+ ddl.FieldSpec(
+ SqlBuilder.EMPTY_COLUMNS_NAME, dtype=SqlBuilder.EMPTY_COLUMNS_TYPE, nullable=True
+ )
+ )
+ table = self._exit_stack.enter_context(self.db.temporary_table(table_spec))
+ self.db.insert(table, *(dict(zip(dimensions.required, values)) for values in rows))
+ key = uuid.uuid4()
+ self._upload_tables[key] = table
+ return key
+
+ def count(
+ self,
+ tree: qt.QueryTree,
+ columns: qt.ColumnSet,
+ find_first_dataset: str | None,
+ *,
+ exact: bool,
+ discard: bool,
+ ) -> int:
+ # Docstring inherited.
+ query, sql_builder = self.analyze_query(tree, columns, find_first_dataset=find_first_dataset)
+ sql_builder = self.build_query(query, sql_builder)
+ if query.postprocessing and exact:
+ if not discard:
+ raise RuntimeError("Cannot count query rows exactly without discarding them.")
+ sql_select = sql_builder.select(columns, query.postprocessing)
+ n = 0
+ with self.db.query(sql_select.execution_options(yield_per=self._raw_page_size)) as results:
+ for _ in query.postprocessing.apply(results):
+ n + 1
+ return n
+ # Do COUNT(*) on the original query's FROM clause.
+ sql_builder.special["_ROWCOUNT"] = sqlalchemy.func.count()
+ sql_select = sql_builder.select(qt.ColumnSet(self._universe.empty.as_group()))
+ with self.db.query(sql_select) as result:
+ return cast(int, result.scalar())
+
+ def any(self, tree: qt.QueryTree, *, execute: bool, exact: bool) -> bool:
+ # Docstring inherited.
+ query, sql_builder = self.analyze_query(tree, qt.ColumnSet(tree.dimensions))
+ if not all(d.collection_records for d in query.datasets.values()):
+ return False
+ if not execute:
+ if exact:
+ raise RuntimeError("Cannot obtain exact result for 'any' without executing.")
+ return True
+ sql_builder = self.build_query(query, sql_builder)
+ if query.postprocessing and exact:
+ sql_select = sql_builder.select(query.final_columns, query.postprocessing)
+ with self.db.query(
+ sql_select.execution_options(yield_per=self._postprocessing_filter_factor)
+ ) as result:
+ for _ in query.postprocessing.apply(result):
+ return True
+ return False
+ sql_select = sql_builder.select(query.final_columns).limit(1)
+ with self.db.query(sql_select) as result:
+ return result.first() is not None
+
+ def explain_no_results(self, tree: qt.QueryTree, execute: bool) -> Iterable[str]:
+ # Docstring inherited.
+ query, _ = self.analyze_query(tree, qt.ColumnSet(tree.dimensions))
+ if query.messages or not execute:
+ return query.messages
+ # TODO: guess at ways to split up query that might fail or succeed if
+ # run separately, execute them with LIMIT 1 and report the results.
+ return []
+
+ def get_dataset_dimensions(self, name: str) -> DimensionGroup:
+ # Docstring inherited
+ return self.managers.datasets[name].datasetType.dimensions.as_group()
+
+ def get_default_collections(self) -> tuple[str, ...]:
+ # Docstring inherited.
+ if not self._defaults.collections:
+ raise NoDefaultCollectionError("No collections provided and no default collections.")
+ return tuple(self._defaults.collections)
+
+ def resolve_collection_path(
+ self, collections: Iterable[str]
+ ) -> list[tuple[CollectionRecord, CollectionSummary]]:
+ result: list[tuple[CollectionRecord, CollectionSummary]] = []
+ done: set[str] = set()
+
+ def recurse(collection_names: Iterable[str]) -> None:
+ for collection_name in collection_names:
+ if collection_name not in done:
+ done.add(collection_name)
+ record = self.managers.collections.find(collection_name)
+
+ if record.type is CollectionType.CHAINED:
+ recurse(cast(ChainedCollectionRecord, record).children)
+ else:
+ result.append((record, self.managers.datasets.getCollectionSummary(record)))
+
+ recurse(collections)
+
+ return result
+
+ def analyze_query(
+ self,
+ tree: qt.QueryTree,
+ final_columns: qt.ColumnSet,
+ order_by: Iterable[qt.OrderExpression] = (),
+ find_first_dataset: str | None = None,
+ ) -> tuple[AnalyzedQuery, SqlBuilder]:
+ # Delegate to the dimensions manager to rewrite the predicate and
+ # start a SqlBuilder and Postprocessing to cover any spatial overlap
+ # joins or constraints. We'll return that SqlBuilder at the end.
+ (
+ predicate,
+ sql_builder,
+ postprocessing,
+ ) = self.managers.dimensions.process_query_overlaps(
+ tree.dimensions,
+ tree.predicate,
+ tree.join_operand_dimensions,
+ )
+ # Initialize the AnalyzedQuery instance we'll update throughout this
+ # method.
+ query = AnalyzedQuery(
+ predicate,
+ postprocessing,
+ base_columns=qt.ColumnSet(tree.dimensions),
+ projection_columns=final_columns.copy(),
+ final_columns=final_columns,
+ find_first_dataset=find_first_dataset,
+ )
+ # The base query needs to join in all columns required by the
+ # predicate.
+ predicate.gather_required_columns(query.base_columns)
+ # The "projection" query differs from the final query by not omitting
+ # any dimension keys (since that makes it easier to reason about),
+ # including any columns needed by order_by terms, and including
+ # the dataset rank if there's a find-first search in play.
+ query.projection_columns.restore_dimension_keys()
+ for term in order_by:
+ term.gather_required_columns(query.projection_columns)
+ if query.find_first_dataset is not None:
+ query.projection_columns.dataset_fields[query.find_first_dataset].add("collection_key")
+ # The base query also needs to include all columns needed by the
+ # downstream projection query.
+ query.base_columns.update(query.projection_columns)
+ # Extract the data ID implied by the predicate; we can use the governor
+ # dimensions in that to constrain the collections we search for
+ # datasets later.
+ query.predicate.visit(DataIdExtractionVisitor(query.constraint_data_id, query.messages))
+ # We also check that the predicate doesn't reference any dimensions
+ # without constraining their governor dimensions, since that's a
+ # particularly easy mistake to make and it's almost never intentional.
+ # We also also the registry data ID values to provide governor values.
+ where_columns = qt.ColumnSet(query.universe.empty.as_group())
+ query.predicate.gather_required_columns(where_columns)
+ for governor in where_columns.dimensions.governors:
+ if governor not in query.constraint_data_id:
+ if governor in self._defaults.dataId.dimensions:
+ query.constraint_data_id[governor] = self._defaults.dataId[governor]
+ else:
+ raise qt.InvalidQueryTreeError(
+ f"Query 'where' expression references a dimension dependent on {governor} without "
+ "constraining it directly."
+ )
+ # Add materializations, which can also bring in more postprocessing.
+ for m_key, m_dimensions in tree.materializations.items():
+ _, m_postprocessing = self._materializations[m_key]
+ query.materializations[m_key] = m_dimensions
+ # When a query is materialized, the new tree's has an empty
+ # (trivially true) predicate, and the materialization prevents the
+ # creation of automatic spatial joins that are already included in
+ # the materialization, so we don't need to deduplicate these
+ # filters. It's possible for there to be duplicates, but only if
+ # the user explicitly adds a redundant constraint, and we'll still
+ # behave correctly (just less efficiently) if that happens.
+ postprocessing.spatial_join_filtering.extend(m_postprocessing.spatial_join_filtering)
+ postprocessing.spatial_where_filtering.extend(m_postprocessing.spatial_where_filtering)
+ # Add data coordinate uploads.
+ query.data_coordinate_uploads.update(tree.data_coordinate_uploads)
+ # Add dataset_searches and filter out collections that don't have the
+ # right dataset type or governor dimensions.
+ name_shrinker = make_dataset_name_shrinker(self.db.dialect)
+ for dataset_type_name, dataset_search in tree.datasets.items():
+ dataset = AnalyzedDatasetSearch(
+ dataset_type_name, name_shrinker.shrink(dataset_type_name), dataset_search.dimensions
+ )
+ for collection_record, collection_summary in self.resolve_collection_path(
+ dataset_search.collections
+ ):
+ rejected: bool = False
+ if dataset.name not in collection_summary.dataset_types.names:
+ dataset.messages.append(
+ f"No datasets of type {dataset.name!r} in collection {collection_record.name}."
+ )
+ rejected = True
+ for governor in query.constraint_data_id.keys() & collection_summary.governors.keys():
+ if query.constraint_data_id[governor] not in collection_summary.governors[governor]:
+ dataset.messages.append(
+ f"No datasets with {governor}={query.constraint_data_id[governor]!r} "
+ f"in collection {collection_record.name}."
+ )
+ rejected = True
+ if not rejected:
+ if collection_record.type is CollectionType.CALIBRATION:
+ dataset.is_calibration_search = True
+ dataset.collection_records.append(collection_record)
+ if dataset.dimensions != self.get_dataset_type(dataset_type_name).dimensions.as_group():
+ # This is really for server-side defensiveness; it's hard to
+ # imagine the query getting different dimensions for a dataset
+ # type in two calls to the same query driver.
+ raise qt.InvalidQueryTreeError(
+ f"Incorrect dimensions {dataset.dimensions} for dataset {dataset_type_name} "
+ f"in query (vs. {self.get_dataset_type(dataset_type_name).dimensions.as_group()})."
+ )
+ query.datasets[dataset_type_name] = dataset
+ if not dataset.collection_records:
+ query.messages.append(f"Search for dataset type {dataset_type_name!r} is doomed to fail.")
+ query.messages.extend(dataset.messages)
+ # Set flags that indicate certain kinds of special processing the query
+ # will need, mostly in the "projection" stage, where we might do a
+ # GROUP BY or DISTINCT [ON].
+ if query.find_first_dataset is not None:
+ # If we're doing a find-first search and there's a calibration
+ # collection in play, we need to make sure the rows coming out of
+ # the base query have only one timespan for each data ID +
+ # collection, and we can only do that with a GROUP BY and COUNT.
+ query.postprocessing.check_validity_match_count = query.datasets[
+ query.find_first_dataset
+ ].is_calibration_search
+ # We only actually need to include the find-first resolution query
+ # logic if there's more than one collection.
+ query.needs_find_first_resolution = (
+ len(query.datasets[query.find_first_dataset].collection_records) > 1
+ )
+ if query.projection_columns.dimensions != query.base_columns.dimensions:
+ # We're going from a larger set of dimensions to a smaller set,
+ # that means we'll be doing a SELECT DISTINCT [ON] or GROUP BY.
+ query.needs_dimension_distinct = True
+ # If there are any dataset fields being propagated through that
+ # projection and there is more than one collection, we need to
+ # include the collection_key column so we can use that as one of
+ # the DISTINCT ON or GROUP BY columns.
+ for dataset_type, fields_for_dataset in query.projection_columns.dataset_fields.items():
+ if len(query.datasets[dataset_type].collection_records) > 1:
+ fields_for_dataset.add("collection_key")
+ # If there's a projection and we're doing postprocessing, we might
+ # be collapsing the dimensions of the postprocessing regions. When
+ # that happens, we want to apply an aggregate function to them that
+ # computes the union of the regions that are grouped together.
+ for element in query.postprocessing.iter_missing(query.projection_columns):
+ if element.name not in query.projection_columns.dimensions.elements:
+ query.projection_region_aggregates.append(element)
+ break
+ return query, sql_builder
+
+ def build_query(self, query: AnalyzedQuery, sql_builder: SqlBuilder) -> SqlBuilder:
+ sql_builder = self._build_base_query(query, sql_builder)
+ if query.needs_projection:
+ sql_builder = self._project_query(query, sql_builder)
+ if query.needs_find_first_resolution:
+ sql_builder = self._apply_find_first(query, sql_builder)
+ elif query.needs_find_first_resolution:
+ sql_builder = self._apply_find_first(
+ query, sql_builder.cte(query.projection_columns, query.postprocessing)
+ )
+ return sql_builder
+
+ def _build_base_query(self, query: AnalyzedQuery, sql_builder: SqlBuilder) -> SqlBuilder:
+ # Process data coordinate upload joins.
+ for upload_key, upload_dimensions in query.data_coordinate_uploads.items():
+ sql_builder = sql_builder.join(
+ SqlBuilder(self.db, self._upload_tables[upload_key]).extract_dimensions(
+ upload_dimensions.required
+ )
+ )
+ # Process materialization joins.
+ for materialization_key, materialization_spec in query.materializations.items():
+ sql_builder = self._join_materialization(sql_builder, materialization_key, materialization_spec)
+ # Process dataset joins.
+ for dataset_type, dataset_search in query.datasets.items():
+ sql_builder = self._join_dataset_search(
+ sql_builder,
+ dataset_type,
+ dataset_search,
+ query.base_columns,
+ )
+ # Join in dimension element tables that we know we need relationships
+ # or columns from.
+ for element in query.iter_mandatory_base_elements():
+ sql_builder = sql_builder.join(
+ self.managers.dimensions.make_sql_builder(
+ element, query.base_columns.dimension_fields[element.name]
+ )
+ )
+ # See if any dimension keys are still missing, and if so join in their
+ # tables. Note that we know there are no fields needed from these.
+ while not (sql_builder.dimension_keys.keys() >= query.base_columns.dimensions.names):
+ # Look for opportunities to join in multiple dimensions via single
+ # table, to reduce the total number of tables joined in.
+ missing_dimension_names = query.base_columns.dimensions.names - sql_builder.dimension_keys.keys()
+ best = self._universe[
+ max(
+ missing_dimension_names,
+ key=lambda name: len(self._universe[name].dimensions.names & missing_dimension_names),
+ )
+ ]
+ sql_builder = sql_builder.join(self.managers.dimensions.make_sql_builder(best, frozenset()))
+ # Add the WHERE clause to the builder.
+ return sql_builder.where_sql(query.predicate.visit(SqlColumnVisitor(sql_builder, self)))
+
+ def _project_query(self, query: AnalyzedQuery, sql_builder: SqlBuilder) -> SqlBuilder:
+ assert query.needs_projection
+ # This method generates a Common Table Expresssion (CTE) using either a
+ # SELECT DISTINCT [ON] or a SELECT with GROUP BY.
+ # We'll work out which as we go
+ have_aggregates: bool = False
+ # Dimension key columns form at least most of our GROUP BY or DISTINCT
+ # ON clause; we'll work out which of those we'll use.
+ unique_keys: list[sqlalchemy.ColumnElement[Any]] = [
+ sql_builder.dimension_keys[k][0] for k in query.projection_columns.dimensions.data_coordinate_keys
+ ]
+ # There are two reasons we might need an aggregate function:
+ # - to make sure temporal constraints and joins have resulted in at
+ # most one validity range match for each data ID and collection,
+ # when we're doing a find-first query.
+ # - to compute the unions of regions we need for postprocessing, when
+ # the data IDs for those regions are not wholly included in the
+ # results (i.e. we need to postprocess on
+ # visit_detector_region.region, but the output rows don't have
+ # detector, just visit - so we compute the union of the
+ # visit_detector region over all matched detectors).
+ if query.postprocessing.check_validity_match_count:
+ sql_builder.special[query.postprocessing.VALIDITY_MATCH_COUNT] = sqlalchemy.func.count().label(
+ query.postprocessing.VALIDITY_MATCH_COUNT
+ )
+ have_aggregates = True
+ for element in query.projection_region_aggregates:
+ sql_builder.fields[element.name]["region"] = ddl.Base64Region.union_aggregate(
+ sql_builder.fields[element.name]["region"]
+ )
+ have_aggregates = True
+ # Many of our fields derive their uniqueness from the unique_key
+ # fields: if rows are uniqe over the 'unique_key' fields, then they're
+ # automatically unique over these 'derived_fields'. We just remember
+ # these as pairs of (logical_table, field) for now.
+ derived_fields: list[tuple[str, str]] = []
+ # All dimension record fields are derived fields.
+ for element_name, fields_for_element in query.projection_columns.dimension_fields.items():
+ for element_field in fields_for_element:
+ derived_fields.append((element_name, element_field))
+ # Some dataset fields are derived fields and some are unique keys, and
+ # it depends on the kinds of collection(s) we're searching and whether
+ # it's a find-first query.
+ for dataset_type, fields_for_dataset in query.projection_columns.dataset_fields.items():
+ for dataset_field in fields_for_dataset:
+ if dataset_field == "collection_key":
+ # If the collection_key field is present, it's needed for
+ # uniqueness if we're looking in more than one collection.
+ # If not, it's a derived field.
+ if len(query.datasets[dataset_type].collection_records) > 1:
+ unique_keys.append(sql_builder.fields[dataset_type]["collection_key"])
+ else:
+ derived_fields.append((dataset_type, "collection_key"))
+ elif dataset_field == "timespan" and query.datasets[dataset_type].is_calibration_search:
+ # If we're doing a non-find-first query against a
+ # CALIBRATION collection, the timespan is also a unique
+ # key...
+ if dataset_type == query.find_first_dataset:
+ # ...unless we're doing a find-first search on this
+ # dataset, in which case we need to use ANY_VALUE on
+ # the timespan and check that _VALIDITY_MATCH_COUNT
+ # (added earlier) is one, indicating that there was
+ # indeed only one timespan for each data ID in each
+ # collection that survived the base query's WHERE
+ # clauses and JOINs.
+ if not self.db.has_any_aggregate:
+ raise NotImplementedError(
+ f"Cannot generate query that returns {dataset_type}.timespan after a "
+ "find-first search, because this a database does not support the ANY_VALUE "
+ "aggregate function (or equivalent)."
+ )
+ sql_builder.timespans[dataset_type] = sql_builder.timespans[
+ dataset_type
+ ].apply_any_aggregate(self.db.apply_any_aggregate)
+ else:
+ unique_keys.extend(sql_builder.timespans[dataset_type].flatten())
+ else:
+ # Other dataset fields derive their uniqueness from key
+ # fields.
+ derived_fields.append((dataset_type, dataset_field))
+ if not have_aggregates and not derived_fields:
+ # SELECT DISTINCT is sufficient.
+ return sql_builder.cte(query.projection_columns, query.postprocessing, distinct=True)
+ elif not have_aggregates and self.db.has_distinct_on:
+ # SELECT DISTINCT ON is sufficient and works.
+ return sql_builder.cte(query.projection_columns, query.postprocessing, distinct=unique_keys)
+ else:
+ # GROUP BY is the only option.
+ if derived_fields:
+ if self.db.has_any_aggregate:
+ for logical_table, field in derived_fields:
+ if field == "timespan":
+ sql_builder.timespans[logical_table] = sql_builder.timespans[
+ logical_table
+ ].apply_any_aggregate(self.db.apply_any_aggregate)
+ else:
+ sql_builder.fields[logical_table][field] = self.db.apply_any_aggregate(
+ sql_builder.fields[logical_table][field]
+ )
+ else:
+ _LOG.warning(
+ "Adding %d fields to GROUP BY because this database backend does not support the "
+ "ANY_VALUE aggregate function (or equivalent). This may result in a poor query "
+ "plan. Materializing the query first sometimes avoids this problem.",
+ len(derived_fields),
+ )
+ for logical_table, field in derived_fields:
+ if field == "timespan":
+ unique_keys.extend(sql_builder.timespans[logical_table].flatten())
+ else:
+ unique_keys.append(sql_builder.fields[logical_table][field])
+ return sql_builder.cte(query.projection_columns, query.postprocessing, group_by=unique_keys)
+
+ def _apply_find_first(self, query: AnalyzedQuery, sql_builder: SqlBuilder) -> SqlBuilder:
+ assert query.needs_find_first_resolution
+ assert query.find_first_dataset is not None
+ assert sql_builder.sql_from_clause is not None
+ # The query we're building looks like this:
+ #
+ # WITH {dst}_base AS (
+ # {target}
+ # ...
+ # )
+ # SELECT
+ # {dst}_window.*,
+ # FROM (
+ # SELECT
+ # {dst}_base.*,
+ # ROW_NUMBER() OVER (
+ # PARTITION BY {dst_base}.{dimensions}
+ # ORDER BY {rank}
+ # ) AS rownum
+ # ) {dst}_window
+ # WHERE
+ # {dst}_window.rownum = 1;
+ #
+ # The outermost SELECT will be represented by the SqlBuilder we return.
+
+ # The sql_builder we're given corresponds to the Common Table
+ # Expression (CTE) at the top, and is guaranteed to have
+ # ``query.projected_columns`` (+ postprocessing columns).
+ # We start by filling out the "window" SELECT statement...
+ partition_by = [sql_builder.dimension_keys[d][0] for d in query.base_columns.dimensions.required]
+ rank_sql_column = sqlalchemy.case(
+ {
+ record.key: n
+ for n, record in enumerate(query.datasets[query.find_first_dataset].collection_records)
+ },
+ value=sql_builder.fields[query.find_first_dataset]["collection_key"],
+ )
+ if partition_by:
+ sql_builder.special["_ROWNUM"] = sqlalchemy.sql.func.row_number().over(
+ partition_by=partition_by, order_by=rank_sql_column
+ )
+ else:
+ sql_builder.special["_ROWNUM"] = sqlalchemy.sql.func.row_number().over(order_by=rank_sql_column)
+ # ... and then turn that into a subquery with a constraint on rownum.
+ sql_builder = sql_builder.subquery(query.base_columns, query.postprocessing)
+ sql_builder = sql_builder.where_sql(sql_builder.special["_ROWNUM"] == 1)
+ del sql_builder.special["_ROWNUM"]
+ return sql_builder
+
+ def _join_materialization(
+ self,
+ sql_builder: SqlBuilder,
+ materialization_key: qt.MaterializationKey,
+ dimensions: DimensionGroup,
+ ) -> SqlBuilder:
+ columns = qt.ColumnSet(dimensions)
+ table, postprocessing = self._materializations[materialization_key]
+ return sql_builder.join(SqlBuilder(self.db, table).extract_columns(columns, postprocessing))
+
+ def _join_dataset_search(
+ self,
+ sql_builder: SqlBuilder,
+ dataset_type: str,
+ processed_dataset_search: AnalyzedDatasetSearch,
+ columns: qt.ColumnSet,
+ ) -> SqlBuilder:
+ storage = self.managers.datasets[dataset_type]
+ # The next two asserts will need to be dropped (and the implications
+ # dealt with instead) if materializations start having dataset fields.
+ assert (
+ dataset_type not in sql_builder.fields
+ ), "Dataset fields have unexpected already been joined in."
+ assert (
+ dataset_type not in sql_builder.timespans
+ ), "Dataset timespan has unexpected already been joined in."
+ return sql_builder.join(
+ storage.make_sql_builder(
+ processed_dataset_search.collection_records, columns.dataset_fields[dataset_type]
+ )
+ )
+
+ def _process_page(
+ self,
+ raw_page_iter: Iterator[Sequence[sqlalchemy.Row]],
+ result_spec: ResultSpec,
+ postprocessing: Postprocessing,
+ ) -> ResultPage:
+ try:
+ raw_page = next(raw_page_iter)
+ except StopIteration:
+ raw_page = tuple()
+ if len(raw_page) == self._raw_page_size:
+ # There's some chance we got unlucky and this page exactly finishes
+ # off the query, and we won't know the next page does not exist
+ # until we try to fetch it. But that's better than always fetching
+ # the next page up front.
+ next_key = uuid.uuid4()
+ self._active_pages[next_key] = (raw_page_iter, postprocessing)
+ else:
+ next_key = None
+ match result_spec:
+ case DimensionRecordResultSpec():
+ return convert_dimension_record_results(
+ postprocessing.apply(raw_page),
+ result_spec,
+ next_key,
+ self._name_shrinker,
+ )
+ case _:
+ raise NotImplementedError("TODO")
+
+
+def make_dataset_name_shrinker(dialect: sqlalchemy.Dialect) -> NameShrinker:
+ max_dataset_field_length = max(len(field) for field in qt.DATASET_FIELD_NAMES)
+ return NameShrinker(dialect.max_identifier_length - max_dataset_field_length - 1, 6)
diff --git a/python/lsst/daf/butler/direct_query_driver/_postprocessing.py b/python/lsst/daf/butler/direct_query_driver/_postprocessing.py
new file mode 100644
index 0000000000..2dc4ead742
--- /dev/null
+++ b/python/lsst/daf/butler/direct_query_driver/_postprocessing.py
@@ -0,0 +1,138 @@
+# This file is part of daf_butler.
+#
+# Developed for the LSST Data Management System.
+# This product includes software developed by the LSST Project
+# (http://www.lsst.org).
+# See the COPYRIGHT file at the top-level directory of this distribution
+# for details of code ownership.
+#
+# This software is dual licensed under the GNU General Public License and also
+# under a 3-clause BSD license. Recipients may choose which of these licenses
+# to use; please see the files gpl-3.0.txt and/or bsd_license.txt,
+# respectively. If you choose the GPL option then the following text applies
+# (but note that there is still no warranty even if you opt for BSD instead):
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see .
+
+from __future__ import annotations
+
+__all__ = ("Postprocessing", "ValidityRangeMatchError")
+
+from collections.abc import Iterable, Iterator
+from typing import TYPE_CHECKING, ClassVar
+
+import sqlalchemy
+from lsst.sphgeom import DISJOINT, Region
+
+from ..queries import tree as qt
+
+if TYPE_CHECKING:
+ from ..dimensions import DimensionElement
+
+
+class ValidityRangeMatchError(RuntimeError):
+ pass
+
+
+class Postprocessing:
+ def __init__(self) -> None:
+ self.spatial_join_filtering: list[tuple[DimensionElement, DimensionElement]] = []
+ self.spatial_where_filtering: list[tuple[DimensionElement, Region]] = []
+ self.check_validity_match_count: bool = False
+ self._offset: int = 0
+ self._limit: int | None = None
+
+ VALIDITY_MATCH_COUNT: ClassVar[str] = "_VALIDITY_MATCH_COUNT"
+
+ @property
+ def offset(self) -> int:
+ return self._offset
+
+ @offset.setter
+ def offset(self, value: int) -> None:
+ if value and not self:
+ raise RuntimeError(
+ "Postprocessing should only implement 'offset' if it needs to do spatial filtering."
+ )
+ self._offset = value
+
+ @property
+ def limit(self) -> int | None:
+ return self._limit
+
+ @limit.setter
+ def limit(self, value: int | None) -> None:
+ if value and not self:
+ raise RuntimeError(
+ "Postprocessing should only implement 'limit' if it needs to do spatial filtering."
+ )
+ self._limit = value
+
+ def __bool__(self) -> bool:
+ return bool(self.spatial_join_filtering or self.spatial_where_filtering)
+
+ def gather_columns_required(self, columns: qt.ColumnSet) -> None:
+ for element in self.iter_region_dimension_elements():
+ columns.update_dimensions(element.minimal_group)
+ columns.dimension_fields[element.name].add("region")
+
+ def iter_region_dimension_elements(self) -> Iterator[DimensionElement]:
+ for a, b in self.spatial_join_filtering:
+ yield a
+ yield b
+ for element, _ in self.spatial_where_filtering:
+ yield element
+
+ def iter_missing(self, columns: qt.ColumnSet) -> Iterator[DimensionElement]:
+ done: set[DimensionElement] = set()
+ for element in self.iter_region_dimension_elements():
+ if element not in done:
+ if "region" not in columns.dimension_fields.get(element.name, frozenset()):
+ yield element
+ done.add(element)
+
+ def apply(self, rows: Iterable[sqlalchemy.Row]) -> Iterable[sqlalchemy.Row]:
+ if not self:
+ yield from rows
+ joins = [
+ (
+ qt.ColumnSet.get_qualified_name(a.name, "region"),
+ qt.ColumnSet.get_qualified_name(b.name, "region"),
+ )
+ for a, b in self.spatial_join_filtering
+ ]
+ where = [
+ (qt.ColumnSet.get_qualified_name(element.name, "region"), region)
+ for element, region in self.spatial_where_filtering
+ ]
+ for row in rows:
+ m = row._mapping
+ if any(m[a].relate(m[b]) & DISJOINT for a, b in joins) or any(
+ m[field].relate(region) & DISJOINT for field, region in where
+ ):
+ continue
+ if self.check_validity_match_count and m[self.VALIDITY_MATCH_COUNT] > 1:
+ raise ValidityRangeMatchError(
+ "Ambiguous calibration validity range match. This usually means a temporal join or "
+ "'where' needs to be added, but it could also mean that multiple validity ranges "
+ "overlap a single output data ID."
+ )
+ if self._offset:
+ self._offset -= 1
+ continue
+ if self._limit == 0:
+ break
+ yield row
+ if self._limit is not None:
+ self._limit -= 1
diff --git a/python/lsst/daf/butler/direct_query_driver/_sql_builder.py b/python/lsst/daf/butler/direct_query_driver/_sql_builder.py
new file mode 100644
index 0000000000..84726f4bc5
--- /dev/null
+++ b/python/lsst/daf/butler/direct_query_driver/_sql_builder.py
@@ -0,0 +1,257 @@
+# This file is part of daf_butler.
+#
+# Developed for the LSST Data Management System.
+# This product includes software developed by the LSST Project
+# (http://www.lsst.org).
+# See the COPYRIGHT file at the top-level directory of this distribution
+# for details of code ownership.
+#
+# This software is dual licensed under the GNU General Public License and also
+# under a 3-clause BSD license. Recipients may choose which of these licenses
+# to use; please see the files gpl-3.0.txt and/or bsd_license.txt,
+# respectively. If you choose the GPL option then the following text applies
+# (but note that there is still no warranty even if you opt for BSD instead):
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see .
+
+from __future__ import annotations
+
+__all__ = ("SqlBuilder",)
+
+import dataclasses
+import itertools
+from collections.abc import Iterable, Sequence
+from typing import TYPE_CHECKING, Any, ClassVar
+
+import sqlalchemy
+
+from .. import ddl
+from ..nonempty_mapping import NonemptyMapping
+from ..queries import tree as qt
+from ._postprocessing import Postprocessing
+
+if TYPE_CHECKING:
+ from ..registry.interfaces import Database
+ from ..timespan_database_representation import TimespanDatabaseRepresentation
+
+
+@dataclasses.dataclass
+class SqlBuilder:
+ db: Database
+ sql_from_clause: sqlalchemy.FromClause | None = None
+ sql_where_terms: list[sqlalchemy.ColumnElement[bool]] = dataclasses.field(default_factory=list)
+ needs_distinct: bool = False
+
+ dimension_keys: NonemptyMapping[str, list[sqlalchemy.ColumnElement]] = dataclasses.field(
+ default_factory=lambda: NonemptyMapping(list)
+ )
+
+ fields: NonemptyMapping[str, dict[str, sqlalchemy.ColumnElement[Any]]] = dataclasses.field(
+ default_factory=lambda: NonemptyMapping(dict)
+ )
+
+ timespans: dict[str, TimespanDatabaseRepresentation] = dataclasses.field(default_factory=dict)
+
+ special: dict[str, sqlalchemy.ColumnElement[Any]] = dataclasses.field(default_factory=dict)
+
+ EMPTY_COLUMNS_NAME: ClassVar[str] = "IGNORED"
+ """Name of the column added to a SQL ``SELECT`` query in order to represent
+ relations that have no real columns.
+ """
+
+ EMPTY_COLUMNS_TYPE: ClassVar[type] = sqlalchemy.Boolean
+ """Type of the column added to a SQL ``SELECT`` query in order to represent
+ relations that have no real columns.
+ """
+
+ @property
+ def sql_columns(self) -> sqlalchemy.ColumnCollection:
+ assert self.sql_from_clause is not None
+ return self.sql_from_clause.columns
+
+ @classmethod
+ def handle_empty_columns(
+ cls, columns: list[sqlalchemy.sql.ColumnElement]
+ ) -> list[sqlalchemy.ColumnElement]:
+ """Handle the edge case where a SELECT statement has no columns, by
+ adding a literal column that should be ignored.
+
+ Parameters
+ ----------
+ columns : `list` [ `sqlalchemy.ColumnElement` ]
+ List of SQLAlchemy column objects. This may have no elements when
+ this method is called, and will always have at least one element
+ when it returns.
+
+ Returns
+ -------
+ columns : `list` [ `sqlalchemy.ColumnElement` ]
+ The same list that was passed in, after any modification.
+ """
+ if not columns:
+ columns.append(sqlalchemy.sql.literal(True).label(cls.EMPTY_COLUMNS_NAME))
+ return columns
+
+ def select(
+ self,
+ columns: qt.ColumnSet,
+ postprocessing: Postprocessing | None = None,
+ *,
+ distinct: bool | Sequence[sqlalchemy.ColumnElement[Any]] = False,
+ group_by: Sequence[sqlalchemy.ColumnElement] = (),
+ ) -> sqlalchemy.Select:
+ sql_columns: list[sqlalchemy.ColumnElement[Any]] = []
+ for logical_table, field in columns:
+ name = columns.get_qualified_name(logical_table, field)
+ if field is None:
+ sql_columns.append(self.dimension_keys[logical_table][0].label(name))
+ elif columns.is_timespan(logical_table, field):
+ sql_columns.extend(self.timespans[logical_table].flatten(name))
+ else:
+ sql_columns.append(self.fields[logical_table][field].label(name))
+ if postprocessing is not None:
+ for element in postprocessing.iter_missing(columns):
+ assert (
+ element.name in columns.dimensions.elements
+ ), "Region aggregates not handled by this method."
+ sql_columns.append(
+ self.fields[element.name]["region"].label(
+ columns.get_qualified_name(element.name, "region")
+ )
+ )
+ for label, sql_column in self.special.items():
+ sql_columns.append(sql_column.label(label))
+ self.handle_empty_columns(sql_columns)
+ result = sqlalchemy.select(*sql_columns)
+ if self.sql_from_clause is not None:
+ result = result.select_from(self.sql_from_clause)
+ if self.needs_distinct or distinct:
+ if distinct is True or distinct is False:
+ result = result.distinct()
+ else:
+ result = result.distinct(*distinct)
+ if group_by:
+ result = result.group_by(*group_by)
+ if self.sql_where_terms:
+ result = result.where(*self.sql_where_terms)
+ return result
+
+ def make_table_spec(
+ self,
+ columns: qt.ColumnSet,
+ postprocessing: Postprocessing | None = None,
+ ) -> ddl.TableSpec:
+ assert not self.special, "special columns not supported in make_table_spec"
+ results = ddl.TableSpec(
+ [columns.get_column_spec(logical_table, field).to_sql_spec() for logical_table, field in columns]
+ )
+ if postprocessing:
+ for element in postprocessing.iter_missing(columns):
+ results.fields.add(
+ ddl.FieldSpec.for_region(columns.get_qualified_name(element.name, "region"))
+ )
+ return results
+
+ def extract_dimensions(self, dimensions: Iterable[str], **kwargs: str) -> SqlBuilder:
+ assert self.sql_from_clause is not None, "Cannot extract columns with no FROM clause."
+ for dimension_name in dimensions:
+ self.dimension_keys[dimension_name].append(self.sql_from_clause.columns[dimension_name])
+ for k, v in kwargs.items():
+ self.dimension_keys[v].append(self.sql_from_clause.columns[k])
+ return self
+
+ def extract_columns(
+ self, columns: qt.ColumnSet, postprocessing: Postprocessing | None = None
+ ) -> SqlBuilder:
+ assert self.sql_from_clause is not None, "Cannot extract columns with no FROM clause."
+ for logical_table, field in columns:
+ name = columns.get_qualified_name(logical_table, field)
+ if field is None:
+ self.dimension_keys[logical_table].append(self.sql_from_clause.columns[name])
+ elif columns.is_timespan(logical_table, field):
+ self.timespans[logical_table] = self.db.getTimespanRepresentation().from_columns(
+ self.sql_from_clause.columns, name
+ )
+ else:
+ self.fields[logical_table][field] = self.sql_from_clause.columns[name]
+ if postprocessing is not None:
+ for element in postprocessing.iter_missing(columns):
+ self.fields[element.name]["region"] = self.sql_from_clause.columns[name]
+ if postprocessing.check_validity_match_count:
+ self.special[postprocessing.VALIDITY_MATCH_COUNT] = self.sql_from_clause.columns[
+ postprocessing.VALIDITY_MATCH_COUNT
+ ]
+ return self
+
+ def join(self, other: SqlBuilder) -> SqlBuilder:
+ join_on: list[sqlalchemy.ColumnElement] = []
+ for dimension_name in self.dimension_keys.keys() & other.dimension_keys.keys():
+ for column1, column2 in itertools.product(
+ self.dimension_keys[dimension_name], other.dimension_keys[dimension_name]
+ ):
+ join_on.append(column1 == column2)
+ self.dimension_keys[dimension_name].extend(other.dimension_keys[dimension_name])
+ if self.sql_from_clause is None:
+ self.sql_from_clause = other.sql_from_clause
+ elif other.sql_from_clause is not None:
+ self.sql_from_clause = self.sql_from_clause.join(
+ other.sql_from_clause, onclause=sqlalchemy.and_(*join_on)
+ )
+ self.sql_where_terms += other.sql_where_terms
+ self.needs_distinct = self.needs_distinct or other.needs_distinct
+ self.special.update(other.special)
+ return self
+
+ def where_sql(self, *arg: sqlalchemy.ColumnElement[bool]) -> SqlBuilder:
+ self.sql_where_terms.extend(arg)
+ return self
+
+ def cte(
+ self,
+ columns: qt.ColumnSet,
+ postprocessing: Postprocessing | None = None,
+ *,
+ distinct: bool | Sequence[sqlalchemy.ColumnElement[Any]] = False,
+ group_by: Sequence[sqlalchemy.ColumnElement] = (),
+ ) -> SqlBuilder:
+ return SqlBuilder(
+ self.db,
+ self.select(columns, postprocessing, distinct=distinct, group_by=group_by).cte(),
+ ).extract_columns(columns, postprocessing)
+
+ def subquery(
+ self,
+ columns: qt.ColumnSet,
+ postprocessing: Postprocessing | None = None,
+ *,
+ distinct: bool | Sequence[sqlalchemy.ColumnElement[Any]] = False,
+ group_by: Sequence[sqlalchemy.ColumnElement] = (),
+ ) -> SqlBuilder:
+ return SqlBuilder(
+ self.db,
+ self.select(columns, postprocessing, distinct=distinct, group_by=group_by).subquery(),
+ ).extract_columns(columns, postprocessing)
+
+ def union_subquery(
+ self,
+ others: Iterable[SqlBuilder],
+ columns: qt.ColumnSet,
+ postprocessing: Postprocessing | None = None,
+ ) -> SqlBuilder:
+ select0 = self.select(columns, postprocessing)
+ other_selects = [other.select(columns, postprocessing) for other in others]
+ return SqlBuilder(
+ self.db,
+ select0.union(*other_selects).subquery(),
+ ).extract_columns(columns, postprocessing)
diff --git a/python/lsst/daf/butler/direct_query_driver/_sql_column_visitor.py b/python/lsst/daf/butler/direct_query_driver/_sql_column_visitor.py
new file mode 100644
index 0000000000..d3a6de201f
--- /dev/null
+++ b/python/lsst/daf/butler/direct_query_driver/_sql_column_visitor.py
@@ -0,0 +1,239 @@
+# This file is part of daf_butler.
+#
+# Developed for the LSST Data Management System.
+# This product includes software developed by the LSST Project
+# (http://www.lsst.org).
+# See the COPYRIGHT file at the top-level directory of this distribution
+# for details of code ownership.
+#
+# This software is dual licensed under the GNU General Public License and also
+# under a 3-clause BSD license. Recipients may choose which of these licenses
+# to use; please see the files gpl-3.0.txt and/or bsd_license.txt,
+# respectively. If you choose the GPL option then the following text applies
+# (but note that there is still no warranty even if you opt for BSD instead):
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see .
+
+from __future__ import annotations
+
+__all__ = ("SqlColumnVisitor",)
+
+from typing import TYPE_CHECKING, Any, cast
+
+import sqlalchemy
+
+from .. import ddl
+from ..queries.visitors import ColumnExpressionVisitor, PredicateVisitFlags, PredicateVisitor
+from ..timespan_database_representation import TimespanDatabaseRepresentation
+
+if TYPE_CHECKING:
+ from ..queries import tree as qt
+ from ._driver import DirectQueryDriver
+ from ._sql_builder import SqlBuilder
+
+
+class SqlColumnVisitor(
+ ColumnExpressionVisitor[sqlalchemy.ColumnElement[Any] | TimespanDatabaseRepresentation],
+ PredicateVisitor[
+ sqlalchemy.ColumnElement[bool], sqlalchemy.ColumnElement[bool], sqlalchemy.ColumnElement[bool]
+ ],
+):
+ def __init__(self, sql_builder: SqlBuilder, driver: DirectQueryDriver):
+ self._driver = driver
+ self._sql_builder = sql_builder
+
+ def visit_literal(
+ self, expression: qt.ColumnLiteral
+ ) -> sqlalchemy.ColumnElement[Any] | TimespanDatabaseRepresentation:
+ # Docstring inherited.
+ if expression.column_type == "timespan":
+ return self._driver.db.getTimespanRepresentation().fromLiteral(expression.get_literal_value())
+ return sqlalchemy.literal(
+ expression.get_literal_value(), type_=ddl.VALID_CONFIG_COLUMN_TYPES[expression.column_type]
+ )
+
+ def visit_dimension_key_reference(
+ self, expression: qt.DimensionKeyReference
+ ) -> sqlalchemy.ColumnElement[int | str]:
+ # Docstring inherited.
+ return self._sql_builder.dimension_keys[expression.dimension.name][0]
+
+ def visit_dimension_field_reference(
+ self, expression: qt.DimensionFieldReference
+ ) -> sqlalchemy.ColumnElement[Any] | TimespanDatabaseRepresentation:
+ # Docstring inherited.
+ if expression.column_type == "timespan":
+ return self._sql_builder.timespans[expression.element.name]
+ return self._sql_builder.fields[expression.element.name][expression.field]
+
+ def visit_dataset_field_reference(
+ self, expression: qt.DatasetFieldReference
+ ) -> sqlalchemy.ColumnElement[Any] | TimespanDatabaseRepresentation:
+ # Docstring inherited.
+ if expression.column_type == "timespan":
+ return self._sql_builder.timespans[expression.dataset_type]
+ return self._sql_builder.fields[expression.dataset_type][expression.field]
+
+ def visit_unary_expression(self, expression: qt.UnaryExpression) -> sqlalchemy.ColumnElement[Any]:
+ # Docstring inherited.
+ match expression.operator:
+ case "-":
+ return -self.expect_scalar(expression.operand)
+ case "begin_of":
+ return self.expect_timespan(expression.operand).lower()
+ case "end_of":
+ return self.expect_timespan(expression.operand).upper()
+ raise AssertionError(f"Invalid unary expression operator {expression.operator!r}.")
+
+ def visit_binary_expression(self, expression: qt.BinaryExpression) -> sqlalchemy.ColumnElement[Any]:
+ # Docstring inherited.
+ a = self.expect_scalar(expression.a)
+ b = self.expect_scalar(expression.b)
+ match expression.operator:
+ case "+":
+ return a + b
+ case "-":
+ return a - b
+ case "*":
+ return a * b
+ case "/":
+ return a / b
+ case "%":
+ return a % b
+ raise AssertionError(f"Invalid binary expression operator {expression.operator!r}.")
+
+ def visit_reversed(self, expression: qt.Reversed) -> sqlalchemy.ColumnElement[Any]:
+ # Docstring inherited.
+ return self.expect_scalar(expression.operand).desc()
+
+ def visit_comparison(
+ self,
+ a: qt.ColumnExpression,
+ operator: qt.ComparisonOperator,
+ b: qt.ColumnExpression,
+ flags: PredicateVisitFlags,
+ ) -> sqlalchemy.ColumnElement[bool]:
+ # Docstring inherited.
+ if operator == "overlaps":
+ assert a.column_type == "timespan", "Spatial overlaps should be transformed away by now."
+ return self.expect_timespan(a).overlaps(self.expect_timespan(b))
+ lhs = self.expect_scalar(a)
+ rhs = self.expect_scalar(b)
+ match operator:
+ case "==":
+ return lhs == rhs
+ case "!=":
+ return lhs != rhs
+ case "<":
+ return lhs < rhs
+ case ">":
+ return lhs > rhs
+ case "<=":
+ return lhs <= rhs
+ case ">=":
+ return lhs >= rhs
+ raise AssertionError(f"Invalid comparison operator {operator!r}.")
+
+ def visit_is_null(
+ self, operand: qt.ColumnExpression, flags: PredicateVisitFlags
+ ) -> sqlalchemy.ColumnElement[bool]:
+ # Docstring inherited.
+ if operand.column_type == "timespan":
+ return self.expect_timespan(operand).isNull()
+ return self.expect_scalar(operand) == sqlalchemy.null()
+
+ def visit_in_container(
+ self,
+ member: qt.ColumnExpression,
+ container: tuple[qt.ColumnExpression, ...],
+ flags: PredicateVisitFlags,
+ ) -> sqlalchemy.ColumnElement[bool]:
+ # Docstring inherited.
+ return self.expect_scalar(member).in_([self.expect_scalar(item) for item in container])
+
+ def visit_in_range(
+ self, member: qt.ColumnExpression, start: int, stop: int | None, step: int, flags: PredicateVisitFlags
+ ) -> sqlalchemy.ColumnElement[bool]:
+ # Docstring inherited.
+ sql_member = self.expect_scalar(member)
+ if stop is None:
+ target = sql_member >= sqlalchemy.literal(start)
+ else:
+ stop_inclusive = stop - 1
+ if start == stop_inclusive:
+ return sql_member == sqlalchemy.literal(start)
+ else:
+ target = sqlalchemy.sql.between(
+ sql_member,
+ sqlalchemy.literal(start),
+ sqlalchemy.literal(stop_inclusive),
+ )
+ if step != 1:
+ return sqlalchemy.sql.and_(
+ *[
+ target,
+ sql_member % sqlalchemy.literal(step) == sqlalchemy.literal(start % step),
+ ]
+ )
+ else:
+ return target
+
+ def visit_in_query_tree(
+ self,
+ member: qt.ColumnExpression,
+ column: qt.ColumnExpression,
+ query_tree: qt.QueryTree,
+ flags: PredicateVisitFlags,
+ ) -> sqlalchemy.ColumnElement[bool]:
+ # Docstring inherited.
+ columns = qt.ColumnSet(self._driver.universe.empty.as_group())
+ column.gather_required_columns(columns)
+ query, sql_builder = self._driver.analyze_query(query_tree, columns)
+ self._driver.build_query(query, sql_builder)
+ if query.postprocessing:
+ raise NotImplementedError(
+ "Right-hand side subquery in IN expression would require postprocessing."
+ )
+ subquery_visitor = SqlColumnVisitor(sql_builder, self._driver)
+ sql_builder.special["_MEMBER"] = subquery_visitor.expect_scalar(column)
+ subquery_select = sql_builder.select(qt.ColumnSet(self._driver.universe.empty.as_group()))
+ sql_member = self.expect_scalar(member)
+ return sql_member.in_(subquery_select)
+
+ def apply_logical_and(
+ self, originals: qt.PredicateOperands, results: tuple[sqlalchemy.ColumnElement[bool], ...]
+ ) -> sqlalchemy.ColumnElement[bool]:
+ # Docstring inherited.
+ return sqlalchemy.and_(*results)
+
+ def apply_logical_or(
+ self,
+ originals: tuple[qt.PredicateLeaf, ...],
+ results: tuple[sqlalchemy.ColumnElement[bool], ...],
+ flags: PredicateVisitFlags,
+ ) -> sqlalchemy.ColumnElement[bool]:
+ # Docstring inherited.
+ return sqlalchemy.or_(*results)
+
+ def apply_logical_not(
+ self, original: qt.PredicateLeaf, result: sqlalchemy.ColumnElement[bool], flags: PredicateVisitFlags
+ ) -> sqlalchemy.ColumnElement[bool]:
+ # Docstring inherited.
+ return sqlalchemy.not_(result)
+
+ def expect_scalar(self, expression: qt.OrderExpression) -> sqlalchemy.ColumnElement[Any]:
+ return cast(sqlalchemy.ColumnElement[Any], expression.visit(self))
+
+ def expect_timespan(self, expression: qt.ColumnExpression) -> TimespanDatabaseRepresentation:
+ return cast(TimespanDatabaseRepresentation, expression.visit(self))
diff --git a/python/lsst/daf/butler/registry/collections/nameKey.py b/python/lsst/daf/butler/registry/collections/nameKey.py
index ccd7d26b6a..7da4390656 100644
--- a/python/lsst/daf/butler/registry/collections/nameKey.py
+++ b/python/lsst/daf/butler/registry/collections/nameKey.py
@@ -179,6 +179,12 @@ def getParentChains(self, key: str) -> set[str]:
parent_names = set(sql_result.scalars().all())
return parent_names
+ def lookup_name_sql(
+ self, sql_key: sqlalchemy.ColumnElement[str], sql_from_clause: sqlalchemy.FromClause
+ ) -> tuple[sqlalchemy.ColumnElement[str], sqlalchemy.FromClause]:
+ # Docstring inherited.
+ return sql_key, sql_from_clause
+
def _fetch_by_name(self, names: Iterable[str]) -> list[CollectionRecord[str]]:
# Docstring inherited from base class.
return self._fetch_by_key(names)
diff --git a/python/lsst/daf/butler/registry/collections/synthIntKey.py b/python/lsst/daf/butler/registry/collections/synthIntKey.py
index b96a42f0fc..38605a6ad9 100644
--- a/python/lsst/daf/butler/registry/collections/synthIntKey.py
+++ b/python/lsst/daf/butler/registry/collections/synthIntKey.py
@@ -180,6 +180,17 @@ def getParentChains(self, key: int) -> set[str]:
parent_names = set(sql_result.scalars().all())
return parent_names
+ def lookup_name_sql(
+ self, sql_key: sqlalchemy.ColumnElement[int], sql_from_clause: sqlalchemy.FromClause
+ ) -> tuple[sqlalchemy.ColumnElement[str], sqlalchemy.FromClause]:
+ # Docstring inherited.
+ return (
+ self._tables.collection.c.name,
+ sql_from_clause.join(
+ self._tables.collection, onclause=self._tables.collection.c[_KEY_FIELD_SPEC.name] == sql_key
+ ),
+ )
+
def _fetch_by_name(self, names: Iterable[str]) -> list[CollectionRecord[int]]:
# Docstring inherited from base class.
_LOG.debug("Fetching collection records using names %s.", names)
diff --git a/python/lsst/daf/butler/registry/datasets/byDimensions/_storage.py b/python/lsst/daf/butler/registry/datasets/byDimensions/_storage.py
index c67d0b6fb8..7a7a7ae7fb 100644
--- a/python/lsst/daf/butler/registry/datasets/byDimensions/_storage.py
+++ b/python/lsst/daf/butler/registry/datasets/byDimensions/_storage.py
@@ -34,7 +34,7 @@
import datetime
from collections.abc import Callable, Iterable, Iterator, Sequence, Set
-from typing import TYPE_CHECKING
+from typing import TYPE_CHECKING, Literal
import astropy.time
import sqlalchemy
@@ -46,11 +46,13 @@
from ...._dataset_type import DatasetType
from ...._timespan import Timespan
from ....dimensions import DataCoordinate
+from ....direct_query_driver import SqlBuilder # new query system, server+direct only
+from ....queries import tree as qt # new query system, both clients + server
from ..._collection_summary import CollectionSummary
from ..._collection_type import CollectionType
from ..._exceptions import CollectionTypeError, ConflictingDefinitionError
from ...interfaces import DatasetRecordStorage
-from ...queries import SqlQueryContext
+from ...queries import SqlQueryContext # old registry query system
from .tables import makeTagTableSpec
if TYPE_CHECKING:
@@ -552,6 +554,212 @@ def _finish_single_relation(
)
return leaf
+ def make_sql_builder(
+ self,
+ collections: Sequence[CollectionRecord],
+ fields: Set[qt.DatasetFieldName | Literal["collection_key"]],
+ ) -> SqlBuilder:
+ # This method largely mimics `make_relation`, but it uses the new query
+ # system primitives instead of the old one. In terms of the SQL
+ # queries it builds, there are two more main differences:
+ #
+ # - Collection and run columns are now string names rather than IDs.
+ # This insulates the query result-processing code from collection
+ # caching and the collection manager subclass details.
+ #
+ # - The subquery always has unique rows, which is achieved by using
+ # SELECT DISTINCT when necessary.
+ #
+ collection_types = {collection.type for collection in collections}
+ assert CollectionType.CHAINED not in collection_types, "CHAINED collections must be flattened."
+ #
+ # There are two kinds of table in play here:
+ #
+ # - the static dataset table (with the dataset ID, dataset type ID,
+ # run ID/name, and ingest date);
+ #
+ # - the dynamic tags/calibs table (with the dataset ID, dataset type
+ # type ID, collection ID/name, data ID, and possibly validity
+ # range).
+ #
+ # That means that we might want to return a query against either table
+ # or a JOIN of both, depending on which quantities the caller wants.
+ # But the data ID is always included, which means we'll always include
+ # the tags/calibs table and join in the static dataset table only if we
+ # need things from it that we can't get from the tags/calibs table.
+ #
+ # Note that it's important that we include a WHERE constraint on both
+ # tables for any column (e.g. dataset_type_id) that is in both when
+ # it's given explicitly; not doing can prevent the query planner from
+ # using very important indexes. At present, we don't include those
+ # redundant columns in the JOIN ON expression, however, because the
+ # FOREIGN KEY (and its index) are defined only on dataset_id.
+ tag_sql_builder: SqlBuilder | None = None
+ if collection_types != {CollectionType.CALIBRATION}:
+ # We'll need a subquery for the tags table if any of the given
+ # collections are not a CALIBRATION collection. This intentionally
+ # also fires when the list of collections is empty as a way to
+ # create a dummy subquery that we know will fail.
+ # We give the table an alias because it might appear multiple times
+ # in the same query, for different dataset types.
+ tag_sql_builder = SqlBuilder(self._db, self._tags.alias(f"{self.datasetType.name}_tags"))
+ if "timespan" in fields:
+ tag_sql_builder.timespans[self.datasetType.name] = (
+ self._db.getTimespanRepresentation().fromLiteral(Timespan(None, None))
+ )
+ tag_sql_builder = self._finish_sql_builder(
+ tag_sql_builder,
+ [
+ (record, rank)
+ for rank, record in enumerate(collections)
+ if record.type is not CollectionType.CALIBRATION
+ ],
+ fields,
+ )
+ calib_sql_builder: SqlBuilder | None = None
+ if CollectionType.CALIBRATION in collection_types:
+ # If at least one collection is a CALIBRATION collection, we'll
+ # need a subquery for the calibs table, and could include the
+ # timespan as a result or constraint.
+ assert (
+ self._calibs is not None
+ ), "DatasetTypes with isCalibration() == False can never be found in a CALIBRATION collection."
+ calib_sql_builder = SqlBuilder(self._db, self._calibs.alias(f"{self.datasetType.name}_calibs"))
+ if "timespan" in fields:
+ calib_sql_builder.timespans[self.datasetType.name] = (
+ self._db.getTimespanRepresentation().from_columns(self._calibs.columns)
+ )
+ calib_sql_builder = self._finish_sql_builder(
+ calib_sql_builder,
+ [
+ (record, rank)
+ for rank, record in enumerate(collections)
+ if record.type is CollectionType.CALIBRATION
+ ],
+ fields,
+ )
+ # In calibration collections, we need timespan as well as data ID
+ # to ensure unique rows.
+ calib_sql_builder.needs_distinct = calib_sql_builder.needs_distinct and "timespan" not in fields
+ columns = qt.ColumnSet(self.datasetType.dimensions.as_group())
+ columns.dataset_fields[self.datasetType.name].update(fields)
+ columns.drop_implied_dimension_keys()
+ if tag_sql_builder is not None:
+ if calib_sql_builder is not None:
+ # Need a UNION subquery.
+ return tag_sql_builder.union_subquery([calib_sql_builder], columns)
+ elif tag_sql_builder.needs_distinct:
+ # Need a SELECT DISTINCT subquery.
+ return tag_sql_builder.subquery(columns)
+ else:
+ return tag_sql_builder
+ elif calib_sql_builder is not None:
+ if calib_sql_builder.needs_distinct:
+ return calib_sql_builder.subquery(columns)
+ else:
+ return calib_sql_builder
+ else:
+ raise AssertionError("Branch should be unreachable.")
+
+ def _finish_sql_builder(
+ self,
+ sql_builder: SqlBuilder,
+ collections: Sequence[tuple[CollectionRecord, int]],
+ fields: Set[qt.DatasetFieldName | Literal["collection_key"]],
+ ) -> SqlBuilder:
+ # This method plays the same role as _finish_single_relation in the new
+ # query system. It is called exactly one or two times by
+ # make_sql_builder, just as _finish_single_relation is called exactly
+ # one or two times by make_relation. See make_sql_builder comments for
+ # what's different.
+ assert sql_builder.sql_from_clause is not None
+ run_collections_only = all(record.type is CollectionType.RUN for record, _ in collections)
+ sql_builder.where_sql(sql_builder.sql_from_clause.c.dataset_type_id == self._dataset_type_id)
+ dataset_id_col = sql_builder.sql_from_clause.c.dataset_id
+ collection_col = sql_builder.sql_from_clause.c[self._collections.getCollectionForeignKeyName()]
+ fields_provided = sql_builder.fields[self.datasetType.name]
+ # We always constrain and optionally retrieve the collection(s) via the
+ # tags/calibs table.
+ if "collection_key" in fields:
+ sql_builder.fields[self.datasetType.name]["collection_key"] = collection_col
+ if len(collections) == 1:
+ only_collection_record, _ = collections[0]
+ sql_builder.where_sql(collection_col == only_collection_record.key)
+ if "collection" in fields:
+ fields_provided["collection"] = sqlalchemy.literal(only_collection_record.name)
+ elif not collections:
+ sql_builder.where_sql(sqlalchemy.literal(False))
+ if "collection" in fields:
+ fields_provided["collection"] = sqlalchemy.literal("NO COLLECTIONS")
+ else:
+ sql_builder.where_sql(collection_col.in_([collection.key for collection, _ in collections]))
+ if "collection" in fields:
+ # Avoid a join to the collection table to get the name by using
+ # a CASE statement. The SQL will be a bit more verbose but
+ # more efficient.
+ fields_provided["collection"] = sqlalchemy.case(
+ {record.key: record.name for record, _ in collections}, value=collection_col
+ )
+ # Add more column definitions, starting with the data ID.
+ sql_builder.extract_dimensions(self.datasetType.dimensions.required.names)
+ # We can always get the dataset_id from the tags/calibs table, even if
+ # could also get it from the 'static' dataset table.
+ if "dataset_id" in fields:
+ fields_provided["dataset_id"] = dataset_id_col
+
+ # It's possible we now have everything we need, from just the
+ # tags/calibs table. The things we might need to get from the static
+ # dataset table are the run key and the ingest date.
+ need_static_table = False
+ if "run" in fields:
+ if len(collections) == 1 and run_collections_only:
+ # If we are searching exactly one RUN collection, we
+ # know that if we find the dataset in that collection,
+ # then that's the datasets's run; we don't need to
+ # query for it.
+ fields_provided["run"] = sqlalchemy.literal(only_collection_record.name)
+ elif run_collections_only:
+ # Once again we can avoid joining to the collection table by
+ # adding a CASE statement.
+ fields_provided["run"] = sqlalchemy.case(
+ {record.key: record.name for record, _ in collections},
+ value=self._static.dataset.c[self._runKeyColumn],
+ )
+ need_static_table = True
+ else:
+ # Here we can't avoid a join to the collection table, because
+ # we might find a dataset via something other than its RUN
+ # collection.
+ fields_provided["run"], sql_builder.sql_from_clause = self._collections.lookup_name_sql(
+ self._static.dataset.c[self._runKeyColumn],
+ sql_builder.sql_from_clause,
+ )
+ need_static_table = True
+ # Ingest date can only come from the static table.
+ if "ingest_date" in fields:
+ fields_provided["ingest_date"] = self._static.dataset.c.ingest_date
+ need_static_table = True
+ if need_static_table:
+ # If we need the static table, join it in via dataset_id.
+ sql_builder.sql_from_clause = sql_builder.sql_from_clause.join(
+ self._static.dataset, onclause=(dataset_id_col == self._static.dataset.c.id)
+ )
+ # Also constrain dataset_type_id in static table in case that helps
+ # generate a better plan. We could also include this in the JOIN ON
+ # clause, but my guess is that that's a good idea IFF it's in the
+ # foreign key, and right now it isn't.
+ sql_builder.where_sql(self._static.dataset.c.dataset_type_id == self._dataset_type_id)
+ sql_builder.needs_distinct = (
+ # If there are multiple collections and we're searching any non-RUN
+ # collection, we could find the same dataset twice, which would
+ # yield duplicate rows unless "collection" or "rank" is there to
+ # make those rows unique.
+ len(collections) > 1
+ and not run_collections_only
+ and ("collection_key" not in fields)
+ )
+ return sql_builder
+
def getDataId(self, id: DatasetId) -> DataCoordinate:
"""Return DataId for a dataset.
diff --git a/python/lsst/daf/butler/registry/dimensions/static.py b/python/lsst/daf/butler/registry/dimensions/static.py
index 3a3903e855..ceda38b95a 100644
--- a/python/lsst/daf/butler/registry/dimensions/static.py
+++ b/python/lsst/daf/butler/registry/dimensions/static.py
@@ -30,17 +30,19 @@
import itertools
import logging
from collections import defaultdict
-from collections.abc import Mapping, Sequence, Set
+from collections.abc import Iterable, Mapping, Sequence, Set
from typing import TYPE_CHECKING, Any
import sqlalchemy
from lsst.daf.relation import Calculation, ColumnExpression, Join, Relation, sql
+from lsst.sphgeom import Region
from ... import ddl
from ..._column_tags import DimensionKeyColumnTag, DimensionRecordColumnTag
from ..._column_type_info import LogicalColumn
from ..._named import NamedKeyDict
from ...dimensions import (
+ DatabaseDimensionElement,
DatabaseTopologicalFamily,
DataCoordinate,
Dimension,
@@ -53,11 +55,15 @@
addDimensionForeignKey,
)
from ...dimensions.record_cache import DimensionRecordCache
+from ...direct_query_driver import Postprocessing, SqlBuilder # Future query system (direct,server).
+from ...queries import tree as qt # Future query system (direct,client,server)
+from ...queries.overlaps import OverlapsVisitor
+from ...queries.visitors import PredicateVisitFlags
from .._exceptions import MissingSpatialOverlapError
from ..interfaces import Database, DimensionRecordStorageManager, StaticTablesContext, VersionTuple
if TYPE_CHECKING:
- from .. import queries
+ from .. import queries # Current Registry.query* system.
# This has to be updated on every schema change
@@ -426,6 +432,38 @@ def make_spatial_join_relation(
)
return overlaps, needs_refinement
+ def make_sql_builder(self, element: DimensionElement, fields: Set[str]) -> SqlBuilder:
+ if element.implied_union_target is not None:
+ assert not fields, "Dimensions with implied-union storage never have fields."
+ return self.make_sql_builder(element.implied_union_target, fields).subquery(
+ qt.ColumnSet(element.minimal_group).drop_implied_dimension_keys(), distinct=True
+ )
+ if not element.has_own_table:
+ raise NotImplementedError(f"Cannot join dimension element {element} with no table.")
+ table = self._tables[element.name]
+ result = SqlBuilder(self._db, table)
+ for dimension_name, column_name in zip(element.required.names, element.schema.required.names):
+ result.dimension_keys[dimension_name].append(table.columns[column_name])
+ result.extract_dimensions(element.implied.names)
+ for field in fields:
+ if field == "timespan":
+ result.timespans[element.name] = self._db.getTimespanRepresentation().from_columns(
+ table.columns
+ )
+ else:
+ result.fields[element.name][field] = table.columns[field]
+ return result
+
+ def process_query_overlaps(
+ self,
+ dimensions: DimensionGroup,
+ predicate: qt.Predicate,
+ join_operands: Iterable[DimensionGroup],
+ ) -> tuple[qt.Predicate, SqlBuilder, Postprocessing]:
+ overlaps_visitor = _CommonSkyPixMediatedOverlapsVisitor(self._db, dimensions, self._overlap_tables)
+ new_predicate = overlaps_visitor.run(predicate, join_operands)
+ return new_predicate, overlaps_visitor.sql_builder, overlaps_visitor.postprocessing
+
def _make_relation(
self,
element: DimensionElement,
@@ -948,3 +986,159 @@ def load(self, key: int) -> DimensionGroup:
self.refresh()
graph = self._groupsByKey[key]
return graph
+
+
+class _CommonSkyPixMediatedOverlapsVisitor(OverlapsVisitor):
+ def __init__(
+ self,
+ db: Database,
+ dimensions: DimensionGroup,
+ overlap_tables: Mapping[str, tuple[sqlalchemy.Table, sqlalchemy.Table]],
+ ):
+ super().__init__(dimensions)
+ self.sql_builder: SqlBuilder = SqlBuilder(db)
+ self.postprocessing = Postprocessing()
+ self.common_skypix = dimensions.universe.commonSkyPix
+ self.overlap_tables: Mapping[str, tuple[sqlalchemy.Table, sqlalchemy.Table]] = overlap_tables
+ self.common_skypix_overlaps_done: set[DatabaseDimensionElement] = set()
+
+ def visit_spatial_constraint(
+ self,
+ element: DimensionElement,
+ region: Region,
+ flags: PredicateVisitFlags,
+ ) -> qt.Predicate | None:
+ # Reject spatial constraints that are nested inside OR or NOT, because
+ # the postprocessing needed for those would be a lot harder.
+ if flags & PredicateVisitFlags.INVERTED or flags & PredicateVisitFlags.HAS_OR_SIBLINGS:
+ raise NotImplementedError(
+ "Spatial overlap constraints nested inside OR or NOT are not supported."
+ )
+ # Delegate to super just because that's good practice with
+ # OverlapVisitor.
+ super().visit_spatial_constraint(element, region, flags)
+ match element:
+ case DatabaseDimensionElement():
+ # If this is a database dimension element like tract, patch, or
+ # visit, we need to:
+ # - join in the common skypix overlap table for this element;
+ # - constrain the common skypix index to be inside the
+ # ranges that overlap the region as a SQL where clause;
+ # - add postprocessing to reject rows where the database
+ # dimension element's region doesn't actually overlap the
+ # region.
+ self.postprocessing.spatial_where_filtering.append((element, region))
+ if self.common_skypix.name in self.dimensions:
+ # The common skypix dimension should be part of the query
+ # as a first-class dimension, so we can join in the overlap
+ # table directly, and fall through to the end of this
+ # function to construct a Predicate that will turn into the
+ # SQL WHERE clause we want.
+ self._join_common_skypix_overlap(element)
+ skypix = self.common_skypix
+ else:
+ # We need to hide the common skypix dimension from the
+ # larger query, so we make a subquery out of the overlap
+ # table that embeds the SQL WHERE clause we want and then
+ # projects out that dimension (with SELECT DISTINCT, to
+ # avoid introducing duplicate rows into the larger query).
+ overlap_sql_builder = self._make_common_skypix_overlap_sql_builder(element)
+ sql_where_or: list[sqlalchemy.ColumnElement[bool]] = []
+ sql_skypix_col = overlap_sql_builder.dimension_keys[self.common_skypix.name][0]
+ for begin, end in self.common_skypix.pixelization.envelope(region):
+ sql_where_or.append(sqlalchemy.and_(sql_skypix_col >= begin, sql_skypix_col < end))
+ overlap_sql_builder.where_sql(sqlalchemy.or_(*sql_where_or))
+ self.sql_builder = self.sql_builder.join(
+ overlap_sql_builder.subquery(
+ qt.ColumnSet(element.minimal_group).drop_implied_dimension_keys(), distinct=True
+ )
+ )
+ # Short circuit here since the SQL WHERE clause has already
+ # been embedded in the subquery.
+ return qt.Predicate.from_bool(True)
+ case SkyPixDimension():
+ # If this is a skypix dimension, we can do a index-in-ranges
+ # test directly on that dimension. Note that this doesn't on
+ # its own guarantee the skypix dimension column will be in the
+ # query; that'll be the job of the DirectQueryDriver to sort
+ # out (generally this will require a dataset using that skypix
+ # dimension to be joined in, unless this is the common skypix
+ # system).
+ assert (
+ element.name in self.dimensions
+ ), "QueryTree guarantees dimensions are expanded when constraints are added."
+ skypix = element
+ case _:
+ raise NotImplementedError(
+ f"Spatial overlap constraint for dimension {element} not supported."
+ )
+ # Convert the region-overlap constraint into a skypix
+ # index range-membership constraint in SQL...
+ result = qt.Predicate.from_bool(False)
+ skypix_col_ref = qt.DimensionKeyReference.model_construct(dimension=skypix)
+ for begin, end in skypix.pixelization.envelope(region):
+ result = result.logical_or(qt.Predicate.in_range(skypix_col_ref, start=begin, stop=end))
+ return result
+
+ def visit_spatial_join(
+ self, a: DimensionElement, b: DimensionElement, flags: PredicateVisitFlags
+ ) -> qt.Predicate | None:
+ # Reject spatial joins that are nested inside OR or NOT, because the
+ # postprocessing needed for those would be a lot harder.
+ if flags & PredicateVisitFlags.INVERTED or flags & PredicateVisitFlags.HAS_OR_SIBLINGS:
+ raise NotImplementedError("Spatial overlap joins nested inside OR or NOT are not supported.")
+ # Delegate to super to check for invalid joins and record this
+ # "connection" for use when seeing whether to add an automatic join
+ # later.
+ super().visit_spatial_join(a, b, flags)
+ match (a, b):
+ case (self.common_skypix, DatabaseDimensionElement() as b):
+ self._join_common_skypix_overlap(b)
+ case (DatabaseDimensionElement() as a, self.common_skypix):
+ self._join_common_skypix_overlap(a)
+ case (DatabaseDimensionElement() as a, DatabaseDimensionElement() as b):
+ if self.common_skypix.name in self.dimensions:
+ # We want the common skypix dimension to appear in the
+ # query as a first-class dimension, so just join in the
+ # two overlap tables directly.
+ self._join_common_skypix_overlap(a)
+ self._join_common_skypix_overlap(b)
+ else:
+ # We do not want the common skypix system to appear in the
+ # query or cause duplicate rows, so we join the two overlap
+ # tables in a subquery that projects out the common skypix
+ # index column with SELECT DISTINCT.
+
+ self.sql_builder = self.sql_builder.join(
+ self._make_common_skypix_overlap_sql_builder(a)
+ .join(self._make_common_skypix_overlap_sql_builder(b))
+ .subquery(
+ qt.ColumnSet(a.minimal_group | b.minimal_group).drop_implied_dimension_keys(),
+ distinct=True,
+ )
+ )
+ # In both cases we add postprocessing to check that the regions
+ # really do overlap, since overlapping the same common skypix
+ # tile is necessary but not sufficient for that.
+ self.postprocessing.spatial_join_filtering.append((a, b))
+ case _:
+ raise NotImplementedError(f"Unsupported combination for spatial join: {a, b}.")
+ return qt.Predicate.from_bool(True)
+
+ def _join_common_skypix_overlap(self, element: DatabaseDimensionElement) -> None:
+ if element not in self.common_skypix_overlaps_done:
+ self.sql_builder = self.sql_builder.join(self._make_common_skypix_overlap_sql_builder(element))
+ self.common_skypix_overlaps_done.add(element)
+
+ def _make_common_skypix_overlap_sql_builder(self, element: DatabaseDimensionElement) -> SqlBuilder:
+ _, overlap_table = self.overlap_tables[element.name]
+ return self.sql_builder.join(
+ SqlBuilder(self.sql_builder.db, overlap_table)
+ .extract_dimensions(element.required.names, skypix_index=self.common_skypix.name)
+ .where_sql(
+ sqlalchemy.and_(
+ overlap_table.c.skypix_system == self.common_skypix.system.name,
+ overlap_table.c.skypix_level == self.common_skypix.level,
+ )
+ )
+ )
diff --git a/python/lsst/daf/butler/registry/interfaces/_collections.py b/python/lsst/daf/butler/registry/interfaces/_collections.py
index cef7b9741f..418d264460 100644
--- a/python/lsst/daf/butler/registry/interfaces/_collections.py
+++ b/python/lsst/daf/butler/registry/interfaces/_collections.py
@@ -39,6 +39,8 @@
from collections.abc import Iterable, Set
from typing import TYPE_CHECKING, Any, Generic, Self, TypeVar
+import sqlalchemy
+
from ..._timespan import Timespan
from .._collection_type import CollectionType
from ..wildcards import CollectionWildcard
@@ -621,3 +623,27 @@ def update_chain(
`~CollectionType.CHAINED` collections in ``children`` first.
"""
raise NotImplementedError()
+
+ @abstractmethod
+ def lookup_name_sql(
+ self, sql_key: sqlalchemy.ColumnElement[_Key], sql_from_clause: sqlalchemy.FromClause
+ ) -> tuple[sqlalchemy.ColumnElement[str], sqlalchemy.FromClause]:
+ """Return a SQLAlchemy column and FROM clause that enable a query
+ to look up a collection name from the key.
+
+ Parameters
+ ----------
+ sql_key : `sqlalchemy.ColumnElement`
+ SQL column expression that evaluates to the collection key.
+ sql_from_clause : `sqlalchemy.FromClause`
+ SQL FROM clause from which ``sql_key`` was obtained.
+
+ Returns
+ -------
+ sql_name : `sqlalchemy.ColumnElement` [ `str` ]
+ SQL column expression that evalutes to the collection name.
+ sql_from_clause : `sqlalchemy.FromClause`
+ SQL FROM clause that includes the given ``sql_from_clause`` and
+ any table needed to provided ``sql_name``.
+ """
+ raise NotImplementedError()
diff --git a/python/lsst/daf/butler/registry/interfaces/_datasets.py b/python/lsst/daf/butler/registry/interfaces/_datasets.py
index abc85d4a05..6c401a9e47 100644
--- a/python/lsst/daf/butler/registry/interfaces/_datasets.py
+++ b/python/lsst/daf/butler/registry/interfaces/_datasets.py
@@ -32,8 +32,8 @@
__all__ = ("DatasetRecordStorageManager", "DatasetRecordStorage")
from abc import ABC, abstractmethod
-from collections.abc import Iterable, Iterator, Mapping, Set
-from typing import TYPE_CHECKING, Any
+from collections.abc import Iterable, Iterator, Mapping, Sequence, Set
+from typing import TYPE_CHECKING, Any, Literal
from lsst.daf.relation import Relation
@@ -45,9 +45,11 @@
from ._versioning import VersionedExtension, VersionTuple
if TYPE_CHECKING:
+ from ...direct_query_driver import SqlBuilder # new query system, server+direct only
+ from ...queries import tree as qt # new query system, both clients + server
from .._caching_context import CachingContext
from .._collection_summary import CollectionSummary
- from ..queries import SqlQueryContext
+ from ..queries import SqlQueryContext # old registry query system
from ._collections import CollectionManager, CollectionRecord, RunRecord
from ._database import Database, StaticTablesContext
from ._dimensions import DimensionRecordStorageManager
@@ -311,6 +313,14 @@ def make_relation(
"""
raise NotImplementedError()
+ @abstractmethod
+ def make_sql_builder(
+ self,
+ collections: Sequence[CollectionRecord],
+ fields: Set[qt.DatasetFieldName | Literal["collection_key"]],
+ ) -> SqlBuilder:
+ raise NotImplementedError()
+
datasetType: DatasetType
"""Dataset type whose records this object manages (`DatasetType`).
"""
diff --git a/python/lsst/daf/butler/registry/interfaces/_dimensions.py b/python/lsst/daf/butler/registry/interfaces/_dimensions.py
index c14e19ca29..9f6aff6b6d 100644
--- a/python/lsst/daf/butler/registry/interfaces/_dimensions.py
+++ b/python/lsst/daf/butler/registry/interfaces/_dimensions.py
@@ -29,7 +29,7 @@
__all__ = ("DimensionRecordStorageManager",)
from abc import abstractmethod
-from collections.abc import Set
+from collections.abc import Iterable, Set
from typing import TYPE_CHECKING, Any
from lsst.daf.relation import Join, Relation
@@ -46,7 +46,9 @@
from ._versioning import VersionedExtension, VersionTuple
if TYPE_CHECKING:
- from .. import queries
+ from ...direct_query_driver import Postprocessing, SqlBuilder # Future query system (direct,server).
+ from ...queries.tree import Predicate # Future query system (direct,client,server).
+ from .. import queries # Old Registry.query* system.
from ._database import Database, StaticTablesContext
@@ -357,6 +359,19 @@ def make_spatial_join_relation(
"""
raise NotImplementedError()
+ @abstractmethod
+ def make_sql_builder(self, element: DimensionElement, fields: Set[str]) -> SqlBuilder:
+ raise NotImplementedError()
+
+ @abstractmethod
+ def process_query_overlaps(
+ self,
+ dimensions: DimensionGroup,
+ predicate: Predicate,
+ join_operands: Iterable[DimensionGroup],
+ ) -> tuple[Predicate, SqlBuilder, Postprocessing]:
+ raise NotImplementedError()
+
universe: DimensionUniverse
"""Universe of all dimensions and dimension elements known to the
`Registry` (`DimensionUniverse`).