Skip to content

Commit

Permalink
Merge pull request #26 from thorben-frank/v1.0-itp-net
Browse files Browse the repository at this point in the history
add ITPNet
  • Loading branch information
thorben-frank authored Mar 15, 2024
2 parents a03ec63 + 6042eb3 commit e867144
Show file tree
Hide file tree
Showing 15 changed files with 953 additions and 6 deletions.
28 changes: 28 additions & 0 deletions mlff/CLI/run_training_itp_net.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import argparse
import json
from mlff.config import from_config
from ml_collections import config_dict
import pathlib
import yaml


def train_itp_net():
# Create the parser
parser = argparse.ArgumentParser(description='Train a SO3kratesSparse model.')
parser.add_argument('--config', type=str, required=True, help='Path to the config file.')

args = parser.parse_args()

config = pathlib.Path(args.config).expanduser().absolute().resolve()
if config.suffix == '.json':
with open(config, mode='r') as fp:
cfg = config_dict.ConfigDict(json.load(fp=fp))
elif config.suffix == '.yaml':
with open(config, mode='r') as fp:
cfg = config_dict.ConfigDict(yaml.load(fp, Loader=yaml.FullLoader))

from_config.run_training(cfg, model='itp_net')


if __name__ == '__main__':
train_itp_net()
67 changes: 67 additions & 0 deletions mlff/config/config_itp_net.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
workdir: first_experiment_itp # Working directory. Checkpoints and hyperparameters are saved there.
data:
filepath: null # Path to the data file. Either ASE digestible or .npz with appropriate column names are supported.
energy_unit: eV # Energy unit.
length_unit: Angstrom # Length unit.
shift_mode: null # Options are null, mean, custom.
energy_shifts: null # Energy shifts to subtract.
split_seed: 0 # Seed using for splitting the data into training, validation and test.
model:
num_features: 128 # Number of invariant features.
radial_basis_fn: reciprocal_bernstein # Radial basis function to use.
num_radial_basis_fn: 32 # Number of radial basis functions.
cutoff: 5.0 # Local cutoff to use.
cutoff_fn: smooth_cutoff # Cutoff function to use.
filter_num_layers: 2 # Number of filter layers.
filter_activation_fn: identity # Activation function for the filter.
mp_max_degree: 2
mp_post_res_block: true
mp_post_res_block_activation_fn: identity
itp_num_features: 16
itp_max_degree: 2
itp_num_updates: 2
itp_post_res_block: true
itp_post_res_block_activation_fn: identity
itp_connectivity: dense
feature_collection_over_layers: last
include_pseudotensors: false
message_normalization: avg_num_neighbors # How to normalize the message function. Options are (identity, sqrt_num_features, avg_num_neighbors)
output_is_zero_at_init: true # The output of the full network is zero at initialization.
energy_regression_dim: 128 # Dimension to which final features are projected, before atomic energies are calculated.
energy_activation_fn: identity # Activation function to use on the energy_regression_dim before atomic energies are calculated.
energy_learn_atomic_type_scales: false
energy_learn_atomic_type_shifts: false
input_convention: positions # Input convention.
optimizer:
name: adam # Name of the optimizer. See https://optax.readthedocs.io/en/latest/api.html#common-optimizers for available ones.
learning_rate: 0.001 # Learning rate to use.
learning_rate_schedule: exponential_decay # Which learning rate schedule to use. See https://optax.readthedocs.io/en/latest/api.html#optimizer-schedules for available ones.
learning_rate_schedule_args: # Arguments passed to the learning rate schedule. See https://optax.readthedocs.io/en/latest/api.html#optimizer-schedules.
decay_rate: 0.75
transition_steps: 125000
gradient_clipping: identity
gradient_clipping_args: null
num_of_nans_to_ignore: 0 # Number of repeated update/gradient steps that ignore NaNs before raising on error.
training:
allow_restart: false # Re-starting from checkpoint is allowed. This will overwrite existing checkpoints so only use if this is desired.
num_epochs: 100 # Number of epochs.
num_train: 950 # Number of training points to draw from data.filepath.
num_valid: 50 # Number of validation points to draw from data.filepath.
batch_max_num_nodes: null # Maximal number of nodes per batch. Must be at least maximal number of atoms + 1 in the data set.
batch_max_num_edges: null # Maximal number of edges per batch. Must be at least maximal number of edges + 1 in the data set.
# If batch_max_num_nodes and batch_max_num_edges is set to null, they will be determined from the max_num_of_graphs.
# If they are set to values, each batch will contain as many molecular structures/graphs such none of the three values
# batch_max_num_nodes, batch_max_num_edges and batch_max_num_of_graphs is exceeded.
batch_max_num_graphs: 6 # Maximal number of graphs per batch.
# Since there is one padding graph involved for an effective batch size of 5 corresponds to 6 max_num_graphs.
eval_every_num_steps: 1000 # Number of gradient steps after which the metrics on the validation set are calculated.
loss_weights:
energy: 0.01 # Loss weight for the energy.
forces: 0.99 # Loss weight for the forces.
model_seed: 0 # Seed used for the initialization of the model parameters.
training_seed: 0 # Seed used for shuffling the batches during training.
log_gradient_values: False # Log the norm of the gradients for each set of weights.
wandb_init_args: # Arguments to wandb.init(). See https://docs.wandb.ai/ref/python/init. The config itself is passed as config to wandb.init().
name: first_training_run
project: mlff
group: null
56 changes: 52 additions & 4 deletions mlff/config/from_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,46 @@ def make_so3krates_sparse_from_config(config: config_dict.ConfigDict = None):
)


def make_itp_net_from_config(config: config_dict.ConfigDict):
"""Make an iterated tensor product model from a config.
Args:
config (): The config.
Returns:
ITP flax model.
"""

model_config = config.model

return nn.ITPNet(
num_features=model_config.num_features,
radial_basis_fn=model_config.radial_basis_fn,
num_radial_basis_fn=model_config.num_radial_basis_fn,
cutoff_fn=model_config.cutoff_fn,
filter_num_layers=model_config.filter_num_layers,
filter_activation_fn=model_config.filter_activation_fn,
mp_max_degree=model_config.mp_max_degree,
mp_post_res_block=model_config.mp_post_res_block,
mp_post_res_block_activation_fn=model_config.mp_post_res_block_activation_fn,
itp_max_degree=model_config.itp_max_degree,
itp_num_features=model_config.itp_num_features,
itp_post_res_block=model_config.itp_post_res_block,
itp_post_res_block_activation_fn=model_config.itp_post_res_block_activation_fn,
itp_connectivity=model_config.itp_connectivity,
feature_collection_over_layers=model_config.feature_collection_over_layers,
include_pseudotensors=model_config.include_pseudotensors,
message_normalization=config.model.message_normalization,
avg_num_neighbors=config.data.avg_num_neighbors if config.model.message_normalization == 'avg_num_neighbors' else None,
output_is_zero_at_init=model_config.output_is_zero_at_init,
input_convention=model_config.input_convention,
energy_regression_dim=model_config.energy_regression_dim,
energy_activation_fn=model_config.energy_activation_fn,
energy_learn_atomic_type_scales=model_config.energy_learn_atomic_type_scales,
energy_learn_atomic_type_shifts=model_config.energy_learn_atomic_type_shifts
)


def make_optimizer_from_config(config: config_dict.ConfigDict = None):
"""Make optax optimizer from config.
Expand All @@ -85,11 +125,12 @@ def make_optimizer_from_config(config: config_dict.ConfigDict = None):
)


def run_training(config: config_dict.ConfigDict):
def run_training(config: config_dict.ConfigDict, model: str = None):
"""Run training given a config.
Args:
config (): The config.
model (): The model to train. Defaults to SO3krates.
Returns:
Expand Down Expand Up @@ -190,10 +231,17 @@ def run_training(config: config_dict.ConfigDict):
))

opt = make_optimizer_from_config(config)
so3k = make_so3krates_sparse_from_config(config)
if model is None or model == 'so3krates':
net = make_so3krates_sparse_from_config(config)
elif model == 'itp_net':
net = make_itp_net_from_config(config)
else:
raise ValueError(
f'{model=} is not a valid model.'
)

loss_fn = training_utils.make_loss_fn(
get_energy_and_force_fn_sparse(so3k),
get_energy_and_force_fn_sparse(net),
weights=config.training.loss_weights
)

Expand Down Expand Up @@ -238,7 +286,7 @@ def run_training(config: config_dict.ConfigDict):
wandb.init(config=config.to_dict(), **config.training.wandb_init_args)
logging.mlff('Training is starting!')
training_utils.fit(
model=so3k,
model=net,
optimizer=opt,
loss_fn=loss_fn,
graph_to_batch_fn=jraph_utils.graph_to_batch_fn,
Expand Down
5 changes: 3 additions & 2 deletions mlff/nn/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from .representation import (So3krates,
So3kratACE,
SchNet,
SO3kratesSparse)
SO3kratesSparse,
ITPNet)

from .stacknet import (get_observable_fn,
get_energy_force_stress_fn,
Expand All @@ -15,4 +16,4 @@
from .observable import (Energy,
ZBLRepulsion)

from .embed import GeometryEmbedSparse
from .embed import GeometryEmbedSparse, GeometryEmbedE3x
1 change: 1 addition & 0 deletions mlff/nn/embed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from .embed_sparse import (
GeometryEmbedSparse,
GeometryEmbedE3x,
AtomTypeEmbedSparse
)

Expand Down
59 changes: 59 additions & 0 deletions mlff/nn/embed/embed_sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import (Any, Dict, Sequence)

import flax.linen as nn
import e3x

from mlff.nn.base.sub_module import BaseSubModule
from mlff.masking.mask import safe_mask
Expand All @@ -14,6 +15,64 @@
from mlff.basis_function.spherical import init_sph_fn


class GeometryEmbedE3x(BaseSubModule):
prop_keys: Dict
max_degree: int
radial_basis_fn: str
num_radial_basis_fn: int
cutoff_fn: str
cutoff: float
input_convention: str = 'positions'
module_name: str = 'geometry_embed_e3x'

def __call__(self, inputs, *args, **kwargs):

idx_i = inputs['idx_i'] # shape: (num_pairs)
idx_j = inputs['idx_j'] # shape: (num_pairs)
cell = inputs.get('cell') # shape: (num_graphs, 3, 3)
cell_offsets = inputs.get('cell_offset') # shape: (num_pairs, 3)

if self.input_convention == 'positions':
positions = inputs['positions'] # (N, 3)

# Calculate pairwise distance vectors
r_ij = jax.vmap(
lambda i, j: positions[j] - positions[i]
)(idx_i, idx_j) # (num_pairs, 3)

# Apply minimal image convention if needed.
if cell is not None:
r_ij = add_cell_offsets_sparse(
r_ij=r_ij,
cell=cell,
cell_offsets=cell_offsets
) # shape: (num_pairs,3)

# Here it is assumed that PBC (if present) have already been respected in displacement calculation.
elif self.input_convention == 'displacements':
positions = None
r_ij = inputs['displacements']
else:
raise ValueError(f"{self.input_convention} is not a valid argument for `input_convention`.")

basis, cut = e3x.nn.basis(
r=r_ij,
max_degree=self.max_degree,
radial_fn=getattr(e3x.nn, self.radial_basis_fn),
num=self.num_radial_basis_fn,
cutoff_fn=partial(getattr(e3x.nn, self.cutoff_fn), cutoff=self.cutoff),
return_cutoff=True
) # (N, 1, (max_degree+1)^2, num_radial_basis_fn), (N, )

geometric_data = {'positions': positions,
'basis': basis,
'r_ij': r_ij,
'cut': cut,
}

return geometric_data


class GeometryEmbedSparse(BaseSubModule):
prop_keys: Dict
degrees: Sequence[int]
Expand Down
3 changes: 3 additions & 0 deletions mlff/nn/embed/h_register.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from .embed_sparse import (
GeometryEmbedSparse,
GeometryEmbedE3x,
AtomTypeEmbedSparse
)

Expand All @@ -19,6 +20,8 @@ def get_embedding_module(name: str, h: Dict):
return AtomTypeEmbedSparse(**h)
elif name == 'geometry_embed':
return GeometryEmbed(**h)
elif name == 'geometry_embed_e3x':
return GeometryEmbedE3x(**h)
elif name == 'geometry_embed_sparse':
return GeometryEmbedSparse(**h)
elif name == 'one_hot_embed':
Expand Down
1 change: 1 addition & 0 deletions mlff/nn/layer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,5 @@
from .so3krates_layer import So3kratesLayer
from .so3kratace_layer import So3krataceLayer
from .so3krates_layer_sparse import SO3kratesLayerSparse
from .itp_layer import ITPLayer
from .h_register import get_layer
3 changes: 3 additions & 0 deletions mlff/nn/layer/h_register.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from .so3krates_layer import So3kratesLayer
from .so3kratace_layer import So3krataceLayer
from .schnet_layer import SchNetLayer
from .itp_layer import ITPLayer


def get_layer(name: str, h: Dict):
Expand All @@ -11,6 +12,8 @@ def get_layer(name: str, h: Dict):
return So3krataceLayer(**h)
elif name == 'schnet_layer':
return SchNetLayer(**h)
elif name == 'itp_layer':
return ITPLayer(**h)
elif name == 'spookynet_layer':
raise NotImplementedError('SpookyNet not implemented!')
return SpookyNetLayer(**h)
Expand Down
Loading

0 comments on commit e867144

Please sign in to comment.