Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixup dataloader state dict bugs + incorporate load/save_state API #3034

Merged
merged 19 commits into from
Aug 23, 2024
Merged
24 changes: 16 additions & 8 deletions examples/by_feature/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,10 @@
from datasets import load_dataset
from torch.optim import AdamW
from torch.utils.data import DataLoader
from transformers import AutoModelForSequenceClassification, AutoTokenizer, get_linear_schedule_with_warmup, set_seed
from transformers import AutoModelForSequenceClassification, AutoTokenizer, get_linear_schedule_with_warmup

from accelerate import Accelerator, DistributedType
from accelerate import Accelerator, DataLoaderConfiguration, DistributedType
from accelerate.utils import set_seed


########################################################################
Expand Down Expand Up @@ -125,7 +126,8 @@ def training_function(config, args):
if os.environ.get("TESTING_MOCKED_DATALOADERS", None) == "1":
config["num_epochs"] = 2
# Initialize accelerator
accelerator = Accelerator(cpu=args.cpu, mixed_precision=args.mixed_precision)
dataloader_config = DataLoaderConfiguration(use_stateful_dataloader=args.use_stateful_dataloader)
accelerator = Accelerator(cpu=args.cpu, mixed_precision=args.mixed_precision, dataloader_config=dataloader_config)
# Sample hyper-parameters for learning rate, batch size, seed and a few other HPs
lr = config["lr"]
num_epochs = int(config["num_epochs"])
Expand All @@ -146,7 +148,7 @@ def training_function(config, args):
else:
checkpointing_steps = None

set_seed(seed)
set_seed(seed, deterministic=True)

train_dataloader, eval_dataloader = get_dataloaders(accelerator, batch_size)
metric = evaluate.load("glue", "mrpc")
Expand Down Expand Up @@ -217,8 +219,11 @@ def training_function(config, args):
model.train()
# New Code #
if args.resume_from_checkpoint and epoch == starting_epoch and resume_step is not None:
# We need to skip steps until we reach the resumed step
active_dataloader = accelerator.skip_first_batches(train_dataloader, resume_step)
# We need to skip steps until we reach the resumed step only if we are not using a stateful dataloader
if not args.use_stateful_dataloader:
active_dataloader = accelerator.skip_first_batches(train_dataloader, resume_step)
else:
active_dataloader = train_dataloader
overall_step += resume_step
else:
# After the first iteration though, we need to go back to the original dataloader
Expand Down Expand Up @@ -248,7 +253,6 @@ def training_function(config, args):
if args.output_dir is not None:
output_dir = os.path.join(args.output_dir, output_dir)
accelerator.save_state(output_dir)

model.eval()
for step, batch in enumerate(eval_dataloader):
# We could avoid this line since we set the accelerator with `device_placement=True` (the default).
Expand All @@ -261,7 +265,6 @@ def training_function(config, args):
predictions=predictions,
references=references,
)

eval_metric = metric.compute()
# Use accelerator.print to print only on the main process.
accelerator.print(f"epoch {epoch}:", eval_metric)
Expand Down Expand Up @@ -309,6 +312,11 @@ def main():
default=None,
help="If the training should continue from a checkpoint folder.",
)
parser.add_argument(
"--use_stateful_dataloader",
action="store_true",
help="If the dataloader should be a resumable stateful dataloader.",
)
args = parser.parse_args()
config = {"lr": 2e-5, "num_epochs": 3, "seed": 42, "batch_size": 16}
training_function(config, args)
Expand Down
10 changes: 10 additions & 0 deletions src/accelerate/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,11 @@ def save_accelerator_state(
sampler = dataloader.get_sampler()
if isinstance(sampler, SeedableRandomSampler):
save(sampler, output_sampler_file, save_on_each_node=save_on_each_node, safe_serialization=False)
if getattr(dataloader, "use_stateful_dataloader", False):
dataloader_state_dict_name = "dl_state_dict.bin" if i == 0 else f"dl_state_dict_{i}.bin"
output_dataloader_state_dict_file = output_dir.joinpath(dataloader_state_dict_name)
state_dict = dataloader.state_dict()
torch.save(state_dict, output_dataloader_state_dict_file)
logger.info(f"Sampler state for dataloader {i} saved in {output_sampler_file}")

# GradScaler state
Expand Down Expand Up @@ -241,6 +246,11 @@ def load_accelerator_state(
sampler = dataloader.get_sampler()
if isinstance(sampler, SeedableRandomSampler):
sampler = dataloader.set_sampler(torch.load(input_sampler_file))
if getattr(dataloader, "use_stateful_dataloader", False):
dataloader_state_dict_name = "dl_state_dict.bin" if i == 0 else f"dl_state_dict_{i}.bin"
muellerzr marked this conversation as resolved.
Show resolved Hide resolved
input_dataloader_state_dict_file = input_dir.joinpath(dataloader_state_dict_name)
state_dict = torch.load(input_dataloader_state_dict_file)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably not something for this PR, but this line could cause trouble when the weights_only switch to torch.load will come.

dataloader.load_state_dict(state_dict)
logger.info("All dataloader sampler states loaded successfully")

# GradScaler state
Expand Down
23 changes: 21 additions & 2 deletions src/accelerate/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,8 +442,21 @@ def state_dict(self):
return self.dl_state_dict

def load_state_dict(self, state_dict):
# The state dict will be off by a factor of `n-1` batch too many during DDP,
# so we need to adjust it here
if PartialState().distributed_type != DistributedType.NO:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be fixed during loading or rather during saving?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we make it during saving, that isolates it to users who just do Accelerator.save_state/Accelerator.load_state, which (esp in the trainer) might not be what users want to end up doing, since it's always entirely optional. I'd rather it happen in load IMO but will think if there's a better way

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The other (probably more right option) is to fix it in state_dict()

Copy link
Contributor

@byi8220 byi8220 Aug 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adding onto this, is it appropriate to hack around with the state_dict's fields such as _iter_yielded, _num_yielded?

IIUC this works in the basic case, but torchdata seems to support custom state functions.

This makes me feel like this needs to be fixed during saving, and maybe in a way that doesn't make assumptions about the contents of state_dict (make sure to save the right state dict?).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We must, it's not a matter of if its appropriate or not. If we don't, the sampler/resuming simply wont work :) As we need to modify their values in their sampler.

This is a naiive implementation to start, and if we hit edge cases with more things to be adjusted later, we can. But the base case is supported

Copy link
Collaborator Author

@muellerzr muellerzr Aug 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The only point that would be is on a non-multi/distributed setup. Otherwise they're likely not using Accelerate dataloaders

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Regarding the question of custom state functions: Would it be possible to optionally allow custom data loaders to make the update themselves and otherwise fall back on the solution provided here? So for instance:

if hasattr(self.base_data_loader, "correct_state_for_prefetch"):  # or whatever fitting name
    self.dl_state_dict = self.base_data_loader.correct_state_for_prefetch(self.dl_state_dict, PartialState())
else:
    ...  # existing code

Of course, this needs to be documented so that users can correctly implement this method.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm more a fan of this, but not sure if we can get there in our current state since it's either a StatefulDataLoader or a native torch.utils.data.DataLoader that gets built

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got confused, the type of the data loader is fix.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I went with an adjust_state_dict_for_prefetch func in DataLoaderAdapter, with documentation on how overriding it should work.

factor = PartialState().num_processes - 1
if state_dict["_sampler_iter_yielded"] > 0:
state_dict["_sampler_iter_yielded"] -= factor
if state_dict["_num_yielded"] > 0:
state_dict["_num_yielded"] -= factor
if state_dict["_index_sampler_state"] is not None:
if (
"samples_yielded" in state_dict["_index_sampler_state"]
and state_dict["_index_sampler_state"]["samples_yielded"] > 0
):
state_dict["_index_sampler_state"]["samples_yielded"] -= self.batch_size * factor
self.base_dataloader.load_state_dict(state_dict)
self.dl_state_dict = self.state_dict

def _update_state_dict(self):
# The state_dict of the underlying base_dataloader may be ahead of what is currently being yielded.
Expand All @@ -453,6 +466,7 @@ def _update_state_dict(self):
# _update_state_dict is called to snapshot the state_dict that would properly recover the DataLoaderAdapter.
if hasattr(self.base_dataloader, "state_dict"):
self.dl_state_dict = self.base_dataloader.state_dict()
self.dl_state_dict["_iterator_finished"] = self.end_of_dataloader


class DataLoaderShard(DataLoaderAdapter, DataLoaderStateMixin):
Expand Down Expand Up @@ -539,6 +553,7 @@ def __iter__(self):
current_batch = next_batch
except StopIteration:
self.end_of_dataloader = True
self._update_state_dict()
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

additionally, we need to update the state dict when we've hit the stop iteration to know if we've hit the end of the iterator

if batch_index >= self.skip_batches:
yield current_batch
break
Expand Down Expand Up @@ -809,6 +824,7 @@ def __iter__(self):

if stop_iteration:
self.end_of_dataloader = True
self._update_state_dict()
self.remainder = observed_batch_size
if batch_index >= self.skip_batches:
yield batch
Expand Down Expand Up @@ -1146,7 +1162,7 @@ def __len__(self):
return len(self.batch_sampler) - self.skip_batches


class SkipDataLoader(DataLoaderAdapter):
class SkipDataLoader(DataLoaderAdapter, DataLoaderStateMixin):
"""
Subclass of a PyTorch `DataLoader` that will skip the first batches.

Expand All @@ -1164,12 +1180,15 @@ class SkipDataLoader(DataLoaderAdapter):
def __init__(self, dataset, skip_batches=0, use_stateful_dataloader=False, **kwargs):
super().__init__(dataset, use_stateful_dataloader=use_stateful_dataloader, **kwargs)
self.skip_batches = skip_batches
self.gradient_state = GradientState()
muellerzr marked this conversation as resolved.
Show resolved Hide resolved

def __iter__(self):
self.begin()
for index, batch in enumerate(self.base_dataloader.__iter__()):
if index >= self.skip_batches:
self._update_state_dict()
yield batch
self.end()


def skip_first_batches(dataloader, num_batches=0):
Expand Down
74 changes: 71 additions & 3 deletions src/accelerate/test_utils/scripts/test_distributed_data_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.


import tempfile
import warnings
from typing import List
from unittest.mock import Mock
Expand Down Expand Up @@ -77,12 +77,17 @@ def create_accelerator(even_batches=True):
return accelerator


def create_dataloader(accelerator: Accelerator, dataset_size: int, batch_size: int, iterable: bool = False):
def create_dataloader(
accelerator: Accelerator, dataset_size: int, batch_size: int, iterable: bool = False, shuffle: bool = False
):
"""
Create a simple DataLoader to use during the test cases
"""
values = torch.as_tensor(range(dataset_size))
if shuffle:
values = values[torch.randperm(values.size(0))]
if iterable:
dataset = DummyIterableDataset(torch.as_tensor(range(dataset_size)))
dataset = DummyIterableDataset(values)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not the same for TensorDataset?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TensorDataset can be shuffled, basically. Our iterable here can't

else:
dataset = TensorDataset(torch.as_tensor(range(dataset_size)))

Expand Down Expand Up @@ -260,6 +265,67 @@ def test_data_loader(data_loader, accelerator):
), "Not all the dataset elements have been iterated in an epoch due to duplication of samples across processes."


def test_stateful_dataloader(accelerator):
muellerzr marked this conversation as resolved.
Show resolved Hide resolved
old_dataloader_config = accelerator.dataloader_config
accelerator.dataloader_config = DataLoaderConfiguration(use_stateful_dataloader=True)
prepared_dl = create_dataloader(
accelerator, dataset_size=32 * accelerator.num_processes, batch_size=4, iterable=True, shuffle=True
)
untrained_batches = []
# Calculate what step that will be
total_batches = 32 * accelerator.num_processes // (4 * accelerator.num_processes)
last_batch_num = total_batches - 1
for step, batch in enumerate(prepared_dl):
# Step just before
if step == last_batch_num - 1:
state_dict = prepared_dl.state_dict()
if step >= last_batch_num:
# Otherwise grab the "unseen" batches
untrained_batches.append(batch)
not_skipped_batches = accelerator.gather(untrained_batches)
prepared_dl.load_state_dict(state_dict)
resumed_batches = []
for batch in prepared_dl:
resumed_batches.append(batch)
resumed_batches = accelerator.gather(resumed_batches)
for b1, b2 in zip(not_skipped_batches, resumed_batches):
for v1, v2 in zip(b1, b2):
assert torch.equal(v1, v2), f"Batch {b1} and {b2} are not equal"

accelerator.dataloader_config = old_dataloader_config
muellerzr marked this conversation as resolved.
Show resolved Hide resolved


def test_stateful_dataloader_save_state(accelerator):
muellerzr marked this conversation as resolved.
Show resolved Hide resolved
with tempfile.TemporaryDirectory() as tmpdir:
old_dataloader_config = accelerator.dataloader_config
accelerator.dataloader_config = DataLoaderConfiguration(use_stateful_dataloader=True)
prepared_dl = create_dataloader(
accelerator, dataset_size=32 * accelerator.num_processes, batch_size=4, iterable=True, shuffle=True
)
untrained_batches = []
# Calculate what step that will be
total_batches = 32 * accelerator.num_processes // (4 * accelerator.num_processes)
last_batch_num = total_batches - 1
for step, batch in enumerate(prepared_dl):
# Step just before
if step == last_batch_num - 1:
accelerator.save_state(tmpdir)
if step >= last_batch_num:
# Otherwise grab the "unseen" batches
untrained_batches.append(batch)
not_skipped_batches = accelerator.gather(untrained_batches)
accelerator.load_state(tmpdir)
resumed_batches = []
for batch in prepared_dl:
resumed_batches.append(batch)
resumed_batches = accelerator.gather(resumed_batches)
for b1, b2 in zip(not_skipped_batches, resumed_batches):
for v1, v2 in zip(b1, b2):
assert torch.equal(v1, v2), f"Batch {b1} and {b2} are not equal"

accelerator.dataloader_config = old_dataloader_config


def main():
accelerator = create_accelerator()
torch.manual_seed(accelerator.process_index)
Expand Down Expand Up @@ -306,6 +372,8 @@ def main():
sampler = BatchSampler(RandomSampler(dataset), batch_size=BATCH_SIZE, drop_last=False)
loader = DataLoader(dataset, sampler=sampler, batch_size=None, collate_fn=default_collate, num_workers=NUM_WORKERS)
test_data_loader(loader, accelerator)
test_stateful_dataloader(accelerator)
test_stateful_dataloader_save_state(accelerator)

accelerator.end_training()

Expand Down
Loading
Loading