Skip to content

Commit

Permalink
add charge spin embed to CLI pipeline
Browse files Browse the repository at this point in the history
  • Loading branch information
thorben-frank committed Mar 26, 2024
1 parent 72f5aa7 commit a55af7e
Show file tree
Hide file tree
Showing 10 changed files with 97 additions and 91 deletions.
2 changes: 2 additions & 0 deletions mlff/config/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ model:
layer_normalization_2: false # Use layer normalization after the second residual mlp.
layers_behave_like_identity_fn_at_init: false # The message passing layers behave like the identity function at initialization.
output_is_zero_at_init: true # The output of the full network is zero at initialization.
use_charge_embed: false # Use embedding module for total charge.
use_spin_embed: false # Use embedding module for number of unpaired electrons.
energy_regression_dim: 128 # Dimension to which final features are projected, before atomic energies are calculated.
energy_activation_fn: identity # Activation function to use on the energy_regression_dim before atomic energies are calculated.
energy_learn_atomic_type_scales: false
Expand Down
2 changes: 2 additions & 0 deletions mlff/config/config_itp_net.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ model:
include_pseudotensors: false
message_normalization: avg_num_neighbors # How to normalize the message function. Options are (identity, sqrt_num_features, avg_num_neighbors)
output_is_zero_at_init: true # The output of the full network is zero at initialization.
use_charge_embed: false # Use embedding module for total charge.
use_spin_embed: false # Use embedding module for number of unpaired electrons.
energy_regression_dim: 128 # Dimension to which final features are projected, before atomic energies are calculated.
energy_activation_fn: silu # Activation function to use on the energy_regression_dim before atomic energies are calculated.
energy_learn_atomic_type_scales: false
Expand Down
8 changes: 6 additions & 2 deletions mlff/config/from_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,12 @@ def make_so3krates_sparse_from_config(config: config_dict.ConfigDict = None):
layers_behave_like_identity_fn_at_init=model_config.layers_behave_like_identity_fn_at_init,
output_is_zero_at_init=model_config.output_is_zero_at_init,
input_convention=model_config.input_convention,
use_charge_embed=model_config.use_charge_embed,
use_spin_embed=model_config.use_spin_embed,
energy_regression_dim=model_config.energy_regression_dim,
energy_activation_fn=model_config.energy_activation_fn,
energy_learn_atomic_type_scales=model_config.energy_learn_atomic_type_scales,
energy_learn_atomic_type_shifts=model_config.energy_learn_atomic_type_shifts
energy_learn_atomic_type_shifts=model_config.energy_learn_atomic_type_shifts,
)


Expand Down Expand Up @@ -95,10 +97,12 @@ def make_itp_net_from_config(config: config_dict.ConfigDict):
avg_num_neighbors=config.data.avg_num_neighbors if config.model.message_normalization == 'avg_num_neighbors' else None,
output_is_zero_at_init=model_config.output_is_zero_at_init,
input_convention=model_config.input_convention,
use_charge_embed=model_config.use_charge_embed,
use_spin_embed=model_config.use_spin_embed,
energy_regression_dim=model_config.energy_regression_dim,
energy_activation_fn=model_config.energy_activation_fn,
energy_learn_atomic_type_scales=model_config.energy_learn_atomic_type_scales,
energy_learn_atomic_type_shifts=model_config.energy_learn_atomic_type_shifts
energy_learn_atomic_type_shifts=model_config.energy_learn_atomic_type_shifts,
)


Expand Down
10 changes: 7 additions & 3 deletions mlff/data/dataloader_sparse_npz.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,8 @@ def entry_to_jraph(

atomic_numbers = entry['atomic_numbers']
positions = entry['positions']
total_charge = entry.get('total_charge')
num_unpaired_electrons = entry.get('num_unpaired_electrons')
forces = entry.get('forces')
energy = entry.get('energy')
stress = entry.get('stress')
Expand All @@ -107,7 +109,7 @@ def entry_to_jraph(
"positions": np.array(positions),
"atomic_numbers": np.array(atomic_numbers, dtype=np.int64),
"forces": np.array(forces),
}
}

senders = np.array(j)
receivers = np.array(i)
Expand All @@ -116,8 +118,10 @@ def entry_to_jraph(
n_edge = np.array([len(i)])

global_context = {
"energy": np.array(energy) if energy is not None else None,
"stress": np.array(stress) if stress is not None else None
"energy": np.array(energy).reshape(-1) if energy is not None else None,
"stress": np.array(stress) if stress is not None else None,
"total_charge": np.array(total_charge, dtype=np.int16).reshape(-1) if total_charge is not None else None,
"num_unpaired_electrons": np.array(num_unpaired_electrons, dtype=np.int16).reshape(-1) if num_unpaired_electrons is not None else None
}

return jraph.GraphsTuple(
Expand Down
13 changes: 12 additions & 1 deletion mlff/nn/embed/embed_sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,7 @@ def __call__(self,
Returns: Per atom embedding, shape: (n,F)
"""
assert psi.ndim == 1

q = nn.Embed(
num_embeddings=self.zmax + 1,
Expand Down Expand Up @@ -289,7 +290,7 @@ def __call__(self,
a = psi[batch_segments] * y / denominator[batch_segments] # shape: (N)
e_psi = Residual(
use_bias=False,
activation_fn=getattr(jax.nn, self.activation_fn) if self.activation_fn is not 'identity' else lambda u: u
activation_fn=getattr(jax.nn, self.activation_fn) if self.activation_fn != 'identity' else lambda u: u
)(jnp.expand_dims(a, axis=-1) * v) # shape: (N, F)

return e_psi
Expand Down Expand Up @@ -326,6 +327,11 @@ def __call__(self,
graph_mask = inputs['graph_mask']
batch_segments = inputs['batch_segments']

if Q is None:
raise ValueError(
f'ChargeEmbedSparse requires to pass `total_charge != None`.'
)

return ChargeSpinEmbedSparse(
zmax=self.zmax,
num_features=self.num_features,
Expand Down Expand Up @@ -374,6 +380,11 @@ def __call__(self,
graph_mask = inputs['graph_mask']
batch_segments = inputs['batch_segments']

if S is None:
raise ValueError(
f'SpinEmbedSparse requires to pass `num_unpaired_electrons != None`.'
)

return ChargeSpinEmbedSparse(
zmax=self.zmax,
num_features=self.num_features,
Expand Down
12 changes: 9 additions & 3 deletions mlff/nn/representation/itp_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from mlff.nn.embed import GeometryEmbedE3x, AtomTypeEmbedSparse
from mlff.nn.layer import ITPLayer
from mlff.nn.observable import EnergySparse
from .representation_utils import make_embedding_modules

from typing import Optional, Sequence


Expand All @@ -29,16 +31,20 @@ def init_itp_net(
feature_collection_over_layers: str = 'final',
include_pseudotensors: bool = False,
output_is_zero_at_init: bool = True,
use_charge_embed: bool = False,
use_spin_embed: bool = False,
energy_regression_dim: int = 128,
energy_activation_fn: str = 'identity',
energy_learn_atomic_type_scales: bool = False,
energy_learn_atomic_type_shifts: bool = False,
input_convention: str = 'positions'
):
atom_type_embed = AtomTypeEmbedSparse(
embedding_modules = make_embedding_modules(
num_features=num_features,
prop_keys=None
use_spin_embed=use_spin_embed,
use_charge_embed=use_charge_embed
)

geometry_embed = GeometryEmbedE3x(
max_degree=mp_max_degree,
radial_basis_fn=radial_basis_fn,
Expand Down Expand Up @@ -80,7 +86,7 @@ def init_itp_net(

return StackNetSparse(
geometry_embeddings=[geometry_embed],
feature_embeddings=[atom_type_embed],
feature_embeddings=embedding_modules,
layers=layers,
observables=[energy],
prop_keys=None
Expand Down
32 changes: 32 additions & 0 deletions mlff/nn/representation/representation_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from mlff.nn.embed import AtomTypeEmbedSparse, ChargeEmbedSparse, SpinEmbedSparse


def make_embedding_modules(
num_features: int,
use_charge_embed: bool,
use_spin_embed: bool
):
embedding_modules = []
atom_type_embed = AtomTypeEmbedSparse(
num_features=num_features,
prop_keys=None
)
embedding_modules.append(atom_type_embed)

# Embed the total charge.
if use_charge_embed:
charge_embed = ChargeEmbedSparse(
num_features=num_features,
prop_keys=None
)
embedding_modules.append(charge_embed)

# Embed the number of unpaired electrons.
if use_spin_embed:
spin_embed = SpinEmbedSparse(
num_features=num_features,
prop_keys=None
)
embedding_modules.append(spin_embed)

return embedding_modules
13 changes: 9 additions & 4 deletions mlff/nn/representation/so3krates_sparse.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import flax.linen as nn
import jax
from mlff.nn.stacknet import StackNetSparse
from mlff.nn.embed import GeometryEmbedSparse, AtomTypeEmbedSparse
from mlff.nn.embed import GeometryEmbedSparse, AtomTypeEmbedSparse, ChargeEmbedSparse, SpinEmbedSparse
from mlff.nn.layer import SO3kratesLayerSparse
from mlff.nn.observable import EnergySparse
from .representation_utils import make_embedding_modules
from typing import Sequence


Expand All @@ -27,16 +28,20 @@ def init_so3krates_sparse(
activation_fn: str = 'silu',
layers_behave_like_identity_fn_at_init: bool = False,
output_is_zero_at_init: bool = True,
use_charge_embed: bool = False,
use_spin_embed: bool = False,
energy_regression_dim: int = 128,
energy_activation_fn: str = 'identity',
energy_learn_atomic_type_scales: bool = False,
energy_learn_atomic_type_shifts: bool = False,
input_convention: str = 'positions'
):
atom_type_embed = AtomTypeEmbedSparse(
embedding_modules = make_embedding_modules(
num_features=num_features,
prop_keys=None
use_spin_embed=use_spin_embed,
use_charge_embed=use_charge_embed
)

geometry_embed = GeometryEmbedSparse(
degrees=degrees,
radial_basis_fn=radial_basis_fn,
Expand Down Expand Up @@ -76,7 +81,7 @@ def init_so3krates_sparse(

return StackNetSparse(
geometry_embeddings=[geometry_embed],
feature_embeddings=[atom_type_embed],
feature_embeddings=embedding_modules,
layers=layers,
observables=[energy],
prop_keys=None
Expand Down
94 changes: 16 additions & 78 deletions mlff/nn/stacknet/observable_function_sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ def observable_fn(
atomic_numbers: jnp.ndarray,
idx_i: jnp.ndarray,
idx_j: jnp.ndarray,
total_charge: jnp.ndarray = None,
num_unpaired_electrons: jnp.ndarray = None,
cell: jnp.ndarray = None,
cell_offset: jnp.ndarray = None,
batch_segments: jnp.ndarray = None,
Expand All @@ -53,6 +55,8 @@ def observable_fn(
atomic_numbers=atomic_numbers,
idx_i=idx_i,
idx_j=idx_j,
total_charge=total_charge,
num_unpaired_electrons=num_unpaired_electrons,
cell=cell,
cell_offset=cell_offset,
batch_segments=batch_segments,
Expand All @@ -67,6 +71,8 @@ def observable_fn(
atomic_numbers: jnp.ndarray,
idx_i: jnp.ndarray,
idx_j: jnp.ndarray,
total_charge: jnp.ndarray = None,
num_unpaired_electrons: jnp.ndarray = None,
cell: jnp.ndarray = None,
cell_offset: jnp.ndarray = None,
batch_segments: jnp.ndarray = None,
Expand All @@ -88,6 +94,8 @@ def observable_fn(
atomic_numbers=atomic_numbers,
idx_i=idx_i,
idx_j=idx_j,
total_charge=total_charge,
num_unpaired_electrons=num_unpaired_electrons,
cell=cell,
cell_offset=cell_offset,
batch_segments=batch_segments,
Expand All @@ -105,6 +113,8 @@ def energy_fn(params,
atomic_numbers: jnp.ndarray,
idx_i: jnp.ndarray,
idx_j: jnp.ndarray,
total_charge: jnp.ndarray = None,
num_unpaired_electrons: jnp.ndarray = None,
cell: jnp.ndarray = None,
cell_offset: jnp.ndarray = None,
batch_segments: jnp.ndarray = None,
Expand All @@ -122,6 +132,8 @@ def energy_fn(params,
atomic_numbers=atomic_numbers,
idx_i=idx_i,
idx_j=idx_j,
total_charge=total_charge,
num_unpaired_electrons=num_unpaired_electrons,
cell=cell,
cell_offset=cell_offset,
batch_segments=batch_segments,
Expand All @@ -138,6 +150,8 @@ def energy_and_force_fn(params,
atomic_numbers: jnp.ndarray,
idx_i: jnp.ndarray,
idx_j: jnp.ndarray,
total_charge: jnp.ndarray = None,
num_unpaired_electrons: jnp.ndarray = None,
cell: jnp.ndarray = None,
cell_offset: jnp.ndarray = None,
batch_segments: jnp.ndarray = None,
Expand All @@ -153,6 +167,8 @@ def energy_and_force_fn(params,
atomic_numbers,
idx_i,
idx_j,
total_charge,
num_unpaired_electrons,
cell,
cell_offset,
batch_segments,
Expand All @@ -163,81 +179,3 @@ def energy_and_force_fn(params,
return dict(energy=energy, forces=forces)

return energy_and_force_fn


# def get_energy_and_force_fn_sparse(model: StackNetSparse):
# def energy_fn(params,
# positions: jnp.ndarray,
# atomic_numbers: jnp.ndarray,
# idx_i: jnp.ndarray,
# idx_j: jnp.ndarray,
# cell: jnp.ndarray = None,
# cell_offset: jnp.ndarray = None,
# batch_segments: jnp.ndarray = None,
# node_mask: jnp.ndarray = None,
# graph_mask: jnp.ndarray = None):
# if batch_segments is None:
# assert graph_mask is None
# assert node_mask is None
#
# graph_mask = jnp.ones((1,)).astype(jnp.bool_) # (1)
# node_mask = jnp.ones((len(positions),)).astype(jnp.bool_) # (num_nodes)
# batch_segments = jnp.zeros_like(atomic_numbers) # (num_nodes)
#
# inputs = dict(positions=positions,
# atomic_numbers=atomic_numbers,
# idx_i=idx_i,
# idx_j=idx_j,
# cell=cell,
# cell_offset=cell_offset,
# batch_segments=batch_segments,
# node_mask=node_mask,
# graph_mask=graph_mask
# )
#
# energy = model.apply(params, inputs)['energy'] # (num_graphs)
# energy = jnp.where(graph_mask, energy, jnp.asarray(0., dtype=energy.dtype)) # (num_graphs)
# return -jnp.sum(energy), energy # (), (num_graphs)
#
# def energy_and_force_fn(params,
# positions: jnp.ndarray,
# atomic_numbers: jnp.ndarray,
# idx_i: jnp.ndarray,
# idx_j: jnp.ndarray,
# cell: jnp.ndarray = None,
# cell_offset: jnp.ndarray = None,
# batch_segments: jnp.ndarray = None,
# node_mask: jnp.ndarray = None,
# graph_mask: jnp.ndarray = None,
# *args,
# **kwargs):
#
# # _, energy = energy_fn(
# # params,
# # positions,
# # atomic_numbers,
# # idx_i,
# # idx_j,
# # cell,
# # cell_offset,
# # batch_segments,
# # node_mask,
# # graph_mask
# # )
#
# forces, energy = jax.jacrev(energy_fn, argnums=1, has_aux=True)(
# params,
# positions,
# atomic_numbers,
# idx_i,
# idx_j,
# cell,
# cell_offset,
# batch_segments,
# node_mask,
# graph_mask
# )
#
# return dict(energy=energy, forces=forces)
#
# return energy_and_force_fn
2 changes: 2 additions & 0 deletions mlff/utils/jraph_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ def graph_to_batch_fn(graph: jraph.GraphsTuple):
batch = dict(
positions=graph.nodes.get('positions'),
atomic_numbers=graph.nodes.get('atomic_numbers'),
total_charge=graph.globals.get('total_charge'),
num_unpaired_electrons=graph.globals.get('num_unpaired_electrons'),
idx_i=graph.receivers,
idx_j=graph.senders,
cell=graph.edges.get('cell'),
Expand Down

0 comments on commit a55af7e

Please sign in to comment.