Skip to content

Commit

Permalink
when use_oring_params=True, skip optimizers params update
Browse files Browse the repository at this point in the history
  • Loading branch information
eljandoubi committed Nov 4, 2024
1 parent c0f34a2 commit cf3ae98
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions src/accelerate/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1320,8 +1320,8 @@ def prepare(self, *args, device_placement=None):
args = self._prepare_ipex_or_xpu(*args)
if self.fp8_backend == "TE":
args = self._prepare_te(*args)
if self.distributed_type == DistributedType.FSDP:
# Wrap models with FSDP and update the optimizer parameters.
if self.distributed_type == DistributedType.FSDP and not self.state.fsdp_plugin.use_orig_params:
# Wrap models with FSDP and update the optimizers parameters.
# Other types of wrapping are handled in the next if-else block.
args = self._prepare_fsdp(*args, device_placement=device_placement)
if self.distributed_type == DistributedType.DEEPSPEED:
Expand Down Expand Up @@ -2200,11 +2200,11 @@ def _prepare_fsdp(self, *args, device_placement):

# Replace optimizer parameter groups with the flattened ones.
for model_idx, opt_idx in model_optimizer_map.items():
loal_fsdp_map = fsdp_2_base_layer_map[model_idx]
local_fsdp_map = fsdp_2_base_layer_map[model_idx]
local_layer_group = model_layer_group_map[model_idx]
for fsdp_layer, param in models[model_idx].named_parameters():
local_groups = []
for layer in loal_fsdp_map[fsdp_layer]:
for layer in local_fsdp_map[fsdp_layer]:
local_groups.append(local_layer_group[layer])
if local_groups:
counter_groups = Counter(local_groups)
Expand Down

0 comments on commit cf3ae98

Please sign in to comment.