Skip to content

Commit

Permalink
Merge pull request #28 from thorben-frank/global-displacements-in-geo…
Browse files Browse the repository at this point in the history
…metry-embed

add long range indices to GeometryEmbedSparse
  • Loading branch information
thorben-frank authored Apr 22, 2024
2 parents 0c2ff72 + a35dcda commit 04ebed8
Show file tree
Hide file tree
Showing 2 changed files with 112 additions and 2 deletions.
40 changes: 38 additions & 2 deletions mlff/nn/embed/embed_sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import flax.linen as nn
import e3x
import logging

from mlff.nn.base.sub_module import BaseSubModule
from mlff.nn.mlp import Residual
Expand Down Expand Up @@ -123,17 +124,28 @@ def __call__(self, inputs: Dict):
"""
idx_i = inputs['idx_i'] # shape: (num_pairs)
idx_j = inputs['idx_j'] # shape: (num_pairs)
idx_i_lr = inputs.get('idx_i_lr') # shape: (num_pairs_lr)
idx_j_lr = inputs.get('idx_j_lr') # shape: (num_pairs_lr)
cell = inputs.get('cell') # shape: (num_graphs, 3, 3)
cell_offsets = inputs.get('cell_offset') # shape: (num_pairs, 3)
cell_offsets_lr = inputs.get('cell_offset_lr') # shape: (num_pairs, 3)

if self.input_convention == 'positions':
positions = inputs['positions'] # (N, 3)

# Calculate pairwise distance vectors
# Calculate pairwise distance vectors.
r_ij = jax.vmap(
lambda i, j: positions[j] - positions[i]
)(idx_i, idx_j) # (num_pairs, 3)

r_ij_lr = None
# If indices for long range corrections are present they are used.
if idx_i_lr is not None:
# Calculate pairwise distance vectors on long range indices.
r_ij_lr = jax.vmap(
lambda i, j: positions[j] - positions[i]
)(idx_i_lr, idx_j_lr) # (num_pairs_lr, 3)

# Apply minimal image convention if needed.
if cell is not None:
r_ij = add_cell_offsets_sparse(
Expand All @@ -142,16 +154,39 @@ def __call__(self, inputs: Dict):
cell_offsets=cell_offsets
) # shape: (num_pairs,3)

if idx_i_lr is not None:
if cell_offsets_lr is None:
raise ValueError(
'`cell_offsets_lr` are required in GeometryEmbed when using global indices with periodic'
'boundary conditions.'
)
logging.warning(
'The use of long range indices with PBCs has not been tested thoroughly yet, so use with care!'
)

r_ij_lr = add_cell_offsets_sparse(
r_ij=r_ij_lr,
cell=cell,
cell_offsets=cell_offsets_lr
) # shape: (num_pairs_lr,3)

# Here it is assumed that PBC (if present) have already been respected in displacement calculation.
elif self.input_convention == 'displacements':
positions = None
r_ij = inputs['displacements']
r_ij = inputs['displacements'] # shape : (num_pairs, 3)
r_ij_lr = inputs.get('displacements_lr') # shape : (num_pairs_lr, 3)
else:
raise ValueError(f"{self.input_convention} is not a valid argument for `input_convention`.")

# Calculate pairwise distances.
d_ij = safe_norm(r_ij, axis=-1) # shape : (num_pairs)

if r_ij_lr is not None:
d_ij_lr = safe_norm(r_ij_lr, axis=-1) # shape : (num_pairs_lr)
del r_ij_lr
else:
d_ij_lr = None

# Gaussian basis expansion of distances.
rbf_ij = self.rbf_fn(jnp.expand_dims(d_ij, axis=-1)) # shape: (num_pairs, num_radial_basis_fn)

Expand All @@ -173,6 +208,7 @@ def __call__(self, inputs: Dict):
'r_ij': r_ij,
'unit_r_ij': unit_r_ij,
'd_ij': d_ij,
'd_ij_lr': d_ij_lr,
'rbf_ij': rbf_ij,
'cut': cut,
'ylm_ij': ylm_ij,
Expand Down
74 changes: 74 additions & 0 deletions tests/test_geometry_embed_sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@

degrees = [0, 1, 2]

atomic_numbers = jnp.array([1, 11, 31])

positions = jnp.array([
[0., 0., 0.],
[0., -2., 0.],
Expand All @@ -18,8 +20,11 @@

idx_i = jnp.array([0, 0, 1, 2])
idx_j = jnp.array([1, 2, 0, 0])
idx_i_lr = jnp.array([0, 0, 1, 1, 2, 2])
idx_j_lr = jnp.array([1, 2, 0, 2, 0, 1])

inputs = dict(positions=positions,
atomic_numbers=atomic_numbers,
idx_i=idx_i,
idx_j=idx_j,
cell=None,
Expand Down Expand Up @@ -61,6 +66,7 @@ def test_apply(cutoff: float):
npt.assert_equal(output.get('ylm_ij').shape, (4, 9))
npt.assert_equal(output.get('rbf_ij').shape, (4, 16))
npt.assert_allclose(output.get('d_ij'), jnp.array([2., jnp.sqrt(1.25), 2., jnp.sqrt(1.25)]))
npt.assert_equal(output.get('d_ij_lr'), None)
npt.assert_allclose(
output.get('r_ij'),
jnp.array(
Expand Down Expand Up @@ -96,3 +102,71 @@ def test_apply(cutoff: float):
raise RuntimeError('Invalid test argument.')


def test_apply_with_long_range():
geometry_embed = GeometryEmbedSparse(degrees=degrees,
radial_basis_fn='bernstein',
num_radial_basis_fn=16,
cutoff_fn='exponential',
cutoff=10.,
input_convention='positions',
prop_keys=None)

params = geometry_embed.init(
jax.random.PRNGKey(0),
inputs
)

inputs.update(
dict(
idx_i_lr=idx_i_lr,
idx_j_lr=idx_j_lr
)
)

output = geometry_embed.apply(params, inputs)

npt.assert_equal(output.get('ylm_ij').shape, (4, 9))
npt.assert_equal(output.get('rbf_ij').shape, (4, 16))
npt.assert_allclose(
output.get('d_ij'),
jnp.array([2., jnp.sqrt(1.25), 2., jnp.sqrt(1.25)])
)
npt.assert_allclose(
output.get('r_ij'),
jnp.array(
[
[0.0, -2.0, 0.0],
[1.0, 0.5, 0.0],
[0.0, 2.0, 0.0],
[-1.0, -0.5, 0.0],
]
)
)
npt.assert_allclose(
output.get('unit_r_ij'),
jnp.array(
[
[0.0, -1.0, 0.0],
[1.0/jnp.sqrt(1.25), 0.5/jnp.sqrt(1.25), 0.0],
[0.0, 1.0, 0.0],
[-1.0/jnp.sqrt(1.25), -0.5/jnp.sqrt(1.25), 0.0],
]
)
)
npt.assert_allclose(
output.get('d_ij_lr'),
jnp.array([2., jnp.sqrt(1.25), 2., jnp.sqrt(2.5**2 + 1), jnp.sqrt(1.25), jnp.sqrt(2.5**2 + 1)])
)

# if cutoff == 2.5:
# npt.assert_allclose(output.get('cut') > 0., jnp.array([True, True, True, True]))
# with npt.assert_raises(AssertionError):
# npt.assert_allclose(output.get('cut') > 0., jnp.array([False, True, False, True]))
# elif cutoff == 1.5:
# npt.assert_allclose(output.get('cut') > 0., jnp.array([False, True, False, True]))
# with npt.assert_raises(AssertionError):
# npt.assert_allclose(output.get('cut') > 0., jnp.array([True, True, True, True]))
# else:
# raise RuntimeError('Invalid test argument.')


0 comments on commit 04ebed8

Please sign in to comment.