-
Notifications
You must be signed in to change notification settings - Fork 24
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #26 from thorben-frank/v1.0-itp-net
add ITPNet
- Loading branch information
Showing
15 changed files
with
953 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
import argparse | ||
import json | ||
from mlff.config import from_config | ||
from ml_collections import config_dict | ||
import pathlib | ||
import yaml | ||
|
||
|
||
def train_itp_net(): | ||
# Create the parser | ||
parser = argparse.ArgumentParser(description='Train a SO3kratesSparse model.') | ||
parser.add_argument('--config', type=str, required=True, help='Path to the config file.') | ||
|
||
args = parser.parse_args() | ||
|
||
config = pathlib.Path(args.config).expanduser().absolute().resolve() | ||
if config.suffix == '.json': | ||
with open(config, mode='r') as fp: | ||
cfg = config_dict.ConfigDict(json.load(fp=fp)) | ||
elif config.suffix == '.yaml': | ||
with open(config, mode='r') as fp: | ||
cfg = config_dict.ConfigDict(yaml.load(fp, Loader=yaml.FullLoader)) | ||
|
||
from_config.run_training(cfg, model='itp_net') | ||
|
||
|
||
if __name__ == '__main__': | ||
train_itp_net() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
workdir: first_experiment_itp # Working directory. Checkpoints and hyperparameters are saved there. | ||
data: | ||
filepath: null # Path to the data file. Either ASE digestible or .npz with appropriate column names are supported. | ||
energy_unit: eV # Energy unit. | ||
length_unit: Angstrom # Length unit. | ||
shift_mode: null # Options are null, mean, custom. | ||
energy_shifts: null # Energy shifts to subtract. | ||
split_seed: 0 # Seed using for splitting the data into training, validation and test. | ||
model: | ||
num_features: 128 # Number of invariant features. | ||
radial_basis_fn: reciprocal_bernstein # Radial basis function to use. | ||
num_radial_basis_fn: 32 # Number of radial basis functions. | ||
cutoff: 5.0 # Local cutoff to use. | ||
cutoff_fn: smooth_cutoff # Cutoff function to use. | ||
filter_num_layers: 2 # Number of filter layers. | ||
filter_activation_fn: identity # Activation function for the filter. | ||
mp_max_degree: 2 | ||
mp_post_res_block: true | ||
mp_post_res_block_activation_fn: identity | ||
itp_num_features: 16 | ||
itp_max_degree: 2 | ||
itp_num_updates: 2 | ||
itp_post_res_block: true | ||
itp_post_res_block_activation_fn: identity | ||
itp_connectivity: dense | ||
feature_collection_over_layers: last | ||
include_pseudotensors: false | ||
message_normalization: avg_num_neighbors # How to normalize the message function. Options are (identity, sqrt_num_features, avg_num_neighbors) | ||
output_is_zero_at_init: true # The output of the full network is zero at initialization. | ||
energy_regression_dim: 128 # Dimension to which final features are projected, before atomic energies are calculated. | ||
energy_activation_fn: identity # Activation function to use on the energy_regression_dim before atomic energies are calculated. | ||
energy_learn_atomic_type_scales: false | ||
energy_learn_atomic_type_shifts: false | ||
input_convention: positions # Input convention. | ||
optimizer: | ||
name: adam # Name of the optimizer. See https://optax.readthedocs.io/en/latest/api.html#common-optimizers for available ones. | ||
learning_rate: 0.001 # Learning rate to use. | ||
learning_rate_schedule: exponential_decay # Which learning rate schedule to use. See https://optax.readthedocs.io/en/latest/api.html#optimizer-schedules for available ones. | ||
learning_rate_schedule_args: # Arguments passed to the learning rate schedule. See https://optax.readthedocs.io/en/latest/api.html#optimizer-schedules. | ||
decay_rate: 0.75 | ||
transition_steps: 125000 | ||
gradient_clipping: identity | ||
gradient_clipping_args: null | ||
num_of_nans_to_ignore: 0 # Number of repeated update/gradient steps that ignore NaNs before raising on error. | ||
training: | ||
allow_restart: false # Re-starting from checkpoint is allowed. This will overwrite existing checkpoints so only use if this is desired. | ||
num_epochs: 100 # Number of epochs. | ||
num_train: 950 # Number of training points to draw from data.filepath. | ||
num_valid: 50 # Number of validation points to draw from data.filepath. | ||
batch_max_num_nodes: null # Maximal number of nodes per batch. Must be at least maximal number of atoms + 1 in the data set. | ||
batch_max_num_edges: null # Maximal number of edges per batch. Must be at least maximal number of edges + 1 in the data set. | ||
# If batch_max_num_nodes and batch_max_num_edges is set to null, they will be determined from the max_num_of_graphs. | ||
# If they are set to values, each batch will contain as many molecular structures/graphs such none of the three values | ||
# batch_max_num_nodes, batch_max_num_edges and batch_max_num_of_graphs is exceeded. | ||
batch_max_num_graphs: 6 # Maximal number of graphs per batch. | ||
# Since there is one padding graph involved for an effective batch size of 5 corresponds to 6 max_num_graphs. | ||
eval_every_num_steps: 1000 # Number of gradient steps after which the metrics on the validation set are calculated. | ||
loss_weights: | ||
energy: 0.01 # Loss weight for the energy. | ||
forces: 0.99 # Loss weight for the forces. | ||
model_seed: 0 # Seed used for the initialization of the model parameters. | ||
training_seed: 0 # Seed used for shuffling the batches during training. | ||
log_gradient_values: False # Log the norm of the gradients for each set of weights. | ||
wandb_init_args: # Arguments to wandb.init(). See https://docs.wandb.ai/ref/python/init. The config itself is passed as config to wandb.init(). | ||
name: first_training_run | ||
project: mlff | ||
group: null |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,6 +6,7 @@ | |
|
||
from .embed_sparse import ( | ||
GeometryEmbedSparse, | ||
GeometryEmbedE3x, | ||
AtomTypeEmbedSparse | ||
) | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.