diff --git a/python/lsst/daf/butler/direct_query_driver.py b/python/lsst/daf/butler/direct_query_driver.py index 1522255e0a..b0fb12f221 100644 --- a/python/lsst/daf/butler/direct_query_driver.py +++ b/python/lsst/daf/butler/direct_query_driver.py @@ -36,9 +36,10 @@ from typing import TYPE_CHECKING, ClassVar, cast, overload import sqlalchemy +from lsst.sphgeom import Region from ._dataset_type import DatasetType -from .dimensions import DataIdValue, DimensionElement, DimensionGroup, DimensionUniverse, SkyPixDimension +from .dimensions import DataIdValue, DimensionElement, DimensionGroup, DimensionUniverse from .queries import relation_tree as rt from .queries.data_coordinate_results import DataCoordinateResultPage, DataCoordinateResultSpec from .queries.dataset_results import DatasetRefResultPage, DatasetRefResultSpec @@ -188,14 +189,18 @@ def _build_sql_select( match tree: case rt.Select(): sql_builder, postprocessing = self._process_select_tree(tree, columns_required) - return sql_builder.sql_select( - columns_to_select, postprocessing, order_by=order_by, limit=limit, offset=offset + sql_select, postprocessing = sql_builder.sql_select( + columns_to_select, + postprocessing, + order_by=order_by, + limit=limit, + offset=offset, ) case rt.OrderedSlice(): assert ( not order_by and limit is None and not offset ), "order_by/limit/offset args are for recursion only" - return self._build_sql_select( + sql_select, postprocessing = self._build_sql_select( tree.operand, columns_to_select, columns_required=columns_required, @@ -204,232 +209,386 @@ def _build_sql_select( offset=tree.offset, ) case rt.FindFirst(): - # The query we're building looks like this: - # - # WITH {dst}_search AS ( - # {target} - # ... - # ) - # SELECT - # {dst}_window.*, - # FROM ( - # SELECT - # {dst}_search.*, - # ROW_NUMBER() OVER ( - # PARTITION BY {dst_search}.{operation.dimensions} - # ORDER BY {operation.rank} - # ) AS rownum - # ) {dst}_window - # WHERE - # {dst}_window.rownum = 1; - # - - # We'll start with the Common Table Expression (CTE) at the - # top, which we mostly get from recursing on the FindFirst - # relation's operand. Note that we need to use - # 'columns_required' to populate the SELECT clause list, - # because this isn't the outermost query. - rank_column = rt.DatasetFieldReference.model_construct( - dataset_type=tree.dataset_type, field="rank" - ) - search_select, postprocessing = self._build_sql_select( - tree.operand, columns_required | {rank_column} - ) - columns_required.update(postprocessing.gather_columns_required()) - search_cte = search_select.cte(f"{tree.dataset_type}_search") - # Now we fill out the SELECT from the CTE, and the subquery it - # contains (at the same time, since they have the same columns, - # aside from the special 'rownum' window-function column). Once - # again the SELECT clause is populated by 'columns_required'. - partition_by = [search_cte.columns[d] for d in tree.dimensions.required] - rownum_column = sqlalchemy.sql.func.row_number() - if partition_by: - rownum_column = rownum_column.over( - partition_by=partition_by, order_by=search_cte.columns[rank_column.qualified_name] - ) - else: - rownum_column = rownum_column.over( - order_by=search_cte.columns[rank_column.qualified_name] - ) - window_select, postprocessing = ( - SqlBuilder(search_cte) - .with_data_coordinate_columns(tree.dimensions.names) - .with_qualified_name_fields(columns_required, self._timespan_db_repr) - .sql_select( - columns_required, - postprocessing, - sql_columns_to_select=[rownum_column.label("rownum")], - ) - ) - window_subquery = window_select.subquery(f"{tree.dataset_type}_window") - # Finally we make the outermost select, which is where we put - # the order_by, limit, and offset clauses, and use the actual - # 'columns_to_select` list for the SELECT clause. - full_select, postprocessing = ( - SqlBuilder(window_subquery) - .with_data_coordinate_columns(tree.dimensions.names) - .with_qualified_name_fields(columns_required, self._timespan_db_repr) - .sql_select( - columns_to_select, postprocessing, order_by=order_by, limit=limit, offset=offset - ) + sql_builder, postprocessing = self._process_find_first_tree(tree, columns_required) + sql_select, postprocessing = sql_builder.sql_select( + columns_to_select, + postprocessing, + order_by=order_by, + limit=limit, + offset=offset, ) - return full_select.where(window_subquery.columns["rownum"] == 1), postprocessing - raise AssertionError(f"Invalid root relation: {tree}.") + case _: + raise AssertionError(f"Invalid root relation: {tree}.") + return sql_select, postprocessing def _process_select_tree( - self, tree: rt.Select, columns_required: Set[rt.ColumnReference] + self, + tree: rt.Select, + columns_required: Set[rt.ColumnReference], ) -> tuple[EmptySqlBuilder | SqlBuilder, Postprocessing]: columns_required = set(columns_required) - # We start by processing spatial joins by joining together overlap - # tables and adding any postprocess filtering. We do this first - # because that postprocessing can require additional columns to be - # included in the SQL query, and that can affect what else we join in. - sql_builder, postprocessing = self._process_spatial_joins(tree.complete_spatial_joins()) - # Process temporal joins next, but only to put them in a bidirectional - # map and add their timespans to the columns_required set. - temporal_join_map: dict[str, list[rt.DimensionFieldReference | rt.DatasetFieldReference]] = {} - for name_pair in tree.complete_temporal_joins(): - for i in (False, True): - col_ref: rt.DimensionFieldReference | rt.DatasetFieldReference - if element := self._universe.get(name_pair[i]): - col_ref = rt.DimensionFieldReference.model_construct(element=element, field="timespan") - else: - col_ref = rt.DatasetFieldReference.model_construct( - dataset_type=name_pair[i], field="timespan" - ) - columns_required.add(col_ref) - temporal_join_map.setdefault(name_pair[not i], []).append(col_ref) - # Gather more required columns, then categorize them by where they need - # to come from. + # A pair of empty-for-now objects that represent the query we're + # building: the SQL query, and any post-SQL processing we'll do in + # Python. + sql_builder: SqlBuilder | EmptySqlBuilder = EmptySqlBuilder() + postprocessing = Postprocessing() + # Process temporal joins, but only to put them in a bidirectional + # map for easier lookup later and mark their timespans as required + # columns. + temporal_join_map = self._make_temporal_join_map(tree) + columns_required.update(itertools.chain.from_iterable(temporal_join_map.values())) + # Process spatial constraints and joins first, since those define the + # postprocessing we'll need to do we need to know about the columns + # that will involve early. This can also rewrite the WHERE clause + # terms, by transforming "column overlaps literal region" comparisons + # into a combination of "common skypix in ()" and + # postprocessing, and since there's no guarantee common skypix is in + # tree.dimensions, it can update the dimensions of the query, too. + full_dimensions, processed_where_terms = self._process_spatial(sql_builder, postprocessing, tree) columns_required.update(postprocessing.gather_columns_required()) - for predicate in tree.where_terms: - columns_required.update(predicate.gather_required_columns()) - fields = CategorizedFields().categorize( - columns_required, tree.dimensions, tree.available_dataset_types + # Now that we're done gathering columns_required, we categorize them + # into mappings keyed by what kind of table they come from: + dimension_tables_to_join, datasets_to_join = self._categorize_columns( + full_dimensions, columns_required, tree.available_dataset_types ) + # From here down, 'columns_required' is no long authoritative. del columns_required - # Process explicit join terms. - for join_operand in tree.join_operands: - match join_operand: - case rt.DatasetSearch(): - raise NotImplementedError("TODO") - case rt.DataCoordinateUpload(): - sql_builder = sql_builder.join( - SqlBuilder(self._upload_tables[join_operand.key]).with_data_coordinate_columns( - join_operand.dimensions.required - ) - ) - case rt.Materialization(): - if join_operand.dataset_types: - raise NotImplementedError("TODO") - sql_builder = sql_builder.join( - SqlBuilder( - self._materialization_tables[join_operand.key] - ).with_data_coordinate_columns(join_operand.dimensions.names) - ) - for element, fields_for_element in fields.dimension_fields.items(): - assert fields_for_element, "element should be absent if it does not provide any required fields" - element_sql_builder = SqlBuilder( - self._dimension_tables[element.name] - ).with_dimension_record_columns(element, fields_for_element, self._timespan_db_repr) - sql_builder = sql_builder.join( - element_sql_builder, - sql_join_on=[ - temporal_join_sql_column.overlaps( - element_sql_builder.timespans_provided[ - rt.DimensionFieldReference.model_construct(element=element, field="timespan") - ] - ) - for temporal_join_col_ref in temporal_join_map.get(element.name, []) - if (temporal_join_sql_column := sql_builder.timespans_provided.get(temporal_join_col_ref)) - is not None - ], + # Process explicit join operands. This also returns a set of dimension + # elements whose tables should be joined in in order to enforce + # one-to-many or many-to-many relationships that should be part of this + # query's dimensions but were not provided by any join operand. + sql_builder, relationship_elements = self._process_join_operands( + sql_builder, tree.join_operands, full_dimensions, temporal_join_map, datasets_to_join + ) + # Actually join in all of the dimension tables that provide either + # fields or relationships. + for element in relationship_elements: + dimension_tables_to_join.setdefault(element, set()) + for element, fields_for_element in dimension_tables_to_join.items(): + sql_builder = self._join_dimension_table( + sql_builder, element, temporal_join_map, fields_for_element ) + # See if any dimension keys are still missing, and if so join in their + # tables. Note that we know there are no fields needed from these, + # and no temporal joins in play. + missing_dimension_names = full_dimensions.names - sql_builder.dimensions_provided.keys() + while missing_dimension_names: + # Look for opportunities to join in multiple dimensions at once. + best = self._universe[ + max( + missing_dimension_names, + key=lambda name: len(self._universe[name].dimensions.names & missing_dimension_names), + ) + ] + if best.viewOf: + best = self._universe[best.viewOf] + elif not best.hasTable(): + raise NotImplementedError(f"No way to join missing dimension {best.name!r} into query.") + sql_builder = self._join_dimension_table(sql_builder, best, {}) raise NotImplementedError("TODO") - def _process_spatial_joins( - self, spatial_joins: Set[rt.JoinTuple] + def _process_find_first_tree( + self, tree: rt.FindFirst, columns_required: Set[rt.ColumnReference] ) -> tuple[EmptySqlBuilder | SqlBuilder, Postprocessing]: - # Set up empty output objects to update to update or replace as we go. - sql_builder: EmptySqlBuilder | SqlBuilder = EmptySqlBuilder() - postprocessing = Postprocessing() - for name_pair in spatial_joins: - table_pair = ( - self._dimension_skypix_overlap_tables.get(name_pair[0]), - self._dimension_skypix_overlap_tables.get(name_pair[1]), + # The query we're building looks like this: + # + # WITH {dst}_search AS ( + # {target} + # ... + # ) + # SELECT + # {dst}_window.*, + # FROM ( + # SELECT + # {dst}_search.*, + # ROW_NUMBER() OVER ( + # PARTITION BY {dst_search}.{operation.dimensions} + # ORDER BY {operation.rank} + # ) AS rownum + # ) {dst}_window + # WHERE + # {dst}_window.rownum = 1; + # + + # We'll start with the Common Table Expression (CTE) at the + # top, which we mostly get from recursing on the FindFirst + # relation's operand. Note that we need to use + # 'columns_required' to populate the SELECT clause list, + # because this isn't the outermost query. + rank_column = rt.DatasetFieldReference.model_construct(dataset_type=tree.dataset_type, field="rank") + search_select, postprocessing = self._build_sql_select(tree.operand, columns_required | {rank_column}) + columns_required |= postprocessing.gather_columns_required() + search_cte = search_select.cte(f"{tree.dataset_type}_search") + # Now we fill out the SELECT from the CTE, and the subquery it + # contains (at the same time, since they have the same columns, + # aside from the special 'rownum' window-function column). Once + # again the SELECT clause is populated by 'columns_required'. + partition_by = [search_cte.columns[d] for d in tree.dimensions.required] + rownum_column = sqlalchemy.sql.func.row_number() + if partition_by: + rownum_column = rownum_column.over( + partition_by=partition_by, order_by=search_cte.columns[rank_column.qualified_name] ) - sql_columns: list[sqlalchemy.ColumnElement] = [] - where_terms: list[sqlalchemy.ColumnElement] = [] - for name, table in zip(name_pair, table_pair): - if table: - for dimension_name in self._universe[name].required.names: - sql_columns.append(table.columns[dimension_name].label(dimension_name)) - where_terms.append(table.c.skypix_system == self._universe.commonSkyPix.system.name) - where_terms.append(table.c.skypix_level == self._universe.commonSkyPix.level) - elif name != self._universe.commonSkyPix.name: - raise NotImplementedError( - f"Only {self._universe.commonSkyPix.name} and non-skypix dimensions " - "can participate in spatial joins." - ) - from_clause: sqlalchemy.FromClause - if table_pair[0] and table_pair[1]: - from_clause = table_pair[0].join( - table_pair[1], onclause=(table_pair[0].c.skypix_index == table_pair[1].c.skypix_index) + else: + rownum_column = rownum_column.over(order_by=search_cte.columns[rank_column.qualified_name]) + window_select, postprocessing = ( + SqlBuilder(search_cte) + .extract_keys(tree.dimensions.names) + .extract_fields(columns_required, self._timespan_db_repr) + .sql_select( + columns_required, + postprocessing, + sql_columns_to_select=[rownum_column.label("rownum")], + ) + ) + window_subquery = window_select.subquery(f"{tree.dataset_type}_window") + return ( + SqlBuilder(window_subquery) + .extract_keys(tree.dimensions.names) + .extract_fields(columns_required, self._timespan_db_repr) + .where_sql(window_subquery.c.rownum == 1) + ), postprocessing + + def _process_spatial( + self, + sql_builder: SqlBuilder | EmptySqlBuilder, + postprocessing: Postprocessing, + tree: rt.Select, + ) -> tuple[DimensionGroup, list[rt.Predicate]]: + processed_where_terms: list[rt.Predicate] = [] + dimensions = tree.dimensions + for where_term in tree.where_terms: + processed_where_terms.append(where_term) + raise NotImplementedError( + "TODO: extract overlaps and turn them into postprocessing + common skypix ID tests." + ) + # Add automatic spatial joins to connect all spatial dimensions. + spatial_joins = rt.joins.complete_joins( + dimensions, + [operand.dimensions.spatial for operand in tree.join_operands], + tree.spatial_joins, + "spatial", + ) + # Categorize the joins into: + # - joins that directly involve the "common" skypix dimension and + # a non-skypix dimension, so direct overlaps are in the database; + common_skypix_joins: set[DimensionElement] = set() + # - joins that don't involve any skypix dimension, but will have to use + # the common skypix as an intermediate and then use post-query + # spatial region filtering; + postprocess_filter_joins: set[tuple[DimensionElement, DimensionElement]] = set() + # - joins that involve a skypix dimension other than the common one; + # we don't support these, but we hope to in the future. + for name_a, name_b in spatial_joins: + if ( + self._universe.commonSkyPix.name == name_a + and name_b not in self._universe.skypix_dimensions.names + ): + common_skypix_joins.add(self._universe[name_b]) + elif ( + self._universe.commonSkyPix.name == name_b + and name_a not in self._universe.skypix_dimensions.names + ): + common_skypix_joins.add(self._universe[name_a]) + elif name_a in self._universe.skypix_dimensions.names: + raise NotImplementedError( + f"Joins to skypix dimensions other than {self._universe.commonSkyPix.name} " + "are not yet supported." ) - postprocessing.spatial_join_filtering.append( - (self._universe[name_pair[0]], self._universe[name_pair[1]]) + elif name_b in self._universe.skypix_dimensions.names: + raise NotImplementedError( + f"Joins to skypix dimensions other than {self._universe.commonSkyPix.name} " + "are not yet supported." ) - elif table_pair[0]: - from_clause = table_pair[0] - sql_columns.append(table_pair[0].c.skypix_index.label(name_pair[1])) - elif table_pair[1]: - from_clause = table_pair[1] - sql_columns.append(table_pair[1].c.skypix_index.label(name_pair[0])) else: - raise AssertionError("Should be impossible due to join validation.") - sql_builder = sql_builder.join( - SqlBuilder( - sqlalchemy.select(*sql_columns) - .select_from(from_clause) - .where(*where_terms) - .subquery(f"{name_pair[0]}_{name_pair[1]}_overlap") - ) - ) - return sql_builder, postprocessing - - -@dataclasses.dataclass -class CategorizedFields: - dimension_fields: dict[DimensionElement, list[rt.DimensionFieldReference]] = dataclasses.field( - default_factory=dict - ) - dataset_fields: dict[str, list[rt.DatasetFieldReference]] = dataclasses.field(default_factory=dict) + postprocess_filter_joins.add((self._universe[name_a], self._universe[name_b])) + done: set[DimensionElement] = set() + # Join in overlap tables for fully in-database joins to common skypix. + for element in common_skypix_joins: + sql_builder = self._join_skypix_overlap(sql_builder, element) + done.add(element) + # Join in overlap tables for in-database + postprocess-filtered joins + # with mediated by common skypix. + for element_a, element_b in postprocess_filter_joins: + if element_a not in done: + sql_builder = self._join_skypix_overlap(sql_builder, element_a) + done.add(element_a) + if element_b not in done: + sql_builder = self._join_skypix_overlap(sql_builder, element_b) + done.add(element_b) + postprocessing.spatial_join_filtering.append((element_a, element_b)) + return dimensions, processed_where_terms + + def _make_temporal_join_map( + self, tree: rt.Select + ) -> Mapping[str, list[rt.DimensionFieldReference | rt.DatasetFieldReference]]: + temporal_join_map: dict[str, list[rt.DimensionFieldReference | rt.DatasetFieldReference]] = {} + for name_pair in rt.joins.complete_joins( + tree.dimensions, + [operand.dimensions.temporal for operand in tree.join_operands], + tree.temporal_joins, + "temporal", + ): + for i in (False, True): + col_ref: rt.DimensionFieldReference | rt.DatasetFieldReference + if element := self._universe.get(name_pair[i]): + col_ref = rt.DimensionFieldReference.model_construct(element=element, field="timespan") + else: + col_ref = rt.DatasetFieldReference.model_construct( + dataset_type=name_pair[i], field="timespan" + ) + temporal_join_map.setdefault(name_pair[not i], []).append(col_ref) + return temporal_join_map - def categorize( + def _categorize_columns( self, - columns_required: Iterable[rt.ColumnReference], dimensions: DimensionGroup, - dataset_types: Set[str], - ) -> CategorizedFields: + columns: Iterable[rt.ColumnReference], + available_dataset_types: frozenset[str], + ) -> tuple[ + dict[DimensionElement, set[rt.DimensionFieldReference]], + dict[str, set[rt.DatasetFieldReference]], + ]: only_dataset_type: str | None = None - for col_ref in columns_required: + dimension_tables: dict[DimensionElement, set[rt.DimensionFieldReference]] = {} + datasets: dict[str, set[rt.DatasetFieldReference]] = {name: set() for name in available_dataset_types} + for col_ref in columns: if col_ref.expression_type == "dimension_key": assert col_ref.dimension.name in dimensions elif col_ref.expression_type == "dimension_field": - self.dimension_fields.setdefault(col_ref.element, []).append(col_ref) + # The only field for SkyPix dimensions is their region, and we + # can always compute those from the ID outside the query + # system. + if col_ref.element in dimensions.universe.skypix_dimensions: + assert col_ref.element.name in dimensions + else: + dimension_tables.setdefault(col_ref.element, set()).add(col_ref) elif col_ref.expression_type == "dataset_field": if col_ref.dataset_type is ...: if only_dataset_type is None: - if len(dataset_types) > 1: + if len(datasets) > 1: raise ValueError( f"Reference to dataset field {col_ref} with no dataset type is " "ambiguous." ) - elif not dataset_types: + elif not datasets: raise ValueError(f"No datasets in query for reference to ataset field {col_ref}.") - (only_dataset_type,) = dataset_types + (only_dataset_type,) = datasets col_ref = col_ref.model_copy(update={"dataset_type": only_dataset_type}) - self.dataset_fields.setdefault(cast(str, col_ref.dataset_type), []).append(col_ref) - return self + datasets[cast(str, col_ref.dataset_type)].add(col_ref) + return dimension_tables, datasets + + def _process_join_operands( + self, + sql_builder: SqlBuilder | EmptySqlBuilder, + join_operands: Iterable[rt.JoinOperand], + dimensions: DimensionGroup, + temporal_join_map: Mapping[str, Iterable[rt.DimensionFieldReference | rt.DatasetFieldReference]], + datasets: Mapping[str, Set[rt.DatasetFieldReference]], + ) -> tuple[SqlBuilder | EmptySqlBuilder, Set[DimensionElement]]: + # Make a set of DimensionElements whose tables need to be joined in to + # make sure the output rows reflect one-to-many and many-to-many + # relationships. We'll remove from this as we join in dataset + # searches, data ID uploads, and materializations, because those can + # also provide those relationships. + relationship_elements: set[DimensionElement] = { + element + for name in dimensions.elements + if (element := self._universe[name]).implied or element.alwaysJoin + } + for join_operand in join_operands: + match join_operand: + case rt.DatasetSearch(): + # Drop relationship elements whose dimensions are a subset + # of the *required* dimensions of the dataset, since + # dataset tables only have columns for required dimensions. + relationship_elements = { + element + for element in relationship_elements + if not (element.minimal_group.names <= join_operand.dimensions.required) + } + raise NotImplementedError("TODO") + case rt.DataCoordinateUpload(): + sql_builder = sql_builder.join( + SqlBuilder(self._upload_tables[join_operand.key]).extract_keys( + join_operand.dimensions.required + ) + ) + # Drop relationship elements whose dimensions are a subset + # of the *required* dimensions of the upload, since uploads + # only have columns for required dimensions. + relationship_elements = { + element + for element in relationship_elements + if not (element.minimal_group.names <= join_operand.dimensions.required) + } + case rt.Materialization(): + if join_operand.dataset_types: + raise NotImplementedError("TODO") + sql_builder = sql_builder.join( + SqlBuilder(self._materialization_tables[join_operand.key]).extract_keys( + join_operand.dimensions.names + ) + ) + # Drop relationship elements whose dimensions are a subset + # dimensions of the materialization, since materializations + # have full dimension columns. + relationship_elements = { + element + for element in relationship_elements + if not (element.minimal_group <= join_operand.dimensions) + } + case _: + raise AssertionError(f"Invalid join operand {join_operand}.") + return sql_builder, relationship_elements + + def _join_dimension_table( + self, + sql_builder: SqlBuilder | EmptySqlBuilder, + element: DimensionElement, + temporal_join_map: Mapping[str, Iterable[rt.DimensionFieldReference | rt.DatasetFieldReference]], + fields: Set[rt.DimensionFieldReference] = frozenset(), + ) -> SqlBuilder: + table = self._dimension_tables[element.name] + element_sql_builder = SqlBuilder(table) + for dimension_name, column_name in zip(element.required.names, element.schema.required.names): + element_sql_builder.dimensions_provided[dimension_name] = [table.columns[column_name]] + element_sql_builder.extract_keys(element.implied.names) + sql_join_on: list[sqlalchemy.ColumnElement[bool]] = [] + for col_ref in fields: + if col_ref.column_type == "timespan": + timespan = self._timespan_db_repr.from_columns(table.columns, col_ref.field) + element_sql_builder.timespans_provided[col_ref] = timespan + sql_join_on.extend( + [ + temporal_join_sql_column.overlaps(timespan) + for temporal_join_col_ref in temporal_join_map.get(element.name, []) + if ( + temporal_join_sql_column := sql_builder.timespans_provided.get( + temporal_join_col_ref + ) + ) + ] + ) + else: + element_sql_builder.fields_provided[col_ref] = table.columns[col_ref.field] + return sql_builder.join(element_sql_builder, sql_join_on=sql_join_on) + + def _join_skypix_overlap( + self, sql_builder: SqlBuilder | EmptySqlBuilder, element: DimensionElement + ) -> SqlBuilder: + table = self._dimension_skypix_overlap_tables[element.name] + overlap_sql_builder = ( + SqlBuilder(table) + .extract_keys(element.required.names) + .where_sql( + table.c.skypix_system == element.universe.commonSkyPix.system.name, + table.c.skypix_level == element.universe.commonSkyPix.level, + ) + ) + overlap_sql_builder.dimensions_provided[element.universe.commonSkyPix.name] = [ + table.c.skypix_index.label(element.universe.commonSkyPix.name) + ] + return sql_builder.join(overlap_sql_builder) @dataclasses.dataclass @@ -446,9 +605,6 @@ class BaseSqlBuilder: rt.DimensionFieldReference | rt.DatasetFieldReference, TimespanDatabaseRepresentation ] = dataclasses.field(default_factory=dict, kw_only=True) - -@dataclasses.dataclass -class EmptySqlBuilder(BaseSqlBuilder): EMPTY_COLUMNS_NAME: ClassVar[str] = "IGNORED" """Name of the column added to a SQL ``SELECT`` query in order to represent relations that have no real columns. @@ -459,27 +615,6 @@ class EmptySqlBuilder(BaseSqlBuilder): relations that have no real columns. """ - def join(self, other: SqlBuilder, sql_join_on: Iterable[sqlalchemy.ColumnElement] = ()) -> SqlBuilder: - return other - - def sql_select( - self, - columns_to_select: Iterable[rt.ColumnReference], - postprocessing: Postprocessing, - *, - sql_columns_to_select: Iterable[sqlalchemy.ColumnElement] = (), - order_by: Iterable[rt.OrderExpression] = (), - limit: int | None = None, - offset: int = 0, - ) -> tuple[sqlalchemy.Select, Postprocessing]: - assert not columns_to_select - assert not postprocessing - assert not order_by - result = sqlalchemy.select(*self.handle_empty_columns([]), *sql_columns_to_select) - if offset > 0 or limit == 0: - result = result.where(sqlalchemy.literal(False)) - return result, postprocessing - @classmethod def handle_empty_columns( cls, columns: list[sqlalchemy.sql.ColumnElement] @@ -504,34 +639,41 @@ def handle_empty_columns( return columns +@dataclasses.dataclass +class EmptySqlBuilder(BaseSqlBuilder): + def join(self, other: SqlBuilder, sql_join_on: Iterable[sqlalchemy.ColumnElement] = ()) -> SqlBuilder: + return other + + def sql_select( + self, + columns_to_select: Iterable[rt.ColumnReference], + postprocessing: Postprocessing, + *, + sql_columns_to_select: Iterable[sqlalchemy.ColumnElement] = (), + order_by: Iterable[rt.OrderExpression] = (), + limit: int | None = None, + offset: int = 0, + ) -> tuple[sqlalchemy.Select, Postprocessing]: + assert not columns_to_select + assert not postprocessing + assert not order_by + result = sqlalchemy.select(*self.handle_empty_columns([]), *sql_columns_to_select) + if offset > 0 or limit == 0: + result = result.where(sqlalchemy.literal(False)) + return result, postprocessing + + @dataclasses.dataclass class SqlBuilder(BaseSqlBuilder): sql_from_clause: sqlalchemy.FromClause + sql_where_terms: list[sqlalchemy.ColumnElement[bool]] = dataclasses.field(default_factory=list) - def with_data_coordinate_columns(self, dimensions: Iterable[str]) -> SqlBuilder: + def extract_keys(self, dimensions: Iterable[str]) -> SqlBuilder: for dimension_name in dimensions: self.dimensions_provided[dimension_name] = [self.sql_from_clause.columns[dimension_name]] return self - def with_dimension_record_columns( - self, - element: DimensionElement, - fields: Iterable[rt.DimensionFieldReference], - timespan_db_repr: type[TimespanDatabaseRepresentation], - ) -> SqlBuilder: - for dimension_name, column_name in zip(element.required.names, element.schema.required.names): - self.dimensions_provided[dimension_name] = [self.sql_from_clause.columns[column_name]] - self.with_data_coordinate_columns(element.implied.names) - for col_ref in fields: - if col_ref.column_type == "timespan": - self.timespans_provided[col_ref] = timespan_db_repr.from_columns( - self.sql_from_clause.columns, col_ref.field - ) - else: - self.fields_provided[col_ref] = self.sql_from_clause.columns[col_ref.field] - return self - - def with_qualified_name_fields( + def extract_fields( self, fields: Iterable[rt.ColumnReference], timespan_db_repr: type[TimespanDatabaseRepresentation] ) -> SqlBuilder: for col_ref in fields: @@ -582,21 +724,24 @@ def sql_select( else: sql_columns.append(self.fields_provided[col_ref]) sql_columns.extend(sql_columns_to_select) - EmptySqlBuilder.handle_empty_columns(sql_columns) + self.handle_empty_columns(sql_columns) result = sqlalchemy.select(*sql_columns).select_from(self.sql_from_clause) # Add ORDER BY, LIMIT, and OFFSET clauses as appropriate. if order_by: result = result.order_by(*[self.build_sql_order_by_expression(term) for term in order_by]) - if not postprocessing.can_remove_rows: + if not postprocessing: if offset: result = result.offset(offset) if limit is not None: result = result.limit(limit) else: - postprocessing.limit = limit - postprocessing.offset = offset + raise NotImplementedError("TODO") return result, postprocessing + def where_sql(self, *arg: sqlalchemy.ColumnElement[bool]) -> SqlBuilder: + self.sql_where_terms.extend(arg) + return self + def build_sql_order_by_expression(self, term: rt.OrderExpression) -> sqlalchemy.ColumnElement: if term.expression_type == "reversed": return self.build_sql_column_expression(term.operand).desc() @@ -611,29 +756,22 @@ def build_sql_predicate(self, predicate: rt.Predicate) -> sqlalchemy.ColumnEleme @dataclasses.dataclass class Postprocessing: - skypix_regions_provided: set[SkyPixDimension] = dataclasses.field(default_factory=set) spatial_join_filtering: list[tuple[DimensionElement, DimensionElement]] = dataclasses.field( default_factory=list ) - limit: int | None = None - offset: int = 0 + spatial_where_filtering: list[tuple[DimensionElement, Region]] = dataclasses.field(default_factory=list) def gather_columns_required(self) -> Set[rt.ColumnReference]: result: set[rt.ColumnReference] = set() for element in itertools.chain.from_iterable(self.spatial_join_filtering): result.add(rt.DimensionFieldReference.model_construct(element=element, field="region")) - for dimension in self.skypix_regions_provided: - result.add(rt.DimensionKeyReference.model_construct(dimension=dimension)) + for element, _ in self.spatial_join_filtering: + result.add(rt.DimensionFieldReference.model_construct(element=element, field="region")) return result - @property - def can_remove_rows(self) -> bool: - return bool(self.offset) or self.limit is not None - def __bool__(self) -> bool: - return bool( - self.skypix_regions_provided - or self.spatial_join_filtering - or self.limit is not None - or self.offset - ) + return bool(self.spatial_join_filtering) or bool(self.spatial_where_filtering) + + def extend(self, other: Postprocessing) -> None: + self.spatial_join_filtering.extend(other.spatial_join_filtering) + self.spatial_where_filtering.extend(other.spatial_where_filtering)