From c95d24e6aa5821d72c648a4888155ced0c510baa Mon Sep 17 00:00:00 2001 From: sichu2023 <152360507+sichu2023@users.noreply.github.com> Date: Thu, 5 Dec 2024 15:58:01 +0100 Subject: [PATCH] Add support on num steps for learning rate scheduler (#489) Decouple `num_steps` with number of steps in learning rate scheduler. --- .../src/bionemo/esm2/run/recipes.py | 43 +++++++++++++------ .../src/bionemo/esm2/scripts/train_esm2.py | 22 ++++++++-- .../bionemo/esm2/scripts/test_train_esm2.py | 2 + .../src/bionemo/llm/run/config_models.py | 2 + .../bionemo-llm/src/bionemo/llm/train.py | 4 +- 5 files changed, 55 insertions(+), 18 deletions(-) diff --git a/sub-packages/bionemo-esm2/src/bionemo/esm2/run/recipes.py b/sub-packages/bionemo-esm2/src/bionemo/esm2/run/recipes.py index cc4da8c337..e5cca198d3 100644 --- a/sub-packages/bionemo-esm2/src/bionemo/esm2/run/recipes.py +++ b/sub-packages/bionemo-esm2/src/bionemo/esm2/run/recipes.py @@ -36,10 +36,10 @@ from bionemo.llm.utils.logger_utils import WandbConfig -def esm2_base_training_config() -> TrainingConfig: +def esm2_base_training_config(max_steps: int = 500000) -> TrainingConfig: """Base training config for ESM2.""" return TrainingConfig( - max_steps=500000, + max_steps=max_steps, limit_val_batches=1.0, val_check_interval=10_000, precision="bf16-mixed", @@ -47,10 +47,16 @@ def esm2_base_training_config() -> TrainingConfig: ) -def esm2_base_optimizer_scheduler_config() -> OptimizerSchedulerConfig: +def esm2_base_optimizer_scheduler_config(max_steps: Optional[int] = None) -> OptimizerSchedulerConfig: """Base optimizer scheduler config for ESM2.""" return OptimizerSchedulerConfig( - optimizer="adam", lr=4e-4, interval="step", monitor="val_loss", lr_scheduler="warmup_anneal", warmup_steps=2000 + optimizer="adam", + lr=4e-4, + interval="step", + monitor="val_loss", + lr_scheduler="warmup_anneal", + warmup_steps=2000, + max_steps=max_steps, ) @@ -128,9 +134,9 @@ def esm2_8m_recipe(args) -> MainConfig[ExposedESM2PretrainConfig, ESM2DataConfig return MainConfig( data_config=esm2_base_data_config(args), parallel_config=esm2_base_parallel_config(), - training_config=esm2_base_training_config(), # no changes for 8m + training_config=esm2_base_training_config(max_steps=args.max_steps), # no changes for 8m bionemo_model_config=esm2_8m_model_config(args.initial_ckpt_path), - optim_config=esm2_base_optimizer_scheduler_config(), # no changes for 8m + optim_config=esm2_base_optimizer_scheduler_config(max_steps=args.scheduler_max_steps), # no changes for 8m experiment_config=esm2_8m_experiment_config(args.result_dir), wandb_config=esm2_8m_wandb_config(), ) @@ -183,9 +189,9 @@ def esm2_650m_recipe(args) -> MainConfig[ExposedESM2PretrainConfig, ESM2DataConf return MainConfig( data_config=esm2_base_data_config(args), parallel_config=esm2_base_parallel_config(), - training_config=esm2_base_training_config(), # no changes for 8m + training_config=esm2_base_training_config(max_steps=args.max_steps), # no changes for 8m bionemo_model_config=esm2_650m_model_config(args.initial_ckpt_path), - optim_config=esm2_base_optimizer_scheduler_config(), # no changes for 8m + optim_config=esm2_base_optimizer_scheduler_config(max_steps=args.scheduler_max_steps), # no changes for 8m experiment_config=esm2_650m_experiment_config(args.result_dir), wandb_config=esm2_650m_wandb_config(), ) @@ -251,9 +257,9 @@ def esm2_3b_recipe(args) -> MainConfig[ExposedESM2PretrainConfig, ESM2DataConfig return MainConfig( data_config=esm2_base_data_config(args), parallel_config=esm2_3b_parallel_config(), - training_config=esm2_base_training_config(), # no changes for 8m + training_config=esm2_base_training_config(max_steps=args.max_steps), # no changes for 8m bionemo_model_config=esm2_3b_model_config(args.initial_ckpt_path), - optim_config=esm2_base_optimizer_scheduler_config(), # no changes for 8m + optim_config=esm2_base_optimizer_scheduler_config(max_steps=args.scheduler_max_steps), # no changes for 8m experiment_config=esm2_3b_experiment_config(args.result_dir), wandb_config=esm2_3b_wandb_config(), ) @@ -282,9 +288,9 @@ def tiny_train_config_recipe() -> TrainingConfig: return TrainingConfig(max_steps=10, limit_val_batches=2, val_check_interval=2) -def default_adam_optimizer_with_cosine_annealing_recipe() -> OptimizerSchedulerConfig: +def default_adam_optimizer_with_cosine_annealing_recipe(max_steps: Optional[int] = None) -> OptimizerSchedulerConfig: """Default optimizer scheduler config for ESM2.""" - return OptimizerSchedulerConfig() + return OptimizerSchedulerConfig(max_steps=max_steps) def experiment_config_recipe(result_dir="./results") -> ExperimentConfig: @@ -347,7 +353,7 @@ def esm2_tiny_test_recipe(args): seq_length=data_config.max_seq_length, initial_ckpt_path=args.initial_ckpt_path ) - optim_config = default_adam_optimizer_with_cosine_annealing_recipe() + optim_config = default_adam_optimizer_with_cosine_annealing_recipe(max_steps=args.scheduler_max_steps) experiment_config = experiment_config_recipe(args.result_dir) wandb_config = WandbConfig( project="bionemo2-demo", @@ -436,6 +442,17 @@ def parse_args(): help="Path to an existing to a checkpoint directory to restore an existing checkpoint. Not compatible with all recipes.", ) + parser.add_argument( + "--max-steps", type=int, required=False, default=500000, help="Max steps for training. Default to 500000." + ) + + parser.add_argument( + "--scheduler-max-steps", + type=int, + required=False, + help="Set scheduler max_steps directly. Otherwise default to None, which uses max_steps from training config.", + ) + args = parser.parse_args() return args diff --git a/sub-packages/bionemo-esm2/src/bionemo/esm2/scripts/train_esm2.py b/sub-packages/bionemo-esm2/src/bionemo/esm2/scripts/train_esm2.py index ed5e46899b..7d05455924 100644 --- a/sub-packages/bionemo-esm2/src/bionemo/esm2/scripts/train_esm2.py +++ b/sub-packages/bionemo-esm2/src/bionemo/esm2/scripts/train_esm2.py @@ -52,12 +52,13 @@ def main( max_seq_length: int, result_dir: Path, num_steps: int, + scheduler_num_steps: Optional[int], warmup_steps: int, limit_val_batches: int, val_check_interval: int, log_every_n_steps: Optional[int], num_dataset_workers: int, - biobert_spec_option: BiobertSpecOption, # TODO(@farhadrgh) clarify how to parse this. + biobert_spec_option: BiobertSpecOption, lr: float, micro_batch_size: int, accumulate_grad_batches: int, @@ -111,6 +112,7 @@ def main( num_dataset_workers (int): number of dataset workers biobert_spec_option (BiobertSpecOption): the biobert spec option (architecture) to use for this run lr (float): learning rate + scheduler_num_steps (Optional[int]): Number of steps in learning rate scheduler. Use num_steps if not provided. micro_batch_size (int): micro batch size, from this and parallelism settings we infer the global batch size accumulate_grad_batches (int): number of batches to accumulate gradients for experiment_name (str): experiment name, this is the name used for the wandb run, and the sub-directory of the @@ -247,20 +249,27 @@ def main( variable_seq_lengths=min_seq_length != max_seq_length, ) + if scheduler_num_steps is None: + scheduler_num_steps = num_steps + model = biobert_lightning_module( esm2_config, tokenizer=tokenizer, optimizer=MegatronOptimizerModule( config=OptimizerConfig( lr=lr, - optimizer="adam", # fused_adam not supported + optimizer="adam", use_distributed_optimizer=True, weight_decay=0.01, adam_beta1=0.9, adam_beta2=0.98, ), lr_scheduler=WarmupAnnealDecayHoldScheduler( - warmup_steps=warmup_steps, max_steps=num_steps, max_lr=lr, min_lr=0.0, anneal_percentage=0.10 + warmup_steps=warmup_steps, + max_steps=scheduler_num_steps, + max_lr=lr, + min_lr=0.0, + anneal_percentage=0.10, ), ), ) @@ -328,6 +337,7 @@ def train_esm2_entrypoint(): num_dataset_workers=args.num_dataset_workers, biobert_spec_option=args.biobert_spec_option, lr=args.lr, + scheduler_num_steps=args.scheduler_num_steps, micro_batch_size=args.micro_batch_size, pipeline_model_parallel_size=args.pipeline_model_parallel_size, tensor_model_parallel_size=args.tensor_model_parallel_size, @@ -398,6 +408,12 @@ def get_parser(): default=4e-4, help="Learning rate for training. Default is 4e-4", ) + parser.add_argument( + "--scheduler-num-steps", + type=int, + required=False, + help="Number of steps for learning rate scheduler. Will use --num-steps if not given. Default is None.", + ) parser.add_argument( "--create-tensorboard-logger", action="store_true", default=False, help="Create a tensorboard logger." ) diff --git a/sub-packages/bionemo-esm2/tests/bionemo/esm2/scripts/test_train_esm2.py b/sub-packages/bionemo-esm2/tests/bionemo/esm2/scripts/test_train_esm2.py index 9023e399eb..435f5bcbac 100644 --- a/sub-packages/bionemo-esm2/tests/bionemo/esm2/scripts/test_train_esm2.py +++ b/sub-packages/bionemo-esm2/tests/bionemo/esm2/scripts/test_train_esm2.py @@ -101,6 +101,7 @@ def test_main_runs(monkeypatch, tmpdir, dummy_protein_dataset, dummy_parquet_tra wandb_project=None, wandb_offline=True, num_steps=10, + scheduler_num_steps=None, warmup_steps=5, limit_val_batches=1, val_check_interval=1, @@ -168,6 +169,7 @@ def test_val_dataloader_in_main_runs_with_limit_val_batches( wandb_project=None, wandb_offline=True, num_steps=10, + scheduler_num_steps=None, warmup_steps=2, limit_val_batches=limit_val_batches, val_check_interval=1, diff --git a/sub-packages/bionemo-llm/src/bionemo/llm/run/config_models.py b/sub-packages/bionemo-llm/src/bionemo/llm/run/config_models.py index f22e7c3650..e6c0f6177d 100644 --- a/sub-packages/bionemo-llm/src/bionemo/llm/run/config_models.py +++ b/sub-packages/bionemo-llm/src/bionemo/llm/run/config_models.py @@ -325,6 +325,7 @@ class OptimizerSchedulerConfig(BaseModel): monitor (str): Metric to monitor for learning rate adjustments. Default is "val_loss". warmup_steps (int): Number of warmup steps for use with the warmup annealing learning rate scheduler. Default is 0. lr_scheduler (Literal['warmup_anneal', 'cosine']): Type of learning rate scheduler to use. Default is 'warmup_anneal'. NOTE this is likely to change. + max_steps (Optional[int]): max_steps used in optimizer. Default to None which uses max_steps from TrainingConfig. """ lr: float = 1e-4 @@ -335,6 +336,7 @@ class OptimizerSchedulerConfig(BaseModel): cosine_hold_frac: float = 0.05 warmup_steps: int = 0 lr_scheduler: Literal["warmup_anneal", "cosine"] = "warmup_anneal" + max_steps: Optional[int] = None class ExperimentConfig(BaseModel): diff --git a/sub-packages/bionemo-llm/src/bionemo/llm/train.py b/sub-packages/bionemo-llm/src/bionemo/llm/train.py index 442b01957a..7626deb43c 100644 --- a/sub-packages/bionemo-llm/src/bionemo/llm/train.py +++ b/sub-packages/bionemo-llm/src/bionemo/llm/train.py @@ -203,7 +203,7 @@ def train( # TODO: need an abstraction for LrSchedulerConfig if optim_config.lr_scheduler == "cosine": lr_scheduler = CosineAnnealingScheduler( - max_steps=training_config.max_steps, + max_steps=training_config.max_steps if optim_config.max_steps is None else optim_config.max_steps, min_lr=optim_config.lr / 100, warmup_steps=int(math.ceil(training_config.max_steps * optim_config.cosine_rampup_frac)), interval=optim_config.interval, @@ -213,7 +213,7 @@ def train( elif optim_config.lr_scheduler == "warmup_anneal": lr_scheduler = WarmupAnnealDecayHoldScheduler( warmup_steps=optim_config.warmup_steps, - max_steps=training_config.max_steps, + max_steps=training_config.max_steps if optim_config.max_steps is None else optim_config.max_steps, max_lr=optim_config.lr, min_lr=optim_config.lr / 10.0, anneal_percentage=0.10,