From f80a9831f922e783f7662172f4fa75b638f1c9f8 Mon Sep 17 00:00:00 2001 From: Dayton Vogel Date: Thu, 29 Jun 2023 22:34:15 -0600 Subject: [PATCH 01/22] all files updated with changes for ddp implementation --- install/mala_gpu_base_environment.yml | 6 +- mala/common/check_modules.py | 2 - mala/common/parallelizer.py | 46 +++-- mala/common/parameters.py | 106 ++++++++---- mala/datahandling/data_handler.py | 28 ++- mala/datahandling/data_handler_base.py | 2 +- mala/datahandling/data_scaler.py | 21 +-- mala/datahandling/lazy_load_dataset.py | 19 +- .../lazy_load_dataset_clustered.py | 19 +- mala/datahandling/lazy_load_dataset_single.py | 8 +- mala/network/network.py | 14 +- mala/network/objective_naswot.py | 2 +- mala/network/predictor.py | 5 - mala/network/runner.py | 23 ++- mala/network/tester.py | 5 - mala/network/trainer.py | 163 ++++++++++-------- 16 files changed, 243 insertions(+), 226 deletions(-) diff --git a/install/mala_gpu_base_environment.yml b/install/mala_gpu_base_environment.yml index c3e9e6c9f..7f78d40fd 100644 --- a/install/mala_gpu_base_environment.yml +++ b/install/mala_gpu_base_environment.yml @@ -1,4 +1,6 @@ -name: mala-gpu +name: mala-gpu-ddp channels: - - defaults - conda-forge + - defaults +dependencies: + - python=3.10 diff --git a/mala/common/check_modules.py b/mala/common/check_modules.py index 63fb4e16b..c4bc05017 100644 --- a/mala/common/check_modules.py +++ b/mala/common/check_modules.py @@ -8,8 +8,6 @@ def check_modules(): optional_libs = { "mpi4py": {"available": False, "description": "Enables inference parallelization."}, - "horovod": {"available": False, "description": - "Enables training parallelization."}, "lammps": {"available": False, "description": "Enables descriptor calculation for data preprocessing " "and inference."}, diff --git a/mala/common/parallelizer.py b/mala/common/parallelizer.py index 0d8947934..ae39b9a94 100644 --- a/mala/common/parallelizer.py +++ b/mala/common/parallelizer.py @@ -1,15 +1,13 @@ """Functions for operating MALA in parallel.""" from collections import defaultdict import platform +import os import warnings -try: - import horovod.torch as hvd -except ModuleNotFoundError: - pass import torch +import torch.distributed as dist -use_horovod = False +use_ddp = False use_mpi = False comm = None local_mpi_rank = None @@ -32,41 +30,41 @@ def set_current_verbosity(new_value): current_verbosity = new_value -def set_horovod_status(new_value): +def set_ddp_status(new_value): """ - Set the horovod status. + Set the ddp status. - By setting the horovod status via this function it can be ensured that + By setting the ddp status via this function it can be ensured that printing works in parallel. The Parameters class does that for the user. Parameters ---------- new_value : bool - Value the horovod status has. + Value the ddp status has. """ if use_mpi is True and new_value is True: - raise Exception("Cannot use horovod and inference-level MPI at " + raise Exception("Cannot use ddp and inference-level MPI at " "the same time yet.") - global use_horovod - use_horovod = new_value + global use_ddp + use_ddp = new_value def set_mpi_status(new_value): """ Set the MPI status. - By setting the horovod status via this function it can be ensured that + By setting the ddp status via this function it can be ensured that printing works in parallel. The Parameters class does that for the user. Parameters ---------- new_value : bool - Value the horovod status has. + Value the ddp status has. """ - if use_horovod is True and new_value is True: - raise Exception("Cannot use horovod and inference-level MPI at " + if use_ddp is True and new_value is True: + raise Exception("Cannot use ddp and inference-level MPI at " "the same time yet.") global use_mpi use_mpi = new_value @@ -113,8 +111,8 @@ def get_rank(): The rank of the current thread. """ - if use_horovod: - return hvd.rank() + if use_ddp: + return dist.get_rank() if use_mpi: return comm.Get_rank() return 0 @@ -153,8 +151,8 @@ def get_local_rank(): FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ - if use_horovod: - return hvd.local_rank() + if use_ddp: + return int(os.environ.get("LOCAL_RANK")) if use_mpi: global local_mpi_rank if local_mpi_rank is None: @@ -181,8 +179,8 @@ def get_size(): size : int The number of ranks. """ - if use_horovod: - return hvd.size() + if use_ddp: + return dist.get_world_size() if use_mpi: return comm.Get_size() @@ -203,8 +201,8 @@ def get_comm(): def barrier(): """General interface for a barrier.""" - if use_horovod: - hvd.allreduce(torch.tensor(0), name='barrier') + if use_ddp: + dist.barrier() if use_mpi: comm.Barrier() return diff --git a/mala/common/parameters.py b/mala/common/parameters.py index 6c0c6908d..fe603854a 100644 --- a/mala/common/parameters.py +++ b/mala/common/parameters.py @@ -6,14 +6,11 @@ import pickle from time import sleep -try: - import horovod.torch as hvd -except ModuleNotFoundError: - pass import numpy as np import torch +import torch.distributed as dist -from mala.common.parallelizer import printout, set_horovod_status, \ +from mala.common.parallelizer import printout, set_ddp_status, \ set_mpi_status, get_rank, get_local_rank, set_current_verbosity, \ parallel_warn from mala.common.json_serializable import JSONSerializable @@ -26,7 +23,7 @@ class ParametersBase(JSONSerializable): def __init__(self,): super(ParametersBase, self).__init__() - self._configuration = {"gpu": False, "horovod": False, "mpi": False, + self._configuration = {"gpu": False, "ddp": False, "mpi": False, "device": "cpu", "openpmd_configuration": {}, "openpmd_granularity": 1} pass @@ -54,8 +51,8 @@ def show(self, indent=""): def _update_gpu(self, new_gpu): self._configuration["gpu"] = new_gpu - def _update_horovod(self, new_horovod): - self._configuration["horovod"] = new_horovod + def _update_ddp(self, new_ddp): + self._configuration["ddp"] = new_ddp def _update_mpi(self, new_mpi): self._configuration["mpi"] = new_mpi @@ -675,10 +672,6 @@ class ParametersRunning(ParametersBase): validation loss has to plateau before the schedule takes effect). Default: 0. - use_compression : bool - If True and horovod is used, horovod compression will be used for - allreduce communication. This can improve performance. - num_workers : int Number of workers to be used for data loading. @@ -739,7 +732,6 @@ def __init__(self): self.learning_rate_scheduler = None self.learning_rate_decay = 0.1 self.learning_rate_patience = 0 - self.use_compression = False self.num_workers = 0 self.use_shuffling_for_samplers = True self.checkpoints_each_epoch = 0 @@ -755,8 +747,8 @@ def __init__(self): self.training_report_frequency = 1000 self.profiler_range = [1000, 2000] - def _update_horovod(self, new_horovod): - super(ParametersRunning, self)._update_horovod(new_horovod) + def _update_ddp(self, new_ddp): + super(ParametersRunning, self)._update_ddp(new_ddp) self.during_training_metric = self.during_training_metric self.after_before_training_metric = self.after_before_training_metric @@ -778,9 +770,9 @@ def during_training_metric(self): @during_training_metric.setter def during_training_metric(self, value): if value != "ldos": - if self._configuration["horovod"]: + if self._configuration["ddp"]: raise Exception("Currently, MALA can only operate with the " - "\"ldos\" metric for horovod runs.") + "\"ldos\" metric for ddp runs.") self._during_training_metric = value @property @@ -801,17 +793,17 @@ def after_before_training_metric(self): @after_before_training_metric.setter def after_before_training_metric(self, value): if value != "ldos": - if self._configuration["horovod"]: + if self._configuration["ddp"]: raise Exception("Currently, MALA can only operate with the " - "\"ldos\" metric for horovod runs.") + "\"ldos\" metric for ddp runs.") self._after_before_training_metric = value @during_training_metric.setter def during_training_metric(self, value): if value != "ldos": - if self._configuration["horovod"]: + if self._configuration["ddp"]: raise Exception("Currently, MALA can only operate with the " - "\"ldos\" metric for horovod runs.") + "\"ldos\" metric for ddp runs.") self._during_training_metric = value @property @@ -1178,7 +1170,10 @@ def __init__(self): # Properties self.use_gpu = False - self.use_horovod = False + self.use_ddp = False + self.use_distributed_sampler_train = True + self.use_distributed_sampler_val = True + self.use_distributed_sampler_test = True self.use_mpi = False self.verbosity = 1 self.device = "cpu" @@ -1259,25 +1254,62 @@ def use_gpu(self, value): self.hyperparameters._update_gpu(self.use_gpu) @property - def use_horovod(self): - """Control whether or not horovod is used for parallel training.""" - return self._use_horovod + def use_ddp(self): + """Control whether or not dd is used for parallel training.""" + return self._use_ddp + + @property + def use_distributed_sampler_train(self): + """Control wether or not distributed sampler is used to distribute training data.""" + return self._use_distributed_sampler_train + + @use_distributed_sampler_train.setter + def use_distributed_sampler_train(self, value): + """Control whether or not distributed sampler is used to distribute training data.""" + self._use_distributed_sampler_train = value + + @property + def use_distributed_sampler_val(self): + """Control whether or not distributed sampler is used to distribute validation data.""" + return self._use_distributed_sampler_val + + @use_distributed_sampler_val.setter + def use_distributed_sampler_val(self, value): + """Control whether or not distributed sampler is used to distribute validation data.""" + self._use_distributed_sampler_val = value + + @property + def use_distributed_sampler_test(self): + """Control whether or not distributed sampler is used to distribute test data.""" + return self._use_distributed_sampler_test + + @use_distributed_sampler_test.setter + def use_distributed_sampler_test(self, value): + """Control whether or not distributed sampler is used to distribute test data.""" + self._use_distributed_sampler_test = value - @use_horovod.setter - def use_horovod(self, value): + @use_ddp.setter + def use_ddp(self, value): if value: - hvd.init() + print("initializing torch.distributed.") + # JOSHR: + # We start up torch distributed here. As is fairly standard convention, we get the rank + # and world size arguments via environment variables (RANK, WORLD_SIZE). In addition to + # those variables, LOCAL_RANK, MASTER_ADDR and MASTER_PORT should be set. + rank = int(os.environ.get("RANK")) + world_size = int(os.environ.get("WORLD_SIZE")) + dist.init_process_group("nccl", rank=rank, world_size=world_size) # Invalidate, will be updated in setter. - set_horovod_status(value) + set_ddp_status(value) self.device = None - self._use_horovod = value - self.network._update_horovod(self.use_horovod) - self.descriptors._update_horovod(self.use_horovod) - self.targets._update_horovod(self.use_horovod) - self.data._update_horovod(self.use_horovod) - self.running._update_horovod(self.use_horovod) - self.hyperparameters._update_horovod(self.use_horovod) + self._use_ddp = value + self.network._update_ddp(self.use_ddp) + self.descriptors._update_ddp(self.use_ddp) + self.targets._update_ddp(self.use_ddp) + self.data._update_ddp(self.use_ddp) + self.running._update_ddp(self.use_ddp) + self.hyperparameters._update_ddp(self.use_ddp) @property def device(self): @@ -1301,7 +1333,7 @@ def device(self, value): @property def use_mpi(self): - """Control whether or not horovod is used for parallel training.""" + """Control whether or not ddp is used for parallel training.""" return self._use_mpi @use_mpi.setter diff --git a/mala/datahandling/data_handler.py b/mala/datahandling/data_handler.py index f97b9e886..60184b174 100644 --- a/mala/datahandling/data_handler.py +++ b/mala/datahandling/data_handler.py @@ -1,11 +1,5 @@ """DataHandler class that loads and scales data.""" import os - -try: - import horovod.torch as hvd -except ModuleNotFoundError: - # Warning is thrown by Parameters class - pass import numpy as np import torch from torch.utils.data import TensorDataset @@ -71,13 +65,13 @@ def __init__(self, parameters: Parameters, target_calculator=None, if self.input_data_scaler is None: self.input_data_scaler \ = DataScaler(self.parameters.input_rescaling_type, - use_horovod=self.use_horovod) + use_ddp=self.use_ddp) self.output_data_scaler = output_data_scaler if self.output_data_scaler is None: self.output_data_scaler \ = DataScaler(self.parameters.output_rescaling_type, - use_horovod=self.use_horovod) + use_ddp=self.use_ddp) # Actual data points in the different categories. self.nr_training_data = 0 @@ -576,14 +570,14 @@ def __build_datasets(self): self.input_data_scaler, self.output_data_scaler, self.descriptor_calculator, self.target_calculator, self.grid_dimension, self.grid_size, - self.use_horovod, self.parameters.number_of_clusters, + self.use_ddp, self.parameters.number_of_clusters, self.parameters.train_ratio, self.parameters.sample_ratio)) self.validation_data_sets.append(LazyLoadDataset( self.input_dimension, self.output_dimension, self.input_data_scaler, self.output_data_scaler, self.descriptor_calculator, self.target_calculator, - self.use_horovod)) + self.use_ddp)) if self.nr_test_data != 0: self.test_data_sets.append(LazyLoadDataset( @@ -591,7 +585,7 @@ def __build_datasets(self): self.output_dimension, self.input_data_scaler, self.output_data_scaler, self.descriptor_calculator, self.target_calculator, - self.use_horovod, + self.use_ddp, input_requires_grad=True)) else: @@ -599,12 +593,12 @@ def __build_datasets(self): self.input_dimension, self.output_dimension, self.input_data_scaler, self.output_data_scaler, self.descriptor_calculator, self.target_calculator, - self.use_horovod)) + self.use_ddp)) self.validation_data_sets.append(LazyLoadDataset( self.input_dimension, self.output_dimension, self.input_data_scaler, self.output_data_scaler, self.descriptor_calculator, self.target_calculator, - self.use_horovod)) + self.use_ddp)) if self.nr_test_data != 0: self.test_data_sets.append(LazyLoadDataset( @@ -612,7 +606,7 @@ def __build_datasets(self): self.output_dimension, self.input_data_scaler, self.output_data_scaler, self.descriptor_calculator, self.target_calculator, - self.use_horovod, + self.use_ddp, input_requires_grad=True)) # Add snapshots to the lazy loading data sets. @@ -646,21 +640,21 @@ def __build_datasets(self): self.input_dimension, self.output_dimension, self.input_data_scaler, self.output_data_scaler, self.descriptor_calculator, self.target_calculator, - self.use_horovod)) + self.use_ddp)) if snapshot.snapshot_function == "va": self.validation_data_sets.append(LazyLoadDatasetSingle( self.mini_batch_size, snapshot, self.input_dimension, self.output_dimension, self.input_data_scaler, self.output_data_scaler, self.descriptor_calculator, self.target_calculator, - self.use_horovod)) + self.use_ddp)) if snapshot.snapshot_function == "te": self.test_data_sets.append(LazyLoadDatasetSingle( self.mini_batch_size, snapshot, self.input_dimension, self.output_dimension, self.input_data_scaler, self.output_data_scaler, self.descriptor_calculator, self.target_calculator, - self.use_horovod, + self.use_ddp, input_requires_grad=True)) else: diff --git a/mala/datahandling/data_handler_base.py b/mala/datahandling/data_handler_base.py index 92bc75126..838f6bec0 100644 --- a/mala/datahandling/data_handler_base.py +++ b/mala/datahandling/data_handler_base.py @@ -32,7 +32,7 @@ class DataHandlerBase(ABC): def __init__(self, parameters: Parameters, target_calculator=None, descriptor_calculator=None): self.parameters: ParametersData = parameters.data - self.use_horovod = parameters.use_horovod + self.use_ddp = parameters.use_ddp # Calculators used to parse data from compatible files. self.target_calculator = target_calculator diff --git a/mala/datahandling/data_scaler.py b/mala/datahandling/data_scaler.py index 0a489f7a7..4863a09d0 100644 --- a/mala/datahandling/data_scaler.py +++ b/mala/datahandling/data_scaler.py @@ -1,13 +1,8 @@ """DataScaler class for scaling DFT data.""" import pickle - -try: - import horovod.torch as hvd -except ModuleNotFoundError: - # Warning is thrown by parameters class - pass import numpy as np import torch +import torch.distributed as dist from mala.common.parameters import printout @@ -33,13 +28,13 @@ class DataScaler: - "feature-wise-normal": Row Min-Max scaling (Scale to be in range 0...1) - use_horovod : bool - If True, the DataScaler will use horovod to check that data is + use_ddp : bool + If True, the DataScaler will use ddp to check that data is only saved on the root process in parallel execution. """ - def __init__(self, typestring, use_horovod=False): - self.use_horovod = use_horovod + def __init__(self, typestring, use_ddp=False): + self.use_ddp = use_ddp self.typestring = typestring self.scale_standard = False self.scale_normal = False @@ -393,9 +388,9 @@ def save(self, filename, save_format="pickle"): save_format : File format which will be used for saving. """ - # If we use horovod, only save the network on root. - if self.use_horovod: - if hvd.rank() != 0: + # If we use ddp, only save the network on root. + if self.use_ddp: + if dist.get_rank() != 0: return if save_format == "pickle": with open(filename, 'wb') as handle: diff --git a/mala/datahandling/lazy_load_dataset.py b/mala/datahandling/lazy_load_dataset.py index df7a61095..b031aa3f9 100644 --- a/mala/datahandling/lazy_load_dataset.py +++ b/mala/datahandling/lazy_load_dataset.py @@ -1,13 +1,8 @@ """DataSet for lazy-loading.""" import os - -try: - import horovod.torch as hvd -except ModuleNotFoundError: - # Warning is thrown by Parameters class. - pass import numpy as np import torch +import torch.distributed as dist from torch.utils.data import Dataset from mala.common.parallelizer import barrier @@ -46,8 +41,8 @@ class LazyLoadDataset(torch.utils.data.Dataset): target_calculator : mala.targets.target.Target or derivative Used to do unit conversion on output data. - use_horovod : bool - If true, it is assumed that horovod is used. + use_ddp : bool + If true, it is assumed that ddp is used. input_requires_grad : bool If True, then the gradient is stored for the inputs. @@ -55,7 +50,7 @@ class LazyLoadDataset(torch.utils.data.Dataset): def __init__(self, input_dimension, output_dimension, input_data_scaler, output_data_scaler, descriptor_calculator, - target_calculator, use_horovod, + target_calculator, use_ddp, input_requires_grad=False): self.snapshot_list = [] self.input_dimension = input_dimension @@ -71,7 +66,7 @@ def __init__(self, input_dimension, output_dimension, input_data_scaler, self.currently_loaded_file = None self.input_data = np.empty(0) self.output_data = np.empty(0) - self.use_horovod = use_horovod + self.use_ddp = use_ddp self.return_outputs_directly = False self.input_requires_grad = input_requires_grad @@ -113,8 +108,8 @@ def mix_datasets(self): """ used_perm = torch.randperm(self.number_of_snapshots) barrier() - if self.use_horovod: - used_perm = hvd.broadcast(used_perm, 0) + if self.use_ddp: + used_perm = dist.broadcast(used_perm, 0) self.snapshot_list = [self.snapshot_list[i] for i in used_perm] self.get_new_data(0) diff --git a/mala/datahandling/lazy_load_dataset_clustered.py b/mala/datahandling/lazy_load_dataset_clustered.py index e46636b73..47835de76 100644 --- a/mala/datahandling/lazy_load_dataset_clustered.py +++ b/mala/datahandling/lazy_load_dataset_clustered.py @@ -1,12 +1,5 @@ """DataSet for lazy-loading.""" import os - -try: - import horovod.torch as hvd -except ModuleNotFoundError: - # Warning is thrown by Parameters class. - pass - import numpy as np import torch from torch.utils.data import Dataset @@ -47,8 +40,8 @@ class LazyLoadDatasetClustered(torch.utils.data.Dataset): target_calculator : mala.targets.target.Target or derivative Used to do unit conversion on output data. - use_horovod : bool - If true, it is assumed that horovod is used. + use_ddp : bool + If true, it is assumed that ddp is used. input_requires_grad : bool If True, then the gradient is stored for the inputs. @@ -58,7 +51,7 @@ class LazyLoadDatasetClustered(torch.utils.data.Dataset): def __init__(self, input_dimension, output_dimension, input_data_scaler, output_data_scaler, descriptor_calculator, - target_calculator, use_horovod, + target_calculator, use_ddp, number_of_clusters, train_ratio, sample_ratio, input_requires_grad=False): self.snapshot_list = [] @@ -75,7 +68,7 @@ def __init__(self, input_dimension, output_dimension, input_data_scaler, self.currently_loaded_file = None self.input_data = np.empty(0) self.output_data = np.empty(0) - self.use_horovod = use_horovod + self.use_ddp = use_ddp self.return_outputs_directly = False self.input_requires_grad = input_requires_grad @@ -231,8 +224,8 @@ def mix_datasets(self): if self.number_of_snapshots > 1: used_perm = torch.randperm(self.number_of_snapshots) barrier() - if self.use_horovod: - used_perm = hvd.broadcast(used_perm, 0) + if self.use_ddp: + used_perm = dist.broadcast(used_perm, 0) # Not only the snapshots, but also the clustered inputs and samples # per clusters have to be permutated. diff --git a/mala/datahandling/lazy_load_dataset_single.py b/mala/datahandling/lazy_load_dataset_single.py index 90d882a4e..f2c53d7d0 100644 --- a/mala/datahandling/lazy_load_dataset_single.py +++ b/mala/datahandling/lazy_load_dataset_single.py @@ -38,8 +38,8 @@ class LazyLoadDatasetSingle(torch.utils.data.Dataset): target_calculator : mala.targets.target.Target or derivative Used to do unit conversion on output data. - use_horovod : bool - If true, it is assumed that horovod is used. + use_ddp : bool + If true, it is assumed that ddp is used. input_requires_grad : bool If True, then the gradient is stored for the inputs. @@ -47,7 +47,7 @@ class LazyLoadDatasetSingle(torch.utils.data.Dataset): def __init__(self, batch_size, snapshot, input_dimension, output_dimension, input_data_scaler, output_data_scaler, descriptor_calculator, - target_calculator, use_horovod, + target_calculator, use_ddp, input_requires_grad=False): self.snapshot = snapshot self.input_dimension = input_dimension @@ -63,7 +63,7 @@ def __init__(self, batch_size, snapshot, input_dimension, output_dimension, self.currently_loaded_file = None self.input_data = np.empty(0) self.output_data = np.empty(0) - self.use_horovod = use_horovod + self.use_ddp = use_ddp self.return_outputs_directly = False self.input_requires_grad = input_requires_grad diff --git a/mala/network/network.py b/mala/network/network.py index 521b7c35f..e2dc3f3a7 100644 --- a/mala/network/network.py +++ b/mala/network/network.py @@ -2,16 +2,12 @@ from abc import abstractmethod import numpy as np import torch +import torch.distributed as dist import torch.nn as nn import torch.nn.functional as functional from mala.common.parameters import Parameters from mala.common.parallelizer import printout -try: - import horovod.torch as hvd -except ModuleNotFoundError: - # Warning is thrown by parameters class - pass class Network(nn.Module): @@ -67,7 +63,7 @@ def __new__(cls, params: Parameters): def __init__(self, params: Parameters): # copy the network params from the input parameter object - self.use_horovod = params.use_horovod + self.use_ddp = params.use_ddp self.mini_batch_size = params.running.mini_batch_size self.params = params.network @@ -161,9 +157,9 @@ def save_network(self, path_to_file): path_to_file : string Path to the file in which the network should be saved. """ - # If we use horovod, only save the network on root. - if self.use_horovod: - if hvd.rank() != 0: + # If we use ddp, only save the network on root. + if self.use_ddp: + if dist.get_rank() != 0: return torch.save(self.state_dict(), path_to_file, _use_new_zipfile_serialization=False) diff --git a/mala/network/objective_naswot.py b/mala/network/objective_naswot.py index 655af9a85..ca76392ff 100644 --- a/mala/network/objective_naswot.py +++ b/mala/network/objective_naswot.py @@ -69,7 +69,7 @@ def __call__(self, trial): # Load the batchesand get the jacobian. do_shuffle = self.params.running.use_shuffling_for_samplers if self.data_handler.parameters.use_lazy_loading or \ - self.params.use_horovod: + self.params.use_ddp: do_shuffle = False if self.params.running.use_shuffling_for_samplers: self.data_handler.mix_datasets() diff --git a/mala/network/predictor.py b/mala/network/predictor.py index c282e118c..dfec05daf 100644 --- a/mala/network/predictor.py +++ b/mala/network/predictor.py @@ -1,10 +1,5 @@ """Tester class for testing a network.""" import ase.io -try: - import horovod.torch as hvd -except ModuleNotFoundError: - # Warning is thrown by Parameters class - pass import numpy as np import torch diff --git a/mala/network/runner.py b/mala/network/runner.py index 5367c2a7c..a3f5ad158 100644 --- a/mala/network/runner.py +++ b/mala/network/runner.py @@ -2,13 +2,9 @@ import os from zipfile import ZipFile, ZIP_STORED -try: - import horovod.torch as hvd -except ModuleNotFoundError: - # Warning is thrown by Parameters class - pass import numpy as np import torch +import torch.distributed as dist from mala.common.parameters import ParametersRunning from mala.network.network import Network @@ -353,16 +349,19 @@ def __prepare_to_run(self): """ Prepare the Runner to run the Network. - This includes e.g. horovod setup. + This includes e.g. ddp setup. """ - # See if we want to use horovod. - if self.parameters_full.use_horovod: + # See if we want to use ddp. + if self.parameters_full.use_ddp: if self.parameters_full.use_gpu: # We cannot use "printout" here because this is supposed # to happen on every rank. + size = dist.get_world_size() + rank = dist.get_rank() + local_rank = int(os.environ.get("LOCAL_RANK")) if self.parameters_full.verbosity >= 2: - print("size=", hvd.size(), "global_rank=", hvd.rank(), - "local_rank=", hvd.local_rank(), "device=", - torch.cuda.get_device_name(hvd.local_rank())) + print("size=", size, "global_rank=", rank, + "local_rank=", local_rank, "device=", + torch.cuda.get_device_name(local_rank)) # pin GPU to local rank - torch.cuda.set_device(hvd.local_rank()) + torch.cuda.set_device(local_rank) diff --git a/mala/network/tester.py b/mala/network/tester.py index e8a46ebec..14be01324 100644 --- a/mala/network/tester.py +++ b/mala/network/tester.py @@ -1,9 +1,4 @@ """Tester class for testing a network.""" -try: - import horovod.torch as hvd -except ModuleNotFoundError: - # Warning is thrown by Parameters class - pass import numpy as np from mala.common.parameters import printout diff --git a/mala/network/trainer.py b/mala/network/trainer.py index 98dc291b8..fc33abf0b 100644 --- a/mala/network/trainer.py +++ b/mala/network/trainer.py @@ -4,13 +4,10 @@ from datetime import datetime from packaging import version -try: - import horovod.torch as hvd -except ModuleNotFoundError: - # Warning is thrown by Parameters class - pass import numpy as np import torch +import torch.distributed as dist +from torch.nn.parallel import DistributedDataParallel as DDP from torch import optim from torch.utils.data import DataLoader from torch.utils.tensorboard import SummaryWriter @@ -46,6 +43,16 @@ class Trainer(Runner): def __init__(self, params, network, data, optimizer_dict=None): # copy the parameters into the class. super(Trainer, self).__init__(params, network, data) + + if self.parameters_full.use_ddp: + print("wrapping model in ddp..") + # JOSHR: using streams here to maintain compatibility with + # graph capture + s = torch.cuda.Stream() + with torch.cuda.stream(s): + self.network = DDP(self.network) + torch.cuda.current_stream().wait_stream(s) + self.final_test_loss = float("inf") self.initial_test_loss = float("inf") self.final_validation_loss = float("inf") @@ -59,7 +66,7 @@ def __init__(self, params, network, data, optimizer_dict=None): self.validation_data_loaders = [] self.test_data_loaders = [] - # Samplers for the horovod case. + # Samplers for the ddp case. self.train_sampler = None self.test_sampler = None self.validation_sampler = None @@ -230,11 +237,13 @@ def train_network(self): after_before_training_metric) # Collect and average all the losses from all the devices - if self.parameters_full.use_horovod: - vloss = self.__average_validation(vloss, 'average_loss') + if self.parameters_full.use_ddp: + vloss = self.__average_validation(vloss, 'average_loss', + self.parameters._configuration["device"]) self.initial_validation_loss = vloss if self.data.test_data_set is not None: - tloss = self.__average_validation(tloss, 'average_loss') + tloss = self.__average_validation(tloss, 'average_loss', + self.parameters._configuration["device"]) self.initial_test_loss = tloss printout("Initial Guess - validation data loss: ", vloss, @@ -271,7 +280,7 @@ def train_network(self): training_loss_sum = torch.zeros(1, device=self.parameters._configuration["device"]) # train sampler - if self.parameters_full.use_horovod: + if self.train_sampler: self.train_sampler.set_epoch(epoch) # shuffle dataset if necessary @@ -344,8 +353,9 @@ def train_network(self): self.parameters. during_training_metric) - if self.parameters_full.use_horovod: - vloss = self.__average_validation(vloss, 'average_loss') + if self.parameters_full.use_ddp: + vloss = self.__average_validation(vloss, 'average_loss', + self.parameters._configuration["device"]) if self.parameters_full.verbosity > 1: printout("Epoch {0}: validation data loss: {1}, " "training data loss: {2}".format(epoch, vloss, @@ -433,8 +443,9 @@ def train_network(self): "validation", self.parameters. after_before_training_metric) - if self.parameters_full.use_horovod: - vloss = self.__average_validation(vloss, 'average_loss') + if self.parameters_full.use_ddp: + vloss = self.__average_validation(vloss, 'average_loss', + self.parameters._configuration["device"]) # Calculate final loss. self.final_validation_loss = vloss @@ -446,8 +457,9 @@ def train_network(self): "test", self.parameters. after_before_training_metric) - if self.parameters_full.use_horovod: - tloss = self.__average_validation(tloss, 'average_loss') + if self.parameters_full.use_ddp: + tloss = self.__average_validation(tloss, 'average_loss', + self.parameters._configuration["device"]) printout("Final test data loss: ", tloss, min_verbosity=0) self.final_test_loss = tloss @@ -470,13 +482,13 @@ def __prepare_to_train(self, optimizer_dict): if optimizer_dict is not None: self.last_epoch = optimizer_dict['epoch']+1 - # Scale the learning rate according to horovod. - if self.parameters_full.use_horovod: - if hvd.size() > 1 and self.last_epoch == 0: + # Scale the learning rate according to ddp. + if self.parameters_full.use_ddp: + if dist.get_world_size() > 1 and self.last_epoch == 0: printout("Rescaling learning rate because multiple workers are" " used for training.", min_verbosity=1) self.parameters.learning_rate = self.parameters.learning_rate \ - * hvd.size() + * dist.get_world_size() # Choose an optimizer to use. if self.parameters.trainingtype == "SGD": @@ -508,13 +520,10 @@ def __prepare_to_train(self, optimizer_dict): self.patience_counter = optimizer_dict['early_stopping_counter'] self.last_loss = optimizer_dict['early_stopping_last_loss'] - if self.parameters_full.use_horovod: + if self.parameters_full.use_ddp: # scaling the batch size for multiGPU per node # self.batch_size= self.batch_size*hvd.local_size() - compression = hvd.Compression.fp16 if self.parameters_full.\ - running.use_compression else hvd.Compression.none - # If lazy loading is used we do not shuffle the data points on # their own, but rather shuffle them # by shuffling the files themselves and then reading file by file @@ -524,37 +533,26 @@ def __prepare_to_train(self, optimizer_dict): if self.data.parameters.use_lazy_loading: do_shuffle = False - self.train_sampler = torch.utils.data.\ - distributed.DistributedSampler(self.data.training_data_sets[0], - num_replicas=hvd.size(), - rank=hvd.rank(), - shuffle=do_shuffle) - - self.validation_sampler = torch.utils.data.\ - distributed.DistributedSampler(self.data.validation_data_sets[0], - num_replicas=hvd.size(), - rank=hvd.rank(), - shuffle=False) - - if self.data.test_data_sets: - self.test_sampler = torch.utils.data.\ - distributed.DistributedSampler(self.data.test_data_sets[0], - num_replicas=hvd.size(), - rank=hvd.rank(), + if self.parameters_full.use_distributed_sampler_train: + self.train_sampler = torch.utils.data.\ + distributed.DistributedSampler(self.data.training_data_set, + num_replicas=dist.get_world_size(), + rank=dist.get_rank(), + shuffle=do_shuffle) + if self.parameters_full.use_distributed_sampler_val: + self.validation_sampler = torch.utils.data.\ + distributed.DistributedSampler(self.data.validation_data_set, + num_replicas=dist.get_world_size(), + rank=dist.get_rank(), shuffle=False) - # broadcaste parameters and optimizer state from root device to - # other devices - hvd.broadcast_parameters(self.network.state_dict(), root_rank=0) - hvd.broadcast_optimizer_state(self.optimizer, root_rank=0) - - # Wraps the opimizer for multiGPU operation - self.optimizer = hvd.DistributedOptimizer(self.optimizer, - named_parameters= - self.network. - named_parameters(), - compression=compression, - op=hvd.Average) + if self.parameters_full.use_distributed_sampler_test: + if self.data.test_data_set is not None: + self.test_sampler = torch.utils.data.\ + distributed.DistributedSampler(self.data.test_data_set, + num_replicas=dist.get_world_size(), + rank=dist.get_rank(), + shuffle=False) # Instantiate the learning rate scheduler, if necessary. if self.parameters.learning_rate_scheduler == "ReduceLROnPlateau": @@ -581,7 +579,7 @@ def __prepare_to_train(self, optimizer_dict): # This shuffling is done in the dataset themselves. do_shuffle = self.parameters.use_shuffling_for_samplers if self.data.parameters.use_lazy_loading or self.parameters_full.\ - use_horovod: + use_ddp: do_shuffle = False # Prepare data loaders.(look into mini-batch size) @@ -645,7 +643,11 @@ def __process_mini_batch(self, network, input_data, target_data): with torch.cuda.amp.autocast(enabled=self.parameters.use_mixed_precision): prediction = network(input_data) - loss = network.calculate_loss(prediction, target_data) + if self.parameters_full.use_ddp: + # JOSHR: We have to use "module" here to access custom method of DDP wrapped model + loss = network.module.calcualte_loss(prediction, target_data) + else: + loss = network.calculate_loss(prediction, target_data) if self.gradscaler: self.gradscaler.scale(loss).backward() @@ -659,12 +661,15 @@ def __process_mini_batch(self, network, input_data, target_data): # Capture graph self.train_graph = torch.cuda.CUDAGraph() - self.network.zero_grad(set_to_none=True) + network.zero_grad(set_to_none=True) with torch.cuda.graph(self.train_graph): with torch.cuda.amp.autocast(enabled=self.parameters.use_mixed_precision): self.static_prediction = network(self.static_input_data) - self.static_loss = network.calculate_loss(self.static_prediction, self.static_target_data) + if self.parameters_full.use_ddp: + self.static_loss = network.module.calculate_loss(self.static_prediction, self.static_target_data) + else: + self.static_loss = network.calculate_loss(self.static_prediction, self.static_target_data) if self.gradscaler: self.gradscaler.scale(self.static_loss).backward() @@ -688,7 +693,10 @@ def __process_mini_batch(self, network, input_data, target_data): torch.cuda.nvtx.range_pop() torch.cuda.nvtx.range_push("loss") - loss = network.calculate_loss(prediction, target_data) + if self.parameters_full.use_ddp: + loss = network.module.calculate_loss(prediction, target_data) + else: + loss = network.calculate_loss(prediction, target_data) # loss torch.cuda.nvtx.range_pop() @@ -711,7 +719,10 @@ def __process_mini_batch(self, network, input_data, target_data): return loss else: prediction = network(input_data) - loss = network.calculate_loss(prediction, target_data) + if self.parameters_full.use_ddp: + loss = network.module.calculate_loss(prediction, target_data) + else: + loss = network.calculate_loss(prediction, target_data) loss.backward() self.optimizer.step() self.optimizer.zero_grad() @@ -761,7 +772,10 @@ def __validate_network(self, network, data_set_type, validation_type): for _ in range(20): with torch.cuda.amp.autocast(enabled=self.parameters.use_mixed_precision): prediction = network(x) - loss = network.calculate_loss(prediction, y) + if self.parameters_full.use_ddp: + loss = network.module.calculate_loss(prediction, y) + else: + loss = network.calculate_loss(prediction, y) torch.cuda.current_stream().wait_stream(s) # Create static entry point tensors to graph @@ -773,7 +787,10 @@ def __validate_network(self, network, data_set_type, validation_type): with torch.cuda.graph(self.validation_graph): with torch.cuda.amp.autocast(enabled=self.parameters.use_mixed_precision): self.static_prediction_validation = network(self.static_input_validation) - self.static_loss_validation = network.calculate_loss(self.static_prediction_validation, self.static_target_validation) + if self.parameters_full.use_ddp: + self.static_loss_validation = network.module.calculate_loss(self.static_prediction_validation, self.static_target_validation) + else: + self.static_loss_validation = network.calculate_loss(self.static_prediction_validation, self.static_target_validation) if self.validation_graph: self.static_input_validation.copy_(x) @@ -783,7 +800,10 @@ def __validate_network(self, network, data_set_type, validation_type): else: with torch.cuda.amp.autocast(enabled=self.parameters.use_mixed_precision): prediction = network(x) - loss = network.calculate_loss(prediction, y) + if self.parameters_full.use_ddp: + loss = network.module.calculate_loss(prediction, y) + else: + loss = network.calculate_loss(prediction, y) validation_loss_sum += loss if batchid != 0 and (batchid + 1) % report_freq == 0: torch.cuda.synchronize() @@ -804,8 +824,12 @@ def __validate_network(self, network, data_set_type, validation_type): x = x.to(self.parameters._configuration["device"]) y = y.to(self.parameters._configuration["device"]) prediction = network(x) - validation_loss_sum += \ - network.calculate_loss(prediction, y).item() + if self.parameters_full.use_ddp: + validation_loss_sum += \ + network.module.calculate_loss(prediction, y).item() + else: + validation_loss_sum += \ + network.calculate_loss(prediction, y).item() batchid += 1 validation_loss = validation_loss_sum.item() / batchid @@ -939,8 +963,8 @@ def __create_training_checkpoint(self): # Next, we save all the other objects. - if self.parameters_full.use_horovod: - if hvd.rank() != 0: + if self.parameters_full.use_ddp: + if dist.get_rank() != 0: return if self.scheduler is None: save_dict = { @@ -963,8 +987,9 @@ def __create_training_checkpoint(self): self.save_run(self.parameters.checkpoint_name, save_runner=True) @staticmethod - def __average_validation(val, name): + def __average_validation(val, name, device="cpu"): """Average validation over multiple parallel processes.""" - tensor = torch.tensor(val) - avg_loss = hvd.allreduce(tensor, name=name, op=hvd.Average) + tensor = torch.tensor(val, device=device) + dist.all_reduce(tensor) + avg_loss = tensor / dist.get_world_size() return avg_loss.item() From 2f10a23acbf4056c18209ce375c7b9d398b1e9e4 Mon Sep 17 00:00:00 2001 From: Dayton Jon Vogel Date: Fri, 30 Jun 2023 15:07:56 -0700 Subject: [PATCH 02/22] allowing ddp wrapper to push network for saving during checkpoint --- mala/network/runner.py | 3 ++- mala/network/trainer.py | 12 ++++++------ 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/mala/network/runner.py b/mala/network/runner.py index a3f5ad158..f4a16e08c 100644 --- a/mala/network/runner.py +++ b/mala/network/runner.py @@ -77,7 +77,8 @@ def save_run(self, run_name, save_path="./", zip_run=True, optimizer_file = run_name+".optimizer.pth" self.parameters_full.save(os.path.join(save_path, params_file)) - self.network.save_network(os.path.join(save_path, model_file)) + self.network.module.save_network(os.path.join(save_path, model_file)) + #self.network.save_network(os.path.join(save_path, model_file)) self.data.input_data_scaler.save(os.path.join(save_path, iscaler_file)) self.data.output_data_scaler.save(os.path.join(save_path, oscaler_file)) diff --git a/mala/network/trainer.py b/mala/network/trainer.py index fc33abf0b..c52c8c5dc 100644 --- a/mala/network/trainer.py +++ b/mala/network/trainer.py @@ -241,7 +241,7 @@ def train_network(self): vloss = self.__average_validation(vloss, 'average_loss', self.parameters._configuration["device"]) self.initial_validation_loss = vloss - if self.data.test_data_set is not None: + if self.data.test_data_sets is not None: tloss = self.__average_validation(tloss, 'average_loss', self.parameters._configuration["device"]) self.initial_test_loss = tloss @@ -535,21 +535,21 @@ def __prepare_to_train(self, optimizer_dict): if self.parameters_full.use_distributed_sampler_train: self.train_sampler = torch.utils.data.\ - distributed.DistributedSampler(self.data.training_data_set, + distributed.DistributedSampler(self.data.training_data_sets, num_replicas=dist.get_world_size(), rank=dist.get_rank(), shuffle=do_shuffle) if self.parameters_full.use_distributed_sampler_val: self.validation_sampler = torch.utils.data.\ - distributed.DistributedSampler(self.data.validation_data_set, + distributed.DistributedSampler(self.data.validation_data_sets, num_replicas=dist.get_world_size(), rank=dist.get_rank(), shuffle=False) if self.parameters_full.use_distributed_sampler_test: - if self.data.test_data_set is not None: + if self.data.test_data_sets is not None: self.test_sampler = torch.utils.data.\ - distributed.DistributedSampler(self.data.test_data_set, + distributed.DistributedSampler(self.data.test_data_sets, num_replicas=dist.get_world_size(), rank=dist.get_rank(), shuffle=False) @@ -645,7 +645,7 @@ def __process_mini_batch(self, network, input_data, target_data): prediction = network(input_data) if self.parameters_full.use_ddp: # JOSHR: We have to use "module" here to access custom method of DDP wrapped model - loss = network.module.calcualte_loss(prediction, target_data) + loss = network.module.calculate_loss(prediction, target_data) else: loss = network.calculate_loss(prediction, target_data) From 05f551b206913fc23ae3c9574163e6382f187558 Mon Sep 17 00:00:00 2001 From: Dayton Jon Vogel Date: Sun, 2 Jul 2023 18:33:56 -0700 Subject: [PATCH 03/22] allow checkpoint network save when not using ddp --- mala/network/runner.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/mala/network/runner.py b/mala/network/runner.py index f4a16e08c..2ae78d21a 100644 --- a/mala/network/runner.py +++ b/mala/network/runner.py @@ -77,8 +77,10 @@ def save_run(self, run_name, save_path="./", zip_run=True, optimizer_file = run_name+".optimizer.pth" self.parameters_full.save(os.path.join(save_path, params_file)) - self.network.module.save_network(os.path.join(save_path, model_file)) - #self.network.save_network(os.path.join(save_path, model_file)) + if self.parameters_full.use_ddp: + self.network.module.save_network(os.path.join(save_path, model_file)) + else: + self.network.save_network(os.path.join(save_path, model_file)) self.data.input_data_scaler.save(os.path.join(save_path, iscaler_file)) self.data.output_data_scaler.save(os.path.join(save_path, oscaler_file)) From e60e9959f229a56db4192e1f756466c2da676350 Mon Sep 17 00:00:00 2001 From: Lenz Fiedler Date: Thu, 25 Apr 2024 17:59:23 +0200 Subject: [PATCH 04/22] blackified remaining files --- mala/common/parallelizer.py | 10 +- mala/common/parameters.py | 8 +- mala/network/objective_naswot.py | 6 +- mala/network/runner.py | 17 ++- mala/network/trainer.py | 196 +++++++++++++++++++++---------- 5 files changed, 158 insertions(+), 79 deletions(-) diff --git a/mala/common/parallelizer.py b/mala/common/parallelizer.py index 746a54476..160695a42 100644 --- a/mala/common/parallelizer.py +++ b/mala/common/parallelizer.py @@ -45,8 +45,9 @@ def set_ddp_status(new_value): """ if use_mpi is True and new_value is True: - raise Exception("Cannot use ddp and inference-level MPI at " - "the same time yet.") + raise Exception( + "Cannot use ddp and inference-level MPI at " "the same time yet." + ) global use_ddp use_ddp = new_value @@ -65,8 +66,9 @@ def set_mpi_status(new_value): """ if use_ddp is True and new_value is True: - raise Exception("Cannot use ddp and inference-level MPI at " - "the same time yet.") + raise Exception( + "Cannot use ddp and inference-level MPI at " "the same time yet." + ) global use_mpi use_mpi = new_value if use_mpi: diff --git a/mala/common/parameters.py b/mala/common/parameters.py index 539030c4d..711d2aaa9 100644 --- a/mala/common/parameters.py +++ b/mala/common/parameters.py @@ -806,7 +806,7 @@ def after_before_training_metric(self, value): if self._configuration["ddp"]: raise Exception( "Currently, MALA can only operate with the " - "\"ldos\" metric for ddp runs." + '"ldos" metric for ddp runs.' ) self._after_before_training_metric = value @@ -816,7 +816,7 @@ def during_training_metric(self, value): if self._configuration["ddp"]: raise Exception( "Currently, MALA can only operate with the " - "\"ldos\" metric for ddp runs." + '"ldos" metric for ddp runs.' ) self._during_training_metric = value @@ -1312,8 +1312,8 @@ def use_distributed_sampler_train(self, value): @property def use_distributed_sampler_val(self): - """Control whether or not distributed sampler is used to distribute validation data.""" - return self._use_distributed_sampler_val + """Control whether or not distributed sampler is used to distribute validation data.""" + return self._use_distributed_sampler_val @use_distributed_sampler_val.setter def use_distributed_sampler_val(self, value): diff --git a/mala/network/objective_naswot.py b/mala/network/objective_naswot.py index b7a49938b..96377e527 100644 --- a/mala/network/objective_naswot.py +++ b/mala/network/objective_naswot.py @@ -74,8 +74,10 @@ def __call__(self, trial): # Load the batchesand get the jacobian. do_shuffle = self.params.running.use_shuffling_for_samplers - if self.data_handler.parameters.use_lazy_loading or \ - self.params.use_ddp: + if ( + self.data_handler.parameters.use_lazy_loading + or self.params.use_ddp + ): do_shuffle = False if self.params.running.use_shuffling_for_samplers: self.data_handler.mix_datasets() diff --git a/mala/network/runner.py b/mala/network/runner.py index 18f16518f..81cd54736 100644 --- a/mala/network/runner.py +++ b/mala/network/runner.py @@ -85,7 +85,9 @@ def save_run( self.parameters_full.save(os.path.join(save_path, params_file)) if self.parameters_full.use_ddp: - self.network.module.save_network(os.path.join(save_path, model_file)) + self.network.module.save_network( + os.path.join(save_path, model_file) + ) else: self.network.save_network(os.path.join(save_path, model_file)) self.data.input_data_scaler.save(os.path.join(save_path, iscaler_file)) @@ -431,8 +433,15 @@ def __prepare_to_run(self): rank = dist.get_rank() local_rank = int(os.environ.get("LOCAL_RANK")) if self.parameters_full.verbosity >= 2: - print("size=", size, "global_rank=", rank, - "local_rank=", local_rank, "device=", - torch.cuda.get_device_name(local_rank)) + print( + "size=", + size, + "global_rank=", + rank, + "local_rank=", + local_rank, + "device=", + torch.cuda.get_device_name(local_rank), + ) # pin GPU to local rank torch.cuda.set_device(local_rank) diff --git a/mala/network/trainer.py b/mala/network/trainer.py index 24e9a71ef..5c94b4437 100644 --- a/mala/network/trainer.py +++ b/mala/network/trainer.py @@ -45,13 +45,13 @@ def __init__(self, params, network, data, optimizer_dict=None): super(Trainer, self).__init__(params, network, data) if self.parameters_full.use_ddp: - print("wrapping model in ddp..") - # JOSHR: using streams here to maintain compatibility with - # graph capture - s = torch.cuda.Stream() - with torch.cuda.stream(s): - self.network = DDP(self.network) - torch.cuda.current_stream().wait_stream(s) + print("wrapping model in ddp..") + # JOSHR: using streams here to maintain compatibility with + # graph capture + s = torch.cuda.Stream() + with torch.cuda.stream(s): + self.network = DDP(self.network) + torch.cuda.current_stream().wait_stream(s) self.final_test_loss = float("inf") self.initial_test_loss = float("inf") @@ -264,12 +264,16 @@ def train_network(self): # Collect and average all the losses from all the devices if self.parameters_full.use_ddp: - vloss = self.__average_validation(vloss, 'average_loss', - self.parameters._configuration["device"]) + vloss = self.__average_validation( + vloss, "average_loss", self.parameters._configuration["device"] + ) self.initial_validation_loss = vloss if self.data.test_data_sets is not None: - tloss = self.__average_validation(tloss, 'average_loss', - self.parameters._configuration["device"]) + tloss = self.__average_validation( + tloss, + "average_loss", + self.parameters._configuration["device"], + ) self.initial_test_loss = tloss printout( @@ -416,8 +420,11 @@ def train_network(self): ) if self.parameters_full.use_ddp: - vloss = self.__average_validation(vloss, 'average_loss', - self.parameters._configuration["device"]) + vloss = self.__average_validation( + vloss, + "average_loss", + self.parameters._configuration["device"], + ) if self.parameters_full.verbosity > 1: printout( "Epoch {0}: validation data loss: {1}, " @@ -527,15 +534,21 @@ def train_network(self): # CALCULATE FINAL METRICS ############################ - if self.parameters.after_before_training_metric != \ - self.parameters.during_training_metric: - vloss = self.__validate_network(self.network, - "validation", - self.parameters. - after_before_training_metric) + if ( + self.parameters.after_before_training_metric + != self.parameters.during_training_metric + ): + vloss = self.__validate_network( + self.network, + "validation", + self.parameters.after_before_training_metric, + ) if self.parameters_full.use_ddp: - vloss = self.__average_validation(vloss, 'average_loss', - self.parameters._configuration["device"]) + vloss = self.__average_validation( + vloss, + "average_loss", + self.parameters._configuration["device"], + ) # Calculate final loss. self.final_validation_loss = vloss @@ -543,13 +556,17 @@ def train_network(self): tloss = float("inf") if len(self.data.test_data_sets) > 0: - tloss = self.__validate_network(self.network, - "test", - self.parameters. - after_before_training_metric) + tloss = self.__validate_network( + self.network, + "test", + self.parameters.after_before_training_metric, + ) if self.parameters_full.use_ddp: - tloss = self.__average_validation(tloss, 'average_loss', - self.parameters._configuration["device"]) + tloss = self.__average_validation( + tloss, + "average_loss", + self.parameters._configuration["device"], + ) printout("Final test data loss: ", tloss, min_verbosity=0) self.final_test_loss = tloss @@ -577,10 +594,14 @@ def __prepare_to_train(self, optimizer_dict): # Scale the learning rate according to ddp. if self.parameters_full.use_ddp: if dist.get_world_size() > 1 and self.last_epoch == 0: - printout("Rescaling learning rate because multiple workers are" - " used for training.", min_verbosity=1) - self.parameters.learning_rate = self.parameters.learning_rate \ - * dist.get_world_size() + printout( + "Rescaling learning rate because multiple workers are" + " used for training.", + min_verbosity=1, + ) + self.parameters.learning_rate = ( + self.parameters.learning_rate * dist.get_world_size() + ) # Choose an optimizer to use. if self.parameters.trainingtype == "SGD": @@ -632,25 +653,34 @@ def __prepare_to_train(self, optimizer_dict): do_shuffle = False if self.parameters_full.use_distributed_sampler_train: - self.train_sampler = torch.utils.data.\ - distributed.DistributedSampler(self.data.training_data_sets, - num_replicas=dist.get_world_size(), - rank=dist.get_rank(), - shuffle=do_shuffle) + self.train_sampler = ( + torch.utils.data.distributed.DistributedSampler( + self.data.training_data_sets, + num_replicas=dist.get_world_size(), + rank=dist.get_rank(), + shuffle=do_shuffle, + ) + ) if self.parameters_full.use_distributed_sampler_val: - self.validation_sampler = torch.utils.data.\ - distributed.DistributedSampler(self.data.validation_data_sets, - num_replicas=dist.get_world_size(), - rank=dist.get_rank(), - shuffle=False) + self.validation_sampler = ( + torch.utils.data.distributed.DistributedSampler( + self.data.validation_data_sets, + num_replicas=dist.get_world_size(), + rank=dist.get_rank(), + shuffle=False, + ) + ) if self.parameters_full.use_distributed_sampler_test: if self.data.test_data_sets is not None: - self.test_sampler = torch.utils.data.\ - distributed.DistributedSampler(self.data.test_data_sets, - num_replicas=dist.get_world_size(), - rank=dist.get_rank(), - shuffle=False) + self.test_sampler = ( + torch.utils.data.distributed.DistributedSampler( + self.data.test_data_sets, + num_replicas=dist.get_world_size(), + rank=dist.get_rank(), + shuffle=False, + ) + ) # Instantiate the learning rate scheduler, if necessary. if self.parameters.learning_rate_scheduler == "ReduceLROnPlateau": @@ -676,8 +706,10 @@ def __prepare_to_train(self, optimizer_dict): # epoch. # This shuffling is done in the dataset themselves. do_shuffle = self.parameters.use_shuffling_for_samplers - if self.data.parameters.use_lazy_loading or self.parameters_full.\ - use_ddp: + if ( + self.data.parameters.use_lazy_loading + or self.parameters_full.use_ddp + ): do_shuffle = False # Prepare data loaders.(look into mini-batch size) @@ -774,9 +806,13 @@ def __process_mini_batch(self, network, input_data, target_data): prediction = network(input_data) if self.parameters_full.use_ddp: # JOSHR: We have to use "module" here to access custom method of DDP wrapped model - loss = network.module.calculate_loss(prediction, target_data) + loss = network.module.calculate_loss( + prediction, target_data + ) else: - loss = network.calculate_loss(prediction, target_data) + loss = network.calculate_loss( + prediction, target_data + ) if self.gradscaler: self.gradscaler.scale(loss).backward() @@ -802,9 +838,13 @@ def __process_mini_batch(self, network, input_data, target_data): ) if self.parameters_full.use_ddp: - self.static_loss = network.module.calculate_loss(self.static_prediction, self.static_target_data) + self.static_loss = network.module.calculate_loss( + self.static_prediction, self.static_target_data + ) else: - self.static_loss = network.calculate_loss(self.static_prediction, self.static_target_data) + self.static_loss = network.calculate_loss( + self.static_prediction, self.static_target_data + ) if self.gradscaler: self.gradscaler.scale(self.static_loss).backward() @@ -831,7 +871,9 @@ def __process_mini_batch(self, network, input_data, target_data): torch.cuda.nvtx.range_push("loss") if self.parameters_full.use_ddp: - loss = network.module.calculate_loss(prediction, target_data) + loss = network.module.calculate_loss( + prediction, target_data + ) else: loss = network.calculate_loss(prediction, target_data) # loss @@ -936,9 +978,13 @@ def __validate_network(self, network, data_set_type, validation_type): ): prediction = network(x) if self.parameters_full.use_ddp: - loss = network.module.calculate_loss(prediction, y) + loss = network.module.calculate_loss( + prediction, y + ) else: - loss = network.calculate_loss(prediction, y) + loss = network.calculate_loss( + prediction, y + ) torch.cuda.current_stream( self.parameters._configuration["device"] ).wait_stream(s) @@ -954,12 +1000,24 @@ def __validate_network(self, network, data_set_type, validation_type): # Capture graph self.validation_graph = torch.cuda.CUDAGraph() with torch.cuda.graph(self.validation_graph): - with torch.cuda.amp.autocast(enabled=self.parameters.use_mixed_precision): - self.static_prediction_validation = network(self.static_input_validation) + with torch.cuda.amp.autocast( + enabled=self.parameters.use_mixed_precision + ): + self.static_prediction_validation = ( + network( + self.static_input_validation + ) + ) if self.parameters_full.use_ddp: - self.static_loss_validation = network.module.calculate_loss(self.static_prediction_validation, self.static_target_validation) + self.static_loss_validation = network.module.calculate_loss( + self.static_prediction_validation, + self.static_target_validation, + ) else: - self.static_loss_validation = network.calculate_loss(self.static_prediction_validation, self.static_target_validation) + self.static_loss_validation = network.calculate_loss( + self.static_prediction_validation, + self.static_target_validation, + ) if self.validation_graph: self.static_input_validation.copy_(x) @@ -974,9 +1032,13 @@ def __validate_network(self, network, data_set_type, validation_type): ): prediction = network(x) if self.parameters_full.use_ddp: - loss = network.module.calculate_loss(prediction, y) + loss = network.module.calculate_loss( + prediction, y + ) else: - loss = network.calculate_loss(prediction, y) + loss = network.calculate_loss( + prediction, y + ) validation_loss_sum += loss if ( batchid != 0 @@ -1009,11 +1071,15 @@ def __validate_network(self, network, data_set_type, validation_type): y = y.to(self.parameters._configuration["device"]) prediction = network(x) if self.parameters_full.use_ddp: - validation_loss_sum += \ - network.module.calculate_loss(prediction, y).item() + validation_loss_sum += ( + network.module.calculate_loss( + prediction, y + ).item() + ) else: - validation_loss_sum += \ - network.calculate_loss(prediction, y).item() + validation_loss_sum += network.calculate_loss( + prediction, y + ).item() batchid += 1 validation_loss = validation_loss_sum.item() / batchid From 4175cd0207e7aa4e106d69cababe92a0cf6e7542 Mon Sep 17 00:00:00 2001 From: Lenz Fiedler Date: Thu, 25 Apr 2024 18:06:30 +0200 Subject: [PATCH 05/22] Added suggestions by Josh Co-authored-by: Josh Romero --- mala/network/trainer.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/mala/network/trainer.py b/mala/network/trainer.py index 5c94b4437..48b9680bd 100644 --- a/mala/network/trainer.py +++ b/mala/network/trainer.py @@ -268,7 +268,7 @@ def train_network(self): vloss, "average_loss", self.parameters._configuration["device"] ) self.initial_validation_loss = vloss - if self.data.test_data_sets is not None: + if self.data.test_data_sets: tloss = self.__average_validation( tloss, "average_loss", @@ -655,7 +655,7 @@ def __prepare_to_train(self, optimizer_dict): if self.parameters_full.use_distributed_sampler_train: self.train_sampler = ( torch.utils.data.distributed.DistributedSampler( - self.data.training_data_sets, + self.data.training_data_sets[0], num_replicas=dist.get_world_size(), rank=dist.get_rank(), shuffle=do_shuffle, @@ -664,7 +664,7 @@ def __prepare_to_train(self, optimizer_dict): if self.parameters_full.use_distributed_sampler_val: self.validation_sampler = ( torch.utils.data.distributed.DistributedSampler( - self.data.validation_data_sets, + self.data.validation_data_sets[0], num_replicas=dist.get_world_size(), rank=dist.get_rank(), shuffle=False, @@ -675,7 +675,7 @@ def __prepare_to_train(self, optimizer_dict): if self.data.test_data_sets is not None: self.test_sampler = ( torch.utils.data.distributed.DistributedSampler( - self.data.test_data_sets, + self.data.test_data_sets[0], num_replicas=dist.get_world_size(), rank=dist.get_rank(), shuffle=False, From 2cde3f160458f359618d31bbe9d36454b6218c54 Mon Sep 17 00:00:00 2001 From: Lenz Fiedler Date: Thu, 25 Apr 2024 18:40:03 +0200 Subject: [PATCH 06/22] Removed DDP in yaml file --- install/mala_gpu_base_environment.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/install/mala_gpu_base_environment.yml b/install/mala_gpu_base_environment.yml index 7f78d40fd..340fef170 100644 --- a/install/mala_gpu_base_environment.yml +++ b/install/mala_gpu_base_environment.yml @@ -1,4 +1,4 @@ -name: mala-gpu-ddp +name: mala-gpu channels: - conda-forge - defaults From bd68063df3f6a240886133f7db0e3f7585665723 Mon Sep 17 00:00:00 2001 From: Lenz Fiedler Date: Fri, 26 Apr 2024 11:12:53 +0200 Subject: [PATCH 07/22] Minor reformatting --- mala/common/parameters.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/mala/common/parameters.py b/mala/common/parameters.py index 711d2aaa9..6a431e04f 100644 --- a/mala/common/parameters.py +++ b/mala/common/parameters.py @@ -1335,15 +1335,18 @@ def use_ddp(self, value): if value: print("initializing torch.distributed.") # JOSHR: - # We start up torch distributed here. As is fairly standard convention, we get the rank - # and world size arguments via environment variables (RANK, WORLD_SIZE). In addition to - # those variables, LOCAL_RANK, MASTER_ADDR and MASTER_PORT should be set. + # We start up torch distributed here. As is fairly standard + # convention, we get the rank and world size arguments via + # environment variables (RANK, WORLD_SIZE). In addition to + # those variables, LOCAL_RANK, MASTER_ADDR and MASTER_PORT + # should be set. rank = int(os.environ.get("RANK")) world_size = int(os.environ.get("WORLD_SIZE")) + dist.init_process_group("nccl", rank=rank, world_size=world_size) - # Invalidate, will be updated in setter. set_ddp_status(value) + # Invalidate, will be updated in setter. self.device = None self._use_ddp = value self.network._update_ddp(self.use_ddp) From 3cebb9d810820fd77627f14e0f4e1c5c7b3f3df2 Mon Sep 17 00:00:00 2001 From: Lenz Fiedler Date: Fri, 26 Apr 2024 11:21:22 +0200 Subject: [PATCH 08/22] Small bug --- mala/network/trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mala/network/trainer.py b/mala/network/trainer.py index 48b9680bd..ead4b4d5c 100644 --- a/mala/network/trainer.py +++ b/mala/network/trainer.py @@ -672,7 +672,7 @@ def __prepare_to_train(self, optimizer_dict): ) if self.parameters_full.use_distributed_sampler_test: - if self.data.test_data_sets is not None: + if self.data.test_data_sets: self.test_sampler = ( torch.utils.data.distributed.DistributedSampler( self.data.test_data_sets[0], From ffa3082f01f821b282a888b24cbde1aed16ef71c Mon Sep 17 00:00:00 2001 From: Lenz Fiedler Date: Fri, 26 Apr 2024 11:36:56 +0200 Subject: [PATCH 09/22] Model only saved on master rank --- mala/network/runner.py | 90 +++++++++++++++++++++++------------------- 1 file changed, 49 insertions(+), 41 deletions(-) diff --git a/mala/network/runner.py b/mala/network/runner.py index 81cd54736..896e8b720 100644 --- a/mala/network/runner.py +++ b/mala/network/runner.py @@ -7,6 +7,7 @@ import torch import torch.distributed as dist +from mala.common.parallelizer import get_rank from mala.common.parameters import ParametersRunning from mala.network.network import Network from mala.datahandling.data_scaler import DataScaler @@ -76,55 +77,62 @@ def save_run( data is already present in the DataHandler object, it can be saved by setting. """ - model_file = run_name + ".network.pth" - iscaler_file = run_name + ".iscaler.pkl" - oscaler_file = run_name + ".oscaler.pkl" - params_file = run_name + ".params.json" - if save_runner: - optimizer_file = run_name + ".optimizer.pth" - - self.parameters_full.save(os.path.join(save_path, params_file)) - if self.parameters_full.use_ddp: - self.network.module.save_network( - os.path.join(save_path, model_file) + # If a model is trained via DDP, we need to make sure saving is only + # performed on rank 0. + if get_rank() == 0: + model_file = run_name + ".network.pth" + iscaler_file = run_name + ".iscaler.pkl" + oscaler_file = run_name + ".oscaler.pkl" + params_file = run_name + ".params.json" + if save_runner: + optimizer_file = run_name + ".optimizer.pth" + + self.parameters_full.save(os.path.join(save_path, params_file)) + if self.parameters_full.use_ddp: + self.network.module.save_network( + os.path.join(save_path, model_file) + ) + else: + self.network.save_network(os.path.join(save_path, model_file)) + self.data.input_data_scaler.save( + os.path.join(save_path, iscaler_file) + ) + self.data.output_data_scaler.save( + os.path.join(save_path, oscaler_file) ) - else: - self.network.save_network(os.path.join(save_path, model_file)) - self.data.input_data_scaler.save(os.path.join(save_path, iscaler_file)) - self.data.output_data_scaler.save( - os.path.join(save_path, oscaler_file) - ) - files = [model_file, iscaler_file, oscaler_file, params_file] - if save_runner: - files += [optimizer_file] - if zip_run: - if additional_calculation_data is not None: - additional_calculation_file = run_name + ".info.json" - if isinstance(additional_calculation_data, str): - self.data.target_calculator.read_additional_calculation_data( - additional_calculation_data - ) - self.data.target_calculator.write_additional_calculation_data( - os.path.join(save_path, additional_calculation_file) - ) - elif isinstance(additional_calculation_data, bool): - if additional_calculation_data: + files = [model_file, iscaler_file, oscaler_file, params_file] + if save_runner: + files += [optimizer_file] + if zip_run: + if additional_calculation_data is not None: + additional_calculation_file = run_name + ".info.json" + if isinstance(additional_calculation_data, str): + self.data.target_calculator.read_additional_calculation_data( + additional_calculation_data + ) self.data.target_calculator.write_additional_calculation_data( os.path.join( save_path, additional_calculation_file ) ) + elif isinstance(additional_calculation_data, bool): + if additional_calculation_data: + self.data.target_calculator.write_additional_calculation_data( + os.path.join( + save_path, additional_calculation_file + ) + ) - files.append(additional_calculation_file) - with ZipFile( - os.path.join(save_path, run_name + ".zip"), - "w", - compression=ZIP_STORED, - ) as zip_obj: - for file in files: - zip_obj.write(os.path.join(save_path, file), file) - os.remove(os.path.join(save_path, file)) + files.append(additional_calculation_file) + with ZipFile( + os.path.join(save_path, run_name + ".zip"), + "w", + compression=ZIP_STORED, + ) as zip_obj: + for file in files: + zip_obj.write(os.path.join(save_path, file), file) + os.remove(os.path.join(save_path, file)) @classmethod def load_run( From 04b00506b7852610722a1aaa184284671018d8f2 Mon Sep 17 00:00:00 2001 From: Lenz Fiedler Date: Fri, 26 Apr 2024 11:43:02 +0200 Subject: [PATCH 10/22] Adjusted output for parallel --- mala/network/trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mala/network/trainer.py b/mala/network/trainer.py index ead4b4d5c..a100d5c35 100644 --- a/mala/network/trainer.py +++ b/mala/network/trainer.py @@ -45,7 +45,7 @@ def __init__(self, params, network, data, optimizer_dict=None): super(Trainer, self).__init__(params, network, data) if self.parameters_full.use_ddp: - print("wrapping model in ddp..") + printout("DDP activated, wrapping model in DDP.", min_verbosity=1) # JOSHR: using streams here to maintain compatibility with # graph capture s = torch.cuda.Stream() From 1fb2c98e75e0d12e3a96016b09e06fe7df40c7e2 Mon Sep 17 00:00:00 2001 From: Lenz Fiedler Date: Tue, 30 Apr 2024 08:50:09 +0200 Subject: [PATCH 11/22] Testing if distributed samplers work as default --- mala/common/parameters.py | 33 ------------------------------- mala/network/trainer.py | 41 ++++++++++++++++++--------------------- 2 files changed, 19 insertions(+), 55 deletions(-) diff --git a/mala/common/parameters.py b/mala/common/parameters.py index 6a431e04f..65523d048 100644 --- a/mala/common/parameters.py +++ b/mala/common/parameters.py @@ -1208,9 +1208,6 @@ def __init__(self): # Properties self.use_gpu = False self.use_ddp = False - self.use_distributed_sampler_train = True - self.use_distributed_sampler_val = True - self.use_distributed_sampler_test = True self.use_mpi = False self.verbosity = 1 self.device = "cpu" @@ -1300,36 +1297,6 @@ def use_ddp(self): """Control whether or not dd is used for parallel training.""" return self._use_ddp - @property - def use_distributed_sampler_train(self): - """Control wether or not distributed sampler is used to distribute training data.""" - return self._use_distributed_sampler_train - - @use_distributed_sampler_train.setter - def use_distributed_sampler_train(self, value): - """Control whether or not distributed sampler is used to distribute training data.""" - self._use_distributed_sampler_train = value - - @property - def use_distributed_sampler_val(self): - """Control whether or not distributed sampler is used to distribute validation data.""" - return self._use_distributed_sampler_val - - @use_distributed_sampler_val.setter - def use_distributed_sampler_val(self, value): - """Control whether or not distributed sampler is used to distribute validation data.""" - self._use_distributed_sampler_val = value - - @property - def use_distributed_sampler_test(self): - """Control whether or not distributed sampler is used to distribute test data.""" - return self._use_distributed_sampler_test - - @use_distributed_sampler_test.setter - def use_distributed_sampler_test(self, value): - """Control whether or not distributed sampler is used to distribute test data.""" - self._use_distributed_sampler_test = value - @use_ddp.setter def use_ddp(self, value): if value: diff --git a/mala/network/trainer.py b/mala/network/trainer.py index a100d5c35..bb9d4d41b 100644 --- a/mala/network/trainer.py +++ b/mala/network/trainer.py @@ -652,36 +652,33 @@ def __prepare_to_train(self, optimizer_dict): if self.data.parameters.use_lazy_loading: do_shuffle = False - if self.parameters_full.use_distributed_sampler_train: - self.train_sampler = ( - torch.utils.data.distributed.DistributedSampler( - self.data.training_data_sets[0], - num_replicas=dist.get_world_size(), - rank=dist.get_rank(), - shuffle=do_shuffle, - ) + self.train_sampler = ( + torch.utils.data.distributed.DistributedSampler( + self.data.training_data_sets[0], + num_replicas=dist.get_world_size(), + rank=dist.get_rank(), + shuffle=do_shuffle, ) - if self.parameters_full.use_distributed_sampler_val: - self.validation_sampler = ( + ) + self.validation_sampler = ( + torch.utils.data.distributed.DistributedSampler( + self.data.validation_data_sets[0], + num_replicas=dist.get_world_size(), + rank=dist.get_rank(), + shuffle=False, + ) + ) + + if self.data.test_data_sets: + self.test_sampler = ( torch.utils.data.distributed.DistributedSampler( - self.data.validation_data_sets[0], + self.data.test_data_sets[0], num_replicas=dist.get_world_size(), rank=dist.get_rank(), shuffle=False, ) ) - if self.parameters_full.use_distributed_sampler_test: - if self.data.test_data_sets: - self.test_sampler = ( - torch.utils.data.distributed.DistributedSampler( - self.data.test_data_sets[0], - num_replicas=dist.get_world_size(), - rank=dist.get_rank(), - shuffle=False, - ) - ) - # Instantiate the learning rate scheduler, if necessary. if self.parameters.learning_rate_scheduler == "ReduceLROnPlateau": self.scheduler = optim.lr_scheduler.ReduceLROnPlateau( From a9027a7685860dadce3e5abb6742944ea65e85c3 Mon Sep 17 00:00:00 2001 From: Lenz Fiedler Date: Tue, 30 Apr 2024 09:17:16 +0200 Subject: [PATCH 12/22] Added some documentation --- docs/source/advanced_usage/trainingmodel.rst | 57 ++++++++++++++++++++ 1 file changed, 57 insertions(+) diff --git a/docs/source/advanced_usage/trainingmodel.rst b/docs/source/advanced_usage/trainingmodel.rst index ddb429368..d0228237b 100644 --- a/docs/source/advanced_usage/trainingmodel.rst +++ b/docs/source/advanced_usage/trainingmodel.rst @@ -220,3 +220,60 @@ via The full path for ``path_to_visualization`` can be accessed via ``trainer.full_visualization_path``. + + +Training in parallel +******************** + +If large models or large data sets are employed, training may be slow even +if a GPU is used. In this case, multiple GPUs can be employed with MALA +using the ``DistributedDataParallel`` (DDP) formalism of the ``torch`` library. +To use DDP, make sure you have `NCCL `_ +installed on your system. + +To activate and use DDP in MALA, almost no modification of your training script +is necessary. Simply activate DDP in your ``Parameters`` object. Make sure to +also enable GPU, since parallel training is currently only supported on GPUs. + + .. code-block:: python + + parameters = mala.Parameters() + parameters.use_gpu = True + parameters.use_ddp = True + +MALA is now set up for parallel training. DDP works across multiple compute +nodes on HPC infrastructure as well as on a single machine hosting multiple +GPUs. While essentially no modification of the python script is necessary, some +modifications for calling the python script may be necessary, to ensure +that DDP has all the information it needs for inter/intra-node communication. +This setup *may* differ across machines/clusters. During testing, the +following setup was confirmed to work on an HPC cluster using the +``slurm`` scheduler. + + .. code-block:: bash + + #SBATCH --nodes=NUMBER_OF_NODES + #SBATCH --ntasks-per-node=NUMBER_OF_TASKS_PER_NODE + #SBATCH --gres=gpu:NUMBER_OF_TASKS_PER_NODE + # Add more arguments as needed + ... + + # Load more modules as needed + ... + + # This port can be arbitrarily chosen. + export MASTER_PORT=12342 + + # Find out the host node. + echo "NODELIST="${SLURM_NODELIST} + master_addr=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1) + export MASTER_ADDR=$master_addr + echo "MASTER_ADDR="$MASTER_ADDR + + # Run using torchrun. + torchrun --nnodes NUMBER_OF_NODES --nproc_per_node NUMBER_OF_TASKS_PER_NODE --rdzv_id "$SLURM_JOB_ID" training.py + +This script follows `this tutorial `_. +A tutorial on DDP itself can be found `here `_. + + From af1081ead5f7c8e5ec2b66cfc676bc2dd3d617bd Mon Sep 17 00:00:00 2001 From: Lenz Fiedler Date: Tue, 30 Apr 2024 15:21:24 +0200 Subject: [PATCH 13/22] This should fix the inference --- mala/common/parameters.py | 16 ++++++++++++---- mala/network/runner.py | 10 +++++++++- mala/network/trainer.py | 4 ++++ 3 files changed, 25 insertions(+), 5 deletions(-) diff --git a/mala/common/parameters.py b/mala/common/parameters.py index 65523d048..6a8baec76 100644 --- a/mala/common/parameters.py +++ b/mala/common/parameters.py @@ -1543,7 +1543,9 @@ def optuna_singlenode_setup(self, wait_time=0): self.hyperparameters._update_device(device_temp) @classmethod - def load_from_file(cls, file, save_format="json", no_snapshots=False): + def load_from_file( + cls, file, save_format="json", no_snapshots=False, force_no_ddp=False + ): """ Load a Parameters object from a file. @@ -1598,7 +1600,10 @@ def load_from_file(cls, file, save_format="json", no_snapshots=False): not isinstance(json_dict[key], dict) or key == "openpmd_configuration" ): - setattr(loaded_parameters, key, json_dict[key]) + if key == "use_ddp" and force_no_ddp is True: + setattr(loaded_parameters, key, False) + else: + setattr(loaded_parameters, key, json_dict[key]) if no_snapshots is True: loaded_parameters.data.snapshot_directories_list = [] else: @@ -1631,7 +1636,7 @@ def load_from_pickle(cls, file, no_snapshots=False): ) @classmethod - def load_from_json(cls, file, no_snapshots=False): + def load_from_json(cls, file, no_snapshots=False, force_no_ddp=False): """ Load a Parameters object from a json file. @@ -1651,5 +1656,8 @@ def load_from_json(cls, file, no_snapshots=False): """ return Parameters.load_from_file( - file, save_format="json", no_snapshots=no_snapshots + file, + save_format="json", + no_snapshots=no_snapshots, + force_no_ddp=force_no_ddp, ) diff --git a/mala/network/runner.py b/mala/network/runner.py index 896e8b720..5e6ecdafa 100644 --- a/mala/network/runner.py +++ b/mala/network/runner.py @@ -7,6 +7,7 @@ import torch import torch.distributed as dist +import mala from mala.common.parallelizer import get_rank from mala.common.parameters import ParametersRunning from mala.network.network import Network @@ -145,6 +146,7 @@ def load_run( prepare_data=False, load_with_mpi=None, load_with_gpu=None, + load_with_ddp=None, ): """ Load a run. @@ -231,7 +233,13 @@ def load_run( path, run_name + ".params." + params_format ) - loaded_params = Parameters.load_from_json(loaded_params) + # Neither Predictor nor Runner classes can work with DDP. + if cls is mala.Trainer: + loaded_params = Parameters.load_from_json(loaded_params) + else: + loaded_params = Parameters.load_from_json( + loaded_params, force_no_ddp=True + ) # MPI has to be specified upon loading, in contrast to GPU. if load_with_mpi is not None: diff --git a/mala/network/trainer.py b/mala/network/trainer.py index bb9d4d41b..430a0cf47 100644 --- a/mala/network/trainer.py +++ b/mala/network/trainer.py @@ -156,6 +156,7 @@ def load_run( params_format="json", load_runner=True, prepare_data=True, + load_with_ddp=None, ): """ Load a run. @@ -205,6 +206,9 @@ def load_run( params_format=params_format, load_runner=load_runner, prepare_data=prepare_data, + load_with_gpu=None, + load_with_mpi=None, + load_with_ddp=load_with_ddp, ) @classmethod From 18fa6e23e707237a4e5af658035992a8fb443014 Mon Sep 17 00:00:00 2001 From: Lenz Fiedler Date: Thu, 2 May 2024 07:49:56 +0200 Subject: [PATCH 14/22] Trying to fix checkpointing --- mala/network/trainer.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/mala/network/trainer.py b/mala/network/trainer.py index 430a0cf47..df4e7c848 100644 --- a/mala/network/trainer.py +++ b/mala/network/trainer.py @@ -14,6 +14,7 @@ from torch.utils.tensorboard import SummaryWriter from mala.common.parameters import printout +from mala.common.parallelizer import get_local_rank from mala.datahandling.fast_tensor_dataset import FastTensorDataset from mala.network.runner import Runner from mala.datahandling.lazy_load_dataset_single import LazyLoadDatasetSingle @@ -238,7 +239,9 @@ def _load_from_run(cls, params, network, data, file=None): The trainer that was loaded from the file. """ # First, load the checkpoint. - checkpoint = torch.load(file) + if params.use_ddp: + map_location = {"cuda:%d" % 0: "cuda:%d" % get_local_rank()} + checkpoint = torch.load(file, map_location=map_location) # Now, create the Trainer class with it. loaded_trainer = Trainer( From f49e63d5c04314e9f2eeecec59625832c208182a Mon Sep 17 00:00:00 2001 From: Lenz Fiedler Date: Thu, 2 May 2024 08:34:06 +0200 Subject: [PATCH 15/22] Added docs for new loading parameters --- mala/network/runner.py | 12 ++++++++++-- mala/network/trainer.py | 3 +-- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/mala/network/runner.py b/mala/network/runner.py index 5e6ecdafa..a5f620071 100644 --- a/mala/network/runner.py +++ b/mala/network/runner.py @@ -174,7 +174,7 @@ def load_run( If True, the data will be loaded into memory. This is needed when continuing a model training. - load_with_mpi : bool + load_with_mpi : bool or None Can be used to actively enable/disable MPI during loading. Default is None, so that the MPI parameters set during training/saving of the model are not overwritten. @@ -182,7 +182,7 @@ def load_run( MPI already has to be activated here, if it was not activated during training! - load_with_gpu : bool + load_with_gpu : bool or None Can be used to actively enable/disable GPU during loading. Default is None, so that the GPU parameters set during training/saving of the model are not overwritten. @@ -191,6 +191,14 @@ def load_run( activated during training. Can also be used to activate a CPU based inference, by setting it to False. + load_with_ddp : bool or None + Can be used to actively disable DDP (pytorch distributed + data parallel used for parallel training) during loading. + Default is None, which for loading a Trainer object will not + interfere with DDP settings. For Predictor and Tester class, + this command will automatically disable DDP during loading, + as inference is using MPI rather than DDP for parallelization. + Return ------ loaded_params : mala.common.parameters.Parameters diff --git a/mala/network/trainer.py b/mala/network/trainer.py index df4e7c848..f8bf391f5 100644 --- a/mala/network/trainer.py +++ b/mala/network/trainer.py @@ -157,7 +157,6 @@ def load_run( params_format="json", load_runner=True, prepare_data=True, - load_with_ddp=None, ): """ Load a run. @@ -209,7 +208,7 @@ def load_run( prepare_data=prepare_data, load_with_gpu=None, load_with_mpi=None, - load_with_ddp=load_with_ddp, + load_with_ddp=None, ) @classmethod From e1753d0f3f97809841682f13c0c556abaf7948c8 Mon Sep 17 00:00:00 2001 From: Lenz Fiedler Date: Thu, 2 May 2024 08:34:18 +0200 Subject: [PATCH 16/22] This should fix lazy loading mixing --- mala/datahandling/data_handler.py | 3 +++ mala/datahandling/lazy_load_dataset.py | 9 ++++++++- 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/mala/datahandling/data_handler.py b/mala/datahandling/data_handler.py index dae111c0d..96d1dc6c0 100644 --- a/mala/datahandling/data_handler.py +++ b/mala/datahandling/data_handler.py @@ -640,6 +640,7 @@ def __build_datasets(self): self.descriptor_calculator, self.target_calculator, self.use_ddp, + self.parameters._configuration["device"] ) ) self.validation_data_sets.append( @@ -651,6 +652,7 @@ def __build_datasets(self): self.descriptor_calculator, self.target_calculator, self.use_ddp, + self.parameters._configuration["device"] ) ) @@ -664,6 +666,7 @@ def __build_datasets(self): self.descriptor_calculator, self.target_calculator, self.use_ddp, + self.parameters._configuration["device"] input_requires_grad=True, ) ) diff --git a/mala/datahandling/lazy_load_dataset.py b/mala/datahandling/lazy_load_dataset.py index f37fdb60d..a3af4ab64 100644 --- a/mala/datahandling/lazy_load_dataset.py +++ b/mala/datahandling/lazy_load_dataset.py @@ -59,6 +59,7 @@ def __init__( descriptor_calculator, target_calculator, use_ddp, + device, input_requires_grad=False, ): self.snapshot_list = [] @@ -79,6 +80,7 @@ def __init__( self.use_ddp = use_ddp self.return_outputs_directly = False self.input_requires_grad = input_requires_grad + self.device = device @property def return_outputs_directly(self): @@ -119,8 +121,13 @@ def mix_datasets(self): used_perm = torch.randperm(self.number_of_snapshots) barrier() if self.use_ddp: + used_perm.to(device=self.device) used_perm = dist.broadcast(used_perm, 0) - self.snapshot_list = [self.snapshot_list[i] for i in used_perm] + self.snapshot_list = [ + self.snapshot_list[i] for i in used_perm.to("cpu") + ] + else: + self.snapshot_list = [self.snapshot_list[i] for i in used_perm] self.get_new_data(0) def get_new_data(self, file_index): From 36e626c84d7f9147374cbe0ce1177eb4eb0e3ff0 Mon Sep 17 00:00:00 2001 From: Lenz Fiedler Date: Thu, 2 May 2024 08:38:05 +0200 Subject: [PATCH 17/22] Missing comma --- mala/datahandling/data_handler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mala/datahandling/data_handler.py b/mala/datahandling/data_handler.py index 96d1dc6c0..266664e59 100644 --- a/mala/datahandling/data_handler.py +++ b/mala/datahandling/data_handler.py @@ -666,7 +666,7 @@ def __build_datasets(self): self.descriptor_calculator, self.target_calculator, self.use_ddp, - self.parameters._configuration["device"] + self.parameters._configuration["device"], input_requires_grad=True, ) ) From d9c7a73562c6798513ff06bf3c6c5db3e28f468b Mon Sep 17 00:00:00 2001 From: Lenz Fiedler Date: Thu, 2 May 2024 08:40:13 +0200 Subject: [PATCH 18/22] Made printing for DDP init debug only --- mala/common/parameters.py | 3 ++- mala/datahandling/data_handler.py | 4 ++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/mala/common/parameters.py b/mala/common/parameters.py index 6a8baec76..3627bd40f 100644 --- a/mala/common/parameters.py +++ b/mala/common/parameters.py @@ -1300,7 +1300,8 @@ def use_ddp(self): @use_ddp.setter def use_ddp(self, value): if value: - print("initializing torch.distributed.") + if self.verbosity > 1: + print("Initializing torch.distributed.") # JOSHR: # We start up torch distributed here. As is fairly standard # convention, we get the rank and world size arguments via diff --git a/mala/datahandling/data_handler.py b/mala/datahandling/data_handler.py index 266664e59..7b8fc2a43 100644 --- a/mala/datahandling/data_handler.py +++ b/mala/datahandling/data_handler.py @@ -640,7 +640,7 @@ def __build_datasets(self): self.descriptor_calculator, self.target_calculator, self.use_ddp, - self.parameters._configuration["device"] + self.parameters._configuration["device"], ) ) self.validation_data_sets.append( @@ -652,7 +652,7 @@ def __build_datasets(self): self.descriptor_calculator, self.target_calculator, self.use_ddp, - self.parameters._configuration["device"] + self.parameters._configuration["device"], ) ) From 51235b4ea789b69780b2af8ce31e370b54fae7c6 Mon Sep 17 00:00:00 2001 From: Lenz Fiedler Date: Thu, 2 May 2024 08:44:18 +0200 Subject: [PATCH 19/22] Forgot an equals sign --- mala/datahandling/lazy_load_dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mala/datahandling/lazy_load_dataset.py b/mala/datahandling/lazy_load_dataset.py index a3af4ab64..6a91fe731 100644 --- a/mala/datahandling/lazy_load_dataset.py +++ b/mala/datahandling/lazy_load_dataset.py @@ -121,7 +121,7 @@ def mix_datasets(self): used_perm = torch.randperm(self.number_of_snapshots) barrier() if self.use_ddp: - used_perm.to(device=self.device) + used_perm = used_perm.to(device=self.device) used_perm = dist.broadcast(used_perm, 0) self.snapshot_list = [ self.snapshot_list[i] for i in used_perm.to("cpu") From 325cf658587202268703a0f7544ce7e1bda7551e Mon Sep 17 00:00:00 2001 From: Lenz Fiedler Date: Fri, 3 May 2024 10:03:30 +0200 Subject: [PATCH 20/22] Lazy loading working now --- mala/datahandling/lazy_load_dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mala/datahandling/lazy_load_dataset.py b/mala/datahandling/lazy_load_dataset.py index 6a91fe731..00810beb3 100644 --- a/mala/datahandling/lazy_load_dataset.py +++ b/mala/datahandling/lazy_load_dataset.py @@ -122,7 +122,7 @@ def mix_datasets(self): barrier() if self.use_ddp: used_perm = used_perm.to(device=self.device) - used_perm = dist.broadcast(used_perm, 0) + dist.broadcast(used_perm, 0) self.snapshot_list = [ self.snapshot_list[i] for i in used_perm.to("cpu") ] From 873e486678521187c46725ef929af714b3ac39b8 Mon Sep 17 00:00:00 2001 From: Lenz Fiedler Date: Fri, 3 May 2024 11:17:34 +0200 Subject: [PATCH 21/22] Adapted docs to use srun instead of torchrun for example --- docs/source/advanced_usage/trainingmodel.rst | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/docs/source/advanced_usage/trainingmodel.rst b/docs/source/advanced_usage/trainingmodel.rst index d0228237b..4413ab078 100644 --- a/docs/source/advanced_usage/trainingmodel.rst +++ b/docs/source/advanced_usage/trainingmodel.rst @@ -262,7 +262,8 @@ following setup was confirmed to work on an HPC cluster using the ... # This port can be arbitrarily chosen. - export MASTER_PORT=12342 + # Given here is the torchrun default + export MASTER_PORT=29500 # Find out the host node. echo "NODELIST="${SLURM_NODELIST} @@ -270,10 +271,17 @@ following setup was confirmed to work on an HPC cluster using the export MASTER_ADDR=$master_addr echo "MASTER_ADDR="$MASTER_ADDR - # Run using torchrun. - torchrun --nnodes NUMBER_OF_NODES --nproc_per_node NUMBER_OF_TASKS_PER_NODE --rdzv_id "$SLURM_JOB_ID" training.py + # Run using srun. + srun -u bash -c ' + # Export additional per process variables + export RANK=$SLURM_PROCID + export LOCAL_RANK=$SLURM_LOCALID + export WORLD_SIZE=$SLURM_NTASKS -This script follows `this tutorial `_. -A tutorial on DDP itself can be found `here `_. + python3 -u training.py + ' + +An overview of environment variables to be set can be found `in the official documentation `_. +A general tutorial on DDP itself can be found `here `_. From b58c096c323237e8d28803ece9cfa06e2fc3de17 Mon Sep 17 00:00:00 2001 From: Lenz Fiedler Date: Fri, 3 May 2024 11:28:47 +0200 Subject: [PATCH 22/22] Small bugfix to fix CI --- mala/network/trainer.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/mala/network/trainer.py b/mala/network/trainer.py index f8bf391f5..81977c40e 100644 --- a/mala/network/trainer.py +++ b/mala/network/trainer.py @@ -240,7 +240,9 @@ def _load_from_run(cls, params, network, data, file=None): # First, load the checkpoint. if params.use_ddp: map_location = {"cuda:%d" % 0: "cuda:%d" % get_local_rank()} - checkpoint = torch.load(file, map_location=map_location) + checkpoint = torch.load(file, map_location=map_location) + else: + checkpoint = torch.load(file) # Now, create the Trainer class with it. loaded_trainer = Trainer(