diff --git a/dmff/admp/disp_pme.py b/dmff/admp/disp_pme.py index c0bcd3284..ff8588bb4 100755 --- a/dmff/admp/disp_pme.py +++ b/dmff/admp/disp_pme.py @@ -164,7 +164,7 @@ def disp_pme_real(positions, box, pairs, # pairs = pairs[pairs[:, 0] < pairs[:, 1]] pairs = pairs.at[:, :2].set(regularize_pairs(pairs[:, :2])) - box_inv = jnp.linalg.inv(box) + box_inv = jnp.linalg.inv(box + jnp.eye(3) * 1e-36) ri = distribute_v3(positions, pairs[:, 0]) rj = distribute_v3(positions, pairs[:, 1]) diff --git a/dmff/admp/mbpol_intra.py b/dmff/admp/mbpol_intra.py index e6155e375..f8bb14186 100755 --- a/dmff/admp/mbpol_intra.py +++ b/dmff/admp/mbpol_intra.py @@ -431,7 +431,7 @@ ## compute intra def onebodyenergy(positions, box): - box_inv = jnp.linalg.inv(box) + box_inv = jnp.linalg.inv(box + jnp.eye(3) * 1e-36) O = positions[::3] H1 = positions[1::3] H2 = positions[2::3] diff --git a/dmff/admp/pairwise.py b/dmff/admp/pairwise.py index 8fbc1b856..2af815740 100755 --- a/dmff/admp/pairwise.py +++ b/dmff/admp/pairwise.py @@ -77,7 +77,7 @@ def pair_int(positions, box, pairs, mScales, *atomic_params): buffer_scales = pair_buffer_scales(pairs) mscales = mscales * buffer_scales # mscales = mScales[nbonds-1] - box_inv = jnp.linalg.inv(box) + box_inv = jnp.linalg.inv(box + jnp.eye(3) * 1e-36) dr = ri - rj dr = v_pbc_shift(dr, box, box_inv) dr = jnp.linalg.norm(dr, axis=1) diff --git a/dmff/admp/pme.py b/dmff/admp/pme.py index 14b9d4056..878ac66c0 100755 --- a/dmff/admp/pme.py +++ b/dmff/admp/pme.py @@ -1104,7 +1104,7 @@ def pme_real( """ pairs = pairs.at[:, :2].set(regularize_pairs(pairs[:, :2])) buffer_scales = pair_buffer_scales(pairs[:, :2]) - box_inv = jnp.linalg.inv(box) + box_inv = jnp.linalg.inv(box + jnp.eye(3) * 1e-36) r1 = distribute_v3(positions, pairs[:, 0]) r2 = distribute_v3(positions, pairs[:, 1]) Q_extendi = distribute_multipoles(Q_global, pairs[:, 0]) diff --git a/dmff/admp/qeq.py b/dmff/admp/qeq.py index b263d2b5f..4b51f858d 100644 --- a/dmff/admp/qeq.py +++ b/dmff/admp/qeq.py @@ -164,7 +164,7 @@ def ds_pairs(positions, box, pairs, pbc_flag): if pbc_flag is False: dr = pos1 - pos2 else: - box_inv = jnp.linalg.inv(box) + box_inv = jnp.linalg.inv(box + jnp.eye(3) * 1e-36) dpos = pos1 - pos2 dpos = dpos.dot(box_inv) dpos -= jnp.floor(dpos + 0.5) diff --git a/dmff/admp/recip.py b/dmff/admp/recip.py index fc907812c..7ae4e077d 100755 --- a/dmff/admp/recip.py +++ b/dmff/admp/recip.py @@ -42,7 +42,7 @@ def get_recip_vectors(N, box): 3 x 3 matrix, the first index denotes reciprocal lattice vector, the second index is the component xyz. (lattice vectors arranged in rows) """ - Nj_Aji_star = (N.reshape((1, 3)) * jnp.linalg.inv(box)).T + Nj_Aji_star = (N.reshape((1, 3)) * jnp.linalg.inv(box + jnp.eye(3) * 1e-36)).T return Nj_Aji_star @@ -396,7 +396,7 @@ def setup_kpts(box, kpts_int): 4 * K, K=K1*K2*K3, contains kx, ky, kz, k^2 for each kpoint ''' # in this array, a*, b*, c* (without 2*pi) are arranged in column - box_inv = jnp.linalg.inv(box).T + box_inv = jnp.linalg.inv(box + jnp.eye(3) * 1e-36).T # K * 3, coordinate in reciprocal space kpts = 2 * jnp.pi * kpts_int.dot(box_inv) ksq = jnp.sum(kpts**2, axis=1) diff --git a/dmff/admp/spatial.py b/dmff/admp/spatial.py index dbfc04464..dbfd7461e 100644 --- a/dmff/admp/spatial.py +++ b/dmff/admp/spatial.py @@ -37,7 +37,7 @@ def normalize(matrix, axis=1, ord=2): ''' Normalise a matrix along one dimension ''' - normalised = matrix / jnp.linalg.norm(matrix, axis=axis, keepdims=True, ord=ord) + normalised = matrix / jnp.linalg.norm(matrix + 1e-36, axis=axis, keepdims=True, ord=ord) return normalised @@ -93,7 +93,7 @@ def construct_local_frames(positions, box): positions = jnp.array(positions) n_sites = positions.shape[0] - box_inv = jnp.linalg.inv(box) + box_inv = jnp.linalg.inv(box + jnp.eye(3) * 1e-36) ### Process the x, y, z vectors according to local axis rules vec_z = pbc_shift(positions[z_atoms] - positions, box, box_inv) diff --git a/dmff/classical/fep.py b/dmff/classical/fep.py index 01dc838ad..dba203200 100644 --- a/dmff/classical/fep.py +++ b/dmff/classical/fep.py @@ -107,7 +107,7 @@ def get_energy(positions, box, pairs, epsilon, sigma, epsfix, sigfix, mscales, l eps_scale = eps * mscale_pair if self.ifPBC: - dr_vec = v_pbc_shift(dr_vec, box, jnp.linalg.inv(box)) + dr_vec = v_pbc_shift(dr_vec, box, jnp.linalg.inv(box + jnp.eye(3) * 1e-36)) dr_norm = jnp.linalg.norm(dr_vec, axis=1) @@ -281,7 +281,7 @@ def get_energy(positions, box, pairs, charges, mscales, lambda_): pairs = pairs.at[:, :2].set(regularize_pairs(pairs[:, :2])) bufScales = pair_buffer_scales(pairs[:, :2]) dr_vec = positions[pairs[:, 0]] - positions[pairs[:, 1]] - dr_vec = v_pbc_shift(dr_vec, box, jnp.linalg.inv(box)) + dr_vec = v_pbc_shift(dr_vec, box, jnp.linalg.inv(box + jnp.eye(3) * 1e-36)) dr_norm = jnp.linalg.norm(dr_vec, axis=1) atomCharges = charges[self.map_prm[np.arange(positions.shape[0])]] diff --git a/dmff/classical/inter.py b/dmff/classical/inter.py index d6f393783..4fd304ec8 100644 --- a/dmff/classical/inter.py +++ b/dmff/classical/inter.py @@ -37,7 +37,7 @@ def __init__( def generate_get_energy(self): def get_LJ_energy(dr_vec, sig, eps, box): if self.ifPBC: - dr_vec = v_pbc_shift(dr_vec, box, jnp.linalg.inv(box)) + dr_vec = v_pbc_shift(dr_vec, box, jnp.linalg.inv(box + jnp.eye(3) * 1e-36)) dr_norm = jnp.linalg.norm(dr_vec, axis=1) dr_inv = 1.0 / dr_norm @@ -224,7 +224,7 @@ def __init__( def generate_get_energy(self): def get_rf_energy(dr_vec, chrgprod, box): if self.ifPBC: - dr_vec = v_pbc_shift(dr_vec, box, jnp.linalg.inv(box)) + dr_vec = v_pbc_shift(dr_vec, box, jnp.linalg.inv(box + jnp.eye(3) * 1e-36)) dr_norm = jnp.linalg.norm(dr_vec, axis=1) dr_inv = 1.0 / dr_norm diff --git a/dmff/eann/eann.py b/dmff/eann/eann.py index 1c04dd400..bd7db7487 100644 --- a/dmff/eann/eann.py +++ b/dmff/eann/eann.py @@ -300,7 +300,7 @@ def get_energy(positions, box, pairs, params): buffer_scales = pair_buffer_scales(pairs) # get distances - box_inv = jnp.linalg.inv(box) + box_inv = jnp.linalg.inv(box + jnp.eye(3) * 1e-36) ri = distribute_v3(positions, pairs[:, 0]) rj = distribute_v3(positions, pairs[:, 1]) dr = rj - ri diff --git a/dmff/sgnn/graph.py b/dmff/sgnn/graph.py index a4a6da5e9..1d85d92e2 100755 --- a/dmff/sgnn/graph.py +++ b/dmff/sgnn/graph.py @@ -94,7 +94,7 @@ def __init__(self, list_atom_elems, bonds, positions=None, box=None): self.set_internal_coords_indices() self.box = box if box is not None: - self.box_inv = jnp.linalg.inv(box) + self.box_inv = jnp.linalg.inv(box + jnp.eye(3) * 1e-36) else: self.box_inv = None return @@ -109,7 +109,7 @@ def set_box(self, box): 3 * 3: the box array, pbc vectors arranged in rows ''' self.box = box - self.box_inv = jnp.linalg.inv(box) + self.box_inv = jnp.linalg.inv(box + jnp.eye(3) * 1e-36) if hasattr(self, 'subgraphs'): self._propagate_attr('box') self._propagate_attr('box_inv') @@ -426,7 +426,7 @@ def calc_internal_coords_features(positions, box): All these variables should be "static" throughout NVE/NVT/NPT simulations ''' - box_inv = jnp.linalg.inv(box) + box_inv = jnp.linalg.inv(box + jnp.eye(3) * 1e-36) @jit_condition(static_argnums=()) @partial(vmap, in_axes=(0, None, 0), out_axes=(0)) diff --git a/examples/eann/eann.py b/examples/eann/eann.py index b9085b233..ac02427cb 100644 --- a/examples/eann/eann.py +++ b/examples/eann/eann.py @@ -285,7 +285,7 @@ def get_energy(positions, box, pairs, params): buffer_scales = pair_buffer_scales(pairs) # get distances - box_inv = jnp.linalg.inv(box) + box_inv = jnp.linalg.inv(box + jnp.eye(3) * 1e-36) ri = distribute_v3(positions, pairs[:, 0]) rj = distribute_v3(positions, pairs[:, 1]) dr = rj - ri diff --git a/examples/fluctuated_leading_term_waterff/run.py b/examples/fluctuated_leading_term_waterff/run.py index bc7c8d48c..ec6012dae 100755 --- a/examples/fluctuated_leading_term_waterff/run.py +++ b/examples/fluctuated_leading_term_waterff/run.py @@ -24,7 +24,7 @@ def compute_leading_terms(positions,box): n_atoms = len(positions) c0 = jnp.zeros(n_atoms) c6_list = jnp.zeros(n_atoms) - box_inv = jnp.linalg.inv(box) + box_inv = jnp.linalg.inv(box + jnp.eye(3) * 1e-36) O = positions[::3] H1 = positions[1::3] H2 = positions[2::3] diff --git a/tests/test_frontend/test_inter_water.py b/tests/test_frontend/test_inter_water.py index b1c98fe24..1737f303a 100644 --- a/tests/test_frontend/test_inter_water.py +++ b/tests/test_frontend/test_inter_water.py @@ -13,7 +13,7 @@ def dist_pbc(vi, vj, box): - box_inv = np.linalg.inv(box) + box_inv = np.linalg.inv(box + jnp.eye(3) * 1e-36) drvec = (vi - vj).reshape((1, 3)) unshifted_dsvecs = drvec.dot(box_inv) dsvecs = unshifted_dsvecs - np.floor(unshifted_dsvecs + 0.5)