Skip to content

Commit

Permalink
fixing triplet reference lists in MolecularNeighbourhood
Browse files Browse the repository at this point in the history
Now, modulo a sort operation, the triplet reference lists into the pair list
should match the behavior of the cutoff-based routines
  • Loading branch information
prs513rosewood committed Jun 16, 2023
1 parent 58facf1 commit 89b34a2
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 22 deletions.
47 changes: 37 additions & 10 deletions matscipy/neighbours.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,10 @@ def find_triplet_types(self, atoms: ase.Atoms, i, j, k):
"""Return triplet types from atom ids."""
return self.triplet_type(*self.triplet_to_numbers(atoms, i, j, k))

@staticmethod
def lexsort(connectivity: np.ndarray):
return np.lexsort(np.flipud(connectivity.T))

@abstractmethod
def double_neighbourhood(self):
"""Return neighbourhood with double cutoff/connectivity."""
Expand Down Expand Up @@ -308,8 +312,8 @@ def molecules(self, molecules):
if not self.double_cutoff:
self.connectivity["angles"] = \
self.double_connectivity(molecules.angles)
self.triplet_list = np.vstack([self.triplet_list,
self.triplet_list[:, (1, 0, 2)]])

# not doing anything to triplet list
else:
self.triplet_list = np.zeros([0, 3], dtype=np.int32)

Expand All @@ -321,7 +325,12 @@ def pair_type(self):
@property
def triplet_type(self):
"""Map atom types to triplet types."""
return lambda ti_p, tj_p, tk_p: self.connectivity["angles"]["type"]
def tp(ti_p, tj_p, tk_p):
types = self.connectivity["angles"]["type"]
if self.double_cutoff:
return np.concatenate([types] * 2)
return types
return tp

@staticmethod
def double_connectivity(connectivity: np.ndarray) -> np.ndarray:
Expand Down Expand Up @@ -363,7 +372,8 @@ def complete_connectivity(self, typeoffset: int = 0):
np.unique(new_bonds, return_inverse=True)

# Need to sort after all the shenanigans
idx = np.argsort(self.connectivity["bonds"]["atoms"][:, 0])
# Below sorts lexicographically the pairs (first col, then second col)
idx = Neighbourhood.lexsort(self.connectivity["bonds"]["atoms"])
self.connectivity["bonds"][:] = self.connectivity["bonds"][idx]

# To construct triplet references (aka ij_t, ik_t and jk_t):
Expand All @@ -376,7 +386,7 @@ def complete_connectivity(self, typeoffset: int = 0):
r_idx[idx] = np.arange(len(idx)) # revert sort
self.triplet_list = r_idx[indices_r][n:].reshape(e, -1).T

idx = np.argsort(self.triplet_list[:, 0]) # sort ij_t
idx = Neighbourhood.lexsort(self.triplet_list) # sort ij_t
self.triplet_list = self.triplet_list[idx]

def get_pairs(self, atoms: ase.Atoms, quantities: str, cutoff=None):
Expand All @@ -402,19 +412,36 @@ def get_triplets(self,

# Need to reorder connectivity for distances
bonds = self.connectivity["bonds"]["atoms"]
connectivity = np.array([
bonds[self.triplet_list[:, i], j]
double_triplets = np.vstack([self.triplet_list,
self.triplet_list[:, (1, 0, 2)]])

# Returning triplet references in bonds list
connectivity = double_triplets.copy()
i_p, j_p = bonds.T

first_neigh = first_neighbours(len(atoms), i_p)
ij_t, ik_t, jk_t = connectivity.T
jk_t[:] = -np.ones(len(ij_t), dtype='int32')
# This is slow as
for t, (ij, ik) in enumerate(zip(ij_t, ik_t)):
for i in np.arange(first_neigh[j_p[ij]],
first_neigh[j_p[ij] + 1]):
if i_p[i] == j_p[ij] and j_p[i] == j_p[ik]:
jk_t[t] = i
break

connectivity_in_bounds = np.array([
bonds[connectivity[:, i], j]
for i, j in [(0, 0), (0, 1), (1, 1)]
]).T

# If any distance is requested, compute distances vectors and norms
if "d" in quantities or "D" in quantities:
# i j i k j k
indices = [(0, 1), (0, 2), (1, 2)] # defined in Jan's paper
D, d = self.compute_distances(atoms, connectivity, indices)
D, d = self.compute_distances(atoms,
connectivity_in_bounds, indices)

# Returning triplet references in bonds list
connectivity = self.triplet_list
return self.make_result(
quantities, connectivity, D, d, None, accepted_quantities="ijkdD")

Expand Down
33 changes: 21 additions & 12 deletions tests/test_neighbours.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,9 +345,13 @@ class TestNeighbourhood(matscipytest.MatSciPyTestCase):
def test_pairs(self):
cutoff_d = self.cutoff.get_pairs(self.atoms, "ijdD")
molecule_d = self.molecule.get_pairs(self.atoms, "ijdD")
p = np.array([0, 1, 2, 3, 5, 4])
mask_extra_bonds = self.molecule.connectivity["bonds"]["type"] >= 0

# Lexicographic sort of pair indices, as in cutoff neighborhood
p = CutoffNeighbourhood.lexsort(
np.asarray(molecule_d[0:2]).T[mask_extra_bonds]
)

# print("CUTOFF", cutoff_d)
# print("MOLECULE", molecule_d)

Expand All @@ -361,24 +365,29 @@ def test_triplets(self):
molecules_pairs = np.array(self.molecule.get_pairs(self.atoms, "ij")).T
cutoff_d = self.cutoff.get_triplets(self.atoms, "ijk")
molecule_d = self.molecule.get_triplets(self.atoms, "ijk")
p = np.array([0, 1, 3, 2, 4, 5])

# We compare the refered pairs, not the triplet info directly
# We compare:
# - i_p[ij_t], j_p[ij_t]
# - i_p[ik_t], j_p[ik_t]
# - i_p[jk_t], j_p[jk_t]
sort_cutoff, sort_molecules = [], []
for c, m in zip(cutoff_d, molecule_d):
print("c =", cutoff_pairs[:][c])
print("m =", molecules_pairs[:][m])
self.assertArrayAlmostEqual(cutoff_pairs[:, 0][c],
molecules_pairs[:, 0][m][p], tol=1e-10)
self.assertArrayAlmostEqual(cutoff_pairs[:, 1][c],
molecules_pairs[:, 1][m][p], tol=1e-10)
sort_cutoff.append(CutoffNeighbourhood.lexsort(cutoff_pairs[c]))
sort_molecules.append(CutoffNeighbourhood.lexsort(molecules_pairs[m]))
cpairs = cutoff_pairs[c][sort_cutoff[-1]]
mpairs = molecules_pairs[m][sort_molecules[-1]]

print("c =", cpairs)
print("m =", mpairs)
self.assertArrayAlmostEqual(cpairs[:, 0], mpairs[:, 0], tol=1e-10)
self.assertArrayAlmostEqual(cpairs[:, 1], mpairs[:, 1], tol=1e-10)

# Testing computed distances and vectors
cutoff_d = self.cutoff.get_triplets(self.atoms, "dD")
molecule_d = self.molecule.get_triplets(self.atoms, "dD")

# TODO why no permutation?
for c, m in zip(cutoff_d, molecule_d):
self.assertArrayAlmostEqual(c, m, tol=1e-10)
for c, m, pc, pm in zip(cutoff_d, molecule_d, sort_cutoff, sort_molecules):
self.assertArrayAlmostEqual(c[pc], m[pm], tol=1e-10)

def test_pair_types(self):
pass
Expand Down

0 comments on commit 89b34a2

Please sign in to comment.