Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve docs on grad accumulation #1817

Merged
merged 5 commits into from
Aug 7, 2023
Merged
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 73 additions & 0 deletions docs/source/usage_guides/gradient_accumulation.md
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,11 @@ training on. 🤗 Accelerate automagically does this for you by default. Behind
Below is the finished implementation for performing gradient accumulation with 🤗 Accelerate

```python
from accelerate import Accelerator
accelerator = Accelerator(gradient_accumulation_steps=2)
model, optimizer, training_dataloader, scheduler = accelerator.prepare(
model, optimizer, training_dataloader, scheduler
)
for batch in training_dataloader:
with accelerator.accumulate(model):
inputs, targets = batch
Expand All @@ -138,4 +143,72 @@ for batch in training_dataloader:
optimizer.zero_grad()
```

<Tip warning={true}>
It's important that **only one forward/backward** should be done inside the context manager `with accelerator.accumulate(model)`.
</Tip>
vwxyzjn marked this conversation as resolved.
Show resolved Hide resolved


To learn more about what magic this wraps around, read the [Gradient Synchronization concept guide](../concept_guides/gradient_synchronization)


## Self-contained example

Here is a self-contained example that you can run to see gradient accumulation in action with 🤗 Accelerate:

```python
import torch
import copy
from accelerate import Accelerator
from accelerate.utils import set_seed
from torch.utils.data import TensorDataset, DataLoader

# seed
set_seed(0)

# define toy inputs and labels
x = torch.tensor([1., 2., 3., 4., 5., 6., 7., 8.])
y = torch.tensor([2., 4., 6., 8., 10., 12., 14., 16.])
gradient_accumulation_steps = 4
batch_size = len(x) // gradient_accumulation_steps

# define dataset and dataloader
dataset = TensorDataset(x, y)
dataloader = DataLoader(dataset, batch_size=batch_size)

# define model, optimizer and loss function
model = torch.zeros((1, 1), requires_grad=True)
model_clone = copy.deepcopy(model)
criterion = torch.nn.MSELoss()
model_optimizer = torch.optim.SGD([model], lr=0.02)
accelerator = Accelerator(gradient_accumulation_steps=gradient_accumulation_steps)
model, model_optimizer, dataloader = accelerator.prepare(model, model_optimizer, dataloader)
model_clone_optimizer = torch.optim.SGD([model_clone], lr=0.02)
print(f"initial model weight is {model.mean().item():.5f}")
print(f"initial model weight is {model_clone.mean().item():.5f}")
for i, (inputs, labels) in enumerate(dataloader):
with accelerator.accumulate(model):
inputs = inputs.view(-1, 1)
print(i, inputs.flatten())
labels = labels.view(-1, 1)
outputs = inputs @ model
loss = criterion(outputs, labels)
accelerator.backward(loss)
model_optimizer.step()
model_optimizer.zero_grad()
loss = criterion(x.view(-1, 1) @ model_clone, y.view(-1, 1))
model_clone_optimizer.zero_grad()
loss.backward()
model_clone_optimizer.step()
print(f"w/ accumulation, the final model weight is {model.mean().item():.5f}")
print(f"w/o accumulation, the final model weight is {model_clone.mean().item():.5f}")
```
```
initial model weight is 0.00000
initial model weight is 0.00000
0 tensor([1., 2.])
1 tensor([3., 4.])
2 tensor([5., 6.])
3 tensor([7., 8.])
w/ accumulation, the final model weight is 2.04000
w/o accumulation, the final model weight is 2.04000
```
Loading