Skip to content

Commit

Permalink
Merge pull request #27 from thorben-frank/charge-spin-embedding
Browse files Browse the repository at this point in the history
Charge spin embedding
  • Loading branch information
thorben-frank authored Mar 26, 2024
2 parents 965932a + a55af7e commit a114235
Show file tree
Hide file tree
Showing 14 changed files with 608 additions and 94 deletions.
2 changes: 2 additions & 0 deletions mlff/config/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ model:
layer_normalization_2: false # Use layer normalization after the second residual mlp.
layers_behave_like_identity_fn_at_init: false # The message passing layers behave like the identity function at initialization.
output_is_zero_at_init: true # The output of the full network is zero at initialization.
use_charge_embed: false # Use embedding module for total charge.
use_spin_embed: false # Use embedding module for number of unpaired electrons.
energy_regression_dim: 128 # Dimension to which final features are projected, before atomic energies are calculated.
energy_activation_fn: identity # Activation function to use on the energy_regression_dim before atomic energies are calculated.
energy_learn_atomic_type_scales: false
Expand Down
2 changes: 2 additions & 0 deletions mlff/config/config_itp_net.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ model:
include_pseudotensors: false
message_normalization: avg_num_neighbors # How to normalize the message function. Options are (identity, sqrt_num_features, avg_num_neighbors)
output_is_zero_at_init: true # The output of the full network is zero at initialization.
use_charge_embed: false # Use embedding module for total charge.
use_spin_embed: false # Use embedding module for number of unpaired electrons.
energy_regression_dim: 128 # Dimension to which final features are projected, before atomic energies are calculated.
energy_activation_fn: silu # Activation function to use on the energy_regression_dim before atomic energies are calculated.
energy_learn_atomic_type_scales: false
Expand Down
8 changes: 6 additions & 2 deletions mlff/config/from_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,12 @@ def make_so3krates_sparse_from_config(config: config_dict.ConfigDict = None):
layers_behave_like_identity_fn_at_init=model_config.layers_behave_like_identity_fn_at_init,
output_is_zero_at_init=model_config.output_is_zero_at_init,
input_convention=model_config.input_convention,
use_charge_embed=model_config.use_charge_embed,
use_spin_embed=model_config.use_spin_embed,
energy_regression_dim=model_config.energy_regression_dim,
energy_activation_fn=model_config.energy_activation_fn,
energy_learn_atomic_type_scales=model_config.energy_learn_atomic_type_scales,
energy_learn_atomic_type_shifts=model_config.energy_learn_atomic_type_shifts
energy_learn_atomic_type_shifts=model_config.energy_learn_atomic_type_shifts,
)


Expand Down Expand Up @@ -95,10 +97,12 @@ def make_itp_net_from_config(config: config_dict.ConfigDict):
avg_num_neighbors=config.data.avg_num_neighbors if config.model.message_normalization == 'avg_num_neighbors' else None,
output_is_zero_at_init=model_config.output_is_zero_at_init,
input_convention=model_config.input_convention,
use_charge_embed=model_config.use_charge_embed,
use_spin_embed=model_config.use_spin_embed,
energy_regression_dim=model_config.energy_regression_dim,
energy_activation_fn=model_config.energy_activation_fn,
energy_learn_atomic_type_scales=model_config.energy_learn_atomic_type_scales,
energy_learn_atomic_type_shifts=model_config.energy_learn_atomic_type_shifts
energy_learn_atomic_type_shifts=model_config.energy_learn_atomic_type_shifts,
)


Expand Down
10 changes: 7 additions & 3 deletions mlff/data/dataloader_sparse_npz.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,8 @@ def entry_to_jraph(

atomic_numbers = entry['atomic_numbers']
positions = entry['positions']
total_charge = entry.get('total_charge')
num_unpaired_electrons = entry.get('num_unpaired_electrons')
forces = entry.get('forces')
energy = entry.get('energy')
stress = entry.get('stress')
Expand All @@ -107,7 +109,7 @@ def entry_to_jraph(
"positions": np.array(positions),
"atomic_numbers": np.array(atomic_numbers, dtype=np.int64),
"forces": np.array(forces),
}
}

senders = np.array(j)
receivers = np.array(i)
Expand All @@ -116,8 +118,10 @@ def entry_to_jraph(
n_edge = np.array([len(i)])

global_context = {
"energy": np.array(energy) if energy is not None else None,
"stress": np.array(stress) if stress is not None else None
"energy": np.array(energy).reshape(-1) if energy is not None else None,
"stress": np.array(stress) if stress is not None else None,
"total_charge": np.array(total_charge, dtype=np.int16).reshape(-1) if total_charge is not None else None,
"num_unpaired_electrons": np.array(num_unpaired_electrons, dtype=np.int16).reshape(-1) if num_unpaired_electrons is not None else None
}

return jraph.GraphsTuple(
Expand Down
4 changes: 3 additions & 1 deletion mlff/nn/embed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@
from .embed_sparse import (
GeometryEmbedSparse,
GeometryEmbedE3x,
AtomTypeEmbedSparse
AtomTypeEmbedSparse,
SpinEmbedSparse,
ChargeEmbedSparse
)

from .h_register import get_embedding_module
181 changes: 180 additions & 1 deletion mlff/nn/embed/embed_sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@
import jax.numpy as jnp

from functools import partial
from typing import (Any, Dict, Sequence)
from typing import (Any, Callable, Dict, Sequence)

import flax.linen as nn
import e3x

from mlff.nn.base.sub_module import BaseSubModule
from mlff.nn.mlp import Residual
from mlff.masking.mask import safe_mask
from mlff.masking.mask import safe_norm
from mlff.cutoff_function import add_cell_offsets_sparse
Expand Down Expand Up @@ -221,3 +222,181 @@ def __dict_repr__(self):
return {self.module_name: {'num_features': self.num_features,
'zmax': self.zmax,
'prop_keys': self.prop_keys}}


class ChargeSpinEmbedSparse(nn.Module):
num_features: int
activation_fn: str = 'silu'
zmax: int = 118

@nn.compact
def __call__(self,
atomic_numbers: jnp.ndarray,
psi: jnp.ndarray,
batch_segments: jnp.ndarray,
graph_mask: jnp.ndarray,
*args,
**kwargs) -> jnp.ndarray:
"""
Create atomic embeddings based on the total charge or the number of unpaired spins in the system, following the
embedding procedure introduced in SpookyNet. Returns per atom embeddings of dimension F.
Args:
z (Array): Atomic types, shape: (N)
psi (Array): Total charge or number of unpaired spins, shape: (num_graphs)
batch_segment (Array): (N)
graph_mask (Array): Mask for atom-wise operations, shape: (num_graphs)
*args ():
**kwargs ():
Returns: Per atom embedding, shape: (n,F)
"""
assert psi.ndim == 1

q = nn.Embed(
num_embeddings=self.zmax + 1,
features=self.num_features
)(atomic_numbers) # shape: (N,F)

psi_ = psi // jnp.inf # -1 if psi < 0 and 0 otherwise
psi_ = psi_.astype(jnp.int32) # shape: (num_graphs)

k = nn.Embed(
num_embeddings=2,
features=self.num_features
)(psi_)[batch_segments] # shape: (N, F)

v = nn.Embed(
num_embeddings=2,
features=self.num_features
)(psi_)[batch_segments] # shape: (N, F)

q_x_k = (q*k).sum(axis=-1) / jnp.sqrt(self.num_features) # shape: (N)

y = nn.softplus(q_x_k) # shape: (N)
denominator = jax.ops.segment_sum(
y,
segment_ids=batch_segments,
num_segments=len(graph_mask)
) # (num_graphs)

denominator = jnp.where(
graph_mask,
denominator,
jnp.asarray(1., dtype=q.dtype)
) # (num_graphs)

a = psi[batch_segments] * y / denominator[batch_segments] # shape: (N)
e_psi = Residual(
use_bias=False,
activation_fn=getattr(jax.nn, self.activation_fn) if self.activation_fn != 'identity' else lambda u: u
)(jnp.expand_dims(a, axis=-1) * v) # shape: (N, F)

return e_psi


class ChargeEmbedSparse(BaseSubModule):
prop_keys: Dict
num_features: int
activation_fn: str = 'silu'
zmax: int = 118
module_name: str = 'charge_embed_sparse'

@nn.compact
def __call__(self,
inputs: Dict,
*args,
**kwargs):
"""
Args:
inputs (Dict):
atomic_numbers (Array): atomic types, shape: (N)
total_charge (Array): total charge, shape: (num_graphs)
graph_mask (Array): (num_graphs)
batch_segments (Array): (N)
*args ():
**kwargs ():
Returns:
"""
atomic_numbers = inputs['atomic_numbers']
Q = inputs['total_charge']
graph_mask = inputs['graph_mask']
batch_segments = inputs['batch_segments']

if Q is None:
raise ValueError(
f'ChargeEmbedSparse requires to pass `total_charge != None`.'
)

return ChargeSpinEmbedSparse(
zmax=self.zmax,
num_features=self.num_features,
activation_fn=self.activation_fn
)(
atomic_numbers=atomic_numbers,
psi=Q,
batch_segments=batch_segments,
graph_mask=graph_mask
)

def __dict_repr__(self):
return {self.module_name: {'num_features': self.num_features,
'zmax': self.zmax,
'prop_keys': self.prop_keys}}


class SpinEmbedSparse(BaseSubModule):
prop_keys: Dict
num_features: int
activation_fn: str = 'silu'
zmax: int = 118
module_name: str = 'spin_embed_sparse'

@nn.compact
def __call__(self,
inputs: Dict,
*args,
**kwargs):
"""
Args:
inputs (Dict):
atomic_numbers (Array): atomic types, shape: (N)
num_unpaired_electrons (Array): total charge, shape: (num_graphs)
graph_mask (Array): (num_graphs)
batch_segments (Array): (N)
*args ():
**kwargs ():
Returns:
"""
atomic_numbers = inputs['atomic_numbers']
S = inputs['num_unpaired_electrons']
graph_mask = inputs['graph_mask']
batch_segments = inputs['batch_segments']

if S is None:
raise ValueError(
f'SpinEmbedSparse requires to pass `num_unpaired_electrons != None`.'
)

return ChargeSpinEmbedSparse(
zmax=self.zmax,
num_features=self.num_features,
activation_fn=self.activation_fn
)(
atomic_numbers=atomic_numbers,
psi=S,
batch_segments=batch_segments,
graph_mask=graph_mask
)

def __dict_repr__(self):
return {self.module_name: {'num_features': self.num_features,
'zmax': self.zmax,
'prop_keys': self.prop_keys}}
8 changes: 7 additions & 1 deletion mlff/nn/embed/h_register.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,19 @@
from .embed_sparse import (
GeometryEmbedSparse,
GeometryEmbedE3x,
AtomTypeEmbedSparse
AtomTypeEmbedSparse,
SpinEmbedSparse,
ChargeEmbedSparse
)


def get_embedding_module(name: str, h: Dict):
if name == 'atom_type_embed':
return AtomTypeEmbed(**h)
elif name == 'spin_embed_sparse':
return SpinEmbedSparse(**h)
elif name == 'charge_embed_sparse':
return ChargeEmbedSparse(**h)
elif name == 'atom_type_embed_sparse':
return AtomTypeEmbedSparse(**h)
elif name == 'geometry_embed':
Expand Down
2 changes: 1 addition & 1 deletion mlff/nn/mlp/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from .mlp import MLP, ResidualMLP
from .mlp import MLP, Residual, ResidualMLP
12 changes: 9 additions & 3 deletions mlff/nn/representation/itp_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from mlff.nn.embed import GeometryEmbedE3x, AtomTypeEmbedSparse
from mlff.nn.layer import ITPLayer
from mlff.nn.observable import EnergySparse
from .representation_utils import make_embedding_modules

from typing import Optional, Sequence


Expand All @@ -29,16 +31,20 @@ def init_itp_net(
feature_collection_over_layers: str = 'final',
include_pseudotensors: bool = False,
output_is_zero_at_init: bool = True,
use_charge_embed: bool = False,
use_spin_embed: bool = False,
energy_regression_dim: int = 128,
energy_activation_fn: str = 'identity',
energy_learn_atomic_type_scales: bool = False,
energy_learn_atomic_type_shifts: bool = False,
input_convention: str = 'positions'
):
atom_type_embed = AtomTypeEmbedSparse(
embedding_modules = make_embedding_modules(
num_features=num_features,
prop_keys=None
use_spin_embed=use_spin_embed,
use_charge_embed=use_charge_embed
)

geometry_embed = GeometryEmbedE3x(
max_degree=mp_max_degree,
radial_basis_fn=radial_basis_fn,
Expand Down Expand Up @@ -80,7 +86,7 @@ def init_itp_net(

return StackNetSparse(
geometry_embeddings=[geometry_embed],
feature_embeddings=[atom_type_embed],
feature_embeddings=embedding_modules,
layers=layers,
observables=[energy],
prop_keys=None
Expand Down
32 changes: 32 additions & 0 deletions mlff/nn/representation/representation_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from mlff.nn.embed import AtomTypeEmbedSparse, ChargeEmbedSparse, SpinEmbedSparse


def make_embedding_modules(
num_features: int,
use_charge_embed: bool,
use_spin_embed: bool
):
embedding_modules = []
atom_type_embed = AtomTypeEmbedSparse(
num_features=num_features,
prop_keys=None
)
embedding_modules.append(atom_type_embed)

# Embed the total charge.
if use_charge_embed:
charge_embed = ChargeEmbedSparse(
num_features=num_features,
prop_keys=None
)
embedding_modules.append(charge_embed)

# Embed the number of unpaired electrons.
if use_spin_embed:
spin_embed = SpinEmbedSparse(
num_features=num_features,
prop_keys=None
)
embedding_modules.append(spin_embed)

return embedding_modules
Loading

0 comments on commit a114235

Please sign in to comment.