Skip to content

Commit

Permalink
Updated tests
Browse files Browse the repository at this point in the history
  • Loading branch information
pfebrer committed Jun 3, 2024
1 parent 26045e6 commit f690753
Show file tree
Hide file tree
Showing 4 changed files with 168 additions and 108 deletions.
2 changes: 1 addition & 1 deletion src/sisl/geom/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@
# TODO add category and neighbors in the above discussion
from ._category import *
from ._composite import *
from ._neighbors import *
from ._neighbors import NeighborFinder

# isort: split
from .basic import *
Expand Down
2 changes: 1 addition & 1 deletion src/sisl/geom/_neighbors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@
# file, You can obtain one at https://mozilla.org/MPL/2.0/.
from __future__ import annotations

from ._finder import NeighborFinder
from ._finder import *
62 changes: 44 additions & 18 deletions src/sisl/geom/_neighbors/_finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,18 @@

from . import _operations

__all__ = [
"NeighborFinder",
"Neighbors",
"AtomsNeighborList",
"UniqueNeighborList",
"FullNeighborList",
"PartialNeighborList",
"AtomNeighborList",
"PointsNeighborList",
"PointNeighborList",
]


class Neighbors:

Expand All @@ -36,7 +48,10 @@ def isc(self):

@cached_property
def nsc(self):
return np.max(np.abs(self.isc), axis=0) * 2 + 1
if len(self.isc) == 0:
return np.ones(3, dtype=int)
else:
return np.max(np.abs(self.isc), axis=0) * 2 + 1

@property
def j_sc(self):
Expand Down Expand Up @@ -98,7 +113,7 @@ def n_neighbors(self):
else:
return np.diff(self._split_indices, prepend=0)

Check warning on line 114 in src/sisl/geom/_neighbors/_finder.py

View check run for this annotation

Codecov / codecov/patch

src/sisl/geom/_neighbors/_finder.py#L114

Added line #L114 was not covered by tests

def to_full(self):
def to_full(self) -> FullNeighborList:
"""Converts the unique neighbors list to a full neighbors list."""
upper_tri = self._finder_results
lower_tri = np.column_stack(
Expand All @@ -108,10 +123,11 @@ def to_full(self):
self_interactions = (self.i == self.j) & np.all(self.isc == 0, axis=1)
lower_tri = lower_tri[~self_interactions]

# Sort
# Concatenate the lower triangular with the upper triangular part
all_finder_results = np.concatenate([upper_tri, lower_tri], axis=0)

sorted_indices = np.lexsort(all_finder_results[:, [1, 0]])
# Sort by i and then by j
sorted_indices = np.lexsort(all_finder_results[:, [1, 0]].T)
all_finder_results = all_finder_results[sorted_indices]

return FullNeighborList(
Expand Down Expand Up @@ -140,6 +156,17 @@ def __getitem__(self, item):
else:
raise ValueError("Only integer indexing is supported.")

Check warning on line 157 in src/sisl/geom/_neighbors/_finder.py

View check run for this annotation

Codecov / codecov/patch

src/sisl/geom/_neighbors/_finder.py#L157

Added line #L157 was not covered by tests

def to_unique(self) -> UniqueNeighborList:
"""Converts the full neighbors list to a unique neighbors list."""

full_finder_results = self._finder_results
unique_finder_results = full_finder_results[self.i <= self.j]

# Concatenate the uc connections with the rest of the connections.
return UniqueNeighborList(
geometry=self.geometry, finder_results=unique_finder_results
)


class PartialNeighborList(AtomsNeighborList):

Expand Down Expand Up @@ -168,7 +195,7 @@ def __getitem__(self, item):
return AtomNeighborList(
self.geometry,
self._finder_results[start:end],
atom=self.atoms[at],
atom=self.atoms[item],
nsc=self.nsc,
)
else:
Expand All @@ -187,6 +214,11 @@ def __init__(self, geometry, finder_results, atom: int, nsc: np.ndarray):
def nsc(self):
return self._nsc

Check warning on line 215 in src/sisl/geom/_neighbors/_finder.py

View check run for this annotation

Codecov / codecov/patch

src/sisl/geom/_neighbors/_finder.py#L215

Added line #L215 was not covered by tests

@cached_property
def n_neighbors(self):
"""Number of neighbors of the atom."""
return len(self._finder_results)


class PointsNeighborList(Neighbors):
"""List of atoms that are close to a set of points in space."""
Expand Down Expand Up @@ -238,6 +270,11 @@ def __init__(self, geometry, point: np.ndarray, finder_results, nsc):
def nsc(self):
return self._nsc

Check warning on line 271 in src/sisl/geom/_neighbors/_finder.py

View check run for this annotation

Codecov / codecov/patch

src/sisl/geom/_neighbors/_finder.py#L271

Added line #L271 was not covered by tests

@cached_property
def n_neighbors(self):
"""Number of neighbors of the point."""
return len(self._finder_results)


class NeighborFinder:
"""Efficient linear scaling finding of neighbors.
Expand Down Expand Up @@ -774,20 +811,9 @@ def find_unique_pairs(
# just find all neighbors and then drop duplicate connections. Otherwise it is a bit of a mess.
if self._R_too_big:
# Find all neighbors
all_neighbors = self.find_neighbors(
as_pairs=True, self_interaction=self_interaction
)

# Find out which of the pairs are uc connections
is_uc_neigh = ~np.any(all_neighbors[:, 2:], axis=1)

# Create an array with unit cell connections where duplicates are removed
unique_uc = np.unique(np.sort(all_neighbors[is_uc_neigh][:, :2]), axis=0)
uc_neighbors = np.zeros((len(unique_uc), 5), dtype=int)
uc_neighbors[:, :2] = unique_uc
all_neighbors = self.find_neighbors(self_interaction=self_interaction)

# Concatenate the uc connections with the rest of the connections.
return np.concatenate((uc_neighbors, all_neighbors[~is_uc_neigh]))
return all_neighbors.to_unique()

# Cast R into array of appropiate shape and type.
thresholds = np.full(self.geometry.na, self.R, dtype=np.float64)
Expand Down
Loading

0 comments on commit f690753

Please sign in to comment.