Skip to content

Commit

Permalink
Merge branch 'main' into configs-refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
runame committed Dec 17, 2024
2 parents 839a40a + 269f363 commit 84a44f3
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 6 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ Optimizers is BSD licensed, as found in the LICENSE file.
## Installation and Dependencies
This code requires `python>=3.10` and `torch>=2.5.0`.
Install `distributed_shampoo` with all dependencies:
```
```bash
git clone [email protected]:facebookresearch/optimizers.git
cd optimizers
pip install .
Expand Down
67 changes: 63 additions & 4 deletions distributed_shampoo/README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# PyTorch Distributed Shampoo

Distributed Shampoo is a preconditioned stochastic gradient optimizer in the adaptive gradient (Adagrad) family of methods [1, 2]. It converges faster by leveraging neural network-specific structures to achieve comparable model quality/accuracy in fewer iterations or epochs at the cost of additional FLOPs and memory, or achieve higher model quality in the same number of iterations or epochs. Our implementation offers specialized support for serial, [Distributed Data Parallel (DDP)](https://pytorch.org/tutorials/intermediate/ddp_tutorial.html), [Fully Sharded Data Parallel (FSDP)](https://pytorch.org/tutorials/intermediate/FSDP_tutorial.html), and [Hybrid Sharding Data Parallel](https://pytorch.org/tutorials/recipes/distributed_device_mesh.html#how-to-use-devicemesh-with-hsdp) training.
Distributed Shampoo is a preconditioned stochastic gradient optimizer in the adaptive gradient (Adagrad) family of methods [1, 2]. It converges faster by leveraging neural network-specific structures to achieve comparable model quality/accuracy in fewer iterations or epochs at the cost of additional FLOPs and memory, or achieve higher model quality in the same number of iterations or epochs. Our implementation offers specialized support for serial, [Distributed Data Parallel (DDP)](https://pytorch.org/tutorials/intermediate/ddp_tutorial.html), [Fully Sharded Data Parallel (FSDP)](https://pytorch.org/tutorials/intermediate/FSDP_tutorial.html), and [Hybrid Sharding Data Parallel (HSDP)](https://pytorch.org/tutorials/recipes/distributed_device_mesh.html#how-to-use-devicemesh-with-hsdp) training.

Distributed Shampoo currently only supports dense parameters.

Expand Down Expand Up @@ -329,7 +329,7 @@ optimizer = DistributedShampoo(
),
)
```
Please see `ddp_cifar10_example.py` as an example.
Please see [`ddp_cifar10_example.py`](https://github.com/facebookresearch/optimizers/blob/main/distributed_shampoo/examples/ddp_cifar10_example.py) as an example.

### FSDP Training Support

Expand Down Expand Up @@ -383,7 +383,66 @@ optimizer = DistributedShampoo(
),
)
```
Please see `fsdp_cifar10_example.py` as an example.
Please see [`fsdp_cifar10_example.py`](https://github.com/facebookresearch/optimizers/blob/main/distributed_shampoo/examples/fsdp_cifar10_example.py) as an example.

### HSDP Training Support

Note that we only support PyTorch HSDP with `sharding_strategy=ShardingStrategy.HYBRID_SHARD` and the `use_orig_params=True` option.
```python
import os

import torch
import torch.distributed as dist
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, ShardingStrategy

from distributed_shampoo import (
AdamGraftingConfig,
compile_fsdp_parameter_metadata,
DistributedShampoo,
HSDPShampooConfig,
)

LOCAL_RANK = int(os.environ["LOCAL_RANK"])
WORLD_RANK = int(os.environ["RANK"])
WORLD_SIZE = int(os.environ["WORLD_SIZE"])

dist.init_process_group(
backend=args.backend,
init_method="env://",
rank=WORLD_RANK,
world_size=WORLD_SIZE,
)
device = torch.device("cuda:{}".format(LOCAL_RANK))

# Instantiate device mesh for HSDP Shampoo.
# Assuming 8 GPUs, will be initialized as 2 x 4 mesh.
# ([[0, 1, 2, 3], [4, 5, 6, 7]])
# This means we shard model into two, and each sub-model has four replicates.
device_mesh = init_device_mesh("cuda", (2, 4))

model = instantiate_model().to(device)
model = FSDP(model, device_mesh=device_mesh, sharding_strategy=ShardingStrategy.HYBRID_SHARD, use_orig_params=True)

optimizer = DistributedShampoo(
model.parameters(),
lr=0.001,
betas=(0.9, 0.999),
epsilon=1e-12,
weight_decay=1e-05,
max_preconditioner_dim=8192,
precondition_frequency=100,
use_decoupled_weight_decay=True,
grafting_config=AdamGraftingConfig(
beta2=0.999,
epsilon=1e-12,
),
distributed_config=HSDPShampooConfig(
param_to_metadata=compile_fsdp_parameter_metadata(model),
device_mesh=device_mesh,
),
)
```
Please see [`hsdp_cifar10_example.py`](https://github.com/facebookresearch/optimizers/blob/main/distributed_shampoo/examples/hsdp_cifar10_example.py) as an example.

## Checkpointing Support

Expand Down Expand Up @@ -415,7 +474,7 @@ model.load_state_dict(state_dict["model"])
optimizer.load_distributed_state_dict(state_dict["optim"], key_to_param=model.named_parameters())
```

You can also refer to `ddp_cifar10_example.py` as an example.
You can also refer to [`ddp_cifar10_example.py`](https://github.com/facebookresearch/optimizers/blob/main/distributed_shampoo/examples/ddp_cifar10_example.py) as an example.

## Hyperparameter Tuning

Expand Down
2 changes: 1 addition & 1 deletion distributed_shampoo/examples/hsdp_cifar10_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@
# Instantiate device mesh for HSDP Shampoo.
# Assuming 8 GPUs, will be initialized as 2 x 4 mesh.
# ([[0, 1, 2, 3], [4, 5, 6, 7]])
device_mesh = init_device_mesh("cuda", (4, 2))
device_mesh = init_device_mesh("cuda", (2, 4))

# instantiate model and loss function
model, loss_function = get_model_and_loss_fn(device)
Expand Down

0 comments on commit 84a44f3

Please sign in to comment.