diff --git a/dmff/classical/inter.py b/dmff/classical/inter.py
index 4fd304ec8..4efd6b0e1 100644
--- a/dmff/classical/inter.py
+++ b/dmff/classical/inter.py
@@ -332,3 +332,67 @@ def get_energy(positions, box, pairs, bcc, mscales):
charges = self.init_charges + jnp.dot(self.top_mat, bcc).flatten()
return get_energy_kernel(positions, box, pairs, charges, mscales)
return get_energy
+
+
+class CustomGBForce:
+ def __init__(
+ self,
+ map_charge,
+ map_radius,
+ map_scale,
+ epsilon_1=1.0,
+ epsilon_solv=78.3,
+ alpha=1,
+ beta=0.8,
+ gamma=4.85,
+ ) -> None:
+ self.map_charge = map_charge
+ self.map_radius = map_radius
+ self.map_scale = map_scale
+ self.alpha = alpha
+ self.beta = beta
+ self.gamma = gamma
+ self.exp_solv = epsilon_solv
+ self.eps_1 = epsilon_1
+
+ def generate_get_energy(self):
+ @jax.jit
+ def get_energy(positions, box, pairs, Ipairs, charges, radius, scales):
+ def calI(posList, radMap, scalMap, rhoMap, pairMap):
+ I = jnp.array([])
+
+ for i in range(len(radMap)):
+ posj = posList[Ipairs[i]]
+ rhoj = rhoMap[Ipairs[i]]
+ scalj = scalMap[Ipairs[i]]
+ posi = posList[i]
+ rhoi = rhoMap[i]
+
+ r = jnp.sqrt(jnp.sum(jnp.power(posi-posj,2),axis=1))
+ sr2 = rhoj * scalj
+ D = jnp.abs(r - sr2)
+ L = jnp.maximum(D, rhoi)
+ C = 2 * (1 / rhoi - 1 / L) * jnp.heaviside(sr2 - r - rhoi, 1)
+ U = r + sr2
+ I = jnp.append(I, jnp.sum(0.5 * jnp.heaviside(r + sr2 - rhoi, 1) * (
+ 1 / L - 1 / U + 0.25 * (1 / U ** 2 - 1 / L ** 2) * (
+ r - sr2 ** 2 / r) + 0.5 * jnp.log(L / U) / r + C)))
+
+ return I
+
+ chargeMap = charges[self.map_charge]
+ radiusMap = radius[self.map_radius]
+ scalesMap = scales[self.map_scale]
+ rhoMap = radiusMap - 0.009
+
+ # effective radius
+ IList = calI(positions, radiusMap, scalesMap, rhoMap, Ipairs)
+ psi = IList*rhoMap
+ rEff = 1/(1/rhoMap-jnp.tanh(self.alpha*psi-self.beta*jnp.power(psi, 2)+self.gamma*jnp.power(psi, 3))/radiusMap)
+ Ese = jnp.sum(28.3919551*(radiusMap+0.14)**2*jnp.power(radiusMap/rEff, 6)-0.5*138.935456*(1/self.eps_1-1/self.exp_solv)*chargeMap**2/rEff)
+ dr_norm = jnp.linalg.norm(positions[pairs[:,0]] - positions[pairs[:,1]], axis=1)
+ chargepro = chargeMap[pairs[:, 0]] * chargeMap[pairs[:, 1]]
+ rEffpro = rEff[pairs[:, 0]] * rEff[pairs[:, 1]]
+ Egb = jnp.sum(-138.935456*(1/self.eps_1-1/self.exp_solv)*chargepro/jnp.sqrt(jnp.power(dr_norm, 2)+rEffpro*jnp.exp(-jnp.power(dr_norm,2)/(4*rEffpro))))
+ return Ese + Egb
+ return get_energy
\ No newline at end of file
diff --git a/dmff/classical/intra.py b/dmff/classical/intra.py
index 76843b979..198c16ce2 100644
--- a/dmff/classical/intra.py
+++ b/dmff/classical/intra.py
@@ -148,3 +148,78 @@ def refresh_calculators(self):
"""
self.get_energy = self.generate_get_energy()
self.get_forces = value_and_grad(self.get_energy)
+
+
+class Custom1_5BondJaxForce:
+ def __init__(self, p1idx, p2idx, prmidx):
+ self.p1idx = p1idx
+ self.p2idx = p2idx
+ self.prmidx = prmidx
+ self.refresh_calculators()
+
+ def generate_get_energy(self):
+ def get_energy(positions, box, pairs, k, length):
+ p1 = positions[self.p1idx,:]
+ p2 = positions[self.p2idx,:]
+ kprm = k[self.prmidx]
+ b0prm = length[self.prmidx]
+ dist = distance(p1, p2)
+ return jnp.sum(0.5 * kprm * jnp.power(dist - b0prm, 2))
+
+ return get_energy
+
+ def update_env(self, attr, val):
+ """
+ Update the environment of the calculator
+ """
+ setattr(self, attr, val)
+ self.refresh_calculators()
+
+ def refresh_calculators(self):
+ """
+ refresh the energy and force calculators according to the current environment
+ """
+ self.get_energy = self.generate_get_energy()
+ self.get_forces = value_and_grad(self.get_energy)
+
+
+class CustomTorsionJaxForce:
+ def __init__(self, p1idx, p2idx, p3idx, p4idx, prmidx, order):
+ self.p1idx = p1idx
+ self.p2idx = p2idx
+ self.p3idx = p3idx
+ self.p4idx = p4idx
+ self.prmidx = prmidx
+ self.order = order
+ self.refresh_calculators()
+
+ def generate_get_energy(self):
+ if len(self.p1idx) == 0:
+ return lambda positions, box, pairs, k, psi, shift: 0.0
+ def get_energy(positions, box, pairs, k, psi, shift):
+ p1 = positions[self.p1idx, :]
+ p2 = positions[self.p2idx, :]
+ p3 = positions[self.p3idx, :]
+ p4 = positions[self.p4idx, :]
+ kp = k[self.prmidx]
+ psip = psi[self.prmidx]
+ shiftp = shift[self.prmidx]
+ dih = dihedral(p1, p2, p3, p4)
+ ener = kp * (jnp.cos(self.order * dih - psip)) + shiftp
+ return jnp.sum(ener)
+
+ return get_energy
+
+ def update_env(self, attr, val):
+ """
+ Update the environment of the calculator
+ """
+ setattr(self, attr, val)
+ self.refresh_calculators()
+
+ def refresh_calculators(self):
+ """
+ refresh the energy and force calculators according to the current environment
+ """
+ self.get_energy = self.generate_get_energy()
+ self.get_forces = value_and_grad(self.get_energy)
\ No newline at end of file
diff --git a/dmff/generators/classical.py b/dmff/generators/classical.py
index 00d70be45..bb260b6f1 100644
--- a/dmff/generators/classical.py
+++ b/dmff/generators/classical.py
@@ -7,8 +7,8 @@
import jax.numpy as jnp
import openmm.app as app
import openmm.unit as unit
-from ..classical.intra import HarmonicBondJaxForce, HarmonicAngleJaxForce, PeriodicTorsionJaxForce
-from ..classical.inter import CoulNoCutoffForce, CoulombPMEForce, CoulReactionFieldForce, LennardJonesForce, LennardJonesLongRangeForce
+from ..classical.intra import HarmonicBondJaxForce, HarmonicAngleJaxForce, PeriodicTorsionJaxForce, Custom1_5BondJaxForce, CustomTorsionJaxForce
+from ..classical.inter import CoulNoCutoffForce, CoulombPMEForce, CoulReactionFieldForce, LennardJonesForce, LennardJonesLongRangeForce, CustomGBForce
from typing import Tuple, List, Union, Callable
@@ -787,157 +787,921 @@ def potential_fn(positions: jnp.ndarray, box: jnp.ndarray, pairs: jnp.ndarray, p
_DMFFGenerators["PeriodicTorsionForce"] = PeriodicTorsionGenerator
+class CustomTorsionGenerator:
+
+ def __init__(self, ffinfo: dict, paramset: ParamSet):
+ """
+ Initializes a PeriodicTorsionForce object.
+
+ Args:
+ - ffinfo (dict): A dictionary containing force field information.
+ - paramset (ParamSet): A ParamSet object to register parameters.
+
+ Returns:
+ - None
+ """
+ self.name = "CustomTorsionForce"
+ self.ffinfo = ffinfo
+ paramset.addField(self.name)
+ self._use_smarts = False
+ self.key_type = None
+ self.torsionIndices = [i for i in range(len(self.ffinfo["Forces"][self.name]["node"])) if "roper" in self.ffinfo["Forces"][self.name]["node"][i]["name"]]
+
+ proper_keys, proper_periods, proper_prms, proper_shift = [], [], [], []
+ proper_key_to_prms = {}
+ improper_keys, improper_periods, improper_prms, improper_shift = [], [], [], []
+ improper_key_to_prms = {}
+ for i in self.torsionIndices:
+ node = self.ffinfo["Forces"][self.name]["node"][i]
+ attribs = node["attrib"]
+ if "type1" in attribs:
+ self.key_type = "type"
+ elif "class1" in attribs:
+ self.key_type = "class"
+ key = (attribs[self.key_type + "1"], attribs[self.key_type + "2"],
+ attribs[self.key_type + "3"], attribs[self.key_type + "4"])
+ if node["name"] == "Proper":
+ proper_keys.append(key)
+ elif node["name"] == "Improper":
+ improper_keys.append(key)
+
+ mask = 1.0
+ if "mask" in attribs and attribs["mask"].upper() == "TRUE":
+ mask = 0.0
+
+ for period_key in attribs.keys():
+ if "per" not in period_key:
+ continue
+ order = int(period_key.replace("per", ""))
+ period = int(attribs[period_key])
+ phase = float(attribs["phase" + str(order)])
+ k = float(attribs["k" + str(order)])
+ shift = float(attribs["shift"])/4
+ if node["name"] == "Proper":
+ proper_periods.append(period)
+ proper_prms.append([phase, k, mask, shift])
+ if len(proper_keys) - 1 not in proper_key_to_prms:
+ proper_key_to_prms[len(proper_keys) - 1] = []
+ proper_key_to_prms[len(
+ proper_keys) - 1].append(len(proper_periods) - 1)
+ elif node["name"] == "Improper":
+ improper_periods.append(period)
+ improper_prms.append([phase, k, mask, shift])
+ if len(improper_keys) - 1 not in improper_key_to_prms:
+ improper_key_to_prms[len(improper_keys) - 1] = []
+ improper_key_to_prms[len(
+ improper_keys) - 1].append(len(improper_periods) - 1)
+
+ self.proper_keys = proper_keys
+ self.proper_periods = jnp.array(proper_periods)
+ self.proper_key_to_prms = proper_key_to_prms
+ proper_phase = jnp.array([i[0] for i in proper_prms])
+ proper_k = jnp.array([i[1] for i in proper_prms])
+ proper_mask = jnp.array([i[2] for i in proper_prms])
+ proper_shift = jnp.array([i[3] for i in proper_prms])
+ # register parameters to ParamSet
+ paramset.addParameter(proper_phase, "proper_phase",
+ field=self.name, mask=proper_mask)
+ paramset.addParameter(proper_k, "proper_k",
+ field=self.name, mask=proper_mask)
+ paramset.addParameter(proper_shift, "proper_shift",
+ field=self.name, mask=proper_mask)
+
+ self.imp_keys = improper_keys
+ self.imp_periods = jnp.array(improper_periods)
+ self.imp_key_to_prms = improper_key_to_prms
+ improper_phase = jnp.array([i[0] for i in improper_prms])
+ improper_k = jnp.array([i[1] for i in improper_prms])
+ improper_mask = jnp.array([i[2] for i in improper_prms])
+ improper_shift = jnp.array([i[3] for i in improper_prms])
+ # register parameters to ParamSet
+ paramset.addParameter(improper_phase, "improper_phase",
+ field=self.name, mask=improper_mask)
+ paramset.addParameter(improper_k, "improper_k",
+ field=self.name, mask=improper_mask)
+ paramset.addParameter(improper_shift, "improper_shift",
+ field=self.name, mask=improper_mask)
+
+ def getName(self):
+ return self.name
+
+ def overwrite(self, paramset):
+ # paramset to ffinfo
+ proper_node_indices = [i for i in range(len(
+ self.ffinfo["Forces"][self.name]["node"])) if
+ self.ffinfo["Forces"][self.name]["node"][i]["name"] == "Proper"]
+ improper_node_indices = [i for i in range(len(
+ self.ffinfo["Forces"][self.name]["node"])) if
+ self.ffinfo["Forces"][self.name]["node"][i]["name"] == "Improper"]
+
+ proper_phase = paramset[self.name]["proper_phase"]
+ proper_k = paramset[self.name]["proper_k"]
+ proper_shift = paramset[self.name]["proper_shift"]
+ proper_msks = paramset.mask[self.name]["proper"]
+ for nnode, key in enumerate(self.proper_keys):
+ self.ffinfo["Forces"][self.name]["node"][proper_node_indices[nnode]]["attrib"] = {
+ }
+ self.ffinfo["Forces"][self.name]["node"][proper_node_indices[nnode]
+ ]["attrib"][f"{self.key_type}1"] = key[0]
+ self.ffinfo["Forces"][self.name]["node"][proper_node_indices[nnode]
+ ]["attrib"][f"{self.key_type}2"] = key[1]
+ self.ffinfo["Forces"][self.name]["node"][proper_node_indices[nnode]
+ ]["attrib"][f"{self.key_type}3"] = key[2]
+ self.ffinfo["Forces"][self.name]["node"][proper_node_indices[nnode]
+ ]["attrib"][f"{self.key_type}4"] = key[3]
+ shiftTem = 0
+ for nitem, item in enumerate(self.proper_key_to_prms[nnode]):
+ phase, k, shift = proper_phase[item], proper_k[item], proper_shift[item]
+ mask = proper_msks[item]
+ self.ffinfo["Forces"][self.name]["node"][proper_node_indices[nnode]
+ ]["attrib"]["per" + str(nitem + 1)] = str(self.proper_periods[item])
+ self.ffinfo["Forces"][self.name]["node"][proper_node_indices[nnode]
+ ]["attrib"]["phase" + str(nitem + 1)] = str(phase)
+ self.ffinfo["Forces"][self.name]["node"][proper_node_indices[nnode]
+ ]["attrib"]["k" + str(nitem + 1)] = str(k)
+ shiftTem += shift
+ self.ffinfo["Forces"][self.name]["node"][proper_node_indices[nnode]
+ ]["attrib"]["shift"] = str(shiftTem)
+ if mask < 0.999:
+ self.ffinfo["Forces"][self.name]["node"][proper_node_indices[nnode]
+ ]["attrib"]["mask"] = "true"
+
+ improper_phase = paramset[self.name]["improper_phase"]
+ improper_k = paramset[self.name]["improper_k"]
+ improper_shift = paramset[self.name]["improper_shift"]
+ improper_msks = paramset.mask[self.name]["improper"]
+ for nnode, key in enumerate(self.imp_keys):
+ self.ffinfo["Forces"][self.name]["node"][improper_node_indices[nnode]]["attrib"] = {
+ }
+ self.ffinfo["Forces"][self.name]["node"][improper_node_indices[nnode]
+ ]["attrib"][f"{self.key_type}1"] = key[0]
+ self.ffinfo["Forces"][self.name]["node"][improper_node_indices[nnode]
+ ]["attrib"][f"{self.key_type}2"] = key[1]
+ self.ffinfo["Forces"][self.name]["node"][improper_node_indices[nnode]
+ ]["attrib"][f"{self.key_type}3"] = key[2]
+ self.ffinfo["Forces"][self.name]["node"][improper_node_indices[nnode]
+ ]["attrib"][f"{self.key_type}4"] = key[3]
+ shiftTem = 0
+ for nitem, item in enumerate(self.imp_key_to_prms[nnode]):
+ phase = improper_phase[item]
+ k = improper_k[item]
+ shift = improper_shift[item]
+ mask = improper_msks[item]
+ self.ffinfo["Forces"][self.name]["node"][improper_node_indices[nnode]
+ ]["attrib"]["per" + str(nitem + 1)] = str(self.imp_periods[item])
+ self.ffinfo["Forces"][self.name]["node"][improper_node_indices[nnode]
+ ]["attrib"]["phase" + str(nitem + 1)] = str(phase)
+ self.ffinfo["Forces"][self.name]["node"][improper_node_indices[nnode]
+ ]["attrib"]["k" + str(nitem + 1)] = str(k)
+ shiftTem += shift
+ self.ffinfo["Forces"][self.name]["node"][improper_node_indices[nnode]
+ ]["attrib"]["shift"] = str(shiftTem)
+ if mask < 0.999:
+ self.ffinfo["Forces"][self.name]["node"][improper_node_indices[nnode]
+ ]["attrib"]["mask"] = "true"
+
+ def _find_proper_key_index(self, key: Tuple[str, str, str, str]) -> int:
+ wc_patch = []
+ for i, k in enumerate(self.proper_keys):
+ if k[0] in ["", key[0]] and k[1] in ["", key[1]] and k[2] in ["", key[2]] and k[3] in ["", key[3]]:
+ if "" in k:
+ wc_patch.append(i)
+ else:
+ return i
+ if k[0] in ["", key[3]] and k[1] in ["", key[2]] and k[2] in ["", key[1]] and k[3] in ["", key[0]]:
+ if "" in k:
+ wc_patch.append(i)
+ else:
+ return i
+ if len(wc_patch) > 0:
+ return wc_patch[0]
+ return None
+
+ def _find_improper_key_index(self, improper):
+
+ type1 = improper[0].meta[self.key_type]
+ type2 = improper[1].meta[self.key_type]
+ type3 = improper[2].meta[self.key_type]
+ type4 = improper[3].meta[self.key_type]
+
+ def _wild_match(tp, tps):
+ if tps == "":
+ return True
+ if tp == tps:
+ return True
+ return False
+
+ matched = None
+ for ndef, tordef in enumerate(self.imp_keys):
+ types1 = tordef[0]
+ types2 = tordef[1]
+ types3 = tordef[2]
+ types4 = tordef[3]
+ hasWildcard = ("" in (types1, types2, types3, types4))
+
+ if matched is not None and hasWildcard:
+ continue
+
+ import itertools
+ if type1 in types1:
+ for (t2, t3, t4) in itertools.permutations(((type2, 1), (type3, 2), (type4, 3))):
+ if _wild_match(t2[0], types2) and _wild_match(t3[0], types3) and _wild_match(t4[0], types4):
+ a1 = improper[t2[1]].index
+ a2 = improper[t3[1]].index
+ e1 = improper[t2[1]].element
+ e2 = improper[t3[1]].element
+ m1 = app.element.get_by_symbol(e1).mass
+ m2 = app.element.get_by_symbol(e2).mass
+ if e1 == e2 and a1 > a2:
+ (a1, a2) = (a2, a1)
+ elif e1 != "C" and (e2 == "C" or m1 < m2):
+ (a1, a2) = (a2, a1)
+ matched = (a1, a2, improper[0].index, improper[t4[1]].index, ndef)
+ break
+ if matched is None:
+ return None, None
+ return matched[4], matched[:4]
+
+ def createPotential(self, topdata: DMFFTopology, nonbondedMethod,
+ nonbondedCutoff, **kwargs):
+
+ if self.key_type is None:
+ def potential_fn_zero(positions: jnp.ndarray, box: jnp.ndarray, pairs: jnp.ndarray,
+ params: ParamSet) -> jnp.ndarray:
+ return jnp.zeros((1,))
+
+ self._jaxPotential = potential_fn_zero
+ return potential_fn_zero
+
+ proper_list = []
+
+ acenters = {}
+ atoms = [a for a in topdata.atoms()]
+ for bond in topdata.bonds():
+ a1, a2 = bond.atom1, bond.atom2
+ i1, i2 = a1.index, a2.index
+ if i1 not in acenters:
+ acenters[i1] = []
+ acenters[i1].append(i2)
+ if i2 not in acenters:
+ acenters[i2] = []
+ acenters[i2].append(i1)
+
+ # find rotamers and loop over proper torsions on the rotamer
+ for bond in topdata.bonds():
+ a1, a2 = bond.atom1, bond.atom2
+ i1, i2 = a1.index, a2.index
+ alinks1 = [i for i in acenters[i1] if i != i2]
+ alinks2 = [i for i in acenters[i2] if i != i1]
+ for i3 in alinks1:
+ for i4 in alinks2:
+ if i3 != i4:
+ proper_list.append(
+ (atoms[i3], atoms[i1], atoms[i2], atoms[i4]))
+
+ impr_list = []
+ # find atoms that link with three other atoms
+ import itertools as it
+ for i1 in acenters:
+ if len(acenters[i1]) < 3:
+ continue
+ for item in it.combinations(acenters[i1], 3):
+ impr_list.append(
+ (atoms[i1], atoms[item[0]], atoms[item[1]], atoms[item[2]]))
+
+ # create potential
+ proper_a1, proper_a2, proper_a3, proper_a4, proper_indices, proper_period = [
+ ], [], [], [], [], []
+ for proper in proper_list:
+ pidx = self._find_proper_key_index(
+ (proper[0].meta[self.key_type], proper[1].meta[self.key_type], proper[2].meta[self.key_type],
+ proper[3].meta[self.key_type]))
+ if pidx is None:
+ continue
+
+ prm_indices = self.proper_key_to_prms[pidx]
+ for prm_idx in prm_indices:
+ prm_period = self.proper_periods[prm_idx]
+ proper_a1.append(proper[0].index)
+ proper_a2.append(proper[1].index)
+ proper_a3.append(proper[2].index)
+ proper_a4.append(proper[3].index)
+ proper_indices.append(prm_idx)
+ proper_period.append(prm_period)
+
+ proper_a1 = jnp.array(proper_a1)
+ proper_a2 = jnp.array(proper_a2)
+ proper_a3 = jnp.array(proper_a3)
+ proper_a4 = jnp.array(proper_a4)
+ proper_indices = jnp.array(proper_indices)
+ proper_period = jnp.array(proper_period)
+
+ improper_a1, improper_a2, improper_a3, improper_a4, improper_indices, improper_period = [], [], [], [], [], []
+ for improper in impr_list:
+ iidx, order = self._find_improper_key_index(improper)
+ if iidx is None:
+ continue
+
+ prm_indices = self.imp_key_to_prms[iidx]
+ for prm_idx in prm_indices:
+ prm_period = self.imp_periods[prm_idx]
+ improper_a1.append(atoms[order[0]].index)
+ improper_a2.append(atoms[order[1]].index)
+ improper_a3.append(atoms[order[2]].index)
+ improper_a4.append(atoms[order[3]].index)
+ improper_indices.append(prm_idx)
+ improper_period.append(prm_period)
+ improper_a1 = jnp.array(improper_a1)
+ improper_a2 = jnp.array(improper_a2)
+ improper_a3 = jnp.array(improper_a3)
+ improper_a4 = jnp.array(improper_a4)
+ improper_indices = jnp.array(improper_indices)
+ improper_period = jnp.array(improper_period)
+
+ proper_func = CustomTorsionJaxForce(
+ proper_a1, proper_a2, proper_a3, proper_a4, proper_indices, proper_period)
+ proper_energy = proper_func.generate_get_energy()
+ improper_func = CustomTorsionJaxForce(
+ improper_a1, improper_a2, improper_a3, improper_a4, improper_indices, improper_period)
+ improper_energy = improper_func.generate_get_energy()
+
+ has_aux = False
+ if "has_aux" in kwargs and kwargs["has_aux"]:
+ has_aux = True
+
+ def potential_fn(positions: jnp.ndarray, box: jnp.ndarray, pairs: jnp.ndarray, params: ParamSet, aux=None):
+ isinstance_jnp(positions, box, params)
+ proper_energy_ = proper_energy(
+ positions, box, pairs, params[self.name]["proper_k"], params[self.name]["proper_phase"], params[self.name]["proper_shift"])
+ improper_energy_ = improper_energy(
+ positions, box, pairs, params[self.name]["improper_k"], params[self.name]["improper_phase"], params[self.name]["improper_shift"])
+ if has_aux:
+ return proper_energy_ + improper_energy_, aux
+ else:
+ return proper_energy_ + improper_energy_
+
+ self._jaxPotential = potential_fn
+ return potential_fn
+
+
+_DMFFGenerators["CustomTorsionForce"] = CustomTorsionGenerator
+
+
class NonbondedGenerator:
def __init__(self, ffinfo: dict, paramset: ParamSet):
- self.name = "NonbondedForce"
+ self.name = "NonbondedForce"
+ self.ffinfo = ffinfo
+ paramset.addField(self.name)
+ self.coulomb14scale = float(
+ self.ffinfo["Forces"]["NonbondedForce"]["meta"].get("coulomb14scale", 0.8333333333333334))
+ self.lj14scale = float(
+ self.ffinfo["Forces"]["NonbondedForce"]["meta"].get("lj14scale", 0.5))
+ self.key_type = None
+ self.type_to_charge = {}
+
+ self.charge_in_residue = False
+ for node in self.ffinfo["Forces"]["NonbondedForce"]["node"]:
+ if not self.charge_in_residue and node["name"] == "UseAttributeFromResidue":
+ if node["attrib"]["name"] == "charge":
+ self.charge_in_residue = True
+
+ types, sigma, epsilon, atom_mask = [], [], [], []
+ for node in self.ffinfo["Forces"]["NonbondedForce"]["node"]:
+ if node["name"] == "Atom":
+ attribs = node["attrib"]
+ self.key_type = None
+ if "type" in attribs:
+ self.key_type = "type"
+ elif "class" in attribs:
+ self.key_type = "class"
+ types.append(attribs[self.key_type])
+ sigma.append(float(attribs["sigma"]))
+ epsilon.append(float(attribs["epsilon"]))
+ mask = 1.0
+ if "mask" in attribs and attribs["mask"].upper() == "TRUE":
+ mask = 0.0
+ atom_mask.append(mask)
+ if not self.charge_in_residue:
+ if "charge" not in attribs:
+ raise ValueError("No charge information found in NonbondedForce or Residues.")
+ self.type_to_charge[attribs[self.key_type]] = float(attribs["charge"])
+
+ sigma = jnp.array(sigma)
+ epsilon = jnp.array(epsilon)
+ atom_mask = jnp.array(atom_mask)
+ self.atom_keys = types
+ paramset.addParameter(sigma, "sigma", field=self.name, mask=atom_mask)
+ paramset.addParameter(epsilon, "epsilon", field=self.name, mask=atom_mask)
+
+ def getName(self):
+ return self.name
+
+ def overwrite(self, paramset):
+ sigma = paramset[self.name]["sigma"]
+ epsilon = paramset[self.name]["epsilon"]
+ atom_mask = paramset.mask[self.name]["sigma"]
+
+ node2atom = [i for i in range(len(self.ffinfo["Forces"][self.name]["node"])) if self.ffinfo["Forces"][self.name]["node"][i]["name"] == "Atom"]
+
+ for natom in range(len(self.atom_keys)):
+ nnode = node2atom[natom]
+ sig_new = sigma[natom]
+ eps_new = epsilon[natom]
+ mask = atom_mask[natom]
+ self.ffinfo["Forces"][self.name]["node"][nnode]["attrib"]["sigma"] = str(sig_new)
+ self.ffinfo["Forces"][self.name]["node"][nnode]["attrib"]["epsilon"] = str(eps_new)
+ if mask < 0.999:
+ self.ffinfo["Forces"][self.name]["node"][nnode]["attrib"]["mask"] = "true"
+
+ def _find_atype_key_index(self, atype: str):
+ for n, i in enumerate(self.atom_keys):
+ if i == atype:
+ return n
+ return None
+
+ def createPotential(self, topdata: DMFFTopology, nonbondedMethod,
+ nonbondedCutoff, **kwargs):
+ methodMap = {
+ app.NoCutoff: "NoCutoff",
+ app.CutoffPeriodic: "CutoffPeriodic",
+ app.CutoffNonPeriodic: "CutoffNonPeriodic",
+ app.PME: "PME",
+ }
+ methodString = methodMap[nonbondedMethod]
+ if nonbondedMethod not in methodMap:
+ raise DMFFException("Illegal nonbonded method for NonbondedForce")
+
+ isNoCut = False
+ if nonbondedMethod is app.NoCutoff:
+ isNoCut = True
+
+ mscales_coul = jnp.array([0.0, 0.0, self.coulomb14scale, 1.0, 1.0,
+ 1.0])
+ mscales_lj = jnp.array([0.0, 0.0, self.lj14scale, 1.0, 1.0,
+ 1.0])
+
+ # coulomb
+ # set PBC
+ if nonbondedMethod not in [app.NoCutoff, app.CutoffNonPeriodic]:
+ ifPBC = True
+ else:
+ ifPBC = False
+
+ if self.charge_in_residue:
+ charges = [a.meta["charge"] for a in topdata.atoms()]
+ charges = jnp.array(charges)
+ else:
+ types = [a.meta[self.key_type] for a in topdata.atoms()]
+ charges = jnp.array([self.type_to_charge[i] for i in types])
+
+ if unit.is_quantity(nonbondedCutoff):
+ r_cut = nonbondedCutoff.value_in_unit(unit.nanometer)
+ else:
+ r_cut = nonbondedCutoff
+
+ # PME Settings
+ if nonbondedMethod is app.PME:
+ cell = topdata.getPeriodicBoxVectors()
+ self.ethresh = kwargs.get("ethresh", 1e-6)
+ self.coeff_method = kwargs.get("PmeCoeffMethod", "openmm")
+ self.fourier_spacing = kwargs.get("PmeSpacing", 0.1)
+ kappa, K1, K2, K3 = setup_ewald_parameters(r_cut, self.ethresh,
+ cell,
+ self.fourier_spacing,
+ self.coeff_method)
+ if nonbondedMethod is not app.PME:
+ # do not use PME
+ if nonbondedMethod in [app.CutoffPeriodic, app.CutoffNonPeriodic]:
+ # use Reaction Field
+ coulforce = CoulReactionFieldForce(r_cut, charges, isPBC=ifPBC)
+ if nonbondedMethod is app.NoCutoff:
+ # use NoCutoff
+ coulforce = CoulNoCutoffForce(init_charges=charges)
+ else:
+ coulforce = CoulombPMEForce(r_cut, charges, kappa, (K1, K2, K3))
+
+ self.pme_force = coulforce
+ coulenergy = coulforce.generate_get_energy()
+
+ # LJ
+ atypes = [a.meta[self.key_type] for a in topdata.atoms()]
+ map_prm = []
+ for atype in atypes:
+ pidx = self._find_atype_key_index(atype)
+ if pidx is None:
+ raise DMFFException(f"Atom type {atype} not found.")
+ map_prm.append(pidx)
+ map_prm = jnp.array(map_prm)
+
+ # not use nbfix for now
+ map_nbfix = []
+ map_nbfix = jnp.array(map_nbfix, dtype=int).reshape((-1, 3))
+ eps_nbfix = jnp.array(map_nbfix, dtype=float).reshape((-1, 3))
+ sig_nbfix = jnp.array(map_nbfix, dtype=float).reshape((-1, 3))
+
+ if methodString in ["NoCutoff", "CutoffNonPeriodic"]:
+ isPBC = False
+ if methodString == "NoCutoff":
+ isNoCut = True
+ else:
+ isNoCut = False
+ else:
+ isPBC = True
+ isNoCut = False
+
+ ljforce = LennardJonesForce(0.0,
+ r_cut,
+ map_prm,
+ map_nbfix,
+ isSwitch=False,
+ isPBC=isPBC,
+ isNoCut=isNoCut)
+ ljenergy = ljforce.generate_get_energy()
+
+ # dispersion correction
+ use_disp_corr = False
+ if "useDispersionCorrection" in kwargs and kwargs["useDispersionCorrection"]:
+ use_disp_corr = True
+ numTypes = len(self.atom_keys)
+ countVec = np.zeros(numTypes, dtype=int)
+ countMat = np.zeros((numTypes, numTypes), dtype=int)
+ types, count = np.unique(map_prm, return_counts=True)
+ for typ, cnt in zip(types, count):
+ countVec[typ] += cnt
+ for i in range(numTypes):
+ for j in range(i, numTypes):
+ if i != j:
+ countMat[i, j] = countVec[i] * countVec[j]
+ else:
+ countMat[i, j] = countVec[i] * (countVec[i] - 1) // 2
+ assert np.sum(countMat) == len(map_prm) * (len(map_prm) - 1) // 2
+
+ coval_map = topdata.buildCovMat()
+ colv_pairs = np.argwhere(
+ np.logical_and(coval_map > 0, coval_map <= 3))
+ for pair in colv_pairs:
+ if pair[0] <= pair[1]:
+ tmp = (map_prm[pair[0]], map_prm[pair[1]])
+ t1, t2 = min(tmp), max(tmp)
+ countMat[t1, t2] -= 1
+
+ ljDispCorrForce = LennardJonesLongRangeForce(r_cut, map_prm, map_nbfix, countMat)
+ ljDispEnergyFn = ljDispCorrForce.generate_get_energy()
+
+ has_aux = False
+ if "has_aux" in kwargs and kwargs["has_aux"]:
+ has_aux = True
+
+ def potential_fn(positions, box, pairs, params, aux=None):
+
+ # check whether args passed into potential_fn are jnp.array and differentiable
+ # note this check will be optimized away by jit
+ # it is jit-compatiable
+ isinstance_jnp(positions, box, params)
+
+ coulE = coulenergy(positions, box, pairs, mscales_coul)
+
+ ljE = ljenergy(positions, box, pairs, params[self.name]["epsilon"],
+ params[self.name]["sigma"], eps_nbfix, sig_nbfix, mscales_lj)
+ if use_disp_corr:
+ ljdispE = ljDispEnergyFn(box, params[self.name]["epsilon"],
+ params[self.name]["sigma"], eps_nbfix, sig_nbfix)
+ if has_aux:
+ return coulE + ljE + ljdispE, aux
+ else:
+ return coulE + ljE + ljdispE
+ else:
+ if has_aux:
+ return coulE + ljE, aux
+ else:
+ return coulE + ljE
+
+ self._jaxPotential = potential_fn
+ return potential_fn
+
+
+_DMFFGenerators["NonbondedForce"] = NonbondedGenerator
+
+
+class CoulombGenerator:
+ def __init__(self, ffinfo: dict, paramset: ParamSet):
+ self.name = "CoulombForce"
self.ffinfo = ffinfo
paramset.addField(self.name)
self.coulomb14scale = float(
- self.ffinfo["Forces"]["NonbondedForce"]["meta"].get("coulomb14scale", 0.8333333333333334))
+ self.ffinfo["Forces"]["CoulombForce"]["meta"]["coulomb14scale"])
+ self._use_bcc = False
+ self._bcc_mol = []
+ self.bcc_parsers = []
+ bcc_prms = []
+ bcc_mask = []
+ for node in self.ffinfo["Forces"]["CoulombForce"]["node"]:
+ if node["name"] == "UseBondChargeCorrection":
+ self._use_bcc = True
+ self._bcc_mol.append(node["attrib"]["name"])
+ if node["name"] == "BondChargeCorrection":
+ bcc = node["attrib"]["bcc"]
+ parser = node["attrib"]["smarts"] if "smarts" in node["attrib"] else node["attrib"]["smirks"]
+ bcc_prms.append(float(bcc))
+ self.bcc_parsers.append(parser)
+ if "mask" in node["attrib"] and node["attrib"]["mask"].upper() == "TRUE":
+ bcc_mask.append(0.0)
+ else:
+ bcc_mask.append(1.0)
+ bcc_prms = jnp.array(bcc_prms)
+ bcc_mask = jnp.array(bcc_mask)
+ paramset.addParameter(bcc_prms, "bcc", field=self.name, mask=bcc_mask)
+ self._bcc_shape = paramset[self.name]["bcc"].shape[0]
+
+ def getName(self):
+ return self.name
+
+ def overwrite(self, paramset):
+ # paramset to ffinfo
+ if self._use_bcc:
+ bcc_now = paramset[self.name]["bcc"]
+ mask_list = paramset.mask[self.name]["bcc"]
+ nbcc = 0
+ for nnode, node in enumerate(self.ffinfo["Forces"][self.name]["node"]):
+ if node["name"] == "BondChargeCorrection":
+ mask = mask_list[nbcc]
+ self.ffinfo["Forces"][self.name]["node"][nnode]["attrib"]["bcc"] = bcc_now[nbcc]
+ if mask < 0.999:
+ self.ffinfo["Forces"][self.name]["node"][nnode]["attrib"]["mask"] = "true"
+ nbcc += 1
+
+ def createPotential(self, topdata: DMFFTopology, nonbondedMethod,
+ nonbondedCutoff, **kwargs):
+ methodMap = {
+ app.NoCutoff: "NoCutoff",
+ app.CutoffPeriodic: "CutoffPeriodic",
+ app.CutoffNonPeriodic: "CutoffNonPeriodic",
+ app.PME: "PME",
+ }
+ if nonbondedMethod not in methodMap:
+ raise DMFFException("Illegal nonbonded method for NonbondedForce")
+
+ isNoCut = False
+ if nonbondedMethod is app.NoCutoff:
+ isNoCut = True
+
+ mscales_coul = jnp.array([0.0, 0.0, 0.0, 1.0, 1.0,
+ 1.0]) # mscale for PME
+ mscales_coul = mscales_coul.at[2].set(self.coulomb14scale)
+ self.mscales_coul = mscales_coul # for qeq calculation
+
+ # set PBC
+ if nonbondedMethod not in [app.NoCutoff, app.CutoffNonPeriodic]:
+ ifPBC = True
+ else:
+ ifPBC = False
+
+ charges = [a.meta["charge"] for a in topdata.atoms()]
+ charges = jnp.array(charges)
+
+ cov_mat = topdata.buildCovMat()
+
+ if unit.is_quantity(nonbondedCutoff):
+ r_cut = nonbondedCutoff.value_in_unit(unit.nanometer)
+ else:
+ r_cut = nonbondedCutoff
+
+ # PME Settings
+ if nonbondedMethod is app.PME:
+ cell = topdata.getPeriodicBoxVectors()
+ box = jnp.array(cell)
+ self.ethresh = kwargs.get("ethresh", 1e-5)
+ self.coeff_method = kwargs.get("PmeCoeffMethod", "openmm")
+ self.fourier_spacing = kwargs.get("PmeSpacing", 0.1)
+ kappa, K1, K2, K3 = setup_ewald_parameters(r_cut, self.ethresh,
+ box,
+ self.fourier_spacing,
+ self.coeff_method)
+
+ if self._use_bcc:
+ top_mat = np.zeros(
+ (topdata.getNumAtoms(), self._bcc_shape))
+ matched_dict = {}
+ for nparser, parser in enumerate(self.bcc_parsers):
+ matches = topdata.parseSMARTS(parser, resname=self._bcc_mol)
+ for ii, jj in matches:
+ if (ii, jj) in matched_dict:
+ del matched_dict[(ii, jj)]
+ elif (jj, ii) in matched_dict:
+ del matched_dict[(jj, ii)]
+ matched_dict[(ii, jj)] = nparser
+ for ii, jj in matched_dict.keys():
+ nval = matched_dict[(ii, jj)]
+ top_mat[ii, nval] += 1.
+ top_mat[jj, nval] -= 1.
+ topdata._meta["bcc_top_mat"] = top_mat
+
+ if nonbondedMethod is not app.PME:
+ # do not use PME
+ if nonbondedMethod in [app.CutoffPeriodic, app.CutoffNonPeriodic]:
+ # use Reaction Field
+ coulforce = CoulReactionFieldForce(
+ r_cut,
+ charges,
+ isPBC=ifPBC,
+ topology_matrix=top_mat if self._use_bcc else None)
+ if nonbondedMethod is app.NoCutoff:
+ # use NoCutoff
+ coulforce = CoulNoCutoffForce(
+ charges, topology_matrix=top_mat if self._use_bcc else None)
+ else:
+ coulforce = CoulombPMEForce(
+ r_cut,
+ charges,
+ kappa, (K1, K2, K3),
+ topology_matrix=top_mat if self._use_bcc else None)
+
+ coulenergy = coulforce.generate_get_energy()
+
+ has_aux = False
+ if "has_aux" in kwargs and kwargs["has_aux"]:
+ has_aux = True
+
+ def potential_fn(positions, box, pairs, params, aux=None):
+
+ # check whether args passed into potential_fn are jnp.array and differentiable
+ # note this check will be optimized away by jit
+ # it is jit-compatiable
+ isinstance_jnp(positions, box, params)
+
+ if self._use_bcc:
+ coulE = coulenergy(positions, box, pairs,
+ params["CoulombForce"]["bcc"], mscales_coul)
+ else:
+ coulE = coulenergy(positions, box, pairs,
+ mscales_coul)
+
+ if has_aux:
+ return coulE, aux
+ else:
+ return coulE
+
+ self._jaxPotential = potential_fn
+ return potential_fn
+
+
+_DMFFGenerators["CoulombForce"] = CoulombGenerator
+
+
+class LennardJonesGenerator:
+
+ def __init__(self, ffinfo: dict, paramset: ParamSet):
+ self.name = "LennardJonesForce"
+ self.ffinfo = ffinfo
self.lj14scale = float(
- self.ffinfo["Forces"]["NonbondedForce"]["meta"].get("lj14scale", 0.5))
- self.key_type = None
- self.type_to_charge = {}
-
- self.charge_in_residue = False
- for node in self.ffinfo["Forces"]["NonbondedForce"]["node"]:
- if not self.charge_in_residue and node["name"] == "UseAttributeFromResidue":
- if node["attrib"]["name"] == "charge":
- self.charge_in_residue = True
-
- types, sigma, epsilon, atom_mask = [], [], [], []
- for node in self.ffinfo["Forces"]["NonbondedForce"]["node"]:
+ self.ffinfo["Forces"][self.name]["meta"]["lj14scale"])
+ self.nbfix_to_idx = {}
+ self.atype_to_idx = {}
+ sig_prms, eps_prms = [], []
+ sig_mask, eps_mask = [], []
+ sig_nbfix, eps_nbfix = [], []
+ sig_nbf_mask, eps_nbf_mask = [], []
+ for node in self.ffinfo["Forces"][self.name]["node"]:
if node["name"] == "Atom":
- attribs = node["attrib"]
- self.key_type = None
- if "type" in attribs:
- self.key_type = "type"
- elif "class" in attribs:
- self.key_type = "class"
- types.append(attribs[self.key_type])
- sigma.append(float(attribs["sigma"]))
- epsilon.append(float(attribs["epsilon"]))
- mask = 1.0
- if "mask" in attribs and attribs["mask"].upper() == "TRUE":
- mask = 0.0
- atom_mask.append(mask)
- if not self.charge_in_residue:
- if "charge" not in attribs:
- raise ValueError("No charge information found in NonbondedForce or Residues.")
- self.type_to_charge[attribs[self.key_type]] = float(attribs["charge"])
+ if "type" in node["attrib"]:
+ atype, eps, sig = node["attrib"]["type"], node["attrib"][
+ "epsilon"], node["attrib"]["sigma"]
+ self.atype_to_idx[atype] = len(sig_prms)
+ elif "class" in node["attrib"]:
+ acls, eps, sig = node["attrib"]["class"], node["attrib"][
+ "epsilon"], node["attrib"]["sigma"]
+ atypes = ffinfo["ClassToType"][acls]
+ for atype in atypes:
+ self.atype_to_idx[atype] = len(sig_prms)
+ sig_prms.append(float(sig))
+ eps_prms.append(float(eps))
+ if "mask" in node["attrib"] and node["attrib"]["mask"].upper() == "TRUE":
+ sig_mask.append(0.0)
+ eps_mask.append(0.0)
+ else:
+ sig_mask.append(1.0)
+ eps_mask.append(1.0)
+ elif node["name"] == "NBFixPair":
+ if "type1" in node["attrib"]:
+ atype1, atype2, eps, sig = node["attrib"]["type1"], node["attrib"][
+ "type2"], node["attrib"]["epsilon"], node["attrib"]["sigma"]
+ if atype1 not in self.nbfix_to_idx:
+ self.nbfix_to_idx[atype1] = {}
+ if atype2 not in self.nbfix_to_idx:
+ self.nbfix_to_idx[atype2] = {}
+ self.nbfix_to_idx[atype1][atype2] = len(sig_nbfix)
+ self.nbfix_to_idx[atype2][atype1] = len(sig_nbfix)
+ elif "class1" in node["attrib"]:
+ acls1, acls2, eps, sig = node["attrib"]["class1"], node["attrib"][
+ "class2"], node["attrib"]["epsilon"], node["attrib"]["sigma"]
+ atypes1 = ffinfo["ClassToType"][acls1]
+ atypes2 = ffinfo["ClassToType"][acls2]
+ for atype1 in atypes1:
+ if atype1 not in self.nbfix_to_idx:
+ self.nbfix_to_idx[atype1] = {}
+ for atype2 in atypes2:
+ if atype2 not in self.nbfix_to_idx:
+ self.nbfix_to_idx[atype2] = {}
+ self.nbfix_to_idx[atype1][atype2] = len(sig_nbfix)
+ self.nbfix_to_idx[atype2][atype1] = len(sig_nbfix)
+ sig_nbfix.append(float(sig))
+ eps_nbfix.append(float(eps))
+ if "mask" in node["attrib"] and node["attrib"]["mask"].upper() == "TRUE":
+ sig_nbf_mask.append(0.0)
+ eps_nbf_mask.append(0.0)
+ else:
+ sig_nbf_mask.append(1.0)
+ eps_nbf_mask.append(1.0)
- sigma = jnp.array(sigma)
- epsilon = jnp.array(epsilon)
- atom_mask = jnp.array(atom_mask)
- self.atom_keys = types
- paramset.addParameter(sigma, "sigma", field=self.name, mask=atom_mask)
- paramset.addParameter(epsilon, "epsilon", field=self.name, mask=atom_mask)
+ sig_prms = jnp.array(sig_prms)
+ eps_prms = jnp.array(eps_prms)
+ sig_mask = jnp.array(sig_mask)
+ eps_mask = jnp.array(eps_mask)
+
+ sig_nbfix, eps_nbfix = jnp.array(sig_nbfix), jnp.array(eps_nbfix)
+ sig_nbf_mask = jnp.array(sig_nbf_mask)
+ eps_nbf_mask = jnp.array(eps_nbf_mask)
+
+ paramset.addField(self.name)
+ paramset.addParameter(
+ sig_prms, "sigma", field=self.name, mask=sig_mask)
+ paramset.addParameter(eps_prms, "epsilon",
+ field=self.name, mask=eps_mask)
+ paramset.addParameter(sig_nbfix, "sigma_nbfix",
+ field=self.name, mask=sig_nbf_mask)
+ paramset.addParameter(eps_nbfix, "epsilon_nbfix",
+ field=self.name, mask=eps_nbf_mask)
def getName(self):
return self.name
def overwrite(self, paramset):
- sigma = paramset[self.name]["sigma"]
- epsilon = paramset[self.name]["epsilon"]
- atom_mask = paramset.mask[self.name]["sigma"]
+ # paramset to ffinfo
+ for nnode in range(len(self.ffinfo["Forces"][self.name]["node"])):
+ node = self.ffinfo["Forces"][self.name]["node"][nnode]
+ if node["name"] == "Atom":
+ if "type" in node["attrib"]:
+ atype = node["attrib"]["type"]
+ idx = self.atype_to_idx[atype]
- node2atom = [i for i in range(len(self.ffinfo["Forces"][self.name]["node"])) if self.ffinfo["Forces"][self.name]["node"][i]["name"] == "Atom"]
+ elif "class" in node["attrib"]:
+ acls = node["attrib"]["class"]
+ atypes = self.ffinfo["ClassToType"][acls]
+ idx = self.atype_to_idx[atypes[0]]
- for natom in range(len(self.atom_keys)):
- nnode = node2atom[natom]
- sig_new = sigma[natom]
- eps_new = epsilon[natom]
- mask = atom_mask[natom]
- self.ffinfo["Forces"][self.name]["node"][nnode]["attrib"]["sigma"] = str(sig_new)
- self.ffinfo["Forces"][self.name]["node"][nnode]["attrib"]["epsilon"] = str(eps_new)
- if mask < 0.999:
- self.ffinfo["Forces"][self.name]["node"][nnode]["attrib"]["mask"] = "true"
+ eps_now = paramset[self.name]["epsilon"][idx]
+ sig_now = paramset[self.name]["sigma"][idx]
+ self.ffinfo["Forces"][
+ self.name]["node"][nnode]["attrib"]["sigma"] = sig_now
+ self.ffinfo["Forces"][
+ self.name]["node"][nnode]["attrib"]["epsilon"] = eps_now
+ # have not tested for NBFixPair overwrite
+ elif node["name"] == "NBFixPair":
+ if "type1" in node["attrib"]:
+ atype1, atype2 = node["attrib"]["type1"], node["attrib"]["type2"]
+ idx = self.nbfix_to_idx[atype1][atype2]
+ elif "class1" in node["attrib"]:
+ acls1, acls2 = node["attrib"]["class1"], node["attrib"]["class2"]
+ atypes1 = self.ffinfo["ClassToType"][acls1]
+ atypes2 = self.ffinfo["ClassToType"][acls2]
+ idx = self.nbfix_to_idx[atypes1[0]][atypes2[0]]
+ sig_now = paramset[self.name]["sigma_nbfix"][idx]
+ eps_now = paramset[self.name]["epsilon_nbfix"][idx]
+ self.ffinfo["Forces"][self.name]["node"][nnode]["attrib"]["sigma"] = sig_now
+ self.ffinfo["Forces"][self.name]["node"][nnode]["attrib"]["epsilon"] = eps_now
- def _find_atype_key_index(self, atype: str):
- for n, i in enumerate(self.atom_keys):
- if i == atype:
- return n
- return None
-
def createPotential(self, topdata: DMFFTopology, nonbondedMethod,
nonbondedCutoff, **kwargs):
methodMap = {
app.NoCutoff: "NoCutoff",
app.CutoffPeriodic: "CutoffPeriodic",
app.CutoffNonPeriodic: "CutoffNonPeriodic",
- app.PME: "PME",
+ app.PME: "CutoffPeriodic",
}
- methodString = methodMap[nonbondedMethod]
if nonbondedMethod not in methodMap:
raise DMFFException("Illegal nonbonded method for NonbondedForce")
+ methodString = methodMap[nonbondedMethod]
- isNoCut = False
- if nonbondedMethod is app.NoCutoff:
- isNoCut = True
-
- mscales_coul = jnp.array([0.0, 0.0, self.coulomb14scale, 1.0, 1.0,
- 1.0])
- mscales_lj = jnp.array([0.0, 0.0, self.lj14scale, 1.0, 1.0,
- 1.0])
-
- # coulomb
- # set PBC
- if nonbondedMethod not in [app.NoCutoff, app.CutoffNonPeriodic]:
- ifPBC = True
- else:
- ifPBC = False
-
- if self.charge_in_residue:
- charges = [a.meta["charge"] for a in topdata.atoms()]
- charges = jnp.array(charges)
- else:
- types = [a.meta[self.key_type] for a in topdata.atoms()]
- charges = jnp.array([self.type_to_charge[i] for i in types])
-
- if unit.is_quantity(nonbondedCutoff):
- r_cut = nonbondedCutoff.value_in_unit(unit.nanometer)
- else:
- r_cut = nonbondedCutoff
-
- # PME Settings
- if nonbondedMethod is app.PME:
- cell = topdata.getPeriodicBoxVectors()
- self.ethresh = kwargs.get("ethresh", 1e-6)
- self.coeff_method = kwargs.get("PmeCoeffMethod", "openmm")
- self.fourier_spacing = kwargs.get("PmeSpacing", 0.1)
- kappa, K1, K2, K3 = setup_ewald_parameters(r_cut, self.ethresh,
- cell,
- self.fourier_spacing,
- self.coeff_method)
- if nonbondedMethod is not app.PME:
- # do not use PME
- if nonbondedMethod in [app.CutoffPeriodic, app.CutoffNonPeriodic]:
- # use Reaction Field
- coulforce = CoulReactionFieldForce(r_cut, charges, isPBC=ifPBC)
- if nonbondedMethod is app.NoCutoff:
- # use NoCutoff
- coulforce = CoulNoCutoffForce(init_charges=charges)
- else:
- coulforce = CoulombPMEForce(r_cut, charges, kappa, (K1, K2, K3))
-
- self.pme_force = coulforce
- coulenergy = coulforce.generate_get_energy()
-
- # LJ
- atypes = [a.meta[self.key_type] for a in topdata.atoms()]
+ atoms = [a for a in topdata.atoms()]
+ atypes = [a.meta["type"] for a in atoms]
map_prm = []
for atype in atypes:
- pidx = self._find_atype_key_index(atype)
- if pidx is None:
+ if atype not in self.atype_to_idx:
raise DMFFException(f"Atom type {atype} not found.")
- map_prm.append(pidx)
+ idx = self.atype_to_idx[atype]
+ map_prm.append(idx)
map_prm = jnp.array(map_prm)
+ topdata._meta["lj_map_idx"] = map_prm
# not use nbfix for now
map_nbfix = []
- map_nbfix = jnp.array(map_nbfix, dtype=int).reshape((-1, 3))
- eps_nbfix = jnp.array(map_nbfix, dtype=float).reshape((-1, 3))
- sig_nbfix = jnp.array(map_nbfix, dtype=float).reshape((-1, 3))
+ for atype1 in self.nbfix_to_idx.keys():
+ for atype2 in self.nbfix_to_idx[atype1].keys():
+ nbfix_idx = self.nbfix_to_idx[atype1][atype2]
+ type1_idx = self.atype_to_idx[atype1]
+ type2_idx = self.atype_to_idx[atype2]
+ map_nbfix.append([type1_idx, type2_idx, nbfix_idx])
+ map_nbfix = np.array(map_nbfix, dtype=int).reshape((-1, 3))
if methodString in ["NoCutoff", "CutoffNonPeriodic"]:
isPBC = False
@@ -949,6 +1713,14 @@ def createPotential(self, topdata: DMFFTopology, nonbondedMethod,
isPBC = True
isNoCut = False
+ mscales_lj = jnp.array([0.0, 0.0, 0.0, 1.0, 1.0, 1.0]) # mscale for LJ
+ mscales_lj = mscales_lj.at[2].set(self.lj14scale)
+
+ if unit.is_quantity(nonbondedCutoff):
+ r_cut = nonbondedCutoff.value_in_unit(unit.nanometer)
+ else:
+ r_cut = nonbondedCutoff
+
ljforce = LennardJonesForce(0.0,
r_cut,
map_prm,
@@ -958,36 +1730,6 @@ def createPotential(self, topdata: DMFFTopology, nonbondedMethod,
isNoCut=isNoCut)
ljenergy = ljforce.generate_get_energy()
- # dispersion correction
- use_disp_corr = False
- if "useDispersionCorrection" in kwargs and kwargs["useDispersionCorrection"]:
- use_disp_corr = True
- numTypes = len(self.atom_keys)
- countVec = np.zeros(numTypes, dtype=int)
- countMat = np.zeros((numTypes, numTypes), dtype=int)
- types, count = np.unique(map_prm, return_counts=True)
- for typ, cnt in zip(types, count):
- countVec[typ] += cnt
- for i in range(numTypes):
- for j in range(i, numTypes):
- if i != j:
- countMat[i, j] = countVec[i] * countVec[j]
- else:
- countMat[i, j] = countVec[i] * (countVec[i] - 1) // 2
- assert np.sum(countMat) == len(map_prm) * (len(map_prm) - 1) // 2
-
- coval_map = topdata.buildCovMat()
- colv_pairs = np.argwhere(
- np.logical_and(coval_map > 0, coval_map <= 3))
- for pair in colv_pairs:
- if pair[0] <= pair[1]:
- tmp = (map_prm[pair[0]], map_prm[pair[1]])
- t1, t2 = min(tmp), max(tmp)
- countMat[t1, t2] -= 1
-
- ljDispCorrForce = LennardJonesLongRangeForce(r_cut, map_prm, map_nbfix, countMat)
- ljDispEnergyFn = ljDispCorrForce.generate_get_energy()
-
has_aux = False
if "has_aux" in kwargs and kwargs["has_aux"]:
has_aux = True
@@ -999,402 +1741,372 @@ def potential_fn(positions, box, pairs, params, aux=None):
# it is jit-compatiable
isinstance_jnp(positions, box, params)
- coulE = coulenergy(positions, box, pairs, mscales_coul)
-
- ljE = ljenergy(positions, box, pairs, params[self.name]["epsilon"],
- params[self.name]["sigma"], eps_nbfix, sig_nbfix, mscales_lj)
- if use_disp_corr:
- ljdispE = ljDispEnergyFn(box, params[self.name]["epsilon"],
- params[self.name]["sigma"], eps_nbfix, sig_nbfix)
- if has_aux:
- return coulE + ljE + ljdispE, aux
- else:
- return coulE + ljE + ljdispE
+ ljE = ljenergy(positions, box, pairs,
+ params[self.name]["epsilon"],
+ params[self.name]["sigma"],
+ params[self.name]["epsilon_nbfix"],
+ params[self.name]["sigma_nbfix"],
+ mscales_lj)
+
+ if has_aux:
+ return ljE, aux
else:
- if has_aux:
- return coulE + ljE, aux
- else:
- return coulE + ljE
+ return ljE
self._jaxPotential = potential_fn
return potential_fn
-_DMFFGenerators["NonbondedForce"] = NonbondedGenerator
+_DMFFGenerators["LennardJonesForce"] = LennardJonesGenerator
+
+
+class Custom1_5BondGenerator:
+ """
+ A class for generating harmonic bond force field parameters.
+ Attributes:
+ -----------
+ name : str
+ The name of the force field.
+ ffinfo : dict
+ The force field information.
+ key_type : str
+ The type of the key.
+ bond_keys : list of tuple
+ The keys of the bonds.
+ bond_params : list of tuple
+ The parameters of the bonds.
+ bond_mask : list of float
+ The mask of the bonds.
+ """
-class CoulombGenerator:
def __init__(self, ffinfo: dict, paramset: ParamSet):
- self.name = "CoulombForce"
+ """
+ Initializes the HarmonicBondGenerator.
+
+ Parameters:
+ -----------
+ ffinfo : dict
+ The force field information.
+ paramset : ParamSet
+ The parameter set.
+ """
+ self.name = "Custom1_5BondForce"
self.ffinfo = ffinfo
paramset.addField(self.name)
- self.coulomb14scale = float(
- self.ffinfo["Forces"]["CoulombForce"]["meta"]["coulomb14scale"])
- self._use_bcc = False
- self._bcc_mol = []
- self.bcc_parsers = []
- bcc_prms = []
- bcc_mask = []
- for node in self.ffinfo["Forces"]["CoulombForce"]["node"]:
- if node["name"] == "UseBondChargeCorrection":
- self._use_bcc = True
- self._bcc_mol.append(node["attrib"]["name"])
- if node["name"] == "BondChargeCorrection":
- bcc = node["attrib"]["bcc"]
- parser = node["attrib"]["smarts"] if "smarts" in node["attrib"] else node["attrib"]["smirks"]
- bcc_prms.append(float(bcc))
- self.bcc_parsers.append(parser)
- if "mask" in node["attrib"] and node["attrib"]["mask"].upper() == "TRUE":
- bcc_mask.append(0.0)
+ self.key_type = None
+
+ bond_keys, bond_params = [], []
+ for node in self.ffinfo["Forces"][self.name]["node"]:
+ attribs = node["attrib"]
+ if self.key_type is None:
+ if "atomIndex1" in attribs:
+ self.key_type = "atomIndex"
else:
- bcc_mask.append(1.0)
- bcc_prms = jnp.array(bcc_prms)
- bcc_mask = jnp.array(bcc_mask)
- paramset.addParameter(bcc_prms, "bcc", field=self.name, mask=bcc_mask)
- self._bcc_shape = paramset[self.name]["bcc"].shape[0]
+ raise ValueError(
+ "Cannot find key type for Custom1_5BondForce.")
+ key = (attribs[self.key_type + "1"], attribs[self.key_type + "2"])
+ bond_keys.append(key)
+ k = float(attribs["k"])
+ r0 = float(attribs["length"])
+ bond_params.append([k, r0])
- def getName(self):
- return self.name
+ self.bond_keys = bond_keys
+ bond_length = jnp.array([i[1] for i in bond_params])
+ bond_k = jnp.array([i[0] for i in bond_params])
- def overwrite(self, paramset):
- # paramset to ffinfo
- if self._use_bcc:
- bcc_now = paramset[self.name]["bcc"]
- mask_list = paramset.mask[self.name]["bcc"]
- nbcc = 0
- for nnode, node in enumerate(self.ffinfo["Forces"][self.name]["node"]):
- if node["name"] == "BondChargeCorrection":
- mask = mask_list[nbcc]
- self.ffinfo["Forces"][self.name]["node"][nnode]["attrib"]["bcc"] = bcc_now[nbcc]
- if mask < 0.999:
- self.ffinfo["Forces"][self.name]["node"][nnode]["attrib"]["mask"] = "true"
- nbcc += 1
+ # register parameters to ParamSet
+ paramset.addParameter(bond_length, "length",
+ field=self.name)
+ # register parameters to ParamSet
+ paramset.addParameter(bond_k, "k", field=self.name)
- def createPotential(self, topdata: DMFFTopology, nonbondedMethod,
- nonbondedCutoff, **kwargs):
- methodMap = {
- app.NoCutoff: "NoCutoff",
- app.CutoffPeriodic: "CutoffPeriodic",
- app.CutoffNonPeriodic: "CutoffNonPeriodic",
- app.PME: "PME",
- }
- if nonbondedMethod not in methodMap:
- raise DMFFException("Illegal nonbonded method for NonbondedForce")
+ def getName(self) -> str:
+ """
+ Returns the name of the force field.
- isNoCut = False
- if nonbondedMethod is app.NoCutoff:
- isNoCut = True
+ Returns:
+ --------
+ str
+ The name of the force field.
+ """
+ return self.name
- mscales_coul = jnp.array([0.0, 0.0, 0.0, 1.0, 1.0,
- 1.0]) # mscale for PME
- mscales_coul = mscales_coul.at[2].set(self.coulomb14scale)
- self.mscales_coul = mscales_coul # for qeq calculation
+ def overwrite(self, paramset: ParamSet) -> None:
+ """
+ Overwrites the parameter set.
- # set PBC
- if nonbondedMethod not in [app.NoCutoff, app.CutoffNonPeriodic]:
- ifPBC = True
- else:
- ifPBC = False
+ Parameters:
+ -----------
+ paramset : ParamSet
+ The parameter set.
+ """
+ bond_node_indices = [i for i in range(len(
+ self.ffinfo["Forces"][self.name]["node"])) if self.ffinfo["Forces"][self.name]["node"][i]["name"] == "Bond"]
- charges = [a.meta["charge"] for a in topdata.atoms()]
- charges = jnp.array(charges)
+ bond_length = paramset[self.name]["length"]
+ bond_k = paramset[self.name]["k"]
+ for nnode, key in enumerate(self.bond_keys):
+ self.ffinfo["Forces"][self.name]["node"][bond_node_indices[nnode]]["attrib"] = {
+ }
+ self.ffinfo["Forces"][self.name]["node"][bond_node_indices[nnode]
+ ]["attrib"][f"{self.key_type}1"] = key[0]
+ self.ffinfo["Forces"][self.name]["node"][bond_node_indices[nnode]
+ ]["attrib"][f"{self.key_type}2"] = key[1]
+ r0 = bond_length[nnode]
+ k = bond_k[nnode]
+ self.ffinfo["Forces"][self.name]["node"][bond_node_indices[nnode]
+ ]["attrib"]["k"] = str(k)
+ self.ffinfo["Forces"][self.name]["node"][bond_node_indices[nnode]
+ ]["attrib"]["length"] = str(r0)
- cov_mat = topdata.buildCovMat()
+ def _find_key_index(self, key: Tuple[str, str]) -> int:
+ """
+ Finds the index of the key.
- if unit.is_quantity(nonbondedCutoff):
- r_cut = nonbondedCutoff.value_in_unit(unit.nanometer)
- else:
- r_cut = nonbondedCutoff
+ Parameters:
+ -----------
+ key : tuple of str
+ The key.
- # PME Settings
- if nonbondedMethod is app.PME:
- cell = topdata.getPeriodicBoxVectors()
- box = jnp.array(cell)
- self.ethresh = kwargs.get("ethresh", 1e-5)
- self.coeff_method = kwargs.get("PmeCoeffMethod", "openmm")
- self.fourier_spacing = kwargs.get("PmeSpacing", 0.1)
- kappa, K1, K2, K3 = setup_ewald_parameters(r_cut, self.ethresh,
- box,
- self.fourier_spacing,
- self.coeff_method)
+ Returns:
+ --------
+ int
+ The index of the key.
+ """
+ for i, k in enumerate(self.bond_keys):
+ if k[0] == key[0] and k[1] == key[1]:
+ return i
+ if k[0] == key[1] and k[1] == key[0]:
+ return i
+ return None
- if self._use_bcc:
- top_mat = np.zeros(
- (topdata.getNumAtoms(), self._bcc_shape))
- matched_dict = {}
- for nparser, parser in enumerate(self.bcc_parsers):
- matches = topdata.parseSMARTS(parser, resname=self._bcc_mol)
- for ii, jj in matches:
- if (ii, jj) in matched_dict:
- del matched_dict[(ii, jj)]
- elif (jj, ii) in matched_dict:
- del matched_dict[(jj, ii)]
- matched_dict[(ii, jj)] = nparser
- for ii, jj in matched_dict.keys():
- nval = matched_dict[(ii, jj)]
- top_mat[ii, nval] += 1.
- top_mat[jj, nval] -= 1.
- topdata._meta["bcc_top_mat"] = top_mat
+ def createPotential(self, topdata: DMFFTopology, nonbondedMethod,
+ nonbondedCutoff, **kwargs):
+ """
+ Creates the potential.
- if nonbondedMethod is not app.PME:
- # do not use PME
- if nonbondedMethod in [app.CutoffPeriodic, app.CutoffNonPeriodic]:
- # use Reaction Field
- coulforce = CoulReactionFieldForce(
- r_cut,
- charges,
- isPBC=ifPBC,
- topology_matrix=top_mat if self._use_bcc else None)
- if nonbondedMethod is app.NoCutoff:
- # use NoCutoff
- coulforce = CoulNoCutoffForce(
- charges, topology_matrix=top_mat if self._use_bcc else None)
- else:
- coulforce = CoulombPMEForce(
- r_cut,
- charges,
- kappa, (K1, K2, K3),
- topology_matrix=top_mat if self._use_bcc else None)
+ Parameters:
+ -----------
+ topdata : DMFFTopology
+ The topology data.
+ nonbondedMethod : str
+ The nonbonded method.
+ nonbondedCutoff : float
+ The nonbonded cutoff.
+ args : list
+ The arguments.
- coulenergy = coulforce.generate_get_energy()
+ Returns:
+ --------
+ function
+ The potential function.
+ """
+ bond_a1, bond_a2, bond_indices = [], [], []
+ for i, k in enumerate(self.bond_keys):
+ bond_a1.append(int(k[0]))
+ bond_a2.append(int(k[1]))
+ bond_indices.append(int(i))
+ bond_a1 = jnp.array(bond_a1)
+ bond_a2 = jnp.array(bond_a2)
+ bond_indices = jnp.array(bond_indices)
+
+ # 创建势函数
+ harmonic_bond_force = HarmonicBondJaxForce(
+ bond_a1, bond_a2, bond_indices)
+ harmonic_bond_energy = harmonic_bond_force.generate_get_energy()
has_aux = False
if "has_aux" in kwargs and kwargs["has_aux"]:
has_aux = True
- def potential_fn(positions, box, pairs, params, aux=None):
-
- # check whether args passed into potential_fn are jnp.array and differentiable
- # note this check will be optimized away by jit
- # it is jit-compatiable
+ def potential_fn(positions: jnp.ndarray, box: jnp.ndarray, pairs: jnp.ndarray, params: ParamSet, aux=None):
isinstance_jnp(positions, box, params)
-
- if self._use_bcc:
- coulE = coulenergy(positions, box, pairs,
- params["CoulombForce"]["bcc"], mscales_coul)
- else:
- coulE = coulenergy(positions, box, pairs,
- mscales_coul)
-
+ energy = harmonic_bond_energy(
+ positions, box, pairs, params[self.name]["k"], params[self.name]["length"])
if has_aux:
- return coulE, aux
+ return energy, aux
else:
- return coulE
+ return energy
self._jaxPotential = potential_fn
return potential_fn
-_DMFFGenerators["CoulombForce"] = CoulombGenerator
+# register the generator
+_DMFFGenerators["Custom1_5BondForce"] = Custom1_5BondGenerator
+
+class CustomGBGenerator:
+ """
+ A class for generating Custom Generalized Born implicit solvation models.
+ The following code implements the OBC variant of the GB/SA solvation model, using the ACE approximation to estimate surface area.
+
+ Attributes:
+ -----------
+ name : str
+ The name of the force field.
+ ffinfo : dict
+ The force field information.
+ key_type : str
+ The type of the key.
+ perParticleKey : list of tuple
+ The keys of the atoms
-class LennardJonesGenerator:
+ """
def __init__(self, ffinfo: dict, paramset: ParamSet):
- self.name = "LennardJonesForce"
- self.ffinfo = ffinfo
- self.lj14scale = float(
- self.ffinfo["Forces"][self.name]["meta"]["lj14scale"])
- self.nbfix_to_idx = {}
- self.atype_to_idx = {}
- sig_prms, eps_prms = [], []
- sig_mask, eps_mask = [], []
- sig_nbfix, eps_nbfix = [], []
- sig_nbf_mask, eps_nbf_mask = [], []
- for node in self.ffinfo["Forces"][self.name]["node"]:
- if node["name"] == "Atom":
- if "type" in node["attrib"]:
- atype, eps, sig = node["attrib"]["type"], node["attrib"][
- "epsilon"], node["attrib"]["sigma"]
- self.atype_to_idx[atype] = len(sig_prms)
- elif "class" in node["attrib"]:
- acls, eps, sig = node["attrib"]["class"], node["attrib"][
- "epsilon"], node["attrib"]["sigma"]
- atypes = ffinfo["ClassToType"][acls]
- for atype in atypes:
- self.atype_to_idx[atype] = len(sig_prms)
- sig_prms.append(float(sig))
- eps_prms.append(float(eps))
- if "mask" in node["attrib"] and node["attrib"]["mask"].upper() == "TRUE":
- sig_mask.append(0.0)
- eps_mask.append(0.0)
- else:
- sig_mask.append(1.0)
- eps_mask.append(1.0)
- elif node["name"] == "NBFixPair":
- if "type1" in node["attrib"]:
- atype1, atype2, eps, sig = node["attrib"]["type1"], node["attrib"][
- "type2"], node["attrib"]["epsilon"], node["attrib"]["sigma"]
- if atype1 not in self.nbfix_to_idx:
- self.nbfix_to_idx[atype1] = {}
- if atype2 not in self.nbfix_to_idx:
- self.nbfix_to_idx[atype2] = {}
- self.nbfix_to_idx[atype1][atype2] = len(sig_nbfix)
- self.nbfix_to_idx[atype2][atype1] = len(sig_nbfix)
- elif "class1" in node["attrib"]:
- acls1, acls2, eps, sig = node["attrib"]["class1"], node["attrib"][
- "class2"], node["attrib"]["epsilon"], node["attrib"]["sigma"]
- atypes1 = ffinfo["ClassToType"][acls1]
- atypes2 = ffinfo["ClassToType"][acls2]
- for atype1 in atypes1:
- if atype1 not in self.nbfix_to_idx:
- self.nbfix_to_idx[atype1] = {}
- for atype2 in atypes2:
- if atype2 not in self.nbfix_to_idx:
- self.nbfix_to_idx[atype2] = {}
- self.nbfix_to_idx[atype1][atype2] = len(sig_nbfix)
- self.nbfix_to_idx[atype2][atype1] = len(sig_nbfix)
- sig_nbfix.append(float(sig))
- eps_nbfix.append(float(eps))
- if "mask" in node["attrib"] and node["attrib"]["mask"].upper() == "TRUE":
- sig_nbf_mask.append(0.0)
- eps_nbf_mask.append(0.0)
- else:
- sig_nbf_mask.append(1.0)
- eps_nbf_mask.append(1.0)
-
- sig_prms = jnp.array(sig_prms)
- eps_prms = jnp.array(eps_prms)
- sig_mask = jnp.array(sig_mask)
- eps_mask = jnp.array(eps_mask)
+ """
+ Initialize the CustomGBForceGenerator
- sig_nbfix, eps_nbfix = jnp.array(sig_nbfix), jnp.array(eps_nbfix)
- sig_nbf_mask = jnp.array(sig_nbf_mask)
- eps_nbf_mask = jnp.array(eps_nbf_mask)
+ Parameters:
+ -----------
+ ffinfo : dict
+ The force field information.
+ paramset : ParamSet
+ The parameter set.
+ """
+ self.name = "CustomGBForce"
+ self.ffinfo = ffinfo
paramset.addField(self.name)
- paramset.addParameter(
- sig_prms, "sigma", field=self.name, mask=sig_mask)
- paramset.addParameter(eps_prms, "epsilon",
- field=self.name, mask=eps_mask)
- paramset.addParameter(sig_nbfix, "sigma_nbfix",
- field=self.name, mask=sig_nbf_mask)
- paramset.addParameter(eps_nbfix, "epsilon_nbfix",
- field=self.name, mask=eps_nbf_mask)
+ self.key_type = None
+ self.perParticleParamIndices = [i for i in range(len(self.ffinfo["Forces"][self.name]["node"])) if self.ffinfo["Forces"][self.name]["node"][i]["name"] == "Atom"]
- def getName(self):
- return self.name
+ perParticleKey, perParticleParam, chargeMask = [], [], []
+ for i in self.perParticleParamIndices:
+ attribs = self.ffinfo["Forces"][self.name]["node"][i]["attrib"]
+ if self.key_type is None:
+ if "type" in attribs:
+ self.key_type = "type"
+ elif "class" in attribs:
+ self.key_type = "class"
+ else:
+ raise ValueError(
+ "Cannot find key type for CustomGBForce."
+ )
+ key = (attribs[self.key_type])
+ perParticleKey.append(key)
- def overwrite(self, paramset):
- # paramset to ffinfo
- for nnode in range(len(self.ffinfo["Forces"][self.name]["node"])):
- node = self.ffinfo["Forces"][self.name]["node"][nnode]
- if node["name"] == "Atom":
- if "type" in node["attrib"]:
- atype = node["attrib"]["type"]
- idx = self.atype_to_idx[atype]
+ charge = float(attribs["charge"])
+ radius = float(attribs["radius"])
+ scale = float(attribs["scale"])
- elif "class" in node["attrib"]:
- acls = node["attrib"]["class"]
- atypes = self.ffinfo["ClassToType"][acls]
- idx = self.atype_to_idx[atypes[0]]
+ # Parameter Charge is not trainable
+ chargeMask.append(0.0)
+ perParticleParam.append([charge, radius, scale])
- eps_now = paramset[self.name]["epsilon"][idx]
- sig_now = paramset[self.name]["sigma"][idx]
- self.ffinfo["Forces"][
- self.name]["node"][nnode]["attrib"]["sigma"] = sig_now
- self.ffinfo["Forces"][
- self.name]["node"][nnode]["attrib"]["epsilon"] = eps_now
- # have not tested for NBFixPair overwrite
- elif node["name"] == "NBFixPair":
- if "type1" in node["attrib"]:
- atype1, atype2 = node["attrib"]["type1"], node["attrib"]["type2"]
- idx = self.nbfix_to_idx[atype1][atype2]
- elif "class1" in node["attrib"]:
- acls1, acls2 = node["attrib"]["class1"], node["attrib"]["class2"]
- atypes1 = self.ffinfo["ClassToType"][acls1]
- atypes2 = self.ffinfo["ClassToType"][acls2]
- idx = self.nbfix_to_idx[atypes1[0]][atypes2[0]]
- sig_now = paramset[self.name]["sigma_nbfix"][idx]
- eps_now = paramset[self.name]["epsilon_nbfix"][idx]
- self.ffinfo["Forces"][self.name]["node"][nnode]["attrib"]["sigma"] = sig_now
- self.ffinfo["Forces"][self.name]["node"][nnode]["attrib"]["epsilon"] = eps_now
+ self.perParticleKey = perParticleKey
+ paramset.addParameter(jnp.array([i[0] for i in perParticleParam]),
+ "charge", field=self.name, mask=chargeMask)
+ paramset.addParameter(jnp.array([i[1] for i in perParticleParam]),
+ "radius", field=self.name)
+ paramset.addParameter(jnp.array([i[2] for i in perParticleParam]),
+ "scale", field=self.name)
- def createPotential(self, topdata: DMFFTopology, nonbondedMethod,
- nonbondedCutoff, **kwargs):
- methodMap = {
- app.NoCutoff: "NoCutoff",
- app.CutoffPeriodic: "CutoffPeriodic",
- app.CutoffNonPeriodic: "CutoffNonPeriodic",
- app.PME: "CutoffPeriodic",
- }
- if nonbondedMethod not in methodMap:
- raise DMFFException("Illegal nonbonded method for NonbondedForce")
- methodString = methodMap[nonbondedMethod]
- atoms = [a for a in topdata.atoms()]
- atypes = [a.meta["type"] for a in atoms]
- map_prm = []
- for atype in atypes:
- if atype not in self.atype_to_idx:
- raise DMFFException(f"Atom type {atype} not found.")
- idx = self.atype_to_idx[atype]
- map_prm.append(idx)
- map_prm = jnp.array(map_prm)
- topdata._meta["lj_map_idx"] = map_prm
+ def getName(self) -> str:
+ """
+ Returns the name of the force field.
- # not use nbfix for now
- map_nbfix = []
- for atype1 in self.nbfix_to_idx.keys():
- for atype2 in self.nbfix_to_idx[atype1].keys():
- nbfix_idx = self.nbfix_to_idx[atype1][atype2]
- type1_idx = self.atype_to_idx[atype1]
- type2_idx = self.atype_to_idx[atype2]
- map_nbfix.append([type1_idx, type2_idx, nbfix_idx])
- map_nbfix = np.array(map_nbfix, dtype=int).reshape((-1, 3))
+ Returns:
+ --------
+ str
+ The name of the force field.
+ """
+ return self.name
- if methodString in ["NoCutoff", "CutoffNonPeriodic"]:
- isPBC = False
- if methodString == "NoCutoff":
- isNoCut = True
- else:
- isNoCut = False
- else:
- isPBC = True
- isNoCut = False
- mscales_lj = jnp.array([0.0, 0.0, 0.0, 1.0, 1.0, 1.0]) # mscale for LJ
- mscales_lj = mscales_lj.at[2].set(self.lj14scale)
+ def overwrite(self, paramset: ParamSet) -> None:
+ """
+ Overwrites the parameter set.
- if unit.is_quantity(nonbondedCutoff):
- r_cut = nonbondedCutoff.value_in_unit(unit.nanometer)
- else:
- r_cut = nonbondedCutoff
+ Parameters:
+ -----------
+ paramset : ParamSet
+ The parameter set.
+ """
+ radius = paramset[self.name]["radius"]
+ scale = paramset[self.name]["scale"]
+ for i in self.ffinfo.perParticleParamIndices:
+ self.ffinfo["Forces"][self.name]["node"][i]["attrib"]["radius"] = str(radius[i])
+ self.ffinfo["Forces"][self.name]["node"][i]["attrib"]["scale"] = str(scale[i])
- ljforce = LennardJonesForce(0.0,
- r_cut,
- map_prm,
- map_nbfix,
- isSwitch=False,
- isPBC=isPBC,
- isNoCut=isNoCut)
- ljenergy = ljforce.generate_get_energy()
+ def _find_key_index(self, key: Tuple[str]) -> int:
+ """
+ Finds the index of the key.
- has_aux = False
- if "has_aux" in kwargs and kwargs["has_aux"]:
- has_aux = True
+ Parameters:
+ -----------
+ key : tuple of str
+ The key.
- def potential_fn(positions, box, pairs, params, aux=None):
+ Returns:
+ --------
+ int
+ The index of the key.
+ """
+ for i, k in enumerate(self.perParticleKey):
+ if k == key:
+ return i
+ return None
- # check whether args passed into potential_fn are jnp.array and differentiable
- # note this check will be optimized away by jit
- # it is jit-compatiable
- isinstance_jnp(positions, box, params)
+ def createPotential(self, topdata: DMFFTopology, nonbondedMethod, nonbondedCutoff, **kwargs):
+ """
+ Creates the potential.
- ljE = ljenergy(positions, box, pairs,
- params[self.name]["epsilon"],
- params[self.name]["sigma"],
- params[self.name]["epsilon_nbfix"],
- params[self.name]["sigma_nbfix"],
- mscales_lj)
+ Parameters:
+ -----------
+ topdata : DMFFTopology
+ The topology data.
+ nonbondedMethod : str
+ The nonbonded method.
+ nonbondedCutoff : float
+ The nonbonded cutoff.
+ args : list
+ The arguments.
- if has_aux:
- return ljE, aux
- else:
- return ljE
+ Returns:
+ --------
+ function
+ The potential function.
+ """
+ # Load CustomGBForce parameters
+ charge_indices, radius_indices, scale_indices = [], [] ,[]
+ for atom in topdata.atoms():
+ if self.key_type == "type":
+ key = (atom.meta["type"])
+ elif self.key_type == "class":
+ key = (atom.meta["class"])
+ idx = self._find_key_index(key)
+ if idx is None:
+ continue
+ charge_indices.append(idx)
+ radius_indices.append(idx)
+ scale_indices.append(idx)
+
+ charge_indices = jnp.array(charge_indices)
+ radius_indices = jnp.array(radius_indices)
+ scale_indices = jnp.array(scale_indices)
+
+ customGBforce = CustomGBForce(charge_indices, radius_indices, scale_indices)
+ GBSAOBCenergy = customGBforce.generate_get_energy()
+
+ def potential_fn(positions: jnp.ndarray, box: jnp.ndarray, pairs: jnp.ndarray, params: ParamSet):
+ pairs = pairs[:int(positions.shape[0]*(positions.shape[0]-1)/2)]
+ tt = np.vstack((pairs, pairs[:,[1, 0, 2]]))
+ Ipair = []
+ for i in range(positions.shape[0]):
+ Ipair.append([pair[1] for pair in tt if pair[0] == i])
+ Ipair = jnp.array(Ipair)
+ energy = GBSAOBCenergy(positions, box, pairs, Ipair,
+ params[self.name]["charge"],
+ params[self.name]["radius"],
+ params[self.name]['scale'])
+ return energy
self._jaxPotential = potential_fn
return potential_fn
-_DMFFGenerators["LennardJonesForce"] = LennardJonesGenerator
+_DMFFGenerators["CustomGBForce"] = CustomGBGenerator
\ No newline at end of file
diff --git a/docs/user_guide/4.1classical.md b/docs/user_guide/4.1classical.md
index 1dabe12e6..840506dc7 100644
--- a/docs/user_guide/4.1classical.md
+++ b/docs/user_guide/4.1classical.md
@@ -230,4 +230,184 @@ The attribute `coulomb14scale` and `lj14scale` specifies the scale factors betwe
```
+# CustomGBJaxForce
+## 1. Theory
+
+### Generalized Born Term
+
+This force is used to present implicit solvent model. The force consists of two energy terms: a Generalized Born Approximation term to represent the electrostatic interaction between the solute and solvent, and a surface area term to represent the free energy cost of solvating a neutral molecule. The Generalized Born energy is given by
+
+$$
+E = -\frac{1}{2}(\frac{1}{\epsilon_{solute}}-\frac{1}{\epsilon_{solvent}})\sum_{i,j}\frac{q_iq_j}{f_{GB}(d_{ij},R_i,R_j)}
+$$
+
+where the indices $i$ and $j$ run over all particles, $\epsilon_{solute}$ and $\epsilon_{solvent}$ are the dielectric constants of the solute and solvent respectively, $q_i$ is the charge of particle i, and $d_{ij}$ is the distance between particles i and j. And $f_{GB}(d_{ij},R_i,R_j)$ is defined as:
+
+$$
+f_{GB}(d_{ij},R_i,R_j)=[d^2_{ij}+R_iR_jexp(-\frac{d^2_{ij}}{4R_iR_j})]^{\frac{1}{2}}
+$$
+
+$R_i$ is the Born radius of particle i, which is calculated as:
+
+$$
+R_i = \frac{1}{\rho_i^{-1}-r_i^{-1}tanh(\alpha\Psi_i-\beta\Psi_i^2+\gamma\Psi_i^3)}
+$$
+
+where $\alpha,\beta,\gamma$ are the $GB^{OBC}II$ parameters $\alpha=1, \beta=0.8,\gamma=4.85$. $\rho_i$ is the adjusted atomic radius of particle i, which is calculated from the atomic radius $r_i$ as $\rho_i=r_i-0.009$ nm. $\Psi_i$ is calculated as an integral over the van der Waals spheres of all particles outside particle i:
+
+$$
+\Psi_i=\frac{\rho_i}{4\pi}\int_{VDM}\theta(|r|-\rho_i)\frac{1}{|r|^4}d^3r
+$$
+
+where $\theta(r)$ is a step function that excludes the interior of particle i from the integral.
+
+### Surface Area Term
+
+The surface area term is given by:
+
+$$
+E=E_{SA}·4\pi\sum_i(r_i+r_{solvent})^2(\frac{r_i}{R_i})^6
+$$
+
+where $r_i$ is the atomic radius of particle i, $r_i$ is its atomic radius, and $r_{solvent}$ is the solvent radius, which is taken to be 0.14 nm. The default value for the energy scale $E_{SA}$ is $2.25936\ kJ/mol/nm^2$.
+
+## 2. Frontend
+
+The way to specify a CustomGBJaxForce in DMFF is the same as the way doing it in OpenMM with CustomGBForce: (the example is shown in DMFF/examples/classical/gbForce/gbForce.py )
+
+```xml
+
+
+
+
+
+
+
+ step(r+sr2-or1)*0.5*(1/L-1/U+0.25*(1/U^2-1/L^2)*(r-sr2*sr2/r)+0.5*log(L/U)/r+C);
+ U=r+sr2; C=2*(1/or1-1/L)*step(sr2-r-or1); L=max(or1, D); D=abs(r-sr2); sr2 =
+ scale2*or2; or1 = radius1-0.009; or2 = radius2-0.009
+
+
+ 1/(1/or-tanh(1*psi-0.8*psi^2+4.85*psi^3)/radius); psi=I*or; or=radius-0.009
+
+
+ 28.3919551*(radius+0.14)^2*(radius/B)^6-0.5*138.935456*(1/soluteDielectric-1/solventDielectric)*charge^2/B
+
+
+ -138.935456*(1/soluteDielectric-1/solventDielectric)*charge1*charge2/f;
+ f=sqrt(r^2+B1*B2*exp(-r^2/(4*B1*B2)))
+
+
+
+ ...
+
+```
+
+Every `` tag defines a rule for creating CustomGBForce interactions between atoms. Each tag may identify the atoms either by type (using the attributes `type1` and `type2`) or by class (using the attributes `class1` and `class2`).
+For now, only the ``, `` and `` (if necessary) are trainable. DMFF CustomGBForce generators do not accept other approximation for implicit solvent model. Unless user designs a new CustomGBForce generator for the specific approximation like HCT, OBC etc.
+
+
+# CustomTorsionJaxForce
+
+## 1. Theory
+
+The custom torsion is represented by a truncated periodic Fourier series:
+
+$$
+E = \sum_{n=0}^{4} k_n(\cos(n\phi-\phi_{0n})) + shift
+$$
+
+where $\phi$ is the dihedral angle formed by four particles, $n$ is the periodicity, $\phi_{0n}$ is the phase offset $k_{n}$ is the force constant. To preserve the symmetry, $\phi_{0n}$ usually adopts a value of $0$ (for $n=1,3,5$) or $\pi$ ($n=2,4,6$), and it is recommended to follow these definitions and not to optimize them in force field development.
+
+## 2. Frontend
+
+The way to specify a custom torsion in DMFF is the same as the way doing it in OpenMM:
+
+```xml
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+```
+
+Every child tag `` or `` defines a rule for creating periodic torsion interactions between sets of four atoms. Each tag may identify the atoms either by type (using the attributes `type1`, `type2`, ...) or by class (using the attributes `class1`, `class2`, ...).
+
+The force field recognizes two different types of torsions: `Proper` and `Improper`. A proper torsion involves four atoms that are bonded in sequence: 1 to 2, 2 to 3, and 3 to 4. An improper torsion involves a central atom and three others that are bonded to it: atoms 2, 3, and 4 are all bonded to atom 1. `per1` is the periodicity of the torsion, `phase1` is the phase offset in radians, and `k1` is the force constant in kJ/mol. To add a second periodicity, just add three more attributes: `per2`, `phase2`, and `k2`. **The maxium periodicity supported in DMFF is 6, which is different from OpenMM**.
+
+You can also use wildcards when defining torsions. To do this, simply leave the type or class name for an atom empty. That will cause it to match any atom:
+
+```xml
+
+```
+
+When the tag has an attribute named `mask` and it's value set to `true`, this means the parameter is not trainable. Such information will be passed to `ParamSet.mask` (the corresponding mask value will be 0.0 if not trainable).
+
+
+# Custom1_5BondJaxForce
+
+## 1. Theory
+
+The force is used to regulate the atoms relation between atom 1 to 5 in coarse-grained polyphosphate.
+
+$$
+E = \frac{1}{2}k(b-b_0)^2
+$$
+
+where $k$ is the force constant, $b$ is the distance betweeen two particles that forming a bond and $b_0$ is the equilibrium bond length. Note that in some other MD softwares, the potential form adopts a slight different form: $E=k(b-b_0)^2$. Users should check which form to use and multiply (or divide) the force constant by 2.
+
+## 2. Frontend
+
+The way to specify a harmonic bond in DMFF is different from the way doing it in OpenMM, which requires add special force `openmm.CustomCompoundBondForce()` in coding:
+
+```xml
+
+
+
+
+
+
+
+
+```
+When using this force, you need to add `openmm.CustomCompoundBondForce()` during your simulation: (the example is shown in DMFF/examples/classical/gbForce/gbForce.py )
+
+```python
+h = Hamiltonian("CG.xml")
+params = h.getParameters()
+compoundBondForceParam = params["Custom1_5BondForce"]
+length = compoundBondForceParam["length"]
+k = compoundBondForceParam["k"]
+system = ff.createSystem(pdb.topology, nonbondedMethod=NoCutoff)
+customCompoundForce = openmm.CustomCompoundBondForce(2, "0.5*k*(distance(p1,p2)-length)^2")
+customCompoundForce.addPerBondParameter("length")
+customCompoundForce.addPerBondParameter("k")
+for i, leng in enumerate(length):
+ customCompoundForce.addBond([i, i+4], [leng, k[i]])
+system.addForce(customCompoundForce)
+```
+
+Every `` tag defines a rule for creating harmonic bond interactions between 1 and 5 atoms. Each tag may identify the atoms by index (using the attributes `atomIndex1` and `atomIndex2`). `length` is the equilibrium bond length in $\mathrm{nm}$, and `k` is the force constant in $\mathrm{kJ/mol/nm^2}$.
+For now, `Custom1_5BondJaxForce` doesn't accept different energy forms unless user designs new generator for it.
diff --git a/examples/classical/gbForce/10p.pdb b/examples/classical/gbForce/10p.pdb
new file mode 100644
index 000000000..05b18aa71
--- /dev/null
+++ b/examples/classical/gbForce/10p.pdb
@@ -0,0 +1,25 @@
+TITLE GRoups of Organic Molecules in ACtion for Science
+REMARK THIS IS A SIMULATION BOX
+CRYST1 30.000 30.000 30.000 90.00 90.00 90.00 P 1 1
+MODEL 1
+ATOM 1 TP1 TE1 A 1 28.725 22.974 33.406 1.00 0.00 P
+ATOM 2 P00 IN1 A 2 26.892 23.631 31.811 1.00 0.00 C
+ATOM 3 P00 IN1 A 3 27.633 24.222 29.541 1.00 0.00 C
+ATOM 4 P00 IN1 A 4 25.695 23.995 27.884 1.00 0.00 C
+ATOM 5 P00 IN1 A 5 26.290 25.095 25.743 1.00 0.00 C
+ATOM 6 P00 IN1 A 6 24.826 24.484 23.700 1.00 0.00 C
+ATOM 7 P00 IN1 A 7 24.944 26.058 21.774 1.00 0.00 C
+ATOM 8 P00 IN1 A 8 23.018 26.878 20.449 1.00 0.00 C
+ATOM 9 P00 IN1 A 9 22.109 25.856 18.341 1.00 0.00 C
+ATOM 10 TP1 TE1 A 10 19.868 26.802 17.355 1.00 0.00 P
+TER
+CONECT 1 2
+CONECT 2 3
+CONECT 3 4
+CONECT 4 5
+CONECT 5 6
+CONECT 6 7
+CONECT 7 8
+CONECT 8 9
+CONECT 9 10
+ENDMDL
diff --git a/examples/classical/gbForce/1_5corrV2.xml b/examples/classical/gbForce/1_5corrV2.xml
new file mode 100644
index 000000000..633b4326b
--- /dev/null
+++ b/examples/classical/gbForce/1_5corrV2.xml
@@ -0,0 +1,98 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ step(r+sr2-or1)*0.5*(1/L-1/U+0.25*(1/U^2-1/L^2)*(r-sr2*sr2/r)+0.5*log(L/U)/r+C);
+ U=r+sr2; C=2*(1/or1-1/L)*step(sr2-r-or1); L=max(or1, D); D=abs(r-sr2); sr2 =
+ scale2*or2; or1 = radius1-0.009; or2 = radius2-0.009
+
+
+ 1/(1/or-tanh(1*psi-0.8*psi^2+4.85*psi^3)/radius); psi=I*or; or=radius-0.009
+
+
+ 28.3919551*(radius+0.14)^2*(radius/B)^6-0.5*138.935456*(1/soluteDielectric-1/solventDielectric)*charge^2/B
+
+
+ -138.935456*(1/soluteDielectric-1/solventDielectric)*charge1*charge2/f;
+ f=sqrt(r^2+B1*B2*exp(-r^2/(4*B1*B2)))
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/examples/classical/gbForce/gbforce.py b/examples/classical/gbForce/gbforce.py
new file mode 100644
index 000000000..55a42756f
--- /dev/null
+++ b/examples/classical/gbForce/gbforce.py
@@ -0,0 +1,99 @@
+#!/usr/bin/env python
+from openmm import *
+from openmm.app import *
+from openmm.unit import *
+import numpy as np
+import sys
+sys.path.append("..")
+sys.path.append(".")
+sys.path.append("...")
+from dmff.api.hamiltonian import Hamiltonian
+from dmff.common import nblist
+from jax import jit
+import jax.numpy as jnp
+import mdtraj as md
+
+def forcegroupify(system):
+ forcegroups = {}
+ for i in range(system.getNumForces()):
+ force = system.getForce(i)
+ force.setForceGroup(i)
+ forcegroups[force] = i
+ return forcegroups
+
+def getEnergyDecomposition(context, forcegroups):
+ energies = {}
+ for f, i in forcegroups.items():
+ energies[f] = context.getState(getEnergy=True, groups=2 ** i).getPotentialEnergy()
+ return energies
+
+print("MM Reference Energy:")
+pdb = PDBFile("10p.pdb")
+ff = ForceField("1_5corrV2.xml")
+system = ff.createSystem(pdb.topology, nonbondedMethod=NoCutoff, constraints=None, removeCMMotion=False)
+h = Hamiltonian("1_5corrV2.xml")
+params = h.getParameters()
+compoundBondForceParam = params["Custom1_5BondForce"]
+length = compoundBondForceParam["length"]
+k = compoundBondForceParam["k"]
+customCompoundForce = openmm.CustomCompoundBondForce(2, "0.5*k*(distance(p1,p2)-length)^2")
+customCompoundForce.addPerBondParameter("length")
+customCompoundForce.addPerBondParameter("k")
+for i, leng in enumerate(length):
+ customCompoundForce.addBond([i, i+4], [leng, k[i]])
+system.addForce(customCompoundForce)
+print("Dih info:")
+for force in system.getForces():
+ if isinstance(force, PeriodicTorsionForce):
+ print("No. of dihs:", force.getNumTorsions())
+
+forcegroups = forcegroupify(system)
+integrator = VerletIntegrator(0.1)
+context = Context(system, integrator, Platform.getPlatformByName("Reference"))
+context.setPositions(pdb.positions)
+state = context.getState(getEnergy=True)
+energy = state.getPotentialEnergy()
+energies = getEnergyDecomposition(context, forcegroups)
+print("Total energy:", energy)
+for key in energies.keys():
+ print(key.getName(), energies[key])
+
+print("Jax Energy")
+h = Hamiltonian("1_5corrV2.xml")
+pot = h.createPotential(pdb.topology, nonbondedMethod=NoCutoff)
+params = h.getParameters()
+positions = pdb.getPositions(asNumpy=True).value_in_unit(nanometer)
+positions = jnp.array(positions)
+box = np.array([
+ [30.0, 0.0, 0.0],
+ [0.0, 30.0, 0.0],
+ [0.0, 0.0, 30.0]
+])
+
+# neighbor list
+rc = 6.0
+nbl = nblist.NeighborList(box, rc, pot.meta['cov_map'])
+nbl.allocate(positions)
+pairs = nbl.pairs
+
+bondE = pot.dmff_potentials['HarmonicBondForce']
+print("Bond:", bondE(positions, box, pairs, params))
+
+angleE = pot.dmff_potentials['HarmonicAngleForce']
+print("Angle:", angleE(positions, box, pairs, params))
+
+gbE = pot.dmff_potentials['CustomGBForce']
+print("CustomGBForce:", gbE(positions, box, pairs, params))
+
+E1_5 = pot.dmff_potentials['Custom1_5BondForce']
+print("Custom1_5BondForce:", E1_5(positions, box, pairs, params))
+
+dihE = pot.dmff_potentials['CustomTorsionForce']
+print("Torsion:", dihE(positions, box, pairs, params))
+
+nbE = pot.dmff_potentials['NonbondedForce']
+print("Nonbonded:", nbE(positions, box, pairs, params))
+
+etotal = pot.getPotentialFunc()
+print("Total:", etotal(positions, box, pairs, params))
+
diff --git a/tests/data/10p.pdb b/tests/data/10p.pdb
new file mode 100644
index 000000000..05b18aa71
--- /dev/null
+++ b/tests/data/10p.pdb
@@ -0,0 +1,25 @@
+TITLE GRoups of Organic Molecules in ACtion for Science
+REMARK THIS IS A SIMULATION BOX
+CRYST1 30.000 30.000 30.000 90.00 90.00 90.00 P 1 1
+MODEL 1
+ATOM 1 TP1 TE1 A 1 28.725 22.974 33.406 1.00 0.00 P
+ATOM 2 P00 IN1 A 2 26.892 23.631 31.811 1.00 0.00 C
+ATOM 3 P00 IN1 A 3 27.633 24.222 29.541 1.00 0.00 C
+ATOM 4 P00 IN1 A 4 25.695 23.995 27.884 1.00 0.00 C
+ATOM 5 P00 IN1 A 5 26.290 25.095 25.743 1.00 0.00 C
+ATOM 6 P00 IN1 A 6 24.826 24.484 23.700 1.00 0.00 C
+ATOM 7 P00 IN1 A 7 24.944 26.058 21.774 1.00 0.00 C
+ATOM 8 P00 IN1 A 8 23.018 26.878 20.449 1.00 0.00 C
+ATOM 9 P00 IN1 A 9 22.109 25.856 18.341 1.00 0.00 C
+ATOM 10 TP1 TE1 A 10 19.868 26.802 17.355 1.00 0.00 P
+TER
+CONECT 1 2
+CONECT 2 3
+CONECT 3 4
+CONECT 4 5
+CONECT 5 6
+CONECT 6 7
+CONECT 7 8
+CONECT 8 9
+CONECT 9 10
+ENDMDL
diff --git a/tests/data/1_5corrV2.xml b/tests/data/1_5corrV2.xml
new file mode 100644
index 000000000..633b4326b
--- /dev/null
+++ b/tests/data/1_5corrV2.xml
@@ -0,0 +1,98 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ step(r+sr2-or1)*0.5*(1/L-1/U+0.25*(1/U^2-1/L^2)*(r-sr2*sr2/r)+0.5*log(L/U)/r+C);
+ U=r+sr2; C=2*(1/or1-1/L)*step(sr2-r-or1); L=max(or1, D); D=abs(r-sr2); sr2 =
+ scale2*or2; or1 = radius1-0.009; or2 = radius2-0.009
+
+
+ 1/(1/or-tanh(1*psi-0.8*psi^2+4.85*psi^3)/radius); psi=I*or; or=radius-0.009
+
+
+ 28.3919551*(radius+0.14)^2*(radius/B)^6-0.5*138.935456*(1/soluteDielectric-1/solventDielectric)*charge^2/B
+
+
+ -138.935456*(1/soluteDielectric-1/solventDielectric)*charge1*charge2/f;
+ f=sqrt(r^2+B1*B2*exp(-r^2/(4*B1*B2)))
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/tests/data/pBox.pdb b/tests/data/pBox.pdb
new file mode 100644
index 000000000..b1a0118e1
--- /dev/null
+++ b/tests/data/pBox.pdb
@@ -0,0 +1,87 @@
+TITLE polyten_GMX.gro created by acpype (v: 2022.6.6) on Wed Oct 5 08:03:17 2022
+REMARK THIS IS A SIMULATION BOX
+CRYST1 30.000 30.000 30.000 90.00 90.00 90.00 P 1 1
+MODEL 1
+HETATM 1 P00 TE1 1 28.652 12.361 14.349 1.00 0.00 P
+HETATM 2 O01 TE1 1 28.511 11.590 12.887 1.00 0.00 O
+HETATM 3 O02 TE1 1 28.857 11.267 15.580 1.00 0.00 O
+HETATM 4 O03 TE1 1 29.865 13.487 14.329 1.00 0.00 O
+HETATM 5 O04 TE1 1 26.965 13.160 14.653 1.00 0.00 O
+HETATM 6 P00 IN1 2 26.260 14.805 14.965 1.00 0.00 P
+HETATM 7 O01 IN1 2 26.855 15.339 16.380 1.00 0.00 O
+HETATM 8 O02 IN1 2 26.508 15.706 13.631 1.00 0.00 O
+HETATM 9 O00 IN2 3 24.431 14.542 15.100 1.00 0.00 O
+HETATM 10 P00 IN1 4 23.086 15.420 16.142 1.00 0.00 P
+HETATM 11 O01 IN1 4 22.975 14.492 17.468 1.00 0.00 O
+HETATM 12 O02 IN1 4 23.530 16.970 16.231 1.00 0.00 O
+HETATM 13 O00 IN2 5 21.421 15.307 15.326 1.00 0.00 O
+HETATM 14 P00 IN1 6 19.909 16.539 15.340 1.00 0.00 P
+HETATM 15 O01 IN1 6 20.000 17.324 16.743 1.00 0.00 O
+HETATM 16 O02 IN1 6 20.063 17.265 13.902 1.00 0.00 O
+HETATM 17 O00 IN2 7 18.227 15.697 15.269 1.00 0.00 O
+HETATM 18 P00 IN1 8 16.526 16.321 15.943 1.00 0.00 P
+HETATM 19 O01 IN1 8 16.373 15.511 17.337 1.00 0.00 O
+HETATM 20 O02 IN1 8 16.580 17.927 15.852 1.00 0.00 O
+HETATM 21 O00 IN2 9 15.002 15.739 14.963 1.00 0.00 O
+HETATM 22 P00 IN1 10 13.371 16.585 14.465 1.00 0.00 P
+HETATM 23 O01 IN1 10 13.092 17.763 15.532 1.00 0.00 O
+HETATM 24 O02 IN1 10 13.581 16.841 12.885 1.00 0.00 O
+HETATM 25 O00 IN2 11 11.772 15.515 14.643 1.00 0.00 O
+HETATM 26 P00 IN1 12 10.263 15.275 13.556 1.00 0.00 P
+HETATM 27 O01 IN1 12 10.068 16.593 12.650 1.00 0.00 O
+HETATM 28 O02 IN1 12 10.482 13.800 12.922 1.00 0.00 O
+HETATM 29 O00 IN2 13 8.598 15.069 14.540 1.00 0.00 O
+HETATM 30 P00 IN1 14 6.866 15.556 14.061 1.00 0.00 P
+HETATM 31 O01 IN1 14 6.615 17.010 14.733 1.00 0.00 O
+HETATM 32 O02 IN1 14 6.677 15.310 12.476 1.00 0.00 O
+HETATM 33 O00 IN2 15 5.578 14.441 14.933 1.00 0.00 O
+HETATM 34 P00 IN1 16 3.855 13.938 14.474 1.00 0.00 P
+HETATM 35 O01 IN1 16 3.150 15.098 13.579 1.00 0.00 O
+HETATM 36 O02 IN1 16 3.961 12.437 13.850 1.00 0.00 O
+HETATM 37 O04 TE1 17 2.946 13.817 16.041 1.00 0.00 O
+HETATM 38 P00 TE1 17 1.214 13.377 16.663 1.00 0.00 P
+HETATM 39 O01 TE1 17 0.217 12.929 15.419 1.00 0.00 O
+HETATM 40 O02 TE1 17 0.655 14.736 17.433 1.00 0.00 O
+HETATM 41 O03 TE1 17 1.442 12.135 17.739 1.00 0.00 O
+TER
+CONECT 1 2
+CONECT 1 3
+CONECT 1 4
+CONECT 1 5
+CONECT 5 6
+CONECT 6 7
+CONECT 6 8
+CONECT 6 9
+CONECT 9 10
+CONECT 10 11
+CONECT 10 12
+CONECT 10 13
+CONECT 13 14
+CONECT 14 15
+CONECT 14 16
+CONECT 14 17
+CONECT 17 18
+CONECT 18 19
+CONECT 18 20
+CONECT 18 21
+CONECT 21 22
+CONECT 22 23
+CONECT 22 24
+CONECT 22 25
+CONECT 25 26
+CONECT 26 27
+CONECT 26 28
+CONECT 26 29
+CONECT 29 30
+CONECT 30 31
+CONECT 30 32
+CONECT 30 33
+CONECT 33 34
+CONECT 34 35
+CONECT 34 36
+CONECT 34 37
+CONECT 37 38
+CONECT 38 39
+CONECT 38 40
+CONECT 38 41
+ENDMDL
diff --git a/tests/data/polyp_amberImp.xml b/tests/data/polyp_amberImp.xml
new file mode 100644
index 000000000..15f6c9721
--- /dev/null
+++ b/tests/data/polyp_amberImp.xml
@@ -0,0 +1,111 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ step(r+sr2-or1)*0.5*(1/L-1/U+0.25*(1/U^2-1/L^2)*(r-sr2*sr2/r)+0.5*log(L/U)/r+C);
+ U=r+sr2; C=2*(1/or1-1/L)*step(sr2-r-or1); L=max(or1, D); D=abs(r-sr2); sr2 =
+ scale2*or2; or1 = radius1-0.009; or2 = radius2-0.009
+
+
+ 1/(1/or-tanh(1*psi-0.8*psi^2+4.85*psi^3)/radius); psi=I*or; or=radius-0.009
+
+
+ 28.3919551*(radius+0.14)^2*(radius/B)^6-0.5*138.935456*
+ (1/soluteDielectric-1/solventDielectric)*charge^2/B
+
+
+ -138.935456*(1/soluteDielectric-1/solventDielectric)*charge1*charge2/f;
+ f=sqrt(r^2+B1*B2*exp(-r^2/(4*B1*B2)))
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/tests/test_classical/test_gbforce.py b/tests/test_classical/test_gbforce.py
new file mode 100644
index 000000000..ffa29c5a7
--- /dev/null
+++ b/tests/test_classical/test_gbforce.py
@@ -0,0 +1,80 @@
+import pytest
+import jax
+import jax.numpy as jnp
+import openmm.app as app
+import openmm.unit as unit
+import numpy as np
+import numpy.testing as npt
+from dmff.api import Hamiltonian
+from dmff.common import nblist
+
+
+@pytest.mark.parametrize(
+ "pdb, prm, value",
+ [
+ ("./tests/data/10p.pdb", "./tests/data/1_5corrV2.xml", -11184.921239189738),
+ ("./tests/data/pBox.pdb", "./tests/data/polyp_amberImp.xml", -13914.34177591779),
+ ])
+def test_custom_gb_force(pdb, prm, value):
+ pdb = app.PDBFile(pdb)
+ h = Hamiltonian(prm)
+ potential = h.createPotential(
+ pdb.topology,
+ nonbondedMethod=app.NoCutoff
+ )
+ pos = jnp.asarray(pdb.getPositions(asNumpy=True).value_in_unit(unit.nanometer))
+ box = np.array([[20.0, 0.0, 0.0], [0.0, 20.0, 0.0], [0.0, 0.0, 20.0]])
+ rc = 6.0
+ nbl = nblist.NeighborList(box, rc, potential.meta['cov_map'])
+ nbl.allocate(pos)
+ pairs = nbl.pairs
+ gbE = potential.getPotentialFunc(names=["CustomGBForce"])
+ energy = gbE(pos, box, pairs, h.paramset)
+ print(energy)
+ npt.assert_almost_equal(energy, value, decimal=3)
+
+
+@pytest.mark.parametrize(
+ "pdb, prm, value",
+ [
+ ("./tests/data/10p.pdb", "./tests/data/1_5corrV2.xml", 59.53033875302844),
+ ])
+def test_custom_torsion_force(pdb, prm, value):
+ pdb = app.PDBFile(pdb)
+ h = Hamiltonian(prm)
+ potential = h.createPotential(
+ pdb.topology,
+ nonbondedMethod=app.NoCutoff
+ )
+ pos = jnp.asarray(pdb.getPositions(asNumpy=True).value_in_unit(unit.nanometer))
+ box = np.array([[20.0, 0.0, 0.0], [0.0, 20.0, 0.0], [0.0, 0.0, 20.0]])
+ rc = 6.0
+ nbl = nblist.NeighborList(box, rc, potential.meta['cov_map'])
+ nbl.allocate(pos)
+ pairs = nbl.pairs
+ gbE = potential.getPotentialFunc(names=["CustomTorsionForce"])
+ energy = gbE(pos, box, pairs, h.paramset)
+ npt.assert_almost_equal(energy, value, decimal=3)
+
+
+@pytest.mark.parametrize(
+ "pdb, prm, value",
+ [
+ ("./tests/data/10p.pdb", "./tests/data/1_5corrV2.xml", 117.95416362791674),
+ ])
+def test_custom_1_5bond_force(pdb, prm, value):
+ pdb = app.PDBFile(pdb)
+ h = Hamiltonian(prm)
+ potential = h.createPotential(
+ pdb.topology,
+ nonbondedMethod=app.NoCutoff
+ )
+ pos = jnp.asarray(pdb.getPositions(asNumpy=True).value_in_unit(unit.nanometer))
+ box = np.array([[20.0, 0.0, 0.0], [0.0, 20.0, 0.0], [0.0, 0.0, 20.0]])
+ rc = 6.0
+ nbl = nblist.NeighborList(box, rc, potential.meta['cov_map'])
+ nbl.allocate(pos)
+ pairs = nbl.pairs
+ gbE = potential.getPotentialFunc(names=["Custom1_5BondForce"])
+ energy = gbE(pos, box, pairs, h.paramset)
+ npt.assert_almost_equal(energy, value, decimal=3)
\ No newline at end of file