[checkpoint] feat: open source fast checkpoint system #38
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Summary
We improved
vescale.checkpoint
with the following new features for fast checkpointing (where front three features are built-in techniques without necessitating manual activation):Saving Plan Caching: During training, the program may save model and optimizer checkpoints every n steps. Once a saving plan is created, it remains unchanged as long as the model does. We implemented plan caching to avoid regenerating the plan when checkpointing a model or optimizer multiple times, reducing unnecessary compute and communication costs. As of 05/30/2024, PyTorch DCP does not support plan caching.
Saving Plan Load-Balancing: In data parallel training, models are replicated across GPUs with different data parallel ranks but the same pipeline and tensor parallel ranks. Existing PyTorch DCP (as of 05/30/2024) deduplicates replicated tensors using a simple algorithm, causing GPUs with data parallel rank 0 to save the entire model, leading to load imbalance. We implemented a load-balancing algorithm to address this issue when deduplicating model tensors.
D2H Tensor Copying via Pinned Memory: When copying tensors from GPU to host memory,
vescale.checkpoint
uses pinned host memory, reducing memory allocation costs each time a checkpoint is saved. As of 05/30/2024, PyTorch DCP does not support pinned memory.Checkpoint Broadcasting: In data parallel training, models are replicated across GPUs with different data parallel ranks but the same pipeline and tensor parallel ranks. If
broadcast_checkpoint
is enabled,vescale.checkpoint.load
lets GPUs with data parallel rank 0 to load the model and broadcast it to other GPUs with higher data parallel ranks. If GPUs are connected with NCCL and I/O bandwidth is fully utilized, broadcasting model tensors speeds up checkpoint loading compared to all GPUs loading models from persistent storage. E.g.:Asynchronous Checkpointing: When
vescale.checkpoint.save
is called, it first generates a saving plan and then synchronously copies tensors from GPU to host memory. Ifasync_checkpoint
is enabled, the training program can continue after the D2H copying, whilevescale.checkpoint.save
continues to serialize tensors and dump the checkpoint to persistent storage asynchronously without blocking training. As of 05/30/2024, PyTorch DCP does not support asynchronous checkpointing. E.g.:Acknowledgement
We sincerely appreciate all contributors including but not limited to @shanesyy-1992 @raywan-110 @lazychao @AHEADer @MingjiHan99