Skip to content

Commit

Permalink
Add support on num steps for learning rate scheduler (#489)
Browse files Browse the repository at this point in the history
Decouple `num_steps` with number of steps in learning rate scheduler.
  • Loading branch information
sichu2023 authored Dec 5, 2024
1 parent 38be873 commit c95d24e
Show file tree
Hide file tree
Showing 5 changed files with 55 additions and 18 deletions.
43 changes: 30 additions & 13 deletions sub-packages/bionemo-esm2/src/bionemo/esm2/run/recipes.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,21 +36,27 @@
from bionemo.llm.utils.logger_utils import WandbConfig


def esm2_base_training_config() -> TrainingConfig:
def esm2_base_training_config(max_steps: int = 500000) -> TrainingConfig:
"""Base training config for ESM2."""
return TrainingConfig(
max_steps=500000,
max_steps=max_steps,
limit_val_batches=1.0,
val_check_interval=10_000,
precision="bf16-mixed",
include_perplexity=True,
)


def esm2_base_optimizer_scheduler_config() -> OptimizerSchedulerConfig:
def esm2_base_optimizer_scheduler_config(max_steps: Optional[int] = None) -> OptimizerSchedulerConfig:
"""Base optimizer scheduler config for ESM2."""
return OptimizerSchedulerConfig(
optimizer="adam", lr=4e-4, interval="step", monitor="val_loss", lr_scheduler="warmup_anneal", warmup_steps=2000
optimizer="adam",
lr=4e-4,
interval="step",
monitor="val_loss",
lr_scheduler="warmup_anneal",
warmup_steps=2000,
max_steps=max_steps,
)


Expand Down Expand Up @@ -128,9 +134,9 @@ def esm2_8m_recipe(args) -> MainConfig[ExposedESM2PretrainConfig, ESM2DataConfig
return MainConfig(
data_config=esm2_base_data_config(args),
parallel_config=esm2_base_parallel_config(),
training_config=esm2_base_training_config(), # no changes for 8m
training_config=esm2_base_training_config(max_steps=args.max_steps), # no changes for 8m
bionemo_model_config=esm2_8m_model_config(args.initial_ckpt_path),
optim_config=esm2_base_optimizer_scheduler_config(), # no changes for 8m
optim_config=esm2_base_optimizer_scheduler_config(max_steps=args.scheduler_max_steps), # no changes for 8m
experiment_config=esm2_8m_experiment_config(args.result_dir),
wandb_config=esm2_8m_wandb_config(),
)
Expand Down Expand Up @@ -183,9 +189,9 @@ def esm2_650m_recipe(args) -> MainConfig[ExposedESM2PretrainConfig, ESM2DataConf
return MainConfig(
data_config=esm2_base_data_config(args),
parallel_config=esm2_base_parallel_config(),
training_config=esm2_base_training_config(), # no changes for 8m
training_config=esm2_base_training_config(max_steps=args.max_steps), # no changes for 8m
bionemo_model_config=esm2_650m_model_config(args.initial_ckpt_path),
optim_config=esm2_base_optimizer_scheduler_config(), # no changes for 8m
optim_config=esm2_base_optimizer_scheduler_config(max_steps=args.scheduler_max_steps), # no changes for 8m
experiment_config=esm2_650m_experiment_config(args.result_dir),
wandb_config=esm2_650m_wandb_config(),
)
Expand Down Expand Up @@ -251,9 +257,9 @@ def esm2_3b_recipe(args) -> MainConfig[ExposedESM2PretrainConfig, ESM2DataConfig
return MainConfig(
data_config=esm2_base_data_config(args),
parallel_config=esm2_3b_parallel_config(),
training_config=esm2_base_training_config(), # no changes for 8m
training_config=esm2_base_training_config(max_steps=args.max_steps), # no changes for 8m
bionemo_model_config=esm2_3b_model_config(args.initial_ckpt_path),
optim_config=esm2_base_optimizer_scheduler_config(), # no changes for 8m
optim_config=esm2_base_optimizer_scheduler_config(max_steps=args.scheduler_max_steps), # no changes for 8m
experiment_config=esm2_3b_experiment_config(args.result_dir),
wandb_config=esm2_3b_wandb_config(),
)
Expand Down Expand Up @@ -282,9 +288,9 @@ def tiny_train_config_recipe() -> TrainingConfig:
return TrainingConfig(max_steps=10, limit_val_batches=2, val_check_interval=2)


def default_adam_optimizer_with_cosine_annealing_recipe() -> OptimizerSchedulerConfig:
def default_adam_optimizer_with_cosine_annealing_recipe(max_steps: Optional[int] = None) -> OptimizerSchedulerConfig:
"""Default optimizer scheduler config for ESM2."""
return OptimizerSchedulerConfig()
return OptimizerSchedulerConfig(max_steps=max_steps)


def experiment_config_recipe(result_dir="./results") -> ExperimentConfig:
Expand Down Expand Up @@ -347,7 +353,7 @@ def esm2_tiny_test_recipe(args):
seq_length=data_config.max_seq_length, initial_ckpt_path=args.initial_ckpt_path
)

optim_config = default_adam_optimizer_with_cosine_annealing_recipe()
optim_config = default_adam_optimizer_with_cosine_annealing_recipe(max_steps=args.scheduler_max_steps)
experiment_config = experiment_config_recipe(args.result_dir)
wandb_config = WandbConfig(
project="bionemo2-demo",
Expand Down Expand Up @@ -436,6 +442,17 @@ def parse_args():
help="Path to an existing to a checkpoint directory to restore an existing checkpoint. Not compatible with all recipes.",
)

parser.add_argument(
"--max-steps", type=int, required=False, default=500000, help="Max steps for training. Default to 500000."
)

parser.add_argument(
"--scheduler-max-steps",
type=int,
required=False,
help="Set scheduler max_steps directly. Otherwise default to None, which uses max_steps from training config.",
)

args = parser.parse_args()
return args

Expand Down
22 changes: 19 additions & 3 deletions sub-packages/bionemo-esm2/src/bionemo/esm2/scripts/train_esm2.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,12 +52,13 @@ def main(
max_seq_length: int,
result_dir: Path,
num_steps: int,
scheduler_num_steps: Optional[int],
warmup_steps: int,
limit_val_batches: int,
val_check_interval: int,
log_every_n_steps: Optional[int],
num_dataset_workers: int,
biobert_spec_option: BiobertSpecOption, # TODO(@farhadrgh) clarify how to parse this.
biobert_spec_option: BiobertSpecOption,
lr: float,
micro_batch_size: int,
accumulate_grad_batches: int,
Expand Down Expand Up @@ -111,6 +112,7 @@ def main(
num_dataset_workers (int): number of dataset workers
biobert_spec_option (BiobertSpecOption): the biobert spec option (architecture) to use for this run
lr (float): learning rate
scheduler_num_steps (Optional[int]): Number of steps in learning rate scheduler. Use num_steps if not provided.
micro_batch_size (int): micro batch size, from this and parallelism settings we infer the global batch size
accumulate_grad_batches (int): number of batches to accumulate gradients for
experiment_name (str): experiment name, this is the name used for the wandb run, and the sub-directory of the
Expand Down Expand Up @@ -247,20 +249,27 @@ def main(
variable_seq_lengths=min_seq_length != max_seq_length,
)

if scheduler_num_steps is None:
scheduler_num_steps = num_steps

model = biobert_lightning_module(
esm2_config,
tokenizer=tokenizer,
optimizer=MegatronOptimizerModule(
config=OptimizerConfig(
lr=lr,
optimizer="adam", # fused_adam not supported
optimizer="adam",
use_distributed_optimizer=True,
weight_decay=0.01,
adam_beta1=0.9,
adam_beta2=0.98,
),
lr_scheduler=WarmupAnnealDecayHoldScheduler(
warmup_steps=warmup_steps, max_steps=num_steps, max_lr=lr, min_lr=0.0, anneal_percentage=0.10
warmup_steps=warmup_steps,
max_steps=scheduler_num_steps,
max_lr=lr,
min_lr=0.0,
anneal_percentage=0.10,
),
),
)
Expand Down Expand Up @@ -328,6 +337,7 @@ def train_esm2_entrypoint():
num_dataset_workers=args.num_dataset_workers,
biobert_spec_option=args.biobert_spec_option,
lr=args.lr,
scheduler_num_steps=args.scheduler_num_steps,
micro_batch_size=args.micro_batch_size,
pipeline_model_parallel_size=args.pipeline_model_parallel_size,
tensor_model_parallel_size=args.tensor_model_parallel_size,
Expand Down Expand Up @@ -398,6 +408,12 @@ def get_parser():
default=4e-4,
help="Learning rate for training. Default is 4e-4",
)
parser.add_argument(
"--scheduler-num-steps",
type=int,
required=False,
help="Number of steps for learning rate scheduler. Will use --num-steps if not given. Default is None.",
)
parser.add_argument(
"--create-tensorboard-logger", action="store_true", default=False, help="Create a tensorboard logger."
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ def test_main_runs(monkeypatch, tmpdir, dummy_protein_dataset, dummy_parquet_tra
wandb_project=None,
wandb_offline=True,
num_steps=10,
scheduler_num_steps=None,
warmup_steps=5,
limit_val_batches=1,
val_check_interval=1,
Expand Down Expand Up @@ -168,6 +169,7 @@ def test_val_dataloader_in_main_runs_with_limit_val_batches(
wandb_project=None,
wandb_offline=True,
num_steps=10,
scheduler_num_steps=None,
warmup_steps=2,
limit_val_batches=limit_val_batches,
val_check_interval=1,
Expand Down
2 changes: 2 additions & 0 deletions sub-packages/bionemo-llm/src/bionemo/llm/run/config_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,7 @@ class OptimizerSchedulerConfig(BaseModel):
monitor (str): Metric to monitor for learning rate adjustments. Default is "val_loss".
warmup_steps (int): Number of warmup steps for use with the warmup annealing learning rate scheduler. Default is 0.
lr_scheduler (Literal['warmup_anneal', 'cosine']): Type of learning rate scheduler to use. Default is 'warmup_anneal'. NOTE this is likely to change.
max_steps (Optional[int]): max_steps used in optimizer. Default to None which uses max_steps from TrainingConfig.
"""

lr: float = 1e-4
Expand All @@ -335,6 +336,7 @@ class OptimizerSchedulerConfig(BaseModel):
cosine_hold_frac: float = 0.05
warmup_steps: int = 0
lr_scheduler: Literal["warmup_anneal", "cosine"] = "warmup_anneal"
max_steps: Optional[int] = None


class ExperimentConfig(BaseModel):
Expand Down
4 changes: 2 additions & 2 deletions sub-packages/bionemo-llm/src/bionemo/llm/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ def train(
# TODO: need an abstraction for LrSchedulerConfig
if optim_config.lr_scheduler == "cosine":
lr_scheduler = CosineAnnealingScheduler(
max_steps=training_config.max_steps,
max_steps=training_config.max_steps if optim_config.max_steps is None else optim_config.max_steps,
min_lr=optim_config.lr / 100,
warmup_steps=int(math.ceil(training_config.max_steps * optim_config.cosine_rampup_frac)),
interval=optim_config.interval,
Expand All @@ -213,7 +213,7 @@ def train(
elif optim_config.lr_scheduler == "warmup_anneal":
lr_scheduler = WarmupAnnealDecayHoldScheduler(
warmup_steps=optim_config.warmup_steps,
max_steps=training_config.max_steps,
max_steps=training_config.max_steps if optim_config.max_steps is None else optim_config.max_steps,
max_lr=optim_config.lr,
min_lr=optim_config.lr / 10.0,
anneal_percentage=0.10,
Expand Down

0 comments on commit c95d24e

Please sign in to comment.