Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix StressForceNode #110

Merged
merged 5 commits into from
Oct 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ Improvements:
set the default to use a better value for epsilon.
- Improved detection of valid custom kernel implementation.
- Improved computational efficiency of HIP-NN-TS network.

- ``StressForceNode`` now also works with batch size greater than 1.


Bug Fixes:
Expand Down
2 changes: 1 addition & 1 deletion examples/ase_example_multilayer.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
# Load the files
try:
with active_directory("TEST_ALUMINUM_MODEL_MULTILAYER", create=False):
bundle = load_checkpoint_from_cwd(map_location='cpu',e)
bundle = load_checkpoint_from_cwd(map_location='cpu')
except FileNotFoundError:
raise FileNotFoundError("Model not found, run ani_aluminum_example_multilayer.py first!")

Expand Down
3 changes: 1 addition & 2 deletions hippynn/interfaces/ase_interface/calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,9 +336,8 @@ def calculate(self, atoms=None, properties=None, system_changes=True):
# Convert from ASE distance (angstrom) to whatever the network uses.
positions = positions / self.dist_unit
species = torch.as_tensor(self.atoms.numbers,dtype=torch.long).unsqueeze(0)
cell = torch.as_tensor(self.atoms.cell.array) # ExternalNieghbors doesn't take batch index
cell = torch.as_tensor(self.atoms.cell.array).unsqueeze(0)
# Get pair first and second from neighbors list

pair_first = torch.as_tensor(self.nl.nl.pair_first,dtype=torch.long)
pair_second = torch.as_tensor(self.nl.nl.pair_second,dtype=torch.long)
pair_shiftvecs = torch.as_tensor(self.nl.nl.offset_vec,dtype=torch.long)
Expand Down
5 changes: 2 additions & 3 deletions hippynn/layers/indexers.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,10 +172,9 @@ def __init__(self, *args, **kwargs):
def forward(self, coordinates, cell):
strain = torch.eye(
coordinates.shape[2], dtype=coordinates.dtype, device=coordinates.device, requires_grad=True
).unsqueeze(0)
).tile(coordinates.shape[0],1,1)
strained_coordinates = torch.bmm(coordinates, strain)
if cell.dim() == 2:
strained_cell = torch.mm(cell, strain.squeeze(0))
strained_cell = torch.bmm(cell, strain)
return strained_coordinates, strained_cell, strain


Expand Down
13 changes: 10 additions & 3 deletions hippynn/layers/pairs/indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,18 @@ class ExternalNeighbors(_PairIndexer):
"""

def forward(self, coordinates, real_atoms, shifts, cell, pair_first, pair_second):
n_molecules, n_atoms, _ = coordinates.shape
atom_coordinates = coordinates.reshape(n_molecules * n_atoms, 3)[real_atoms]
if (coordinates.ndim > 3) or (coordinates.ndim == 3 and coordinates.shape[0] != 1):
raise ValueError(f"coordinates must have (n,3) or (1,n,3) but has shape {coordinates.shape}")
if coordinates.ndim == 3:
coordinates = coordinates.squeeze(0)
if (cell.ndim > 3) or (cell.ndim == 3 and cell.shape[0] != 1):
raise ValueError(f"cell must have (3,3) or (1,3,3) but has shape {cell.shape}")
if cell.ndim == 3:
cell = cell.squeeze(0)

atom_coordinates = coordinates[real_atoms]
paircoord = atom_coordinates[pair_second] - atom_coordinates[pair_first] + shifts.to(cell.dtype) @ cell
distflat = paircoord.norm(dim=1)

# We filter the lists to only send forward relevant pairs (those with distance under cutoff), improving performance.
return filter_pairs(self.hard_dist_cutoff, distflat, pair_first, pair_second, paircoord)

Expand Down
Loading