Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add _quantity property for values with units #38

Merged
merged 3 commits into from
Sep 4, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ python_requires = >=3.7
install_requires =
pydantic[email]>=1.0
xmlschema>=1.0.16
Pint>=0.15

[options.package_data]
* = *.xsd
Expand Down
25 changes: 25 additions & 0 deletions src/ome_types/dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,21 @@
from textwrap import indent
from typing import TYPE_CHECKING, Any, Callable, Optional, Sequence, Type, Union

import pint
from pydantic import validator
from pydantic.dataclasses import _process_class

if TYPE_CHECKING:
from pydantic.dataclasses import DataclassType


ureg = pint.UnitRegistry(auto_reduce_dimensions=True)
ureg.define("reference_frame = [_reference_frame]")
ureg.define("@alias grade = gradian")
ureg.define("@alias astronomical_unit = ua")
ureg.define("line = inch / 12 = li")


class Sentinel:
"""Create singleton sentinel objects with a readable repr."""

Expand Down Expand Up @@ -89,6 +97,22 @@ def new_post_init(self: Any, *args: Any) -> None:
setattr(_cls, "__post_init__", new_post_init)


def add_quantities(_cls: Type[Any]) -> None:
value_fields = [f for f in dir(_cls) if f + "_unit" in dir(_cls)]
for field in value_fields:
setattr(_cls, field + "_quantity", quantity_property(field))


def quantity_property(field: str) -> property:
def quantity(self: Any) -> Optional[pint.Quantity]:
value = getattr(self, field)
if value is None:
return None
unit = getattr(self, field + "_unit").value.replace(' ', '_')
return ureg.Quantity(value, unit)
return property(quantity)


def modify_repr(_cls: Type[Any]) -> None:
"""Improved dataclass repr function.

Expand Down Expand Up @@ -153,6 +177,7 @@ def wrap(cls: Type[Any]) -> DataclassType:
if getattr(cls, "id", None) is AUTO_SEQUENCE:
setattr(cls, "validate_id", validate_id)
modify_post_init(cls)
add_quantities(cls)
if not repr:
modify_repr(cls)
return _process_class(cls, init, repr, eq, order, unsafe_hash, frozen, config)
Expand Down
62 changes: 62 additions & 0 deletions testing/test_units.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import pytest
from pydantic import ValidationError
from pint import DimensionalityError
from ome_types.model import Channel, Laser, Plane
from ome_types.dataclasses import ureg


def test_quantity_math():
"""Validate math on quantities with different but compatible units."""
channel = Channel(
excitation_wavelength=475,
excitation_wavelength_unit="nm",
emission_wavelength=530000,
emission_wavelength_unit="pm",
)
shift = (
channel.emission_wavelength_quantity - channel.excitation_wavelength_quantity
)
# Compare to a tolerance due to Pint internal factor representation.
assert abs(shift.to("nm").m - 55) < 1e-12


def test_invalid_unit():
"""Ensure incompatible units in constructor raises ValidationError."""
with pytest.raises(ValidationError):
Channel(
excitation_wavelength=475, excitation_wavelength_unit="kg",
)


def test_dimensionality_error():
"""Ensure math on incompatible units raises DimensionalityError."""
laser = Laser(
id="LightSource:1",
repetition_rate=10,
repetition_rate_unit="MHz",
wavelength=640,
)
with pytest.raises(DimensionalityError):
laser.repetition_rate_quantity + laser.wavelength_quantity


def test_reference_frame():
"""Validate reference_frame behavior."""
plane = Plane(
the_c=0,
the_t=0,
the_z=0,
position_x=1,
position_x_unit="reference frame",
position_y=2,
position_y_unit="mm",
)
# Verify two different ways that reference_frame and length are incompatible.
with pytest.raises(DimensionalityError):
plane.position_x_quantity + plane.position_y_quantity
product = plane.position_x_quantity * plane.position_y_quantity
assert not product.check("[area]")
# Verify that we can obtain a usable length if we know the conversion factor.
conversion_factor = ureg.Quantity(1, "micron/reference_frame")
position_x = plane.position_x_quantity * conversion_factor
assert position_x.check("[length]")