Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Importing torchdata.stateful_dataloader causes the test check_seedable_sampler to fail #2894

Closed
byi8220 opened this issue Jun 26, 2024 · 9 comments

Comments

@byi8220
Copy link
Contributor

byi8220 commented Jun 26, 2024

Merely importing StatefulDataLoader from the nightly torchdata package (i.e. putting the line from torchdata.stateful_dataloader import StatefulDataLoader anywhere in the code) causes one of the unit test, check_seedable_sampler to fail.

Stack trace obtained by running tests with the import

stderr: Traceback (most recent call last):
stderr:   File "redacted/accelerate/src/accelerate/test_utils/scripts/test_script.py", line 827, in <module>
stderr:     main()
stderr:   File "redacted/accelerate/src/accelerate/test_utils/scripts/test_script.py", line 802, in main
stderr:     check_seedable_sampler()
stderr:   File "redacted/accelerate/src/accelerate/test_utils/scripts/test_script.py", line 381, in check_seedable_sampler
stderr:     assert torch.allclose(original_items, new_items), "Did not obtain the same items with the same seed and epoch."
stderr:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
stderr: AssertionError: Did not obtain the same items with the same seed and epoch.

I suspect it has something to do with torchdata overriding torch's BatchSampler in this code. This is supported by the fact if I import this and add some logging, it seems SeedableRandomSampler.__iter__() is called one less time than expected:

# We should see the epoch and seed sequence [(0, 42), (1, 43), (2, 44)] twice, but the first call with seed 42 is missing
# It looks like the first sample is being drawn without setting a seed

stdout: stdout: Shuffled central dataloader passing.
stdout: stdout: {'x': tensor([-1.3022,  0.1278], device='cuda:0'), 'y': tensor([0.3097, 3.2926], device='cuda:0')}
stdout: stdout: {'x': tensor([-1.0400,  0.7505], device='cuda:0'), 'y': tensor([0.9978, 4.5075], device='cuda:0')}
stdout: stdout: {'x': tensor([ 0.3047, -0.0168], device='cuda:0'), 'y': tensor([3.6974, 3.0542], device='cuda:0')}
stdout: stdout: {'x': tensor([-0.8530,  0.9406], device='cuda:0'), 'y': tensor([1.2889, 4.9939], device='cuda:0')}
stdout: stdout: {'x': tensor([-0.3162, -1.9510], device='cuda:0'), 'y': tensor([ 2.2716, -0.8553], device='cuda:0')}
stdout: stdout: --
stdout: stdout: Setting seed at epoch 1 43
stdout: stdout: {'x': tensor([-0.0168, -1.9510], device='cuda:0'), 'y': tensor([ 3.0542, -0.8553], device='cuda:0')}
stdout: stdout: {'x': tensor([-0.8530,  0.3047], device='cuda:0'), 'y': tensor([1.2889, 3.6974], device='cuda:0')}
stdout: stdout: {'x': tensor([-1.3022, -1.0400], device='cuda:0'), 'y': tensor([0.3097, 0.9978], device='cuda:0')}
stdout: stdout: {'x': tensor([ 0.1278, -0.3162], device='cuda:0'), 'y': tensor([3.2926, 2.2716], device='cuda:0')}
stdout: stdout: {'x': tensor([0.7505, 0.9406], device='cuda:0'), 'y': tensor([4.5075, 4.9939], device='cuda:0')}
stdout: stdout: --
stdout: stdout: Setting seed at epoch 2 44
stdout: stdout: {'x': tensor([0.7505, 0.1278], device='cuda:0'), 'y': tensor([4.5075, 3.2926], device='cuda:0')}
stdout: stdout: {'x': tensor([ 0.9406, -0.0168], device='cuda:0'), 'y': tensor([4.9939, 3.0542], device='cuda:0')}
stdout: stdout: {'x': tensor([-0.3162, -0.8530], device='cuda:0'), 'y': tensor([2.2716, 1.2889], device='cuda:0')}
stdout: stdout: {'x': tensor([-1.0400, -1.9510], device='cuda:0'), 'y': tensor([ 0.9978, -0.8553], device='cuda:0')}
stdout: stdout: {'x': tensor([-1.3022,  0.3047], device='cuda:0'), 'y': tensor([0.3097, 3.6974], device='cuda:0')}
stdout: stdout: --
stdout: stdout: Resetting epoch and seed
stdout: stdout: Setting seed at epoch 0 42
stdout: stdout: {'x': tensor([0.7505, 0.1278], device='cuda:0'), 'y': tensor([4.5075, 3.2926], device='cuda:0')}
stdout: stdout: {'x': tensor([-1.0400, -0.0168], device='cuda:0'), 'y': tensor([0.9978, 3.0542], device='cuda:0')}
stdout: stdout: {'x': tensor([-1.9510, -1.3022], device='cuda:0'), 'y': tensor([-0.8553,  0.3097], device='cuda:0')}
stdout: stdout: {'x': tensor([ 0.3047, -0.8530], device='cuda:0'), 'y': tensor([3.6974, 1.2889], device='cuda:0')}
stdout: stdout: {'x': tensor([ 0.9406, -0.3162], device='cuda:0'), 'y': tensor([4.9939, 2.2716], device='cuda:0')}
stdout: stdout: --
stdout: stdout: Setting seed at epoch 1 43
stdout: stdout: {'x': tensor([-0.0168, -1.9510], device='cuda:0'), 'y': tensor([ 3.0542, -0.8553], device='cuda:0')}
stdout: stdout: {'x': tensor([-0.8530,  0.3047], device='cuda:0'), 'y': tensor([1.2889, 3.6974], device='cuda:0')}
stdout: stdout: {'x': tensor([-1.3022, -1.0400], device='cuda:0'), 'y': tensor([0.3097, 0.9978], device='cuda:0')}
stdout: stdout: {'x': tensor([ 0.1278, -0.3162], device='cuda:0'), 'y': tensor([3.2926, 2.2716], device='cuda:0')}
stdout: stdout: {'x': tensor([0.7505, 0.9406], device='cuda:0'), 'y': tensor([4.5075, 4.9939], device='cuda:0')}
stdout: stdout: --
stdout: stdout: Setting seed at epoch 2 44
stdout: stdout: {'x': tensor([0.7505, 0.1278], device='cuda:0'), 'y': tensor([4.5075, 3.2926], device='cuda:0')}
stdout: stdout: {'x': tensor([ 0.9406, -0.0168], device='cuda:0'), 'y': tensor([4.9939, 3.0542], device='cuda:0')}
stdout: stdout: {'x': tensor([-0.3162, -0.8530], device='cuda:0'), 'y': tensor([2.2716, 1.2889], device='cuda:0')}
stdout: stdout: {'x': tensor([-1.0400, -1.9510], device='cuda:0'), 'y': tensor([ 0.9978, -0.8553], device='cuda:0')}
stdout: stdout: {'x': tensor([-1.3022,  0.3047], device='cuda:0'), 'y': tensor([0.3097, 3.6974], device='cuda:0')}
stdout: stdout: --
stdout: stdout: original_items:
stdout: stdout:  tensor([-1.3022,  0.1278, -1.0400,  0.7505,  0.3047, -0.0168, -0.8530,  0.9406,
stdout: stdout:         -0.3162, -1.9510, -0.0168, -1.9510, -0.8530,  0.3047, -1.3022, -1.0400,
stdout: stdout:          0.1278, -0.3162,  0.7505,  0.9406,  0.7505,  0.1278,  0.9406, -0.0168,
stdout: stdout:         -0.3162, -0.8530, -1.0400, -1.9510, -1.3022,  0.3047], device='cuda:0')
stdout: stdout: new_items:
stdout: stdout:  tensor([ 0.7505,  0.1278, -1.0400, -0.0168, -1.9510, -1.3022,  0.3047, -0.8530,
stdout: stdout:          0.9406, -0.3162, -0.0168, -1.9510, -0.8530,  0.3047, -1.3022, -1.0400,
stdout: stdout:          0.1278, -0.3162,  0.7505,  0.9406,  0.7505,  0.1278,  0.9406, -0.0168,
stdout: stdout:         -0.3162, -0.8530, -1.0400, -1.9510, -1.3022,  0.3047], device='cuda:0')

How to reproduce:

  1. Install the torchdata nightly:
pip install --pre torchdata --index-url https://download.pytorch.org/whl/nightly/cpu
  1. Import stateful_dataloader in the test check_seedable_sampler:
def check_seedable_sampler():
+   import torchdata.stateful_dataloader
    # Set seed
    set_seed(42)
    train_set = RegressionDataset(length=10, seed=42)
    train_dl = DataLoader(train_set, batch_size=2, shuffle=True)
    ...
  1. Run make test and expect 1 failing test
@byi8220
Copy link
Contributor Author

byi8220 commented Jun 28, 2024

Alright, so I'm pretty convinced that these lines are the culprit:

  1. https://github.com/pytorch/data/blob/main/torchdata/stateful_dataloader/sampler.py#L61-L62
  2. https://github.com/pytorch/data/blob/main/torchdata/stateful_dataloader/sampler.py#L134-L135

Multiple attempts at looking into this show inconsistency between which sampler implementation to use, and by simply just importing stateful_dataloader before anything else, I have managed to get this test to work.

If this is the case, I can think of 3 possible solutions.

  1. Make sure that we always import this first. This seems very fragile.
  2. Somehow resolve this issue through code. Unless someone better at python than me has seen this problem before, this seems like an absolute nightmare of a code change.
  3. File an issue with the torchdata maintainers. I'm not sure if this is their problem to fix but the fact that their imports redefine fundamental pytorch types such as RandomSampler and BatchSampler is dubious to me.

@muellerzr
Copy link
Collaborator

Ah the joys of monkey-patching 😓

  1. this is a them issue IMO. At most we can guard imports when we officially support their dataloaders, but Accelerate is designed to work with native PyTorch dataloaders. This monkey-patching approach needs to be carefully guarded on such a core feature.

@muellerzr
Copy link
Collaborator

I've pinged the torchdata team internally, we'll come to some solution :)

@byi8220
Copy link
Contributor Author

byi8220 commented Jul 1, 2024

Thanks!

Well, I can't really recommend merging in #2895 even if it looks good given this situation then.

@muellerzr
Copy link
Collaborator

No worries, thank you for working on the initial support! Based on what happens next we can move the discussion to that PR on how to go forward. (And thank you so much for working on doing that!)

@muellerzr
Copy link
Collaborator

I'll give that PR a review in the AM when I can look thoroughly at what you've done (great work BTW)

@byi8220
Copy link
Contributor Author

byi8220 commented Jul 1, 2024

SG, Thanks. It passes the test cases that I've written, but admittedly the implementation is rather hacky and I am only able to test on my local machine (with a single rtx 3060ti)

I kinda was just trying to write the bare minimum needed code to get it working, and even that turned out to be incredibly invasive.

@muellerzr
Copy link
Collaborator

A good solution has been found 🤗

pytorch/data#1281

@byi8220
Copy link
Contributor Author

byi8220 commented Jul 3, 2024

This should be good to close since tests are now passing after upgrading torchdata to their nightly 0.7.1.dev20240703+cpu

@byi8220 byi8220 closed this as completed Jul 3, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants