-
Notifications
You must be signed in to change notification settings - Fork 938
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Importing torchdata.stateful_dataloader
causes the test check_seedable_sampler
to fail
#2894
Comments
Alright, so I'm pretty convinced that these lines are the culprit:
Multiple attempts at looking into this show inconsistency between which sampler implementation to use, and by simply just importing If this is the case, I can think of 3 possible solutions.
|
Ah the joys of monkey-patching 😓
|
I've pinged the torchdata team internally, we'll come to some solution :) |
Thanks! Well, I can't really recommend merging in #2895 even if it looks good given this situation then. |
No worries, thank you for working on the initial support! Based on what happens next we can move the discussion to that PR on how to go forward. (And thank you so much for working on doing that!) |
I'll give that PR a review in the AM when I can look thoroughly at what you've done (great work BTW) |
SG, Thanks. It passes the test cases that I've written, but admittedly the implementation is rather hacky and I am only able to test on my local machine (with a single rtx 3060ti) I kinda was just trying to write the bare minimum needed code to get it working, and even that turned out to be incredibly invasive. |
A good solution has been found 🤗 |
This should be good to close since tests are now passing after upgrading |
Merely importing StatefulDataLoader from the nightly
torchdata
package (i.e. putting the linefrom torchdata.stateful_dataloader import StatefulDataLoader
anywhere in the code) causes one of the unit test,check_seedable_sampler
to fail.Stack trace obtained by running tests with the import
I suspect it has something to do with
torchdata
overriding torch's BatchSampler in this code. This is supported by the fact if I import this and add some logging, it seems SeedableRandomSampler.__iter__() is called one less time than expected:How to reproduce:
def check_seedable_sampler(): + import torchdata.stateful_dataloader # Set seed set_seed(42) train_set = RegressionDataset(length=10, seed=42) train_dl = DataLoader(train_set, batch_size=2, shuffle=True) ...
make test
and expect 1 failing testThe text was updated successfully, but these errors were encountered: