Skip to content

Commit

Permalink
Fixup dataloader state dict bugs + incorporate load/save_state API (#…
Browse files Browse the repository at this point in the history
…3034)

* v1

* More testing, need to try on H100

* Bigger batch for h100 test

* test tweak

* Fixup all tests!

* Bookmark

* Fix issues, working now

* rm num samples

* Uncomment

* Give stateful dl end of dl

* Make skip DL stateful

* Migrate to update_state_dict

* try/finally

* Add comments to test

* rm comment

* Document

* refactor out for eventual override

* Doc nit

* Brute force it
  • Loading branch information
muellerzr authored Aug 23, 2024
1 parent 2d4f1dd commit 726140c
Show file tree
Hide file tree
Showing 8 changed files with 219 additions and 75 deletions.
22 changes: 15 additions & 7 deletions examples/by_feature/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,10 @@
from datasets import load_dataset
from torch.optim import AdamW
from torch.utils.data import DataLoader
from transformers import AutoModelForSequenceClassification, AutoTokenizer, get_linear_schedule_with_warmup, set_seed
from transformers import AutoModelForSequenceClassification, AutoTokenizer, get_linear_schedule_with_warmup

from accelerate import Accelerator, DistributedType
from accelerate import Accelerator, DataLoaderConfiguration, DistributedType
from accelerate.utils import set_seed


########################################################################
Expand Down Expand Up @@ -125,7 +126,8 @@ def training_function(config, args):
if os.environ.get("TESTING_MOCKED_DATALOADERS", None) == "1":
config["num_epochs"] = 2
# Initialize accelerator
accelerator = Accelerator(cpu=args.cpu, mixed_precision=args.mixed_precision)
dataloader_config = DataLoaderConfiguration(use_stateful_dataloader=args.use_stateful_dataloader)
accelerator = Accelerator(cpu=args.cpu, mixed_precision=args.mixed_precision, dataloader_config=dataloader_config)
# Sample hyper-parameters for learning rate, batch size, seed and a few other HPs
lr = config["lr"]
num_epochs = int(config["num_epochs"])
Expand Down Expand Up @@ -217,8 +219,11 @@ def training_function(config, args):
model.train()
# New Code #
if args.resume_from_checkpoint and epoch == starting_epoch and resume_step is not None:
# We need to skip steps until we reach the resumed step
active_dataloader = accelerator.skip_first_batches(train_dataloader, resume_step)
# We need to skip steps until we reach the resumed step only if we are not using a stateful dataloader
if not args.use_stateful_dataloader:
active_dataloader = accelerator.skip_first_batches(train_dataloader, resume_step)
else:
active_dataloader = train_dataloader
overall_step += resume_step
else:
# After the first iteration though, we need to go back to the original dataloader
Expand Down Expand Up @@ -248,7 +253,6 @@ def training_function(config, args):
if args.output_dir is not None:
output_dir = os.path.join(args.output_dir, output_dir)
accelerator.save_state(output_dir)

model.eval()
for step, batch in enumerate(eval_dataloader):
# We could avoid this line since we set the accelerator with `device_placement=True` (the default).
Expand All @@ -261,7 +265,6 @@ def training_function(config, args):
predictions=predictions,
references=references,
)

eval_metric = metric.compute()
# Use accelerator.print to print only on the main process.
accelerator.print(f"epoch {epoch}:", eval_metric)
Expand Down Expand Up @@ -309,6 +312,11 @@ def main():
default=None,
help="If the training should continue from a checkpoint folder.",
)
parser.add_argument(
"--use_stateful_dataloader",
action="store_true",
help="If the dataloader should be a resumable stateful dataloader.",
)
args = parser.parse_args()
config = {"lr": 2e-5, "num_epochs": 3, "seed": 42, "batch_size": 16}
training_function(config, args)
Expand Down
18 changes: 15 additions & 3 deletions examples/complete_cv_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from torch.utils.data import DataLoader, Dataset
from torchvision.transforms import Compose, RandomResizedCrop, Resize, ToTensor

from accelerate import Accelerator
from accelerate import Accelerator, DataLoaderConfiguration


########################################################################
Expand Down Expand Up @@ -72,12 +72,19 @@ def __getitem__(self, idx):

def training_function(config, args):
# Initialize accelerator
dataloader_config = DataLoaderConfiguration(use_stateful_dataloader=args.use_stateful_dataloader)
if args.with_tracking:
accelerator = Accelerator(
cpu=args.cpu, mixed_precision=args.mixed_precision, log_with="all", project_dir=args.project_dir
cpu=args.cpu,
mixed_precision=args.mixed_precision,
log_with="all",
project_dir=args.project_dir,
dataloader_config=dataloader_config,
)
else:
accelerator = Accelerator(cpu=args.cpu, mixed_precision=args.mixed_precision)
accelerator = Accelerator(
cpu=args.cpu, mixed_precision=args.mixed_precision, dataloader_config=dataloader_config
)

# Sample hyper-parameters for learning rate, batch size, seed and a few other HPs
lr = config["lr"]
Expand Down Expand Up @@ -297,6 +304,11 @@ def main():
default=None,
help="If the training should continue from a checkpoint folder.",
)
parser.add_argument(
"--use_stateful_dataloader",
action="store_true",
help="If the dataloader should be a resumable stateful dataloader.",
)
parser.add_argument(
"--with_tracking",
action="store_true",
Expand Down
23 changes: 19 additions & 4 deletions examples/complete_nlp_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from torch.utils.data import DataLoader
from transformers import AutoModelForSequenceClassification, AutoTokenizer, get_linear_schedule_with_warmup, set_seed

from accelerate import Accelerator, DistributedType
from accelerate import Accelerator, DataLoaderConfiguration, DistributedType


########################################################################
Expand Down Expand Up @@ -49,12 +49,19 @@

def training_function(config, args):
# Initialize accelerator
dataloader_config = DataLoaderConfiguration(use_stateful_dataloader=args.use_stateful_dataloader)
if args.with_tracking:
accelerator = Accelerator(
cpu=args.cpu, mixed_precision=args.mixed_precision, log_with="all", project_dir=args.project_dir
cpu=args.cpu,
mixed_precision=args.mixed_precision,
dataloader_config=dataloader_config,
log_with="all",
project_dir=args.project_dir,
)
else:
accelerator = Accelerator(cpu=args.cpu, mixed_precision=args.mixed_precision)
accelerator = Accelerator(
cpu=args.cpu, mixed_precision=args.mixed_precision, dataloader_config=dataloader_config
)

if hasattr(args.checkpointing_steps, "isdigit"):
if args.checkpointing_steps == "epoch":
Expand Down Expand Up @@ -194,7 +201,10 @@ def collate_fn(examples):
total_loss = 0
if args.resume_from_checkpoint and epoch == starting_epoch and resume_step is not None:
# We need to skip steps until we reach the resumed step
active_dataloader = accelerator.skip_first_batches(train_dataloader, resume_step)
if not args.use_stateful_dataloader:
active_dataloader = accelerator.skip_first_batches(train_dataloader, resume_step)
else:
active_dataloader = train_dataloader
overall_step += resume_step
else:
# After the first iteration though, we need to go back to the original dataloader
Expand Down Expand Up @@ -283,6 +293,11 @@ def main():
default=None,
help="If the training should continue from a checkpoint folder.",
)
parser.add_argument(
"--use_stateful_dataloader",
action="store_true",
help="If the dataloader should be a resumable stateful dataloader.",
)
parser.add_argument(
"--with_tracking",
action="store_true",
Expand Down
11 changes: 11 additions & 0 deletions src/accelerate/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,11 @@ def save_accelerator_state(
sampler = dataloader.get_sampler()
if isinstance(sampler, SeedableRandomSampler):
save(sampler, output_sampler_file, save_on_each_node=save_on_each_node, safe_serialization=False)
if getattr(dataloader, "use_stateful_dataloader", False):
dataloader_state_dict_name = "dl_state_dict.bin" if i == 0 else f"dl_state_dict_{i}.bin"
output_dataloader_state_dict_file = output_dir.joinpath(dataloader_state_dict_name)
state_dict = dataloader.state_dict()
torch.save(state_dict, output_dataloader_state_dict_file)
logger.info(f"Sampler state for dataloader {i} saved in {output_sampler_file}")

# GradScaler state
Expand Down Expand Up @@ -241,6 +246,12 @@ def load_accelerator_state(
sampler = dataloader.get_sampler()
if isinstance(sampler, SeedableRandomSampler):
sampler = dataloader.set_sampler(torch.load(input_sampler_file))
if getattr(dataloader, "use_stateful_dataloader", False):
dataloader_state_dict_name = "dl_state_dict.bin" if i == 0 else f"dl_state_dict_{i}.bin"
input_dataloader_state_dict_file = input_dir.joinpath(dataloader_state_dict_name)
if input_dataloader_state_dict_file.exists():
state_dict = torch.load(input_dataloader_state_dict_file)
dataloader.load_state_dict(state_dict)
logger.info("All dataloader sampler states loaded successfully")

# GradScaler state
Expand Down
42 changes: 40 additions & 2 deletions src/accelerate/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,13 @@ class DataLoaderStateMixin:
- **remainder** (`int`) -- The number of items that are remaining in the last batch, relative to the total
batch size
<Tip warning={true}>
Inheriters of this class should ensure that the class creates a `GradientState()` instance, stored in
`self.gradient_state`.
</Tip>
"""

def __init_subclass__(cls, **kwargs):
Expand Down Expand Up @@ -443,7 +450,29 @@ def state_dict(self):

def load_state_dict(self, state_dict):
self.base_dataloader.load_state_dict(state_dict)
self.dl_state_dict = self.state_dict

def adjust_state_dict_for_prefetch(self):
"""
Adjusts the state dict for prefetching. Natively, this will adjust all of the iters yielded keys in
`self.dl_state_dict` by a factor of `num_processes - 1`, however if a custom correction is needed, this can be
overridden.
This should modify `self.dl_state_dict` directly
"""
# The state dict will be off by a factor of `n-1` batch too many during DDP,
# so we need to adjust it here
if PartialState().distributed_type != DistributedType.NO:
factor = PartialState().num_processes - 1
if self.dl_state_dict["_sampler_iter_yielded"] > 0:
self.dl_state_dict["_sampler_iter_yielded"] -= factor
if self.dl_state_dict["_num_yielded"] > 0:
self.dl_state_dict["_num_yielded"] -= factor
if self.dl_state_dict["_index_sampler_state"] is not None:
if (
"samples_yielded" in self.dl_state_dict["_index_sampler_state"]
and self.dl_state_dict["_index_sampler_state"]["samples_yielded"] > 0
):
self.dl_state_dict["_index_sampler_state"]["samples_yielded"] -= self.batch_size * factor

def _update_state_dict(self):
# The state_dict of the underlying base_dataloader may be ahead of what is currently being yielded.
Expand All @@ -453,6 +482,10 @@ def _update_state_dict(self):
# _update_state_dict is called to snapshot the state_dict that would properly recover the DataLoaderAdapter.
if hasattr(self.base_dataloader, "state_dict"):
self.dl_state_dict = self.base_dataloader.state_dict()
# Potentially modify the state_dict to adjust for prefetching
self.adjust_state_dict_for_prefetch()
# Then tag if we are at the end of the dataloader
self.dl_state_dict["_iterator_finished"] = self.end_of_dataloader


class DataLoaderShard(DataLoaderAdapter, DataLoaderStateMixin):
Expand Down Expand Up @@ -539,6 +572,7 @@ def __iter__(self):
current_batch = next_batch
except StopIteration:
self.end_of_dataloader = True
self._update_state_dict()
if batch_index >= self.skip_batches:
yield current_batch
break
Expand Down Expand Up @@ -809,6 +843,7 @@ def __iter__(self):

if stop_iteration:
self.end_of_dataloader = True
self._update_state_dict()
self.remainder = observed_batch_size
if batch_index >= self.skip_batches:
yield batch
Expand Down Expand Up @@ -1146,7 +1181,7 @@ def __len__(self):
return len(self.batch_sampler) - self.skip_batches


class SkipDataLoader(DataLoaderAdapter):
class SkipDataLoader(DataLoaderAdapter, DataLoaderStateMixin):
"""
Subclass of a PyTorch `DataLoader` that will skip the first batches.
Expand All @@ -1164,12 +1199,15 @@ class SkipDataLoader(DataLoaderAdapter):
def __init__(self, dataset, skip_batches=0, use_stateful_dataloader=False, **kwargs):
super().__init__(dataset, use_stateful_dataloader=use_stateful_dataloader, **kwargs)
self.skip_batches = skip_batches
self.gradient_state = GradientState()

def __iter__(self):
self.begin()
for index, batch in enumerate(self.base_dataloader.__iter__()):
if index >= self.skip_batches:
self._update_state_dict()
yield batch
self.end()


def skip_first_batches(dataloader, num_batches=0):
Expand Down
Loading

0 comments on commit 726140c

Please sign in to comment.