From 18ba8accb97aab27793a085714849ed05c8d4c66 Mon Sep 17 00:00:00 2001 From: Jim Bosch Date: Fri, 1 Mar 2024 13:40:09 -0500 Subject: [PATCH] WIP: put dataset field name shrinking in SqlBuilder. This means results handlers that have datasets need to re-expand them. But we don't have any of those yet. --- .../direct_query_driver/_analyzed_query.py | 1 - .../direct_query_driver/_convert_results.py | 2 +- .../daf/butler/direct_query_driver/_driver.py | 12 +--- .../direct_query_driver/_sql_builder.py | 67 ++++++++++++++----- 4 files changed, 52 insertions(+), 30 deletions(-) diff --git a/python/lsst/daf/butler/direct_query_driver/_analyzed_query.py b/python/lsst/daf/butler/direct_query_driver/_analyzed_query.py index 71d98c7898..63f9b4a9fc 100644 --- a/python/lsst/daf/butler/direct_query_driver/_analyzed_query.py +++ b/python/lsst/daf/butler/direct_query_driver/_analyzed_query.py @@ -43,7 +43,6 @@ @dataclasses.dataclass class AnalyzedDatasetSearch: name: str - shrunk: str dimensions: DimensionGroup collection_records: list[CollectionRecord] = dataclasses.field(default_factory=list) messages: list[str] = dataclasses.field(default_factory=list) diff --git a/python/lsst/daf/butler/direct_query_driver/_convert_results.py b/python/lsst/daf/butler/direct_query_driver/_convert_results.py index 899c1760b4..bf0bda3948 100644 --- a/python/lsst/daf/butler/direct_query_driver/_convert_results.py +++ b/python/lsst/daf/butler/direct_query_driver/_convert_results.py @@ -37,9 +37,9 @@ from ..dimensions import DimensionRecordSet if TYPE_CHECKING: + from ..name_shrinker import NameShrinker from ..queries.driver import DimensionRecordResultPage, PageKey from ..queries.result_specs import DimensionRecordResultSpec - from ..registry.nameShrinker import NameShrinker def convert_dimension_record_results( diff --git a/python/lsst/daf/butler/direct_query_driver/_driver.py b/python/lsst/daf/butler/direct_query_driver/_driver.py index 7cbb59c304..2e3fe2324f 100644 --- a/python/lsst/daf/butler/direct_query_driver/_driver.py +++ b/python/lsst/daf/butler/direct_query_driver/_driver.py @@ -40,6 +40,7 @@ from .. import ddl from ..dimensions import DataIdValue, DimensionGroup, DimensionUniverse +from ..name_shrinker import NameShrinker from ..queries import tree as qt from ..queries.driver import ( DataCoordinateResultPage, @@ -60,7 +61,6 @@ from ..registry import CollectionSummary, CollectionType, NoDefaultCollectionError, RegistryDefaults from ..registry.interfaces import ChainedCollectionRecord, CollectionRecord from ..registry.managers import RegistryManagerInstances -from ..registry.nameShrinker import NameShrinker from ._analyzed_query import AnalyzedDatasetSearch, AnalyzedQuery, DataIdExtractionVisitor from ._convert_results import convert_dimension_record_results from ._sql_column_visitor import SqlColumnVisitor @@ -432,11 +432,8 @@ def analyze_query( query.data_coordinate_uploads.update(tree.data_coordinate_uploads) # Add dataset_searches and filter out collections that don't have the # right dataset type or governor dimensions. - name_shrinker = make_dataset_name_shrinker(self.db.dialect) for dataset_type_name, dataset_search in tree.datasets.items(): - dataset = AnalyzedDatasetSearch( - dataset_type_name, name_shrinker.shrink(dataset_type_name), dataset_search.dimensions - ) + dataset = AnalyzedDatasetSearch(dataset_type_name, dataset_search.dimensions) for collection_record, collection_summary in self.resolve_collection_path( dataset_search.collections ): @@ -784,8 +781,3 @@ def _process_page( ) case _: raise NotImplementedError("TODO") - - -def make_dataset_name_shrinker(dialect: sqlalchemy.Dialect) -> NameShrinker: - max_dataset_field_length = max(len(field) for field in qt.DATASET_FIELD_NAMES) - return NameShrinker(dialect.max_identifier_length - max_dataset_field_length - 1, 6) diff --git a/python/lsst/daf/butler/direct_query_driver/_sql_builder.py b/python/lsst/daf/butler/direct_query_driver/_sql_builder.py index 84726f4bc5..3d39f875ae 100644 --- a/python/lsst/daf/butler/direct_query_driver/_sql_builder.py +++ b/python/lsst/daf/butler/direct_query_driver/_sql_builder.py @@ -37,6 +37,7 @@ import sqlalchemy from .. import ddl +from ..name_shrinker import NameShrinker from ..nonempty_mapping import NonemptyMapping from ..queries import tree as qt from ._postprocessing import Postprocessing @@ -65,6 +66,8 @@ class SqlBuilder: special: dict[str, sqlalchemy.ColumnElement[Any]] = dataclasses.field(default_factory=dict) + name_shrinker: NameShrinker | None = None + EMPTY_COLUMNS_NAME: ClassVar[str] = "IGNORED" """Name of the column added to a SQL ``SELECT`` query in order to represent relations that have no real columns. @@ -111,15 +114,19 @@ def select( distinct: bool | Sequence[sqlalchemy.ColumnElement[Any]] = False, group_by: Sequence[sqlalchemy.ColumnElement] = (), ) -> sqlalchemy.Select: + if self.name_shrinker is None: + self.name_shrinker = self._make_name_shrinker() sql_columns: list[sqlalchemy.ColumnElement[Any]] = [] for logical_table, field in columns: name = columns.get_qualified_name(logical_table, field) if field is None: sql_columns.append(self.dimension_keys[logical_table][0].label(name)) - elif columns.is_timespan(logical_table, field): - sql_columns.extend(self.timespans[logical_table].flatten(name)) else: - sql_columns.append(self.fields[logical_table][field].label(name)) + name = self.name_shrinker.shrink(name) + if columns.is_timespan(logical_table, field): + sql_columns.extend(self.timespans[logical_table].flatten(name)) + else: + sql_columns.append(self.fields[logical_table][field].label(name)) if postprocessing is not None: for element in postprocessing.iter_missing(columns): assert ( @@ -127,7 +134,7 @@ def select( ), "Region aggregates not handled by this method." sql_columns.append( self.fields[element.name]["region"].label( - columns.get_qualified_name(element.name, "region") + self.name_shrinker.shrink(columns.get_qualified_name(element.name, "region")) ) ) for label, sql_column in self.special.items(): @@ -148,18 +155,23 @@ def select( return result def make_table_spec( - self, - columns: qt.ColumnSet, - postprocessing: Postprocessing | None = None, + self, columns: qt.ColumnSet, postprocessing: Postprocessing | None = None ) -> ddl.TableSpec: assert not self.special, "special columns not supported in make_table_spec" + if self.name_shrinker is None: + self.name_shrinker = self._make_name_shrinker() results = ddl.TableSpec( - [columns.get_column_spec(logical_table, field).to_sql_spec() for logical_table, field in columns] + [ + columns.get_column_spec(logical_table, field).to_sql_spec(name_shrinker=self.name_shrinker) + for logical_table, field in columns + ] ) if postprocessing: for element in postprocessing.iter_missing(columns): results.fields.add( - ddl.FieldSpec.for_region(columns.get_qualified_name(element.name, "region")) + ddl.FieldSpec.for_region( + self.name_shrinker.shrink(columns.get_qualified_name(element.name, "region")) + ) ) return results @@ -175,19 +187,25 @@ def extract_columns( self, columns: qt.ColumnSet, postprocessing: Postprocessing | None = None ) -> SqlBuilder: assert self.sql_from_clause is not None, "Cannot extract columns with no FROM clause." + if self.name_shrinker is None: + self.name_shrinker = self._make_name_shrinker() for logical_table, field in columns: name = columns.get_qualified_name(logical_table, field) if field is None: self.dimension_keys[logical_table].append(self.sql_from_clause.columns[name]) - elif columns.is_timespan(logical_table, field): - self.timespans[logical_table] = self.db.getTimespanRepresentation().from_columns( - self.sql_from_clause.columns, name - ) else: - self.fields[logical_table][field] = self.sql_from_clause.columns[name] + name = self.name_shrinker.shrink(name) + if columns.is_timespan(logical_table, field): + self.timespans[logical_table] = self.db.getTimespanRepresentation().from_columns( + self.sql_from_clause.columns, name + ) + else: + self.fields[logical_table][field] = self.sql_from_clause.columns[name] if postprocessing is not None: for element in postprocessing.iter_missing(columns): - self.fields[element.name]["region"] = self.sql_from_clause.columns[name] + self.fields[element.name]["region"] = self.sql_from_clause.columns[ + self.name_shrinker.shrink(columns.get_qualified_name(element.name, "region")) + ] if postprocessing.check_validity_match_count: self.special[postprocessing.VALIDITY_MATCH_COUNT] = self.sql_from_clause.columns[ postprocessing.VALIDITY_MATCH_COUNT @@ -211,6 +229,11 @@ def join(self, other: SqlBuilder) -> SqlBuilder: self.sql_where_terms += other.sql_where_terms self.needs_distinct = self.needs_distinct or other.needs_distinct self.special.update(other.special) + if other.name_shrinker: + if self.name_shrinker is not None: + self.name_shrinker.update(other.name_shrinker) + else: + self.name_shrinker = other.name_shrinker return self def where_sql(self, *arg: sqlalchemy.ColumnElement[bool]) -> SqlBuilder: @@ -227,7 +250,8 @@ def cte( ) -> SqlBuilder: return SqlBuilder( self.db, - self.select(columns, postprocessing, distinct=distinct, group_by=group_by).cte(), + sql_from_clause=self.select(columns, postprocessing, distinct=distinct, group_by=group_by).cte(), + name_shrinker=self.name_shrinker, ).extract_columns(columns, postprocessing) def subquery( @@ -240,7 +264,10 @@ def subquery( ) -> SqlBuilder: return SqlBuilder( self.db, - self.select(columns, postprocessing, distinct=distinct, group_by=group_by).subquery(), + sql_from_clause=self.select( + columns, postprocessing, distinct=distinct, group_by=group_by + ).subquery(), + name_shrinker=self.name_shrinker, ).extract_columns(columns, postprocessing) def union_subquery( @@ -253,5 +280,9 @@ def union_subquery( other_selects = [other.select(columns, postprocessing) for other in others] return SqlBuilder( self.db, - select0.union(*other_selects).subquery(), + sql_from_clause=select0.union(*other_selects).subquery(), + name_shrinker=self.name_shrinker, ).extract_columns(columns, postprocessing) + + def _make_name_shrinker(self) -> NameShrinker: + return NameShrinker(self.db.dialect.max_identifier_length, 6)