Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix parsing list[openmm.unit.Quantity] and parse openmm.Vec3 #36

Merged
merged 2 commits into from
Feb 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 27 additions & 1 deletion openff/models/tests/test_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import numpy as np
import pytest
from openff.units import unit
from openff.units import Quantity, unit
from openff.utilities.testing import skip_if_missing

from openff.models.exceptions import UnitValidationError
Expand Down Expand Up @@ -305,3 +305,29 @@ def test_from_omm_quantity():

with pytest.raises(UnitValidationError):
_from_omm_quantity(True * openmm.unit.femtosecond)


@skip_if_missing("openmm.unit")
def test_from_omm_box_vectors():
"""Reproduce issue #35."""
import openmm
import openmm.unit

# mimic the output of getDefaultPeriodicBoxVectors, which returns
# list[openmm.unit.Quantity[openmm.unit.Vec3]]
box_vectors = [
openmm.unit.Quantity(openmm.Vec3(x=4, y=0, z=0), openmm.unit.nanometer),
openmm.unit.Quantity(openmm.Vec3(x=0, y=2, z=0), openmm.unit.nanometer),
openmm.unit.Quantity(openmm.Vec3(x=0, y=0, z=5), openmm.unit.nanometer),
]

validated = ArrayQuantity.validate_type(box_vectors)

assert isinstance(validated, Quantity)

assert validated.shape == (3, 3)
assert validated.m.shape == (3, 3)
assert validated.units == unit.nanometer

for index, value in enumerate([4, 2, 5]):
assert validated.m[index][index] == value
17 changes: 15 additions & 2 deletions openff/models/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import TYPE_CHECKING, Any, Dict

import numpy as np
from openff.units import unit
from openff.units import Quantity, unit
from openff.utilities import has_package, requires_package

from openff.models.exceptions import (
Expand Down Expand Up @@ -99,7 +99,9 @@ def _from_omm_quantity(val: "openmm.unit.Quantity") -> unit.Quantity:
unit_ = val.unit
return float(val_) * unit.Unit(str(unit_))
# Here is where the toolkit's ValidatedList could go, if present in the environment
elif type(val_) in {tuple, list, np.ndarray}:
elif (type(val_) in {tuple, list, np.ndarray}) or (
type(val_).__module__ == "openmm.vec3"
):
array = np.asarray(val_)
return array * unit.Unit(str(unit_))
elif isinstance(val_, (float, int)) and type(val_).__module__ == "numpy":
Expand Down Expand Up @@ -182,10 +184,21 @@ def validate_type(cls, val):
unit_ = getattr(cls, "__unit__", Any)
if unit_ is Any:
if isinstance(val, (list, np.ndarray)):
# Work around a special case in which val might be list[openmm.unit.Quantity]
if {type(element).__module__ for element in val} == {
"openmm.unit.quantity"
}:
unit_ = _from_omm_quantity(val[-1]).units
return Quantity(
[_from_omm_quantity(element) for element in val],
units=unit_,
)

# TODO: Can this exception be raised with knowledge of the field it's in?
raise MissingUnitError(
f"Value {val} needs to be tagged with a unit"
)

elif isinstance(val, unit.Quantity):
# TODO: This might be a redundant cast causing wasted CPU time.
# But maybe it handles pint vs openff.units.unit?
Expand Down
Loading