From cf3ae9814230d1b55756107f5d43a8f86adde3bd Mon Sep 17 00:00:00 2001 From: eljandoubi Date: Mon, 4 Nov 2024 02:59:59 +0100 Subject: [PATCH] when use_oring_params=True, skip optimizers params update --- src/accelerate/accelerator.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/accelerate/accelerator.py b/src/accelerate/accelerator.py index 0504306f7b3..5997a0cc302 100755 --- a/src/accelerate/accelerator.py +++ b/src/accelerate/accelerator.py @@ -1320,8 +1320,8 @@ def prepare(self, *args, device_placement=None): args = self._prepare_ipex_or_xpu(*args) if self.fp8_backend == "TE": args = self._prepare_te(*args) - if self.distributed_type == DistributedType.FSDP: - # Wrap models with FSDP and update the optimizer parameters. + if self.distributed_type == DistributedType.FSDP and not self.state.fsdp_plugin.use_orig_params: + # Wrap models with FSDP and update the optimizers parameters. # Other types of wrapping are handled in the next if-else block. args = self._prepare_fsdp(*args, device_placement=device_placement) if self.distributed_type == DistributedType.DEEPSPEED: @@ -2200,11 +2200,11 @@ def _prepare_fsdp(self, *args, device_placement): # Replace optimizer parameter groups with the flattened ones. for model_idx, opt_idx in model_optimizer_map.items(): - loal_fsdp_map = fsdp_2_base_layer_map[model_idx] + local_fsdp_map = fsdp_2_base_layer_map[model_idx] local_layer_group = model_layer_group_map[model_idx] for fsdp_layer, param in models[model_idx].named_parameters(): local_groups = [] - for layer in loal_fsdp_map[fsdp_layer]: + for layer in local_fsdp_map[fsdp_layer]: local_groups.append(local_layer_group[layer]) if local_groups: counter_groups = Counter(local_groups)