Skip to content

Commit

Permalink
address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewkho committed Apr 25, 2024
1 parent 9cef2a5 commit e37d99d
Showing 1 changed file with 6 additions and 1 deletion.
7 changes: 6 additions & 1 deletion torchdata/stateful_dataloader/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ handles aggregation and distribution of state across multiprocess workers (but n
Using pip:

```bash
pip install --pre torchdata --extra-index-url https://download.pytorch.org/whl/nightly/cpu
pip install --pre torchdata --index-url https://download.pytorch.org/whl/nightly/cpu
```

Using conda:
Expand Down Expand Up @@ -65,6 +65,9 @@ import torch
import torch.utils.data
from torchdata.stateful_dataloader import StatefulDataLoader

# If you are using the default RandomSampler and BatchSampler in torch.utils.data
# they are patched when you import torchdata.stateful_dataloader so that defining
# a custom sampler here is unnecessary
class MySampler(torch.utils.data.Sampler[int]):
def __init__(self, high: int, seed: int, limit: int):
self.seed, self.high, self.limit = seed, high, limit
Expand Down Expand Up @@ -106,6 +109,7 @@ class NoisyRange(torch.utils.data.Dataset):
def state_dict(self):
return {"rng": torch.get_rng_state()}

# Test both single/multiprocess dataloading
for num_workers in [0, 2]:
print(f"{num_workers=}")
dl = StatefulDataLoader(NoisyRange(5, 1, 1), sampler=MySampler(5, 1, 10),
Expand Down Expand Up @@ -176,6 +180,7 @@ class MyIterableDataset(torch.utils.data.IterableDataset):
def load_state_dict(self, state_dict):
self.i = state_dict["i"]

# Test both single/multiprocess dataloading
for num_workers in [0, 2]:
print(f"{num_workers=}")
dl = StatefulDataLoader(
Expand Down

0 comments on commit e37d99d

Please sign in to comment.