Skip to content

Commit

Permalink
Merge pull request #44 from openforcefield/ensure-quantity
Browse files Browse the repository at this point in the history
Add `ensure_quantity` helper function
  • Loading branch information
mattwthompson authored Sep 16, 2022
2 parents ba3d5b9 + c99d04e commit 12dab1c
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 3 deletions.
56 changes: 54 additions & 2 deletions openff/units/openmm.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
"""
Functions for converting between OpenFF and OpenMM units
"""

import ast
import operator as op
from typing import TYPE_CHECKING, List
from typing import TYPE_CHECKING, List, Literal, Union

from openff.utilities import has_package, requires_package

Expand All @@ -21,6 +20,7 @@
"to_openmm",
"openmm_unit_to_string",
"string_to_openmm_unit",
"ensure_quantity",
]

if has_package("openmm.unit") or TYPE_CHECKING:
Expand Down Expand Up @@ -191,3 +191,55 @@ def to_openmm_inner(quantity) -> "openmm_unit.Quantity":
return to_openmm_inner(quantity)
except MissingOpenMMUnitError:
return to_openmm_inner(quantity.to_base_units())


@requires_package("openmm.unit")
def _ensure_openmm_quantity(
unknown_quantity: Union[Quantity, "openmm_unit.Quantity"]
) -> "openmm_unit.Quantity":
if "openmm" in str(type(unknown_quantity)):
from openmm import unit as openmm_unit

if isinstance(unknown_quantity, openmm_unit.Quantity):
return unknown_quantity
else:
raise ValueError(
f"Failed to process input of type {type(unknown_quantity)}."
)
elif isinstance(unknown_quantity, Quantity):
return to_openmm(unknown_quantity)
else:
raise ValueError(f"Failed to process input of type {type(unknown_quantity)}.")


def _ensure_openff_quantity(
unknown_quantity: Union[Quantity, "openmm_unit.Quantity"]
) -> Quantity:
if isinstance(unknown_quantity, Quantity):
return unknown_quantity
elif "openmm" in str(type(unknown_quantity)):
from openmm import unit as openmm_unit

if isinstance(unknown_quantity, openmm_unit.Quantity):
return from_openmm(unknown_quantity)
else:
raise ValueError(
f"Failed to process input of type {type(unknown_quantity)}."
)
else:
raise Exception


def ensure_quantity(
unknown_quantity: Union[Quantity, "openmm_unit.Quantity"],
type_to_ensure: Literal["openmm", "openff"],
) -> Union[Quantity, "openmm_unit.Quantity"]:
if type_to_ensure == "openmm":
return _ensure_openmm_quantity(unknown_quantity)
elif type_to_ensure == "openff":
return _ensure_openff_quantity(unknown_quantity)
else:
raise ValueError(
f"Unsupported `type_to_ensure` found. Given {type_to_ensure}, "
"expected 'openff' or 'openmm'."
)
32 changes: 31 additions & 1 deletion openff/units/tests/test_openmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from openff.units import unit
from openff.units.exceptions import NoneQuantityError, NoneUnitError
from openff.units.openmm import from_openmm
from openff.units.openmm import ensure_quantity, from_openmm

if has_package("openmm.unit"):
from openmm import unit as openmm_unit
Expand Down Expand Up @@ -176,3 +176,33 @@ def test_openmm_unit_constants(self, from_openff_quantity, to_openmm_quantity):
* to_openmm_quantity
* openmm_unit.dimensionless
)


@skip_if_missing("openmm.unit")
class TestEnsureType:
from openmm import unit as openmm_unit

from openff.units import unit

@pytest.mark.parametrize(
"registry",
["openmm", "openff"],
)
def test_ensure_units(self, registry):
x = unit.Quantity(4.0, unit.angstrom)
y = openmm_unit.Quantity(4.0, openmm_unit.angstrom)

assert ensure_quantity(x, registry) == ensure_quantity(y, registry)

def test_unsupported_type(self):
x = unit.Quantity(4.0, unit.angstrom)

with pytest.raises(ValueError, match="Unsupported.*type_to_ensure.*pint"):
ensure_quantity(x, "pint")

def test_short_circuit(self):
x = unit.Quantity(4.0, unit.angstrom)
y = openmm_unit.Quantity(4.0, openmm_unit.angstrom)

assert id(ensure_quantity(x, "openff")) == id(x)
assert id(ensure_quantity(y, "openmm")) == id(y)

0 comments on commit 12dab1c

Please sign in to comment.