Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Save state in dataset_iter_state when dataset is also an iterator #1279

Merged
merged 9 commits into from
Jun 24, 2024
155 changes: 136 additions & 19 deletions test/stateful_dataloader/test_state_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -1285,7 +1285,7 @@ def test(self):
for _ in range((num_workers + 1) * 2):
next(it)
state_dict = dl.state_dict()
worker_state = state_dict["_snapshot"]["_worker_snapshots"]["worker_0"]["dataset_state"]
worker_state = state_dict["_snapshot"]["_worker_snapshots"]["worker_0"]["fetcher_state"]["dataset_iter_state"]
self.assertEqual(len(worker_state), 7)
deep_copy_state_dict = deepcopy(state_dict)

Expand All @@ -1295,7 +1295,9 @@ def test(self):
next_state_dict = dl.state_dict()
self.assertEqual(state_dict, deep_copy_state_dict)
self.assertFalse(state_dict == next_state_dict)
worker_state = next_state_dict["_snapshot"]["_worker_snapshots"]["worker_0"]["dataset_state"]
worker_state = next_state_dict["_snapshot"]["_worker_snapshots"]["worker_0"]["fetcher_state"][
"dataset_iter_state"
]
self.assertEqual(len(worker_state), 11)

dl = StatefulDataLoader(
Expand All @@ -1311,7 +1313,7 @@ def test(self):
exp.extend(next(it))
state_dict = dl.state_dict()
self.assertEqual(exp, [3, 3])
worker_state = state_dict["_snapshot"]["_worker_snapshots"]["worker_0"]["dataset_state"]
worker_state = state_dict["_snapshot"]["_worker_snapshots"]["worker_0"]["fetcher_state"]["dataset_iter_state"]
self.assertEqual(len(worker_state), 9)


Expand All @@ -1334,16 +1336,15 @@ def test(self):
if num_workers > 0:
for i in range(num_workers):
# Ensure worker state is stored only once if the dataset is also the iterator
self.assertTrue(state_dict["_snapshot"]["_worker_snapshots"][f"worker_{i}"]["dataset_state"])
self.assertEqual(
self.assertEqual(state_dict["_snapshot"]["_worker_snapshots"][f"worker_{i}"]["dataset_state"], None)
self.assertTrue(
state_dict["_snapshot"]["_worker_snapshots"][f"worker_{i}"]["fetcher_state"][
"dataset_iter_state"
],
None,
]
)
else:
self.assertTrue(state_dict["dataset_state"])
self.assertEqual(state_dict["fetcher_state"]["dataset_iter_state"], None)
self.assertEqual(state_dict["dataset_state"], None)
self.assertTrue(state_dict["fetcher_state"]["dataset_iter_state"])


class PeriodicStateIterableDataset(torch.utils.data.IterableDataset):
Expand Down Expand Up @@ -1511,25 +1512,26 @@ def load_state_dict(self, state_dict):
class CountIterCallsIter(torch.utils.data.IterableDataset):
def __init__(self, length):
self.length = length
self.iter_calls = 0
self.local_iter_calls = 0
self.prev_state_iter_calls = 0
self.items = []

def __iter__(self):
self.items = list(range(self.length))
self.iter_calls += 1
self.local_iter_calls += 1
return self

def __next__(self):
if len(self.items) > 0:
self.items.popleft()
return self.items.pop(0)
else:
raise StopIteration

def state_dict(self):
return {"iter_calls": self.iter_calls, "items": deepcopy(self.items)}
return {"iter_calls": self.local_iter_calls + self.prev_state_iter_calls, "items": deepcopy(self.items)}

def load_state_dict(self, state_dict):
self.iter_calls = state_dict["iter_calls"]
self.prev_state_iter_calls = state_dict["iter_calls"]
self.items = state_dict["items"]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's revert this, except delete the self.iter_calls in load_state_dict and just use it to track calls in the test



Expand All @@ -1540,26 +1542,52 @@ def _get_iter_calls(self, state):
else:
w_states = list(state["_snapshot"]["_worker_snapshots"].values())

return [x["dataset_state"]["iter_calls"] for x in w_states]
if w_states[0]["dataset_state"] is not None:
return [x["dataset_state"]["iter_calls"] for x in w_states]
return [x["fetcher_state"]["dataset_iter_state"]["iter_calls"] for x in w_states]

def _run_test(self, num_workers, dataset):
# Need to make a copy here as iter calls is tracked and its count is stored in dataset state which persists across calls for single process runs.
dataset_copy = deepcopy(dataset) if num_workers == 0 else dataset
Copy link
Contributor

@andrewkho andrewkho Jun 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# Need to make a copy here as iter calls is tracked and its count is stored in dataset state which persists across calls for single process runs.
dataset_copy = deepcopy(dataset) if num_workers == 0 else dataset
# Need to make a copy here as iter calls is tracked and its count is stored in dataset state which persists across calls for single process runs.
dataset = deepcopy(dataset)

Let's not have the test behave differently for different num_workers as much as possible, a test that is modified just to pass is not really a good test and I'd prefer we delete it

dl = StatefulDataLoader(
dataset=dataset,
andrewkho marked this conversation as resolved.
Show resolved Hide resolved
num_workers=num_workers,
multiprocessing_context=("forkserver" if IS_MACOS and num_workers else None),
persistent_workers=True if num_workers else False,
Copy link
Contributor

@andrewkho andrewkho Jun 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The behaviour should be the same for both PW = True and False

)
iter(dl)
it = iter(dl)
state = dl.state_dict()
# Ensure iter is called only once per worker
self.assertEqual(self._get_iter_calls(state), [1] * max(1, num_workers))

for _ in range(10):
next(it)
state = dl.state_dict()
# Ensure that iter has not been invoked again
self.assertEqual(self._get_iter_calls(state), [1] * max(1, num_workers))

# Call iter on dl to see if iter is called again
iter(dl)
state = dl.state_dict()
self.assertEqual(self._get_iter_calls(state), [2] * max(1, num_workers))

dl2 = StatefulDataLoader(
dataset=dataset,
dataset=dataset_copy,
num_workers=num_workers,
multiprocessing_context=("forkserver" if IS_MACOS and num_workers else None),
persistent_workers=True if num_workers else False,
)
dl2.load_state_dict(state)
iter(dl2)
it = iter(dl2)
state2 = dl2.state_dict()
self.assertEqual(self._get_iter_calls(state2), [2] * max(1, num_workers))
# Ensure that iter is called only once per worker even when dataloader resumes from a state
self.assertEqual(self._get_iter_calls(state2), [3] * max(1, num_workers))

for _ in range(10):
next(it)
state = dl2.state_dict()
# Ensure that iter has not been invoked again
self.assertEqual(self._get_iter_calls(state2), [3] * max(1, num_workers))
Copy link
Contributor

@andrewkho andrewkho Jun 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's change this test to just assert that iter is only called once for new dataloaders with and without state-resumes, but not enforce the actual value of state["iter_calls"] after resume state


def test_inline(self):
self._run_test(0, CountIterCalls(100))
Expand All @@ -1574,5 +1602,94 @@ def test_mp_iter(self):
self._run_test(2, CountIterCallsIter(100))


class IterationState:
def __init__(self, start, end):
self.curr = start
self.end = end

def set_state(self, state):
self.curr = state["curr"]
self.end = state["end"]

def get_state(self):
return {"curr": self.curr, "end": self.end}


class StatesInitializationDataset(torch.utils.data.IterableDataset):
def __init__(self, length):
self.length = length

def __iter__(self):
worker_id = 0
if torch.utils.data.get_worker_info() is not None:
worker_id = torch.utils.data.get_worker_info().id
num_workers = 1
if torch.utils.data.get_worker_info() is not None:
num_workers = torch.utils.data.get_worker_info().num_workers

num_samples = (int)(self.length / num_workers)
self.iter_state = IterationState(num_samples * worker_id, num_samples * (worker_id + 1))
return self

def __next__(self):
if self.iter_state.curr >= self.iter_state.end:
raise StopIteration
value = self.iter_state.curr
self.iter_state.curr += 1
return value

def state_dict(self):
return {"state": self.iter_state.get_state()}

def load_state_dict(self, state_dict):
self.iter_state.set_state(state_dict["state"])


class TestStateInitializationDataset(TestCase):
def _run_test(self, num_workers, dataset):
length = dataset.length

# Ensure test is run with compatible parameters as the test and dataset used in the test doesn't cover all the corner cases
if num_workers > 0:
self.assertTrue(length % num_workers == 0)
self.assertTrue(length > 30)

dl = StatefulDataLoader(
dataset=dataset,
num_workers=num_workers,
collate_fn=identity,
multiprocessing_context=("forkserver" if IS_MACOS and num_workers else None),
)
it = iter(dl)
state = dl.state_dict()
data = []

for _ in range(length - 30):
data.extend(next(it))
state = dl.state_dict()

# Resume from state
dl2 = StatefulDataLoader(
dataset=dataset,
num_workers=num_workers,
collate_fn=identity,
multiprocessing_context=("forkserver" if IS_MACOS and num_workers else None),
)
dl2.load_state_dict(state)
it = iter(dl2)

for _ in range(30):
data.extend(next(it))

# Order could be different for multiworker case as the data comes from different workers, so use set to check equality instead of list
self.assertEqual(set(data), set(range(length)))

def test_inline(self):
self._run_test(0, StatesInitializationDataset(100))

def test_mp(self):
self._run_test(2, StatesInitializationDataset(100))


if __name__ == "__main__":
unittest.main()
9 changes: 4 additions & 5 deletions torchdata/stateful_dataloader/stateful_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,14 +332,13 @@ def _next_data(self):

def state_dict(self):
if self._dataset_kind == _DatasetKind.Iterable:
iter_state = None
if self._dataset_fetcher.dataset_iter is not self._dataset_fetcher.dataset:
iter_state = try_to_serialize(self._dataset_fetcher.dataset_iter)
fetcher_state = {
_DATASET_ITER_STATE: iter_state,
_DATASET_ITER_STATE: try_to_serialize(self._dataset_fetcher.dataset_iter),
_FETCHER_ENDED: self._dataset_fetcher.ended,
}
dataset_state = try_to_serialize(self._dataset_fetcher.dataset)
dataset_state = None
if self._dataset_fetcher.dataset_iter is not self._dataset_fetcher.dataset:
dataset_state = try_to_serialize(self._dataset_fetcher.dataset)
else:
fetcher_state = None
dataset_state = try_to_serialize(self._dataset_fetcher.dataset)
Expand Down
9 changes: 4 additions & 5 deletions torchdata/stateful_dataloader/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,14 +273,13 @@ def _make_state_dict(worker_id, dataset_kind, fetcher, dataset) -> Dict[str, Any
from torch.utils.data import _DatasetKind

if dataset_kind == _DatasetKind.Iterable:
iter_state = None
if fetcher.dataset_iter is not fetcher.dataset:
iter_state = try_to_serialize(fetcher.dataset_iter)
fetcher_state = {
_DATASET_ITER_STATE: iter_state,
_DATASET_ITER_STATE: try_to_serialize(fetcher.dataset_iter),
_FETCHER_ENDED: fetcher.ended,
}
dataset_state = try_to_serialize(fetcher.dataset)
dataset_state = None
if fetcher.dataset_iter is not fetcher.dataset:
dataset_state = try_to_serialize(fetcher.dataset)
else:
fetcher_state = None
# Pick up any user-defined dataset state
Expand Down
Loading