Recent work by Microsoft and
Google has shown that data parallel
training can be made significantly more efficient by sharding the model
parameters and optimizer state across data parallel workers. These ideas are
encapsulated in the new FullyShardedDataParallel
(FSDP) wrapper provided
by fairscale.
Compared to PyTorch DDP:
- FSDP produces identical results as PyTorch DDP (it's still synchronous data parallel training)
- FSDP shards parameters (FP16 + FP32) and optimizer state across data parallel GPUs
- FSDP is faster than PyTorch DDP because the optimizer step is sharded, and the communication can be overlapped with the forward pass
- FSDP enables training 13B parameter models on 8 GPUs and 175B parameter models on 128 GPUs
FSDP is fully supported in fairseq via the following new arguments:
--ddp-backend=fully_sharded
: enables full sharding via FSDP--cpu-offload
: offloads the optimizer state and FP32 model copy to CPU (combine with--optimizer=cpu_adam
)--no-reshard-after-forward
: increases training speed for large models (1B+ params) and is similar to ZeRO stage 2- other popular options (
--fp16
,--update-freq
,--checkpoint-activations
,--offload-activations
, etc.) continue to work as normal
Limitations
FSDP currently has several limitations compared to fairseq's default DDP backend (PyTorch DDP):
- while FSDP is full compatible with pointwise Optimizers (e.g., Adam, AdamW, Adadelta, Adamax, SGD, etc.), it is not currently compatible with non-pointwise Optimizers (e.g., Adagrad, Adafactor, LAMB, etc.)
- FSDP depends on flattening the parameters, so models that currently require
--fp16-no-flatten-grads
may not be supported
See the fairscale docs for a more detailed explanation of these and other limitations.
The following examples illustrate how to train a very large language model with 13 billion parameters on 1 GPU by offloading parameters and optimizer states to CPU, or on 8 GPUs by fully sharding the params and optimizer states across GPUs.
These examples use the WikiText-103 dataset for demonstration purposes, but in practice a much larger dataset will be needed to achieve good results. Follow the instructions here to preprocess the WikiText-103 dataset using the GPT-2/RoBERTa vocabulary.
The following command trains a 13B parameter GPT-3 model on a single V100 GPU
using the --cpu-offload
feature to offload parameters and optimizer states to
CPU. In this setting, the optimizer step (Adam) happens on CPU. We also use the
--checkpoint-activations
feature (sometimes called gradient checkpointing),
which further saves memory in exchange for a small increase in computation.
Requirements:
- Install the latest master version of fairscale:
pip install git+https://github.com/facebookresearch/fairscale.git@master
- You'll need 32GB of GPU memory and ~256GB of system memory to train the 13B param model.
- If you have less system memory, the 6.7B param model can be trained with ~128GB of system memory, just set
--arch transformer_lm_gpt3_6_7
- We use the CPU Adam optimizer from DeepSpeed, so you'll need to
pip install deepspeed
before running the command.
Notes:
- The command will take ~5 minutes to start training, during which time it will appear to be hung, since randomly initializing 13B weights can be slow.
- The
--cpu-offload
feature requires training in mixed precision (--fp16
). - Tune the
OMP_NUM_THREADS
env variable for best performance with CPU offloading. - The example command below stops training after 10 steps (
--max-update 10
) and does not save checkpoints (--no-save
).
OMP_NUM_THREADS=20 CUDA_VISIBLE_DEVICES=0 \
fairseq-train data-bin/wikitext-103-roberta-bpe-bin \
--ddp-backend fully_sharded --fp16 --fp16-init-scale 4 \
--cpu-offload --checkpoint-activations \
--task language_modeling --tokens-per-sample 2048 --batch-size 8 \
--arch transformer_lm_gpt3_13 \
--optimizer cpu_adam --adam-betas "(0.9,0.98)" \
--lr 0.0001 --lr-scheduler polynomial_decay --warmup-updates 5 --total-num-update 10 \
--max-update 10 --no-save --log-format json --log-interval 1
Example output
(...)
2021-03-08 12:29:51 | INFO | fairseq_cli.train | num. model params: 13,110,865,920 (num. trained: 13,110,865,920)
(...)
2021-03-08 12:29:51 | INFO | fairseq_cli.train | training on 1 devices (GPUs/TPUs)
2021-03-08 12:29:51 | INFO | fairseq_cli.train | max tokens per GPU = None and batch size per GPU = 8
(...)
Adam Optimizer #0 is created with AVX2 arithmetic capability.
Config: alpha=0.000100, betas=(0.900000, 0.980000), weight_decay=0.000000, adam_w=1
(...)
2021-03-08 12:31:36 | INFO | train_inner | {"epoch": 1, "update": 0.0, "loss": "16.475", "ppl": "91120.8", "wps": "0", "ups": "0", "wpb": "16384", "bsz": "8", "num_updates": "1", "lr": "2e-05", "gnorm": "20.751", "loss_scale": "4", "train_wall": "99", "gb_free": "9.3", "wall": "105"}
2021-03-08 12:32:33 | INFO | train_inner | {"epoch": 1, "update": 0.0, "loss": "16.446", "ppl": "89281.6", "wps": "288.7", "ups": "0.02", "wpb": "16384", "bsz": "8", "num_updates": "2", "lr": "4e-05", "gnorm": "19.777", "loss_scale": "4", "train_wall": "57", "gb_free": "9.3", "wall": "161"}
2021-03-08 12:33:12 | INFO | fairseq.trainer | NOTE: gradient overflow detected, ignoring gradient, setting loss scale to: 2.0
2021-03-08 12:33:51 | INFO | fairseq.trainer | NOTE: gradient overflow detected, ignoring gradient, setting loss scale to: 1.0
2021-03-08 12:34:45 | INFO | train_inner | {"epoch": 1, "update": 0.001, "loss": "25.22", "ppl": "3.90691e+07", "wps": "123.4", "ups": "0.01", "wpb": "16384", "bsz": "8", "num_updates": "3", "lr": "6e-05", "gnorm": "131.281", "loss_scale": "1", "train_wall": "133", "gb_free": "9.3", "wall": "294"}
2021-03-08 12:35:43 | INFO | train_inner | {"epoch": 1, "update": 0.001, "loss": "18.079", "ppl": "276809", "wps": "285.5", "ups": "0.02", "wpb": "16384", "bsz": "8", "num_updates": "4", "lr": "8e-05", "gnorm": "13.776", "loss_scale": "1", "train_wall": "57", "gb_free": "9.3", "wall": "351"}
2021-03-08 12:36:35 | INFO | train_inner | {"epoch": 1, "update": 0.001, "loss": "23.729", "ppl": "1.39088e+07", "wps": "316.7", "ups": "0.02", "wpb": "16384", "bsz": "8", "num_updates": "5", "lr": "0.0001", "gnorm": "72.774", "loss_scale": "1", "train_wall": "52", "gb_free": "9.3", "wall": "403"}
2021-03-08 12:37:28 | INFO | train_inner | {"epoch": 1, "update": 0.001, "loss": "20.429", "ppl": "1.41203e+06", "wps": "307.6", "ups": "0.02", "wpb": "16384", "bsz": "8", "num_updates": "6", "lr": "8e-05", "gnorm": "60.846", "loss_scale": "1", "train_wall": "53", "gb_free": "9.3", "wall": "456"}
2021-03-08 12:38:27 | INFO | train_inner | {"epoch": 1, "update": 0.001, "loss": "18.965", "ppl": "511684", "wps": "279.4", "ups": "0.02", "wpb": "16384", "bsz": "8", "num_updates": "7", "lr": "6e-05", "gnorm": "22.687", "loss_scale": "1", "train_wall": "59", "gb_free": "9.3", "wall": "515"}
2021-03-08 12:39:18 | INFO | train_inner | {"epoch": 1, "update": 0.001, "loss": "18.345", "ppl": "332887", "wps": "319.1", "ups": "0.02", "wpb": "16384", "bsz": "8", "num_updates": "8", "lr": "4e-05", "gnorm": "8.451", "loss_scale": "1", "train_wall": "51", "gb_free": "9.3", "wall": "566"}
2021-03-08 12:40:11 | INFO | train_inner | {"epoch": 1, "update": 0.002, "loss": "18.262", "ppl": "314336", "wps": "305.9", "ups": "0.02", "wpb": "16384", "bsz": "8", "num_updates": "9", "lr": "2e-05", "gnorm": "6.457", "loss_scale": "1", "train_wall": "54", "gb_free": "9.3", "wall": "620"}
2021-03-08 12:41:04 | INFO | train_inner | {"epoch": 1, "update": 0.002, "loss": "17.556", "ppl": "192686", "wps": "311.8", "ups": "0.02", "wpb": "16384", "bsz": "8", "num_updates": "10", "lr": "0", "gnorm": "5.796", "loss_scale": "1", "train_wall": "53", "gb_free": "9.3", "wall": "673"}
2021-03-08 12:41:04 | INFO | fairseq_cli.train | Stopping training due to num_updates: 10 >= max_update: 10
2021-03-08 12:41:04 | INFO | fairseq_cli.train | begin validation on "valid" subset
2021-03-08 12:43:15 | INFO | valid | {"epoch": 1, "valid_loss": "17.953", "valid_ppl": "253807", "valid_wps": "1868.4", "valid_wpb": "15400.2", "valid_bsz": "7.6", "valid_num_updates": "10"}
2021-03-08 12:43:15 | INFO | fairseq_cli.train | end of epoch 1 (average epoch stats below)
2021-03-08 12:43:15 | INFO | train | {"epoch": 1, "train_loss": "19.351", "train_ppl": "668509", "train_wps": "210.9", "train_ups": "0.01", "train_wpb": "16384", "train_bsz": "8", "train_num_updates": "10", "train_lr": "0", "train_gnorm": "36.26", "train_loss_scale": "1", "train_train_wall": "667", "train_gb_free": "9.3", "train_wall": "804"}
2021-03-08 12:43:15 | INFO | fairseq_cli.train | done training in 798.6 seconds
FSDP can also shard the parameters and optimizer states across multiple GPUs, reducing memory requirements significantly. On 8 x 32GB GPUs, sharding enables training the same 13B parameter model without offloading the parameters to CPU. However, without CPU offloading we'd only be able to fit a batch size of 1 per GPU, which would cause training speed to suffer.
We obtain the best performance on 8 GPUs by combining full sharding and CPU offloading. The following command trains the same 13B parameter GPT-3 model as before on 8 x 32GB V100 GPUs; training speed increases superlinearly from ~310 words per second to ~3200 words per second.
OMP_NUM_THREADS=20 CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
fairseq-train data-bin/wikitext-103-roberta-bpe-bin \
--ddp-backend fully_sharded --fp16 --fp16-init-scale 4 \
--cpu-offload --checkpoint-activations \
--task language_modeling --tokens-per-sample 2048 --batch-size 8 \
--arch transformer_lm_gpt3_13 \
--optimizer cpu_adam --adam-betas "(0.9,0.98)" \
--lr 0.0001 --lr-scheduler polynomial_decay --warmup-updates 5 --total-num-update 10 \
--max-update 10 --no-save --log-format json --log-interval 1
Example output
(...)
2021-03-08 18:04:09 | INFO | fairseq_cli.train | num. model params: 13,110,865,920 (num. trained: 13,110,865,920)
(...)
2021-03-08 18:04:09 | INFO | fairseq_cli.train | training on 8 devices (GPUs/TPUs)
2021-03-08 18:04:09 | INFO | fairseq_cli.train | max tokens per GPU = None and batch size per GPU = 8
(...)
Adam Optimizer #0 is created with AVX2 arithmetic capability.
Config: alpha=0.000100, betas=(0.900000, 0.980000), weight_decay=0.000000, adam_w=1
(...)
2021-03-08 18:05:06 | INFO | train_inner | {"epoch": 1, "update": 0.001, "loss": "16.408", "ppl": "86945.6", "wps": "0", "ups": "0", "wpb": "131072", "bsz": "64", "num_updates": "1", "lr": "2e-05", "gnorm": "18.27", "loss_scale": "4", "train_wall": "47", "gb_free": "9.3", "wall": "56"}
2021-03-08 18:05:45 | INFO | train_inner | {"epoch": 1, "update": 0.002, "loss": "16.352", "ppl": "83644.3", "wps": "3283.4", "ups": "0.03", "wpb": "131072", "bsz": "64", "num_updates": "2", "lr": "4e-05", "gnorm": "18.411", "loss_scale": "4", "train_wall": "40", "gb_free": "9.3", "wall": "96"}
2021-03-08 18:06:21 | INFO | fairseq.trainer | NOTE: gradient overflow detected, ignoring gradient, setting loss scale to: 2.0
2021-03-08 18:06:56 | INFO | fairseq.trainer | NOTE: gradient overflow detected, ignoring gradient, setting loss scale to: 1.0
2021-03-08 18:07:37 | INFO | train_inner | {"epoch": 1, "update": 0.006, "loss": "23.682", "ppl": "1.34537e+07", "wps": "1176.6", "ups": "0.01", "wpb": "131072", "bsz": "64", "num_updates": "3", "lr": "6e-05", "gnorm": "119.682", "loss_scale": "1", "train_wall": "111", "gb_free": "9.3", "wall": "208"}
2021-03-08 18:08:18 | INFO | train_inner | {"epoch": 1, "update": 0.007, "loss": "18.988", "ppl": "519921", "wps": "3189.1", "ups": "0.02", "wpb": "131072", "bsz": "64", "num_updates": "4", "lr": "8e-05", "gnorm": "14.934", "loss_scale": "1", "train_wall": "41", "gb_free": "9.3", "wall": "249"}
2021-03-08 18:08:59 | INFO | train_inner | {"epoch": 1, "update": 0.008, "loss": "20.08", "ppl": "1.10798e+06", "wps": "3223.1", "ups": "0.02", "wpb": "131072", "bsz": "64", "num_updates": "5", "lr": "0.0001", "gnorm": "59.92", "loss_scale": "1", "train_wall": "41", "gb_free": "9.3", "wall": "289"}
2021-03-08 18:09:39 | INFO | train_inner | {"epoch": 1, "update": 0.009, "loss": "18.323", "ppl": "327980", "wps": "3256.6", "ups": "0.02", "wpb": "131072", "bsz": "64", "num_updates": "6", "lr": "8e-05", "gnorm": "37.425", "loss_scale": "1", "train_wall": "40", "gb_free": "9.3", "wall": "330"}
2021-03-08 18:10:20 | INFO | train_inner | {"epoch": 1, "update": 0.01, "loss": "17.264", "ppl": "157354", "wps": "3188.7", "ups": "0.02", "wpb": "131072", "bsz": "64", "num_updates": "7", "lr": "6e-05", "gnorm": "10.824", "loss_scale": "1", "train_wall": "41", "gb_free": "9.3", "wall": "371"}
2021-03-08 18:11:01 | INFO | train_inner | {"epoch": 1, "update": 0.011, "loss": "16.794", "ppl": "113647", "wps": "3230", "ups": "0.02", "wpb": "131072", "bsz": "64", "num_updates": "8", "lr": "4e-05", "gnorm": "5.616", "loss_scale": "1", "train_wall": "41", "gb_free": "9.3", "wall": "411"}
2021-03-08 18:11:39 | INFO | train_inner | {"epoch": 1, "update": 0.012, "loss": "16.706", "ppl": "106938", "wps": "3384", "ups": "0.03", "wpb": "131072", "bsz": "64", "num_updates": "9", "lr": "2e-05", "gnorm": "5.318", "loss_scale": "1", "train_wall": "39", "gb_free": "9.3", "wall": "450"}
2021-03-08 18:12:19 | INFO | train_inner | {"epoch": 1, "update": 0.013, "loss": "16.548", "ppl": "95796.2", "wps": "3274.4", "ups": "0.02", "wpb": "131072", "bsz": "64", "num_updates": "10", "lr": "0", "gnorm": "5.22", "loss_scale": "1", "train_wall": "40", "gb_free": "9.3", "wall": "490"}
2021-03-08 18:12:19 | INFO | fairseq_cli.train | Stopping training due to num_updates: 10 >= max_update: 10
2021-03-08 18:12:19 | INFO | fairseq_cli.train | begin validation on "valid" subset
2021-03-08 18:12:45 | INFO | valid | {"epoch": 1, "valid_loss": "16.624", "valid_ppl": "101000", "valid_wps": "10855.9", "valid_wpb": "123202", "valid_bsz": "60.5", "valid_num_updates": "10"}
2021-03-08 18:12:45 | INFO | fairseq_cli.train | end of epoch 1 (average epoch stats below)
2021-03-08 18:12:45 | INFO | train | {"epoch": 1, "train_loss": "18.114", "train_ppl": "283776", "train_wps": "2567.8", "train_ups": "0.02", "train_wpb": "131072", "train_bsz": "64", "train_num_updates": "10", "train_lr": "0", "train_gnorm": "29.562", "train_loss_scale": "1", "train_train_wall": "480", "train_gb_free": "9.3", "train_wall": "516"}
2021-03-08 18:12:45 | INFO | fairseq_cli.train | done training in 509.9 seconds