From b2bf926df8ad42e4f1556b52ff7ca108c1f0cd25 Mon Sep 17 00:00:00 2001 From: hariharandev1 Date: Sat, 5 Oct 2024 00:54:36 -0700 Subject: [PATCH] New improved modelling for LLM Deepspeed. --- .../checkpointing/base_checkpointing.py | 40 ++++++++++++++----- .../configs/workload/megatron_deepspeed.yaml | 9 +++-- 2 files changed, 35 insertions(+), 14 deletions(-) diff --git a/dlio_benchmark/checkpointing/base_checkpointing.py b/dlio_benchmark/checkpointing/base_checkpointing.py index a9e184cf..72ccfb8f 100644 --- a/dlio_benchmark/checkpointing/base_checkpointing.py +++ b/dlio_benchmark/checkpointing/base_checkpointing.py @@ -15,6 +15,7 @@ limitations under the License. """ import os +import math from abc import ABC, abstractmethod from dlio_benchmark.common.enumerations import CheckpointLocationType @@ -53,7 +54,7 @@ def __init__(self, ext): self.layer_state = dict() for index, state in enumerate(self.args.layer_parameters): if state > 0: - self.layer_state[str(index)] = self.get_tensor(state) + self.layer_state[str(index)] = self.get_tensor(state / self.args.tensor_parallelism) @abstractmethod def get_tensor(self, size): @@ -66,24 +67,41 @@ def save_state(self, suffix, state): def get_name(self, suffix): return os.path.join(self.args.checkpoint_folder, f"{suffix}.{self.ext}") + def get_layer_index(self, rank, tensor_parallelism, pipeline_parallelism, total_layers): + if tensor_parallelism > 1: + total_layers = total_layers + tensor_parallelism + + divisible_layers = total_layers - (total_layers % pipeline_parallelism) + min_layers_per_pipeline = divisible_layers // pipeline_parallelism + max_layer_per_pipeline = min_layers_per_pipeline + 1 + pipeline_rank = (rank // tensor_parallelism) % pipeline_parallelism + left_layers = total_layers - divisible_layers + num_layers_per_pipeline = max_layer_per_pipeline + if pipeline_rank >= left_layers: + num_layers_per_pipeline = min_layers_per_pipeline + if pipeline_rank < left_layers: + start_layer = pipeline_rank * max_layer_per_pipeline + end_layer = start_layer + num_layers_per_pipeline - 1 + else: + start_layer = left_layers * max_layer_per_pipeline + (pipeline_rank - left_layers) * (min_layers_per_pipeline) + end_layer = start_layer + num_layers_per_pipeline - 1 + return start_layer, end_layer + @abstractmethod def checkpoint(self, epoch, step_number): - rank_to_checkpoint = DLIOMPI.get_instance().rank() + my_rank = DLIOMPI.get_instance().rank() + rank_to_checkpoint = my_rank if self.args.checkpoint_type == CheckpointLocationType.RANK_ZERO: rank_to_checkpoint = 0 - if rank_to_checkpoint == DLIOMPI.get_instance().rank(): - my_rank = DLIOMPI.get_instance().rank() + if rank_to_checkpoint == my_rank: if self.model_state: self.save_state(suffix=f"model-{epoch}-{step_number}-{my_rank}", state=self.model_state) if self.optimization_state: self.save_state(suffix=f"optimizer-{epoch}-{step_number}-{my_rank}", state=self.optimization_state) - if rank_to_checkpoint % self.args.pipeline_parallelism == 0: - if self.layer_state and self.args.num_layers > 0: - total_layers = self.args.num_layers - if self.args.tensor_parallelism > 1: - total_layers = total_layers + self.args.tensor_parallelism - for layer in range(total_layers): - self.save_state(suffix=f"layer-{layer}-{epoch}-{step_number}-{my_rank}", state=self.layer_state) + + start_layer, end_layer = self.get_layer_index(my_rank,self.args.tensor_parallelism, self.args.pipeline_parallelism, self.args.num_layers) + for layer_index in range(start_layer, end_layer + 1): + self.save_state(suffix=f"layer-{layer_index}-{epoch}-{step_number}-{my_rank}", state=self.layer_state) @abstractmethod def finalize(self): diff --git a/dlio_benchmark/configs/workload/megatron_deepspeed.yaml b/dlio_benchmark/configs/workload/megatron_deepspeed.yaml index 20e4a3aa..9353c4b7 100644 --- a/dlio_benchmark/configs/workload/megatron_deepspeed.yaml +++ b/dlio_benchmark/configs/workload/megatron_deepspeed.yaml @@ -23,7 +23,8 @@ reader: sample_shuffle: seed train: - epochs: 311541 + total_training_steps: 311541 + epochs: 1 computation_time: 0.03 # every iteration has 290 steps and each iteration is 8.9 sec. checkpoint: @@ -32,5 +33,7 @@ checkpoint: model_size: 30102 type: all_ranks optimization_groups: [1009254400, 865075200, 793600] - num_layers: 44 - layer_parameters: [129761280, 20971520] + num_layers: 40 + pipeline_parallelism: 8 + tensor_parallelism: 4 + layer_parameters: [52583936, 209715200]