Skip to content

Commit

Permalink
Merge pull request #193 from openforcefield/connectivity-check
Browse files Browse the repository at this point in the history
Deal with connectivity changes in QC generation
  • Loading branch information
Yoshanuikabundi authored Sep 21, 2022
2 parents 17edf9c + f112f24 commit 5663b52
Show file tree
Hide file tree
Showing 6 changed files with 281 additions and 5 deletions.
1 change: 1 addition & 0 deletions devtools/conda-envs/docs-env.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ dependencies:
- chemper
- geometric
- torsiondrive
- pymbar <4

# Executor
- uvicorn
Expand Down
1 change: 1 addition & 0 deletions devtools/conda-envs/no_openeye.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ dependencies:
- chemper
- geometric
- torsiondrive
- pymbar <4

# Executor
- uvicorn
Expand Down
1 change: 1 addition & 0 deletions devtools/conda-envs/test-env.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ dependencies:
- chemper
- geometric
- torsiondrive
- pymbar <4

# Executor
- uvicorn
Expand Down
5 changes: 4 additions & 1 deletion openff/bespokefit/executor/services/coordinator/stages.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
ProperTorsionSMIRKS,
VdWSMIRKS,
)
from openff.bespokefit.schema.targets import TargetSchema
from openff.bespokefit.schema.tasks import Torsion1DTask
from openff.bespokefit.utilities.pydantic import BaseModel
from openff.bespokefit.utilities.smirks import (
Expand Down Expand Up @@ -512,7 +513,9 @@ async def _inject_bespoke_qc_data(
input_schema: BespokeOptimizationSchema,
):

targets = [target for stage in input_schema.stages for target in stage.targets]
targets: List[TargetSchema] = [
target for stage in input_schema.stages for target in stage.targets
]
for i, target in enumerate(targets):

if not isinstance(target.reference_data, BespokeQCData):
Expand Down
140 changes: 138 additions & 2 deletions openff/bespokefit/schema/targets.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
import abc
from typing import Any, Dict, Optional, Union
from typing import Any, Dict, Optional, TypeVar, Union

from openff.qcsubmit.results import (
BasicResultCollection,
OptimizationResultCollection,
TorsionDriveResultCollection,
)
from pydantic import Field, PositiveFloat
from openff.toolkit.topology import Molecule
from pydantic import Field, PositiveFloat, validator
from qcelemental.models import AtomicResult
from qcelemental.models import Molecule as QCEMolecule
from qcelemental.models.procedures import OptimizationResult, TorsionDriveResult
from qcelemental.molutil import guess_connectivity
from typing_extensions import Literal

from openff.bespokefit.schema.data import BespokeQCData, LocalQCData
Expand All @@ -20,6 +23,130 @@
from openff.bespokefit.utilities.pydantic import SchemaBase


def _check_connectivity(
qcschema: QCEMolecule,
name: str,
fragment: Optional[Molecule] = None,
):
"""
Raise an exception if the geometry of ``qcschema`` does not match ``fragment``
Parameters
==========
qcschema
A QCElemental ``Molecule`` representing the final geometry of a QC
computation
name
A name for the current calculation. Used in the exception raised by this
method.
fragment
An OpenFF ``Molecule`` representing the true chemical identity of the
fragment.
Returns
=======
None
Raises
======
ValueError
If the connectivity does not match.
"""
# If expected connectivity is not provided, compute it from the fragment
if fragment is None:
fragment = Molecule.from_qcschema(qcschema)

# Get expected connectivity from bonds
expected_connectivity = {
tuple(sorted([bond.atom1_index + 1, bond.atom2_index + 1]))
for bond in fragment.bonds
}

# Guess found connectivity from the output geometry
actual_connectivity = {
tuple(sorted([a + 1, b + 1]))
for a, b in guess_connectivity(qcschema.symbols, qcschema.geometry)
}

if expected_connectivity != actual_connectivity:
# Pydantic validators must raise ValueError, TypeError or AssertionError
raise ValueError(
f"Target record {name}: "
+ "Reference data does not match target.\n"
+ f"Expected mapped SMILES: {fragment.to_smiles(mapped=True)}\n"
+ "The following connections were expected but not found: "
+ f"{expected_connectivity - actual_connectivity}\n"
+ "The following connections were found but not expected: "
+ f"{actual_connectivity - expected_connectivity}\n"
+ f"The reference geometry is: {qcschema.geometry}"
)


R = TypeVar(
"R",
None,
LocalQCData[TorsionDriveResult],
LocalQCData[OptimizationResult],
BespokeQCData[Torsion1DTaskSpec],
BespokeQCData[OptimizationTaskSpec],
OptimizationResultCollection,
TorsionDriveResultCollection,
)


def _validate_connectivity(
cls,
ref_data: R,
) -> R:
"""
Check that connectivity has not changed over the course of QC computation.
This function can be used as a validator for the ``reference_data`` field of
a target schema if the connectivity may change over the course of computing
the target:
def __init__(...):
...
_reference_data_connectivity = validator("reference_data", allow_reuse=True)(
_check_connectivity
)
...
"""
if ref_data is None or isinstance(ref_data, BespokeQCData):
# Reference data has not been computed, so the connectivity is intact
return ref_data
elif isinstance(ref_data, LocalQCData):
for qc_record in ref_data.qc_records:
# Some qc records (eg, TorsionDriveResult) use .final_molecules (plural),
# others (eg, OptimizationResult) use .final_molecule (singular)
try:
final_molecules = qc_record.final_molecules
except AttributeError:
final_molecules = {"opt": qc_record.final_molecule}

for name, qcschema in final_molecules.items():
_check_connectivity(qcschema, name)

elif hasattr(ref_data, "to_records"):
for qc_record, fragment in ref_data.to_records():
# Some qc records (eg, TorsionDriveRecord) use .get_final_molecules() (plural),
# others (eg, OptimizationRecord) use .get_final_molecule() (singular)
try:
final_molecules = qc_record.get_final_molecules()
except AttributeError:
final_molecules = {"opt": qc_record.get_final_molecule()}

for name, qcschema in final_molecules.items():
_check_connectivity(qcschema, name, fragment)

# No connectivity changes found, so return the unchanged input as validated
return ref_data


class BaseTargetSchema(SchemaBase, abc.ABC):
"""The base class for models which store information about fitting targets."""

Expand Down Expand Up @@ -67,6 +194,9 @@ class TorsionProfileTargetSchema(BaseTargetSchema):
description="The reference QC data (either existing or to be generated on the "
"fly) to fit against.",
)
_reference_data_connectivity = validator("reference_data", allow_reuse=True)(
_validate_connectivity
)
calculation_specification: Optional[Torsion1DTaskSpec] = Field(
None,
description="The specification for the reference torsion drive calculation, also acts as a provenance source.",
Expand Down Expand Up @@ -102,6 +232,9 @@ class AbInitioTargetSchema(BaseTargetSchema):
description="The reference QC data (either existing or to be generated on the "
"fly) to fit against.",
)
_reference_data_connectivity = validator("reference_data", allow_reuse=True)(
_validate_connectivity
)
calculation_specification: Optional[Torsion1DTaskSpec] = Field(
None,
description="The specification for the reference torsion drive calculation, also acts as a provenance source.",
Expand Down Expand Up @@ -169,6 +302,9 @@ class OptGeoTargetSchema(BaseTargetSchema):
description="The reference QC data (either existing or to be generated on the "
"fly) to fit against.",
)
_reference_data_connectivity = validator("reference_data", allow_reuse=True)(
_validate_connectivity
)
calculation_specification: Optional[OptimizationTaskSpec] = Field(
None,
description="The specification for the reference optimisation calculation, also acts as a provenance source.",
Expand Down
138 changes: 136 additions & 2 deletions openff/bespokefit/tests/schema/test_targets.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,12 @@
from pydantic import ValidationError
from qcelemental.models.common_models import Model

from openff.bespokefit.schema.data import BespokeQCData
from openff.bespokefit.schema.targets import TorsionProfileTargetSchema
from openff.bespokefit.schema.data import BespokeQCData, LocalQCData
from openff.bespokefit.schema.targets import (
AbInitioTargetSchema,
OptGeoTargetSchema,
TorsionProfileTargetSchema,
)
from openff.bespokefit.schema.tasks import HessianTaskSpec, Torsion1DTaskSpec


Expand Down Expand Up @@ -31,3 +35,133 @@ def test_check_reference_data(qc_torsion_drive_results):
)
)
)


@pytest.mark.parametrize(
"TargetSchema",
[
TorsionProfileTargetSchema,
AbInitioTargetSchema,
OptGeoTargetSchema,
],
)
class TestCheckConnectivity:
@pytest.fixture()
def expected_err(self, TargetSchema) -> str:
return (
r"1 validation error for "
+ TargetSchema.__name__
+ r"\n"
+ r"reference_data\n"
+ r" Target record (opt|\[-165\]): Reference data "
+ r"does not match target\.\n"
+ r"Expected mapped SMILES: "
# This regex for the mapped SMILES is probably extremely fragile;
# if this test breaks after an RDkit/OpenEye update, try replacing it
# with something like `+ r".*"`
+ r"(\(?[-+]?\[("
+ r"H:13|c:1|c:3|c:7|c:11|c:8|c:4|H:16|H:20|c:12|c:9|c:5|c:2|c:6|c:10"
+ r"|H:22|H:18|H:14|H:17|H:21|H:19|H:15"
+ r")\]1?2?\)?){22}"
# End fragile regex
+ r"\n"
+ r"The following connections were expected but not found: "
+ r"{\(1, 13\), \(1, 3\), \(1, 4\)}\n"
+ r"The following connections were found but not expected: "
+ r"{\(1, 6\), \(1, 2\), \(1, 14\), \(1, 5\)}\n"
+ r"The reference geometry is: \[(\[.*\]\n ){21}\[.*\]\]"
# + r"The reference geometry is: \[.*\]"
+ r" \(type=value_error\)"
)

@pytest.fixture()
def ref_data_local(self, TargetSchema, request):
ref_data_fixture = {
TorsionProfileTargetSchema: "qc_torsion_drive_qce_result",
AbInitioTargetSchema: "qc_torsion_drive_qce_result",
OptGeoTargetSchema: "qc_optimization_qce_result",
}[TargetSchema]
result, _ = request.getfixturevalue(ref_data_fixture)
return [result]

@pytest.fixture()
def ref_data_qcfractal(self, TargetSchema, request):
ref_data_fixture = {
TorsionProfileTargetSchema: "qc_torsion_drive_results",
AbInitioTargetSchema: "qc_torsion_drive_results",
OptGeoTargetSchema: "qc_optimization_results",
}[TargetSchema]
return request.getfixturevalue(ref_data_fixture)

def test_check_connectivity_local_positive(
self,
TargetSchema,
ref_data_local,
):
TargetSchema(reference_data=LocalQCData(qc_records=ref_data_local))

def test_check_connectivity_local_negative(
self,
TargetSchema,
ref_data_local,
expected_err,
):
# Swap the first two atoms' coordinates to break their connectivity
torsiondrive_result_disconnection = ref_data_local[0]
try:
geom = next(
iter(torsiondrive_result_disconnection.final_molecules.values())
).geometry
except AttributeError:
geom = torsiondrive_result_disconnection.final_molecule.geometry
geom[0], geom[1] = geom[1], geom[0]

with pytest.raises(ValidationError, match=expected_err):
TargetSchema(
reference_data=LocalQCData(
qc_records=[torsiondrive_result_disconnection]
)
)

def test_check_connectivity_qcfractal_positive(
self,
TargetSchema,
ref_data_qcfractal,
):
TargetSchema(reference_data=ref_data_qcfractal)

def test_check_connectivity_qcfractal_negative(
self,
TargetSchema,
ref_data_qcfractal,
expected_err,
):
[(record, _)] = ref_data_qcfractal.to_records()

# Get the first of the final molecules, and prepare an update function so
# that changes are preserved across calls to get_final_molecule(s)()
try:
final_molecules = record.get_final_molecules()
key, first_final_mol = next(iter(record.cache["final_molecules"].items()))

def update_record(updated_mol):
final_molecules.update({key: updated_mol})
record.__dict__["get_final_molecules"] = lambda: final_molecules

except AttributeError:
first_final_mol = record.get_final_molecule()

def update_record(updated_mol):
record.__dict__["get_final_molecule"] = lambda: updated_mol

# Swap the first two atoms' coordinates to break their connectivity
geom = first_final_mol.geometry.copy()
geom[0], geom[1] = geom[1], geom[0]
updated_mol = first_final_mol.copy(update={"geometry": geom})

# Update the record with the new geometry
update_record(updated_mol)

# Create the target schema, which should fail to validate
with pytest.raises(ValidationError, match=expected_err):
TargetSchema(reference_data=ref_data_qcfractal)

0 comments on commit 5663b52

Please sign in to comment.