diff --git a/docs/source/advanced_usage/trainingmodel.rst b/docs/source/advanced_usage/trainingmodel.rst index 290aa15f..9b118d86 100644 --- a/docs/source/advanced_usage/trainingmodel.rst +++ b/docs/source/advanced_usage/trainingmodel.rst @@ -194,22 +194,64 @@ keyword, you can fine-tune the number of new snapshots being created. By default, the same number of snapshots as had been provided will be created (if possible). -Using tensorboard -****************** +Logging metrics during training +******************************* + +Training progress in MALA can be visualized via tensorboard or wandb, as also shown +in the file ``advanced/ex03_tensor_board``. Simply select a logger prior to training as + + .. code-block:: python + + parameters.running.logger = "tensorboard" + parameters.running.logging_dir = "mala_vis" -Training routines in MALA can be visualized via tensorboard, as also shown -in the file ``advanced/ex03_tensor_board``. Simply enable tensorboard -visualization prior to training via +or .. code-block:: python - # 0: No visualizatuon, 1: loss and learning rate, 2: like 1, - # but additionally weights and biases are saved - parameters.running.logging = 1 + import wandb + wandb.init( + project="mala_training", + entity="your_wandb_entity" + ) + parameters.running.logger = "wandb" parameters.running.logging_dir = "mala_vis" where ``logging_dir`` specifies some directory in which to save the -MALA logging data. Afterwards, you can run the training without any +MALA logging data. You can also select which metrics to record via + + .. code-block:: python + + parameters.validation_metrics = ["ldos", "dos", "density", "total_energy"] + +Full list of available metrics: + - "ldos": MSE of the LDOS. + - "band_energy": Band energy. + - "band_energy_actual_fe": Band energy computed with ground truth Fermi energy. + - "total_energy": Total energy. + - "total_energy_actual_fe": Total energy computed with ground truth Fermi energy. + - "fermi_energy": Fermi energy. + - "density": Electron density. + - "density_relative": Rlectron density (Mean Absolute Percentage Error). + - "dos": Density of states. + - "dos_relative": Density of states (Mean Absolute Percentage Error). + +To save time and resources you can specify the logging interval via + + .. code-block:: python + + parameters.running.validate_every_n_epochs = 10 + +If you want to monitor the degree to which the model overfits to the training data, +you can use the option + + .. code-block:: python + + parameters.running.validate_on_training_data = True + +MALA will evaluate the validation metrics on the training set as well as the validation set. + +Afterwards, you can run the training without any other modifications. Once training is finished (or during training, in case you want to use tensorboard to monitor progress), you can launch tensorboard via @@ -221,6 +263,7 @@ via The full path for ``path_to_log_directory`` can be accessed via ``trainer.full_logging_path``. +If you're using wandb, you can monitor the training progress on the wandb website. Training in parallel ******************** diff --git a/mala/common/parameters.py b/mala/common/parameters.py index 5aa41afc..30ce695b 100644 --- a/mala/common/parameters.py +++ b/mala/common/parameters.py @@ -289,13 +289,13 @@ def __init__(self): self.layer_activations = ["Sigmoid"] self.loss_function_type = "mse" - # for LSTM/Gru + Transformer - self.num_hidden_layers = 1 - # for LSTM/Gru self.no_hidden_state = False self.bidirection = False + # for LSTM/Gru + Transformer + self.num_hidden_layers = 1 + # for transformer net self.dropout = 0.1 self.num_heads = 10 @@ -744,13 +744,16 @@ class ParametersRunning(ParametersBase): a "by snapshot" basis. checkpoints_each_epoch : int - If not 0, checkpoint files will be saved after eac + If not 0, checkpoint files will be saved after each checkpoints_each_epoch epoch. checkpoint_name : string Name used for the checkpoints. Using this, multiple runs can be performed in the same directory. + run_name : string + Name of the run used for logging. + logging_dir : string Name of the folder that logging files will be saved to. @@ -759,6 +762,34 @@ class ParametersRunning(ParametersBase): in a subfolder of logging_dir labelled with the starting date of the logging, to avoid having to change input scripts often. + logger : string + Name of the logger to be used. + Currently supported are: + + - "tensorboard": Tensorboard logger. + - "wandb": Weights and Biases logger. + + validation_metrics : list + List of metrics to be used for validation. Default is ["ldos"]. + Possible options are: + + - "ldos": MSE of the LDOS. + - "band_energy": Band energy. + - "band_energy_actual_fe": Band energy computed with ground truth Fermi energy. + - "total_energy": Total energy. + - "total_energy_actual_fe": Total energy computed with ground truth Fermi energy. + - "fermi_energy": Fermi energy. + - "density": Electron density. + - "density_relative": Rlectron density (MAPE). + - "dos": Density of states. + - "dos_relative": Density of states (MAPE). + + validate_on_training_data : bool + Whether to validate on the training data as well. Default is False. + + validate_every_n_epochs : int + Determines how often validation is performed. Default is 1. + training_log_interval : int Determines how often detailed performance info is printed during training (only has an effect if the verbosity is high enough).