diff --git a/composer/checkpoint/load.py b/composer/checkpoint/load.py index 1ddb2f11f6..3ad52fce3d 100644 --- a/composer/checkpoint/load.py +++ b/composer/checkpoint/load.py @@ -11,7 +11,7 @@ import textwrap from dataclasses import dataclass from pathlib import Path -from typing import Any, Dict, Optional, Sequence, Union +from typing import Any, Dict, Optional, Sequence, Tuple, Union import torch import torch.distributed.checkpoint as DCP @@ -139,7 +139,18 @@ def load_checkpoint( assert model is not None assert model_child_path is not None model_load_path = os.path.join(load_path, model_child_path) - load_model_checkpoint(model, load_path=model_load_path, load_options=load_options) + if state is not None: + state.automicrobatch_fsdp_hook_handles, state.fsdp_modules = load_model_checkpoint( + model, + load_path=model_load_path, + load_options=load_options, + ) + else: + load_model_checkpoint( + model, + load_path=model_load_path, + load_options=load_options, + ) if load_options.load_optimizer: assert optim_child_path is not None @@ -159,7 +170,7 @@ def load_model_checkpoint( load_path: Optional[str] = None, load_options: Optional[Union[CheckpointLoadOptions, Dict]] = None, seed: int = 42, -): +) -> Tuple[list, dict]: """Load a a model checkpoint from the specified path into the model. Args: @@ -178,10 +189,13 @@ def load_model_checkpoint( if load_options.include_keys is not None or load_options.ignore_keys is not None: load_options.strict = False + automicrobatch_fsdp_hook_handles = [] + fsdp_modules = {} + if load_options.sharded_checkpoint: if not _is_model_fsdp(model): if load_options.shard_as_needed_during_load: - _shard_with_fsdp( + automicrobatch_fsdp_hook_handles, fsdp_modules = _shard_with_fsdp( model, fsdp_config=load_options.fsdp_config, precision=load_options.precision, @@ -205,7 +219,13 @@ def load_model_checkpoint( load_options.fsdp_config.update({'sync_module_states': True}) else: load_options.fsdp_config.sync_module_states = True - _shard_with_fsdp(model, fsdp_config=load_options.fsdp_config, precision=load_options.precision, seed=seed) + automicrobatch_fsdp_hook_handles, fsdp_modules = _shard_with_fsdp( + model, + fsdp_config=load_options.fsdp_config, + precision=load_options.precision, + seed=seed, + ) + return automicrobatch_fsdp_hook_handles, fsdp_modules def _shard_with_fsdp( @@ -214,18 +234,19 @@ def _shard_with_fsdp( fsdp_config: Optional[Union[FSDPConfig, dict]] = None, precision: Optional[str] = None, seed: int = 42, -): +) -> Tuple[list, dict]: if fsdp_config is None: fsdp_config = FSDPConfig() if isinstance(fsdp_config, dict): fsdp_config = FSDPConfig(**fsdp_config) with reproducibility.seed_context(seed): - prepare_fsdp_module( + automicrobatch_fsdp_hook_handles, fsdp_modules = prepare_fsdp_module( model, optimizers=optimizer, fsdp_config=fsdp_config, precision=precision, ) + return automicrobatch_fsdp_hook_handles, fsdp_modules def _load_sharded_model_checkpoint( diff --git a/composer/core/state.py b/composer/core/state.py index a4a8b6fdde..5c429a1cd4 100644 --- a/composer/core/state.py +++ b/composer/core/state.py @@ -547,6 +547,9 @@ def __init__( self.fsdp_config = parallelism_config.fsdp if parallelism_config is not None else None self.tp_config = parallelism_config.tp if parallelism_config is not None else None + self.automicrobatch_fsdp_hook_handles = [] + self.fsdp_modules = {} + self._validate_parallelism_configs() self.device_mesh: Optional[DeviceMesh] = _create_device_mesh(self.device, self.fsdp_config, self.tp_config) @@ -1387,7 +1390,7 @@ def load_model_state( with reproducibility.seed_context(self.rank_zero_seed): from composer.distributed import prepare_fsdp_module - prepare_fsdp_module( + self.automicrobatch_fsdp_hook_handles, self.fsdp_modules = prepare_fsdp_module( self.model, self.optimizers, self.fsdp_config, diff --git a/composer/trainer/trainer.py b/composer/trainer/trainer.py index 0e92604fc3..8b1c6d8f93 100644 --- a/composer/trainer/trainer.py +++ b/composer/trainer/trainer.py @@ -1261,7 +1261,6 @@ def __init__( self.cumulative_alloc_retries = 0 self.num_consecutive_thrashes = 0 self.num_consecutive_non_OOM_batches = 0 - self.automicrobatch_fsdp_hook_handles = [] if auto_microbatching and profiler: raise ValueError( @@ -1766,7 +1765,7 @@ def __init__( if self.state.fsdp_config is not None and self.state.fsdp_config.auto_wrap and not self.state.load_monolith_rank0_only: # Init with globally fixed seed so all HSDP replicas have the same initial weights with reproducibility.seed_context(self.state.rank_zero_seed): - self.automicrobatch_fsdp_hook_handles, self.fsdp_modules = prepare_fsdp_module( + self.state.automicrobatch_fsdp_hook_handles, self.state.fsdp_modules = prepare_fsdp_module( model, optimizers, self.state.fsdp_config, @@ -1937,7 +1936,7 @@ def __init__( ): # Init with globally fixed seed so all HSDP replicas have the same initial weights with reproducibility.seed_context(self.state.rank_zero_seed): - self.automicrobatch_fsdp_hook_handles, self.fsdp_modules = prepare_fsdp_module( + self.state.automicrobatch_fsdp_hook_handles, self.state.fsdp_modules = prepare_fsdp_module( model, optimizers, self.state.fsdp_config, @@ -2917,8 +2916,11 @@ def _train_batch(self, use_grad_scaling: bool) -> dict[str, torch.Tensor]: all_ranks_finished = all_ranks_finished_tensor.item() == 1 if found_cuda_oom == 1: # Readd sync hooks if they were previously turned off - if self.state.fsdp_enabled and len(self.automicrobatch_fsdp_hook_handles) == 0: - self.automicrobatch_fsdp_hook_handles = _readd_fsdp_sync_hooks(self.fsdp_modules, sync_hook) + if self.state.fsdp_enabled and len(self.state.automicrobatch_fsdp_hook_handles) == 0: + self.state.automicrobatch_fsdp_hook_handles = _readd_fsdp_sync_hooks( + self.state.fsdp_modules, + sync_hook, + ) _adjust_device_train_microbatch_size(self.state) self.num_consecutive_thrashes = 0 self.num_consecutive_non_OOM_batches = 0 @@ -2934,8 +2936,11 @@ def _train_batch(self, use_grad_scaling: bool) -> dict[str, torch.Tensor]: ) if self.num_consecutive_thrashes >= 2: # Readd sync hooks if they were previously turned off - if self.state.fsdp_enabled and len(self.automicrobatch_fsdp_hook_handles) == 0: - self.automicrobatch_fsdp_hook_handles = _readd_fsdp_sync_hooks(self.fsdp_modules, sync_hook) + if self.state.fsdp_enabled and len(self.state.automicrobatch_fsdp_hook_handles) == 0: + self.state.automicrobatch_fsdp_hook_handles = _readd_fsdp_sync_hooks( + self.state.fsdp_modules, + sync_hook, + ) _adjust_device_train_microbatch_size(self.state) self.num_consecutive_thrashes = 0 continue @@ -2949,12 +2954,12 @@ def _train_batch(self, use_grad_scaling: bool) -> dict[str, torch.Tensor]: ) self.num_consecutive_non_OOM_batches += 1 if self.state.fsdp_enabled and len( - self.automicrobatch_fsdp_hook_handles, + self.state.automicrobatch_fsdp_hook_handles, ) > 0 and self.num_consecutive_non_OOM_batches >= 3: patch_unshard_for_automicrobatching(auto_microbatch_size_found=True) - for handle in self.automicrobatch_fsdp_hook_handles: + for handle in self.state.automicrobatch_fsdp_hook_handles: handle.remove() - self.automicrobatch_fsdp_hook_handles.clear() + self.state.automicrobatch_fsdp_hook_handles.clear() if torch.cuda.is_available(): memory_stats = torch.cuda.memory_stats() self.cumulative_alloc_retries = memory_stats['num_alloc_retries'] @@ -3753,10 +3758,11 @@ def _eval_loop( self.state.dataloader_len = original_num_batches # If training occurs after evaluation, readd hooks in case of memory spike - sync_hook = _create_sync_hook(self.state) - if self.state.fsdp_enabled and len(self.automicrobatch_fsdp_hook_handles) == 0: - self.automicrobatch_fsdp_hook_handles = _readd_fsdp_sync_hooks(self.fsdp_modules, sync_hook) - self.num_consecutive_non_OOM_batches = 0 + if self.state.auto_microbatching: + sync_hook = _create_sync_hook(self.state) + if self.state.fsdp_enabled and len(self.state.automicrobatch_fsdp_hook_handles) == 0: + self.state.automicrobatch_fsdp_hook_handles = _readd_fsdp_sync_hooks(self.state.fsdp_modules, sync_hook) + self.num_consecutive_non_OOM_batches = 0 def _use_grad_scaling(self, precision: Union[str, Precision], scaler: Optional[GradScaler]) -> bool: """Determines based on precision when to use grad scaling. diff --git a/tests/checkpoint/test_load.py b/tests/checkpoint/test_load.py index a3db3ac8d3..aa4f50059e 100644 --- a/tests/checkpoint/test_load.py +++ b/tests/checkpoint/test_load.py @@ -7,6 +7,7 @@ from pathlib import Path import pytest +from torch.utils.data import DataLoader from composer.checkpoint.load import ( load_checkpoint, @@ -26,8 +27,12 @@ get_optim_state_dict, get_resumption_state_dict, ) +from composer.trainer import Trainer from composer.utils import dist from tests.checkpoint.helpers import init_model, init_model_and_optimizer, init_state +from tests.common import ( + RandomClassificationDataset, +) from tests.common.compare import deep_compare @@ -333,3 +338,96 @@ def test_load_checkpoint( deep_compare(original_model_state_dict, new_state_dict) deep_compare(original_optim_state_dict, new_optim_state_dict) deep_compare(original_resumption_state, new_resumption_state, ignore_keys=['rng', 'run_name']) + + +@pytest.mark.gpu +@pytest.mark.parametrize( + 'world_size,sharded_model,sharded_checkpoint,shard_as_needed_during_load', + [ + # Loading an unsharded checkpoint into an unsharded model on a single GPU (not sharding after) + pytest.param(1, False, False, False, marks=pytest.mark.world_size(1)), + + # Loading a sharded checkpoint into a sharded model in distributed setting + pytest.param(2, True, True, False, marks=pytest.mark.world_size(2)), + + # Loading a sharded checkpoint into an unsharded model (sharding it before load) + pytest.param(2, False, True, True, marks=pytest.mark.world_size(2)), + + # Loading an unsharded checkpoint into an unsharded model and sharding it after. + pytest.param(2, False, False, True, marks=pytest.mark.world_size(2)), + + # The other three permutations of the above tests are: + # 2 gpu, Sharded model, sharded checkpoint, with additional sharding -> no need to shard already sharded model + # 2 gpu, Sharded model, unsharded checkpoint, with additional sharding -> no need to shard already sharded model + # 2 gpu, Unsharded model, unsharded checkpoint, without additional sharding -> no need to try this on 2 gpus + ], +) +def test_load_model_checkpoint_and_eval( + world_size: int, + tmp_path: Path, + sharded_model: bool, + sharded_checkpoint: bool, + shard_as_needed_during_load: bool, +): + if sharded_model and not sharded_checkpoint: + pytest.xfail( + 'Loading an unsharded checkpoint into a sharded model is not supported and causes OOMs when running with these tests', + ) + # Ensure all ranks use the same path + destination_dir = os.path.join(tmp_path, str(uuid.uuid4())[:8]) + destination_dir = dist.all_gather_object(destination_dir)[0] + + # Save a model checkpoint + model, _ = init_model(use_composer_model=True, use_fsdp=sharded_checkpoint, device='cuda') + save_path = os.path.join(destination_dir, 'model.pt') if not sharded_checkpoint else destination_dir + saved_path = save_model_to_disk(model, save_path, sharded_checkpoint=sharded_checkpoint) + + # Get the original model's state dict + original_state_dict = get_model_state_dict(model, sharded_state_dict=False) + # Load the model checkpoint + new_model, _ = init_model(use_composer_model=True, use_fsdp=sharded_model, device='cuda') + if saved_path is not None: + load_path = saved_path if not sharded_checkpoint else str(Path(saved_path).parent) + else: + load_path = '' + + if not sharded_model and sharded_checkpoint and not shard_as_needed_during_load: + context_manager = pytest.raises(ValueError) + else: + context_manager = contextlib.nullcontext() + + with context_manager: + load_model_checkpoint( + new_model, + load_path=load_path, + load_options=dict( + sharded_checkpoint=sharded_checkpoint, + shard_as_needed_during_load=shard_as_needed_during_load, + ), + ) + # Check if model is sharded when it should be + if shard_as_needed_during_load: + assert _is_model_fsdp(new_model), 'Model should be sharded after load' + + # Get the new model's state dict + new_state_dict = get_model_state_dict(new_model, sharded_state_dict=False) + + if dist.get_global_rank() == 0: + deep_compare(original_state_dict, new_state_dict) + + dataset = RandomClassificationDataset( + shape=(8,), + size=100, + num_classes=3, + ) + + trainer = Trainer( + eval_dataloader=DataLoader( + dataset=dataset, + sampler=dist.get_sampler(dataset), + ), + model=new_model, # type: ignore + ) + + # Evaluate the model + trainer.eval()