diff --git a/dmff/classical/inter.py b/dmff/classical/inter.py index 4fd304ec8..4efd6b0e1 100644 --- a/dmff/classical/inter.py +++ b/dmff/classical/inter.py @@ -332,3 +332,67 @@ def get_energy(positions, box, pairs, bcc, mscales): charges = self.init_charges + jnp.dot(self.top_mat, bcc).flatten() return get_energy_kernel(positions, box, pairs, charges, mscales) return get_energy + + +class CustomGBForce: + def __init__( + self, + map_charge, + map_radius, + map_scale, + epsilon_1=1.0, + epsilon_solv=78.3, + alpha=1, + beta=0.8, + gamma=4.85, + ) -> None: + self.map_charge = map_charge + self.map_radius = map_radius + self.map_scale = map_scale + self.alpha = alpha + self.beta = beta + self.gamma = gamma + self.exp_solv = epsilon_solv + self.eps_1 = epsilon_1 + + def generate_get_energy(self): + @jax.jit + def get_energy(positions, box, pairs, Ipairs, charges, radius, scales): + def calI(posList, radMap, scalMap, rhoMap, pairMap): + I = jnp.array([]) + + for i in range(len(radMap)): + posj = posList[Ipairs[i]] + rhoj = rhoMap[Ipairs[i]] + scalj = scalMap[Ipairs[i]] + posi = posList[i] + rhoi = rhoMap[i] + + r = jnp.sqrt(jnp.sum(jnp.power(posi-posj,2),axis=1)) + sr2 = rhoj * scalj + D = jnp.abs(r - sr2) + L = jnp.maximum(D, rhoi) + C = 2 * (1 / rhoi - 1 / L) * jnp.heaviside(sr2 - r - rhoi, 1) + U = r + sr2 + I = jnp.append(I, jnp.sum(0.5 * jnp.heaviside(r + sr2 - rhoi, 1) * ( + 1 / L - 1 / U + 0.25 * (1 / U ** 2 - 1 / L ** 2) * ( + r - sr2 ** 2 / r) + 0.5 * jnp.log(L / U) / r + C))) + + return I + + chargeMap = charges[self.map_charge] + radiusMap = radius[self.map_radius] + scalesMap = scales[self.map_scale] + rhoMap = radiusMap - 0.009 + + # effective radius + IList = calI(positions, radiusMap, scalesMap, rhoMap, Ipairs) + psi = IList*rhoMap + rEff = 1/(1/rhoMap-jnp.tanh(self.alpha*psi-self.beta*jnp.power(psi, 2)+self.gamma*jnp.power(psi, 3))/radiusMap) + Ese = jnp.sum(28.3919551*(radiusMap+0.14)**2*jnp.power(radiusMap/rEff, 6)-0.5*138.935456*(1/self.eps_1-1/self.exp_solv)*chargeMap**2/rEff) + dr_norm = jnp.linalg.norm(positions[pairs[:,0]] - positions[pairs[:,1]], axis=1) + chargepro = chargeMap[pairs[:, 0]] * chargeMap[pairs[:, 1]] + rEffpro = rEff[pairs[:, 0]] * rEff[pairs[:, 1]] + Egb = jnp.sum(-138.935456*(1/self.eps_1-1/self.exp_solv)*chargepro/jnp.sqrt(jnp.power(dr_norm, 2)+rEffpro*jnp.exp(-jnp.power(dr_norm,2)/(4*rEffpro)))) + return Ese + Egb + return get_energy \ No newline at end of file diff --git a/dmff/classical/intra.py b/dmff/classical/intra.py index 76843b979..198c16ce2 100644 --- a/dmff/classical/intra.py +++ b/dmff/classical/intra.py @@ -148,3 +148,78 @@ def refresh_calculators(self): """ self.get_energy = self.generate_get_energy() self.get_forces = value_and_grad(self.get_energy) + + +class Custom1_5BondJaxForce: + def __init__(self, p1idx, p2idx, prmidx): + self.p1idx = p1idx + self.p2idx = p2idx + self.prmidx = prmidx + self.refresh_calculators() + + def generate_get_energy(self): + def get_energy(positions, box, pairs, k, length): + p1 = positions[self.p1idx,:] + p2 = positions[self.p2idx,:] + kprm = k[self.prmidx] + b0prm = length[self.prmidx] + dist = distance(p1, p2) + return jnp.sum(0.5 * kprm * jnp.power(dist - b0prm, 2)) + + return get_energy + + def update_env(self, attr, val): + """ + Update the environment of the calculator + """ + setattr(self, attr, val) + self.refresh_calculators() + + def refresh_calculators(self): + """ + refresh the energy and force calculators according to the current environment + """ + self.get_energy = self.generate_get_energy() + self.get_forces = value_and_grad(self.get_energy) + + +class CustomTorsionJaxForce: + def __init__(self, p1idx, p2idx, p3idx, p4idx, prmidx, order): + self.p1idx = p1idx + self.p2idx = p2idx + self.p3idx = p3idx + self.p4idx = p4idx + self.prmidx = prmidx + self.order = order + self.refresh_calculators() + + def generate_get_energy(self): + if len(self.p1idx) == 0: + return lambda positions, box, pairs, k, psi, shift: 0.0 + def get_energy(positions, box, pairs, k, psi, shift): + p1 = positions[self.p1idx, :] + p2 = positions[self.p2idx, :] + p3 = positions[self.p3idx, :] + p4 = positions[self.p4idx, :] + kp = k[self.prmidx] + psip = psi[self.prmidx] + shiftp = shift[self.prmidx] + dih = dihedral(p1, p2, p3, p4) + ener = kp * (jnp.cos(self.order * dih - psip)) + shiftp + return jnp.sum(ener) + + return get_energy + + def update_env(self, attr, val): + """ + Update the environment of the calculator + """ + setattr(self, attr, val) + self.refresh_calculators() + + def refresh_calculators(self): + """ + refresh the energy and force calculators according to the current environment + """ + self.get_energy = self.generate_get_energy() + self.get_forces = value_and_grad(self.get_energy) \ No newline at end of file diff --git a/dmff/generators/classical.py b/dmff/generators/classical.py index 00d70be45..bb260b6f1 100644 --- a/dmff/generators/classical.py +++ b/dmff/generators/classical.py @@ -7,8 +7,8 @@ import jax.numpy as jnp import openmm.app as app import openmm.unit as unit -from ..classical.intra import HarmonicBondJaxForce, HarmonicAngleJaxForce, PeriodicTorsionJaxForce -from ..classical.inter import CoulNoCutoffForce, CoulombPMEForce, CoulReactionFieldForce, LennardJonesForce, LennardJonesLongRangeForce +from ..classical.intra import HarmonicBondJaxForce, HarmonicAngleJaxForce, PeriodicTorsionJaxForce, Custom1_5BondJaxForce, CustomTorsionJaxForce +from ..classical.inter import CoulNoCutoffForce, CoulombPMEForce, CoulReactionFieldForce, LennardJonesForce, LennardJonesLongRangeForce, CustomGBForce from typing import Tuple, List, Union, Callable @@ -787,157 +787,921 @@ def potential_fn(positions: jnp.ndarray, box: jnp.ndarray, pairs: jnp.ndarray, p _DMFFGenerators["PeriodicTorsionForce"] = PeriodicTorsionGenerator +class CustomTorsionGenerator: + + def __init__(self, ffinfo: dict, paramset: ParamSet): + """ + Initializes a PeriodicTorsionForce object. + + Args: + - ffinfo (dict): A dictionary containing force field information. + - paramset (ParamSet): A ParamSet object to register parameters. + + Returns: + - None + """ + self.name = "CustomTorsionForce" + self.ffinfo = ffinfo + paramset.addField(self.name) + self._use_smarts = False + self.key_type = None + self.torsionIndices = [i for i in range(len(self.ffinfo["Forces"][self.name]["node"])) if "roper" in self.ffinfo["Forces"][self.name]["node"][i]["name"]] + + proper_keys, proper_periods, proper_prms, proper_shift = [], [], [], [] + proper_key_to_prms = {} + improper_keys, improper_periods, improper_prms, improper_shift = [], [], [], [] + improper_key_to_prms = {} + for i in self.torsionIndices: + node = self.ffinfo["Forces"][self.name]["node"][i] + attribs = node["attrib"] + if "type1" in attribs: + self.key_type = "type" + elif "class1" in attribs: + self.key_type = "class" + key = (attribs[self.key_type + "1"], attribs[self.key_type + "2"], + attribs[self.key_type + "3"], attribs[self.key_type + "4"]) + if node["name"] == "Proper": + proper_keys.append(key) + elif node["name"] == "Improper": + improper_keys.append(key) + + mask = 1.0 + if "mask" in attribs and attribs["mask"].upper() == "TRUE": + mask = 0.0 + + for period_key in attribs.keys(): + if "per" not in period_key: + continue + order = int(period_key.replace("per", "")) + period = int(attribs[period_key]) + phase = float(attribs["phase" + str(order)]) + k = float(attribs["k" + str(order)]) + shift = float(attribs["shift"])/4 + if node["name"] == "Proper": + proper_periods.append(period) + proper_prms.append([phase, k, mask, shift]) + if len(proper_keys) - 1 not in proper_key_to_prms: + proper_key_to_prms[len(proper_keys) - 1] = [] + proper_key_to_prms[len( + proper_keys) - 1].append(len(proper_periods) - 1) + elif node["name"] == "Improper": + improper_periods.append(period) + improper_prms.append([phase, k, mask, shift]) + if len(improper_keys) - 1 not in improper_key_to_prms: + improper_key_to_prms[len(improper_keys) - 1] = [] + improper_key_to_prms[len( + improper_keys) - 1].append(len(improper_periods) - 1) + + self.proper_keys = proper_keys + self.proper_periods = jnp.array(proper_periods) + self.proper_key_to_prms = proper_key_to_prms + proper_phase = jnp.array([i[0] for i in proper_prms]) + proper_k = jnp.array([i[1] for i in proper_prms]) + proper_mask = jnp.array([i[2] for i in proper_prms]) + proper_shift = jnp.array([i[3] for i in proper_prms]) + # register parameters to ParamSet + paramset.addParameter(proper_phase, "proper_phase", + field=self.name, mask=proper_mask) + paramset.addParameter(proper_k, "proper_k", + field=self.name, mask=proper_mask) + paramset.addParameter(proper_shift, "proper_shift", + field=self.name, mask=proper_mask) + + self.imp_keys = improper_keys + self.imp_periods = jnp.array(improper_periods) + self.imp_key_to_prms = improper_key_to_prms + improper_phase = jnp.array([i[0] for i in improper_prms]) + improper_k = jnp.array([i[1] for i in improper_prms]) + improper_mask = jnp.array([i[2] for i in improper_prms]) + improper_shift = jnp.array([i[3] for i in improper_prms]) + # register parameters to ParamSet + paramset.addParameter(improper_phase, "improper_phase", + field=self.name, mask=improper_mask) + paramset.addParameter(improper_k, "improper_k", + field=self.name, mask=improper_mask) + paramset.addParameter(improper_shift, "improper_shift", + field=self.name, mask=improper_mask) + + def getName(self): + return self.name + + def overwrite(self, paramset): + # paramset to ffinfo + proper_node_indices = [i for i in range(len( + self.ffinfo["Forces"][self.name]["node"])) if + self.ffinfo["Forces"][self.name]["node"][i]["name"] == "Proper"] + improper_node_indices = [i for i in range(len( + self.ffinfo["Forces"][self.name]["node"])) if + self.ffinfo["Forces"][self.name]["node"][i]["name"] == "Improper"] + + proper_phase = paramset[self.name]["proper_phase"] + proper_k = paramset[self.name]["proper_k"] + proper_shift = paramset[self.name]["proper_shift"] + proper_msks = paramset.mask[self.name]["proper"] + for nnode, key in enumerate(self.proper_keys): + self.ffinfo["Forces"][self.name]["node"][proper_node_indices[nnode]]["attrib"] = { + } + self.ffinfo["Forces"][self.name]["node"][proper_node_indices[nnode] + ]["attrib"][f"{self.key_type}1"] = key[0] + self.ffinfo["Forces"][self.name]["node"][proper_node_indices[nnode] + ]["attrib"][f"{self.key_type}2"] = key[1] + self.ffinfo["Forces"][self.name]["node"][proper_node_indices[nnode] + ]["attrib"][f"{self.key_type}3"] = key[2] + self.ffinfo["Forces"][self.name]["node"][proper_node_indices[nnode] + ]["attrib"][f"{self.key_type}4"] = key[3] + shiftTem = 0 + for nitem, item in enumerate(self.proper_key_to_prms[nnode]): + phase, k, shift = proper_phase[item], proper_k[item], proper_shift[item] + mask = proper_msks[item] + self.ffinfo["Forces"][self.name]["node"][proper_node_indices[nnode] + ]["attrib"]["per" + str(nitem + 1)] = str(self.proper_periods[item]) + self.ffinfo["Forces"][self.name]["node"][proper_node_indices[nnode] + ]["attrib"]["phase" + str(nitem + 1)] = str(phase) + self.ffinfo["Forces"][self.name]["node"][proper_node_indices[nnode] + ]["attrib"]["k" + str(nitem + 1)] = str(k) + shiftTem += shift + self.ffinfo["Forces"][self.name]["node"][proper_node_indices[nnode] + ]["attrib"]["shift"] = str(shiftTem) + if mask < 0.999: + self.ffinfo["Forces"][self.name]["node"][proper_node_indices[nnode] + ]["attrib"]["mask"] = "true" + + improper_phase = paramset[self.name]["improper_phase"] + improper_k = paramset[self.name]["improper_k"] + improper_shift = paramset[self.name]["improper_shift"] + improper_msks = paramset.mask[self.name]["improper"] + for nnode, key in enumerate(self.imp_keys): + self.ffinfo["Forces"][self.name]["node"][improper_node_indices[nnode]]["attrib"] = { + } + self.ffinfo["Forces"][self.name]["node"][improper_node_indices[nnode] + ]["attrib"][f"{self.key_type}1"] = key[0] + self.ffinfo["Forces"][self.name]["node"][improper_node_indices[nnode] + ]["attrib"][f"{self.key_type}2"] = key[1] + self.ffinfo["Forces"][self.name]["node"][improper_node_indices[nnode] + ]["attrib"][f"{self.key_type}3"] = key[2] + self.ffinfo["Forces"][self.name]["node"][improper_node_indices[nnode] + ]["attrib"][f"{self.key_type}4"] = key[3] + shiftTem = 0 + for nitem, item in enumerate(self.imp_key_to_prms[nnode]): + phase = improper_phase[item] + k = improper_k[item] + shift = improper_shift[item] + mask = improper_msks[item] + self.ffinfo["Forces"][self.name]["node"][improper_node_indices[nnode] + ]["attrib"]["per" + str(nitem + 1)] = str(self.imp_periods[item]) + self.ffinfo["Forces"][self.name]["node"][improper_node_indices[nnode] + ]["attrib"]["phase" + str(nitem + 1)] = str(phase) + self.ffinfo["Forces"][self.name]["node"][improper_node_indices[nnode] + ]["attrib"]["k" + str(nitem + 1)] = str(k) + shiftTem += shift + self.ffinfo["Forces"][self.name]["node"][improper_node_indices[nnode] + ]["attrib"]["shift"] = str(shiftTem) + if mask < 0.999: + self.ffinfo["Forces"][self.name]["node"][improper_node_indices[nnode] + ]["attrib"]["mask"] = "true" + + def _find_proper_key_index(self, key: Tuple[str, str, str, str]) -> int: + wc_patch = [] + for i, k in enumerate(self.proper_keys): + if k[0] in ["", key[0]] and k[1] in ["", key[1]] and k[2] in ["", key[2]] and k[3] in ["", key[3]]: + if "" in k: + wc_patch.append(i) + else: + return i + if k[0] in ["", key[3]] and k[1] in ["", key[2]] and k[2] in ["", key[1]] and k[3] in ["", key[0]]: + if "" in k: + wc_patch.append(i) + else: + return i + if len(wc_patch) > 0: + return wc_patch[0] + return None + + def _find_improper_key_index(self, improper): + + type1 = improper[0].meta[self.key_type] + type2 = improper[1].meta[self.key_type] + type3 = improper[2].meta[self.key_type] + type4 = improper[3].meta[self.key_type] + + def _wild_match(tp, tps): + if tps == "": + return True + if tp == tps: + return True + return False + + matched = None + for ndef, tordef in enumerate(self.imp_keys): + types1 = tordef[0] + types2 = tordef[1] + types3 = tordef[2] + types4 = tordef[3] + hasWildcard = ("" in (types1, types2, types3, types4)) + + if matched is not None and hasWildcard: + continue + + import itertools + if type1 in types1: + for (t2, t3, t4) in itertools.permutations(((type2, 1), (type3, 2), (type4, 3))): + if _wild_match(t2[0], types2) and _wild_match(t3[0], types3) and _wild_match(t4[0], types4): + a1 = improper[t2[1]].index + a2 = improper[t3[1]].index + e1 = improper[t2[1]].element + e2 = improper[t3[1]].element + m1 = app.element.get_by_symbol(e1).mass + m2 = app.element.get_by_symbol(e2).mass + if e1 == e2 and a1 > a2: + (a1, a2) = (a2, a1) + elif e1 != "C" and (e2 == "C" or m1 < m2): + (a1, a2) = (a2, a1) + matched = (a1, a2, improper[0].index, improper[t4[1]].index, ndef) + break + if matched is None: + return None, None + return matched[4], matched[:4] + + def createPotential(self, topdata: DMFFTopology, nonbondedMethod, + nonbondedCutoff, **kwargs): + + if self.key_type is None: + def potential_fn_zero(positions: jnp.ndarray, box: jnp.ndarray, pairs: jnp.ndarray, + params: ParamSet) -> jnp.ndarray: + return jnp.zeros((1,)) + + self._jaxPotential = potential_fn_zero + return potential_fn_zero + + proper_list = [] + + acenters = {} + atoms = [a for a in topdata.atoms()] + for bond in topdata.bonds(): + a1, a2 = bond.atom1, bond.atom2 + i1, i2 = a1.index, a2.index + if i1 not in acenters: + acenters[i1] = [] + acenters[i1].append(i2) + if i2 not in acenters: + acenters[i2] = [] + acenters[i2].append(i1) + + # find rotamers and loop over proper torsions on the rotamer + for bond in topdata.bonds(): + a1, a2 = bond.atom1, bond.atom2 + i1, i2 = a1.index, a2.index + alinks1 = [i for i in acenters[i1] if i != i2] + alinks2 = [i for i in acenters[i2] if i != i1] + for i3 in alinks1: + for i4 in alinks2: + if i3 != i4: + proper_list.append( + (atoms[i3], atoms[i1], atoms[i2], atoms[i4])) + + impr_list = [] + # find atoms that link with three other atoms + import itertools as it + for i1 in acenters: + if len(acenters[i1]) < 3: + continue + for item in it.combinations(acenters[i1], 3): + impr_list.append( + (atoms[i1], atoms[item[0]], atoms[item[1]], atoms[item[2]])) + + # create potential + proper_a1, proper_a2, proper_a3, proper_a4, proper_indices, proper_period = [ + ], [], [], [], [], [] + for proper in proper_list: + pidx = self._find_proper_key_index( + (proper[0].meta[self.key_type], proper[1].meta[self.key_type], proper[2].meta[self.key_type], + proper[3].meta[self.key_type])) + if pidx is None: + continue + + prm_indices = self.proper_key_to_prms[pidx] + for prm_idx in prm_indices: + prm_period = self.proper_periods[prm_idx] + proper_a1.append(proper[0].index) + proper_a2.append(proper[1].index) + proper_a3.append(proper[2].index) + proper_a4.append(proper[3].index) + proper_indices.append(prm_idx) + proper_period.append(prm_period) + + proper_a1 = jnp.array(proper_a1) + proper_a2 = jnp.array(proper_a2) + proper_a3 = jnp.array(proper_a3) + proper_a4 = jnp.array(proper_a4) + proper_indices = jnp.array(proper_indices) + proper_period = jnp.array(proper_period) + + improper_a1, improper_a2, improper_a3, improper_a4, improper_indices, improper_period = [], [], [], [], [], [] + for improper in impr_list: + iidx, order = self._find_improper_key_index(improper) + if iidx is None: + continue + + prm_indices = self.imp_key_to_prms[iidx] + for prm_idx in prm_indices: + prm_period = self.imp_periods[prm_idx] + improper_a1.append(atoms[order[0]].index) + improper_a2.append(atoms[order[1]].index) + improper_a3.append(atoms[order[2]].index) + improper_a4.append(atoms[order[3]].index) + improper_indices.append(prm_idx) + improper_period.append(prm_period) + improper_a1 = jnp.array(improper_a1) + improper_a2 = jnp.array(improper_a2) + improper_a3 = jnp.array(improper_a3) + improper_a4 = jnp.array(improper_a4) + improper_indices = jnp.array(improper_indices) + improper_period = jnp.array(improper_period) + + proper_func = CustomTorsionJaxForce( + proper_a1, proper_a2, proper_a3, proper_a4, proper_indices, proper_period) + proper_energy = proper_func.generate_get_energy() + improper_func = CustomTorsionJaxForce( + improper_a1, improper_a2, improper_a3, improper_a4, improper_indices, improper_period) + improper_energy = improper_func.generate_get_energy() + + has_aux = False + if "has_aux" in kwargs and kwargs["has_aux"]: + has_aux = True + + def potential_fn(positions: jnp.ndarray, box: jnp.ndarray, pairs: jnp.ndarray, params: ParamSet, aux=None): + isinstance_jnp(positions, box, params) + proper_energy_ = proper_energy( + positions, box, pairs, params[self.name]["proper_k"], params[self.name]["proper_phase"], params[self.name]["proper_shift"]) + improper_energy_ = improper_energy( + positions, box, pairs, params[self.name]["improper_k"], params[self.name]["improper_phase"], params[self.name]["improper_shift"]) + if has_aux: + return proper_energy_ + improper_energy_, aux + else: + return proper_energy_ + improper_energy_ + + self._jaxPotential = potential_fn + return potential_fn + + +_DMFFGenerators["CustomTorsionForce"] = CustomTorsionGenerator + + class NonbondedGenerator: def __init__(self, ffinfo: dict, paramset: ParamSet): - self.name = "NonbondedForce" + self.name = "NonbondedForce" + self.ffinfo = ffinfo + paramset.addField(self.name) + self.coulomb14scale = float( + self.ffinfo["Forces"]["NonbondedForce"]["meta"].get("coulomb14scale", 0.8333333333333334)) + self.lj14scale = float( + self.ffinfo["Forces"]["NonbondedForce"]["meta"].get("lj14scale", 0.5)) + self.key_type = None + self.type_to_charge = {} + + self.charge_in_residue = False + for node in self.ffinfo["Forces"]["NonbondedForce"]["node"]: + if not self.charge_in_residue and node["name"] == "UseAttributeFromResidue": + if node["attrib"]["name"] == "charge": + self.charge_in_residue = True + + types, sigma, epsilon, atom_mask = [], [], [], [] + for node in self.ffinfo["Forces"]["NonbondedForce"]["node"]: + if node["name"] == "Atom": + attribs = node["attrib"] + self.key_type = None + if "type" in attribs: + self.key_type = "type" + elif "class" in attribs: + self.key_type = "class" + types.append(attribs[self.key_type]) + sigma.append(float(attribs["sigma"])) + epsilon.append(float(attribs["epsilon"])) + mask = 1.0 + if "mask" in attribs and attribs["mask"].upper() == "TRUE": + mask = 0.0 + atom_mask.append(mask) + if not self.charge_in_residue: + if "charge" not in attribs: + raise ValueError("No charge information found in NonbondedForce or Residues.") + self.type_to_charge[attribs[self.key_type]] = float(attribs["charge"]) + + sigma = jnp.array(sigma) + epsilon = jnp.array(epsilon) + atom_mask = jnp.array(atom_mask) + self.atom_keys = types + paramset.addParameter(sigma, "sigma", field=self.name, mask=atom_mask) + paramset.addParameter(epsilon, "epsilon", field=self.name, mask=atom_mask) + + def getName(self): + return self.name + + def overwrite(self, paramset): + sigma = paramset[self.name]["sigma"] + epsilon = paramset[self.name]["epsilon"] + atom_mask = paramset.mask[self.name]["sigma"] + + node2atom = [i for i in range(len(self.ffinfo["Forces"][self.name]["node"])) if self.ffinfo["Forces"][self.name]["node"][i]["name"] == "Atom"] + + for natom in range(len(self.atom_keys)): + nnode = node2atom[natom] + sig_new = sigma[natom] + eps_new = epsilon[natom] + mask = atom_mask[natom] + self.ffinfo["Forces"][self.name]["node"][nnode]["attrib"]["sigma"] = str(sig_new) + self.ffinfo["Forces"][self.name]["node"][nnode]["attrib"]["epsilon"] = str(eps_new) + if mask < 0.999: + self.ffinfo["Forces"][self.name]["node"][nnode]["attrib"]["mask"] = "true" + + def _find_atype_key_index(self, atype: str): + for n, i in enumerate(self.atom_keys): + if i == atype: + return n + return None + + def createPotential(self, topdata: DMFFTopology, nonbondedMethod, + nonbondedCutoff, **kwargs): + methodMap = { + app.NoCutoff: "NoCutoff", + app.CutoffPeriodic: "CutoffPeriodic", + app.CutoffNonPeriodic: "CutoffNonPeriodic", + app.PME: "PME", + } + methodString = methodMap[nonbondedMethod] + if nonbondedMethod not in methodMap: + raise DMFFException("Illegal nonbonded method for NonbondedForce") + + isNoCut = False + if nonbondedMethod is app.NoCutoff: + isNoCut = True + + mscales_coul = jnp.array([0.0, 0.0, self.coulomb14scale, 1.0, 1.0, + 1.0]) + mscales_lj = jnp.array([0.0, 0.0, self.lj14scale, 1.0, 1.0, + 1.0]) + + # coulomb + # set PBC + if nonbondedMethod not in [app.NoCutoff, app.CutoffNonPeriodic]: + ifPBC = True + else: + ifPBC = False + + if self.charge_in_residue: + charges = [a.meta["charge"] for a in topdata.atoms()] + charges = jnp.array(charges) + else: + types = [a.meta[self.key_type] for a in topdata.atoms()] + charges = jnp.array([self.type_to_charge[i] for i in types]) + + if unit.is_quantity(nonbondedCutoff): + r_cut = nonbondedCutoff.value_in_unit(unit.nanometer) + else: + r_cut = nonbondedCutoff + + # PME Settings + if nonbondedMethod is app.PME: + cell = topdata.getPeriodicBoxVectors() + self.ethresh = kwargs.get("ethresh", 1e-6) + self.coeff_method = kwargs.get("PmeCoeffMethod", "openmm") + self.fourier_spacing = kwargs.get("PmeSpacing", 0.1) + kappa, K1, K2, K3 = setup_ewald_parameters(r_cut, self.ethresh, + cell, + self.fourier_spacing, + self.coeff_method) + if nonbondedMethod is not app.PME: + # do not use PME + if nonbondedMethod in [app.CutoffPeriodic, app.CutoffNonPeriodic]: + # use Reaction Field + coulforce = CoulReactionFieldForce(r_cut, charges, isPBC=ifPBC) + if nonbondedMethod is app.NoCutoff: + # use NoCutoff + coulforce = CoulNoCutoffForce(init_charges=charges) + else: + coulforce = CoulombPMEForce(r_cut, charges, kappa, (K1, K2, K3)) + + self.pme_force = coulforce + coulenergy = coulforce.generate_get_energy() + + # LJ + atypes = [a.meta[self.key_type] for a in topdata.atoms()] + map_prm = [] + for atype in atypes: + pidx = self._find_atype_key_index(atype) + if pidx is None: + raise DMFFException(f"Atom type {atype} not found.") + map_prm.append(pidx) + map_prm = jnp.array(map_prm) + + # not use nbfix for now + map_nbfix = [] + map_nbfix = jnp.array(map_nbfix, dtype=int).reshape((-1, 3)) + eps_nbfix = jnp.array(map_nbfix, dtype=float).reshape((-1, 3)) + sig_nbfix = jnp.array(map_nbfix, dtype=float).reshape((-1, 3)) + + if methodString in ["NoCutoff", "CutoffNonPeriodic"]: + isPBC = False + if methodString == "NoCutoff": + isNoCut = True + else: + isNoCut = False + else: + isPBC = True + isNoCut = False + + ljforce = LennardJonesForce(0.0, + r_cut, + map_prm, + map_nbfix, + isSwitch=False, + isPBC=isPBC, + isNoCut=isNoCut) + ljenergy = ljforce.generate_get_energy() + + # dispersion correction + use_disp_corr = False + if "useDispersionCorrection" in kwargs and kwargs["useDispersionCorrection"]: + use_disp_corr = True + numTypes = len(self.atom_keys) + countVec = np.zeros(numTypes, dtype=int) + countMat = np.zeros((numTypes, numTypes), dtype=int) + types, count = np.unique(map_prm, return_counts=True) + for typ, cnt in zip(types, count): + countVec[typ] += cnt + for i in range(numTypes): + for j in range(i, numTypes): + if i != j: + countMat[i, j] = countVec[i] * countVec[j] + else: + countMat[i, j] = countVec[i] * (countVec[i] - 1) // 2 + assert np.sum(countMat) == len(map_prm) * (len(map_prm) - 1) // 2 + + coval_map = topdata.buildCovMat() + colv_pairs = np.argwhere( + np.logical_and(coval_map > 0, coval_map <= 3)) + for pair in colv_pairs: + if pair[0] <= pair[1]: + tmp = (map_prm[pair[0]], map_prm[pair[1]]) + t1, t2 = min(tmp), max(tmp) + countMat[t1, t2] -= 1 + + ljDispCorrForce = LennardJonesLongRangeForce(r_cut, map_prm, map_nbfix, countMat) + ljDispEnergyFn = ljDispCorrForce.generate_get_energy() + + has_aux = False + if "has_aux" in kwargs and kwargs["has_aux"]: + has_aux = True + + def potential_fn(positions, box, pairs, params, aux=None): + + # check whether args passed into potential_fn are jnp.array and differentiable + # note this check will be optimized away by jit + # it is jit-compatiable + isinstance_jnp(positions, box, params) + + coulE = coulenergy(positions, box, pairs, mscales_coul) + + ljE = ljenergy(positions, box, pairs, params[self.name]["epsilon"], + params[self.name]["sigma"], eps_nbfix, sig_nbfix, mscales_lj) + if use_disp_corr: + ljdispE = ljDispEnergyFn(box, params[self.name]["epsilon"], + params[self.name]["sigma"], eps_nbfix, sig_nbfix) + if has_aux: + return coulE + ljE + ljdispE, aux + else: + return coulE + ljE + ljdispE + else: + if has_aux: + return coulE + ljE, aux + else: + return coulE + ljE + + self._jaxPotential = potential_fn + return potential_fn + + +_DMFFGenerators["NonbondedForce"] = NonbondedGenerator + + +class CoulombGenerator: + def __init__(self, ffinfo: dict, paramset: ParamSet): + self.name = "CoulombForce" self.ffinfo = ffinfo paramset.addField(self.name) self.coulomb14scale = float( - self.ffinfo["Forces"]["NonbondedForce"]["meta"].get("coulomb14scale", 0.8333333333333334)) + self.ffinfo["Forces"]["CoulombForce"]["meta"]["coulomb14scale"]) + self._use_bcc = False + self._bcc_mol = [] + self.bcc_parsers = [] + bcc_prms = [] + bcc_mask = [] + for node in self.ffinfo["Forces"]["CoulombForce"]["node"]: + if node["name"] == "UseBondChargeCorrection": + self._use_bcc = True + self._bcc_mol.append(node["attrib"]["name"]) + if node["name"] == "BondChargeCorrection": + bcc = node["attrib"]["bcc"] + parser = node["attrib"]["smarts"] if "smarts" in node["attrib"] else node["attrib"]["smirks"] + bcc_prms.append(float(bcc)) + self.bcc_parsers.append(parser) + if "mask" in node["attrib"] and node["attrib"]["mask"].upper() == "TRUE": + bcc_mask.append(0.0) + else: + bcc_mask.append(1.0) + bcc_prms = jnp.array(bcc_prms) + bcc_mask = jnp.array(bcc_mask) + paramset.addParameter(bcc_prms, "bcc", field=self.name, mask=bcc_mask) + self._bcc_shape = paramset[self.name]["bcc"].shape[0] + + def getName(self): + return self.name + + def overwrite(self, paramset): + # paramset to ffinfo + if self._use_bcc: + bcc_now = paramset[self.name]["bcc"] + mask_list = paramset.mask[self.name]["bcc"] + nbcc = 0 + for nnode, node in enumerate(self.ffinfo["Forces"][self.name]["node"]): + if node["name"] == "BondChargeCorrection": + mask = mask_list[nbcc] + self.ffinfo["Forces"][self.name]["node"][nnode]["attrib"]["bcc"] = bcc_now[nbcc] + if mask < 0.999: + self.ffinfo["Forces"][self.name]["node"][nnode]["attrib"]["mask"] = "true" + nbcc += 1 + + def createPotential(self, topdata: DMFFTopology, nonbondedMethod, + nonbondedCutoff, **kwargs): + methodMap = { + app.NoCutoff: "NoCutoff", + app.CutoffPeriodic: "CutoffPeriodic", + app.CutoffNonPeriodic: "CutoffNonPeriodic", + app.PME: "PME", + } + if nonbondedMethod not in methodMap: + raise DMFFException("Illegal nonbonded method for NonbondedForce") + + isNoCut = False + if nonbondedMethod is app.NoCutoff: + isNoCut = True + + mscales_coul = jnp.array([0.0, 0.0, 0.0, 1.0, 1.0, + 1.0]) # mscale for PME + mscales_coul = mscales_coul.at[2].set(self.coulomb14scale) + self.mscales_coul = mscales_coul # for qeq calculation + + # set PBC + if nonbondedMethod not in [app.NoCutoff, app.CutoffNonPeriodic]: + ifPBC = True + else: + ifPBC = False + + charges = [a.meta["charge"] for a in topdata.atoms()] + charges = jnp.array(charges) + + cov_mat = topdata.buildCovMat() + + if unit.is_quantity(nonbondedCutoff): + r_cut = nonbondedCutoff.value_in_unit(unit.nanometer) + else: + r_cut = nonbondedCutoff + + # PME Settings + if nonbondedMethod is app.PME: + cell = topdata.getPeriodicBoxVectors() + box = jnp.array(cell) + self.ethresh = kwargs.get("ethresh", 1e-5) + self.coeff_method = kwargs.get("PmeCoeffMethod", "openmm") + self.fourier_spacing = kwargs.get("PmeSpacing", 0.1) + kappa, K1, K2, K3 = setup_ewald_parameters(r_cut, self.ethresh, + box, + self.fourier_spacing, + self.coeff_method) + + if self._use_bcc: + top_mat = np.zeros( + (topdata.getNumAtoms(), self._bcc_shape)) + matched_dict = {} + for nparser, parser in enumerate(self.bcc_parsers): + matches = topdata.parseSMARTS(parser, resname=self._bcc_mol) + for ii, jj in matches: + if (ii, jj) in matched_dict: + del matched_dict[(ii, jj)] + elif (jj, ii) in matched_dict: + del matched_dict[(jj, ii)] + matched_dict[(ii, jj)] = nparser + for ii, jj in matched_dict.keys(): + nval = matched_dict[(ii, jj)] + top_mat[ii, nval] += 1. + top_mat[jj, nval] -= 1. + topdata._meta["bcc_top_mat"] = top_mat + + if nonbondedMethod is not app.PME: + # do not use PME + if nonbondedMethod in [app.CutoffPeriodic, app.CutoffNonPeriodic]: + # use Reaction Field + coulforce = CoulReactionFieldForce( + r_cut, + charges, + isPBC=ifPBC, + topology_matrix=top_mat if self._use_bcc else None) + if nonbondedMethod is app.NoCutoff: + # use NoCutoff + coulforce = CoulNoCutoffForce( + charges, topology_matrix=top_mat if self._use_bcc else None) + else: + coulforce = CoulombPMEForce( + r_cut, + charges, + kappa, (K1, K2, K3), + topology_matrix=top_mat if self._use_bcc else None) + + coulenergy = coulforce.generate_get_energy() + + has_aux = False + if "has_aux" in kwargs and kwargs["has_aux"]: + has_aux = True + + def potential_fn(positions, box, pairs, params, aux=None): + + # check whether args passed into potential_fn are jnp.array and differentiable + # note this check will be optimized away by jit + # it is jit-compatiable + isinstance_jnp(positions, box, params) + + if self._use_bcc: + coulE = coulenergy(positions, box, pairs, + params["CoulombForce"]["bcc"], mscales_coul) + else: + coulE = coulenergy(positions, box, pairs, + mscales_coul) + + if has_aux: + return coulE, aux + else: + return coulE + + self._jaxPotential = potential_fn + return potential_fn + + +_DMFFGenerators["CoulombForce"] = CoulombGenerator + + +class LennardJonesGenerator: + + def __init__(self, ffinfo: dict, paramset: ParamSet): + self.name = "LennardJonesForce" + self.ffinfo = ffinfo self.lj14scale = float( - self.ffinfo["Forces"]["NonbondedForce"]["meta"].get("lj14scale", 0.5)) - self.key_type = None - self.type_to_charge = {} - - self.charge_in_residue = False - for node in self.ffinfo["Forces"]["NonbondedForce"]["node"]: - if not self.charge_in_residue and node["name"] == "UseAttributeFromResidue": - if node["attrib"]["name"] == "charge": - self.charge_in_residue = True - - types, sigma, epsilon, atom_mask = [], [], [], [] - for node in self.ffinfo["Forces"]["NonbondedForce"]["node"]: + self.ffinfo["Forces"][self.name]["meta"]["lj14scale"]) + self.nbfix_to_idx = {} + self.atype_to_idx = {} + sig_prms, eps_prms = [], [] + sig_mask, eps_mask = [], [] + sig_nbfix, eps_nbfix = [], [] + sig_nbf_mask, eps_nbf_mask = [], [] + for node in self.ffinfo["Forces"][self.name]["node"]: if node["name"] == "Atom": - attribs = node["attrib"] - self.key_type = None - if "type" in attribs: - self.key_type = "type" - elif "class" in attribs: - self.key_type = "class" - types.append(attribs[self.key_type]) - sigma.append(float(attribs["sigma"])) - epsilon.append(float(attribs["epsilon"])) - mask = 1.0 - if "mask" in attribs and attribs["mask"].upper() == "TRUE": - mask = 0.0 - atom_mask.append(mask) - if not self.charge_in_residue: - if "charge" not in attribs: - raise ValueError("No charge information found in NonbondedForce or Residues.") - self.type_to_charge[attribs[self.key_type]] = float(attribs["charge"]) + if "type" in node["attrib"]: + atype, eps, sig = node["attrib"]["type"], node["attrib"][ + "epsilon"], node["attrib"]["sigma"] + self.atype_to_idx[atype] = len(sig_prms) + elif "class" in node["attrib"]: + acls, eps, sig = node["attrib"]["class"], node["attrib"][ + "epsilon"], node["attrib"]["sigma"] + atypes = ffinfo["ClassToType"][acls] + for atype in atypes: + self.atype_to_idx[atype] = len(sig_prms) + sig_prms.append(float(sig)) + eps_prms.append(float(eps)) + if "mask" in node["attrib"] and node["attrib"]["mask"].upper() == "TRUE": + sig_mask.append(0.0) + eps_mask.append(0.0) + else: + sig_mask.append(1.0) + eps_mask.append(1.0) + elif node["name"] == "NBFixPair": + if "type1" in node["attrib"]: + atype1, atype2, eps, sig = node["attrib"]["type1"], node["attrib"][ + "type2"], node["attrib"]["epsilon"], node["attrib"]["sigma"] + if atype1 not in self.nbfix_to_idx: + self.nbfix_to_idx[atype1] = {} + if atype2 not in self.nbfix_to_idx: + self.nbfix_to_idx[atype2] = {} + self.nbfix_to_idx[atype1][atype2] = len(sig_nbfix) + self.nbfix_to_idx[atype2][atype1] = len(sig_nbfix) + elif "class1" in node["attrib"]: + acls1, acls2, eps, sig = node["attrib"]["class1"], node["attrib"][ + "class2"], node["attrib"]["epsilon"], node["attrib"]["sigma"] + atypes1 = ffinfo["ClassToType"][acls1] + atypes2 = ffinfo["ClassToType"][acls2] + for atype1 in atypes1: + if atype1 not in self.nbfix_to_idx: + self.nbfix_to_idx[atype1] = {} + for atype2 in atypes2: + if atype2 not in self.nbfix_to_idx: + self.nbfix_to_idx[atype2] = {} + self.nbfix_to_idx[atype1][atype2] = len(sig_nbfix) + self.nbfix_to_idx[atype2][atype1] = len(sig_nbfix) + sig_nbfix.append(float(sig)) + eps_nbfix.append(float(eps)) + if "mask" in node["attrib"] and node["attrib"]["mask"].upper() == "TRUE": + sig_nbf_mask.append(0.0) + eps_nbf_mask.append(0.0) + else: + sig_nbf_mask.append(1.0) + eps_nbf_mask.append(1.0) - sigma = jnp.array(sigma) - epsilon = jnp.array(epsilon) - atom_mask = jnp.array(atom_mask) - self.atom_keys = types - paramset.addParameter(sigma, "sigma", field=self.name, mask=atom_mask) - paramset.addParameter(epsilon, "epsilon", field=self.name, mask=atom_mask) + sig_prms = jnp.array(sig_prms) + eps_prms = jnp.array(eps_prms) + sig_mask = jnp.array(sig_mask) + eps_mask = jnp.array(eps_mask) + + sig_nbfix, eps_nbfix = jnp.array(sig_nbfix), jnp.array(eps_nbfix) + sig_nbf_mask = jnp.array(sig_nbf_mask) + eps_nbf_mask = jnp.array(eps_nbf_mask) + + paramset.addField(self.name) + paramset.addParameter( + sig_prms, "sigma", field=self.name, mask=sig_mask) + paramset.addParameter(eps_prms, "epsilon", + field=self.name, mask=eps_mask) + paramset.addParameter(sig_nbfix, "sigma_nbfix", + field=self.name, mask=sig_nbf_mask) + paramset.addParameter(eps_nbfix, "epsilon_nbfix", + field=self.name, mask=eps_nbf_mask) def getName(self): return self.name def overwrite(self, paramset): - sigma = paramset[self.name]["sigma"] - epsilon = paramset[self.name]["epsilon"] - atom_mask = paramset.mask[self.name]["sigma"] + # paramset to ffinfo + for nnode in range(len(self.ffinfo["Forces"][self.name]["node"])): + node = self.ffinfo["Forces"][self.name]["node"][nnode] + if node["name"] == "Atom": + if "type" in node["attrib"]: + atype = node["attrib"]["type"] + idx = self.atype_to_idx[atype] - node2atom = [i for i in range(len(self.ffinfo["Forces"][self.name]["node"])) if self.ffinfo["Forces"][self.name]["node"][i]["name"] == "Atom"] + elif "class" in node["attrib"]: + acls = node["attrib"]["class"] + atypes = self.ffinfo["ClassToType"][acls] + idx = self.atype_to_idx[atypes[0]] - for natom in range(len(self.atom_keys)): - nnode = node2atom[natom] - sig_new = sigma[natom] - eps_new = epsilon[natom] - mask = atom_mask[natom] - self.ffinfo["Forces"][self.name]["node"][nnode]["attrib"]["sigma"] = str(sig_new) - self.ffinfo["Forces"][self.name]["node"][nnode]["attrib"]["epsilon"] = str(eps_new) - if mask < 0.999: - self.ffinfo["Forces"][self.name]["node"][nnode]["attrib"]["mask"] = "true" + eps_now = paramset[self.name]["epsilon"][idx] + sig_now = paramset[self.name]["sigma"][idx] + self.ffinfo["Forces"][ + self.name]["node"][nnode]["attrib"]["sigma"] = sig_now + self.ffinfo["Forces"][ + self.name]["node"][nnode]["attrib"]["epsilon"] = eps_now + # have not tested for NBFixPair overwrite + elif node["name"] == "NBFixPair": + if "type1" in node["attrib"]: + atype1, atype2 = node["attrib"]["type1"], node["attrib"]["type2"] + idx = self.nbfix_to_idx[atype1][atype2] + elif "class1" in node["attrib"]: + acls1, acls2 = node["attrib"]["class1"], node["attrib"]["class2"] + atypes1 = self.ffinfo["ClassToType"][acls1] + atypes2 = self.ffinfo["ClassToType"][acls2] + idx = self.nbfix_to_idx[atypes1[0]][atypes2[0]] + sig_now = paramset[self.name]["sigma_nbfix"][idx] + eps_now = paramset[self.name]["epsilon_nbfix"][idx] + self.ffinfo["Forces"][self.name]["node"][nnode]["attrib"]["sigma"] = sig_now + self.ffinfo["Forces"][self.name]["node"][nnode]["attrib"]["epsilon"] = eps_now - def _find_atype_key_index(self, atype: str): - for n, i in enumerate(self.atom_keys): - if i == atype: - return n - return None - def createPotential(self, topdata: DMFFTopology, nonbondedMethod, nonbondedCutoff, **kwargs): methodMap = { app.NoCutoff: "NoCutoff", app.CutoffPeriodic: "CutoffPeriodic", app.CutoffNonPeriodic: "CutoffNonPeriodic", - app.PME: "PME", + app.PME: "CutoffPeriodic", } - methodString = methodMap[nonbondedMethod] if nonbondedMethod not in methodMap: raise DMFFException("Illegal nonbonded method for NonbondedForce") + methodString = methodMap[nonbondedMethod] - isNoCut = False - if nonbondedMethod is app.NoCutoff: - isNoCut = True - - mscales_coul = jnp.array([0.0, 0.0, self.coulomb14scale, 1.0, 1.0, - 1.0]) - mscales_lj = jnp.array([0.0, 0.0, self.lj14scale, 1.0, 1.0, - 1.0]) - - # coulomb - # set PBC - if nonbondedMethod not in [app.NoCutoff, app.CutoffNonPeriodic]: - ifPBC = True - else: - ifPBC = False - - if self.charge_in_residue: - charges = [a.meta["charge"] for a in topdata.atoms()] - charges = jnp.array(charges) - else: - types = [a.meta[self.key_type] for a in topdata.atoms()] - charges = jnp.array([self.type_to_charge[i] for i in types]) - - if unit.is_quantity(nonbondedCutoff): - r_cut = nonbondedCutoff.value_in_unit(unit.nanometer) - else: - r_cut = nonbondedCutoff - - # PME Settings - if nonbondedMethod is app.PME: - cell = topdata.getPeriodicBoxVectors() - self.ethresh = kwargs.get("ethresh", 1e-6) - self.coeff_method = kwargs.get("PmeCoeffMethod", "openmm") - self.fourier_spacing = kwargs.get("PmeSpacing", 0.1) - kappa, K1, K2, K3 = setup_ewald_parameters(r_cut, self.ethresh, - cell, - self.fourier_spacing, - self.coeff_method) - if nonbondedMethod is not app.PME: - # do not use PME - if nonbondedMethod in [app.CutoffPeriodic, app.CutoffNonPeriodic]: - # use Reaction Field - coulforce = CoulReactionFieldForce(r_cut, charges, isPBC=ifPBC) - if nonbondedMethod is app.NoCutoff: - # use NoCutoff - coulforce = CoulNoCutoffForce(init_charges=charges) - else: - coulforce = CoulombPMEForce(r_cut, charges, kappa, (K1, K2, K3)) - - self.pme_force = coulforce - coulenergy = coulforce.generate_get_energy() - - # LJ - atypes = [a.meta[self.key_type] for a in topdata.atoms()] + atoms = [a for a in topdata.atoms()] + atypes = [a.meta["type"] for a in atoms] map_prm = [] for atype in atypes: - pidx = self._find_atype_key_index(atype) - if pidx is None: + if atype not in self.atype_to_idx: raise DMFFException(f"Atom type {atype} not found.") - map_prm.append(pidx) + idx = self.atype_to_idx[atype] + map_prm.append(idx) map_prm = jnp.array(map_prm) + topdata._meta["lj_map_idx"] = map_prm # not use nbfix for now map_nbfix = [] - map_nbfix = jnp.array(map_nbfix, dtype=int).reshape((-1, 3)) - eps_nbfix = jnp.array(map_nbfix, dtype=float).reshape((-1, 3)) - sig_nbfix = jnp.array(map_nbfix, dtype=float).reshape((-1, 3)) + for atype1 in self.nbfix_to_idx.keys(): + for atype2 in self.nbfix_to_idx[atype1].keys(): + nbfix_idx = self.nbfix_to_idx[atype1][atype2] + type1_idx = self.atype_to_idx[atype1] + type2_idx = self.atype_to_idx[atype2] + map_nbfix.append([type1_idx, type2_idx, nbfix_idx]) + map_nbfix = np.array(map_nbfix, dtype=int).reshape((-1, 3)) if methodString in ["NoCutoff", "CutoffNonPeriodic"]: isPBC = False @@ -949,6 +1713,14 @@ def createPotential(self, topdata: DMFFTopology, nonbondedMethod, isPBC = True isNoCut = False + mscales_lj = jnp.array([0.0, 0.0, 0.0, 1.0, 1.0, 1.0]) # mscale for LJ + mscales_lj = mscales_lj.at[2].set(self.lj14scale) + + if unit.is_quantity(nonbondedCutoff): + r_cut = nonbondedCutoff.value_in_unit(unit.nanometer) + else: + r_cut = nonbondedCutoff + ljforce = LennardJonesForce(0.0, r_cut, map_prm, @@ -958,36 +1730,6 @@ def createPotential(self, topdata: DMFFTopology, nonbondedMethod, isNoCut=isNoCut) ljenergy = ljforce.generate_get_energy() - # dispersion correction - use_disp_corr = False - if "useDispersionCorrection" in kwargs and kwargs["useDispersionCorrection"]: - use_disp_corr = True - numTypes = len(self.atom_keys) - countVec = np.zeros(numTypes, dtype=int) - countMat = np.zeros((numTypes, numTypes), dtype=int) - types, count = np.unique(map_prm, return_counts=True) - for typ, cnt in zip(types, count): - countVec[typ] += cnt - for i in range(numTypes): - for j in range(i, numTypes): - if i != j: - countMat[i, j] = countVec[i] * countVec[j] - else: - countMat[i, j] = countVec[i] * (countVec[i] - 1) // 2 - assert np.sum(countMat) == len(map_prm) * (len(map_prm) - 1) // 2 - - coval_map = topdata.buildCovMat() - colv_pairs = np.argwhere( - np.logical_and(coval_map > 0, coval_map <= 3)) - for pair in colv_pairs: - if pair[0] <= pair[1]: - tmp = (map_prm[pair[0]], map_prm[pair[1]]) - t1, t2 = min(tmp), max(tmp) - countMat[t1, t2] -= 1 - - ljDispCorrForce = LennardJonesLongRangeForce(r_cut, map_prm, map_nbfix, countMat) - ljDispEnergyFn = ljDispCorrForce.generate_get_energy() - has_aux = False if "has_aux" in kwargs and kwargs["has_aux"]: has_aux = True @@ -999,402 +1741,372 @@ def potential_fn(positions, box, pairs, params, aux=None): # it is jit-compatiable isinstance_jnp(positions, box, params) - coulE = coulenergy(positions, box, pairs, mscales_coul) - - ljE = ljenergy(positions, box, pairs, params[self.name]["epsilon"], - params[self.name]["sigma"], eps_nbfix, sig_nbfix, mscales_lj) - if use_disp_corr: - ljdispE = ljDispEnergyFn(box, params[self.name]["epsilon"], - params[self.name]["sigma"], eps_nbfix, sig_nbfix) - if has_aux: - return coulE + ljE + ljdispE, aux - else: - return coulE + ljE + ljdispE + ljE = ljenergy(positions, box, pairs, + params[self.name]["epsilon"], + params[self.name]["sigma"], + params[self.name]["epsilon_nbfix"], + params[self.name]["sigma_nbfix"], + mscales_lj) + + if has_aux: + return ljE, aux else: - if has_aux: - return coulE + ljE, aux - else: - return coulE + ljE + return ljE self._jaxPotential = potential_fn return potential_fn -_DMFFGenerators["NonbondedForce"] = NonbondedGenerator +_DMFFGenerators["LennardJonesForce"] = LennardJonesGenerator + + +class Custom1_5BondGenerator: + """ + A class for generating harmonic bond force field parameters. + Attributes: + ----------- + name : str + The name of the force field. + ffinfo : dict + The force field information. + key_type : str + The type of the key. + bond_keys : list of tuple + The keys of the bonds. + bond_params : list of tuple + The parameters of the bonds. + bond_mask : list of float + The mask of the bonds. + """ -class CoulombGenerator: def __init__(self, ffinfo: dict, paramset: ParamSet): - self.name = "CoulombForce" + """ + Initializes the HarmonicBondGenerator. + + Parameters: + ----------- + ffinfo : dict + The force field information. + paramset : ParamSet + The parameter set. + """ + self.name = "Custom1_5BondForce" self.ffinfo = ffinfo paramset.addField(self.name) - self.coulomb14scale = float( - self.ffinfo["Forces"]["CoulombForce"]["meta"]["coulomb14scale"]) - self._use_bcc = False - self._bcc_mol = [] - self.bcc_parsers = [] - bcc_prms = [] - bcc_mask = [] - for node in self.ffinfo["Forces"]["CoulombForce"]["node"]: - if node["name"] == "UseBondChargeCorrection": - self._use_bcc = True - self._bcc_mol.append(node["attrib"]["name"]) - if node["name"] == "BondChargeCorrection": - bcc = node["attrib"]["bcc"] - parser = node["attrib"]["smarts"] if "smarts" in node["attrib"] else node["attrib"]["smirks"] - bcc_prms.append(float(bcc)) - self.bcc_parsers.append(parser) - if "mask" in node["attrib"] and node["attrib"]["mask"].upper() == "TRUE": - bcc_mask.append(0.0) + self.key_type = None + + bond_keys, bond_params = [], [] + for node in self.ffinfo["Forces"][self.name]["node"]: + attribs = node["attrib"] + if self.key_type is None: + if "atomIndex1" in attribs: + self.key_type = "atomIndex" else: - bcc_mask.append(1.0) - bcc_prms = jnp.array(bcc_prms) - bcc_mask = jnp.array(bcc_mask) - paramset.addParameter(bcc_prms, "bcc", field=self.name, mask=bcc_mask) - self._bcc_shape = paramset[self.name]["bcc"].shape[0] + raise ValueError( + "Cannot find key type for Custom1_5BondForce.") + key = (attribs[self.key_type + "1"], attribs[self.key_type + "2"]) + bond_keys.append(key) + k = float(attribs["k"]) + r0 = float(attribs["length"]) + bond_params.append([k, r0]) - def getName(self): - return self.name + self.bond_keys = bond_keys + bond_length = jnp.array([i[1] for i in bond_params]) + bond_k = jnp.array([i[0] for i in bond_params]) - def overwrite(self, paramset): - # paramset to ffinfo - if self._use_bcc: - bcc_now = paramset[self.name]["bcc"] - mask_list = paramset.mask[self.name]["bcc"] - nbcc = 0 - for nnode, node in enumerate(self.ffinfo["Forces"][self.name]["node"]): - if node["name"] == "BondChargeCorrection": - mask = mask_list[nbcc] - self.ffinfo["Forces"][self.name]["node"][nnode]["attrib"]["bcc"] = bcc_now[nbcc] - if mask < 0.999: - self.ffinfo["Forces"][self.name]["node"][nnode]["attrib"]["mask"] = "true" - nbcc += 1 + # register parameters to ParamSet + paramset.addParameter(bond_length, "length", + field=self.name) + # register parameters to ParamSet + paramset.addParameter(bond_k, "k", field=self.name) - def createPotential(self, topdata: DMFFTopology, nonbondedMethod, - nonbondedCutoff, **kwargs): - methodMap = { - app.NoCutoff: "NoCutoff", - app.CutoffPeriodic: "CutoffPeriodic", - app.CutoffNonPeriodic: "CutoffNonPeriodic", - app.PME: "PME", - } - if nonbondedMethod not in methodMap: - raise DMFFException("Illegal nonbonded method for NonbondedForce") + def getName(self) -> str: + """ + Returns the name of the force field. - isNoCut = False - if nonbondedMethod is app.NoCutoff: - isNoCut = True + Returns: + -------- + str + The name of the force field. + """ + return self.name - mscales_coul = jnp.array([0.0, 0.0, 0.0, 1.0, 1.0, - 1.0]) # mscale for PME - mscales_coul = mscales_coul.at[2].set(self.coulomb14scale) - self.mscales_coul = mscales_coul # for qeq calculation + def overwrite(self, paramset: ParamSet) -> None: + """ + Overwrites the parameter set. - # set PBC - if nonbondedMethod not in [app.NoCutoff, app.CutoffNonPeriodic]: - ifPBC = True - else: - ifPBC = False + Parameters: + ----------- + paramset : ParamSet + The parameter set. + """ + bond_node_indices = [i for i in range(len( + self.ffinfo["Forces"][self.name]["node"])) if self.ffinfo["Forces"][self.name]["node"][i]["name"] == "Bond"] - charges = [a.meta["charge"] for a in topdata.atoms()] - charges = jnp.array(charges) + bond_length = paramset[self.name]["length"] + bond_k = paramset[self.name]["k"] + for nnode, key in enumerate(self.bond_keys): + self.ffinfo["Forces"][self.name]["node"][bond_node_indices[nnode]]["attrib"] = { + } + self.ffinfo["Forces"][self.name]["node"][bond_node_indices[nnode] + ]["attrib"][f"{self.key_type}1"] = key[0] + self.ffinfo["Forces"][self.name]["node"][bond_node_indices[nnode] + ]["attrib"][f"{self.key_type}2"] = key[1] + r0 = bond_length[nnode] + k = bond_k[nnode] + self.ffinfo["Forces"][self.name]["node"][bond_node_indices[nnode] + ]["attrib"]["k"] = str(k) + self.ffinfo["Forces"][self.name]["node"][bond_node_indices[nnode] + ]["attrib"]["length"] = str(r0) - cov_mat = topdata.buildCovMat() + def _find_key_index(self, key: Tuple[str, str]) -> int: + """ + Finds the index of the key. - if unit.is_quantity(nonbondedCutoff): - r_cut = nonbondedCutoff.value_in_unit(unit.nanometer) - else: - r_cut = nonbondedCutoff + Parameters: + ----------- + key : tuple of str + The key. - # PME Settings - if nonbondedMethod is app.PME: - cell = topdata.getPeriodicBoxVectors() - box = jnp.array(cell) - self.ethresh = kwargs.get("ethresh", 1e-5) - self.coeff_method = kwargs.get("PmeCoeffMethod", "openmm") - self.fourier_spacing = kwargs.get("PmeSpacing", 0.1) - kappa, K1, K2, K3 = setup_ewald_parameters(r_cut, self.ethresh, - box, - self.fourier_spacing, - self.coeff_method) + Returns: + -------- + int + The index of the key. + """ + for i, k in enumerate(self.bond_keys): + if k[0] == key[0] and k[1] == key[1]: + return i + if k[0] == key[1] and k[1] == key[0]: + return i + return None - if self._use_bcc: - top_mat = np.zeros( - (topdata.getNumAtoms(), self._bcc_shape)) - matched_dict = {} - for nparser, parser in enumerate(self.bcc_parsers): - matches = topdata.parseSMARTS(parser, resname=self._bcc_mol) - for ii, jj in matches: - if (ii, jj) in matched_dict: - del matched_dict[(ii, jj)] - elif (jj, ii) in matched_dict: - del matched_dict[(jj, ii)] - matched_dict[(ii, jj)] = nparser - for ii, jj in matched_dict.keys(): - nval = matched_dict[(ii, jj)] - top_mat[ii, nval] += 1. - top_mat[jj, nval] -= 1. - topdata._meta["bcc_top_mat"] = top_mat + def createPotential(self, topdata: DMFFTopology, nonbondedMethod, + nonbondedCutoff, **kwargs): + """ + Creates the potential. - if nonbondedMethod is not app.PME: - # do not use PME - if nonbondedMethod in [app.CutoffPeriodic, app.CutoffNonPeriodic]: - # use Reaction Field - coulforce = CoulReactionFieldForce( - r_cut, - charges, - isPBC=ifPBC, - topology_matrix=top_mat if self._use_bcc else None) - if nonbondedMethod is app.NoCutoff: - # use NoCutoff - coulforce = CoulNoCutoffForce( - charges, topology_matrix=top_mat if self._use_bcc else None) - else: - coulforce = CoulombPMEForce( - r_cut, - charges, - kappa, (K1, K2, K3), - topology_matrix=top_mat if self._use_bcc else None) + Parameters: + ----------- + topdata : DMFFTopology + The topology data. + nonbondedMethod : str + The nonbonded method. + nonbondedCutoff : float + The nonbonded cutoff. + args : list + The arguments. - coulenergy = coulforce.generate_get_energy() + Returns: + -------- + function + The potential function. + """ + bond_a1, bond_a2, bond_indices = [], [], [] + for i, k in enumerate(self.bond_keys): + bond_a1.append(int(k[0])) + bond_a2.append(int(k[1])) + bond_indices.append(int(i)) + bond_a1 = jnp.array(bond_a1) + bond_a2 = jnp.array(bond_a2) + bond_indices = jnp.array(bond_indices) + + # 创建势函数 + harmonic_bond_force = HarmonicBondJaxForce( + bond_a1, bond_a2, bond_indices) + harmonic_bond_energy = harmonic_bond_force.generate_get_energy() has_aux = False if "has_aux" in kwargs and kwargs["has_aux"]: has_aux = True - def potential_fn(positions, box, pairs, params, aux=None): - - # check whether args passed into potential_fn are jnp.array and differentiable - # note this check will be optimized away by jit - # it is jit-compatiable + def potential_fn(positions: jnp.ndarray, box: jnp.ndarray, pairs: jnp.ndarray, params: ParamSet, aux=None): isinstance_jnp(positions, box, params) - - if self._use_bcc: - coulE = coulenergy(positions, box, pairs, - params["CoulombForce"]["bcc"], mscales_coul) - else: - coulE = coulenergy(positions, box, pairs, - mscales_coul) - + energy = harmonic_bond_energy( + positions, box, pairs, params[self.name]["k"], params[self.name]["length"]) if has_aux: - return coulE, aux + return energy, aux else: - return coulE + return energy self._jaxPotential = potential_fn return potential_fn -_DMFFGenerators["CoulombForce"] = CoulombGenerator +# register the generator +_DMFFGenerators["Custom1_5BondForce"] = Custom1_5BondGenerator + +class CustomGBGenerator: + """ + A class for generating Custom Generalized Born implicit solvation models. + The following code implements the OBC variant of the GB/SA solvation model, using the ACE approximation to estimate surface area. + + Attributes: + ----------- + name : str + The name of the force field. + ffinfo : dict + The force field information. + key_type : str + The type of the key. + perParticleKey : list of tuple + The keys of the atoms -class LennardJonesGenerator: + """ def __init__(self, ffinfo: dict, paramset: ParamSet): - self.name = "LennardJonesForce" - self.ffinfo = ffinfo - self.lj14scale = float( - self.ffinfo["Forces"][self.name]["meta"]["lj14scale"]) - self.nbfix_to_idx = {} - self.atype_to_idx = {} - sig_prms, eps_prms = [], [] - sig_mask, eps_mask = [], [] - sig_nbfix, eps_nbfix = [], [] - sig_nbf_mask, eps_nbf_mask = [], [] - for node in self.ffinfo["Forces"][self.name]["node"]: - if node["name"] == "Atom": - if "type" in node["attrib"]: - atype, eps, sig = node["attrib"]["type"], node["attrib"][ - "epsilon"], node["attrib"]["sigma"] - self.atype_to_idx[atype] = len(sig_prms) - elif "class" in node["attrib"]: - acls, eps, sig = node["attrib"]["class"], node["attrib"][ - "epsilon"], node["attrib"]["sigma"] - atypes = ffinfo["ClassToType"][acls] - for atype in atypes: - self.atype_to_idx[atype] = len(sig_prms) - sig_prms.append(float(sig)) - eps_prms.append(float(eps)) - if "mask" in node["attrib"] and node["attrib"]["mask"].upper() == "TRUE": - sig_mask.append(0.0) - eps_mask.append(0.0) - else: - sig_mask.append(1.0) - eps_mask.append(1.0) - elif node["name"] == "NBFixPair": - if "type1" in node["attrib"]: - atype1, atype2, eps, sig = node["attrib"]["type1"], node["attrib"][ - "type2"], node["attrib"]["epsilon"], node["attrib"]["sigma"] - if atype1 not in self.nbfix_to_idx: - self.nbfix_to_idx[atype1] = {} - if atype2 not in self.nbfix_to_idx: - self.nbfix_to_idx[atype2] = {} - self.nbfix_to_idx[atype1][atype2] = len(sig_nbfix) - self.nbfix_to_idx[atype2][atype1] = len(sig_nbfix) - elif "class1" in node["attrib"]: - acls1, acls2, eps, sig = node["attrib"]["class1"], node["attrib"][ - "class2"], node["attrib"]["epsilon"], node["attrib"]["sigma"] - atypes1 = ffinfo["ClassToType"][acls1] - atypes2 = ffinfo["ClassToType"][acls2] - for atype1 in atypes1: - if atype1 not in self.nbfix_to_idx: - self.nbfix_to_idx[atype1] = {} - for atype2 in atypes2: - if atype2 not in self.nbfix_to_idx: - self.nbfix_to_idx[atype2] = {} - self.nbfix_to_idx[atype1][atype2] = len(sig_nbfix) - self.nbfix_to_idx[atype2][atype1] = len(sig_nbfix) - sig_nbfix.append(float(sig)) - eps_nbfix.append(float(eps)) - if "mask" in node["attrib"] and node["attrib"]["mask"].upper() == "TRUE": - sig_nbf_mask.append(0.0) - eps_nbf_mask.append(0.0) - else: - sig_nbf_mask.append(1.0) - eps_nbf_mask.append(1.0) - - sig_prms = jnp.array(sig_prms) - eps_prms = jnp.array(eps_prms) - sig_mask = jnp.array(sig_mask) - eps_mask = jnp.array(eps_mask) + """ + Initialize the CustomGBForceGenerator - sig_nbfix, eps_nbfix = jnp.array(sig_nbfix), jnp.array(eps_nbfix) - sig_nbf_mask = jnp.array(sig_nbf_mask) - eps_nbf_mask = jnp.array(eps_nbf_mask) + Parameters: + ----------- + ffinfo : dict + The force field information. + paramset : ParamSet + The parameter set. + """ + self.name = "CustomGBForce" + self.ffinfo = ffinfo paramset.addField(self.name) - paramset.addParameter( - sig_prms, "sigma", field=self.name, mask=sig_mask) - paramset.addParameter(eps_prms, "epsilon", - field=self.name, mask=eps_mask) - paramset.addParameter(sig_nbfix, "sigma_nbfix", - field=self.name, mask=sig_nbf_mask) - paramset.addParameter(eps_nbfix, "epsilon_nbfix", - field=self.name, mask=eps_nbf_mask) + self.key_type = None + self.perParticleParamIndices = [i for i in range(len(self.ffinfo["Forces"][self.name]["node"])) if self.ffinfo["Forces"][self.name]["node"][i]["name"] == "Atom"] - def getName(self): - return self.name + perParticleKey, perParticleParam, chargeMask = [], [], [] + for i in self.perParticleParamIndices: + attribs = self.ffinfo["Forces"][self.name]["node"][i]["attrib"] + if self.key_type is None: + if "type" in attribs: + self.key_type = "type" + elif "class" in attribs: + self.key_type = "class" + else: + raise ValueError( + "Cannot find key type for CustomGBForce." + ) + key = (attribs[self.key_type]) + perParticleKey.append(key) - def overwrite(self, paramset): - # paramset to ffinfo - for nnode in range(len(self.ffinfo["Forces"][self.name]["node"])): - node = self.ffinfo["Forces"][self.name]["node"][nnode] - if node["name"] == "Atom": - if "type" in node["attrib"]: - atype = node["attrib"]["type"] - idx = self.atype_to_idx[atype] + charge = float(attribs["charge"]) + radius = float(attribs["radius"]) + scale = float(attribs["scale"]) - elif "class" in node["attrib"]: - acls = node["attrib"]["class"] - atypes = self.ffinfo["ClassToType"][acls] - idx = self.atype_to_idx[atypes[0]] + # Parameter Charge is not trainable + chargeMask.append(0.0) + perParticleParam.append([charge, radius, scale]) - eps_now = paramset[self.name]["epsilon"][idx] - sig_now = paramset[self.name]["sigma"][idx] - self.ffinfo["Forces"][ - self.name]["node"][nnode]["attrib"]["sigma"] = sig_now - self.ffinfo["Forces"][ - self.name]["node"][nnode]["attrib"]["epsilon"] = eps_now - # have not tested for NBFixPair overwrite - elif node["name"] == "NBFixPair": - if "type1" in node["attrib"]: - atype1, atype2 = node["attrib"]["type1"], node["attrib"]["type2"] - idx = self.nbfix_to_idx[atype1][atype2] - elif "class1" in node["attrib"]: - acls1, acls2 = node["attrib"]["class1"], node["attrib"]["class2"] - atypes1 = self.ffinfo["ClassToType"][acls1] - atypes2 = self.ffinfo["ClassToType"][acls2] - idx = self.nbfix_to_idx[atypes1[0]][atypes2[0]] - sig_now = paramset[self.name]["sigma_nbfix"][idx] - eps_now = paramset[self.name]["epsilon_nbfix"][idx] - self.ffinfo["Forces"][self.name]["node"][nnode]["attrib"]["sigma"] = sig_now - self.ffinfo["Forces"][self.name]["node"][nnode]["attrib"]["epsilon"] = eps_now + self.perParticleKey = perParticleKey + paramset.addParameter(jnp.array([i[0] for i in perParticleParam]), + "charge", field=self.name, mask=chargeMask) + paramset.addParameter(jnp.array([i[1] for i in perParticleParam]), + "radius", field=self.name) + paramset.addParameter(jnp.array([i[2] for i in perParticleParam]), + "scale", field=self.name) - def createPotential(self, topdata: DMFFTopology, nonbondedMethod, - nonbondedCutoff, **kwargs): - methodMap = { - app.NoCutoff: "NoCutoff", - app.CutoffPeriodic: "CutoffPeriodic", - app.CutoffNonPeriodic: "CutoffNonPeriodic", - app.PME: "CutoffPeriodic", - } - if nonbondedMethod not in methodMap: - raise DMFFException("Illegal nonbonded method for NonbondedForce") - methodString = methodMap[nonbondedMethod] - atoms = [a for a in topdata.atoms()] - atypes = [a.meta["type"] for a in atoms] - map_prm = [] - for atype in atypes: - if atype not in self.atype_to_idx: - raise DMFFException(f"Atom type {atype} not found.") - idx = self.atype_to_idx[atype] - map_prm.append(idx) - map_prm = jnp.array(map_prm) - topdata._meta["lj_map_idx"] = map_prm + def getName(self) -> str: + """ + Returns the name of the force field. - # not use nbfix for now - map_nbfix = [] - for atype1 in self.nbfix_to_idx.keys(): - for atype2 in self.nbfix_to_idx[atype1].keys(): - nbfix_idx = self.nbfix_to_idx[atype1][atype2] - type1_idx = self.atype_to_idx[atype1] - type2_idx = self.atype_to_idx[atype2] - map_nbfix.append([type1_idx, type2_idx, nbfix_idx]) - map_nbfix = np.array(map_nbfix, dtype=int).reshape((-1, 3)) + Returns: + -------- + str + The name of the force field. + """ + return self.name - if methodString in ["NoCutoff", "CutoffNonPeriodic"]: - isPBC = False - if methodString == "NoCutoff": - isNoCut = True - else: - isNoCut = False - else: - isPBC = True - isNoCut = False - mscales_lj = jnp.array([0.0, 0.0, 0.0, 1.0, 1.0, 1.0]) # mscale for LJ - mscales_lj = mscales_lj.at[2].set(self.lj14scale) + def overwrite(self, paramset: ParamSet) -> None: + """ + Overwrites the parameter set. - if unit.is_quantity(nonbondedCutoff): - r_cut = nonbondedCutoff.value_in_unit(unit.nanometer) - else: - r_cut = nonbondedCutoff + Parameters: + ----------- + paramset : ParamSet + The parameter set. + """ + radius = paramset[self.name]["radius"] + scale = paramset[self.name]["scale"] + for i in self.ffinfo.perParticleParamIndices: + self.ffinfo["Forces"][self.name]["node"][i]["attrib"]["radius"] = str(radius[i]) + self.ffinfo["Forces"][self.name]["node"][i]["attrib"]["scale"] = str(scale[i]) - ljforce = LennardJonesForce(0.0, - r_cut, - map_prm, - map_nbfix, - isSwitch=False, - isPBC=isPBC, - isNoCut=isNoCut) - ljenergy = ljforce.generate_get_energy() + def _find_key_index(self, key: Tuple[str]) -> int: + """ + Finds the index of the key. - has_aux = False - if "has_aux" in kwargs and kwargs["has_aux"]: - has_aux = True + Parameters: + ----------- + key : tuple of str + The key. - def potential_fn(positions, box, pairs, params, aux=None): + Returns: + -------- + int + The index of the key. + """ + for i, k in enumerate(self.perParticleKey): + if k == key: + return i + return None - # check whether args passed into potential_fn are jnp.array and differentiable - # note this check will be optimized away by jit - # it is jit-compatiable - isinstance_jnp(positions, box, params) + def createPotential(self, topdata: DMFFTopology, nonbondedMethod, nonbondedCutoff, **kwargs): + """ + Creates the potential. - ljE = ljenergy(positions, box, pairs, - params[self.name]["epsilon"], - params[self.name]["sigma"], - params[self.name]["epsilon_nbfix"], - params[self.name]["sigma_nbfix"], - mscales_lj) + Parameters: + ----------- + topdata : DMFFTopology + The topology data. + nonbondedMethod : str + The nonbonded method. + nonbondedCutoff : float + The nonbonded cutoff. + args : list + The arguments. - if has_aux: - return ljE, aux - else: - return ljE + Returns: + -------- + function + The potential function. + """ + # Load CustomGBForce parameters + charge_indices, radius_indices, scale_indices = [], [] ,[] + for atom in topdata.atoms(): + if self.key_type == "type": + key = (atom.meta["type"]) + elif self.key_type == "class": + key = (atom.meta["class"]) + idx = self._find_key_index(key) + if idx is None: + continue + charge_indices.append(idx) + radius_indices.append(idx) + scale_indices.append(idx) + + charge_indices = jnp.array(charge_indices) + radius_indices = jnp.array(radius_indices) + scale_indices = jnp.array(scale_indices) + + customGBforce = CustomGBForce(charge_indices, radius_indices, scale_indices) + GBSAOBCenergy = customGBforce.generate_get_energy() + + def potential_fn(positions: jnp.ndarray, box: jnp.ndarray, pairs: jnp.ndarray, params: ParamSet): + pairs = pairs[:int(positions.shape[0]*(positions.shape[0]-1)/2)] + tt = np.vstack((pairs, pairs[:,[1, 0, 2]])) + Ipair = [] + for i in range(positions.shape[0]): + Ipair.append([pair[1] for pair in tt if pair[0] == i]) + Ipair = jnp.array(Ipair) + energy = GBSAOBCenergy(positions, box, pairs, Ipair, + params[self.name]["charge"], + params[self.name]["radius"], + params[self.name]['scale']) + return energy self._jaxPotential = potential_fn return potential_fn -_DMFFGenerators["LennardJonesForce"] = LennardJonesGenerator +_DMFFGenerators["CustomGBForce"] = CustomGBGenerator \ No newline at end of file diff --git a/docs/user_guide/4.1classical.md b/docs/user_guide/4.1classical.md index 1dabe12e6..840506dc7 100644 --- a/docs/user_guide/4.1classical.md +++ b/docs/user_guide/4.1classical.md @@ -230,4 +230,184 @@ The attribute `coulomb14scale` and `lj14scale` specifies the scale factors betwe ``` +# CustomGBJaxForce +## 1. Theory + +### Generalized Born Term + +This force is used to present implicit solvent model. The force consists of two energy terms: a Generalized Born Approximation term to represent the electrostatic interaction between the solute and solvent, and a surface area term to represent the free energy cost of solvating a neutral molecule. The Generalized Born energy is given by + +$$ +E = -\frac{1}{2}(\frac{1}{\epsilon_{solute}}-\frac{1}{\epsilon_{solvent}})\sum_{i,j}\frac{q_iq_j}{f_{GB}(d_{ij},R_i,R_j)} +$$ + +where the indices $i$ and $j$ run over all particles, $\epsilon_{solute}$ and $\epsilon_{solvent}$ are the dielectric constants of the solute and solvent respectively, $q_i$ is the charge of particle i, and $d_{ij}$ is the distance between particles i and j. And $f_{GB}(d_{ij},R_i,R_j)$ is defined as: + +$$ +f_{GB}(d_{ij},R_i,R_j)=[d^2_{ij}+R_iR_jexp(-\frac{d^2_{ij}}{4R_iR_j})]^{\frac{1}{2}} +$$ + +$R_i$ is the Born radius of particle i, which is calculated as: + +$$ +R_i = \frac{1}{\rho_i^{-1}-r_i^{-1}tanh(\alpha\Psi_i-\beta\Psi_i^2+\gamma\Psi_i^3)} +$$ + +where $\alpha,\beta,\gamma$ are the $GB^{OBC}II$ parameters $\alpha=1, \beta=0.8,\gamma=4.85$. $\rho_i$ is the adjusted atomic radius of particle i, which is calculated from the atomic radius $r_i$ as $\rho_i=r_i-0.009$ nm. $\Psi_i$ is calculated as an integral over the van der Waals spheres of all particles outside particle i: + +$$ +\Psi_i=\frac{\rho_i}{4\pi}\int_{VDM}\theta(|r|-\rho_i)\frac{1}{|r|^4}d^3r +$$ + +where $\theta(r)$ is a step function that excludes the interior of particle i from the integral. + +### Surface Area Term + +The surface area term is given by: + +$$ +E=E_{SA}·4\pi\sum_i(r_i+r_{solvent})^2(\frac{r_i}{R_i})^6 +$$ + +where $r_i$ is the atomic radius of particle i, $r_i$ is its atomic radius, and $r_{solvent}$ is the solvent radius, which is taken to be 0.14 nm. The default value for the energy scale $E_{SA}$ is $2.25936\ kJ/mol/nm^2$. + +## 2. Frontend + +The way to specify a CustomGBJaxForce in DMFF is the same as the way doing it in OpenMM with CustomGBForce: (the example is shown in DMFF/examples/classical/gbForce/gbForce.py ) + +```xml + + + + + + + + step(r+sr2-or1)*0.5*(1/L-1/U+0.25*(1/U^2-1/L^2)*(r-sr2*sr2/r)+0.5*log(L/U)/r+C); + U=r+sr2; C=2*(1/or1-1/L)*step(sr2-r-or1); L=max(or1, D); D=abs(r-sr2); sr2 = + scale2*or2; or1 = radius1-0.009; or2 = radius2-0.009 + + + 1/(1/or-tanh(1*psi-0.8*psi^2+4.85*psi^3)/radius); psi=I*or; or=radius-0.009 + + + 28.3919551*(radius+0.14)^2*(radius/B)^6-0.5*138.935456*(1/soluteDielectric-1/solventDielectric)*charge^2/B + + + -138.935456*(1/soluteDielectric-1/solventDielectric)*charge1*charge2/f; + f=sqrt(r^2+B1*B2*exp(-r^2/(4*B1*B2))) + + + + ... + +``` + +Every `` tag defines a rule for creating CustomGBForce interactions between atoms. Each tag may identify the atoms either by type (using the attributes `type1` and `type2`) or by class (using the attributes `class1` and `class2`). +For now, only the ``, `` and `` (if necessary) are trainable. DMFF CustomGBForce generators do not accept other approximation for implicit solvent model. Unless user designs a new CustomGBForce generator for the specific approximation like HCT, OBC etc. + + +# CustomTorsionJaxForce + +## 1. Theory + +The custom torsion is represented by a truncated periodic Fourier series: + +$$ +E = \sum_{n=0}^{4} k_n(\cos(n\phi-\phi_{0n})) + shift +$$ + +where $\phi$ is the dihedral angle formed by four particles, $n$ is the periodicity, $\phi_{0n}$ is the phase offset $k_{n}$ is the force constant. To preserve the symmetry, $\phi_{0n}$ usually adopts a value of $0$ (for $n=1,3,5$) or $\pi$ ($n=2,4,6$), and it is recommended to follow these definitions and not to optimize them in force field development. + +## 2. Frontend + +The way to specify a custom torsion in DMFF is the same as the way doing it in OpenMM: + +```xml + + + + + + + + + + + + + + + + + + +``` + +Every child tag `` or `` defines a rule for creating periodic torsion interactions between sets of four atoms. Each tag may identify the atoms either by type (using the attributes `type1`, `type2`, ...) or by class (using the attributes `class1`, `class2`, ...). + +The force field recognizes two different types of torsions: `Proper` and `Improper`. A proper torsion involves four atoms that are bonded in sequence: 1 to 2, 2 to 3, and 3 to 4. An improper torsion involves a central atom and three others that are bonded to it: atoms 2, 3, and 4 are all bonded to atom 1. `per1` is the periodicity of the torsion, `phase1` is the phase offset in radians, and `k1` is the force constant in kJ/mol. To add a second periodicity, just add three more attributes: `per2`, `phase2`, and `k2`. **The maxium periodicity supported in DMFF is 6, which is different from OpenMM**. + +You can also use wildcards when defining torsions. To do this, simply leave the type or class name for an atom empty. That will cause it to match any atom: + +```xml + +``` + +When the tag has an attribute named `mask` and it's value set to `true`, this means the parameter is not trainable. Such information will be passed to `ParamSet.mask` (the corresponding mask value will be 0.0 if not trainable). + + +# Custom1_5BondJaxForce + +## 1. Theory + +The force is used to regulate the atoms relation between atom 1 to 5 in coarse-grained polyphosphate. + +$$ +E = \frac{1}{2}k(b-b_0)^2 +$$ + +where $k$ is the force constant, $b$ is the distance betweeen two particles that forming a bond and $b_0$ is the equilibrium bond length. Note that in some other MD softwares, the potential form adopts a slight different form: $E=k(b-b_0)^2$. Users should check which form to use and multiply (or divide) the force constant by 2. + +## 2. Frontend + +The way to specify a harmonic bond in DMFF is different from the way doing it in OpenMM, which requires add special force `openmm.CustomCompoundBondForce()` in coding: + +```xml + + + + + + + + +``` +When using this force, you need to add `openmm.CustomCompoundBondForce()` during your simulation: (the example is shown in DMFF/examples/classical/gbForce/gbForce.py ) + +```python +h = Hamiltonian("CG.xml") +params = h.getParameters() +compoundBondForceParam = params["Custom1_5BondForce"] +length = compoundBondForceParam["length"] +k = compoundBondForceParam["k"] +system = ff.createSystem(pdb.topology, nonbondedMethod=NoCutoff) +customCompoundForce = openmm.CustomCompoundBondForce(2, "0.5*k*(distance(p1,p2)-length)^2") +customCompoundForce.addPerBondParameter("length") +customCompoundForce.addPerBondParameter("k") +for i, leng in enumerate(length): + customCompoundForce.addBond([i, i+4], [leng, k[i]]) +system.addForce(customCompoundForce) +``` + +Every `` tag defines a rule for creating harmonic bond interactions between 1 and 5 atoms. Each tag may identify the atoms by index (using the attributes `atomIndex1` and `atomIndex2`). `length` is the equilibrium bond length in $\mathrm{nm}$, and `k` is the force constant in $\mathrm{kJ/mol/nm^2}$. +For now, `Custom1_5BondJaxForce` doesn't accept different energy forms unless user designs new generator for it. diff --git a/examples/classical/gbForce/10p.pdb b/examples/classical/gbForce/10p.pdb new file mode 100644 index 000000000..05b18aa71 --- /dev/null +++ b/examples/classical/gbForce/10p.pdb @@ -0,0 +1,25 @@ +TITLE GRoups of Organic Molecules in ACtion for Science +REMARK THIS IS A SIMULATION BOX +CRYST1 30.000 30.000 30.000 90.00 90.00 90.00 P 1 1 +MODEL 1 +ATOM 1 TP1 TE1 A 1 28.725 22.974 33.406 1.00 0.00 P +ATOM 2 P00 IN1 A 2 26.892 23.631 31.811 1.00 0.00 C +ATOM 3 P00 IN1 A 3 27.633 24.222 29.541 1.00 0.00 C +ATOM 4 P00 IN1 A 4 25.695 23.995 27.884 1.00 0.00 C +ATOM 5 P00 IN1 A 5 26.290 25.095 25.743 1.00 0.00 C +ATOM 6 P00 IN1 A 6 24.826 24.484 23.700 1.00 0.00 C +ATOM 7 P00 IN1 A 7 24.944 26.058 21.774 1.00 0.00 C +ATOM 8 P00 IN1 A 8 23.018 26.878 20.449 1.00 0.00 C +ATOM 9 P00 IN1 A 9 22.109 25.856 18.341 1.00 0.00 C +ATOM 10 TP1 TE1 A 10 19.868 26.802 17.355 1.00 0.00 P +TER +CONECT 1 2 +CONECT 2 3 +CONECT 3 4 +CONECT 4 5 +CONECT 5 6 +CONECT 6 7 +CONECT 7 8 +CONECT 8 9 +CONECT 9 10 +ENDMDL diff --git a/examples/classical/gbForce/1_5corrV2.xml b/examples/classical/gbForce/1_5corrV2.xml new file mode 100644 index 000000000..633b4326b --- /dev/null +++ b/examples/classical/gbForce/1_5corrV2.xml @@ -0,0 +1,98 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + step(r+sr2-or1)*0.5*(1/L-1/U+0.25*(1/U^2-1/L^2)*(r-sr2*sr2/r)+0.5*log(L/U)/r+C); + U=r+sr2; C=2*(1/or1-1/L)*step(sr2-r-or1); L=max(or1, D); D=abs(r-sr2); sr2 = + scale2*or2; or1 = radius1-0.009; or2 = radius2-0.009 + + + 1/(1/or-tanh(1*psi-0.8*psi^2+4.85*psi^3)/radius); psi=I*or; or=radius-0.009 + + + 28.3919551*(radius+0.14)^2*(radius/B)^6-0.5*138.935456*(1/soluteDielectric-1/solventDielectric)*charge^2/B + + + -138.935456*(1/soluteDielectric-1/solventDielectric)*charge1*charge2/f; + f=sqrt(r^2+B1*B2*exp(-r^2/(4*B1*B2))) + + + + + + + + + + + + + + \ No newline at end of file diff --git a/examples/classical/gbForce/gbforce.py b/examples/classical/gbForce/gbforce.py new file mode 100644 index 000000000..55a42756f --- /dev/null +++ b/examples/classical/gbForce/gbforce.py @@ -0,0 +1,99 @@ +#!/usr/bin/env python +from openmm import * +from openmm.app import * +from openmm.unit import * +import numpy as np +import sys +sys.path.append("..") +sys.path.append(".") +sys.path.append("...") +from dmff.api.hamiltonian import Hamiltonian +from dmff.common import nblist +from jax import jit +import jax.numpy as jnp +import mdtraj as md + +def forcegroupify(system): + forcegroups = {} + for i in range(system.getNumForces()): + force = system.getForce(i) + force.setForceGroup(i) + forcegroups[force] = i + return forcegroups + +def getEnergyDecomposition(context, forcegroups): + energies = {} + for f, i in forcegroups.items(): + energies[f] = context.getState(getEnergy=True, groups=2 ** i).getPotentialEnergy() + return energies + +print("MM Reference Energy:") +pdb = PDBFile("10p.pdb") +ff = ForceField("1_5corrV2.xml") +system = ff.createSystem(pdb.topology, nonbondedMethod=NoCutoff, constraints=None, removeCMMotion=False) +h = Hamiltonian("1_5corrV2.xml") +params = h.getParameters() +compoundBondForceParam = params["Custom1_5BondForce"] +length = compoundBondForceParam["length"] +k = compoundBondForceParam["k"] +customCompoundForce = openmm.CustomCompoundBondForce(2, "0.5*k*(distance(p1,p2)-length)^2") +customCompoundForce.addPerBondParameter("length") +customCompoundForce.addPerBondParameter("k") +for i, leng in enumerate(length): + customCompoundForce.addBond([i, i+4], [leng, k[i]]) +system.addForce(customCompoundForce) +print("Dih info:") +for force in system.getForces(): + if isinstance(force, PeriodicTorsionForce): + print("No. of dihs:", force.getNumTorsions()) + +forcegroups = forcegroupify(system) +integrator = VerletIntegrator(0.1) +context = Context(system, integrator, Platform.getPlatformByName("Reference")) +context.setPositions(pdb.positions) +state = context.getState(getEnergy=True) +energy = state.getPotentialEnergy() +energies = getEnergyDecomposition(context, forcegroups) +print("Total energy:", energy) +for key in energies.keys(): + print(key.getName(), energies[key]) + +print("Jax Energy") +h = Hamiltonian("1_5corrV2.xml") +pot = h.createPotential(pdb.topology, nonbondedMethod=NoCutoff) +params = h.getParameters() +positions = pdb.getPositions(asNumpy=True).value_in_unit(nanometer) +positions = jnp.array(positions) +box = np.array([ + [30.0, 0.0, 0.0], + [0.0, 30.0, 0.0], + [0.0, 0.0, 30.0] +]) + +# neighbor list +rc = 6.0 +nbl = nblist.NeighborList(box, rc, pot.meta['cov_map']) +nbl.allocate(positions) +pairs = nbl.pairs + +bondE = pot.dmff_potentials['HarmonicBondForce'] +print("Bond:", bondE(positions, box, pairs, params)) + +angleE = pot.dmff_potentials['HarmonicAngleForce'] +print("Angle:", angleE(positions, box, pairs, params)) + +gbE = pot.dmff_potentials['CustomGBForce'] +print("CustomGBForce:", gbE(positions, box, pairs, params)) + +E1_5 = pot.dmff_potentials['Custom1_5BondForce'] +print("Custom1_5BondForce:", E1_5(positions, box, pairs, params)) + +dihE = pot.dmff_potentials['CustomTorsionForce'] +print("Torsion:", dihE(positions, box, pairs, params)) + +nbE = pot.dmff_potentials['NonbondedForce'] +print("Nonbonded:", nbE(positions, box, pairs, params)) + +etotal = pot.getPotentialFunc() +print("Total:", etotal(positions, box, pairs, params)) + diff --git a/tests/data/10p.pdb b/tests/data/10p.pdb new file mode 100644 index 000000000..05b18aa71 --- /dev/null +++ b/tests/data/10p.pdb @@ -0,0 +1,25 @@ +TITLE GRoups of Organic Molecules in ACtion for Science +REMARK THIS IS A SIMULATION BOX +CRYST1 30.000 30.000 30.000 90.00 90.00 90.00 P 1 1 +MODEL 1 +ATOM 1 TP1 TE1 A 1 28.725 22.974 33.406 1.00 0.00 P +ATOM 2 P00 IN1 A 2 26.892 23.631 31.811 1.00 0.00 C +ATOM 3 P00 IN1 A 3 27.633 24.222 29.541 1.00 0.00 C +ATOM 4 P00 IN1 A 4 25.695 23.995 27.884 1.00 0.00 C +ATOM 5 P00 IN1 A 5 26.290 25.095 25.743 1.00 0.00 C +ATOM 6 P00 IN1 A 6 24.826 24.484 23.700 1.00 0.00 C +ATOM 7 P00 IN1 A 7 24.944 26.058 21.774 1.00 0.00 C +ATOM 8 P00 IN1 A 8 23.018 26.878 20.449 1.00 0.00 C +ATOM 9 P00 IN1 A 9 22.109 25.856 18.341 1.00 0.00 C +ATOM 10 TP1 TE1 A 10 19.868 26.802 17.355 1.00 0.00 P +TER +CONECT 1 2 +CONECT 2 3 +CONECT 3 4 +CONECT 4 5 +CONECT 5 6 +CONECT 6 7 +CONECT 7 8 +CONECT 8 9 +CONECT 9 10 +ENDMDL diff --git a/tests/data/1_5corrV2.xml b/tests/data/1_5corrV2.xml new file mode 100644 index 000000000..633b4326b --- /dev/null +++ b/tests/data/1_5corrV2.xml @@ -0,0 +1,98 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + step(r+sr2-or1)*0.5*(1/L-1/U+0.25*(1/U^2-1/L^2)*(r-sr2*sr2/r)+0.5*log(L/U)/r+C); + U=r+sr2; C=2*(1/or1-1/L)*step(sr2-r-or1); L=max(or1, D); D=abs(r-sr2); sr2 = + scale2*or2; or1 = radius1-0.009; or2 = radius2-0.009 + + + 1/(1/or-tanh(1*psi-0.8*psi^2+4.85*psi^3)/radius); psi=I*or; or=radius-0.009 + + + 28.3919551*(radius+0.14)^2*(radius/B)^6-0.5*138.935456*(1/soluteDielectric-1/solventDielectric)*charge^2/B + + + -138.935456*(1/soluteDielectric-1/solventDielectric)*charge1*charge2/f; + f=sqrt(r^2+B1*B2*exp(-r^2/(4*B1*B2))) + + + + + + + + + + + + + + \ No newline at end of file diff --git a/tests/data/pBox.pdb b/tests/data/pBox.pdb new file mode 100644 index 000000000..b1a0118e1 --- /dev/null +++ b/tests/data/pBox.pdb @@ -0,0 +1,87 @@ +TITLE polyten_GMX.gro created by acpype (v: 2022.6.6) on Wed Oct 5 08:03:17 2022 +REMARK THIS IS A SIMULATION BOX +CRYST1 30.000 30.000 30.000 90.00 90.00 90.00 P 1 1 +MODEL 1 +HETATM 1 P00 TE1 1 28.652 12.361 14.349 1.00 0.00 P +HETATM 2 O01 TE1 1 28.511 11.590 12.887 1.00 0.00 O +HETATM 3 O02 TE1 1 28.857 11.267 15.580 1.00 0.00 O +HETATM 4 O03 TE1 1 29.865 13.487 14.329 1.00 0.00 O +HETATM 5 O04 TE1 1 26.965 13.160 14.653 1.00 0.00 O +HETATM 6 P00 IN1 2 26.260 14.805 14.965 1.00 0.00 P +HETATM 7 O01 IN1 2 26.855 15.339 16.380 1.00 0.00 O +HETATM 8 O02 IN1 2 26.508 15.706 13.631 1.00 0.00 O +HETATM 9 O00 IN2 3 24.431 14.542 15.100 1.00 0.00 O +HETATM 10 P00 IN1 4 23.086 15.420 16.142 1.00 0.00 P +HETATM 11 O01 IN1 4 22.975 14.492 17.468 1.00 0.00 O +HETATM 12 O02 IN1 4 23.530 16.970 16.231 1.00 0.00 O +HETATM 13 O00 IN2 5 21.421 15.307 15.326 1.00 0.00 O +HETATM 14 P00 IN1 6 19.909 16.539 15.340 1.00 0.00 P +HETATM 15 O01 IN1 6 20.000 17.324 16.743 1.00 0.00 O +HETATM 16 O02 IN1 6 20.063 17.265 13.902 1.00 0.00 O +HETATM 17 O00 IN2 7 18.227 15.697 15.269 1.00 0.00 O +HETATM 18 P00 IN1 8 16.526 16.321 15.943 1.00 0.00 P +HETATM 19 O01 IN1 8 16.373 15.511 17.337 1.00 0.00 O +HETATM 20 O02 IN1 8 16.580 17.927 15.852 1.00 0.00 O +HETATM 21 O00 IN2 9 15.002 15.739 14.963 1.00 0.00 O +HETATM 22 P00 IN1 10 13.371 16.585 14.465 1.00 0.00 P +HETATM 23 O01 IN1 10 13.092 17.763 15.532 1.00 0.00 O +HETATM 24 O02 IN1 10 13.581 16.841 12.885 1.00 0.00 O +HETATM 25 O00 IN2 11 11.772 15.515 14.643 1.00 0.00 O +HETATM 26 P00 IN1 12 10.263 15.275 13.556 1.00 0.00 P +HETATM 27 O01 IN1 12 10.068 16.593 12.650 1.00 0.00 O +HETATM 28 O02 IN1 12 10.482 13.800 12.922 1.00 0.00 O +HETATM 29 O00 IN2 13 8.598 15.069 14.540 1.00 0.00 O +HETATM 30 P00 IN1 14 6.866 15.556 14.061 1.00 0.00 P +HETATM 31 O01 IN1 14 6.615 17.010 14.733 1.00 0.00 O +HETATM 32 O02 IN1 14 6.677 15.310 12.476 1.00 0.00 O +HETATM 33 O00 IN2 15 5.578 14.441 14.933 1.00 0.00 O +HETATM 34 P00 IN1 16 3.855 13.938 14.474 1.00 0.00 P +HETATM 35 O01 IN1 16 3.150 15.098 13.579 1.00 0.00 O +HETATM 36 O02 IN1 16 3.961 12.437 13.850 1.00 0.00 O +HETATM 37 O04 TE1 17 2.946 13.817 16.041 1.00 0.00 O +HETATM 38 P00 TE1 17 1.214 13.377 16.663 1.00 0.00 P +HETATM 39 O01 TE1 17 0.217 12.929 15.419 1.00 0.00 O +HETATM 40 O02 TE1 17 0.655 14.736 17.433 1.00 0.00 O +HETATM 41 O03 TE1 17 1.442 12.135 17.739 1.00 0.00 O +TER +CONECT 1 2 +CONECT 1 3 +CONECT 1 4 +CONECT 1 5 +CONECT 5 6 +CONECT 6 7 +CONECT 6 8 +CONECT 6 9 +CONECT 9 10 +CONECT 10 11 +CONECT 10 12 +CONECT 10 13 +CONECT 13 14 +CONECT 14 15 +CONECT 14 16 +CONECT 14 17 +CONECT 17 18 +CONECT 18 19 +CONECT 18 20 +CONECT 18 21 +CONECT 21 22 +CONECT 22 23 +CONECT 22 24 +CONECT 22 25 +CONECT 25 26 +CONECT 26 27 +CONECT 26 28 +CONECT 26 29 +CONECT 29 30 +CONECT 30 31 +CONECT 30 32 +CONECT 30 33 +CONECT 33 34 +CONECT 34 35 +CONECT 34 36 +CONECT 34 37 +CONECT 37 38 +CONECT 38 39 +CONECT 38 40 +CONECT 38 41 +ENDMDL diff --git a/tests/data/polyp_amberImp.xml b/tests/data/polyp_amberImp.xml new file mode 100644 index 000000000..15f6c9721 --- /dev/null +++ b/tests/data/polyp_amberImp.xml @@ -0,0 +1,111 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + step(r+sr2-or1)*0.5*(1/L-1/U+0.25*(1/U^2-1/L^2)*(r-sr2*sr2/r)+0.5*log(L/U)/r+C); + U=r+sr2; C=2*(1/or1-1/L)*step(sr2-r-or1); L=max(or1, D); D=abs(r-sr2); sr2 = + scale2*or2; or1 = radius1-0.009; or2 = radius2-0.009 + + + 1/(1/or-tanh(1*psi-0.8*psi^2+4.85*psi^3)/radius); psi=I*or; or=radius-0.009 + + + 28.3919551*(radius+0.14)^2*(radius/B)^6-0.5*138.935456* + (1/soluteDielectric-1/solventDielectric)*charge^2/B + + + -138.935456*(1/soluteDielectric-1/solventDielectric)*charge1*charge2/f; + f=sqrt(r^2+B1*B2*exp(-r^2/(4*B1*B2))) + + + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/tests/test_classical/test_gbforce.py b/tests/test_classical/test_gbforce.py new file mode 100644 index 000000000..ffa29c5a7 --- /dev/null +++ b/tests/test_classical/test_gbforce.py @@ -0,0 +1,80 @@ +import pytest +import jax +import jax.numpy as jnp +import openmm.app as app +import openmm.unit as unit +import numpy as np +import numpy.testing as npt +from dmff.api import Hamiltonian +from dmff.common import nblist + + +@pytest.mark.parametrize( + "pdb, prm, value", + [ + ("./tests/data/10p.pdb", "./tests/data/1_5corrV2.xml", -11184.921239189738), + ("./tests/data/pBox.pdb", "./tests/data/polyp_amberImp.xml", -13914.34177591779), + ]) +def test_custom_gb_force(pdb, prm, value): + pdb = app.PDBFile(pdb) + h = Hamiltonian(prm) + potential = h.createPotential( + pdb.topology, + nonbondedMethod=app.NoCutoff + ) + pos = jnp.asarray(pdb.getPositions(asNumpy=True).value_in_unit(unit.nanometer)) + box = np.array([[20.0, 0.0, 0.0], [0.0, 20.0, 0.0], [0.0, 0.0, 20.0]]) + rc = 6.0 + nbl = nblist.NeighborList(box, rc, potential.meta['cov_map']) + nbl.allocate(pos) + pairs = nbl.pairs + gbE = potential.getPotentialFunc(names=["CustomGBForce"]) + energy = gbE(pos, box, pairs, h.paramset) + print(energy) + npt.assert_almost_equal(energy, value, decimal=3) + + +@pytest.mark.parametrize( + "pdb, prm, value", + [ + ("./tests/data/10p.pdb", "./tests/data/1_5corrV2.xml", 59.53033875302844), + ]) +def test_custom_torsion_force(pdb, prm, value): + pdb = app.PDBFile(pdb) + h = Hamiltonian(prm) + potential = h.createPotential( + pdb.topology, + nonbondedMethod=app.NoCutoff + ) + pos = jnp.asarray(pdb.getPositions(asNumpy=True).value_in_unit(unit.nanometer)) + box = np.array([[20.0, 0.0, 0.0], [0.0, 20.0, 0.0], [0.0, 0.0, 20.0]]) + rc = 6.0 + nbl = nblist.NeighborList(box, rc, potential.meta['cov_map']) + nbl.allocate(pos) + pairs = nbl.pairs + gbE = potential.getPotentialFunc(names=["CustomTorsionForce"]) + energy = gbE(pos, box, pairs, h.paramset) + npt.assert_almost_equal(energy, value, decimal=3) + + +@pytest.mark.parametrize( + "pdb, prm, value", + [ + ("./tests/data/10p.pdb", "./tests/data/1_5corrV2.xml", 117.95416362791674), + ]) +def test_custom_1_5bond_force(pdb, prm, value): + pdb = app.PDBFile(pdb) + h = Hamiltonian(prm) + potential = h.createPotential( + pdb.topology, + nonbondedMethod=app.NoCutoff + ) + pos = jnp.asarray(pdb.getPositions(asNumpy=True).value_in_unit(unit.nanometer)) + box = np.array([[20.0, 0.0, 0.0], [0.0, 20.0, 0.0], [0.0, 0.0, 20.0]]) + rc = 6.0 + nbl = nblist.NeighborList(box, rc, potential.meta['cov_map']) + nbl.allocate(pos) + pairs = nbl.pairs + gbE = potential.getPotentialFunc(names=["Custom1_5BondForce"]) + energy = gbE(pos, box, pairs, h.paramset) + npt.assert_almost_equal(energy, value, decimal=3) \ No newline at end of file