From f6f74fb705d38ba94bf0e1ecf815b264fa328c9b Mon Sep 17 00:00:00 2001 From: Jared T Nielsen Date: Fri, 22 May 2020 19:10:55 -0700 Subject: [PATCH] Cleaner argument handling & nlp/common/ folder (#16) By moving the arguments into their own dataclass (available in Python 3.7), we can group certain types of arguments, such as ModelArguments and SageMakerArguments. This lets us consolidate the sagemaker scripts into a single file, and makes the arguments simpler to pass around in functions. Moves several files to common/. Users will need to set PYTHONPATH=/path/to/deep-learning-models/nlp. Also fixes PYTHONPATH to /opt/ml/... in the SageMaker container, so those jobs should run. Also adds support to log hyperparameters in TensorBoard. --- models/nlp/albert/README.md | 37 ++- models/nlp/albert/arguments.py | 123 ---------- models/nlp/albert/launch_sagemaker.py | 54 +++++ models/nlp/albert/run_pretraining.py | 214 +++++++++++------- models/nlp/albert/run_squad.py | 97 ++++---- models/nlp/albert/run_squad_evaluation.py | 19 +- models/nlp/albert/sagemaker_pretraining.py | 51 ----- models/nlp/albert/sagemaker_squad.py | 51 ----- models/nlp/common/arguments.py | 152 +++++++++++++ models/nlp/{albert => common}/datasets.py | 0 .../learning_rate_schedules.py | 0 models/nlp/{albert => common}/models.py | 0 .../nlp/{albert => common}/sagemaker_utils.py | 2 +- models/nlp/{albert => common}/utils.py | 0 models/nlp/docker/hvd_kubernetes.Dockerfile | 15 +- models/nlp/docker/ngc_sagemaker.Dockerfile | 4 + setup.cfg | 2 +- 17 files changed, 459 insertions(+), 362 deletions(-) delete mode 100644 models/nlp/albert/arguments.py create mode 100644 models/nlp/albert/launch_sagemaker.py delete mode 100644 models/nlp/albert/sagemaker_pretraining.py delete mode 100644 models/nlp/albert/sagemaker_squad.py create mode 100644 models/nlp/common/arguments.py rename models/nlp/{albert => common}/datasets.py (100%) rename models/nlp/{albert => common}/learning_rate_schedules.py (100%) rename models/nlp/{albert => common}/models.py (100%) rename models/nlp/{albert => common}/sagemaker_utils.py (100%) rename models/nlp/{albert => common}/utils.py (100%) diff --git a/models/nlp/albert/README.md b/models/nlp/albert/README.md index af32cba8..c75394b4 100644 --- a/models/nlp/albert/README.md +++ b/models/nlp/albert/README.md @@ -20,7 +20,9 @@ Language models help AWS customers to improve search results, text classificatio 3. Create an Amazon Elastic Container Registry (ECR) repository. Then build a Docker image from `docker/ngc_sagemaker.Dockerfile` and push it to ECR. ```bash -export IMAGE=${ACCOUNT_ID}.dkr.ecr.us-east-1.amazonaws.com/${REPO}:ngc_tf21_sagemaker +export ACCOUNT_ID= +export REPO= +export IMAGE=${ACCOUNT_ID}.dkr.ecr.us-east-1.amazonaws.com/${REPO}:ngc_tf210_sagemaker docker build -t ${IMAGE} -f docker/ngc_sagemaker.Dockerfile . $(aws ecr get-login --no-include-email) docker push ${IMAGE} @@ -39,8 +41,13 @@ export SAGEMAKER_SECURITY_GROUP_IDS=sg-123,sg-456 5. Launch the SageMaker job. ```bash -python sagemaker_pretraining.py \ +# Add the main folder to your PYTHONPATH +export PYTHONPATH=$PYTHONPATH:/path/to/deep-learning-models/models/nlp + +python launch_sagemaker.py \ --source_dir=. \ + --entry_point=run_pretraining.py \ + --sm_job_name=albert-pretrain \ --instance_type=ml.p3dn.24xlarge \ --instance_count=1 \ --load_from=scratch \ @@ -52,9 +59,35 @@ python sagemaker_pretraining.py \ --total_steps=125000 \ --learning_rate=0.00176 \ --optimizer=lamb \ + --log_frequency=10 \ --name=myfirstjob ``` +6. Launch a SageMaker finetuning job. + +```bash +python launch_sagemaker.py \ + --source_dir=. \ + --entry_point=run_squad.py \ + --sm_job_name=albert-squad \ + --instance_type=ml.p3dn.24xlarge \ + --instance_count=1 \ + --load_from=scratch \ + --model_type=albert \ + --model_size=base \ + --batch_size=6 \ + --total_steps=8144 \ + --warmup_steps=814 \ + --learning_rate=3e-5 \ + --task_name=squadv2 +``` + +7. Enter the Docker container to debug and edit code. + +```bash +docker run -it -v=/fsx:/fsx --gpus=all --shm-size=1g --ulimit memlock=-1 --ulimit stack=67108864 --rm ${IMAGE} /bin/bash +``` + diff --git a/models/nlp/albert/arguments.py b/models/nlp/albert/arguments.py deleted file mode 100644 index 0ba40038..00000000 --- a/models/nlp/albert/arguments.py +++ /dev/null @@ -1,123 +0,0 @@ -""" Since arguments are duplicated in run_pretraining.py and sagemaker_pretraining.py, they have -been abstracted into this file. It also makes the training scripts much shorter. -""" - -import argparse -import os - - -def populate_pretraining_parser(parser: argparse.ArgumentParser) -> None: - parser.add_argument("--model_dir", help="Unused, but passed by SageMaker") - parser.add_argument("--model_type", default="albert", choices=["albert", "bert"]) - parser.add_argument("--model_size", default="base", choices=["base", "large"]) - parser.add_argument("--batch_size", type=int, default=32, help="per GPU") - parser.add_argument("--gradient_accumulation_steps", type=int, default=2) - parser.add_argument("--max_seq_length", type=int, default=512, choices=[128, 512]) - parser.add_argument("--warmup_steps", type=int, default=3125) - parser.add_argument("--total_steps", type=int, default=125000) - parser.add_argument("--learning_rate", type=float, default=0.00176) - parser.add_argument("--end_learning_rate", type=float, default=3e-5) - parser.add_argument("--learning_rate_decay_power", type=float, default=1.0) - parser.add_argument("--hidden_dropout_prob", type=float, default=0.0) - parser.add_argument("--max_grad_norm", type=float, default=1.0) - parser.add_argument("--optimizer", default="lamb", choices=["lamb", "adam"]) - parser.add_argument("--name", default="", help="Additional info to append to metadata") - parser.add_argument("--log_frequency", type=int, default=1000) - parser.add_argument( - "--load_from", default="scratch", choices=["scratch", "checkpoint", "huggingface"], - ) - parser.add_argument("--checkpoint_path", default=None) - parser.add_argument( - "--fsx_prefix", - default="/fsx", - choices=["/fsx", "/opt/ml/input/data/training"], - help="Change to /opt/ml/input/data/training on SageMaker", - ) - # SageMaker does not work with 'store_const' args, since it parses into a dictionary - # We will treat any value not equal to None as True, and use --skip_xla=true - parser.add_argument( - "--skip_xla", - choices=["true"], - help="For debugging. Faster startup time, slower runtime, more GPU vRAM.", - ) - parser.add_argument( - "--eager", - choices=["true"], - help="For debugging. Faster launch, slower runtime, more GPU vRAM.", - ) - parser.add_argument( - "--skip_sop", choices=["true"], help="Only use MLM loss, and exclude the SOP loss.", - ) - parser.add_argument( - "--skip_mlm", choices=["true"], help="Only use SOP loss, and exclude the MLM loss.", - ) - parser.add_argument( - "--pre_layer_norm", - choices=["true"], - help="Place layer normalization before the attention & FFN, rather than after adding the residual connection. https://openreview.net/pdf?id=B1x8anVFPr", - ) - parser.add_argument("--extra_squad_steps", type=str) - parser.add_argument("--fast_squad", choices=["true"]) - parser.add_argument("--dummy_eval", choices=["true"]) - parser.add_argument("--seed", type=int, default=42) - - -def populate_squad_parser(parser: argparse.ArgumentParser) -> None: - # Model loading - parser.add_argument("--model_type", default="albert", choices=["albert", "bert"]) - parser.add_argument("--model_size", default="base", choices=["base", "large"]) - parser.add_argument("--load_from", required=True) - parser.add_argument("--load_step", type=int) - parser.add_argument("--skip_xla", choices=["true"]) - parser.add_argument("--eager", choices=["true"]) - parser.add_argument( - "--pre_layer_norm", - choices=["true"], - help="See https://github.com/huggingface/transformers/pull/3929", - ) - parser.add_argument( - "--fsx_prefix", - default="/fsx", - choices=["/fsx", "/opt/ml/input/data/training"], - help="Change to /opt/ml/input/data/training on SageMaker", - ) - # Hyperparameters from https://arxiv.org/pdf/1909.11942.pdf#page=17 - parser.add_argument("--batch_size", default=6, type=int) - parser.add_argument("--total_steps", default=8144, type=int) - parser.add_argument("--warmup_steps", default=814, type=int) - parser.add_argument("--learning_rate", default=3e-5, type=float) - parser.add_argument("--dataset", default="squadv2") - parser.add_argument("--seed", type=int, default=42) - # Logging information - parser.add_argument("--name", default="default") - parser.add_argument("--validate_frequency", default=1000, type=int) - parser.add_argument("--checkpoint_frequency", default=500, type=int) - parser.add_argument("--model_dir", help="Unused, but passed by SageMaker") - - -def populate_sagemaker_parser(parser: argparse.ArgumentParser) -> None: - # SageMaker parameters - parser.add_argument( - "--source_dir", - help="For example, /Users/myusername/Desktop/deep-learning-models/models/nlp/albert", - ) - parser.add_argument("--entry_point", default="run_pretraining.py") - parser.add_argument("--role", default=os.environ["SAGEMAKER_ROLE"]) - parser.add_argument("--image_name", default=os.environ["SAGEMAKER_IMAGE_NAME"]) - parser.add_argument("--fsx_id", default=os.environ["SAGEMAKER_FSX_ID"]) - parser.add_argument( - "--subnet_ids", help="Comma-separated string", default=os.environ["SAGEMAKER_SUBNET_IDS"] - ) - parser.add_argument( - "--security_group_ids", - help="Comma-separated string", - default=os.environ["SAGEMAKER_SECURITY_GROUP_IDS"], - ) - # Instance specs - parser.add_argument( - "--instance_type", - type=str, - default="ml.p3dn.24xlarge", - choices=["ml.p3dn.24xlarge", "ml.p3.16xlarge", "ml.g4dn.12xlarge"], - ) - parser.add_argument("--instance_count", type=int, default=1) diff --git a/models/nlp/albert/launch_sagemaker.py b/models/nlp/albert/launch_sagemaker.py new file mode 100644 index 00000000..9c9d84fb --- /dev/null +++ b/models/nlp/albert/launch_sagemaker.py @@ -0,0 +1,54 @@ +import argparse +import dataclasses + +from transformers import HfArgumentParser + +from common.arguments import ( + DataTrainingArguments, + LoggingArguments, + ModelArguments, + SageMakerArguments, + TrainingArguments, +) +from common.sagemaker_utils import launch_sagemaker_job + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser = HfArgumentParser( + ( + ModelArguments, + DataTrainingArguments, + TrainingArguments, + LoggingArguments, + SageMakerArguments, + ) + ) + model_args, data_args, train_args, log_args, sm_args = parser.parse_args_into_dataclasses() + + hyperparameters = dict() + for args in [model_args, data_args, train_args, log_args]: + for key, value in dataclasses.asdict(args).items(): + if value is not None: + hyperparameters[key] = value + hyperparameters["fsx_prefix"] = "/opt/ml/input/data/training" + + instance_abbr = { + "ml.p3dn.24xlarge": "p3dn", + "ml.p3.16xlarge": "p316", + "ml.g4dn.12xlarge": "g4dn", + }[sm_args.instance_type] + job_name = f"{sm_args.sm_job_name}-{sm_args.instance_count}x{instance_abbr}" + + launch_sagemaker_job( + hyperparameters=hyperparameters, + job_name=job_name, + source_dir=sm_args.source_dir, + entry_point=sm_args.entry_point, + instance_type=sm_args.instance_type, + instance_count=sm_args.instance_count, + role=sm_args.role, + image_name=sm_args.image_name, + fsx_id=sm_args.fsx_id, + subnet_ids=sm_args.subnet_ids, + security_group_ids=sm_args.security_group_ids, + ) diff --git a/models/nlp/albert/run_pretraining.py b/models/nlp/albert/run_pretraining.py index 6da0757b..35224e40 100644 --- a/models/nlp/albert/run_pretraining.py +++ b/models/nlp/albert/run_pretraining.py @@ -23,7 +23,6 @@ """ -import argparse import datetime import glob import logging @@ -33,20 +32,27 @@ import numpy as np import tensorflow as tf import tqdm +from tensorboard.plugins.hparams import api as hp from tensorflow_addons.optimizers import LAMB, AdamW from transformers import ( AutoConfig, GradientAccumulator, + HfArgumentParser, TFAlbertModel, TFAutoModelForPreTraining, TFBertForPreTraining, ) -from arguments import populate_pretraining_parser -from datasets import get_mlm_dataset -from learning_rate_schedules import LinearWarmupPolyDecaySchedule +from common.arguments import ( + DataTrainingArguments, + LoggingArguments, + ModelArguments, + TrainingArguments, +) +from common.datasets import get_mlm_dataset +from common.learning_rate_schedules import LinearWarmupPolyDecaySchedule +from common.utils import TqdmLoggingHandler, gather_indexes, rewrap_tf_function from run_squad import get_squad_results_while_pretraining -from utils import TqdmLoggingHandler, gather_indexes, rewrap_tf_function # See https://github.com/huggingface/transformers/issues/3782; this import must come last import horovod.tensorflow as hvd # isort:skip @@ -60,7 +66,7 @@ def get_squad_steps(extra_steps_str: str) -> List[int]: # fmt: off default_squad_steps = [ k * 1000 - for k in [5, 10, 20,40, 60, 80, 100, 120, 140, 160, 180, 200, 220, 240, 260, 280, 300, 320, 340, 360, 380, 400] + for k in [5, 10, 20, 40, 60, 80, 100, 120, 140, 160, 180, 200, 220, 240, 260, 280, 300, 320, 340, 360, 380, 400] ] # fmt: on return extra_squad_steps + default_squad_steps @@ -175,6 +181,7 @@ def allreduce(model, opt, gradient_accumulator, loss, mlm_loss, mlm_acc, sop_los # Placing before also gives a 20% speedup when training BERT-large, probably because the # gradient operations can be fused by XLA. (grads, grad_norm) = tf.clip_by_global_norm(grads, clip_norm=max_grad_norm) + weight_norm = tf.math.sqrt( tf.math.reduce_sum([tf.norm(var, ord=2) ** 2 for var in model.trainable_variables]) ) @@ -345,31 +352,29 @@ def get_checkpoint_paths_from_prefix(prefix: str) -> Tuple[str, str]: def main(): - parser = argparse.ArgumentParser() - populate_pretraining_parser(parser) - args = parser.parse_args() - tf.random.set_seed(args.seed) + parser = HfArgumentParser( + (ModelArguments, DataTrainingArguments, TrainingArguments, LoggingArguments) + ) + model_args, data_args, train_args, log_args = parser.parse_args_into_dataclasses() + + tf.random.set_seed(train_args.seed) tf.autograph.set_verbosity(0) # Settings init parse_bool = lambda arg: arg == "true" - max_predictions_per_seq = 20 - checkpoint_frequency = 5000 - validate_frequency = 2000 - histogram_frequency = 100 - do_gradient_accumulation = args.gradient_accumulation_steps > 1 - do_xla = not parse_bool(args.skip_xla) - do_eager = parse_bool(args.eager) - skip_sop = parse_bool(args.skip_sop) - skip_mlm = parse_bool(args.skip_mlm) - pre_layer_norm = parse_bool(args.pre_layer_norm) - fast_squad = parse_bool(args.fast_squad) - dummy_eval = parse_bool(args.dummy_eval) - squad_steps = get_squad_steps(args.extra_squad_steps) - is_sagemaker = args.fsx_prefix.startswith("/opt/ml") + do_gradient_accumulation = train_args.gradient_accumulation_steps > 1 + do_xla = not parse_bool(train_args.skip_xla) + do_eager = parse_bool(train_args.eager) + skip_sop = parse_bool(train_args.skip_sop) + skip_mlm = parse_bool(train_args.skip_mlm) + pre_layer_norm = parse_bool(model_args.pre_layer_norm) + fast_squad = parse_bool(log_args.fast_squad) + dummy_eval = parse_bool(log_args.dummy_eval) + squad_steps = get_squad_steps(log_args.extra_squad_steps) + is_sagemaker = data_args.fsx_prefix.startswith("/opt/ml") disable_tqdm = is_sagemaker global max_grad_norm - max_grad_norm = args.max_grad_norm + max_grad_norm = train_args.max_grad_norm # Horovod init hvd.init() @@ -394,32 +399,33 @@ def main(): loss_str = "" metadata = ( - f"{args.model_type}" - f"-{args.model_size}" - f"-{args.load_from}" + f"{model_args.model_type}" + f"-{model_args.model_size}" + f"-{model_args.load_from}" f"-{hvd.size()}gpus" - f"-{args.batch_size}batch" - f"-{args.gradient_accumulation_steps}accum" - f"-{args.learning_rate}maxlr" - f"-{args.end_learning_rate}endlr" - f"-{args.learning_rate_decay_power}power" - f"-{args.max_grad_norm}maxgrad" - f"-{args.optimizer}opt" - f"-{args.total_steps}steps" - f"-{args.max_seq_length}seq" + f"-{train_args.batch_size}batch" + f"-{train_args.gradient_accumulation_steps}accum" + f"-{train_args.learning_rate}maxlr" + f"-{train_args.end_learning_rate}endlr" + f"-{train_args.learning_rate_decay_power}power" + f"-{train_args.max_grad_norm}maxgrad" + f"-{train_args.optimizer}opt" + f"-{train_args.total_steps}steps" + f"-{data_args.max_seq_length}seq" + f"-{data_args.max_predictions_per_seq}preds" f"-{'preln' if pre_layer_norm else 'postln'}" f"{loss_str}" - f"-{args.hidden_dropout_prob}dropout" - f"-{args.seed}seed" + f"-{model_args.hidden_dropout_prob}dropout" + f"-{train_args.seed}seed" ) - run_name = f"{current_time}-{platform}-{metadata}-{args.name if args.name else 'unnamed'}" + run_name = f"{current_time}-{platform}-{metadata}-{train_args.name if train_args.name else 'unnamed'}" # Logging should only happen on a single process # https://stackoverflow.com/questions/9321741/printing-to-screen-and-writing-to-a-file-at-the-same-time level = logging.INFO format = "%(asctime)-15s %(name)-12s: %(levelname)-8s %(message)s" handlers = [ - logging.FileHandler(f"{args.fsx_prefix}/logs/albert/{run_name}.log"), + logging.FileHandler(f"{data_args.fsx_prefix}/logs/albert/{run_name}.log"), TqdmLoggingHandler(), ] logging.basicConfig(level=level, format=format, handlers=handlers) @@ -429,25 +435,25 @@ def main(): wrap_global_functions(do_gradient_accumulation) - if args.model_type == "albert": - model_desc = f"albert-{args.model_size}-v2" - elif args.model_type == "bert": - model_desc = f"bert-{args.model_size}-uncased" + if model_args.model_type == "albert": + model_desc = f"albert-{model_args.model_size}-v2" + elif model_args.model_type == "bert": + model_desc = f"bert-{model_args.model_size}-uncased" config = AutoConfig.from_pretrained(model_desc) config.pre_layer_norm = pre_layer_norm - config.hidden_dropout_prob = args.hidden_dropout_prob + config.hidden_dropout_prob = model_args.hidden_dropout_prob model = TFAutoModelForPreTraining.from_config(config) # Create optimizer and enable AMP loss scaling. schedule = LinearWarmupPolyDecaySchedule( - max_learning_rate=args.learning_rate, - end_learning_rate=args.end_learning_rate, - warmup_steps=args.warmup_steps, - total_steps=args.total_steps, - power=args.learning_rate_decay_power, + max_learning_rate=train_args.learning_rate, + end_learning_rate=train_args.end_learning_rate, + warmup_steps=train_args.warmup_steps, + total_steps=train_args.total_steps, + power=train_args.learning_rate_decay_power, ) - if args.optimizer == "lamb": + if train_args.optimizer == "lamb": opt = LAMB( learning_rate=schedule, weight_decay_rate=0.01, @@ -456,22 +462,24 @@ def main(): epsilon=1e-6, exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"], ) - elif args.optimizer == "adam": + elif train_args.optimizer == "adam": opt = AdamW(weight_decay=0.0, learning_rate=schedule) opt = tf.train.experimental.enable_mixed_precision_graph_rewrite(opt, loss_scale="dynamic") gradient_accumulator = GradientAccumulator() loaded_opt_weights = None - if args.load_from == "scratch": + if model_args.load_from == "scratch": pass - elif args.load_from.startswith("huggingface"): - assert args.model_type == "albert", "Only loading pretrained albert models is supported" - huggingface_name = f"albert-{args.model_size}-v2" - if args.load_from == "huggingface": + elif model_args.load_from.startswith("huggingface"): + assert ( + model_args.model_type == "albert" + ), "Only loading pretrained albert models is supported" + huggingface_name = f"albert-{model_args.model_size}-v2" + if model_args.load_from == "huggingface": albert = TFAlbertModel.from_pretrained(huggingface_name, config=config) model.albert = albert else: - model_ckpt, opt_ckpt = get_checkpoint_paths_from_prefix(args.checkpoint_path) + model_ckpt, opt_ckpt = get_checkpoint_paths_from_prefix(model_args.checkpoint_path) model = TFAutoModelForPreTraining.from_config(config) if hvd.rank() == 0: @@ -480,20 +488,25 @@ def main(): # We do not set the weights yet, we have to do a first step to initialize the optimizer. # Train filenames are [1, 2047], Val filenames are [0]. Note the different subdirectories - train_glob = f"{args.fsx_prefix}/albert_pretraining/tfrecords/train/max_seq_len_{args.max_seq_length}_max_predictions_per_seq_{max_predictions_per_seq}_masked_lm_prob_15/albert_*.tfrecord" - validation_glob = f"{args.fsx_prefix}/albert_pretraining/tfrecords/validation/max_seq_len_{args.max_seq_length}_max_predictions_per_seq_{max_predictions_per_seq}_masked_lm_prob_15/albert_*.tfrecord" + # Move to same folder structure and remove if/else + if model_args.model_type == "albert": + train_glob = f"{data_args.fsx_prefix}/albert_pretraining/tfrecords/train/max_seq_len_{data_args.max_seq_length}_max_predictions_per_seq_{data_args.max_predictions_per_seq}_masked_lm_prob_15/albert_*.tfrecord" + validation_glob = f"{data_args.fsx_prefix}/albert_pretraining/tfrecords/validation/max_seq_len_{data_args.max_seq_length}_max_predictions_per_seq_{data_args.max_predictions_per_seq}_masked_lm_prob_15/albert_*.tfrecord" + if model_args.model_type == "bert": + train_glob = f"{data_args.fsx_prefix}/bert_pretraining/max_seq_len_{data_args.max_seq_length}_max_predictions_per_seq_{data_args.max_predictions_per_seq}_masked_lm_prob_15/training/*.tfrecord" + validation_glob = f"{data_args.fsx_prefix}/bert_pretraining/max_seq_len_{data_args.max_seq_length}_max_predictions_per_seq_{data_args.max_predictions_per_seq}_masked_lm_prob_15/validation/*.tfrecord" train_filenames = glob.glob(train_glob) validation_filenames = glob.glob(validation_glob) train_dataset = get_mlm_dataset( filenames=train_filenames, - max_seq_length=args.max_seq_length, - max_predictions_per_seq=max_predictions_per_seq, - batch_size=args.batch_size, + max_seq_length=data_args.max_seq_length, + max_predictions_per_seq=data_args.max_predictions_per_seq, + batch_size=train_args.batch_size, ) # Of shape [batch_size, ...] # Batch of batches, helpful for gradient accumulation. Shape [grad_steps, batch_size, ...] - train_dataset = train_dataset.batch(args.gradient_accumulation_steps) + train_dataset = train_dataset.batch(train_args.gradient_accumulation_steps) # One iteration with 10 dupes, 8 nodes seems to be 60-70k steps. train_dataset = train_dataset.prefetch(buffer_size=8) @@ -501,14 +514,14 @@ def main(): if hvd.rank() == 0: validation_dataset = get_mlm_dataset( filenames=validation_filenames, - max_seq_length=args.max_seq_length, - max_predictions_per_seq=max_predictions_per_seq, - batch_size=args.batch_size, + max_seq_length=data_args.max_seq_length, + max_predictions_per_seq=data_args.max_predictions_per_seq, + batch_size=train_args.batch_size, ) # validation_dataset = validation_dataset.batch(1) validation_dataset = validation_dataset.prefetch(buffer_size=8) - pbar = tqdm.tqdm(args.total_steps, disable=disable_tqdm) + pbar = tqdm.tqdm(train_args.total_steps, disable=disable_tqdm) summary_writer = None # Only create a writer if we make it through a successful step logger.info(f"Starting training, job name {run_name}") @@ -522,7 +535,7 @@ def main(): opt=opt, gradient_accumulator=gradient_accumulator, batch=batch, - gradient_accumulation_steps=args.gradient_accumulation_steps, + gradient_accumulation_steps=train_args.gradient_accumulation_steps, skip_sop=skip_sop, skip_mlm=skip_mlm, ) @@ -535,17 +548,17 @@ def main(): hvd.broadcast_variables(opt.variables(), root_rank=0) i = opt.get_weights()[0] - 1 - is_final_step = i >= args.total_steps - 1 + is_final_step = i >= train_args.total_steps - 1 do_squad = i in squad_steps or is_final_step # Squad requires all the ranks to train, but results are only returned on rank 0 if do_squad: squad_results = get_squad_results_while_pretraining( model=model, - model_size=args.model_size, - fsx_prefix=args.fsx_prefix, + model_size=model_args.model_size, + fsx_prefix=data_args.fsx_prefix, step=i, - fast=args.fast_squad, - dummy_eval=args.dummy_eval, + fast=log_args.fast_squad, + dummy_eval=log_args.dummy_eval, ) if hvd.rank() == 0: squad_exact, squad_f1 = squad_results["exact"], squad_results["f1"] @@ -554,9 +567,9 @@ def main(): wrap_global_functions(do_gradient_accumulation) if hvd.rank() == 0: - do_log = i % args.log_frequency == 0 - do_checkpoint = (i % checkpoint_frequency == 0) or is_final_step - do_validation = (i % validate_frequency == 0) or is_final_step + do_log = i % log_args.log_frequency == 0 + do_checkpoint = ((i > 0) and (i % log_args.checkpoint_frequency == 0)) or is_final_step + do_validation = ((i > 0) and (i % log_args.validation_frequency == 0)) or is_final_step pbar.update(1) description = f"Loss: {loss:.3f}, MLM: {mlm_loss:.3f}, SOP: {sop_loss:.3f}, MLM_acc: {mlm_acc:.3f}, SOP_acc: {sop_acc:.3f}" @@ -566,12 +579,12 @@ def main(): if i == 0: logger.info(f"First step: {elapsed_time:.3f} secs") else: - it_per_sec = args.log_frequency / elapsed_time + it_per_sec = log_args.log_frequency / elapsed_time logger.info(f"Train step {i} -- {description} -- It/s: {it_per_sec:.2f}") start_time = time.perf_counter() if do_checkpoint: - checkpoint_prefix = f"{args.fsx_prefix}/checkpoints/albert/{run_name}-step{i}" + checkpoint_prefix = f"{data_args.fsx_prefix}/checkpoints/albert/{run_name}-step{i}" model_ckpt = f"{checkpoint_prefix}.ckpt" opt_ckpt = f"{checkpoint_prefix}-opt.npy" logger.info(f"Saving model at {model_ckpt}, optimizer at {opt_ckpt}") @@ -595,8 +608,47 @@ def main(): # Create summary_writer after the first step if summary_writer is None: summary_writer = tf.summary.create_file_writer( - f"{args.fsx_prefix}/logs/albert/{run_name}" + f"{data_args.fsx_prefix}/logs/albert/{run_name}" ) + with summary_writer.as_default(): + HP_MODEL_TYPE = hp.HParam("model_type", hp.Discrete(["albert", "bert"])) + HP_MODEL_SIZE = hp.HParam("model_size", hp.Discrete(["base", "large"])) + HP_LEARNING_RATE = hp.HParam("learning_rate", hp.RealInterval(1e-5, 1e-1)) + HP_BATCH_SIZE = hp.HParam("global_batch_size", hp.IntInterval(1, 64)) + HP_PRE_LAYER_NORM = hp.HParam("pre_layer_norm", hp.Discrete([True, False])) + HP_HIDDEN_DROPOUT = hp.HParam("hidden_dropout") + hparams = [ + HP_MODEL_TYPE, + HP_MODEL_SIZE, + HP_BATCH_SIZE, + HP_LEARNING_RATE, + HP_PRE_LAYER_NORM, + HP_HIDDEN_DROPOUT, + ] + + HP_F1 = hp.Metric("squad_f1") + HP_EXACT = hp.Metric("squad_exact") + HP_MLM = hp.Metric("val_mlm_acc") + HP_SOP = hp.Metric("val_sop_acc") + HP_TRAIN_LOSS = hp.Metric("train_loss") + HP_VAL_LOSS = hp.Metric("val_loss") + metrics = [HP_TRAIN_LOSS, HP_VAL_LOSS, HP_F1, HP_EXACT, HP_MLM, HP_SOP] + + hp.hparams_config( + hparams=hparams, metrics=metrics, + ) + hp.hparams( + { + HP_MODEL_TYPE: model_args.model_type, + HP_MODEL_SIZE: model_args.model_size, + HP_LEARNING_RATE: train_args.learning_rate, + HP_BATCH_SIZE: train_args.batch_size * hvd.size(), + HP_PRE_LAYER_NORM: model_args.pre_layer_norm == "true", + HP_HIDDEN_DROPOUT: model_args.hidden_dropout_prob, + }, + trial_id=run_name, + ) + # Log to TensorBoard with summary_writer.as_default(): tf.summary.scalar("weight_norm", weight_norm, step=i) diff --git a/models/nlp/albert/run_squad.py b/models/nlp/albert/run_squad.py index 36de045f..67d52612 100644 --- a/models/nlp/albert/run_squad.py +++ b/models/nlp/albert/run_squad.py @@ -10,7 +10,6 @@ """ -import argparse import datetime import logging import math @@ -24,6 +23,7 @@ from transformers import ( AlbertTokenizer, AutoConfig, + HfArgumentParser, PretrainedConfig, PreTrainedTokenizer, TFAutoModelForQuestionAnswering, @@ -37,11 +37,16 @@ SquadV2Processor, ) -from arguments import populate_squad_parser -from learning_rate_schedules import LinearWarmupPolyDecaySchedule -from models import load_qa_from_pretrained +from common.arguments import ( + DataTrainingArguments, + LoggingArguments, + ModelArguments, + TrainingArguments, +) +from common.learning_rate_schedules import LinearWarmupPolyDecaySchedule +from common.models import load_qa_from_pretrained +from common.utils import TqdmLoggingHandler, f1_score, get_dataset, get_tokenizer from run_squad_evaluation import get_evaluation_metrics -from utils import f1_score, get_dataset, get_tokenizer # See https://github.com/huggingface/transformers/issues/3782; this import must come last import horovod.tensorflow as hvd # isort:skip @@ -186,7 +191,7 @@ def print_eval_metrics(results, step) -> None: f"HasAnsEM: {results['HasAns_exact']:.3f}, HasAnsF1: {results['HasAns_f1']:.3f}, " f"NoAnsEM: {results['NoAns_exact']:.3f}, NoAnsF1: {results['NoAns_f1']:.3f}\n" ) - print(description) + logger.info(description) def tensorboard_eval_metrics(summary_writer, results: Dict, step: int) -> None: @@ -246,7 +251,6 @@ def get_squad_results_while_pretraining( pre_layer_norm=cloned_model.config.pre_layer_norm, model_size=model_size, load_from=cloned_model, - load_step=None, batch_size=per_gpu_batch_size, # This will be less than 3, so no OOM errors checkpoint_frequency=None, validate_frequency=None, @@ -268,7 +272,6 @@ def run_squad_and_get_results( pre_layer_norm: bool, model_size: str, load_from: Union[str, tf.keras.Model], - load_step: int, batch_size: int, checkpoint_frequency: Optional[int], validate_frequency: Optional[int], @@ -281,6 +284,8 @@ def run_squad_and_get_results( ) -> Dict: checkpoint_frequency = checkpoint_frequency or 1000000 validate_frequency = validate_frequency or 1000000 + is_sagemaker = fsx_prefix.startswith("/opt/ml") + disable_tqdm = is_sagemaker if isinstance(load_from, tf.keras.Model): config = load_from.config @@ -343,8 +348,8 @@ def run_squad_and_get_results( ) if hvd.rank() == 0: - print("Starting finetuning") - pbar = tqdm.tqdm(total_steps) + logger.info("Starting finetuning") + pbar = tqdm.tqdm(total_steps, disable=disable_tqdm) summary_writer = None # Only create a writer if we make it through a successful step val_dataset = get_dataset( tokenizer=tokenizer, @@ -383,7 +388,7 @@ def run_squad_and_get_results( pbar.set_description(description) if do_validate: - print("Running validation") + logger.info("Running validation") ( val_loss, val_acc, @@ -396,8 +401,8 @@ def run_squad_and_get_results( f"Step {step} validation - Loss: {val_loss:.3f}, Acc: {val_acc:.3f}, " f"EM: {val_exact_match:.3f}, F1: {val_f1:.3f}" ) - print(description) - print("Running evaluation") + logger.info(description) + logger.info("Running evaluation") if dummy_eval: results = { "exact": 0.8169797018445212, @@ -424,7 +429,7 @@ def run_squad_and_get_results( checkpoint_path = ( f"{fsx_prefix}/checkpoints/albert-squad/{run_name}-step{step}.ckpt" ) - print(f"Saving checkpoint at {checkpoint_path}") + logger.info(f"Saving checkpoint at {checkpoint_path}") model.save_weights(checkpoint_path) if summary_writer is None: @@ -457,17 +462,26 @@ def run_squad_and_get_results( # Can we return a value only on a single rank? if hvd.rank() == 0: pbar.close() - print(f"Finished finetuning, job name {run_name}") + logger.info(f"Finished finetuning, job name {run_name}") return results def main(): - parser = argparse.ArgumentParser() - populate_squad_parser(parser) - args = parser.parse_args() - tf.random.set_seed(args.seed) + parser = HfArgumentParser( + (ModelArguments, DataTrainingArguments, TrainingArguments, LoggingArguments) + ) + model_args, data_args, train_args, log_args = parser.parse_args_into_dataclasses() + + tf.random.set_seed(train_args.seed) tf.autograph.set_verbosity(0) + level = logging.INFO + format = "%(asctime)-15s %(name)-12s: %(levelname)-8s %(message)s" + handlers = [ + TqdmLoggingHandler(), + ] + logging.basicConfig(level=level, format=format, handlers=handlers) + # Horovod init hvd.init() gpus = tf.config.list_physical_devices("GPU") @@ -477,45 +491,44 @@ def main(): tf.config.set_visible_devices(gpus[hvd.local_rank()], "GPU") # XLA, AMP, AutoGraph parse_bool = lambda arg: arg == "true" - tf.config.optimizer.set_jit(not parse_bool(args.skip_xla)) - tf.config.experimental_run_functions_eagerly(parse_bool(args.eager)) + tf.config.optimizer.set_jit(not parse_bool(train_args.skip_xla)) + tf.config.experimental_run_functions_eagerly(parse_bool(train_args.eager)) if hvd.rank() == 0: # Run name should only be used on one process to avoid race conditions current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") - platform = "eks" if args.fsx_prefix == "/fsx" else "sm" - if args.load_from.startswith("amazon"): - load_name = f"{args.load_from}{args.load_step}" + platform = "eks" if data_args.fsx_prefix == "/fsx" else "sm" + if model_args.load_from.startswith("amazon"): + load_name = f"{model_args.load_from}" else: - load_name = args.load_from - run_name = f"{current_time}-{platform}-{args.model_size}-{args.dataset}-{load_name}-{hvd.size()}gpus-{args.batch_size}batch-{args.learning_rate}lr-{args.name}" + load_name = model_args.load_from + run_name = f"{current_time}-{platform}-{model_args.model_size}-{data_args.task_name}-{load_name}-{hvd.size()}gpus-{train_args.batch_size}batch-{train_args.learning_rate}lr-{train_args.name}" else: # We only use run_name on rank 0, but need all ranks to pass a value in function args run_name = None - if args.model_type == "albert": - model_desc = f"albert-{args.model_size}-v2" + if model_args.model_type == "albert": + model_desc = f"albert-{model_args.model_size}-v2" else: - model_desc = f"bert-{args.model_size}-uncased" + model_desc = f"bert-{model_args.model_size}-uncased" results = run_squad_and_get_results( run_name=run_name, - fsx_prefix=args.fsx_prefix, - pre_layer_norm=parse_bool(args.pre_layer_norm), - model_size=args.model_size, - load_from=args.load_from, - load_step=args.load_step, - batch_size=args.batch_size, - checkpoint_frequency=args.checkpoint_frequency, - validate_frequency=args.validate_frequency, - learning_rate=args.learning_rate, - warmup_steps=args.warmup_steps, - total_steps=args.total_steps, - dataset=args.dataset, + fsx_prefix=data_args.fsx_prefix, + pre_layer_norm=parse_bool(model_args.pre_layer_norm), + model_size=model_args.model_size, + load_from=model_args.load_from, + batch_size=train_args.batch_size, + checkpoint_frequency=log_args.checkpoint_frequency, + validate_frequency=log_args.validation_frequency, + learning_rate=train_args.learning_rate, + warmup_steps=train_args.warmup_steps, + total_steps=train_args.total_steps, + dataset=data_args.task_name, config=AutoConfig.from_pretrained(model_desc), ) if hvd.rank() == 0: - print(results) + logger.info(results) if __name__ == "__main__": diff --git a/models/nlp/albert/run_squad_evaluation.py b/models/nlp/albert/run_squad_evaluation.py index 26e24d3d..ae607bfd 100644 --- a/models/nlp/albert/run_squad_evaluation.py +++ b/models/nlp/albert/run_squad_evaluation.py @@ -14,11 +14,16 @@ SquadV2Processor, ) -from utils import get_dataset, get_tokenizer +from common.utils import get_dataset, get_tokenizer def get_evaluation_metrics( - model, data_dir: str, filename: str, batch_size: int = 32, num_batches: int = None, + model, + data_dir: str, + filename: str, + batch_size: int = 32, + num_batches: int = None, + disable_tqdm: bool = False, ) -> Dict[str, "Number"]: """ Return an OrderedDict in the format: @@ -72,6 +77,7 @@ def get_evaluation_metrics( features=features, batch_size=batch_size, num_batches=num_batches, + disable_tqdm=disable_tqdm, ) write_prediction_files = False @@ -110,11 +116,12 @@ def get_squad_results( features: List[SquadFeatures], batch_size: int, num_batches: int, + disable_tqdm: bool, ) -> List[SquadResult]: results = [] total_steps = math.ceil(len(features) / batch_size) - pbar = tqdm.tqdm(total=total_steps) + pbar = tqdm.tqdm(total=total_steps, disable=disable_tqdm) pbar.set_description(f"Evaluating with batch size {batch_size}") if num_batches: @@ -169,6 +176,10 @@ def get_squad_results( val_filename = "dev-v2.0.json" results = get_evaluation_metrics( - model=model, data_dir=data_dir, filename=val_filename, batch_size=args.batch_size + model=model, + data_dir=data_dir, + filename=val_filename, + batch_size=args.batch_size, + disable_tqdm=False, ) print(dict(results)) diff --git a/models/nlp/albert/sagemaker_pretraining.py b/models/nlp/albert/sagemaker_pretraining.py deleted file mode 100644 index ded17d6d..00000000 --- a/models/nlp/albert/sagemaker_pretraining.py +++ /dev/null @@ -1,51 +0,0 @@ -import argparse - -from arguments import populate_pretraining_parser, populate_sagemaker_parser -from sagemaker_utils import launch_sagemaker_job, pop_sagemaker_args - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - populate_sagemaker_parser(parser) - populate_pretraining_parser(parser) - args = parser.parse_args() - - args_dict = args.__dict__ - # Pop off the SageMaker parameters - ( - source_dir, - entry_point, - role, - image_name, - fsx_id, - subnet_ids, - security_group_ids, - instance_type, - instance_count, - ) = pop_sagemaker_args(args_dict) - # Only the script parameters remain - hyperparameters = dict() - for key, value in args_dict.items(): - if value is not None: - hyperparameters[key] = value - hyperparameters["fsx_prefix"] = "/opt/ml/input/data/training" - - instance_abbr = { - "ml.p3dn.24xlarge": "p3dn", - "ml.p3.16xlarge": "p316", - "ml.g4dn.12xlarge": "g4dn", - }[instance_type] - job_name = f"albert-pretrain-{instance_count}x{instance_abbr}" - - launch_sagemaker_job( - job_name=job_name, - source_dir=source_dir, - entry_point=entry_point, - instance_type=instance_type, - instance_count=instance_count, - hyperparameters=hyperparameters, - role=role, - image_name=image_name, - fsx_id=fsx_id, - subnet_ids=subnet_ids, - security_group_ids=security_group_ids, - ) diff --git a/models/nlp/albert/sagemaker_squad.py b/models/nlp/albert/sagemaker_squad.py deleted file mode 100644 index abe3bed4..00000000 --- a/models/nlp/albert/sagemaker_squad.py +++ /dev/null @@ -1,51 +0,0 @@ -import argparse - -from arguments import populate_sagemaker_parser, populate_squad_parser -from sagemaker_utils import launch_sagemaker_job, pop_sagemaker_args - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - populate_sagemaker_parser(parser) - populate_squad_parser(parser) - args = parser.parse_args() - - args_dict = args.__dict__ - # Pop off the SageMaker parameters - ( - source_dir, - entry_point, - role, - image_name, - fsx_id, - subnet_ids, - security_group_ids, - instance_type, - instance_count, - ) = pop_sagemaker_args(args_dict) - # Only the script parameters remain - hyperparameters = dict() - for key, value in args_dict.items(): - if value is not None: - hyperparameters[key] = value - hyperparameters["fsx_prefix"] = "/opt/ml/input/data/training" - - instance_abbr = { - "ml.p3dn.24xlarge": "p3dn", - "ml.p3.16xlarge": "p316", - "ml.g4dn.12xlarge": "g4dn", - }[instance_type] - job_name = f"squad-{instance_count}x{instance_abbr}-{args.load_from}" - - launch_sagemaker_job( - job_name=job_name, - source_dir=source_dir, - entry_point=entry_point, - instance_type=instance_type, - instance_count=instance_count, - hyperparameters=hyperparameters, - role=role, - image_name=image_name, - fsx_id=fsx_id, - subnet_ids=subnet_ids, - security_group_ids=security_group_ids, - ) diff --git a/models/nlp/common/arguments.py b/models/nlp/common/arguments.py new file mode 100644 index 00000000..c0f7dc0c --- /dev/null +++ b/models/nlp/common/arguments.py @@ -0,0 +1,152 @@ +""" +Since arguments are duplicated in run_pretraining.py and sagemaker_pretraining.py, they have +been abstracted into this file. It also makes the training scripts much shorter. +""" + +import dataclasses +import json +import logging +import os +from dataclasses import dataclass, field +from typing import Any, Dict + +logger = logging.getLogger(__name__) + + +@dataclass +class ModelArguments: + """ + ModelArguments is the subset of arguments relating to the model instantiation. + So config options such as dropout fall under this, but skip_xla does not because it is + used at training time. + """ + + model_type: str = field(default="albert", metadata={"choices": ["albert", "bert"]}) + model_size: str = field(default="base", metadata={"choices": ["base", "large"]}) + load_from: str = field( + default="scratch", metadata={"choices": ["scratch", "checkpoint", "huggingface"]} + ) + checkpoint_path: str = field( + default=None, + metadata={ + "help": "For example, `/fsx/checkpoints/albert/2020..step125000`. No .ckpt on the end." + }, + ) + pre_layer_norm: str = field( + default=None, + metadata={ + "choices": ["true"], + "help": "Place layer normalization before the attention & FFN, rather than after adding the residual connection. https://openreview.net/pdf?id=B1x8anVFPr", + }, + ) + hidden_dropout_prob: float = field(default=0.0) + + +@dataclass +class DataTrainingArguments: + task_name: str = field(default="squadv2", metadata={"choices": ["squadv1", "squadv2"]}) + max_seq_length: int = field(default=512, metadata={"choices": [128, 512]}) + max_predictions_per_seq: int = field(default=20, metadata={"choices": [20, 80]}) + fsx_prefix: str = field( + default="/fsx", + metadata={ + "choices": ["/fsx", "/opt/ml/input/data/training"], + "help": "Change to /opt/ml/input/data/training on SageMaker", + }, + ) + + +@dataclass +class TrainingArguments: + """ + TrainingArguments is the subset of the arguments we use in our example scripts + **which relate to the training loop itself**. + + Using `HfArgumentParser` we can turn this class + into argparse arguments to be able to specify them on + the command line. + """ + + # model_dir: str = field(default=None, metadata={"help": "Unused, but passed by SageMaker"}) + seed: int = field(default=42) + # TODO: Change this to per_gpu_train_batch_size + batch_size: int = field(default=32) + gradient_accumulation_steps: int = field( + default=1, + metadata={ + "help": "Number of updates steps to accumulate before performing a backward/update pass." + }, + ) + optimizer: str = field(default="lamb", metadata={"choices": ["lamb", "adam"]}) + warmup_steps: int = field(default=3125) + total_steps: int = field(default=125000) + learning_rate: float = field(default=0.00176) + end_learning_rate: float = field(default=3e-5) + learning_rate_decay_power: float = field(default=1.0) + max_grad_norm: float = field(default=1.0) + name: str = field(default="", metadata={"help": "Additional info to append to metadata"}) + + skip_xla: str = field(default=None, metadata={"choices": ["true"]}) + eager: str = field(default=None, metadata={"choices": ["true"]}) + skip_sop: str = field(default=None, metadata={"choices": ["true"]}) + skip_mlm: str = field(default=None, metadata={"choices": ["true"]}) + + def to_json_string(self): + """ + Serializes this instance to a JSON string. + """ + return json.dumps(dataclasses.asdict(self), indent=2) + + def to_sanitized_dict(self) -> Dict[str, Any]: + """ + Sanitized serialization to use with TensorBoard’s hparams + """ + d = dataclasses.asdict(self) + valid_types = [bool, int, float, str] + return {k: v if type(v) in valid_types else str(v) for k, v in d.items()} + + +@dataclass +class LoggingArguments: + log_frequency: int = field(default=1000) + validation_frequency: int = field(default=2000) + checkpoint_frequency: int = field(default=5000) + extra_squad_steps: str = field(default=None) + fast_squad: str = field(default=None, metadata={"choices": ["true"]}) + dummy_eval: str = field(default=None, metadata={"choices": ["true"]}) + + +@dataclass +class SageMakerArguments: + source_dir: str = field( + metadata={ + "help": "For example, /Users/myusername/Desktop/deep-learning-models/models/nlp/albert" + } + ) + entry_point: str = field(metadata={"help": "For example, run_pretraining.py or run_squad.py"}) + sm_job_name: str = field(default="albert") + + role: str = field(default=None) + image_name: str = field(default=None) + fsx_id: str = field(default=None) + subnet_ids: str = field(default=None, metadata={"help": "Comma-separated string"}) + security_group_ids: str = field(default=None, metadata={"help": "Comma-separated string"}) + instance_type: str = field( + default="ml.p3dn.24xlarge", + metadata={"choices": ["ml.p3dn.24xlarge", "ml.p3.16xlarge", "ml.g4dn.12xlarge"]}, + ) + instance_count: int = field(default=1) + + def __post_init__(self): + # Dataclass are evaluated at import-time, so we need to wrap these in a post-init method + # in case the env-vars don't exist. + self.role = self.role or os.environ["SAGEMAKER_ROLE"] + self.image_name = self.image_name or os.environ["SAGEMAKER_IMAGE_NAME"] + self.fsx_id = self.fsx_id or os.environ["SAGEMAKER_FSX_ID"] + self.subnet_ids = self.subnet_ids or os.environ["SAGEMAKER_SUBNET_IDS"] + self.security_group_ids = ( + self.security_group_ids or os.environ["SAGEMAKER_SECURITY_GROUP_IDS"] + ) + + self.subnet_ids = self.subnet_ids.replace(" ", "").split(",") + self.security_group_ids = self.security_group_ids.replace(" ", "").split(",") diff --git a/models/nlp/albert/datasets.py b/models/nlp/common/datasets.py similarity index 100% rename from models/nlp/albert/datasets.py rename to models/nlp/common/datasets.py diff --git a/models/nlp/albert/learning_rate_schedules.py b/models/nlp/common/learning_rate_schedules.py similarity index 100% rename from models/nlp/albert/learning_rate_schedules.py rename to models/nlp/common/learning_rate_schedules.py diff --git a/models/nlp/albert/models.py b/models/nlp/common/models.py similarity index 100% rename from models/nlp/albert/models.py rename to models/nlp/common/models.py diff --git a/models/nlp/albert/sagemaker_utils.py b/models/nlp/common/sagemaker_utils.py similarity index 100% rename from models/nlp/albert/sagemaker_utils.py rename to models/nlp/common/sagemaker_utils.py index 8cae0dfd..921356ae 100644 --- a/models/nlp/albert/sagemaker_utils.py +++ b/models/nlp/common/sagemaker_utils.py @@ -43,12 +43,12 @@ def pop_sagemaker_args(args_dict: Dict) -> Tuple: def launch_sagemaker_job( + hyperparameters: Dict[str, Any], job_name: str, source_dir: str, entry_point: str, instance_type: str, instance_count: int, - hyperparameters: Dict[str, Any], role: str, image_name: str, fsx_id: str, diff --git a/models/nlp/albert/utils.py b/models/nlp/common/utils.py similarity index 100% rename from models/nlp/albert/utils.py rename to models/nlp/common/utils.py diff --git a/models/nlp/docker/hvd_kubernetes.Dockerfile b/models/nlp/docker/hvd_kubernetes.Dockerfile index 97059ed0..bdd558ed 100644 --- a/models/nlp/docker/hvd_kubernetes.Dockerfile +++ b/models/nlp/docker/hvd_kubernetes.Dockerfile @@ -10,8 +10,8 @@ ENV CUDNN_VERSION=7.6.5.32-1+cuda10.1 ENV NCCL_VERSION=2.4.8-1+cuda10.1 ENV MXNET_VERSION=1.6.0 -# Python 2.7 or 3.6 is supported by Ubuntu Bionic out of the box -ARG python=3.6 +# Python 3.6 is supported by Ubuntu Bionic out of the box +ARG python=3.7 ENV PYTHON_VERSION=${python} # Set default shell to /bin/bash @@ -33,13 +33,11 @@ RUN apt-get update && apt-get install -y --allow-downgrades --allow-change-held- libpng-dev \ python${PYTHON_VERSION} \ python${PYTHON_VERSION}-dev \ + python${PYTHON_VERSION}-distutils \ librdmacm1 \ libibverbs1 \ ibverbs-providers -RUN if [[ "${PYTHON_VERSION}" == "3.6" ]]; then \ - apt-get install -y python${PYTHON_VERSION}-distutils; \ - fi RUN ln -s /usr/bin/python${PYTHON_VERSION} /usr/bin/python RUN curl -O https://bootstrap.pypa.io/get-pip.py && \ @@ -93,15 +91,20 @@ WORKDIR "/examples" ###### Modifications to horovod Dockerfile below +# tensorflow_addons is tightly coupled to TF version. TF 2.1 = 0.9.1, TF 2.2 = 0.10.0 RUN pip install --no-cache-dir --upgrade pip && \ pip install --no-cache-dir \ scikit-learn \ gputil \ requests \ tensorflow-addons==0.9.1 -# TODO: Why does installing torch break TF XLA support? ENV HDF5_USE_FILE_LOCKING "FALSE" WORKDIR /fsx CMD ["/bin/bash"] + +# When you enter this file, you'll need to run two commands manually: +# pip install -e /fsx/transformers +# export PYTHONPATH="${PATH}:/fsx/deep-learning-models/models/nlp" +# These are done in the MPIJob launch script when using Kubernetes, but not for a shell. diff --git a/models/nlp/docker/ngc_sagemaker.Dockerfile b/models/nlp/docker/ngc_sagemaker.Dockerfile index afca4735..bb4240d7 100644 --- a/models/nlp/docker/ngc_sagemaker.Dockerfile +++ b/models/nlp/docker/ngc_sagemaker.Dockerfile @@ -23,4 +23,8 @@ RUN pip install --no-cache-dir \ mpi4py \ sagemaker-containers \ tensorflow-addons==0.9.1 +# TODO: Why does installing torch break TF XLA support? + RUN pip install git+git://github.com/jarednielsen/transformers.git@tfsquad +ENV PYTHONPATH "${PYTHONPATH}:/fsx/deep-learning-models/models/nlp" +ENV PYTHONPATH "${PYTHONPATH}:/opt/ml/input/data/training/deep-learning-models/models/nlp" diff --git a/setup.cfg b/setup.cfg index a403a51c..e4949cac 100644 --- a/setup.cfg +++ b/setup.cfg @@ -4,4 +4,4 @@ include_trailing_comma=true multi_line_output=3 use_parentheses=true default_section=THIRDPARTY -known_first_party=arguments,datasets,learning_rate_schedules,models,utils,run_pretraining,run_squad_evaluation,run_squad,sagemaker_pretraining,sagemaker_squad,sagemaker_utils,utils +known_first_party=common,run_pretraining,run_squad_evaluation,run_squad