Skip to content

Commit

Permalink
Centralize tags table cache
Browse files Browse the repository at this point in the history
In preparation for sharing DatasetTypeCache between threads, make its inner DynamicTables values immutable.  The mutable portion moved to a separate cache inside DatasetTypeCache.

As a side effect, this reduces the number of times we go to the DB to check for the existence of tag and calib tables.
  • Loading branch information
dhirving committed Dec 9, 2024
1 parent 59b45c5 commit c5ae328
Show file tree
Hide file tree
Showing 3 changed files with 105 additions and 55 deletions.
3 changes: 3 additions & 0 deletions python/lsst/daf/butler/registry/_dataset_type_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,9 @@ class DatasetTypeCache:
"""

def __init__(self) -> None:
from .datasets.byDimensions.tables import DynamicTablesCache

self.tables = DynamicTablesCache()
self._by_name_cache: dict[str, tuple[DatasetType, int]] = {}
self._by_dimensions_cache: dict[DimensionGroup, DynamicTables] = {}
self._full = False
Expand Down
51 changes: 30 additions & 21 deletions python/lsst/daf/butler/registry/datasets/byDimensions/_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def update_dynamic_tables(self, current: DynamicTables) -> DynamicTables:
else:
# Some previously-cached dataset type had the same dimensions
# but was not a calibration.
current.calibs_name = self.calib_table_name
current = current.copy(calibs_name=self.calib_table_name)
# If some previously-cached dataset type was a calibration but this
# one isn't, we don't want to forget the calibs table.
return current
Expand Down Expand Up @@ -326,9 +326,13 @@ def register_dataset_type(self, dataset_type: DatasetType) -> bool:
dynamic_tables = DynamicTables.from_dimensions_key(
dataset_type.dimensions, dimensions_key, dataset_type.isCalibration()
)
dynamic_tables.create(self._db, type(self._collections))
dynamic_tables.create(
self._db, type(self._collections), self._caching_context.dataset_types.tables
)
elif dataset_type.isCalibration() and dynamic_tables.calibs_name is None:
dynamic_tables.add_calibs(self._db, type(self._collections))
dynamic_tables = dynamic_tables.add_calibs(
self._db, type(self._collections), self._caching_context.dataset_types.tables
)
row, inserted = self._db.sync(
self._static.dataset_type,
keys={"name": dataset_type.name},
Expand Down Expand Up @@ -454,7 +458,7 @@ def getDatasetRef(self, id: DatasetId) -> DatasetRef | None:
# This query could return multiple rows (one for each tagged
# collection the dataset is in, plus one for its run collection),
# and we don't care which of those we get.
tags_table = dynamic_tables.tags(self._db, type(self._collections))
tags_table = self._get_tags_table(dynamic_tables)
data_id_sql = (
tags_table.select()
.where(
Expand Down Expand Up @@ -553,9 +557,10 @@ def _fetch_dataset_types(self) -> list[DatasetType]:
for record in records:
cache_data.append((record.dataset_type, record.dataset_type_id))
if (dynamic_tables := cache_dimensions_data.get(record.dataset_type.dimensions)) is None:
cache_dimensions_data[record.dataset_type.dimensions] = record.make_dynamic_tables()
tables = record.make_dynamic_tables()
else:
record.update_dynamic_tables(dynamic_tables)
tables = record.update_dynamic_tables(dynamic_tables)
cache_dimensions_data[record.dataset_type.dimensions] = tables
self._caching_context.dataset_types.set(
cache_data, full=True, dimensions_data=cache_dimensions_data.items(), dimensions_full=True
)
Expand Down Expand Up @@ -684,7 +689,7 @@ def insert(
for dataId, row in zip(data_id_list, rows, strict=True)
]
# Insert those rows into the tags table.
self._db.insert(storage.dynamic_tables.tags(self._db, type(self._collections)), *tagsRows)
self._db.insert(self._get_tags_table(storage.dynamic_tables), *tagsRows)

return [
DatasetRef(
Expand Down Expand Up @@ -767,9 +772,7 @@ def import_(
summary.add_datasets(refs)
self._summaries.update(run, [storage.dataset_type_id], summary)
# Copy from temp table into tags table.
self._db.insert(
storage.dynamic_tables.tags(self._db, type(self._collections)), select=tmp_tags.select()
)
self._db.insert(self._get_tags_table(storage.dynamic_tables), select=tmp_tags.select())
return refs

def _validate_import(
Expand All @@ -793,7 +796,7 @@ def _validate_import(
Raise if new datasets conflict with existing ones.
"""
dataset = self._static.dataset
tags = storage.dynamic_tables.tags(self._db, type(self._collections))
tags = self._get_tags_table(storage.dynamic_tables)
collection_fkey_name = self._collections.getCollectionForeignKeyName()

# Check that existing datasets have the same dataset type and
Expand Down Expand Up @@ -943,7 +946,7 @@ def associate(
# inserted there.
self._summaries.update(collection, [storage.dataset_type_id], summary)
# Update the tag table itself.
self._db.replace(storage.dynamic_tables.tags(self._db, type(self._collections)), *rows)
self._db.replace(self._get_tags_table(storage.dynamic_tables), *rows)

def disassociate(
self, dataset_type: DatasetType, collection: CollectionRecord, datasets: Iterable[DatasetRef]
Expand All @@ -964,7 +967,7 @@ def disassociate(
for dataset in datasets
]
self._db.delete(
storage.dynamic_tables.tags(self._db, type(self._collections)),
self._get_tags_table(storage.dynamic_tables),
["dataset_id", self._collections.getCollectionForeignKeyName()],
*rows,
)
Expand Down Expand Up @@ -1015,7 +1018,7 @@ def certify(
# inserted there.
self._summaries.update(collection, [storage.dataset_type_id], summary)
# Update the association table itself.
calibs_table = storage.dynamic_tables.calibs(self._db, type(self._collections))
calibs_table = self._get_calibs_table(storage.dynamic_tables)
if TimespanReprClass.hasExclusionConstraint():
# Rely on database constraint to enforce invariants; we just
# reraise the exception for consistency across DB engines.
Expand Down Expand Up @@ -1099,7 +1102,7 @@ def decertify(
rows_to_insert = []
# Acquire a table lock to ensure there are no concurrent writes
# between the SELECT and the DELETE and INSERT queries based on it.
calibs_table = storage.dynamic_tables.calibs(self._db, type(self._collections))
calibs_table = self._get_calibs_table(storage.dynamic_tables)
with self._db.transaction(lock=[calibs_table], savepoint=True):
# Enter SqlQueryContext in case we need to use a temporary table to
# include the give data IDs in the query (see similar block in
Expand Down Expand Up @@ -1186,7 +1189,7 @@ def make_relation(
tag_relation: Relation | None = None
calib_relation: Relation | None = None
if collection_types != {CollectionType.CALIBRATION}:
tags_table = storage.dynamic_tables.tags(self._db, type(self._collections))
tags_table = self._get_tags_table(storage.dynamic_tables)
# We'll need a subquery for the tags table if any of the given
# collections are not a CALIBRATION collection. This intentionally
# also fires when the list of collections is empty as a way to
Expand Down Expand Up @@ -1214,7 +1217,7 @@ def make_relation(
# If at least one collection is a CALIBRATION collection, we'll
# need a subquery for the calibs table, and could include the
# timespan as a result or constraint.
calibs_table = storage.dynamic_tables.calibs(self._db, type(self._collections))
calibs_table = self._get_calibs_table(storage.dynamic_tables)
calibs_parts = sql.Payload[LogicalColumn](calibs_table.alias(f"{dataset_type.name}_calibs"))
if "timespan" in columns:
calibs_parts.columns_available[DatasetColumnTag(dataset_type.name, "timespan")] = (
Expand Down Expand Up @@ -1422,7 +1425,7 @@ def make_joins_builder(
# create a dummy subquery that we know will fail.
# We give the table an alias because it might appear multiple times
# in the same query, for different dataset types.
tags_table = storage.dynamic_tables.tags(self._db, type(self._collections)).alias(
tags_table = self._get_tags_table(storage.dynamic_tables).alias(
f"{dataset_type.name}_tags{'_union' if is_union else ''}"
)
tags_builder = self._finish_query_builder(
Expand All @@ -1441,7 +1444,7 @@ def make_joins_builder(
# If at least one collection is a CALIBRATION collection, we'll
# need a subquery for the calibs table, and could include the
# timespan as a result or constraint.
calibs_table = storage.dynamic_tables.calibs(self._db, type(self._collections)).alias(
calibs_table = self._get_calibs_table(storage.dynamic_tables).alias(
f"{dataset_type.name}_calibs{'_union' if is_union else ''}"
)
calibs_builder = self._finish_query_builder(
Expand Down Expand Up @@ -1616,14 +1619,14 @@ def refresh_collection_summaries(self, dataset_type: DatasetType) -> None:

# Query datasets tables for associated collections.
column_name = self._collections.getCollectionForeignKeyName()
tags_table = storage.dynamic_tables.tags(self._db, type(self._collections))
tags_table = self._get_tags_table(storage.dynamic_tables)
query: sqlalchemy.sql.expression.SelectBase = (
sqlalchemy.select(tags_table.columns[column_name])
.where(tags_table.columns.dataset_type_id == storage.dataset_type_id)
.distinct()
)
if dataset_type.isCalibration():
calibs_table = storage.dynamic_tables.calibs(self._db, type(self._collections))
calibs_table = self._get_calibs_table(storage.dynamic_tables)
query2 = (
sqlalchemy.select(calibs_table.columns[column_name])
.where(calibs_table.columns.dataset_type_id == storage.dataset_type_id)
Expand All @@ -1637,6 +1640,12 @@ def refresh_collection_summaries(self, dataset_type: DatasetType) -> None:
collections_to_delete = summary_collection_ids - collection_ids
self._summaries.delete_collections(storage.dataset_type_id, collections_to_delete)

def _get_tags_table(self, table: DynamicTables) -> sqlalchemy.Table:
return table.tags(self._db, type(self._collections), self._caching_context.dataset_types.tables)

def _get_calibs_table(self, table: DynamicTables) -> sqlalchemy.Table:
return table.calibs(self._db, type(self._collections), self._caching_context.dataset_types.tables)


def _create_case_expression_for_collections(
collections: Iterable[CollectionRecord], id_column: sqlalchemy.ColumnElement
Expand Down
106 changes: 72 additions & 34 deletions python/lsst/daf/butler/registry/datasets/byDimensions/tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,13 @@
)

from collections import namedtuple
from typing import Any
from typing import Any, TypeAlias

import sqlalchemy
from lsst.utils.classes import immutable

from .... import ddl
from ...._utilities.thread_safe_cache import ThreadSafeCache
from ....dimensions import DimensionGroup, DimensionUniverse, GovernorDimension, addDimensionForeignKey
from ....timespan_database_representation import TimespanDatabaseRepresentation
from ...interfaces import CollectionManager, Database, VersionTuple
Expand Down Expand Up @@ -449,10 +451,17 @@ def makeCalibTableSpec(
return tableSpec


DynamicTablesCache: TypeAlias = ThreadSafeCache[str, sqlalchemy.Table]


@immutable
class DynamicTables:
"""A struct that holds the "dynamic" tables common to dataset types that
share the same dimensions.
Objects of this class may be shared between multiple threads, so it must be
immutable to prevent concurrency issues.
Parameters
----------
dimensions : `DimensionGroup`
Expand All @@ -477,8 +486,9 @@ def __init__(
self.dimensions_key = dimensions_key
self.tags_name = tags_name
self.calibs_name = calibs_name
self._tags_table: sqlalchemy.Table | None = None
self._calibs_table: sqlalchemy.Table | None = None

def copy(self, calibs_name: str) -> DynamicTables:
return DynamicTables(self._dimensions, self.dimensions_key, self.tags_name, calibs_name)

@classmethod
def from_dimensions_key(
Expand Down Expand Up @@ -509,7 +519,7 @@ def from_dimensions_key(
calibs_name=makeCalibTableName(dimensions_key) if is_calibration else None,
)

def create(self, db: Database, collections: type[CollectionManager]) -> None:
def create(self, db: Database, collections: type[CollectionManager], cache: DynamicTablesCache) -> None:
"""Create the tables if they don't already exist.
Parameters
Expand All @@ -519,19 +529,30 @@ def create(self, db: Database, collections: type[CollectionManager]) -> None:
collections : `type` [ `CollectionManager` ]
Manager class for collections; used to create foreign key columns
for collections.
cache : `DynamicTablesCache`
Cache used to store sqlalchemy Table objects.
"""
if self._tags_table is None:
self._tags_table = db.ensureTableExists(
if cache.get(self.tags_name) is None:
cache.set_or_get(
self.tags_name,
makeTagTableSpec(self._dimensions, collections),
db.ensureTableExists(
self.tags_name,
makeTagTableSpec(self._dimensions, collections),
),
)
if self.calibs_name is not None and self._calibs_table is None:
self._calibs_table = db.ensureTableExists(

if self.calibs_name is not None and cache.get(self.calibs_name) is None:
cache.set_or_get(
self.calibs_name,
makeCalibTableSpec(self._dimensions, collections, db.getTimespanRepresentation()),
db.ensureTableExists(
self.calibs_name,
makeCalibTableSpec(self._dimensions, collections, db.getTimespanRepresentation()),
),
)

def add_calibs(self, db: Database, collections: type[CollectionManager]) -> None:
def add_calibs(
self, db: Database, collections: type[CollectionManager], cache: DynamicTablesCache
) -> DynamicTables:
"""Create a calibs table for a dataset type whose dimensions already
have a tags table.
Expand All @@ -542,14 +563,23 @@ def add_calibs(self, db: Database, collections: type[CollectionManager]) -> None
collections : `type` [ `CollectionManager` ]
Manager class for collections; used to create foreign key columns
for collections.
cache : `DynamicTablesCache`
Cache used to store sqlalchemy Table objects.
"""
self.calibs_name = makeCalibTableName(self.dimensions_key)
self._calibs_table = db.ensureTableExists(
self.calibs_name,
makeCalibTableSpec(self._dimensions, collections, db.getTimespanRepresentation()),
calibs_name = makeCalibTableName(self.dimensions_key)
cache.set_or_get(
calibs_name,
db.ensureTableExists(
calibs_name,
makeCalibTableSpec(self._dimensions, collections, db.getTimespanRepresentation()),
),
)

def tags(self, db: Database, collections: type[CollectionManager]) -> sqlalchemy.Table:
return self.copy(calibs_name=calibs_name)

def tags(
self, db: Database, collections: type[CollectionManager], cache: DynamicTablesCache
) -> sqlalchemy.Table:
"""Return the "tags" table that associates datasets with data IDs in
TAGGED and RUN collections.
Expand All @@ -563,21 +593,27 @@ def tags(self, db: Database, collections: type[CollectionManager]) -> sqlalchemy
collections : `type` [ `CollectionManager` ]
Manager class for collections; used to create foreign key columns
for collections.
cache : `DynamicTablesCache`
Cache used to store sqlalchemy Table objects.
Returns
-------
table : `sqlalchemy.Table`
SQLAlchemy table object.
"""
if self._tags_table is None:
spec = makeTagTableSpec(self._dimensions, collections)
table = db.getExistingTable(self.tags_name, spec)
if table is None:
raise MissingDatabaseTableError(f"Table {self.tags_name!r} is missing from database schema.")
self._tags_table = table
return self._tags_table

def calibs(self, db: Database, collections: type[CollectionManager]) -> sqlalchemy.Table:
table = cache.get(self.tags_name)
if table is not None:
return table

spec = makeTagTableSpec(self._dimensions, collections)
table = db.getExistingTable(self.tags_name, spec)
if table is None:
raise MissingDatabaseTableError(f"Table {self.tags_name!r} is missing from database schema.")
return cache.set_or_get(self.tags_name, table)

def calibs(
self, db: Database, collections: type[CollectionManager], cache: DynamicTablesCache
) -> sqlalchemy.Table:
"""Return the "calibs" table that associates datasets with data IDs and
timespans in CALIBRATION collections.
Expand All @@ -592,6 +628,8 @@ def calibs(self, db: Database, collections: type[CollectionManager]) -> sqlalche
collections : `type` [ `CollectionManager` ]
Manager class for collections; used to create foreign key columns
for collections.
cache : `DynamicTablesCache`
Cache used to store sqlalchemy Table objects.
Returns
-------
Expand All @@ -601,12 +639,12 @@ def calibs(self, db: Database, collections: type[CollectionManager]) -> sqlalche
assert (
self.calibs_name is not None
), "Dataset type should be checked to be calibration by calling code."
if self._calibs_table is None:
spec = makeCalibTableSpec(self._dimensions, collections, db.getTimespanRepresentation())
table = db.getExistingTable(self.calibs_name, spec)
if table is None:
raise MissingDatabaseTableError(
f"Table {self.calibs_name!r} is missing from database schema."
)
self._calibs_table = table
return self._calibs_table
table = cache.get(self.calibs_name)
if table is not None:
return table

spec = makeCalibTableSpec(self._dimensions, collections, db.getTimespanRepresentation())
table = db.getExistingTable(self.calibs_name, spec)
if table is None:
raise MissingDatabaseTableError(f"Table {self.calibs_name!r} is missing from database schema.")
return cache.set_or_get(self.calibs_name, table)

0 comments on commit c5ae328

Please sign in to comment.