diff --git a/python/lsst/daf/butler/registry/datasets/byDimensions/_manager.py b/python/lsst/daf/butler/registry/datasets/byDimensions/_manager.py index 60e7b0f308..88e20c67f2 100644 --- a/python/lsst/daf/butler/registry/datasets/byDimensions/_manager.py +++ b/python/lsst/daf/butler/registry/datasets/byDimensions/_manager.py @@ -4,6 +4,7 @@ __all__ = ("ByDimensionsDatasetRecordStorageManagerUUID",) +import dataclasses import logging import warnings from collections import defaultdict @@ -55,16 +56,14 @@ class MissingDatabaseTableError(RuntimeError): """Exception raised when a table is not found in a database.""" -class _ExistingTableFactory: - """Factory for `sqlalchemy.schema.Table` instances that returns already - existing table instance. - """ - - def __init__(self, table: sqlalchemy.schema.Table): - self._table = table +@dataclasses.dataclass +class _DatasetTypeRecord: + """Contents of a single dataset type record.""" - def __call__(self) -> sqlalchemy.schema.Table: - return self._table + dataset_type: DatasetType + dataset_type_id: int + tag_table_name: str + calib_table_name: str | None class _SpecTableFactory: @@ -139,8 +138,6 @@ def __init__( self._dimensions = dimensions self._static = static self._summaries = summaries - self._byName: dict[str, ByDimensionsDatasetRecordStorage] = {} - self._byId: dict[int, ByDimensionsDatasetRecordStorage] = {} @classmethod def initialize( @@ -162,6 +159,7 @@ def initialize( context, collections=collections, dimensions=dimensions, + dataset_type_table=static.dataset_type, ) return cls( db=db, @@ -236,44 +234,33 @@ def addDatasetForeignKey( def refresh(self) -> None: # Docstring inherited from DatasetRecordStorageManager. - byName: dict[str, ByDimensionsDatasetRecordStorage] = {} - byId: dict[int, ByDimensionsDatasetRecordStorage] = {} - c = self._static.dataset_type.columns - with self._db.query(self._static.dataset_type.select()) as sql_result: - sql_rows = sql_result.mappings().fetchall() - for row in sql_rows: - name = row[c.name] - dimensions = self._dimensions.loadDimensionGraph(row[c.dimensions_key]) - calibTableName = row[c.calibration_association_table] - datasetType = DatasetType( - name, dimensions, row[c.storage_class], isCalibration=(calibTableName is not None) - ) - tags_spec = makeTagTableSpec(datasetType, type(self._collections), self.getIdColumnType()) - tags_table_factory = _SpecTableFactory(self._db, row[c.tag_association_table], tags_spec) - calibs_table_factory = None - if calibTableName is not None: - calibs_spec = makeCalibTableSpec( - datasetType, - type(self._collections), - self._db.getTimespanRepresentation(), - self.getIdColumnType(), - ) - calibs_table_factory = _SpecTableFactory(self._db, calibTableName, calibs_spec) - storage = self._recordStorageType( - db=self._db, - datasetType=datasetType, - static=self._static, - summaries=self._summaries, - tags_table_factory=tags_table_factory, - calibs_table_factory=calibs_table_factory, - dataset_type_id=row["id"], - collections=self._collections, - use_astropy_ingest_date=self.ingest_date_dtype() is ddl.AstropyTimeNsecTai, + pass + + def _make_storage(self, record: _DatasetTypeRecord) -> ByDimensionsDatasetRecordStorage: + """Create storage instance for a dataset type record.""" + tags_spec = makeTagTableSpec(record.dataset_type, type(self._collections), self.getIdColumnType()) + tags_table_factory = _SpecTableFactory(self._db, record.tag_table_name, tags_spec) + calibs_table_factory = None + if record.calib_table_name is not None: + calibs_spec = makeCalibTableSpec( + record.dataset_type, + type(self._collections), + self._db.getTimespanRepresentation(), + self.getIdColumnType(), ) - byName[datasetType.name] = storage - byId[storage._dataset_type_id] = storage - self._byName = byName - self._byId = byId + calibs_table_factory = _SpecTableFactory(self._db, record.calib_table_name, calibs_spec) + storage = self._recordStorageType( + db=self._db, + datasetType=record.dataset_type, + static=self._static, + summaries=self._summaries, + tags_table_factory=tags_table_factory, + calibs_table_factory=calibs_table_factory, + dataset_type_id=record.dataset_type_id, + collections=self._collections, + use_astropy_ingest_date=self.ingest_date_dtype() is ddl.AstropyTimeNsecTai, + ) + return storage def remove(self, name: str) -> None: # Docstring inherited from DatasetRecordStorageManager. @@ -296,33 +283,28 @@ def remove(self, name: str) -> None: def find(self, name: str) -> DatasetRecordStorage | None: # Docstring inherited from DatasetRecordStorageManager. - return self._byName.get(name) + record = self._fetch_dataset_type_record(name) + return self._make_storage(record) if record is not None else None - def register(self, datasetType: DatasetType) -> tuple[DatasetRecordStorage, bool]: + def register(self, datasetType: DatasetType) -> bool: # Docstring inherited from DatasetRecordStorageManager. if datasetType.isComponent(): raise ValueError( f"Component dataset types can not be stored in registry. Rejecting {datasetType.name}" ) - storage = self._byName.get(datasetType.name) - if storage is None: + record = self._fetch_dataset_type_record(datasetType.name) + if record is None: dimensionsKey = self._dimensions.saveDimensionGraph(datasetType.dimensions) tagTableName = makeTagTableName(datasetType, dimensionsKey) - calibTableName = ( - makeCalibTableName(datasetType, dimensionsKey) if datasetType.isCalibration() else None - ) - # The order is important here, we want to create tables first and - # only register them if this operation is successful. We cannot - # wrap it into a transaction because database class assumes that - # DDL is not transaction safe in general. - tags = self._db.ensureTableExists( + self._db.ensureTableExists( tagTableName, makeTagTableSpec(datasetType, type(self._collections), self.getIdColumnType()), ) - tags_table_factory = _ExistingTableFactory(tags) - calibs_table_factory = None + calibTableName = ( + makeCalibTableName(datasetType, dimensionsKey) if datasetType.isCalibration() else None + ) if calibTableName is not None: - calibs = self._db.ensureTableExists( + self._db.ensureTableExists( calibTableName, makeCalibTableSpec( datasetType, @@ -331,8 +313,7 @@ def register(self, datasetType: DatasetType) -> tuple[DatasetRecordStorage, bool self.getIdColumnType(), ), ) - calibs_table_factory = _ExistingTableFactory(calibs) - row, inserted = self._db.sync( + _, inserted = self._db.sync( self._static.dataset_type, keys={"name": datasetType.name}, compared={ @@ -347,28 +328,17 @@ def register(self, datasetType: DatasetType) -> tuple[DatasetRecordStorage, bool }, returning=["id", "tag_association_table"], ) - assert row is not None - storage = self._recordStorageType( - db=self._db, - datasetType=datasetType, - static=self._static, - summaries=self._summaries, - tags_table_factory=tags_table_factory, - calibs_table_factory=calibs_table_factory, - dataset_type_id=row["id"], - collections=self._collections, - use_astropy_ingest_date=self.ingest_date_dtype() is ddl.AstropyTimeNsecTai, - ) - self._byName[datasetType.name] = storage - self._byId[storage._dataset_type_id] = storage else: - if datasetType != storage.datasetType: + if datasetType != record.dataset_type: raise ConflictingDefinitionError( f"Given dataset type {datasetType} is inconsistent " - f"with database definition {storage.datasetType}." + f"with database definition {record.dataset_type}." ) inserted = False - return storage, bool(inserted) + # TODO: We return storage instance from this method, but the only + # client that uses this method ignores it. Maybe we should drop it + # and avoid making storage instance above. + return bool(inserted) def resolve_wildcard( self, @@ -422,15 +392,13 @@ def resolve_wildcard( raise TypeError( "Universal wildcard '...' is not permitted for dataset types in this context." ) - for storage in self._byName.values(): - result[storage.datasetType].add(None) + for datasetType in self._fetch_dataset_types(): + result[datasetType].add(None) if components: try: - result[storage.datasetType].update( - storage.datasetType.storageClass.allComponents().keys() - ) + result[datasetType].update(datasetType.storageClass.allComponents().keys()) if ( - storage.datasetType.storageClass.allComponents() + datasetType.storageClass.allComponents() and not already_warned and components_deprecated ): @@ -442,7 +410,7 @@ def resolve_wildcard( already_warned = True except KeyError as err: _LOG.warning( - f"Could not load storage class {err} for {storage.datasetType.name}; " + f"Could not load storage class {err} for {datasetType.name}; " "if it has components they will not be included in query results.", ) elif wildcard.patterns: @@ -454,29 +422,28 @@ def resolve_wildcard( FutureWarning, stacklevel=find_outside_stacklevel("lsst.daf.butler"), ) - for storage in self._byName.values(): - if any(p.fullmatch(storage.datasetType.name) for p in wildcard.patterns): - result[storage.datasetType].add(None) + dataset_types = self._fetch_dataset_types() + for datasetType in dataset_types: + if any(p.fullmatch(datasetType.name) for p in wildcard.patterns): + result[datasetType].add(None) if components is not False: - for storage in self._byName.values(): - if components is None and storage.datasetType in result: + for datasetType in dataset_types: + if components is None and datasetType in result: continue try: - components_for_parent = storage.datasetType.storageClass.allComponents().keys() + components_for_parent = datasetType.storageClass.allComponents().keys() except KeyError as err: _LOG.warning( - f"Could not load storage class {err} for {storage.datasetType.name}; " + f"Could not load storage class {err} for {datasetType.name}; " "if it has components they will not be included in query results." ) continue for component_name in components_for_parent: if any( - p.fullmatch( - DatasetType.nameWithComponent(storage.datasetType.name, component_name) - ) + p.fullmatch(DatasetType.nameWithComponent(datasetType.name, component_name)) for p in wildcard.patterns ): - result[storage.datasetType].add(component_name) + result[datasetType].add(component_name) if not already_warned and components_deprecated: warnings.warn( deprecation_message, @@ -492,49 +459,77 @@ def getDatasetRef(self, id: DatasetId) -> DatasetRef | None: sqlalchemy.sql.select( self._static.dataset.columns.dataset_type_id, self._static.dataset.columns[self._collections.getRunForeignKeyName()], + *self._static.dataset_type.columns, ) .select_from(self._static.dataset) + .join(self._static.dataset_type) .where(self._static.dataset.columns.id == id) ) with self._db.query(sql) as sql_result: row = sql_result.mappings().fetchone() if row is None: return None - recordsForType = self._byId.get(row[self._static.dataset.columns.dataset_type_id]) - if recordsForType is None: - self.refresh() - recordsForType = self._byId.get(row[self._static.dataset.columns.dataset_type_id]) - assert recordsForType is not None, "Should be guaranteed by foreign key constraints." + storage = self._make_storage(self._record_from_row(row)) return DatasetRef( - recordsForType.datasetType, - dataId=recordsForType.getDataId(id=id), + storage.datasetType, + dataId=storage.getDataId(id=id), id=id, run=self._collections[row[self._collections.getRunForeignKeyName()]].name, ) - def _dataset_type_factory(self, dataset_type_id: int) -> DatasetType: - """Return dataset type given its ID.""" - return self._byId[dataset_type_id].datasetType + def _fetch_dataset_type_record(self, name: str) -> _DatasetTypeRecord | None: + """Retrieve all dataset types defined in database. + + Yields + ------ + dataset_types : `_DatasetTypeRecord` + Information from a single database record. + """ + c = self._static.dataset_type.columns + stmt = self._static.dataset_type.select().where(c.name == name) + with self._db.query(stmt) as sql_result: + row = sql_result.mappings().one_or_none() + if row is None: + return None + else: + return self._record_from_row(row) + + def _record_from_row(self, row: Mapping) -> _DatasetTypeRecord: + name = row["name"] + dimensions = self._dimensions.loadDimensionGraph(row["dimensions_key"]) + calibTableName = row["calibration_association_table"] + datasetType = DatasetType( + name, dimensions, row["storage_class"], isCalibration=(calibTableName is not None) + ) + return _DatasetTypeRecord( + dataset_type=datasetType, + dataset_type_id=row["id"], + tag_table_name=row["tag_association_table"], + calib_table_name=calibTableName, + ) + + def _dataset_type_from_row(self, row: Mapping) -> DatasetType: + return self._record_from_row(row).dataset_type + + def _fetch_dataset_types(self) -> list[DatasetType]: + """Fetch list of all defined dataset types.""" + with self._db.query(self._static.dataset_type.select()) as sql_result: + sql_rows = sql_result.mappings().fetchall() + return [self._record_from_row(row).dataset_type for row in sql_rows] def getCollectionSummary(self, collection: CollectionRecord) -> CollectionSummary: # Docstring inherited from DatasetRecordStorageManager. - summaries = self._summaries.fetch_summaries([collection], None, self._dataset_type_factory) + summaries = self._summaries.fetch_summaries([collection], None, self._dataset_type_from_row) return summaries[collection.key] def fetch_summaries( self, collections: Iterable[CollectionRecord], dataset_types: Iterable[DatasetType] | None = None ) -> Mapping[Any, CollectionSummary]: # Docstring inherited from DatasetRecordStorageManager. - dataset_type_ids: list[int] | None = None + dataset_type_names: Iterable[str] | None = None if dataset_types is not None: - dataset_type_ids = [] - for dataset_type in dataset_types: - if dataset_type.isComponent(): - dataset_type = dataset_type.makeCompositeDatasetType() - # Assume we know all possible names. - dataset_type_id = self._byName[dataset_type.name]._dataset_type_id - dataset_type_ids.append(dataset_type_id) - return self._summaries.fetch_summaries(collections, dataset_type_ids, self._dataset_type_factory) + dataset_type_names = set(dataset_type.name for dataset_type in dataset_types) + return self._summaries.fetch_summaries(collections, dataset_type_names, self._dataset_type_from_row) _versions: list[VersionTuple] """Schema version for this class.""" diff --git a/python/lsst/daf/butler/registry/datasets/byDimensions/summaries.py b/python/lsst/daf/butler/registry/datasets/byDimensions/summaries.py index 73511ddd0c..41687cb9c2 100644 --- a/python/lsst/daf/butler/registry/datasets/byDimensions/summaries.py +++ b/python/lsst/daf/butler/registry/datasets/byDimensions/summaries.py @@ -133,6 +133,8 @@ class CollectionSummaryManager: Manager object for the dimensions in this `Registry`. tables : `CollectionSummaryTables` Struct containing the tables that hold collection summaries. + dataset_type_table : `sqlalchemy.schema.Table` + Table containing dataset type definitions. """ def __init__( @@ -142,12 +144,14 @@ def __init__( collections: CollectionManager, dimensions: DimensionRecordStorageManager, tables: CollectionSummaryTables[sqlalchemy.schema.Table], + dataset_type_table: sqlalchemy.schema.Table, ): self._db = db self._collections = collections self._collectionKeyName = collections.getCollectionForeignKeyName() self._dimensions = dimensions self._tables = tables + self._dataset_type_table = dataset_type_table @classmethod def initialize( @@ -157,6 +161,7 @@ def initialize( *, collections: CollectionManager, dimensions: DimensionRecordStorageManager, + dataset_type_table: sqlalchemy.schema.Table, ) -> CollectionSummaryManager: """Create all summary tables (or check that they have been created), returning an object to manage them. @@ -172,6 +177,8 @@ def initialize( Manager object for the collections in this `Registry`. dimensions : `DimensionRecordStorageManager` Manager object for the dimensions in this `Registry`. + dataset_type_table : `sqlalchemy.schema.Table` + Table containing dataset type definitions. Returns ------- @@ -193,6 +200,7 @@ def initialize( collections=collections, dimensions=dimensions, tables=tables, + dataset_type_table=dataset_type_table, ) def update( @@ -240,8 +248,8 @@ def update( def fetch_summaries( self, collections: Iterable[CollectionRecord], - dataset_type_ids: Iterable[int] | None, - dataset_type_factory: Callable[[int], DatasetType], + dataset_type_names: Iterable[str] | None, + dataset_type_factory: Callable[[sqlalchemy.engine.RowMapping], DatasetType], ) -> Mapping[Any, CollectionSummary]: """Fetch collection summaries given their names and dataset types. @@ -249,12 +257,12 @@ def fetch_summaries( ---------- collections : `~collections.abc.Iterable` [`CollectionRecord`] Collection records to query. - dataset_type_ids : `~collections.abc.Iterable` [`int`] - IDs of dataset types to include into returned summaries. If `None` - then all dataset types will be included. + dataset_type_names : `~collections.abc.Iterable` [`str`] + Names of dataset types to include into returned summaries. If + `None` then all dataset types will be included. dataset_type_factory : `Callable` - Method that returns `DatasetType` instance given its dataset type - ID. + Method that takes a table row and make `DatasetType` instance out + of it. Returns ------- @@ -282,8 +290,10 @@ def fetch_summaries( # information at once. coll_col = self._tables.datasetType.columns[self._collectionKeyName].label(self._collectionKeyName) dataset_type_id_col = self._tables.datasetType.columns.dataset_type_id.label("dataset_type_id") - columns = [coll_col, dataset_type_id_col] - fromClause: sqlalchemy.sql.expression.FromClause = self._tables.datasetType + columns = [coll_col, dataset_type_id_col] + list(self._dataset_type_table.columns) + fromClause: sqlalchemy.sql.expression.FromClause = self._tables.datasetType.join( + self._dataset_type_table + ) for dimension, table in self._tables.dimensions.items(): columns.append(table.columns[dimension.name].label(dimension.name)) fromClause = fromClause.join( @@ -297,8 +307,8 @@ def fetch_summaries( sql = sqlalchemy.sql.select(*columns).select_from(fromClause) sql = sql.where(coll_col.in_([coll.key for coll in non_chains])) - if dataset_type_ids is not None: - sql = sql.where(dataset_type_id_col.in_(dataset_type_ids)) + if dataset_type_names is not None: + sql = sql.where(self._dataset_type_table.columns["name"].in_(dataset_type_names)) # Run the query and construct CollectionSummary objects from the result # rows. This will never include CHAINED collections or collections @@ -306,13 +316,16 @@ def fetch_summaries( summaries: dict[Any, CollectionSummary] = {} with self._db.query(sql) as sql_result: sql_rows = sql_result.mappings().fetchall() + dataset_type_ids: dict[int, DatasetType] = {} for row in sql_rows: # Collection key should never be None/NULL; it's what we join on. # Extract that and then turn it into a collection name. collectionKey = row[self._collectionKeyName] # dataset_type_id should also never be None/NULL; it's in the first # table we joined. - dataset_type = dataset_type_factory(row["dataset_type_id"]) + dataset_type_id = row["dataset_type_id"] + if (dataset_type := dataset_type_ids.get(dataset_type_id)) is None: + dataset_type_ids[dataset_type_id] = dataset_type = dataset_type_factory(row) # See if we have a summary already for this collection; if not, # make one. summary = summaries.get(collectionKey) diff --git a/python/lsst/daf/butler/registry/interfaces/_datasets.py b/python/lsst/daf/butler/registry/interfaces/_datasets.py index 84a5a735d4..8bafb02274 100644 --- a/python/lsst/daf/butler/registry/interfaces/_datasets.py +++ b/python/lsst/daf/butler/registry/interfaces/_datasets.py @@ -487,7 +487,7 @@ def find(self, name: str) -> DatasetRecordStorage | None: raise NotImplementedError() @abstractmethod - def register(self, datasetType: DatasetType) -> tuple[DatasetRecordStorage, bool]: + def register(self, datasetType: DatasetType) -> bool: """Ensure that this `Registry` can hold records for the given `DatasetType`, creating new tables as necessary. @@ -499,8 +499,6 @@ def register(self, datasetType: DatasetType) -> tuple[DatasetRecordStorage, bool Returns ------- - records : `DatasetRecordStorage` - The object representing the records for the given dataset type. inserted : `bool` `True` if the dataset type did not exist in the registry before. diff --git a/python/lsst/daf/butler/registry/sql_registry.py b/python/lsst/daf/butler/registry/sql_registry.py index 5e03938b78..07eccf6ef8 100644 --- a/python/lsst/daf/butler/registry/sql_registry.py +++ b/python/lsst/daf/butler/registry/sql_registry.py @@ -697,8 +697,7 @@ def registerDatasetType(self, datasetType: DatasetType) -> bool: This method cannot be called within transactions, as it needs to be able to perform its own transaction to be concurrent. """ - _, inserted = self._managers.datasets.register(datasetType) - return inserted + return self._managers.datasets.register(datasetType) def removeDatasetType(self, name: str | tuple[str, ...]) -> None: """Remove the named `DatasetType` from the registry.