Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixup dataloader state dict bugs + incorporate load/save_state API #3034

Merged
merged 19 commits into from
Aug 23, 2024
Merged
22 changes: 15 additions & 7 deletions examples/by_feature/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,10 @@
from datasets import load_dataset
from torch.optim import AdamW
from torch.utils.data import DataLoader
from transformers import AutoModelForSequenceClassification, AutoTokenizer, get_linear_schedule_with_warmup, set_seed
from transformers import AutoModelForSequenceClassification, AutoTokenizer, get_linear_schedule_with_warmup

from accelerate import Accelerator, DistributedType
from accelerate import Accelerator, DataLoaderConfiguration, DistributedType
from accelerate.utils import set_seed


########################################################################
Expand Down Expand Up @@ -125,7 +126,8 @@ def training_function(config, args):
if os.environ.get("TESTING_MOCKED_DATALOADERS", None) == "1":
config["num_epochs"] = 2
# Initialize accelerator
accelerator = Accelerator(cpu=args.cpu, mixed_precision=args.mixed_precision)
dataloader_config = DataLoaderConfiguration(use_stateful_dataloader=args.use_stateful_dataloader)
accelerator = Accelerator(cpu=args.cpu, mixed_precision=args.mixed_precision, dataloader_config=dataloader_config)
# Sample hyper-parameters for learning rate, batch size, seed and a few other HPs
lr = config["lr"]
num_epochs = int(config["num_epochs"])
Expand Down Expand Up @@ -217,8 +219,11 @@ def training_function(config, args):
model.train()
# New Code #
if args.resume_from_checkpoint and epoch == starting_epoch and resume_step is not None:
# We need to skip steps until we reach the resumed step
active_dataloader = accelerator.skip_first_batches(train_dataloader, resume_step)
# We need to skip steps until we reach the resumed step only if we are not using a stateful dataloader
if not args.use_stateful_dataloader:
active_dataloader = accelerator.skip_first_batches(train_dataloader, resume_step)
else:
active_dataloader = train_dataloader
overall_step += resume_step
else:
# After the first iteration though, we need to go back to the original dataloader
Expand Down Expand Up @@ -248,7 +253,6 @@ def training_function(config, args):
if args.output_dir is not None:
output_dir = os.path.join(args.output_dir, output_dir)
accelerator.save_state(output_dir)

model.eval()
for step, batch in enumerate(eval_dataloader):
# We could avoid this line since we set the accelerator with `device_placement=True` (the default).
Expand All @@ -261,7 +265,6 @@ def training_function(config, args):
predictions=predictions,
references=references,
)

eval_metric = metric.compute()
# Use accelerator.print to print only on the main process.
accelerator.print(f"epoch {epoch}:", eval_metric)
Expand Down Expand Up @@ -309,6 +312,11 @@ def main():
default=None,
help="If the training should continue from a checkpoint folder.",
)
parser.add_argument(
"--use_stateful_dataloader",
action="store_true",
help="If the dataloader should be a resumable stateful dataloader.",
)
args = parser.parse_args()
config = {"lr": 2e-5, "num_epochs": 3, "seed": 42, "batch_size": 16}
training_function(config, args)
Expand Down
18 changes: 15 additions & 3 deletions examples/complete_cv_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from torch.utils.data import DataLoader, Dataset
from torchvision.transforms import Compose, RandomResizedCrop, Resize, ToTensor

from accelerate import Accelerator
from accelerate import Accelerator, DataLoaderConfiguration


########################################################################
Expand Down Expand Up @@ -72,12 +72,19 @@ def __getitem__(self, idx):

def training_function(config, args):
# Initialize accelerator
dataloader_config = DataLoaderConfiguration(use_stateful_dataloader=args.use_stateful_dataloader)
if args.with_tracking:
accelerator = Accelerator(
cpu=args.cpu, mixed_precision=args.mixed_precision, log_with="all", project_dir=args.project_dir
cpu=args.cpu,
mixed_precision=args.mixed_precision,
log_with="all",
project_dir=args.project_dir,
dataloader_config=dataloader_config,
)
else:
accelerator = Accelerator(cpu=args.cpu, mixed_precision=args.mixed_precision)
accelerator = Accelerator(
cpu=args.cpu, mixed_precision=args.mixed_precision, dataloader_config=dataloader_config
)

# Sample hyper-parameters for learning rate, batch size, seed and a few other HPs
lr = config["lr"]
Expand Down Expand Up @@ -297,6 +304,11 @@ def main():
default=None,
help="If the training should continue from a checkpoint folder.",
)
parser.add_argument(
"--use_stateful_dataloader",
action="store_true",
help="If the dataloader should be a resumable stateful dataloader.",
)
parser.add_argument(
"--with_tracking",
action="store_true",
Expand Down
23 changes: 19 additions & 4 deletions examples/complete_nlp_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from torch.utils.data import DataLoader
from transformers import AutoModelForSequenceClassification, AutoTokenizer, get_linear_schedule_with_warmup, set_seed

from accelerate import Accelerator, DistributedType
from accelerate import Accelerator, DataLoaderConfiguration, DistributedType


########################################################################
Expand Down Expand Up @@ -49,12 +49,19 @@

def training_function(config, args):
# Initialize accelerator
dataloader_config = DataLoaderConfiguration(use_stateful_dataloader=args.use_stateful_dataloader)
if args.with_tracking:
accelerator = Accelerator(
cpu=args.cpu, mixed_precision=args.mixed_precision, log_with="all", project_dir=args.project_dir
cpu=args.cpu,
mixed_precision=args.mixed_precision,
dataloader_config=dataloader_config,
log_with="all",
project_dir=args.project_dir,
)
else:
accelerator = Accelerator(cpu=args.cpu, mixed_precision=args.mixed_precision)
accelerator = Accelerator(
cpu=args.cpu, mixed_precision=args.mixed_precision, dataloader_config=dataloader_config
)

if hasattr(args.checkpointing_steps, "isdigit"):
if args.checkpointing_steps == "epoch":
Expand Down Expand Up @@ -194,7 +201,10 @@ def collate_fn(examples):
total_loss = 0
if args.resume_from_checkpoint and epoch == starting_epoch and resume_step is not None:
# We need to skip steps until we reach the resumed step
active_dataloader = accelerator.skip_first_batches(train_dataloader, resume_step)
if not args.use_stateful_dataloader:
active_dataloader = accelerator.skip_first_batches(train_dataloader, resume_step)
else:
active_dataloader = train_dataloader
overall_step += resume_step
else:
# After the first iteration though, we need to go back to the original dataloader
Expand Down Expand Up @@ -283,6 +293,11 @@ def main():
default=None,
help="If the training should continue from a checkpoint folder.",
)
parser.add_argument(
"--use_stateful_dataloader",
action="store_true",
help="If the dataloader should be a resumable stateful dataloader.",
)
parser.add_argument(
"--with_tracking",
action="store_true",
Expand Down
11 changes: 11 additions & 0 deletions src/accelerate/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,11 @@ def save_accelerator_state(
sampler = dataloader.get_sampler()
if isinstance(sampler, SeedableRandomSampler):
save(sampler, output_sampler_file, save_on_each_node=save_on_each_node, safe_serialization=False)
if getattr(dataloader, "use_stateful_dataloader", False):
dataloader_state_dict_name = "dl_state_dict.bin" if i == 0 else f"dl_state_dict_{i}.bin"
output_dataloader_state_dict_file = output_dir.joinpath(dataloader_state_dict_name)
state_dict = dataloader.state_dict()
torch.save(state_dict, output_dataloader_state_dict_file)
logger.info(f"Sampler state for dataloader {i} saved in {output_sampler_file}")

# GradScaler state
Expand Down Expand Up @@ -241,6 +246,12 @@ def load_accelerator_state(
sampler = dataloader.get_sampler()
if isinstance(sampler, SeedableRandomSampler):
sampler = dataloader.set_sampler(torch.load(input_sampler_file))
if getattr(dataloader, "use_stateful_dataloader", False):
dataloader_state_dict_name = "dl_state_dict.bin" if i == 0 else f"dl_state_dict_{i}.bin"
muellerzr marked this conversation as resolved.
Show resolved Hide resolved
input_dataloader_state_dict_file = input_dir.joinpath(dataloader_state_dict_name)
if input_dataloader_state_dict_file.exists():
state_dict = torch.load(input_dataloader_state_dict_file)
dataloader.load_state_dict(state_dict)
logger.info("All dataloader sampler states loaded successfully")

# GradScaler state
Expand Down
42 changes: 40 additions & 2 deletions src/accelerate/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,13 @@ class DataLoaderStateMixin:
- **remainder** (`int`) -- The number of items that are remaining in the last batch, relative to the total
batch size

<Tip warning={true}>

Inheriters of this class should ensure that the class creates a `GradientState()` instance, stored in
`self.gradient_state`.

</Tip>

"""

def __init_subclass__(cls, **kwargs):
Expand Down Expand Up @@ -443,7 +450,29 @@ def state_dict(self):

def load_state_dict(self, state_dict):
self.base_dataloader.load_state_dict(state_dict)
self.dl_state_dict = self.state_dict

def adjust_state_dict_for_prefetch(self):
"""
Adjusts the state dict for prefetching. Natively, this will adjust all of the iters yielded keys in
`self.dl_state_dict` by a factor of `num_processes - 1`, however if a custom correction is needed, this can be
overridden.

This should modify `self.dl_state_dict` directly
"""
# The state dict will be off by a factor of `n-1` batch too many during DDP,
# so we need to adjust it here
if PartialState().distributed_type != DistributedType.NO:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be fixed during loading or rather during saving?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we make it during saving, that isolates it to users who just do Accelerator.save_state/Accelerator.load_state, which (esp in the trainer) might not be what users want to end up doing, since it's always entirely optional. I'd rather it happen in load IMO but will think if there's a better way

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The other (probably more right option) is to fix it in state_dict()

Copy link
Contributor

@byi8220 byi8220 Aug 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adding onto this, is it appropriate to hack around with the state_dict's fields such as _iter_yielded, _num_yielded?

IIUC this works in the basic case, but torchdata seems to support custom state functions.

This makes me feel like this needs to be fixed during saving, and maybe in a way that doesn't make assumptions about the contents of state_dict (make sure to save the right state dict?).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We must, it's not a matter of if its appropriate or not. If we don't, the sampler/resuming simply wont work :) As we need to modify their values in their sampler.

This is a naiive implementation to start, and if we hit edge cases with more things to be adjusted later, we can. But the base case is supported

Copy link
Collaborator Author

@muellerzr muellerzr Aug 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The only point that would be is on a non-multi/distributed setup. Otherwise they're likely not using Accelerate dataloaders

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Regarding the question of custom state functions: Would it be possible to optionally allow custom data loaders to make the update themselves and otherwise fall back on the solution provided here? So for instance:

if hasattr(self.base_data_loader, "correct_state_for_prefetch"):  # or whatever fitting name
    self.dl_state_dict = self.base_data_loader.correct_state_for_prefetch(self.dl_state_dict, PartialState())
else:
    ...  # existing code

Of course, this needs to be documented so that users can correctly implement this method.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm more a fan of this, but not sure if we can get there in our current state since it's either a StatefulDataLoader or a native torch.utils.data.DataLoader that gets built

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got confused, the type of the data loader is fix.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I went with an adjust_state_dict_for_prefetch func in DataLoaderAdapter, with documentation on how overriding it should work.

factor = PartialState().num_processes - 1
if self.dl_state_dict["_sampler_iter_yielded"] > 0:
self.dl_state_dict["_sampler_iter_yielded"] -= factor
if self.dl_state_dict["_num_yielded"] > 0:
self.dl_state_dict["_num_yielded"] -= factor
if self.dl_state_dict["_index_sampler_state"] is not None:
if (
"samples_yielded" in self.dl_state_dict["_index_sampler_state"]
and self.dl_state_dict["_index_sampler_state"]["samples_yielded"] > 0
):
self.dl_state_dict["_index_sampler_state"]["samples_yielded"] -= self.batch_size * factor

def _update_state_dict(self):
# The state_dict of the underlying base_dataloader may be ahead of what is currently being yielded.
Expand All @@ -453,6 +482,10 @@ def _update_state_dict(self):
# _update_state_dict is called to snapshot the state_dict that would properly recover the DataLoaderAdapter.
if hasattr(self.base_dataloader, "state_dict"):
self.dl_state_dict = self.base_dataloader.state_dict()
# Potentially modify the state_dict to adjust for prefetching
self.adjust_state_dict_for_prefetch()
# Then tag if we are at the end of the dataloader
self.dl_state_dict["_iterator_finished"] = self.end_of_dataloader


class DataLoaderShard(DataLoaderAdapter, DataLoaderStateMixin):
Expand Down Expand Up @@ -539,6 +572,7 @@ def __iter__(self):
current_batch = next_batch
except StopIteration:
self.end_of_dataloader = True
self._update_state_dict()
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

additionally, we need to update the state dict when we've hit the stop iteration to know if we've hit the end of the iterator

if batch_index >= self.skip_batches:
yield current_batch
break
Expand Down Expand Up @@ -809,6 +843,7 @@ def __iter__(self):

if stop_iteration:
self.end_of_dataloader = True
self._update_state_dict()
self.remainder = observed_batch_size
if batch_index >= self.skip_batches:
yield batch
Expand Down Expand Up @@ -1146,7 +1181,7 @@ def __len__(self):
return len(self.batch_sampler) - self.skip_batches


class SkipDataLoader(DataLoaderAdapter):
class SkipDataLoader(DataLoaderAdapter, DataLoaderStateMixin):
"""
Subclass of a PyTorch `DataLoader` that will skip the first batches.

Expand All @@ -1164,12 +1199,15 @@ class SkipDataLoader(DataLoaderAdapter):
def __init__(self, dataset, skip_batches=0, use_stateful_dataloader=False, **kwargs):
super().__init__(dataset, use_stateful_dataloader=use_stateful_dataloader, **kwargs)
self.skip_batches = skip_batches
self.gradient_state = GradientState()
muellerzr marked this conversation as resolved.
Show resolved Hide resolved

def __iter__(self):
self.begin()
for index, batch in enumerate(self.base_dataloader.__iter__()):
if index >= self.skip_batches:
self._update_state_dict()
yield batch
self.end()


def skip_first_batches(dataloader, num_batches=0):
Expand Down
Loading
Loading