Skip to content

Commit

Permalink
fix StressForceNode (#110)
Browse files Browse the repository at this point in the history
* fix StressForceNode

* update changelog

* better implementation

* remove debugging print statements

* add shape checking in ExternalNeighbors layer
  • Loading branch information
shinkle-lanl authored Oct 24, 2024
1 parent 862ed0f commit 957de00
Show file tree
Hide file tree
Showing 5 changed files with 15 additions and 10 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ Improvements:
set the default to use a better value for epsilon.
- Improved detection of valid custom kernel implementation.
- Improved computational efficiency of HIP-NN-TS network.

- ``StressForceNode`` now also works with batch size greater than 1.


Bug Fixes:
Expand Down
2 changes: 1 addition & 1 deletion examples/ase_example_multilayer.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
# Load the files
try:
with active_directory("TEST_ALUMINUM_MODEL_MULTILAYER", create=False):
bundle = load_checkpoint_from_cwd(map_location='cpu',e)
bundle = load_checkpoint_from_cwd(map_location='cpu')
except FileNotFoundError:
raise FileNotFoundError("Model not found, run ani_aluminum_example_multilayer.py first!")

Expand Down
3 changes: 1 addition & 2 deletions hippynn/interfaces/ase_interface/calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,9 +336,8 @@ def calculate(self, atoms=None, properties=None, system_changes=True):
# Convert from ASE distance (angstrom) to whatever the network uses.
positions = positions / self.dist_unit
species = torch.as_tensor(self.atoms.numbers,dtype=torch.long).unsqueeze(0)
cell = torch.as_tensor(self.atoms.cell.array) # ExternalNieghbors doesn't take batch index
cell = torch.as_tensor(self.atoms.cell.array).unsqueeze(0)
# Get pair first and second from neighbors list

pair_first = torch.as_tensor(self.nl.nl.pair_first,dtype=torch.long)
pair_second = torch.as_tensor(self.nl.nl.pair_second,dtype=torch.long)
pair_shiftvecs = torch.as_tensor(self.nl.nl.offset_vec,dtype=torch.long)
Expand Down
5 changes: 2 additions & 3 deletions hippynn/layers/indexers.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,10 +172,9 @@ def __init__(self, *args, **kwargs):
def forward(self, coordinates, cell):
strain = torch.eye(
coordinates.shape[2], dtype=coordinates.dtype, device=coordinates.device, requires_grad=True
).unsqueeze(0)
).tile(coordinates.shape[0],1,1)
strained_coordinates = torch.bmm(coordinates, strain)
if cell.dim() == 2:
strained_cell = torch.mm(cell, strain.squeeze(0))
strained_cell = torch.bmm(cell, strain)
return strained_coordinates, strained_cell, strain


Expand Down
13 changes: 10 additions & 3 deletions hippynn/layers/pairs/indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,18 @@ class ExternalNeighbors(_PairIndexer):
"""

def forward(self, coordinates, real_atoms, shifts, cell, pair_first, pair_second):
n_molecules, n_atoms, _ = coordinates.shape
atom_coordinates = coordinates.reshape(n_molecules * n_atoms, 3)[real_atoms]
if (coordinates.ndim > 3) or (coordinates.ndim == 3 and coordinates.shape[0] != 1):
raise ValueError(f"coordinates must have (n,3) or (1,n,3) but has shape {coordinates.shape}")
if coordinates.ndim == 3:
coordinates = coordinates.squeeze(0)
if (cell.ndim > 3) or (cell.ndim == 3 and cell.shape[0] != 1):
raise ValueError(f"cell must have (3,3) or (1,3,3) but has shape {cell.shape}")
if cell.ndim == 3:
cell = cell.squeeze(0)

atom_coordinates = coordinates[real_atoms]
paircoord = atom_coordinates[pair_second] - atom_coordinates[pair_first] + shifts.to(cell.dtype) @ cell
distflat = paircoord.norm(dim=1)

# We filter the lists to only send forward relevant pairs (those with distance under cutoff), improving performance.
return filter_pairs(self.hard_dist_cutoff, distflat, pair_first, pair_second, paircoord)

Expand Down

0 comments on commit 957de00

Please sign in to comment.