diff --git a/python/lsst/daf/butler/registry/_dataset_type_cache.py b/python/lsst/daf/butler/registry/_dataset_type_cache.py index 17569b4058..05d3b27b35 100644 --- a/python/lsst/daf/butler/registry/_dataset_type_cache.py +++ b/python/lsst/daf/butler/registry/_dataset_type_cache.py @@ -57,6 +57,9 @@ class DatasetTypeCache: """ def __init__(self) -> None: + from .datasets.byDimensions.tables import DynamicTablesCache + + self.tables = DynamicTablesCache() self._by_name_cache: dict[str, tuple[DatasetType, int]] = {} self._by_dimensions_cache: dict[DimensionGroup, DynamicTables] = {} self._full = False diff --git a/python/lsst/daf/butler/registry/datasets/byDimensions/_manager.py b/python/lsst/daf/butler/registry/datasets/byDimensions/_manager.py index e8b9ff7222..6f554f0373 100644 --- a/python/lsst/daf/butler/registry/datasets/byDimensions/_manager.py +++ b/python/lsst/daf/butler/registry/datasets/byDimensions/_manager.py @@ -79,7 +79,7 @@ def update_dynamic_tables(self, current: DynamicTables) -> DynamicTables: else: # Some previously-cached dataset type had the same dimensions # but was not a calibration. - current.calibs_name = self.calib_table_name + current = current.copy(calibs_name=self.calib_table_name) # If some previously-cached dataset type was a calibration but this # one isn't, we don't want to forget the calibs table. return current @@ -326,9 +326,13 @@ def register_dataset_type(self, dataset_type: DatasetType) -> bool: dynamic_tables = DynamicTables.from_dimensions_key( dataset_type.dimensions, dimensions_key, dataset_type.isCalibration() ) - dynamic_tables.create(self._db, type(self._collections)) + dynamic_tables.create( + self._db, type(self._collections), self._caching_context.dataset_types.tables + ) elif dataset_type.isCalibration() and dynamic_tables.calibs_name is None: - dynamic_tables.add_calibs(self._db, type(self._collections)) + dynamic_tables = dynamic_tables.add_calibs( + self._db, type(self._collections), self._caching_context.dataset_types.tables + ) row, inserted = self._db.sync( self._static.dataset_type, keys={"name": dataset_type.name}, @@ -454,7 +458,7 @@ def getDatasetRef(self, id: DatasetId) -> DatasetRef | None: # This query could return multiple rows (one for each tagged # collection the dataset is in, plus one for its run collection), # and we don't care which of those we get. - tags_table = dynamic_tables.tags(self._db, type(self._collections)) + tags_table = self._get_tags_table(dynamic_tables) data_id_sql = ( tags_table.select() .where( @@ -553,9 +557,10 @@ def _fetch_dataset_types(self) -> list[DatasetType]: for record in records: cache_data.append((record.dataset_type, record.dataset_type_id)) if (dynamic_tables := cache_dimensions_data.get(record.dataset_type.dimensions)) is None: - cache_dimensions_data[record.dataset_type.dimensions] = record.make_dynamic_tables() + tables = record.make_dynamic_tables() else: - record.update_dynamic_tables(dynamic_tables) + tables = record.update_dynamic_tables(dynamic_tables) + cache_dimensions_data[record.dataset_type.dimensions] = tables self._caching_context.dataset_types.set( cache_data, full=True, dimensions_data=cache_dimensions_data.items(), dimensions_full=True ) @@ -684,7 +689,7 @@ def insert( for dataId, row in zip(data_id_list, rows, strict=True) ] # Insert those rows into the tags table. - self._db.insert(storage.dynamic_tables.tags(self._db, type(self._collections)), *tagsRows) + self._db.insert(self._get_tags_table(storage.dynamic_tables), *tagsRows) return [ DatasetRef( @@ -767,9 +772,7 @@ def import_( summary.add_datasets(refs) self._summaries.update(run, [storage.dataset_type_id], summary) # Copy from temp table into tags table. - self._db.insert( - storage.dynamic_tables.tags(self._db, type(self._collections)), select=tmp_tags.select() - ) + self._db.insert(self._get_tags_table(storage.dynamic_tables), select=tmp_tags.select()) return refs def _validate_import( @@ -793,7 +796,7 @@ def _validate_import( Raise if new datasets conflict with existing ones. """ dataset = self._static.dataset - tags = storage.dynamic_tables.tags(self._db, type(self._collections)) + tags = self._get_tags_table(storage.dynamic_tables) collection_fkey_name = self._collections.getCollectionForeignKeyName() # Check that existing datasets have the same dataset type and @@ -943,7 +946,7 @@ def associate( # inserted there. self._summaries.update(collection, [storage.dataset_type_id], summary) # Update the tag table itself. - self._db.replace(storage.dynamic_tables.tags(self._db, type(self._collections)), *rows) + self._db.replace(self._get_tags_table(storage.dynamic_tables), *rows) def disassociate( self, dataset_type: DatasetType, collection: CollectionRecord, datasets: Iterable[DatasetRef] @@ -964,7 +967,7 @@ def disassociate( for dataset in datasets ] self._db.delete( - storage.dynamic_tables.tags(self._db, type(self._collections)), + self._get_tags_table(storage.dynamic_tables), ["dataset_id", self._collections.getCollectionForeignKeyName()], *rows, ) @@ -1015,7 +1018,7 @@ def certify( # inserted there. self._summaries.update(collection, [storage.dataset_type_id], summary) # Update the association table itself. - calibs_table = storage.dynamic_tables.calibs(self._db, type(self._collections)) + calibs_table = self._get_calibs_table(storage.dynamic_tables) if TimespanReprClass.hasExclusionConstraint(): # Rely on database constraint to enforce invariants; we just # reraise the exception for consistency across DB engines. @@ -1099,7 +1102,7 @@ def decertify( rows_to_insert = [] # Acquire a table lock to ensure there are no concurrent writes # between the SELECT and the DELETE and INSERT queries based on it. - calibs_table = storage.dynamic_tables.calibs(self._db, type(self._collections)) + calibs_table = self._get_calibs_table(storage.dynamic_tables) with self._db.transaction(lock=[calibs_table], savepoint=True): # Enter SqlQueryContext in case we need to use a temporary table to # include the give data IDs in the query (see similar block in @@ -1186,7 +1189,7 @@ def make_relation( tag_relation: Relation | None = None calib_relation: Relation | None = None if collection_types != {CollectionType.CALIBRATION}: - tags_table = storage.dynamic_tables.tags(self._db, type(self._collections)) + tags_table = self._get_tags_table(storage.dynamic_tables) # We'll need a subquery for the tags table if any of the given # collections are not a CALIBRATION collection. This intentionally # also fires when the list of collections is empty as a way to @@ -1214,7 +1217,7 @@ def make_relation( # If at least one collection is a CALIBRATION collection, we'll # need a subquery for the calibs table, and could include the # timespan as a result or constraint. - calibs_table = storage.dynamic_tables.calibs(self._db, type(self._collections)) + calibs_table = self._get_calibs_table(storage.dynamic_tables) calibs_parts = sql.Payload[LogicalColumn](calibs_table.alias(f"{dataset_type.name}_calibs")) if "timespan" in columns: calibs_parts.columns_available[DatasetColumnTag(dataset_type.name, "timespan")] = ( @@ -1422,7 +1425,7 @@ def make_joins_builder( # create a dummy subquery that we know will fail. # We give the table an alias because it might appear multiple times # in the same query, for different dataset types. - tags_table = storage.dynamic_tables.tags(self._db, type(self._collections)).alias( + tags_table = self._get_tags_table(storage.dynamic_tables).alias( f"{dataset_type.name}_tags{'_union' if is_union else ''}" ) tags_builder = self._finish_query_builder( @@ -1441,7 +1444,7 @@ def make_joins_builder( # If at least one collection is a CALIBRATION collection, we'll # need a subquery for the calibs table, and could include the # timespan as a result or constraint. - calibs_table = storage.dynamic_tables.calibs(self._db, type(self._collections)).alias( + calibs_table = self._get_calibs_table(storage.dynamic_tables).alias( f"{dataset_type.name}_calibs{'_union' if is_union else ''}" ) calibs_builder = self._finish_query_builder( @@ -1616,14 +1619,14 @@ def refresh_collection_summaries(self, dataset_type: DatasetType) -> None: # Query datasets tables for associated collections. column_name = self._collections.getCollectionForeignKeyName() - tags_table = storage.dynamic_tables.tags(self._db, type(self._collections)) + tags_table = self._get_tags_table(storage.dynamic_tables) query: sqlalchemy.sql.expression.SelectBase = ( sqlalchemy.select(tags_table.columns[column_name]) .where(tags_table.columns.dataset_type_id == storage.dataset_type_id) .distinct() ) if dataset_type.isCalibration(): - calibs_table = storage.dynamic_tables.calibs(self._db, type(self._collections)) + calibs_table = self._get_calibs_table(storage.dynamic_tables) query2 = ( sqlalchemy.select(calibs_table.columns[column_name]) .where(calibs_table.columns.dataset_type_id == storage.dataset_type_id) @@ -1637,6 +1640,12 @@ def refresh_collection_summaries(self, dataset_type: DatasetType) -> None: collections_to_delete = summary_collection_ids - collection_ids self._summaries.delete_collections(storage.dataset_type_id, collections_to_delete) + def _get_tags_table(self, table: DynamicTables) -> sqlalchemy.Table: + return table.tags(self._db, type(self._collections), self._caching_context.dataset_types.tables) + + def _get_calibs_table(self, table: DynamicTables) -> sqlalchemy.Table: + return table.calibs(self._db, type(self._collections), self._caching_context.dataset_types.tables) + def _create_case_expression_for_collections( collections: Iterable[CollectionRecord], id_column: sqlalchemy.ColumnElement diff --git a/python/lsst/daf/butler/registry/datasets/byDimensions/tables.py b/python/lsst/daf/butler/registry/datasets/byDimensions/tables.py index 029997ba64..7254c8d104 100644 --- a/python/lsst/daf/butler/registry/datasets/byDimensions/tables.py +++ b/python/lsst/daf/butler/registry/datasets/byDimensions/tables.py @@ -38,11 +38,13 @@ ) from collections import namedtuple -from typing import Any +from typing import Any, TypeAlias import sqlalchemy +from lsst.utils.classes import immutable from .... import ddl +from ...._utilities.thread_safe_cache import ThreadSafeCache from ....dimensions import DimensionGroup, DimensionUniverse, GovernorDimension, addDimensionForeignKey from ....timespan_database_representation import TimespanDatabaseRepresentation from ...interfaces import CollectionManager, Database, VersionTuple @@ -449,10 +451,17 @@ def makeCalibTableSpec( return tableSpec +DynamicTablesCache: TypeAlias = ThreadSafeCache[str, sqlalchemy.Table] + + +@immutable class DynamicTables: """A struct that holds the "dynamic" tables common to dataset types that share the same dimensions. + Objects of this class may be shared between multiple threads, so it must be + immutable to prevent concurrency issues. + Parameters ---------- dimensions : `DimensionGroup` @@ -477,8 +486,9 @@ def __init__( self.dimensions_key = dimensions_key self.tags_name = tags_name self.calibs_name = calibs_name - self._tags_table: sqlalchemy.Table | None = None - self._calibs_table: sqlalchemy.Table | None = None + + def copy(self, calibs_name: str) -> DynamicTables: + return DynamicTables(self._dimensions, self.dimensions_key, self.tags_name, calibs_name) @classmethod def from_dimensions_key( @@ -509,7 +519,7 @@ def from_dimensions_key( calibs_name=makeCalibTableName(dimensions_key) if is_calibration else None, ) - def create(self, db: Database, collections: type[CollectionManager]) -> None: + def create(self, db: Database, collections: type[CollectionManager], cache: DynamicTablesCache) -> None: """Create the tables if they don't already exist. Parameters @@ -519,19 +529,30 @@ def create(self, db: Database, collections: type[CollectionManager]) -> None: collections : `type` [ `CollectionManager` ] Manager class for collections; used to create foreign key columns for collections. + cache : `DynamicTablesCache` + Cache used to store sqlalchemy Table objects. """ - if self._tags_table is None: - self._tags_table = db.ensureTableExists( + if cache.get(self.tags_name) is None: + cache.set_or_get( self.tags_name, - makeTagTableSpec(self._dimensions, collections), + db.ensureTableExists( + self.tags_name, + makeTagTableSpec(self._dimensions, collections), + ), ) - if self.calibs_name is not None and self._calibs_table is None: - self._calibs_table = db.ensureTableExists( + + if self.calibs_name is not None and cache.get(self.calibs_name) is None: + cache.set_or_get( self.calibs_name, - makeCalibTableSpec(self._dimensions, collections, db.getTimespanRepresentation()), + db.ensureTableExists( + self.calibs_name, + makeCalibTableSpec(self._dimensions, collections, db.getTimespanRepresentation()), + ), ) - def add_calibs(self, db: Database, collections: type[CollectionManager]) -> None: + def add_calibs( + self, db: Database, collections: type[CollectionManager], cache: DynamicTablesCache + ) -> DynamicTables: """Create a calibs table for a dataset type whose dimensions already have a tags table. @@ -542,14 +563,23 @@ def add_calibs(self, db: Database, collections: type[CollectionManager]) -> None collections : `type` [ `CollectionManager` ] Manager class for collections; used to create foreign key columns for collections. + cache : `DynamicTablesCache` + Cache used to store sqlalchemy Table objects. """ - self.calibs_name = makeCalibTableName(self.dimensions_key) - self._calibs_table = db.ensureTableExists( - self.calibs_name, - makeCalibTableSpec(self._dimensions, collections, db.getTimespanRepresentation()), + calibs_name = makeCalibTableName(self.dimensions_key) + cache.set_or_get( + calibs_name, + db.ensureTableExists( + calibs_name, + makeCalibTableSpec(self._dimensions, collections, db.getTimespanRepresentation()), + ), ) - def tags(self, db: Database, collections: type[CollectionManager]) -> sqlalchemy.Table: + return self.copy(calibs_name=calibs_name) + + def tags( + self, db: Database, collections: type[CollectionManager], cache: DynamicTablesCache + ) -> sqlalchemy.Table: """Return the "tags" table that associates datasets with data IDs in TAGGED and RUN collections. @@ -563,21 +593,27 @@ def tags(self, db: Database, collections: type[CollectionManager]) -> sqlalchemy collections : `type` [ `CollectionManager` ] Manager class for collections; used to create foreign key columns for collections. + cache : `DynamicTablesCache` + Cache used to store sqlalchemy Table objects. Returns ------- table : `sqlalchemy.Table` SQLAlchemy table object. """ - if self._tags_table is None: - spec = makeTagTableSpec(self._dimensions, collections) - table = db.getExistingTable(self.tags_name, spec) - if table is None: - raise MissingDatabaseTableError(f"Table {self.tags_name!r} is missing from database schema.") - self._tags_table = table - return self._tags_table - - def calibs(self, db: Database, collections: type[CollectionManager]) -> sqlalchemy.Table: + table = cache.get(self.tags_name) + if table is not None: + return table + + spec = makeTagTableSpec(self._dimensions, collections) + table = db.getExistingTable(self.tags_name, spec) + if table is None: + raise MissingDatabaseTableError(f"Table {self.tags_name!r} is missing from database schema.") + return cache.set_or_get(self.tags_name, table) + + def calibs( + self, db: Database, collections: type[CollectionManager], cache: DynamicTablesCache + ) -> sqlalchemy.Table: """Return the "calibs" table that associates datasets with data IDs and timespans in CALIBRATION collections. @@ -592,6 +628,8 @@ def calibs(self, db: Database, collections: type[CollectionManager]) -> sqlalche collections : `type` [ `CollectionManager` ] Manager class for collections; used to create foreign key columns for collections. + cache : `DynamicTablesCache` + Cache used to store sqlalchemy Table objects. Returns ------- @@ -601,12 +639,12 @@ def calibs(self, db: Database, collections: type[CollectionManager]) -> sqlalche assert ( self.calibs_name is not None ), "Dataset type should be checked to be calibration by calling code." - if self._calibs_table is None: - spec = makeCalibTableSpec(self._dimensions, collections, db.getTimespanRepresentation()) - table = db.getExistingTable(self.calibs_name, spec) - if table is None: - raise MissingDatabaseTableError( - f"Table {self.calibs_name!r} is missing from database schema." - ) - self._calibs_table = table - return self._calibs_table + table = cache.get(self.calibs_name) + if table is not None: + return table + + spec = makeCalibTableSpec(self._dimensions, collections, db.getTimespanRepresentation()) + table = db.getExistingTable(self.calibs_name, spec) + if table is None: + raise MissingDatabaseTableError(f"Table {self.calibs_name!r} is missing from database schema.") + return cache.set_or_get(self.calibs_name, table)