diff --git a/dmff/api/graph.py b/dmff/api/graph.py index 5e4085581..0096e5e68 100644 --- a/dmff/api/graph.py +++ b/dmff/api/graph.py @@ -14,14 +14,28 @@ import warnings warnings.warn("RDKit is not installed. SMIRKS pattern matching cannot be used.") +def is_same_list(l1, l2): + if len(l1) != len(l2): + return False + for nn in range(len(l1)): + if l1[nn] != l2[nn]: + return False + return True def matchTemplate(graph, template): if graph.number_of_nodes() != template.number_of_nodes(): # print("Node with different number of nodes.") return False, {}, {} - def match_func(n1, n2): - return n1["element"] == n2["element"] and n1["external_bond"] == n2["external_bond"] + name_graph = sorted([i[1]['name'] for i in graph.nodes.data()]) + name_template = sorted([i[1]['name'] for i in template.nodes.data()]) + + if is_same_list(name_graph, name_template): + def match_func(n1, n2): + return n1["element"] == n2["element"] and n1["external_bond"] == n2["external_bond"] and n1['name'] == n2['name'] + else: + def match_func(n1, n2): + return n1["element"] == n2["element"] and n1["external_bond"] == n2["external_bond"] def edge_match(e1, e2): if len(e1) == 0 and len(e2) == 0: diff --git a/dmff/classical/intra.py b/dmff/classical/intra.py index 0625f48ee..76843b979 100644 --- a/dmff/classical/intra.py +++ b/dmff/classical/intra.py @@ -1,5 +1,7 @@ +import jax import jax.numpy as jnp from jax import grad, value_and_grad, vmap +from ..admp.spatial import v_pbc_shift def distance(p1v, p2v): @@ -14,6 +16,13 @@ def angle(p1v, p2v, p3v): vzz = v1[:, 2] * v2[:, 2] return jnp.arccos(vxx + vyy + vzz) +@jax.vmap +def angle_v(v1, v2): + # compute the angle between v1 and v2 + v1n = v1 / jnp.linalg.norm(v1) + v2n = v2 / jnp.linalg.norm(v2) + return jnp.arccos(jnp.dot(v1n, v2n)) + def dihedral(i, j, k, l): b1, b2, b3 = j - i, k - j, l - k @@ -72,12 +81,15 @@ def __init__(self, p1idx, p2idx, p3idx, prmidx): def generate_get_energy(self): def get_energy(positions, box, pairs, k, theta0): + box_inv = jnp.linalg.inv(box + jnp.eye(3) * 1e-36) p1 = positions[self.p1idx,:] p2 = positions[self.p2idx,:] p3 = positions[self.p3idx,:] + v1 = v_pbc_shift(p1 - p2, box, box_inv) + v2 = v_pbc_shift(p3 - p2, box, box_inv) kprm = k[self.prmidx] theta0prm = theta0[self.prmidx] - ang = angle(p1, p2, p3) + ang = angle_v(v1, v2) return jnp.sum(0.5 * kprm * jnp.power(ang - theta0prm, 2)) return get_energy diff --git a/dmff/generators/classical.py b/dmff/generators/classical.py index b865f6b88..626957ade 100644 --- a/dmff/generators/classical.py +++ b/dmff/generators/classical.py @@ -418,7 +418,6 @@ def createPotential(self, topdata: DMFFTopology, nonbondedMethod, angle_a3 = jnp.array(angle_a3) angle_indices = jnp.array(angle_indices) - # 创建势函数 harmonic_angle_force = HarmonicAngleJaxForce( angle_a1, angle_a2, angle_a3, angle_indices) harmonic_angle_energy = harmonic_angle_force.generate_get_energy() @@ -427,7 +426,6 @@ def createPotential(self, topdata: DMFFTopology, nonbondedMethod, if "has_aux" in kwargs and kwargs["has_aux"]: has_aux = True - # 包装成统一的potential_function函数形式,传入四个参数:positions, box, pairs, parameters。 def potential_fn(positions: jnp.ndarray, box: jnp.ndarray, pairs: jnp.ndarray, params: ParamSet, aux=None): isinstance_jnp(positions, box, params) energy = harmonic_angle_energy(