Scaling up AllenNLP to 11B Parameter Models #5433
Replies: 1 comment 5 replies
-
@epwalsh Thanks for the great write up. Maybe I am missing something here, but your claim of
Seems contradictory to the claim in Deepspeed documentation:
In fact, they demonstrate the ability to train a 10B parameter model on a single V100 GPU. https://www.deepspeed.ai/tutorials/zero-offload/#training-a-10b-parameter-gpt-2-on-a-single-v100-gpu Is there a reason for such a huge discrepancy. Is it something specific that can't be achieved without deepspeed ? |
Beta Was this translation helpful? Give feedback.
-
A deep dive into the challenges of large scale training and the tools we used to get there
One of our goals this year was to scale up the AllenNLP library to be able to train 11 billion+ parameter models on a single node.
We chose 11B parameters as the target because that's the size of the largest official T5 model - one of the biggest publicly released models used in research today - and we wanted to be able to do it on a single node since training across multiple nodes is prohibitively expensive for many and quite involved from an engineering perspective.
In the end we met our goal: we were able to fine-tune all of the parameters of T5-11B with AllenNLP on a single machine equipped with 640GiB of GPU memory (this is still a lot, but Google Cloud does offer a a generally available virtual machine of that size). We did so by leveraging the optimizations proposed by the paper ZeRO: Memory Optimizations Toward Training Trillion Parameter Models via the FairScale library. This post takes a deep dive into the challenges of training models of that scale and the tools and tricks we used to get there.
Background
The 3 dimensions of parallelism.
Parallelizing the computation involved in deep learning is the most effective way to scale, and luckily deep learning is inherently parallelizable in multiple ways. In fact there are 3 high-level parallelization strategies for training: data parallelism (DP), model parallelism (MP), and pipeline parallelism (PP).
Data parallelism (DP)
Data parallelism is probably the most used form of parallelism in training deep learning models because it's conceptually simple and easy to use since libraries like PyTorch support it out-of-the-box (see the
DistributedDataParallel
wrapper class).DP simply involves computing the forward and backward pass on separate partitions of each mini-batch in parallel across different devices (or even different machines), and then combining the resulting gradients by averaging them before taking a step with the optimizer. For example, if you have 4 DP workers and mini-batches with 8 instances in them, then each worker would compute gradients for a micro-batch of 2 instances.
DP works because loss functions are typically in the form of an average of the loss across individual data points. Since the derivative of a sum is the sum of the derivatives, the gradient for all of the parameters involved in the computation of the loss for each batch can be calculated as the average of the gradients associated with each individual instance in the batch.
The traditional strategy for implementing DP - and the one used by PyTorch's
DistributedDataParallel
- is to just keep a full copy of the weights and optimizer state on each device. This is very computationally efficient but not very memory efficient, since memory usage scales linearly with the number of workers.Another potential downside of DP is that the effective batch size is tied to the number of workers, so as you scale up the number of DP workers you're also scaling up the batch size. That's not always a bad thing, but for many applications of deep learning there is a point at which the batch size becomes too big, causing test performance to suffer. This is called the efficiency trade-off (generalization vs utilization) [BTT19]. Although it doesn't seem to apply to large language models.
Model parallelism (MP)
Model parallelism is another strategy for splitting up the work of training a deep neural network. With MP, instead of dividing up the data sent to workers, a copy of each mini-batch is sent to all workers and different parts of the model are computed on different workers. In other words, the model itself is split across workers instead of the data.
Naturally this requires more inter-worker communication than DP, but it doesn't suffer from the utilization vs generation trade-off since the effective batch size is independent of the number of workers. It also has better scaling properties with respect to memory usage compared to traditional DP since the weights of the model can be divided among workers: each worker is only storing the weights it requires for its part of the computation.
The simplest form of MP - which we'll call naive-MP, because there's nothing actually parallel about it - is when you divide the model by complete layers. For example, if you were training a transformer model with 12 layers across 4 GPUs, you could split up the model by putting the first 3 layers on 1 GPU, the next 3 layers on another GPU, etc. This naive form does save memory, but it creates serial dependencies in that the output from the first GPU is the input to the second GPU, and so on. So only 1 GPU can be working at a time.
Although there are tricks to make this kind of setup more efficient (which is where pipeline parallelism comes in), true model parallelism involves splitting up some or all of the weights within layers so that all GPUs can be kept busy at all times.
For example, you could split up the heads of each multi-headed attention layer in your model. Now, if that was the only type of layer you split in your model, there would probably be a lot of duplicated computation because each worker would have to do the full computation for every layer that is not split.
In general, there's always a trade-off with communication volume when splitting up a layer: when you split up a layer you may be saving memory and gaining computational efficiency because less work will be duplicated, but this involves more communication between workers. For certain types of layers the communication overhead of MP may be so high that it overshadows any gains. Convolutional layers, for instance, are not very efficient with MP [BTT19].
Overall MP is probably used much less than DP because any implementation is necessarily architecture-dependent. Careful thought must be put in to determine which layers should be split and how to split them, and the code can become quite complex.
Pipeline parallelism (PP)
The final dimension is pipeline parallelism (or "pipelining" [BTT19]). Similar to MP, PP can improve memory and compute efficiency by partitioning the weights of a model across workers. In fact it's almost identical to naive-MP in that sequential layers of a model are put on different devices, but it's done in such a way that the forward pass, backward pass, and weight updates of each layer can overlap to some degree.
For example, consider a simple 4 layer neural net that we want to train on 4 GPUs. As in figure (a) below, we could partition the model by putting the weights for each layer on a single GPU.
Figure (b) shows how the forward and backward passes would be done with naive-MP, which is very inefficient because only 1 GPU would be active at a time. But with PP (figure (c)) each mini-batch is split into smaller micro-batches that can be processed by different workers at the same time.
This strategy can be used with any sequential NN without customizing the code for individual layers, and there are several tools out that make using PP very simple (DeepSpeed, FairScale, and G-pipe, for example). Though just like the other two forms of parallelism, there are trade-offs with PP as well.
For one, micro-batches have to arrive at each worker at the right rate to keep the devices fully utilized, and the latency within the system is proportional to the number of partitions [BTT19]. G-pipe and other implementations also require that the batch size is proportional to the number of pipeline groups, so PP often suffers from the same inherent scaling limitation as DP due to the utilization vs generalization trade-off.
Parallelism in AllenNLP
The 3 dimensions of deep learning parallelism are not mutually exclusive. They can all be used together, in fact. For example, if you had 2 nodes with 8 GPUs each, you could use DP across the nodes and then split each node into 4 pipeline groups of 2 GPUs each, utilizing MP within each group of 2 GPUs.
While scaling up AllenNLP, however, we decided to focus solely on data parallelism since DP is already integrated into AllenNLP and can be used with any model, not just sequential models. But the traditional strategy for DP - where a copy of the weights and optimizer state are stored on each device - is far too memory inefficient to work on 11B parameter models...
"Where did all the memory go?" - Section III from [ZERO]
Training a deep neural net takes a ton of memory, but the main reasons might surprise you.
The primary factor that makes large scale training so difficult is the amount of (GPU) memory it requires. But it's not totally apparent why training models this size takes up so much memory. For instance, the T5-11B weights themselves only take up around 45GiB of space, so you might think that an 80GiB A100 GPU would be big enough to perform a full forward and backward pass without any other tricks. But it turns out that 80GiB is not even close to the amount of memory you'd need.
The memory consumed during training can broken down into 2 groups: model states, which includes the weights of the model's parameters, the optimizer's state, and gradients, and then residual states, which includes activations, temporary buffers, and sections of fragmented memory that can't be used for anything else [ZERO]. It turns out that the model weights themselves take up a small proportion of the overall memory.
Model states
We can easily estimate how much memory the model states will require.
Consider training a model with N parameters using the ADAM optimizer. The optimizer state of ADAM includes a copy of the momentum and variance associated with each gradient. So assuming we're training in full precision (FP32), the model parameters, gradients, optimizer momentum, and optimizer variance will each take up 4N bytes for a total of 16N bytes.
And it's not any better when you're training with mixed-precision.
With mixed-precision training, the forward and backward pass are done in FP16 - requiring 2N and 2N bytes for the parameters and gradients - yet the optimizer must still keep a full precision copy of the model's weights along with the other parts of the optimizer state. So using ADAM with mixed-precision still requires 2N + 2N + 4N + 4N + 4N = 16N bytes.
Residual states
During a forward pass, the activations from each layer need to be stored in order for the gradients to be calculated during the backward pass. This requires a substantial amount of memory proportional to the size of the model.
With GPT-2 large, for example, which has 1.5B parameters, the activations themselves take up 60GiB of memory to process 32K tokens (batch size of 32 with sequences of 1000 tokens) [ZERO].
Gradient checkpointing (often called activation checkpointing as well), is one strategy for reducing the amount of activation memory in exchange for additional computation, but the activations for GPT-2 large would still require 8GiB of memory [ZERO].
Temporary buffers needed to store transient data and left-over fragmented memory can also take up a non-trivial amount of space. It's hard to predict exactly how much memory these will consume, but in the ZeRO paper for instance, they found that in extreme cases fragmentation caused out-of-memory errors despite over 30% of the total memory being unused.
Improving data parallelism
Insights from the ZeRO paper.
The huge memory requirements needed to train an 11B parameter model means that traditional MP - which doesn't have any memory reductions at all per GPU - is ill-equipped for the task. ZeRO (a recursive acronym for Zero Redundancy Optimizer), however, proposes a simple yet effective idea for drastically reducing the memory consumption of DP.
ZeRO-DP, as it's called in the paper, saves memory by partitioning the model states across DP workers so that they are no redundant copies of weights or optimizer state. This results in drastic memory savings that scales with the number of workers. In theory you could train arbitrarily large models with ZeRO-DP as long as you have enough devices to share the model state.
Similar to MP, this can be done layer-by-layer. That is, all of the weights and corresponding optimizer state within each layer of the NN are partitioned across devices. During a forward pass, right before each layer is computed the weights for that layer are gathered from all processes and promptly released from memory as soon as the activations are obtained. During an optimizer step, the optimizer only has to update parameters local to each device.
Unlike MP though, all workers perform the same computation (on different data) and it doesn't matter how exactly the weights within a layer are partitioned. In the FairScale implementation of ZeRO-DP (
FullyShardedDataParallel
), for example, all of the weights within a layer are simply flattened and collected into a single vector which is then partitioned into chunks of equal size across all devices.Choosing your units of partitioning
ZeRO-DP doesn't just work with sequential models, either. It can work with any model. The only choice a programmer has to make when integrating ZeRO-DP into their training pipeline is which blocks of their model to partition together.
I like to call these blocks units of partitioning. In a sequential model such as a transformer language model, the natural choice for the units of partitioning is the layers. In general, you want to choose your units of partitioning in a way that minimizes the frequency of inter-worker communication while providing just enough memory savings.
On one hand, if your entire model was your single unit of partitioning, communication would be at a minimum since the weights would only need to be gathered once per forward pass, but you'd have no memory savings since all of the weights of the model would have to held in memory by all workers during the forward pass. This would be a silly thing to do because you're essentially just back to traditional DP with a little more communication.
On the other extreme, if your units of partitioning are too fine-grained you may save a lot of memory but training will be slower than it needs to be due to the high frequency of inter-worker communication.
The modest cost of ZeRO
ZeRO-DP comes with additional communication compared to traditional DP, but one of the surprising insights from the paper is that this extra communication overhead is actually relatively modest. Part of the reason for that is that the communication required for traditional DP is already quite high.
In traditional DP, the communication overhead comes from the
all-reduce
operation that is applied at the end of each backward pass to average the gradients across workers. Mostall-reduce
implementations consist of areduce-scatter
followed by anall-gather
, each of which involves moving approximately N elements (where N is the number of parameters) [ZERO]. Therefore the total amount of data moved across processes during each step is roughly 2N.With ZeRO-DP however, the state is partitioned so that during a forward pass all weights need to be broadcast across workers at some point with an
all-gather
, which involves moving N elements in total. During the backward pass a similarall-gather
has to be done to compute the gradients corresponding to each worker, which involves moving another N elements in total, followed by areduce-scatter
to average the gradients across workers for each partition, which involves moving yet another N elements.So the total communication volume of ZeRO-DP is 3N, which is only 50% more than traditional DP. This seems like a small price to pay given that traditional DP won't even work for large models!
It's also interesting that the communication volume is only dependent on the model size. MP, on the other hand, involves communication proportional to the batch size, and even for modest batch sizes this can be significantly more communication than ZeRO-DP [ZERO].
Integrating ZeRO into AllenNLP
Putting it all together with FairScale.
We were pretty confident that ZeRO-DP would get us most of the way, if not all the way, to our goal of training 11B parameter models in AllenNLP on a single machine. Ultimately we chose to use the
FullyShardedDataParallel
(FSDP) class fromFairScale
to integrate it into our library. It was straight-forward to get this working with our T5 implementation, and it's now available in AllenNLP as aDdpAccelerator
registered as"fairscale_fsdp"
.But we were still getting out-of-memory errors when training T5-11B, so we needed one more trick. That trick just turned out to be using gradient checkpointing (activation checkpointing) in addition to FSDP. This was pretty easy since FairScale comes with an improved
checkpoint_wrapper
that works with FSDP out-of-the-box. This is available in AllenNLP now too as aCheckpointWrapper
registered as"fairscale"
.The added challenge of fine-tuning
At that point we were able to train T5-11B smoothly on a single machine, but we quickly realized that fine-tuning a pretrained T5-11B model actually presented another issue. We needed a way for each worker to load the pretrained weights into their model's
state_dict
. But with FairScale, any call to.load_state_dict()
involves synchronization across workers, so it needs to be called from each worker at the same time, and that requires first gathering all of the shards within each worker.So if we were to just call
.load_state_dict()
at the model level within each worker with the full pretrainedstate_dict
,each worker would have to hold a copy of the
state_dict
in (CPU) memory as well as a full copy of the current (randomized) weights of the model at the same time. In our case, with 8 workers and ~45GiB of weights, that would require 45 x 2 x 8 = 720GiB of (CPU) memory.The machine we were using actually could have accommodated this demand, nevertheless we thought this was way too inefficient to leave as-is. Our solution was to implement a function we call
load_state_dict_distributed()
. With this, only the main process needs to hold a copy of the fullstate_dict
in memory. The function works by recursively looking into the model to find its units of partitioning, and when a unit of partitioning is found - which in our case is a module that is wrapped withDdpAccelerator.wrap_module()
- the corresponding parameters in thestate_dict
are broadcast to all workers and.load_state_dict()
is called from that module only.How to use AllenNLP's new features
Getting started on your own tasks.
If you're interested in fine-tuning T5-11B on your own dataset using AllenNLP, you'll just need to implement a custom
DatasetReader
. Then you can adapt this config that we used to fine-tune T5-11B on the CNN-DM dataset.On the other hand, if you want to use the FairScale
DdpAccelerator
andCheckpointWrapper
with a different model, you should look at the source code for our T5 implementation. Basically you just need to wrap the units of partitioning of your model withCheckpointWrapper.wrap_module()
andDdpAccelerator.wrap_module()
, in that order. In T5, for example, we just used each layer/block as a unit of partitioning (see these lines).Of course, you'll need a machine with roughly 640GiB of GPU memory to train an 11B parameter model. Luckily Google Cloud does offer a machine this big: the A2 VM. But before you spin one of these up, keep in mind they are quite pricey. If you were to use an A2 VM with 16 40GiB A100s (the largest they offer) and a multiple terabyte drive, it would cost about $10K per week. However it would substantially cheaper if you made use of preemptible machines. Though preemptible machines only stay on for at most 24 hours, after which you'd have to spin up a new one and restart your training job. So if you go that route, make sure you're saving checkpoints often!
Additional info
Why we choose FairScale and not DeepSpeed
While FairScale and DeepSpeed have a very similar feature set, DeepSpeed does do a little bit more. For comparison:
Yet we found that it was easier to integrate FairScale into AllenNLP because DeepSpeed is too high-level for us. It requires using the DeepSpeed engine, and that would have forced us to implement an entirely new trainer class -- something we wanted to avoid. We also weren't too concerned about FairScale's lack of full support for mixed-precision training, since we only planned to utilize A100 GPUs, which are quite fast even when training in full precision.
References
Beta Was this translation helpful? Give feedback.
All reactions