From 6ae6df1bf6ce237568a9ae6a261d73095344d5aa Mon Sep 17 00:00:00 2001 From: Jim Bosch Date: Sun, 17 Dec 2023 09:35:59 -0500 Subject: [PATCH] Add set-like, dict-backed container for DimensionRecords. --- python/lsst/daf/butler/dimensions/__init__.py | 1 + .../lsst/daf/butler/dimensions/_record_set.py | 469 ++++++++++++++++++ .../daf/butler/dimensions/_record_table.py | 8 + tests/test_dimension_record_containers.py | 162 +++++- 4 files changed, 638 insertions(+), 2 deletions(-) create mode 100644 python/lsst/daf/butler/dimensions/_record_set.py diff --git a/python/lsst/daf/butler/dimensions/__init__.py b/python/lsst/daf/butler/dimensions/__init__.py index ab4626a59c..5494a7efed 100644 --- a/python/lsst/daf/butler/dimensions/__init__.py +++ b/python/lsst/daf/butler/dimensions/__init__.py @@ -43,6 +43,7 @@ from ._graph import * from ._group import * from ._packer import * +from ._record_set import * from ._record_table import * from ._records import * from ._schema import * diff --git a/python/lsst/daf/butler/dimensions/_record_set.py b/python/lsst/daf/butler/dimensions/_record_set.py new file mode 100644 index 0000000000..3871cf0788 --- /dev/null +++ b/python/lsst/daf/butler/dimensions/_record_set.py @@ -0,0 +1,469 @@ +# This file is part of daf_butler. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (http://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This software is dual licensed under the GNU General Public License and also +# under a 3-clause BSD license. Recipients may choose which of these licenses +# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, +# respectively. If you choose the GPL option then the following text applies +# (but note that there is still no warranty even if you opt for BSD instead): +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +from __future__ import annotations + +__all__ = ("DimensionRecordSet", "DimensionRecordFactory") + +from collections.abc import Collection, Iterable, Iterator +from typing import TYPE_CHECKING, Any, Protocol, final + +from ._coordinate import DataCoordinate, DataIdValue +from ._records import DimensionRecord + +if TYPE_CHECKING: + from ._elements import DimensionElement + from ._universe import DimensionUniverse + + +class DimensionRecordFactory(Protocol): + """Protocol for a callback that can be used to create a dimension record + to add to a `DimensionRecordSet` when a search for an existing one fails. + """ + + def __call__( + self, record_class: type[DimensionRecord], required_values: tuple[DataIdValue, ...] + ) -> DimensionRecord: + """Make a new `DimensionRecord` instance. + + Parameters + ---------- + record_class : `type` [ `DimensionRecord` ] + A concrete `DimensionRecord` subclass. + required_values : `tuple` + Tuple of data ID values, corresponding to + ``record_class.definition.required``. + """ + ... # pragma: no cover + + +def fail_record_lookup( + record_class: type[DimensionRecord], required_values: tuple[DataIdValue, ...] +) -> DimensionRecord: + """Raise `LookupError` to indicate that a `DimensionRecord` could not be + found or created. + + This is intended for use as the default value for arguments that take a + `DimensionRecordFactory` callback. + + Parameters + ---------- + record_class : `type` [ `DimensionRecord` ] + Type of record to create. + required_values : `tuple` + Tuple of data ID required values that are sufficient to identify a + record that exists in the data repository. + + Returns + ------- + record : `DimensionRecord` + Never returned; this function always raises `LookupError`. + """ + raise LookupError( + f"No {record_class.definition.name!r} record with data ID " + f"{DataCoordinate.from_required_values(record_class.definition.minimal_group, required_values)}." + ) + + +@final +class DimensionRecordSet(Collection[DimensionRecord]): # numpydoc ignore=PR01 + """A mutable set-like container specialized for `DimensionRecord` objects. + + Parameters + ---------- + element : `DimensionElement` or `str`, optional + The dimension element that defines the records held by this set. If + not a `DimensionElement` instance, ``universe`` must be provided. + records : `~collections.abc.Iterable` [ `DimensionRecord` ], optional + Dimension records to add to the set. + universe : `DimensionUniverse`, optional + Object that defines all dimensions. Ignored if ``element`` is a + `DimensionElement` instance. + + Notes + ----- + `DimensionRecordSet` maintains its insertion order (like `dict`, and unlike + `set`). + + `DimensionRecordSet` implements `collections.abc.Collection` but not + `collections.abc.Set` because the latter would require interoperability + with all other `~collections.abc.Set` implementations rather than just + `DimensionRecordSet`, and that adds a lot of complexity without much clear + value. To help make this clear to type checkers it implements only the + named-method versions of these operations (e.g. `issubset`) rather than the + operator special methods (e.g. ``__le__``). + + `DimensionRecord` equality is defined in terms of a record's data ID fields + only, and `DimensionRecordSet` does not generally specify which record + "wins" when two records with the same data ID interact (e.g. in + `intersection`). The `add` and `update` methods are notable exceptions: + they always replace the existing record with the new one. + + Dimension records can also be held by `DimensionRecordTable`, which + provides column-oriented access and Arrow interoperability. + """ + + def __init__( + self, + element: DimensionElement | str, + records: Iterable[DimensionRecord] = (), + universe: DimensionUniverse | None = None, + *, + _by_required_values: dict[tuple[DataIdValue, ...], DimensionRecord] | None = None, + ): + if isinstance(element, str): + if universe is None: + raise TypeError("'universe' must be provided if 'element' is not a DimensionElement.") + element = universe[element] + else: + universe = element.universe + if _by_required_values is None: + _by_required_values = {} + self._record_type = element.RecordClass + self._by_required_values = _by_required_values + self._dimensions = element.minimal_group + self.update(records) + + @property + def element(self) -> DimensionElement: + """Name of the dimension element these records correspond to.""" + return self._record_type.definition + + def __contains__(self, key: object) -> bool: + match key: + case DimensionRecord() if key.definition == self.element: + required_values = key.dataId.required_values + case DataCoordinate() if key.dimensions == self.element.minimal_group: + required_values = key.required_values + case _: + return False + return required_values in self._by_required_values + + def __len__(self) -> int: + return len(self._by_required_values) + + def __iter__(self) -> Iterator[DimensionRecord]: + return iter(self._by_required_values.values()) + + def __eq__(self, other: object) -> bool: + if not isinstance(other, DimensionRecordSet): + return False + return ( + self._record_type is other._record_type + and self._by_required_values.keys() == other._by_required_values.keys() + ) + + def issubset(self, other: DimensionRecordSet) -> bool: + """Test whether all elements in ``self`` are in ``other``. + + Parameters + ---------- + other : `DimensionRecordSet` + Another record set with the same record type. + + Returns + ------- + issubset ; `bool` + Whether all elements in ``self`` are in ``other``. + """ + if self._record_type is not other._record_type: + raise ValueError( + "Invalid comparison between dimension record sets for elements " + f"{self.element.name!r} and {other.element.name!r}." + ) + return self._by_required_values.keys() <= other._by_required_values.keys() + + def issuperset(self, other: DimensionRecordSet) -> bool: + """Test whether all elements in ``other`` are in ``self``. + + Parameters + ---------- + other : `DimensionRecordSet` + Another record set with the same record type. + + Returns + ------- + issuperset ; `bool` + Whether all elements in ``other`` are in ``self``. + """ + if self._record_type is not other._record_type: + raise ValueError( + "Invalid comparison between dimension record sets for elements " + f"{self.element.name!r} and {other.element.name!r}." + ) + return self._by_required_values.keys() >= other._by_required_values.keys() + + def isdisjoint(self, other: DimensionRecordSet) -> bool: + """Test whether the intersection of ``self`` and ``other`` is empty. + + Parameters + ---------- + other : `DimensionRecordSet` + Another record set with the same record type. + + Returns + ------- + isdisjoint ; `bool` + Whether the intersection of ``self`` and ``other`` is empty. + """ + if self._record_type is not other._record_type: + raise ValueError( + "Invalid comparison between dimension record sets for elements " + f"{self.element.name!r} and {other.element.name!r}." + ) + return self._by_required_values.keys().isdisjoint(other._by_required_values.keys()) + + def intersection(self, other: DimensionRecordSet) -> DimensionRecordSet: + """Return a new set with only records that are in both ``self`` and + ``other``. + + Parameters + ---------- + other : `DimensionRecordSet` + Another record set with the same record type. + + Returns + ------- + intersection : `DimensionRecordSet` + A new record set with all elements in both sets. + """ + if self._record_type is not other._record_type: + raise ValueError( + "Invalid intersection between dimension record sets for elements " + f"{self.element.name!r} and {other.element.name!r}." + ) + return DimensionRecordSet( + self.element, + _by_required_values={ + k: v for k, v in self._by_required_values.items() if k in other._by_required_values + }, + ) + + def difference(self, other: DimensionRecordSet) -> DimensionRecordSet: + """Return a new set with only records that are in ``self`` and not in + ``other``. + + Parameters + ---------- + other : `DimensionRecordSet` + Another record set with the same record type. + + Returns + ------- + difference : `DimensionRecordSet` + A new record set with all elements ``self`` that are not in + ``other``. + """ + if self._record_type is not other._record_type: + raise ValueError( + "Invalid difference between dimension record sets for elements " + f"{self.element.name!r} and {other.element.name!r}." + ) + return DimensionRecordSet( + self.element, + _by_required_values={ + k: v for k, v in self._by_required_values.items() if k not in other._by_required_values + }, + ) + + def union(self, other: DimensionRecordSet) -> DimensionRecordSet: + """Return a new set with all records that are either in ``self`` or + ``other``. + + Parameters + ---------- + other : `DimensionRecordSet` + Another record set with the same record type. + + Returns + ------- + intersection : `DimensionRecordSet` + A new record set with all elements in either set. + """ + if self._record_type is not other._record_type: + raise ValueError( + "Invalid union between dimension record sets for elements " + f"{self.element.name!r} and {other.element.name!r}." + ) + return DimensionRecordSet( + self.element, + _by_required_values=self._by_required_values | other._by_required_values, + ) + + def find( + self, + data_id: DataCoordinate, + or_add: DimensionRecordFactory = fail_record_lookup, + ) -> DimensionRecord: + """Return the record with the given data ID. + + Parameters + ---------- + data_id : `DataCoordinate` + Data ID to match. + or_add : `DimensionRecordFactory` + Callback that is invoked if no existing record is found, to create + a new record that is added to the set and returned. The return + value of this callback is *not* checked to see if it is a valid + dimension record with the right element and data ID. + + Returns + ------- + record : `DimensionRecord` + Matching record. + + Raises + ------ + KeyError + Raised if no record with this data ID was found. + ValueError + Raised if the data ID did not have the right dimensions. + """ + if data_id.dimensions != self._dimensions: + raise ValueError( + f"data ID {data_id} has incorrect dimensions for dimension records for {self.element!r}." + ) + return self.find_with_required_values(data_id.required_values, or_add) + + def find_with_required_values( + self, required_values: tuple[DataIdValue, ...], or_add: DimensionRecordFactory = fail_record_lookup + ) -> DimensionRecord: + """Return the record whose data ID has the given required values. + + Parameters + ---------- + required_values : `tuple` [ `int` or `str` ] + Data ID values to match. + or_add : `DimensionRecordFactory` + Callback that is invoked if no existing record is found, to create + a new record that is added to the set and returned. The return + value of this callback is *not* checked to see if it is a valid + dimension record with the right element and data ID. + + Returns + ------- + record : `DimensionRecord` + Matching record. + + Raises + ------ + ValueError + Raised if the data ID did not have the right dimensions. + """ + if (result := self._by_required_values.get(required_values)) is None: + result = or_add(self._record_type, required_values) + self._by_required_values[required_values] = result + return result + + def add(self, value: DimensionRecord) -> None: + """Add a new record to the set. + + Parameters + ---------- + value : `DimensionRecord` + Record to add. + + Raises + ------ + ValueError + Raised if ``value.element != self.element``. + """ + if value.definition.name != self.element: + raise ValueError( + f"Cannot add record {value} for {value.definition.name!r} to set for {self.element!r}." + ) + self._by_required_values[value.dataId.required_values] = value + + def update(self, values: Iterable[DimensionRecord]) -> None: + """Add new records to the set. + + Parameters + ---------- + values : `~collections.abc.Iterable` [ `DimensionRecord` ] + Record to add. + + Raises + ------ + ValueError + Raised if ``value.element != self.element``. + """ + for value in values: + self.add(value) + + def update_from_data_coordinates(self, data_coordinates: Iterable[DataCoordinate]) -> None: + """Add records to the set by extracting and deduplicating them from + data coordinates. + + Parameters + ---------- + data_coordinates : `~collections.abc.Iterable` [ `DataCoordinate` ] + Data coordinates to extract from. `DataCoordinate.hasRecords` must + be `True`. + """ + for data_coordinate in data_coordinates: + if record := data_coordinate._record(self.element.name): + self._by_required_values[record.dataId.required_values] = record + + def discard(self, value: DimensionRecord | DataCoordinate) -> None: + """Remove a record if it exists. + + Parameters + ---------- + value : `DimensionRecord` or `DataCoordinate` + Record to remove, or its data ID. + """ + if isinstance(value, DimensionRecord): + value = value.dataId + if value.dimensions != self._dimensions: + raise ValueError(f"{value} has incorrect dimensions for dimension records for {self.element!r}.") + self._by_required_values.pop(value.required_values, None) + + def remove(self, value: DimensionRecord | DataCoordinate) -> None: + """Remove a record. + + Parameters + ---------- + value : `DimensionRecord` or `DataCoordinate` + Record to remove, or its data ID. + + Raises + ------ + KeyError + Raised if there is no matching record. + """ + if isinstance(value, DimensionRecord): + value = value.dataId + if value.dimensions != self._dimensions: + raise ValueError(f"{value} has incorrect dimensions for dimension records for {self.element!r}.") + del self._by_required_values[value.required_values] + + def pop(self) -> DimensionRecord: + """Remove and return an arbitrary record.""" + return self._by_required_values.popitem()[1] + + def __deepcopy__(self, memo: dict[str, Any]) -> DimensionRecordSet: + return DimensionRecordSet(self.element, _by_required_values=self._by_required_values.copy()) diff --git a/python/lsst/daf/butler/dimensions/_record_table.py b/python/lsst/daf/butler/dimensions/_record_table.py index a5ac1cc262..c7134ef544 100644 --- a/python/lsst/daf/butler/dimensions/_record_table.py +++ b/python/lsst/daf/butler/dimensions/_record_table.py @@ -62,6 +62,14 @@ class DimensionRecordTable: `make_arrow_schema` for this element. This argument is primarily intended to serve as the way to reconstruct a `DimensionRecordTable` that has been serialized to an Arrow-supported file or IPC format. + + Notes + ----- + `DimensionRecordTable` should generally have a smaller memory footprint + than `DimensionRecordSet` if its rows are unique, and it provides fast + column-oriented access and Arrow interoperability that `DimensionRecordSet` + lacks entirely. In other respects `DimensionRecordSet` is more + featureful and simpler to use efficiently. """ def __init__( diff --git a/tests/test_dimension_record_containers.py b/tests/test_dimension_record_containers.py index bc390cf159..c9840944ee 100644 --- a/tests/test_dimension_record_containers.py +++ b/tests/test_dimension_record_containers.py @@ -25,12 +25,13 @@ # You should have received a copy of the GNU General Public License # along with this program. If not, see . +import copy import os import unittest import pyarrow as pa import pyarrow.parquet as pq -from lsst.daf.butler import DimensionRecordTable, YamlRepoImportBackend +from lsst.daf.butler import DimensionRecordSet, DimensionRecordTable, YamlRepoImportBackend from lsst.daf.butler.registry import RegistryConfig, _RegistryFactory DIMENSION_DATA_FILES = [ @@ -55,9 +56,11 @@ def setUpClass(cls): backend.register() backend.load(datastore=None) cls.records = { - element: tuple(list(registry.queryDimensionRecords(element))) for element in ("visit", "skymap") + element: tuple(list(registry.queryDimensionRecords(element))) + for element in ("visit", "skymap", "patch") } cls.universe = registry.dimensions + cls.data_ids = list(registry.queryDataIds(["visit", "patch"]).expanded()) def test_record_table_schema_visit(self): """Test that the Arrow schema for 'visit' has the right types, @@ -195,6 +198,161 @@ def test_record_table_parquet_skymap(self): ) self.assertEqual(list(table1), list(table2)) + def test_record_set_const(self): + """Test attributes and methods of `DimensionRecordSet` that do not + modify the set. + + We use 'patch' records for this test because there are enough of them + to do nontrivial set-operation tests. + """ + element = self.universe["patch"] + records = self.records["patch"] + set1 = DimensionRecordSet(element, records[:7]) + self.assertEqual(set1, DimensionRecordSet("patch", records[:7], universe=self.universe)) + # DimensionRecordSets do not compare as equal with other set types, + # even with the same content. + self.assertNotEqual(set1, set(records[:7])) + with self.assertRaises(TypeError): + DimensionRecordSet("patch", records[:7]) + self.assertEqual(set1.element, self.universe["patch"]) + self.assertEqual(len(set1), 7) + self.assertEqual(list(set1), list(records[:7])) + self.assertIn(records[4], set1) + self.assertIn(records[5].dataId, set1) + self.assertNotIn(self.records["visit"][0], set1) + self.assertTrue(set1.issubset(DimensionRecordSet(element, records[:8]))) + self.assertFalse(set1.issubset(DimensionRecordSet(element, records[1:6]))) + with self.assertRaises(ValueError): + set1.issubset(DimensionRecordSet(self.universe["tract"])) + self.assertTrue(set1.issuperset(DimensionRecordSet(element, records[1:6]))) + self.assertFalse(set1.issuperset(DimensionRecordSet(element, records[:8]))) + with self.assertRaises(ValueError): + set1.issuperset(DimensionRecordSet(self.universe["tract"])) + self.assertTrue(set1.isdisjoint(DimensionRecordSet(element, records[7:]))) + self.assertFalse(set1.isdisjoint(DimensionRecordSet(element, records[5:8]))) + with self.assertRaises(ValueError): + set1.isdisjoint(DimensionRecordSet(self.universe["tract"])) + self.assertEqual( + set1.intersection(DimensionRecordSet(element, records[5:])), + DimensionRecordSet(element, records[5:7]), + ) + self.assertEqual( + set1.intersection(DimensionRecordSet(element, records[5:])), + DimensionRecordSet(element, records[5:7]), + ) + with self.assertRaises(ValueError): + set1.intersection(DimensionRecordSet(self.universe["tract"])) + self.assertEqual( + set1.difference(DimensionRecordSet(element, records[5:])), + DimensionRecordSet(element, records[:5]), + ) + with self.assertRaises(ValueError): + set1.difference(DimensionRecordSet(self.universe["tract"])) + self.assertEqual( + set1.union(DimensionRecordSet(element, records[5:9])), + DimensionRecordSet(element, records[:9]), + ) + with self.assertRaises(ValueError): + set1.union(DimensionRecordSet(self.universe["tract"])) + self.assertEqual(set1.find(records[0].dataId), records[0]) + with self.assertRaises(LookupError): + set1.find(self.records["patch"][8].dataId) + with self.assertRaises(ValueError): + set1.find(self.records["visit"][0].dataId) + self.assertEqual(set1.find_with_required_values(records[0].dataId.required_values), records[0]) + + def test_record_set_add(self): + """Test DimensionRecordSet.add.""" + set1 = DimensionRecordSet("patch", self.records["patch"][:2], universe=self.universe) + set1.add(self.records["patch"][2]) + with self.assertRaises(ValueError): + set1.add(self.records["visit"][0]) + self.assertEqual(set1, DimensionRecordSet("patch", self.records["patch"][:3], universe=self.universe)) + set1.add(self.records["patch"][2]) + self.assertEqual(list(set1), list(self.records["patch"][:3])) + + def test_record_set_find_or_add(self): + """Test DimensionRecordSet.find and find_with_required_values with + a 'or_add' callback. + """ + set1 = DimensionRecordSet("patch", self.records["patch"][:2], universe=self.universe) + set1.find(self.records["patch"][2].dataId, or_add=lambda _c, _r: self.records["patch"][2]) + with self.assertRaises(ValueError): + set1.find(self.records["visit"][0].dataId, or_add=lambda _c, _r: self.records["visit"][0]) + self.assertEqual(set1, DimensionRecordSet("patch", self.records["patch"][:3], universe=self.universe)) + + set1.find_with_required_values( + self.records["patch"][3].dataId.required_values, or_add=lambda _c, _r: self.records["patch"][3] + ) + self.assertEqual(set1, DimensionRecordSet("patch", self.records["patch"][:4], universe=self.universe)) + + def test_record_set_update_from_data_coordinates(self): + """Test DimensionRecordSet.update_from_data_coordinates.""" + set1 = DimensionRecordSet("patch", self.records["patch"][:2], universe=self.universe) + set1.update_from_data_coordinates(self.data_ids) + for data_id in self.data_ids: + self.assertIn(data_id.records["patch"], set1) + + def test_record_set_discard(self): + """Test DimensionRecordSet.discard.""" + set1 = DimensionRecordSet("patch", self.records["patch"][:2], universe=self.universe) + set2 = copy.deepcopy(set1) + # These discards should do nothing. + set1.discard(self.records["patch"][2]) + self.assertEqual(set1, set2) + set1.discard(self.records["patch"][2].dataId) + self.assertEqual(set1, set2) + with self.assertRaises(ValueError): + set1.discard(self.records["visit"][0]) + self.assertEqual(set1, set2) + with self.assertRaises(ValueError): + set1.discard(self.records["visit"][0].dataId) + self.assertEqual(set1, set2) + # These ones should remove a record from each set. + set1.discard(self.records["patch"][1]) + set2.discard(self.records["patch"][1].dataId) + self.assertEqual(set1, set2) + self.assertNotIn(self.records["patch"][1], set1) + self.assertNotIn(self.records["patch"][1], set2) + + def test_record_set_remove(self): + """Test DimensionRecordSet.remove.""" + set1 = DimensionRecordSet("patch", self.records["patch"][:2], universe=self.universe) + set2 = copy.deepcopy(set1) + # These removes should raise with strong exception safety. + with self.assertRaises(KeyError): + set1.remove(self.records["patch"][2]) + self.assertEqual(set1, set2) + with self.assertRaises(KeyError): + set1.remove(self.records["patch"][2].dataId) + self.assertEqual(set1, set2) + with self.assertRaises(ValueError): + set1.remove(self.records["visit"][0]) + self.assertEqual(set1, set2) + with self.assertRaises(ValueError): + set1.remove(self.records["visit"][0].dataId) + self.assertEqual(set1, set2) + # These ones should remove a record from each set. + set1.remove(self.records["patch"][1]) + set2.remove(self.records["patch"][1].dataId) + self.assertEqual(set1, set2) + self.assertNotIn(self.records["patch"][1], set1) + self.assertNotIn(self.records["patch"][1], set2) + + def test_record_set_pop(self): + """Test DimensionRecordSet.pop.""" + set1 = DimensionRecordSet("patch", self.records["patch"][:2], universe=self.universe) + set2 = copy.deepcopy(set1) + record1 = set1.pop() + set2.remove(record1) + self.assertNotIn(record1, set1) + self.assertEqual(set1, set2) + record2 = set1.pop() + set2.remove(record2) + self.assertNotIn(record2, set1) + self.assertEqual(set1, set2) + self.assertFalse(set1) + if __name__ == "__main__": unittest.main()