diff --git a/src/brevitas_examples/llm/README.md b/src/brevitas_examples/llm/README.md index 06effc7b9..838eaed03 100644 --- a/src/brevitas_examples/llm/README.md +++ b/src/brevitas_examples/llm/README.md @@ -40,7 +40,7 @@ usage: main.py [-h] [--model MODEL] [--seed SEED] [--nsamples NSAMPLES] [--act-equalization {None,layerwise,fx}] [--load-awq LOAD_AWQ] [--export-target {None,onnx_qcdq,torch_qcdq,sharded_torchmlir_group_weight,sharded_packed_torchmlir_group_weight}] [--export-prefix EXPORT_PREFIX] - [--checkpoint-name CHECKPOINT_NAME] + [--checkpoint-name CHECKPOINT_NAME] [--fuse-sequences] options: -h, --help show this help message and exit @@ -131,5 +131,10 @@ options: --checkpoint-name CHECKPOINT_NAME Filename to save checkpoint. If `None`, no checkpoint is saved (default: None) + --fuse-sequences Whether to merge the dataset sequences in case they + are shorter than the requested number of samples per + sequence. This is useful in case you would like to + quantize or evaluate on long sequences (default: + False). ``` diff --git a/src/brevitas_examples/llm/main.py b/src/brevitas_examples/llm/main.py index 67c8144c7..cf6f01895 100644 --- a/src/brevitas_examples/llm/main.py +++ b/src/brevitas_examples/llm/main.py @@ -6,6 +6,7 @@ import argparse import re import sys +from warnings import warn import numpy as np from optimum.amd.brevitas.accelerate_utils import offload_model @@ -109,6 +110,13 @@ def validate(args): else: assert args.export_target != 'torch_qcdq', "Cannot export Torch QCDQ with FX" + if not args.fuse_sequences: + # 350 is approximately the 99% percentile for the sequence length in WikiText2 (train partition, using AutoTokenizer) + if args.seqlen >= 350: + warn( + "Data loading can take a long time or, potentially, enter an infinite loop. Consider setting --args.fuse_sequences " + "or decreasing the sequence length (seqlen)") + def main(args): validate(args) @@ -142,7 +150,6 @@ def main(args): apply_awq(model, awq_results) require_fx = True if args.weight_equalization or args.act_equalization == 'fx' or args.ln_affine_merge else False - fuse_sequences = False # Load the data for calibration and evaluation. calibration_loader = get_dataset_for_model( @@ -155,7 +162,7 @@ def main(args): seed=args.seed, require_fx=require_fx, device=None, - fuse_sequences=fuse_sequences, + fuse_sequences=args.fuse_sequences, ) validation_loader = get_dataset_for_model( @@ -168,7 +175,7 @@ def main(args): seed=args.seed, require_fx=require_fx, device=None, - fuse_sequences=fuse_sequences, + fuse_sequences=args.fuse_sequences, ) device = next(iter(model.parameters())).device @@ -474,6 +481,13 @@ def parse_args(args): default=None, help="Filename to save checkpoint. If `None`, no checkpoint is saved (default: %(default)s)" ) + parser.add_argument( + "--fuse-sequences", + action="store_true", + default=False, + help= + "Whether to merge the dataset sequences in case they are shorter than the requested number of samples per sequence. This is useful in case you would like to quantize or evaluate on long sequences (default: %(default)s).", + ) return parser.parse_args(args)