diff --git a/setup.cfg b/setup.cfg index 5db7b47d..41000901 100644 --- a/setup.cfg +++ b/setup.cfg @@ -30,6 +30,7 @@ python_requires = >=3.7 install_requires = pydantic[email]>=1.0 xmlschema>=1.0.16 + Pint>=0.15 [options.package_data] * = *.xsd diff --git a/src/ome_types/dataclasses.py b/src/ome_types/dataclasses.py index 475b641b..a9be1ae1 100644 --- a/src/ome_types/dataclasses.py +++ b/src/ome_types/dataclasses.py @@ -6,6 +6,7 @@ from textwrap import indent from typing import TYPE_CHECKING, Any, Callable, Optional, Sequence, Type, Union +import pint from pydantic import validator from pydantic.dataclasses import _process_class @@ -13,6 +14,13 @@ from pydantic.dataclasses import DataclassType +ureg = pint.UnitRegistry(auto_reduce_dimensions=True) +ureg.define("reference_frame = [_reference_frame]") +ureg.define("@alias grade = gradian") +ureg.define("@alias astronomical_unit = ua") +ureg.define("line = inch / 12 = li") + + class Sentinel: """Create singleton sentinel objects with a readable repr.""" @@ -89,6 +97,22 @@ def new_post_init(self: Any, *args: Any) -> None: setattr(_cls, "__post_init__", new_post_init) +def add_quantities(_cls: Type[Any]) -> None: + value_fields = [f for f in dir(_cls) if f + "_unit" in dir(_cls)] + for field in value_fields: + setattr(_cls, field + "_quantity", quantity_property(field)) + + +def quantity_property(field: str) -> property: + def quantity(self: Any) -> Optional[pint.Quantity]: + value = getattr(self, field) + if value is None: + return None + unit = getattr(self, field + "_unit").value.replace(' ', '_') + return ureg.Quantity(value, unit) + return property(quantity) + + def modify_repr(_cls: Type[Any]) -> None: """Improved dataclass repr function. @@ -153,6 +177,7 @@ def wrap(cls: Type[Any]) -> DataclassType: if getattr(cls, "id", None) is AUTO_SEQUENCE: setattr(cls, "validate_id", validate_id) modify_post_init(cls) + add_quantities(cls) if not repr: modify_repr(cls) return _process_class(cls, init, repr, eq, order, unsafe_hash, frozen, config) diff --git a/testing/test_units.py b/testing/test_units.py new file mode 100644 index 00000000..765d6ce0 --- /dev/null +++ b/testing/test_units.py @@ -0,0 +1,62 @@ +import pytest +from pydantic import ValidationError +from pint import DimensionalityError +from ome_types.model import Channel, Laser, Plane +from ome_types.dataclasses import ureg + + +def test_quantity_math(): + """Validate math on quantities with different but compatible units.""" + channel = Channel( + excitation_wavelength=475, + excitation_wavelength_unit="nm", + emission_wavelength=530000, + emission_wavelength_unit="pm", + ) + shift = ( + channel.emission_wavelength_quantity - channel.excitation_wavelength_quantity + ) + # Compare to a tolerance due to Pint internal factor representation. + assert abs(shift.to("nm").m - 55) < 1e-12 + + +def test_invalid_unit(): + """Ensure incompatible units in constructor raises ValidationError.""" + with pytest.raises(ValidationError): + Channel( + excitation_wavelength=475, excitation_wavelength_unit="kg", + ) + + +def test_dimensionality_error(): + """Ensure math on incompatible units raises DimensionalityError.""" + laser = Laser( + id="LightSource:1", + repetition_rate=10, + repetition_rate_unit="MHz", + wavelength=640, + ) + with pytest.raises(DimensionalityError): + laser.repetition_rate_quantity + laser.wavelength_quantity + + +def test_reference_frame(): + """Validate reference_frame behavior.""" + plane = Plane( + the_c=0, + the_t=0, + the_z=0, + position_x=1, + position_x_unit="reference frame", + position_y=2, + position_y_unit="mm", + ) + # Verify two different ways that reference_frame and length are incompatible. + with pytest.raises(DimensionalityError): + plane.position_x_quantity + plane.position_y_quantity + product = plane.position_x_quantity * plane.position_y_quantity + assert not product.check("[area]") + # Verify that we can obtain a usable length if we know the conversion factor. + conversion_factor = ureg.Quantity(1, "micron/reference_frame") + position_x = plane.position_x_quantity * conversion_factor + assert position_x.check("[length]")