Skip to content

Commit

Permalink
Add HSDP examples in README.md
Browse files Browse the repository at this point in the history
Summary: This also fixes one small mismatch between `init_device_mesh()` usage in example.

Reviewed By: chuanhaozhuge

Differential Revision: D67111166

fbshipit-source-id: 440e85fabc93656ee9cf0578991ce23cc5fe6976
  • Loading branch information
tsunghsienlee authored and facebook-github-bot committed Dec 12, 2024
1 parent a26c75e commit 269f363
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 2 deletions.
61 changes: 60 additions & 1 deletion distributed_shampoo/README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# PyTorch Distributed Shampoo

Distributed Shampoo is a preconditioned stochastic gradient optimizer in the adaptive gradient (Adagrad) family of methods [1, 2]. It converges faster by leveraging neural network-specific structures to achieve comparable model quality/accuracy in fewer iterations or epochs at the cost of additional FLOPs and memory, or achieve higher model quality in the same number of iterations or epochs. Our implementation offers specialized support for serial, [Distributed Data Parallel (DDP)](https://pytorch.org/tutorials/intermediate/ddp_tutorial.html), [Fully Sharded Data Parallel (FSDP)](https://pytorch.org/tutorials/intermediate/FSDP_tutorial.html), and [Hybrid Sharding Data Parallel](https://pytorch.org/tutorials/recipes/distributed_device_mesh.html#how-to-use-devicemesh-with-hsdp) training.
Distributed Shampoo is a preconditioned stochastic gradient optimizer in the adaptive gradient (Adagrad) family of methods [1, 2]. It converges faster by leveraging neural network-specific structures to achieve comparable model quality/accuracy in fewer iterations or epochs at the cost of additional FLOPs and memory, or achieve higher model quality in the same number of iterations or epochs. Our implementation offers specialized support for serial, [Distributed Data Parallel (DDP)](https://pytorch.org/tutorials/intermediate/ddp_tutorial.html), [Fully Sharded Data Parallel (FSDP)](https://pytorch.org/tutorials/intermediate/FSDP_tutorial.html), and [Hybrid Sharding Data Parallel (HSDP)](https://pytorch.org/tutorials/recipes/distributed_device_mesh.html#how-to-use-devicemesh-with-hsdp) training.

Distributed Shampoo currently only supports dense parameters.

Expand Down Expand Up @@ -382,6 +382,65 @@ optimizer = DistributedShampoo(
```
Please see [`fsdp_cifar10_example.py`](https://github.com/facebookresearch/optimizers/blob/main/distributed_shampoo/examples/fsdp_cifar10_example.py) as an example.

### HSDP Training Support

Note that we only support PyTorch HSDP with `sharding_strategy=ShardingStrategy.HYBRID_SHARD` and the `use_orig_params=True` option.
```python
import os

import torch
import torch.distributed as dist
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, ShardingStrategy

from distributed_shampoo import (
AdamGraftingConfig,
compile_fsdp_parameter_metadata,
DistributedShampoo,
HSDPShampooConfig,
)

LOCAL_RANK = int(os.environ["LOCAL_RANK"])
WORLD_RANK = int(os.environ["RANK"])
WORLD_SIZE = int(os.environ["WORLD_SIZE"])

dist.init_process_group(
backend=args.backend,
init_method="env://",
rank=WORLD_RANK,
world_size=WORLD_SIZE,
)
device = torch.device("cuda:{}".format(LOCAL_RANK))

# Instantiate device mesh for HSDP Shampoo.
# Assuming 8 GPUs, will be initialized as 2 x 4 mesh.
# ([[0, 1, 2, 3], [4, 5, 6, 7]])
# This means we shard model into two, and each sub-model has four replicates.
device_mesh = init_device_mesh("cuda", (2, 4))

model = instantiate_model().to(device)
model = FSDP(model, device_mesh=device_mesh, sharding_strategy=ShardingStrategy.HYBRID_SHARD, use_orig_params=True)

optimizer = DistributedShampoo(
model.parameters(),
lr=0.001,
betas=(0.9, 0.999),
epsilon=1e-12,
weight_decay=1e-05,
max_preconditioner_dim=8192,
precondition_frequency=100,
use_decoupled_weight_decay=True,
grafting_config=AdamGraftingConfig(
beta2=0.999,
epsilon=1e-12,
),
distributed_config=HSDPShampooConfig(
param_to_metadata=compile_fsdp_parameter_metadata(model),
device_mesh=device_mesh,
),
)
```
Please see [`hsdp_cifar10_example.py`](https://github.com/facebookresearch/optimizers/blob/main/distributed_shampoo/examples/hsdp_cifar10_example.py) as an example.

## Checkpointing Support

To checkpoint Distributed Shampoo, we have to use the `torch.distributed.checkpoint` solution with `DTensor`. *Note that we do not currently support the standard PyTorch checkpointing solution because it cannot handle storing process groups or `DTensor` by default.* We have therefore disabled `state_dict` and `load_state_dict` and instead rely on `distributed_state_dict` and `load_distributed_state_dict` instead.
Expand Down
2 changes: 1 addition & 1 deletion distributed_shampoo/examples/hsdp_cifar10_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@
# Instantiate device mesh for HSDP Shampoo.
# Assuming 8 GPUs, will be initialized as 2 x 4 mesh.
# ([[0, 1, 2, 3], [4, 5, 6, 7]])
device_mesh = init_device_mesh("cuda", (4, 2))
device_mesh = init_device_mesh("cuda", (2, 4))

# instantiate model and loss function
model, loss_function = get_model_and_loss_fn(device)
Expand Down

0 comments on commit 269f363

Please sign in to comment.