Skip to content

Commit

Permalink
[Infra] Update to pytorch-lightning 2.0
Browse files Browse the repository at this point in the history
  • Loading branch information
albertfgu committed Jul 8, 2023
1 parent fffbeee commit 976c9e9
Show file tree
Hide file tree
Showing 10 changed files with 135 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ trainer:
max_epochs: 310
precision: 16
devices: 1
replace_sampler_ddp: ${eval:"${dataset.num_aug_repeats} == 0"} # only True if using RepeatAug
use_distributed_sampler: ${eval:"${dataset.num_aug_repeats} == 0"} # only True if using RepeatAug
accumulate_grad_batches: ${eval:${train.global_batch_size} // ${.devices} // ${loader.batch_size}}

train:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ train:
remove_test_loader_in_eval: true # null means we do use test loader
global_batch_size: ${loader.batch_size} # effective batch size (handled with multiple gpus, and accumulate_grad_batches)
pretrained_model_strict_load: False
replace_sampler_ddp: False # ${eval:"${trainer.devices} > 1"}
use_distributed_sampler: False # ${eval:"${trainer.devices} > 1"}
pretrained_model_state_hook:
_name_: convnext_timm_tiny_s4nd_2d_to_3d

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ trainer:
max_epochs: 310
precision: 16
devices: 8
replace_sampler_ddp: ${eval:"${dataset.num_aug_repeats} == 0"} # only true if using RepeatAug
use_distributed_sampler: ${eval:"${dataset.num_aug_repeats} == 0"} # only true if using RepeatAug
accumulate_grad_batches: ${eval:${train.global_batch_size} // ${.devices} // ${loader.batch_size}}

train:
Expand Down
2 changes: 1 addition & 1 deletion configs/experiment/s4nd/vit/vit_b_16_s4_imagenet_v2.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ trainer:
max_epochs: 310
precision: 16
devices: 8
replace_sampler_ddp: ${eval:"${dataset.num_aug_repeats} == 0"} # only true if using RepeatAug
use_distributed_sampler: ${eval:"${dataset.num_aug_repeats} == 0"} # only true if using RepeatAug

train:
seed: 1112
Expand Down
1 change: 0 additions & 1 deletion configs/trainer/debug.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,4 @@ overfit_batches: 0
limit_train_batches: 0.1
limit_val_batches: 0.1
limit_test_batches: 0.1
track_grad_norm: -1
terminate_on_nan: False
3 changes: 1 addition & 2 deletions configs/trainer/default.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# See Docs for full flags and descriptions
# https://pytorch-lightning.readthedocs.io/en/stable/common/trainer.html#trainer-class-api
accelerator: gpu
strategy: null
strategy: auto
devices: 1
accumulate_grad_batches: 1 # Gradient accumulation every n batches
max_epochs: 200
Expand All @@ -11,4 +11,3 @@ log_every_n_steps: 10
limit_train_batches: 1.0 # train on full dataset, can be used to toggle quick run
limit_val_batches: 1.0 # validate on full dataset, can be used to toggle quick run
enable_model_summary: false # Can turn on if RichModelSummary is disabled
track_grad_norm: -1 # Set to 2 to track norms of gradients
3 changes: 1 addition & 2 deletions configs/trainer/lm.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,10 @@ gradient_clip_val: null # Gradient clipping
log_every_n_steps: 10
precision: 16
enable_model_summary: false # Can turn on if RichModelSummary is disabled
track_grad_norm: -1 # Set to 2 to track norms of gradients
limit_train_batches: 1.0
limit_val_batches: 1.0
# We use the dataloader from Transformer-XL to ensure adjacent minibatches
# are from text that are next to each other.
# So that dataloader has to deal with DDP, and we don't want PL to handle
# that.
replace_sampler_ddp: False
use_distributed_sampler: False
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ rich
torchtext
lit # Getting installation errors with torch 2.0 if this isn't installed
# torchvision
pytorch-lightning==1.9.3
pytorch-lightning==2.0.4
hydra-core
omegaconf
wandb
Expand Down
2 changes: 1 addition & 1 deletion src/callbacks/norms.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def on_after_training_step(self, batch, batch_idx, trainer: pl.Trainer, pl_modul

def on_after_backward(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
# example to inspect gradient information in tensorboard
if OmegaConf.select(trainer.hparams, 'trainer.track_grad_norms'): # TODO dot notation should work with omegaconf?
if OmegaConf.select(trainer.hparams, 'train.track_grad_norms'): # TODO dot notation should work with omegaconf?
norms = {}
for name, p in pl_module.named_parameters():
if p.grad is None:
Expand Down
134 changes: 127 additions & 7 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,9 +349,9 @@ def on_train_epoch_start(self):
# Reset training torchmetrics
self.task._reset_torchmetrics("train")

def training_epoch_end(self, outputs):
def on_train_epoch_end(self):
# Log training torchmetrics
super().training_epoch_end(outputs)
super().on_train_epoch_end()
self.log_dict(
{f"train/{k}": v for k, v in self.task.get_torchmetrics("train").items()},
on_step=False,
Expand All @@ -367,9 +367,9 @@ def on_validation_epoch_start(self):
for name in self.val_loader_names:
self.task._reset_torchmetrics(name)

def validation_epoch_end(self, outputs):
def on_validation_epoch_end(self):
# Log all validation torchmetrics
super().validation_epoch_end(outputs)
super().on_validation_epoch_end()
for name in self.val_loader_names:
self.log_dict(
{f"{name}/{k}": v for k, v in self.task.get_torchmetrics(name).items()},
Expand All @@ -386,9 +386,9 @@ def on_test_epoch_start(self):
for name in self.test_loader_names:
self.task._reset_torchmetrics(name)

def test_epoch_end(self, outputs):
def on_test_epoch_end(self):
# Log all test torchmetrics
super().test_epoch_end(outputs)
super().on_test_epoch_end()
for name in self.test_loader_names:
self.log_dict(
{f"{name}/{k}": v for k, v in self.task.get_torchmetrics(name).items()},
Expand All @@ -411,7 +411,7 @@ def training_step(self, batch, batch_idx):
loss_epoch,
on_step=True,
on_epoch=False,
prog_bar=False,
prog_bar=True,
add_dataloader_idx=False,
sync_dist=True,
)
Expand Down Expand Up @@ -666,6 +666,23 @@ def create_trainer(config):
# Stage params are resolution and epochs, pretty print
print(f"\tStage {i}: {e['resolution']} @ {e['epochs']} epochs")

# Additional ModelCheckpoint callback for preemption
if config.tolerance.id is not None:
pass
# if 'model_checkpoint' in config.callbacks.keys():
# callback_args = config.callbacks['model_checkpoint']
# callback_args._name_ = 'model_checkpoint' # For the registry
# # Save last two checkpoints to be extra fault tolerant
# callback_args.save_top_k = 2
# callback_args.monitor = 'trainer/epoch'
# callback_args.mode = 'max'
# callback_args.save_last = False
# callback_args.filename = 'last'
# # callback_args.save_on_train_epoch_end = True # this is False for the other checkpoint callback
# ckpt_callback = utils.instantiate(registry.callbacks, callback_args)
# # ckpt_callback.CHECKPOINT_NAME_LAST = 'last_' # now we have two last checkpoints, last.ckpt and last_.ckpt
# callbacks.append(ckpt_callback)

trainer = pl.Trainer(
logger=logger,
callbacks=callbacks,
Expand All @@ -681,6 +698,31 @@ def train(config):
trainer = create_trainer(config)
model = SequenceLightningModule(config)

# Load pretrained_model if specified
if config.train.get("pretrained_model_path", None) is not None:
# PTL style. Note, method returns a new model object, and need to pass config.
model = SequenceLightningModule.load_from_checkpoint(
config.train.pretrained_model_path,
config=config,
strict=config.train.pretrained_model_strict_load,
)
print("Loaded pretrained model from", config.train.pretrained_model_path)

# Added by KS for pre-training
# [22-07-21 AG] refactored, untested
if config.train.get("ignore_pretrained_layers", False):
pretrained_dict = pretrained_model.state_dict()
model_dict = model.state_dict()
for k, v in model_dict.items():
for ignore_layer in config.train.ignore_pretrained_layers:
if ignore_layer in k:
pretrained_dict[k] = v
model.load_state_dict(pretrained_dict)
if config.train.get("pretrained_freeze_encoder", False):
for name, param in model.named_parameters():
if not("decoder" in name): param.requires_grad = False


# Run initial validation epoch (useful for debugging, finetuning)
if config.train.validate_at_start:
print("Running validation before training")
Expand All @@ -693,6 +735,82 @@ def train(config):
if config.train.test:
trainer.test(model)



def preemption_setup(config):
if config.tolerance.id is None:
return config

# Create path ./logdir/id/ to store information for resumption
resume_dir = os.path.join(get_original_cwd(), config.tolerance.logdir, str(config.tolerance.id))

if os.path.exists(resume_dir):
print(f"Resuming from {resume_dir}")

# Load path to the last checkpoint
with open(os.path.join(resume_dir, "hydra.txt"), "r") as f:
hydra_paths = list(f.readlines())

# Look at the previous runs in reverse order
checkpoint_path = None
for hydra_path in reversed(hydra_paths):
hydra_path = hydra_path.rstrip('\n')

# Get the paths to the last.ckpt and last_.ckpt files
last_path = os.path.join(hydra_path, "checkpoints", "last.ckpt")
# last__path = os.path.join(hydra_path, "checkpoints", "last_.ckpt")
# last_exists, last__exists = os.path.exists(last_path), os.path.exists(last__path)

# if not last_exists or not last__exists:
# # This run doesn't have both checkpoints, so skip it
# print(f"\tSkipping {hydra_path}, not suitable for resuming (last_exists = {last_exists}, last__exists = {last__exists})")
# continue

# # Read timestamp when checkpoints were modified
# # We want to load the _earlier_ checkpoint, since that is guaranteed to be uncorrupted
# last_timestamp = os.path.getmtime(last_path)
# last__timestamp = os.path.getmtime(last__path)
# print("\t\tlast_timestamp =", last_timestamp)
# print("\t\tlast__timestamp =", last__timestamp)

# if last_timestamp < last__timestamp:
# checkpoint_path = last_path
# else:
# checkpoint_path = last__path
# checkpoint_path = last_path
# config.train.ckpt = checkpoint_path

if os.path.exists(last_path):
print("\tFound checkpoint at", last_path)
config.train.ckpt = last_path
# HACK TODO
config.train.pretrained_model_path = None
config.train.pretrained_model_state_hook._name_ = None
# config.train.pretrained_model_reinit_hook._name_ = None
break

# If we didn't find a checkpoint
if checkpoint_path is None:
print("\tNo suitable checkpoint found, starting from scratch")

# Set wandb run id to resume
if os.path.exists(os.path.join(hydra_path, 'wandb')):
run_info = [e for e in os.listdir(os.path.join(hydra_path, 'wandb')) if e.startswith('run-')][0]
run_id = run_info.split('-')[-1]
try:
config.wandb.id = run_id
except AttributeError:
pass

os.makedirs(resume_dir, exist_ok=True)

# Store path to Hydra output folder
with open(os.path.join(resume_dir, 'hydra.txt'), 'a') as f:
f.write(os.getcwd() + '\n')

return config


@hydra.main(config_path="configs", config_name="config.yaml")
def main(config: OmegaConf):

Expand All @@ -705,6 +823,8 @@ def main(config: OmegaConf):
# Pretty print config using Rich library
utils.train.print_config(config, resolve=True)

config = preemption_setup(config)

train(config)


Expand Down

0 comments on commit 976c9e9

Please sign in to comment.