Skip to content

Commit

Permalink
Merge branch 'main' of github.com:mosdef-hub/foyer into use-lark
Browse files Browse the repository at this point in the history
  • Loading branch information
chrisjonesBSU committed Aug 19, 2024
2 parents 158b85e + ad0b4d3 commit 06d4e6d
Show file tree
Hide file tree
Showing 22 changed files with 157 additions and 398 deletions.
23 changes: 11 additions & 12 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,27 +8,26 @@ ci:
skip: []
submodules: false
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: v0.5.7
hooks:
# Run the linter.
- id: ruff
args: [--line-length=80, --fix]
# Run the formatter.
- id: ruff-format
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.6.0
hooks:
- id: check-yaml
- id: end-of-file-fixer
- id: trailing-whitespace
exclude: 'setup.cfg'
- repo: https://github.com/psf/black
rev: 24.8.0
hooks:
- id: black
args: [--line-length=80]
exclude: 'setup.cfg|foyer/tests/files/.*'
- repo: https://github.com/pycqa/isort
rev: 5.13.2
hooks:
- id: isort
name: isort (python)
args: [--profile=black, --line-length=80]
- repo: https://github.com/pycqa/pydocstyle
rev: '6.3.0'
hooks:
- id: pydocstyle
exclude: ^(foyer/tests/|docs/|devtools/|setup.py)
args: [--convention=numpy]
exclude: "foyer/tests/files/.*"
5 changes: 2 additions & 3 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,15 @@
import pathlib
import sys

import sphinx_rtd_theme

sys.path.insert(0, os.path.abspath("../.."))
sys.path.insert(0, os.path.abspath("sphinxext"))

base_path = pathlib.Path(__file__).parent
os.system("python {} --name".format((base_path / "../../setup.py").resolve()))


import foyer

# -- Project information -----------------------------------------------------

project = "foyer"
Expand Down Expand Up @@ -147,7 +147,6 @@
# a list of builtin themes.
#
# html_theme = 'alabaster'
import sphinx_rtd_theme

html_theme = "sphinx_rtd_theme"
hhtml_theme_path = [sphinx_rtd_theme.get_html_theme_path()]
Expand Down
16 changes: 4 additions & 12 deletions foyer/atomtyper.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,7 @@ def find_atomtypes(structure, forcefield, max_iter=10):
topology_graph = TopologyGraph.from_gmso_topology(structure)

if isinstance(forcefield, Forcefield):
atomtype_rules = AtomTypingRulesProvider.from_foyer_forcefield(
forcefield
)
atomtype_rules = AtomTypingRulesProvider.from_foyer_forcefield(forcefield)
elif isinstance(forcefield, AtomTypingRulesProvider):
atomtype_rules = forcefield
else:
Expand Down Expand Up @@ -110,9 +108,7 @@ def find_atomtypes(structure, forcefield, max_iter=10):
atomic_number = atom_data.atomic_number
atomic_symbol = atom_data.element
try:
element_from_num = ele.element_from_atomic_number(
atomic_number
).symbol
element_from_num = ele.element_from_atomic_number(atomic_number).symbol
element_from_sym = ele.element_from_symbol(atomic_symbol).symbol
assert element_from_num == element_from_sym
system_elements.add(element_from_num)
Expand Down Expand Up @@ -210,13 +206,9 @@ def _iterate_rules(rules, topology_graph, typemap, max_iter):

def _resolve_atomtypes(topology_graph, typemap):
"""Determine the final atomtypes from the white- and blacklists."""
atoms = {
atom_idx: data for atom_idx, data in topology_graph.atoms(data=True)
}
atoms = {atom_idx: data for atom_idx, data in topology_graph.atoms(data=True)}
for atom_id, atom in typemap.items():
atomtype = [
rule_name for rule_name in atom["whitelist"] - atom["blacklist"]
]
atomtype = [rule_name for rule_name in atom["whitelist"] - atom["blacklist"]]
if len(atomtype) == 1:
atom["atomtype"] = atomtype[0]
elif len(atomtype) > 1:
Expand Down
118 changes: 29 additions & 89 deletions foyer/forcefield.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def generate_topology(non_omm_topology, non_element_types=None, residues=None):
return _topology_from_parmed(non_omm_topology, non_element_types)
elif has_mbuild:
mb = import_("mbuild")
if (non_omm_topology, mb.Compound):
if all([non_omm_topology, mb.Compound]):
pmd_comp_struct = non_omm_topology.to_parmed(residues=residues)
return _topology_from_parmed(pmd_comp_struct, non_element_types)
else:
Expand All @@ -162,16 +162,12 @@ def generate_topology(non_omm_topology, non_element_types=None, residues=None):
def _structure_from_residue(residue, parent_structure):
"""Convert a ParmEd Residue to an equivalent Structure."""
structure = pmd.Structure()
orig_to_copy = (
dict()
) # Clone a lot of atoms to avoid any of parmed's tracking
orig_to_copy = dict() # Clone a lot of atoms to avoid any of parmed's tracking
for atom in residue.atoms:
new_atom = copy(atom)
new_atom._idx = atom.idx
orig_to_copy[atom] = new_atom
structure.add_atom(
new_atom, resname=residue.name, resnum=residue.number
)
structure.add_atom(new_atom, resname=residue.name, resnum=residue.number)

for bond in parent_structure.bonds:
if bond.atom1 in residue.atoms and bond.atom2 in residue.atoms:
Expand All @@ -198,10 +194,7 @@ def _topology_from_parmed(structure, non_element_types):
if pmd_atom.name in non_element_types:
element = non_element_types[pmd_atom.name]
else:
if (
isinstance(pmd_atom.atomic_number, int)
and pmd_atom.atomic_number != 0
):
if isinstance(pmd_atom.atomic_number, int) and pmd_atom.atomic_number != 0:
element = elem.Element.getByAtomicNumber(pmd_atom.atomic_number)
else:
element = elem.Element.getBySymbol(pmd_atom.name)
Expand All @@ -221,9 +214,7 @@ def _topology_from_parmed(structure, non_element_types):
topology.addBond(atom1, atom2)
atom1.bond_partners.append(atom2)
atom2.bond_partners.append(atom1)
if structure.box_vectors and np.any(
[x._value for x in structure.box_vectors]
):
if structure.box_vectors and np.any([x._value for x in structure.box_vectors]):
topology.setPeriodicBoxVectors(structure.box_vectors)

positions = structure.positions
Expand Down Expand Up @@ -293,9 +284,7 @@ def _unwrap_typemap(structure, residue_map):
for res_ref, val in residue_map.items():
if id(res.name) == id(res_ref):
for i, atom in enumerate(res.atoms):
master_typemap[int(atom.idx)]["atomtype"] = val[i][
"atomtype"
]
master_typemap[int(atom.idx)]["atomtype"] = val[i]["atomtype"]
return master_typemap


Expand Down Expand Up @@ -325,9 +314,7 @@ def _separate_urey_bradleys(system, topology):
) not in bonds:
ub_force.addBond(*force.getBondParameters(bond_idx))
else:
harmonic_bond_force.addBond(
*force.getBondParameters(bond_idx)
)
harmonic_bond_force.addBond(*force.getBondParameters(bond_idx))
system.removeForce(force_idx)

system.addForce(harmonic_bond_force)
Expand Down Expand Up @@ -499,9 +486,7 @@ class Forcefield(app.ForceField):
"""

def __init__(
self, forcefield_files=None, name=None, validation=True, debug=False
):
def __init__(self, forcefield_files=None, name=None, validation=True, debug=False):
self.atomTypeDefinitions = dict()
self.atomTypeOverrides = dict()
self.atomTypeDesc = dict()
Expand Down Expand Up @@ -539,13 +524,9 @@ def __init__(
if len(preprocessed_files) == 1:
self._version = self._parse_version_number(preprocessed_files[0])
self._name = self._parse_name(preprocessed_files[0])
self._combining_rule = self._parse_combining_rule(
preprocessed_files[0]
)
self._combining_rule = self._parse_combining_rule(preprocessed_files[0])
elif len(preprocessed_files) > 1:
self._version = [
self._parse_version_number(f) for f in preprocessed_files
]
self._version = [self._parse_version_number(f) for f in preprocessed_files]
self._name = [self._parse_name(f) for f in preprocessed_files]
self._combining_rule = [
self._parse_combining_rule(f) for f in preprocessed_files
Expand Down Expand Up @@ -639,9 +620,7 @@ def _parse_name(self, forcefield_file):
try:
return root.attrib["name"]
except KeyError:
warnings.warn(
"No force field name found in force field XML file."
)
warnings.warn("No force field name found in force field XML file.")
return None

def _parse_combining_rule(self, forcefield_file):
Expand All @@ -651,9 +630,7 @@ def _parse_combining_rule(self, forcefield_file):
try:
return root.attrib["combining_rule"]
except KeyError:
warnings.warn(
"No combining rule found in force field XML file."
)
warnings.warn("No combining rule found in force field XML file.")
return "lorentz"

def _create_element(self, element, mass):
Expand All @@ -679,9 +656,7 @@ def registerAtomType(self, parameters):
"""Register a new atom type."""
name = parameters["name"]
if name in self._atomTypes:
raise ValueError(
"Found multiple definitions for atom type: " + name
)
raise ValueError("Found multiple definitions for atom type: " + name)
atom_class = parameters["class"]
mass = _convertParameterToNumber(parameters["mass"])
element = None
Expand Down Expand Up @@ -846,10 +821,7 @@ def run_atomtyping(self, structure, use_residue_map=True, **kwargs):

# Need to call this only once and store results for later id() comparisons
for res_id, res in enumerate(structure.residues):
if (
structure.residues[res_id].name
not in residue_map.keys()
):
if structure.residues[res_id].name not in residue_map.keys():
tmp_res = _structure_from_residue(res, structure)
typemap = find_atomtypes(tmp_res, forcefield=self)
residue_map[res.name] = typemap
Expand Down Expand Up @@ -877,9 +849,7 @@ def parametrize_system(
**kwargs,
):
"""Create system based on resulting typemapping."""
topology, positions = _topology_from_parmed(
structure, self.non_element_types
)
topology, positions = _topology_from_parmed(structure, self.non_element_types)

system = self.createSystem(topology, *args, **kwargs)

Expand Down Expand Up @@ -918,9 +888,7 @@ def parametrize_system(
)

if self.combining_rule == "geometric":
self._patch_parmed_adjusts(
structure, combining_rule=self.combining_rule
)
self._patch_parmed_adjusts(structure, combining_rule=self.combining_rule)

total_charge = sum([atom.charge for atom in structure.atoms])
if not np.allclose(total_charge, 0):
Expand Down Expand Up @@ -1032,9 +1000,7 @@ def createSystem(
elem.hydrogen,
None,
):
transfer_mass = hydrogenMass - sys.getParticleMass(
atom2.index
)
transfer_mass = hydrogenMass - sys.getParticleMass(atom2.index)
sys.setParticleMass(atom2.index, hydrogenMass)
mass = sys.getParticleMass(atom1.index) - transfer_mass
sys.setParticleMass(atom1.index, mass)
Expand Down Expand Up @@ -1091,9 +1057,7 @@ def createSystem(
bonded_to = data.bondedToAtom[atom]
if len(bonded_to) > 2:
for subset in itertools.combinations(bonded_to, 3):
data.impropers.append(
(atom, subset[0], subset[1], subset[2])
)
data.impropers.append((atom, subset[0], subset[1], subset[2]))

# Identify bonds that should be implemented with constraints
if constraints == AllBonds or constraints == HAngles:
Expand Down Expand Up @@ -1188,15 +1152,9 @@ def createSystem(
site.originWeights[1],
site.originWeights[2],
),
mm.Vec3(
site.xWeights[0], site.xWeights[1], site.xWeights[2]
),
mm.Vec3(
site.yWeights[0], site.yWeights[1], site.yWeights[2]
),
mm.Vec3(
site.localPos[0], site.localPos[1], site.localPos[2]
),
mm.Vec3(site.xWeights[0], site.xWeights[1], site.xWeights[2]),
mm.Vec3(site.yWeights[0], site.yWeights[1], site.yWeights[2]),
mm.Vec3(site.localPos[0], site.localPos[1], site.localPos[2]),
)
sys.setVirtualSite(index, local_coord_site)

Expand Down Expand Up @@ -1263,9 +1221,7 @@ def _write_references_to_file(self, atom_types, references_file):
for atomtype, dois in atomtype_references.items():
for doi in dois:
unique_references[doi].append(atomtype)
unique_references = collections.OrderedDict(
sorted(unique_references.items())
)
unique_references = collections.OrderedDict(sorted(unique_references.items()))
with open(references_file, "w") as f:
for doi, atomtypes in unique_references.items():
url = "http://api.crossref.org/works/{}/transform/application/x-bibtex".format(
Expand Down Expand Up @@ -1338,11 +1294,7 @@ def get_parameters(self, group, key, keys_are_atom_classes=False):
if group not in param_extractors:
raise ValueError(f"Cannot extract parameters for {group}")

key = (
[key]
if isinstance(key, str) or not isinstance(key, Iterable)
else key
)
key = [key] if isinstance(key, str) or not isinstance(key, Iterable) else key

validate_type(key, str)

Expand All @@ -1367,18 +1319,14 @@ def _extract_non_bonded_params(self, atom_type):

atom_type = atom_type[0]

non_bonded_forces_gen = self.get_generator(
ff=self, gen_type=NonbondedGenerator
)
non_bonded_forces_gen = self.get_generator(ff=self, gen_type=NonbondedGenerator)

non_bonded_params = non_bonded_forces_gen.params.paramsForType

try:
return non_bonded_params[atom_type]
except KeyError:
raise MissingParametersError(
f"Missing parameters for atom {atom_type}"
)
raise MissingParametersError(f"Missing parameters for atom {atom_type}")

def _extract_harmonic_bond_params(self, atom_types):
"""Return parameters for a specific HarmonicBondForce between atom types."""
Expand Down Expand Up @@ -1548,9 +1496,7 @@ def _extract_rb_proper_params(self, atom_types):
f"be extracted for four atoms. Provided {len(atom_types)}"
)

rb_torsion_force_gen = self.get_generator(
ff=self, gen_type=RBTorsionGenerator
)
rb_torsion_force_gen = self.get_generator(ff=self, gen_type=RBTorsionGenerator)

wildcard = self._atomClasses[""]
(
Expand Down Expand Up @@ -1600,9 +1546,7 @@ def _extract_rb_improper_params(self, atom_types):
f"be extracted for four atoms. Provided {len(atom_types)}"
)

rb_torsion_force_gen = self.get_generator(
ff=self, gen_type=RBTorsionGenerator
)
rb_torsion_force_gen = self.get_generator(ff=self, gen_type=RBTorsionGenerator)

match = self._match_impropers(atom_types, rb_torsion_force_gen)

Expand All @@ -1622,9 +1566,7 @@ def map_atom_classes_to_types(self, atom_classes_keys, strict=False):
# When to do this substitution with wildcards?
substitution = self._atomClasses.get(key)
if not substitution:
raise ValueError(
f"Atom class {key} is missing from the Forcefield"
)
raise ValueError(f"Atom class {key} is missing from the Forcefield")
atom_type_keys.append(next(iter(substitution)))

return atom_type_keys
Expand Down Expand Up @@ -1715,9 +1657,7 @@ def get_generator(ff, gen_type):
@staticmethod
def substitute_wildcards(atom_types, wildcard):
"""Return possible wildcard options."""
return tuple(
atom_type or next(iter(wildcard)) for atom_type in atom_types
)
return tuple(atom_type or next(iter(wildcard)) for atom_type in atom_types)


pmd.Structure.write_foyer = write_foyer
Loading

0 comments on commit 06d4e6d

Please sign in to comment.