You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Is there a way to control which layers to checkpoint when using recompute-granulariy?
If using transformer engine and flash attention, recompute-activation does not save any memory at all. The reason for that most likely is because flash attention is laready doing the recomputations for backward pass. (It is in the flash attention paper).
The above image explains how a bit more of memory can be saved, by checkpointing the input to FFN layer instead of the input to the flash attention layer. The idea in the above image is from the following paper. (https://arxiv.org/pdf/2310.03294v2)
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
-
Is there a way to control which layers to checkpoint when using recompute-granulariy?
If using transformer engine and flash attention,
recompute-activation
does not save any memory at all. The reason for that most likely is because flash attention is laready doing the recomputations for backward pass. (It is in the flash attention paper).The above image explains how a bit more of memory can be saved, by checkpointing the input to FFN layer instead of the input to the flash attention layer. The idea in the above image is from the following paper. (https://arxiv.org/pdf/2310.03294v2)
Beta Was this translation helpful? Give feedback.
All reactions