You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
One of the scripts in the examples/ folder of Accelerate or an officially supported no_trainer script in the examples folder of the transformers repo (such as run_no_trainer_glue.py)
I modified the code between accelerator.backward(loss) (line 1030) and optimizer.step() (line 1033) as follows:
fromdeepspeed.utilsimportsafe_get_full_gradforn, lpinunet.named_parameters():
# 1. Access the full states# 1.1) gradient lookup# For zero1 and zero2, gradient lookup must be called after `backward` and before `step`# For zero3, gradient lookup must be called after `backward`hp_grad=safe_get_full_grad(lp)
Problem
The current implementation of Accelerate wraps both DeepSpeed's backward and step operations into a single accelerator.backward call. This prevents users from accessing the gradients between these two operations, which is necessary for gradient analysis or custom gradient processing.
Suggested Solution
Modify Accelerate's DeepSpeed integration to allow users to access gradients between the backward and step operations. This could be achieved by:
Separating the backward and step operations in Accelerate's DeepSpeed wrapper. By the way, I don't understand why DeepSpeed's backward and step are coupled together.
A temporary solution to access full gradients when using DeepSpeed with Accelerate. I modified the code in accelerate.utils.deepspeed.py line 178 and accelerate.accelerator.py line 2188.
self.engine.backward(loss, **kwargs)
# Deepspeed's `engine.step` performs the following operations:# - gradient accumulation check# - gradient clipping# - optimizer step# - zero grad# - checking overflow# - lr_scheduler step (only if engine.lr_scheduler is not None)ifgradients!=None:
fromdeepspeed.utilsimportsafe_get_full_gradimporttorchwithtorch.no_grad():
forn, lpinself.engine.module.named_parameters():
# 1. Access the full states# 1.1) gradient lookup# For zero1 and zero2, gradient lookup must be called after `backward` and before `step`# For zero3, gradient lookup must be called after `backward`iflp.gradisNone:
gradients[n] =safe_get_full_grad(lp)
else:
gradients[n] =lp.gradself.engine.step()
ifgradients!=None:
returngradients
Sorry for hijacking the thread, but how does one get the unwrap model's gradients when using accelerate+FSDP? I print out the shape of the gradient in the middle of training with `print(list(model.parameters())[0].grad.shape), and it is a 1D flattened tensor, which is not in the correct view of the parameters shape (should be 2D)
System Info
Information
Tasks
no_trainer
script in theexamples
folder of thetransformers
repo (such asrun_no_trainer_glue.py
)Reproduction
When using the official training script for Diffusers (https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image.py) with DeepSpeed and ZeRO-2, I'm trying to save the gradients of the model at each training step. However, I'm encountering difficulties due to the way Accelerate wraps DeepSpeed's operations.
Current Behavior
I modified the code between
accelerator.backward(loss)
(line 1030) andoptimizer.step()
(line 1033) as follows:Problem
The current implementation of Accelerate wraps both DeepSpeed's
backward
andstep
operations into a singleaccelerator.backward
call. This prevents users from accessing the gradients between these two operations, which is necessary for gradient analysis or custom gradient processing.Suggested Solution
Modify Accelerate's DeepSpeed integration to allow users to access gradients between the
backward
andstep
operations. This could be achieved by:backward
andstep
operations in Accelerate's DeepSpeed wrapper. By the way, I don't understand why DeepSpeed's backward and step are coupled together.Finally, I obtained the desired gradients using the following code.
Expected behavior
I should be able to access and save the full gradients of the model parameters at each training step when using DeepSpeed with ZeRO-2.
The text was updated successfully, but these errors were encountered: