You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
class DatasetStateIterable(torch.utils.data.IterableDataset, Stateful):
def __init__(self, length):
self.length = length
def __iter__(self):
return iter(list(range(self.length)))
def state_dict(self):
print("Calling state dict")
return {"key": "value"}
def load_state_dict(self, state_dict):
pass
class TestSimple(TestCase):
def test(self):
dataset = DatasetStateIterable(100)
dl = StatefulDataLoader(
dataset=dataset,
num_workers=1,
snapshot_every_n_steps=10,
)
it = iter(dl)
for _ in range(30):
next(it)
self.assertTrue(False)
Here snapshot frequency is set to every 10 steps. And the iteration is carried out for 30 steps. But here is the output on number of items (12 times) state_dict is called on the dataset
Calling state dict
Calling state dict
Calling state dict
Calling state dict
Calling state dict
Calling state dict
Calling state dict
Calling state dict
Calling state dict
Calling state dict
Calling state dict
Calling state dict
This is expected because we need to eagerly request state_dict from workers and have no idea if other workers are sending StopIterations, so we need to ask for more than expected
🐛 Describe the bug
Consider the following code:
Here snapshot frequency is set to every 10 steps. And the iteration is carried out for 30 steps. But here is the output on number of items (12 times) state_dict is called on the dataset
Versions
Latest git commit - 82918dd
The text was updated successfully, but these errors were encountered: