-
Notifications
You must be signed in to change notification settings - Fork 152
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Stateful dataloader README #1245
Conversation
@andrewkho has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. It seems like the tests status are still visible in Checks tab |
It seems that CI gets skipped when the last action is by facebook-github-bot |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Few comments, looks good!
Using pip: | ||
|
||
```bash | ||
pip install --pre torchdata --extra-index-url https://download.pytorch.org/whl/nightly/cpu |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@huydhn mentioned that --index-url
is preferred compared to --extra-index-url
. I did have trouble getting the latest nightly build with the latter. It might be good to change this in the main README as well later
import torch.utils.data | ||
from torchdata.stateful_dataloader import StatefulDataLoader | ||
|
||
class MySampler(torch.utils.data.Sampler[int]): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we mentioned that RandomSampler and BatchSampler has been patched so that it should work out of the box for users who use the default samplers. On a related note, should we also consider making DistributedSampler 'Stateful`?
def state_dict(self): | ||
return {"rng": torch.get_rng_state()} | ||
|
||
for num_workers in [0, 2]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe add a comment here that it works for both SingleProcess and MultiProcess dataloading setups?
Tracking iteration order with Iterable-style datasets requires state from each worker-level instance of the dataset to | ||
be captured. You can define `state_dict/load_state_dict` methods on your dataset which capture worker-level state. | ||
`StatefulDataLoader` will handle aggregation across workers and distribution back to the workers. Calling | ||
`load_state_dict` requires `StatefulDataLoader` to have same `num_workers` as those of the provided `state_dict`. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
num_workers
need to be same even for map-style datasets correct?
Separate discussions:
i) should we consider relaxing that if map dataset doesn't have state?
ii) assert statement https://github.com/pytorch/data/blob/main/torchdata/stateful_dataloader/stateful_dataloader.py#L1211 can have an error message to make this check explicit for the user
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
num_workers can be different for map-style if they only care about iteration order, I can clarify that. If map-style dataset doesn't have a state, it should just work but we may need to remove that assert
else: | ||
worker_id = 0 | ||
num_workers = 1 | ||
self.g.manual_seed(self.seed) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need to set the seed to be same for all workers? Can the one set by the worker loop in worker.py be used?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Because of how this is set up (same array on all workers, iterate through based on worker_id) this is correct
[tensor([ 4, 10]), tensor([ 3, 11]), tensor([1, 6])] | ||
""" | ||
``` | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[This can be done in a separate PR too] Can we also list the corner cases where it doesn't work as expected? Basically this case where user has to pass in an explicit generator for it work correctly - https://github.com/pytorch/data/blob/main/test/stateful_dataloader/test_state_dict.py#L686.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes that's a good idea. If we could detect and raise assertion errors it would be even better
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Few comments, looks good otherwise!
@andrewkho has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
@andrewkho merged this pull request in abb3f01. |
Please read through our contribution guide prior to
creating your pull request.
Fixes #{issue number}
Changes
Adds README.md for torchdata.stateful_dataloader
Testplan:
Checked rendering and links in the branch preview on github: