Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ForceField.combine #1996

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 61 additions & 0 deletions openff/toolkit/_tests/test_forcefield.py
Original file line number Diff line number Diff line change
Expand Up @@ -778,6 +778,13 @@ def alkethoh_forcefield(self):
def force_field(self):
return ForceField(get_data_file_path("test_forcefields/test_forcefield.offxml"))

@pytest.fixture
def force_field_with_cosmetic_attributes(self):
return ForceField(
xml_ff_w_cosmetic_elements,
allow_cosmetic_attributes=True,
)


class TestForceField(_ForceFieldFixtures):
"""Test the ForceField class"""
Expand Down Expand Up @@ -1054,6 +1061,60 @@ def test_xml_string_roundtrip_keep_cosmetic(self):
assert 'cosmetic_element="why not?"' not in string_3
assert 'parameterize_eval="blah=blah2"' not in string_3

def test_combine_order_dependent(self):
assert hash(
ForceField("openff-1.3.0.offxml").combine(ForceField("openff-2.2.0.offxml"))
) != hash(
ForceField("openff-2.2.0.offxml").combine(ForceField("openff-1.3.0.offxml"))
)

def test_combine_same_force_field(self, force_field):
combined = force_field.combine(force_field)

for handler_name in force_field.registered_parameter_handlers:
n_parameters = len(force_field[handler_name].parameters)
assert len(combined[handler_name].parameters) == 2 * n_parameters

for parameter_index, parameter in enumerate(force_field[handler_name].parameters):
# __eq__ is undefined, comparing dicts should be close enough
assert combined[handler_name].parameters[parameter_index + n_parameters].to_dict() == parameter.to_dict()

assert hash(force_field) != hash(combined)

def test_combine_same_results_as_loading(self):
assert hash(
ForceField("openff-1.3.0.offxml").combine(ForceField("openff-2.2.0.offxml"))
) == hash(
ForceField("openff-1.3.0.offxml", "openff-2.2.0.offxml")
)

def test_combine_chain_calls(self, force_field):
"""Just test that nothing weird happens if squishing together twice."""
tripled = force_field.combine(force_field.combine(force_field))

assert len(tripled['vdW'].parameters) == 3 * len(force_field['vdW'].parameters)

def test_combine_basic_with_cosmetic_attributes(self, force_field_with_cosmetic_attributes):
combined = force_field_with_cosmetic_attributes.combine(
force_field_with_cosmetic_attributes, allow_cosmetic_attributes=True)

original_cosmetics = force_field_with_cosmetic_attributes['Bonds'].parameters[0]._cosmetic_attribs

# should be ['parameters', 'parameterize_eval']
assert len(original_cosmetics) == 2

assert combined['Bonds'].parameters[0]._cosmetic_attribs == original_cosmetics

def test_combine_errors_with_cosmetic_attributes(self, force_field_with_cosmetic_attributes):
with pytest.raises(
SMIRNOFFSpecError,
match="parameters.*k, length",
):
force_field_with_cosmetic_attributes.combine(
force_field_with_cosmetic_attributes,
allow_cosmetic_attributes=False,
)

def test_read_0_1_smirnoff(self):
"""Test reading an 0.1 spec OFFXML file"""
ForceField(
Expand Down
53 changes: 40 additions & 13 deletions openff/toolkit/typing/engines/smirnoff/forcefield.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,19 +10,6 @@
* Speed up overall import time by putting non-global imports only where they are needed

"""

__all__ = [
"MAX_SUPPORTED_VERSION",
"ForceField",
"ParameterHandlerRegistrationError",
"PartialChargeVirtualSitesError",
"SMIRNOFFAromaticityError",
"SMIRNOFFParseError",
"SMIRNOFFVersionError",
"get_available_force_fields",
]

import copy
import logging
import os
import pathlib
Expand Down Expand Up @@ -63,6 +50,18 @@
from openff.toolkit.utils.base_wrapper import ToolkitWrapper
from openff.toolkit.utils.toolkit_registry import ToolkitRegistry


__all__ = [
"MAX_SUPPORTED_VERSION",
"ForceField",
"ParameterHandlerRegistrationError",
"PartialChargeVirtualSitesError",
"SMIRNOFFAromaticityError",
"SMIRNOFFParseError",
"SMIRNOFFVersionError",
"get_available_force_fields",
]

logger = logging.getLogger(__name__)

# Directory paths used by ForceField to discover offxml files.
Expand Down Expand Up @@ -286,6 +285,10 @@ def __init__(
--------
parse_sources

Notes
-----
No effort is made to de-duplicate redundant parameters or parameters with identical SMIRKS patterns.

"""
# Clear all object fields
self._initialize()
Expand Down Expand Up @@ -1137,6 +1140,29 @@ def to_file(
)
io_handler.to_file(filename, smirnoff_data)

def combine(
self,
other: "ForceField",
allow_cosmetic_attributes: bool = False,
) -> "ForceField":
"""
Combine this `ForceField` with another `ForceField`, returning a new `ForceField`.

The same rules as `ForceField.__init__` are followed.
"""
import copy

combined = copy.deepcopy(self)

combined._load_smirnoff_data(
smirnoff_data=other._to_smirnoff_data(
discard_cosmetic_attributes=False,
),
allow_cosmetic_attributes=allow_cosmetic_attributes,
)

return combined

# TODO: Should we also accept a Molecule as an alternative to a Topology?
@requires_package("openmm")
def create_openmm_system(
Expand Down Expand Up @@ -1421,6 +1447,7 @@ def __hash__(self) -> int:
Notable behavior:
* `author` and `date` are stripped from the ForceField
* `id` and `parent_id` are stripped from each ParameterType"""
import copy

# Completely re-constructing the force field may be overkill
# compared to deepcopying and modifying, but is not currently slow
Expand Down
Loading