Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add initial configs for perf testing on ESM2 in JET (bionemo2) #497

Open
wants to merge 46 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 39 commits
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
835fd20
added initial config for partial conv of bionemo2
dorotat-nv Nov 13, 2024
4598078
Merge remote-tracking branch 'origin/main' into dorotat/esm2-partial-…
dorotat-nv Nov 21, 2024
0b244dc
adding wandb job type support
dorotat-nv Nov 21, 2024
d394d68
bugfixes
dorotat-nv Nov 21, 2024
5462905
added smoke test config
dorotat-nv Nov 21, 2024
47cebab
adding long esm2 650 training
dorotat-nv Nov 21, 2024
2360c26
disabled checkpointing and updated max steps
dorotat-nv Nov 27, 2024
369a19f
Merge remote-tracking branch 'origin' into dorotat/esm2-partial-conv-…
dorotat-nv Dec 2, 2024
ccd5cfa
reduced number of steps
dorotat-nv Dec 2, 2024
ffd3a30
reduced number of steps
dorotat-nv Dec 3, 2024
0c27b78
Merge remote-tracking branch 'origin/main' into dorotat/esm2-partial-…
dorotat-nv Dec 3, 2024
1fb3848
added perf benchmarks
dorotat-nv Dec 4, 2024
f1c4e95
Merge remote-tracking branch 'origin' into dorotat/esm2-perf-jet-bion…
dorotat-nv Dec 4, 2024
c4faf3b
updated config
dorotat-nv Dec 4, 2024
fb3d461
add option to disable checkpointing
dorotat-nv Dec 5, 2024
7c3a7fb
added option to benchmarks
dorotat-nv Dec 5, 2024
fca2dbd
added option to benchmarks
dorotat-nv Dec 5, 2024
bb3cb50
Merge remote-tracking branch 'origin/main' into dorotat/esm2-perf-jet…
dorotat-nv Dec 10, 2024
e26457a
fixed disable checkpointing
dorotat-nv Dec 11, 2024
9ccbc09
added initial config for partial conv of bionemo2
dorotat-nv Nov 13, 2024
64bd14b
adding wandb job type support
dorotat-nv Nov 21, 2024
e4b8044
bugfixes
dorotat-nv Nov 21, 2024
d1612ae
added smoke test config
dorotat-nv Nov 21, 2024
48fc669
adding long esm2 650 training
dorotat-nv Nov 21, 2024
c0d2e38
disabled checkpointing and updated max steps
dorotat-nv Nov 27, 2024
d725855
reduced number of steps
dorotat-nv Dec 2, 2024
379471a
reduced number of steps
dorotat-nv Dec 3, 2024
e5c6be6
added perf benchmarks
dorotat-nv Dec 4, 2024
0e23248
updated config
dorotat-nv Dec 4, 2024
87e50d9
add option to disable checkpointing
dorotat-nv Dec 5, 2024
7c78ef5
added option to benchmarks
dorotat-nv Dec 5, 2024
08dee15
added option to benchmarks
dorotat-nv Dec 5, 2024
5d4c3ec
fixed disable checkpointing
dorotat-nv Dec 11, 2024
be53772
addded support for checkpointing - esm2
dorotat-nv Dec 12, 2024
957871a
Merge remote-tracking branch 'origin/dorotat/esm2-perf-jet-bionemo2' …
dorotat-nv Dec 12, 2024
111772c
added new parameters to train geneformer
dorotat-nv Dec 12, 2024
3477124
Merge remote-tracking branch 'origin' into dorotat/esm2-perf-jet-bion…
dorotat-nv Dec 12, 2024
e3d91d0
added unit test
dorotat-nv Dec 12, 2024
b61fea0
adding support for genformer
dorotat-nv Dec 12, 2024
9da1d95
Merge remote-tracking branch 'origin/main' into dorotat/esm2-perf-jet…
dorotat-nv Dec 13, 2024
f815896
adding a default value to wandb config
dorotat-nv Dec 13, 2024
4a3225b
bugfix
dorotat-nv Dec 13, 2024
ee7972e
fixed geneformer test
dorotat-nv Dec 13, 2024
dea271c
bugfix
dorotat-nv Dec 13, 2024
428058b
Merge branch 'main' into dorotat/esm2-perf-jet-bionemo2
dorotat-nv Dec 13, 2024
125d17e
Merge branch 'main' into dorotat/esm2-perf-jet-bionemo2
dorotat-nv Dec 16, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 54 additions & 0 deletions ci/benchmarks/partial-conv/esm2_pretrain.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
scope: partial-conv
dorotat-nv marked this conversation as resolved.
Show resolved Hide resolved
time_limit: 14400
script_args:
# All arguments referenced in the script string must be specified here.
# Arguments not referenced in the script string must have the 'arg' field specified.
# See jet/core/configs.py for the specification of the configuration class
workspace:
value: /workspace/bionemo2
key_segment: False
data_path:
value: /data/20240809_uniref_2024_03/data
key_segment: False
model:
value: esm2
variant:
value: train
config_name:
value: 650M
precision:
value: [bf16-mixed]
nodes:
value: [4]
gpus:
value: 8
batch_size:
value: 16
max_steps:
value: 26500
script: |-
WANDB_API_KEY=$BIONEMO_WANDB_API_KEY ${variant}_${model} \
dorotat-nv marked this conversation as resolved.
Show resolved Hide resolved
--train-cluster-path=${data_path}/train_clusters.parquet \
--train-database-path=${data_path}/train.db \
--valid-cluster-path=${data_path}/valid_clusters.parquet \
--valid-database-path=${data_path}/validation.db \
--micro-batch-size=${batch_size} \
--num-nodes=${nodes} \
--num-gpus=${gpus} \
--val-check-interval=1000 \
--limit-val-batches=1 \
--num-steps=${max_steps} \
--min-seq-length=1024 \
--max-seq-length=1024 \
--num-layers=33 \
--hidden-size=1280 \
--num-attention-heads=20 \
--ffn-hidden-size=5120 \
--create-tensorboard-logger \
--experiment-name=${batch_size}bs_${nodes}node_${gpus}gpu_${max_steps}s_${precision}prec \
--result-dir=${tensorboard_dir} \
--wandb-project=${wandb_project_name} \
--wandb-group=${model}_${variant}_${config_name} \
--wandb-job-type=${pipeline_label} \
--log-every-n-steps=50 \
--disable-checkpointing;
65 changes: 65 additions & 0 deletions ci/benchmarks/perf/esm2_pretrain.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
scope: perf
time_limit: 1800
script_args:
# All arguments referenced in the script string must be specified here.
# Arguments not referenced in the script string must have the 'arg' field specified.
# See jet/core/configs.py for the specification of the configuration class
workspace:
value: /workspace/bionemo2
key_segment: False
data_path:
value: /data/20240809_uniref_2024_03/data
key_segment: False
model: esm2
variant: train
config_name: 650M
precision: bf16-mixed
max_steps: 200
gpus: 8
acc_grad: 1
products:
- nodes: 1
batch_size: 16
pp: 1
tp: 1
- nodes: 2
batch_size: 16
pp: 2
tp: 1
- nodes: 2
batch_size: 16
pp: 1
tp: 2
- nodes: 2
batch_size: 16
pp: 1
tp: 1
script: |-
WANDB_API_KEY=$BIONEMO_WANDB_API_KEY ${variant}_${model} \
--train-cluster-path=${data_path}/train_clusters.parquet \
--train-database-path=${data_path}/train.db \
--valid-cluster-path=${data_path}/valid_clusters.parquet \
--valid-database-path=${data_path}/validation.db \
--micro-batch-size=${batch_size} \
--num-nodes=${nodes} \
--num-gpus=${gpus} \
--val-check-interval=50 \
--limit-val-batches=1 \
--num-steps=${max_steps} \
--min-seq-length=1024 \
--max-seq-length=1024 \
--num-layers=33 \
--hidden-size=1280 \
--num-attention-heads=20 \
--ffn-hidden-size=5120 \
--create-tensorboard-logger \
--experiment-name=${batch_size}bs_${nodes}node_${gpus}gpu_${max_steps}s_${precision}prec \
--result-dir=${tensorboard_dir} \
--wandb-project=${wandb_project_name} \
--wandb-group=${model}_${variant}_${config_name} \
--wandb-job-type=${pipeline_label} \
--log-every-n-steps=10 \
--accumulate-grad-batches=${acc_grad} \
--pipeline-model-parallel-size=${pp} \
--tensor-model-parallel-size={tp} \
--disable-checkpointing;
21 changes: 19 additions & 2 deletions sub-packages/bionemo-esm2/src/bionemo/esm2/run/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,13 @@ def parse_args():
default=[0],
help="Enable nsys profiling for these ranks.",
)
parser.add_argument(
"--disable-checkpointing",
action="store_false",
default=True,
dest="create_checkpoint_callback",
help="Disable creating a ModelCheckpoint callback.",
)
return parser.parse_args()

def string_to_class(path: str):
Expand All @@ -87,7 +94,12 @@ def string_to_class(path: str):
module = importlib.import_module(module_path)
return getattr(module, class_name)

def load_config(config_path: str, model_config_cls: Optional[str], data_config_cls: Optional[str]) -> MainConfig:
def load_config(
config_path: str,
model_config_cls: Optional[str],
data_config_cls: Optional[str],
create_checkpoint_callback: bool,
) -> MainConfig:
with open(config_path, "r") as f:
config_dict = yaml.safe_load(f)

Expand All @@ -109,10 +121,15 @@ def load_config(config_path: str, model_config_cls: Optional[str], data_config_c
elif isinstance(data_config_cls, str):
data_config_cls = string_to_class(data_config_cls)

# disable checkpointing if called from the command line
if not create_checkpoint_callback:
config_dict["training_config"]["enable_checkpointing"] = create_checkpoint_callback
config_dict["experiment_config"]["create_checkpoint_callback"] = create_checkpoint_callback

return MainConfig[model_config_cls, data_config_cls](**config_dict)

args = parse_args()
config = load_config(args.config, args.model_config_cls, args.data_config_cls)
config = load_config(args.config, args.model_config_cls, args.data_config_cls, args.create_checkpoint_callback)

if args.nsys_profiling:
nsys_config = NsysConfig(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,13 +70,15 @@ def main(
wandb_offline: bool = False,
wandb_tags: Optional[List[str]] = None,
wandb_group: Optional[str] = None,
wandb_job_type: Optional[str] = None,
wandb_id: Optional[str] = None,
wandb_anonymous: Optional[bool] = False,
wandb_log_model: bool = False,
pipeline_model_parallel_size: int = 1,
tensor_model_parallel_size: int = 1,
create_tensorboard_logger: bool = False,
nemo1_init_path: Optional[Path] = None,
create_checkpoint_callback: bool = True,
restore_from_checkpoint_path: Optional[str] = None,
save_best_checkpoint: bool = True,
save_last_checkpoint: bool = True,
Expand Down Expand Up @@ -124,13 +126,15 @@ def main(
wandb_offline (bool): Run offline (data can be streamed later to wandb servers).
wandb_tags (Optional[List[str]]): Tags associated with this run
wandb_group (Optional[str]): A unique string shared by all runs in a given group
wandb_job_type (Optional[str]): Type of run, which is useful when you're grouping runs together into larger experiments using group.
wandb_id (Optional[str]): Sets the version, mainly used to resume a previous run
wandb_anonymous (Optional[bool]): Enables or explicitly disables anonymous logging
wandb_log_model (bool): Save checkpoints in wandb dir to upload on W&B servers
pipeline_model_parallel_size (int): pipeline model parallel size
tensor_model_parallel_size (int): tensor model parallel size
create_tensorboard_logger (bool): create the tensorboard logger
nemo1_init_path (Optional[Path]): Nemo 1 initialization path
create_checkpoint_callback (bool): create a ModelCheckpoint callback and attach it to the pytorch lightning trainer
restore_from_checkpoint_path (Optional[str]): If set, restores the model from the directory passed in. Expects the
checkpoint to be created by using the ModelCheckpoint class and always_save_context=True.
save_best_checkpoint (bool): whether to save the best checkpoint
Expand Down Expand Up @@ -182,6 +186,7 @@ def main(
entity=wandb_entity,
tags=wandb_tags,
group=wandb_group,
job_type=wandb_job_type,
id=wandb_id,
anonymous=wandb_anonymous,
log_model=wandb_log_model,
Expand Down Expand Up @@ -213,6 +218,7 @@ def main(
log_every_n_steps=log_every_n_steps,
num_nodes=num_nodes,
callbacks=callbacks,
enable_checkpointing=create_checkpoint_callback,
plugins=nl.MegatronMixedPrecision(precision=precision),
)

Expand Down Expand Up @@ -275,14 +281,17 @@ def main(
)

# Configure our custom Checkpointer
checkpoint_callback = nl_callbacks.ModelCheckpoint(
save_last=save_last_checkpoint,
monitor=metric_to_monitor_for_checkpoints, # "val_loss",
save_top_k=save_top_k,
every_n_train_steps=val_check_interval,
always_save_context=True, # Enables the .nemo file-like checkpointing where all IOMixins are under SerDe
filename="{epoch}-{val_loss:.2f}-{step}-{consumed_samples}", # Including step and consumed_samples in the checkpoint filename prevents duplicate filenames and bugs related to this.
)
if create_checkpoint_callback:
checkpoint_callback = nl_callbacks.ModelCheckpoint(
save_last=save_last_checkpoint,
monitor=metric_to_monitor_for_checkpoints, # "val_loss",
save_top_k=save_top_k,
every_n_train_steps=val_check_interval,
always_save_context=True, # Enables the .nemo file-like checkpointing where all IOMixins are under SerDe
filename="{epoch}-{val_loss:.2f}-{step}-{consumed_samples}", # Including step and consumed_samples in the checkpoint filename prevents duplicate filenames and bugs related to this.
)
else:
checkpoint_callback = None
dorotat-nv marked this conversation as resolved.
Show resolved Hide resolved

# Setup the logger and train the model
nemo_logger = setup_nemo_lightning_logger(
Expand Down Expand Up @@ -325,6 +334,7 @@ def train_esm2_entrypoint():
wandb_project=args.wandb_project,
wandb_tags=args.wandb_tags,
wandb_group=args.wandb_group,
wandb_job_type=args.wandb_job_type,
wandb_id=args.wandb_id,
wandb_anonymous=args.wandb_anonymous,
wandb_log_model=args.wandb_log_model,
Expand All @@ -346,6 +356,7 @@ def train_esm2_entrypoint():
experiment_name=args.experiment_name,
resume_if_exists=args.resume_if_exists,
nemo1_init_path=args.nemo1_init_path,
create_checkpoint_callback=args.create_checkpoint_callback,
dorotat-nv marked this conversation as resolved.
Show resolved Hide resolved
restore_from_checkpoint_path=args.restore_from_checkpoint_path,
save_best_checkpoint=args.save_best_checkpoint,
save_last_checkpoint=args.save_last_checkpoint,
Expand Down Expand Up @@ -432,6 +443,12 @@ def get_parser():
parser.add_argument(
"--wandb-group", type=str, default=None, help="A unique string shared by all runs in a given group"
)
parser.add_argument(
"--wandb-job-type",
type=str,
default=None,
help="A unique string representing a type of run, which is useful when you're grouping runs together into larger experiments using group.",
)
parser.add_argument(
"--wandb-id", type=str, default=None, help="Sets the version, mainly used to resume a previous run"
)
Expand Down Expand Up @@ -553,6 +570,13 @@ def get_parser():
required=False,
help="Path to nemo1 file, if desired to load at init time.",
)
parser.add_argument(
"--disable-checkpointing",
action="store_false",
default=True,
dest="create_checkpoint_callback",
help="Disable creating a ModelCheckpoint callback.",
)
dorotat-nv marked this conversation as resolved.
Show resolved Hide resolved
parser.add_argument(
"--save-best-checkpoint",
action="store_true",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,10 @@ def dummy_parquet_train_val_inputs(tmp_path):
return train_cluster_path, valid_cluster_path


def test_main_runs(monkeypatch, tmpdir, dummy_protein_dataset, dummy_parquet_train_val_inputs):
@pytest.mark.parametrize("create_checkpoint_callback", [True, False])
def test_main_runs(
monkeypatch, tmpdir, dummy_protein_dataset, dummy_parquet_train_val_inputs, create_checkpoint_callback
):
train_cluster_path, valid_cluster_path = dummy_parquet_train_val_inputs

result_dir = Path(tmpdir.mkdir("results"))
Expand Down Expand Up @@ -119,6 +122,7 @@ def test_main_runs(monkeypatch, tmpdir, dummy_protein_dataset, dummy_parquet_tra
num_attention_heads=2,
hidden_size=4,
ffn_hidden_size=4 * 4,
create_checkpoint_callback=create_checkpoint_callback,
)

assert (result_dir / "test_experiment").exists(), "Could not find test experiment directory."
Expand All @@ -129,9 +133,20 @@ def test_main_runs(monkeypatch, tmpdir, dummy_protein_dataset, dummy_parquet_tra
assert (
result_dir / "test_experiment" / uq_rundir / "checkpoints"
).exists(), "Could not find test experiment checkpoints directory."
assert (
result_dir / "test_experiment" / uq_rundir / "checkpoints"
).is_dir(), "Test experiment checkpoints directory is supposed to be a directory."

expected_exists = create_checkpoint_callback
actual_exists = (result_dir / "test_experiment" / uq_rundir / "checkpoints").exists()

assert expected_exists == actual_exists, (
f"Checkpoints directory existence mismatch. "
f"Expected: {'exists' if expected_exists else 'does not exist'}, "
f"Found: {'exists' if actual_exists else 'does not exist'}."
)

if create_checkpoint_callback:
assert (
result_dir / "test_experiment" / uq_rundir / "checkpoints"
).is_dir(), "Test experiment checkpoints directory is supposed to be a directory."
assert (
result_dir / "test_experiment" / uq_rundir / "nemo_log_globalrank-0_localrank-0.txt"
).is_file(), "Could not find experiment log."
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,13 @@ def parse_args():
default=[0],
help="Enable nsys profiling for these ranks.",
)
parser.add_argument(
"--disable-checkpointing",
action="store_false",
default=True,
dest="create_checkpoint_callback",
help="Disable creating a ModelCheckpoint callback.",
)

return parser.parse_args()

Expand All @@ -92,7 +99,12 @@ def string_to_class(path: str):
module = importlib.import_module(module_path)
return getattr(module, class_name)

def load_config(config_path: str, model_config_cls: Optional[str], data_config_cls: Optional[str]) -> MainConfig:
def load_config(
config_path: str,
model_config_cls: Optional[str],
data_config_cls: Optional[str],
create_checkpoint_callback: bool,
) -> MainConfig:
with open(config_path, "r") as f:
config_dict = yaml.safe_load(f)

Expand All @@ -106,6 +118,11 @@ def load_config(config_path: str, model_config_cls: Optional[str], data_config_c
# We assume we get a string to some importable config... e.g. in the sub-package jensen, 'bionemo.jensen.configs.MyConfig'
model_config_cls = string_to_class(model_config_cls)

# disable checkpointing if called from the command line
if not create_checkpoint_callback:
config_dict["training_config"]["enable_checkpointing"] = create_checkpoint_callback
config_dict["experiment_config"]["create_checkpoint_callback"] = create_checkpoint_callback

if data_config_cls is None:
data_config_cls = GeneformerPretrainingDataConfig
elif isinstance(data_config_cls, str):
Expand Down
Loading
Loading