Skip to content

Commit

Permalink
Fix parsing list[openmm.unit.Quantity] and parse openmm.Vec3 (#36)
Browse files Browse the repository at this point in the history
* TST: Reproduce issue #35

* FIX: Work around OpenMM list-wrapping
  • Loading branch information
mattwthompson authored Feb 1, 2024
1 parent ceadf5e commit 02bc96b
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 3 deletions.
28 changes: 27 additions & 1 deletion openff/models/tests/test_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import numpy as np
import pytest
from openff.units import unit
from openff.units import Quantity, unit
from openff.utilities.testing import skip_if_missing

from openff.models.exceptions import UnitValidationError
Expand Down Expand Up @@ -305,3 +305,29 @@ def test_from_omm_quantity():

with pytest.raises(UnitValidationError):
_from_omm_quantity(True * openmm.unit.femtosecond)


@skip_if_missing("openmm.unit")
def test_from_omm_box_vectors():
"""Reproduce issue #35."""
import openmm
import openmm.unit

# mimic the output of getDefaultPeriodicBoxVectors, which returns
# list[openmm.unit.Quantity[openmm.unit.Vec3]]
box_vectors = [
openmm.unit.Quantity(openmm.Vec3(x=4, y=0, z=0), openmm.unit.nanometer),
openmm.unit.Quantity(openmm.Vec3(x=0, y=2, z=0), openmm.unit.nanometer),
openmm.unit.Quantity(openmm.Vec3(x=0, y=0, z=5), openmm.unit.nanometer),
]

validated = ArrayQuantity.validate_type(box_vectors)

assert isinstance(validated, Quantity)

assert validated.shape == (3, 3)
assert validated.m.shape == (3, 3)
assert validated.units == unit.nanometer

for index, value in enumerate([4, 2, 5]):
assert validated.m[index][index] == value
17 changes: 15 additions & 2 deletions openff/models/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import TYPE_CHECKING, Any, Dict

import numpy as np
from openff.units import unit
from openff.units import Quantity, unit
from openff.utilities import has_package, requires_package

from openff.models.exceptions import (
Expand Down Expand Up @@ -99,7 +99,9 @@ def _from_omm_quantity(val: "openmm.unit.Quantity") -> unit.Quantity:
unit_ = val.unit
return float(val_) * unit.Unit(str(unit_))
# Here is where the toolkit's ValidatedList could go, if present in the environment
elif type(val_) in {tuple, list, np.ndarray}:
elif (type(val_) in {tuple, list, np.ndarray}) or (
type(val_).__module__ == "openmm.vec3"
):
array = np.asarray(val_)
return array * unit.Unit(str(unit_))
elif isinstance(val_, (float, int)) and type(val_).__module__ == "numpy":
Expand Down Expand Up @@ -182,10 +184,21 @@ def validate_type(cls, val):
unit_ = getattr(cls, "__unit__", Any)
if unit_ is Any:
if isinstance(val, (list, np.ndarray)):
# Work around a special case in which val might be list[openmm.unit.Quantity]
if {type(element).__module__ for element in val} == {
"openmm.unit.quantity"
}:
unit_ = _from_omm_quantity(val[-1]).units
return Quantity(
[_from_omm_quantity(element) for element in val],
units=unit_,
)

# TODO: Can this exception be raised with knowledge of the field it's in?
raise MissingUnitError(
f"Value {val} needs to be tagged with a unit"
)

elif isinstance(val, unit.Quantity):
# TODO: This might be a redundant cast causing wasted CPU time.
# But maybe it handles pint vs openff.units.unit?
Expand Down

0 comments on commit 02bc96b

Please sign in to comment.