-
Notifications
You must be signed in to change notification settings - Fork 149
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RFC] Polylithic: Enabling multi-threaded DataLoading through non-monolithic parallelism #1318
Comments
It looks like we can implement datapipes on top of Polylithic nodes and continue to use existing code without changing them at all. What are your thoughts on this? |
@mfbalin I would not want to implement datapipes on top of the Polylithic, as there would be twice the surface area to support and test against |
It's not clear in the example code whether The semantics of "Parallel Prefetch" would be just like "monolithic" parallelism where the entire node + its dependency nodes are replicated and produce results to a queue. It has a benefit over "Parallel Mapper" because a mapper requires two-sided communication, however prefetch workers take no input from main worker and only has to produce output. It requires care to ensure no-duplicates and reproducibility, similar to what people would do today with |
Thanks @ppwwyyxx for the comment and suggestion, all implementations are up for discussion. Definitely open to a If the intention is just to parallelize source/producer nodes, say to interleave multiple files, then we should just find a way to parallelize there. |
OTOH, if we created a ParallelPrefetch with multi-processing that we could configure to have identical behaviour of current torch.DataLoader, we may be able to consolidate the implementation code while maintaining backwards compatibility |
I'm on board with the idea of Regarding debugging, I encountered several challenges, although these seem to stem more from the nature of modular chained pipelines in Python than from torchdata specifically:
I appreciate the torchdata library and the idea of simple, composable nodes. After building several production pipelines with it, I've started to believe that Python's inherent lack of type safety between pipes and the inability to skip over redundant "glue" code in connections makes it inherently challenging to create easy-to-debug, composable pipelines. This might also explain why many popular data processing libraries seem to prefer more monolithic approaches (or large modules linked together), as they're less cumbersome to debug. That said, I still like the concept of "horizontal APIs" made up of simple, linked objects that can achieve impressive things when combined. I'm curious if the issues I've mentioned can be effectively resolved within Python. Any insights or planned improvements would be greatly appreciated. |
@josiahls thank you for the feedback and describing your pain points! These are definitely the questions and issues that we want to prioritize and make sure the solution is debuggable for developers and users. As far as the questions go, this does sound like something inherent in functional chains of operators.
|
It could be nice if with the new design we could achieve the profiling of the whole data provisioning process like |
🚀 The feature
TL;DR - We want to lean into modular Multi-Threading/Multi-Processing instead of the current monolithic Multi-Processing, and steer users away from the monolithic Dataset parallelism approach towards composable Iterables for pre-proc operations, with parallelism configured within each operation. This will enable multi-threaded dataloading (with NoGIL support), auto-tunable parallelism, torch.compilable and GPU enabled preproc operations, more efficient loading of mixed-modalities, and composable dataloading and pre-proc graphs.
Motivation, pitch
Working name for the project: Polylithic (non-monolithic)
Where it will live: torchdata
Why Multi-Threading in the DataLoader?
Why do users need multi-threading in PyTorch DataLoader? If your GPUs are not starved, then frankly you don’t need to change anything. But what happens when your GPUs are starving, for example in training or offline inference where expensive pre-proc like video decoding happens on the trainer? Multi-Modal LLM training requires dataloading of text and images and/or video data sources.
Today, torch.utils.data.DataLoader uses multi-processing to perform dataloading and pre-proc in parallel with your training loop, and the first thing you should try is increasing the number of multiprocess workers. By changing a single variable in the DataLoader constructor, you can pretty easily add parallelism to your job, even with a custom Dataset class written in Python. If this keeps your GPUs fed, then you’re done!
But, Python multiprocessing may also introduce a lot of friction, and is heavy-weight in memory often leading to CPU OOMs, even on ZionEX’s with 96 cores, and 2TB of RAM shared between 8 Nvidia A100s.
At this point, the user has a few options to consider:
What about Multi-Threading?
Within the realm of on-box compute approaches, multi-threading has lower startup and memory costs than multi-processing. Many multi-processing pain points disappear in a multi-threaded world, such as avoiding Python’s copy-on-read behavior, IPC costs, and user pain points. Subject to the same memory budget, jobs will be able to scale up the number of workers much more than with multi-processing (until hitting CPU limits).
However multi-threading introduces many new problems as well:
We should still enable multi-threading in PyTorch DataLoader
Despite the problems multi-threading introduces, we should still offer a way to perform multi-threaded dataloading.
Supporting Multi-Modal DataLoading
Llama 3 is here, and Llama 4 will arrive soon with early-fusion multi-modality. Tasks like fine-tuning, alignment, and distillation will require multi-modal dataloading for our internal and external users. LLM training often requires reading from 10s-100s of multi-modal datasets, tokenizing them, and packing them into a “token-buffer” where tokens from individual datasets are shuffled and combined into training examples for the model.
Audio, Image, and Video datasets may also require heavy-weight decoding operations to be performed before tokenization, and the difference in the data sizes between text, image, and video may be orders of magnitude. GPU decoding of images and video is an option for users as well, and libraries like Nvidia DALI will compile the entire pre-proc pipeline into GPU operations, minimizing the overhead of transfers between CPU and GPU memory.
Existing Context and definitions
Torch.utils.data contains the following abstractions today:
“Monolithic” parallelism
Currently users have a single lever to control parallelism, num_workers. When num_workers > 0, the DataLoader creates background processes and holds a copy of the entire Dataset object in process memory, treating it as a “monolithic” object to be parallelized.
Consider the scenario in the figure below, where a user has defined an iterable dataset which combines two text datasets and one image dataset. There is no parallelism in this example.
Now consider the common case when only the image-decoding and tokenization is a bottleneck causing GPU Starvation. With today’s tooling, users simply increase dataloader num_workers > 1. The image below depicts how this is done today, by treating the entire IterableDataset as a monolith that is forked/spawned to another process.
A granular parallelism approach
To fix the monolithic parallelism problem, we want to introduce abstractions and tooling that expose more granular parallelism controls to users. This implies a solution where users construct their dataloading and pre-proc pipelines by defining and stitching together datasource and pre-proc nodes into a graph, in a similar fashion to tf.data and datapipes, with data passing between the nodes. The root of the graph is the node which produces batches that are passed to the model. The leaves are data-sources which produce data by reading from local disk, remote storage, or eg random number generators. Intermediate nodes may transform data, perform pre-fetching, combine data from multiple nodes, perform “enrichments” by eg fetching images from blob stores, perform decoding, schedule GPU operations etc.
Requirements and Constraints
To adequately support Multi Modal LLM training for PyTorch users, address the above pain points, and give us the best chance for wide-adoption, we want our solution to meet the following requirements and constraints:
How will we achieve this/what will we build? Plan of Record
We will introduce a new base class, (working name) say class PolylithicNode(torch.utils.data.IterableDataset). Nodes in the graph will be instances of subclasses of PolylithicNode. Nodes will define a .iterator() method instead of overriding __iter__(). This is inspired by nn.Module’s implementation where users define .forward() instead of __call__. This will allow PolylithicNode to instantiate user-defined iterators and wrap them, insert queues for pipeline-parallelism, and measure latency. For backwards compatibility, we’ll provide a wrapper which takes an existing IterableDataset. Users can compose their datasets by composing PolylithicNodes (ie through iter() and next()).
Example of composing iterable datasets to create a multimodal dataloader. [Note that we are open to ideas on syntactical sugar]
More complex diagram
Alternatives
__getitem__
to be thread-safe and contention-free, which will be a challenge when users depend on custom libraries.Additional context
What about DataPipes and DL v2?
To avoid confusion, and remove potential issues with backwards compatibility, we will still be deprecating DataPipes and DL v2, however due to the similarity of DataPipes and the current approach, we believe migration should be fairly straightforward. DataLoader2 has very low adoption and we won’t be providing something similar to replace it.
DataPipes and DL v2 were designed to address issues like composability, and there is a lot of value in what was built, however their parallelism and data sharding structure is still based on a monolithic approach (eg plug a datapipe into DL v1, or DL v2 + multiprocess reading service). They required migration/rewrite of datasets with often no improvement in performance, identifying dataloading-preproc bottlenecks was a challenge, and shuffling/sharding pain points weren’t adequately addressed.
The proposed approach improves upon DataPipes + DLv2 in the following ways:
We want to maintain the composable aspects of datapipes, the eager-execution, and continue our partnerships with storage and cloud providers (AWS, Azure, GCP) where they provide high-performance clients, share customer pain points, and provide recommended solutions and examples to their users.
The text was updated successfully, but these errors were encountered: