-
Notifications
You must be signed in to change notification settings - Fork 152
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Accessing DataPipe state with MultiProcessingReadingService #1033
Comments
When MP gets involved, the partial DataPipe graph is sent to worker process. So, there won't be any reference of that partial graph from the main process.
Yes, it is. And, we are working on the solution for it. And, we probably want to add a new request like https://github.com/pytorch/data/blob/a3b34a00e7d2b6694ea0d5e21fcc084080a3abae/torchdata/dataloader2/communication/messages.py#LL89C7-L89C21 to pass the request for state to worker process and let worker process send back the state of graph. Wondering do you have any specific use cases to access datapipe state on top of checkpointing? |
This is what I had envisioned as well. Glad to hear it's being worked on. Our specific use case is for a data loading progress bar, but instead of counting after sharding, we want to count batch sizes before sharding (that's because we can have training on multiple ranks, and we want to avoid multi-rank synchronisation, so we want to see where the rank 0 datapipe is currently pre-sharding). Our datapipe is like so: We have a potential workaround by also returning this size with an extra Map before Shard, but we'd prefer not to. |
FYI, I have started working on a PR that adds that functionality via the |
cc: @NivekT as the POC for snapshot/checkpoint.
Do you mean batch sizes or number of batches? |
I mean summed batch sizes. I've now created a PR as a RFC. |
Hi @NivekT, thanks for the detailed reply. I'll keep the conversation about state checkpointing in the PR, and will focus on the specific problem I'm trying to solve in this issue. The documentation is quite vague on how to use |
@jhoareau Can you tell us more about the set up where you are seeing duplicate data (what is the data pipeline)? For example, here is a multiprocessing example (ran with nightly version): dp1 = IterableWrapper(range(10)).sharding_filter().map(_fn)
dp2 = IterableWrapper(range(10)).sharding_round_robin_dispatch(SHARDING_PRIORITIES.MULTIPROCESSING).map(_fn)
for dp in [dp1, dp2]:
rs = MultiProcessingReadingService(num_workers=2)
dl = DataLoader2(dp, reading_service=rs)
print(list(dl)) # [0, 1, ..., 9] in both cases If you are using Let us know if this is unclear. |
Hi @NivekT it works with the sharding filter before the sharding round robin, indeed we're running multiprocessing + distributed. Thanks for the pointer. However, I needed to monkey-patch the I still see value in extracting state from the underlying datapipes with the MPReadingService, so I'll leave my PR up and hoping that we can also discuss that separately. |
Hi TorchData team,
I'm wondering how to access the state of the datapipe in the multi-processing context with DataLoader2 + MultiProcessingReadingService. When using no reading service, we can simply access the graph using
dataloader.datapipe
, then I can easily access the state of my datapipe using the code shown below.However, in the multi processing case, the datapipe graph is replaced with QueueWrapper instances, and I cannot find any way to communicate with the workers to get access to the state of the data pipe (and I get the error that my StatefulIterator cannot be found on the datapipe). If I access
dl2._datapipe_before_reading_service_adapt
I do get the initial state only which makes sense since there is no state sync between the main and worker processes.As far as I understand, this will also be a blocker for state capturing for proper DataLoader checkpointing when the MultiProcessingReadingService is being used.
Potentially, could we add a
getstate
communication primitive incommunication.messages
in order to capture the state (via getstate) of a datapipe in a worker process?We're also open to using
sharding_round_robin_dispatch
in order to keep more information in the main process but I'm a bit confused on how to use it, if you have some sample code for me for the following case?Running against today's master (commit a3b34a0):
cc: @ejguan @VitalyFedyunin @NivekT
The text was updated successfully, but these errors were encountered: