From 1d126db59e5f782d4997945edb805d2a1045df15 Mon Sep 17 00:00:00 2001 From: Pablo Monteagudo Lago Date: Tue, 8 Oct 2024 11:07:40 +0100 Subject: [PATCH 1/3] Add argument to fuse sequences in data loading --- src/brevitas_examples/llm/main.py | 20 +++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/src/brevitas_examples/llm/main.py b/src/brevitas_examples/llm/main.py index e19390774..7efa02741 100644 --- a/src/brevitas_examples/llm/main.py +++ b/src/brevitas_examples/llm/main.py @@ -6,6 +6,7 @@ import argparse import re import sys +from warnings import warn import numpy as np from optimum.amd.brevitas.accelerate_utils import offload_model @@ -108,6 +109,14 @@ def validate(args): assert args.export_target != 'onnx_qcdq', "Cannot export ONNX QCDQ with FX + MHA replacing" else: assert args.export_target != 'torch_qcdq', "Cannot export Torch QCDQ with FX" + + if not args.fuse_sequences: + # 350 is approximately the 99% percentile for the sequence length in WikiText2 (train partition, using AutoTokenizer) + if args.seqlen >= 350: + warn( + "Data loading can take a long time or, potentially, enter an infinite loop. Consider setting --args.fuse_sequences " + "or decreasing the sequence length (seqlen)" + ) def main(args): @@ -142,7 +151,6 @@ def main(args): apply_awq(model, awq_results) require_fx = True if args.weight_equalization or args.act_equalization == 'fx' or args.ln_affine_merge else False - fuse_sequences = False # Load the data for calibration and evaluation. calibration_loader = get_dataset_for_model( @@ -155,7 +163,7 @@ def main(args): seed=args.seed, require_fx=require_fx, device=None, - fuse_sequences=fuse_sequences, + fuse_sequences=args.fuse_sequences, ) validation_loader = get_dataset_for_model( @@ -168,7 +176,7 @@ def main(args): seed=args.seed, require_fx=require_fx, device=None, - fuse_sequences=fuse_sequences, + fuse_sequences=args.fuse_sequences, ) device = next(iter(model.parameters())).device @@ -467,6 +475,12 @@ def parse_args(args): default=None, help="Filename to save checkpoint. If `None`, no checkpoint is saved (default: %(default)s)" ) + parser.add_argument( + "--fuse-sequences", + action="store_true", + default=False, + help="Whether to merge the dataset sequences in case they are shorter than the requested number of samples per sequence. This is useful in case you would like to quantize or evaluate on long sequences (default: %(default)s).", + ) return parser.parse_args(args) From 5131e4d2f0675151166904e8c2fb8bf7b94d425c Mon Sep 17 00:00:00 2001 From: Pablo Monteagudo Lago Date: Tue, 8 Oct 2024 13:53:04 +0100 Subject: [PATCH 2/3] Pre-commit passed --- src/brevitas_examples/llm/main.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/brevitas_examples/llm/main.py b/src/brevitas_examples/llm/main.py index 7efa02741..e8d46a491 100644 --- a/src/brevitas_examples/llm/main.py +++ b/src/brevitas_examples/llm/main.py @@ -109,14 +109,13 @@ def validate(args): assert args.export_target != 'onnx_qcdq', "Cannot export ONNX QCDQ with FX + MHA replacing" else: assert args.export_target != 'torch_qcdq', "Cannot export Torch QCDQ with FX" - + if not args.fuse_sequences: # 350 is approximately the 99% percentile for the sequence length in WikiText2 (train partition, using AutoTokenizer) if args.seqlen >= 350: warn( "Data loading can take a long time or, potentially, enter an infinite loop. Consider setting --args.fuse_sequences " - "or decreasing the sequence length (seqlen)" - ) + "or decreasing the sequence length (seqlen)") def main(args): @@ -479,7 +478,8 @@ def parse_args(args): "--fuse-sequences", action="store_true", default=False, - help="Whether to merge the dataset sequences in case they are shorter than the requested number of samples per sequence. This is useful in case you would like to quantize or evaluate on long sequences (default: %(default)s).", + help= + "Whether to merge the dataset sequences in case they are shorter than the requested number of samples per sequence. This is useful in case you would like to quantize or evaluate on long sequences (default: %(default)s).", ) return parser.parse_args(args) From 101e1b7b4af2b4efd9926e1162544a59558b07c6 Mon Sep 17 00:00:00 2001 From: Pablo Monteagudo Lago Date: Tue, 8 Oct 2024 14:57:14 +0100 Subject: [PATCH 3/3] Updated README --- src/brevitas_examples/llm/README.md | 95 +++++++++++------------------ 1 file changed, 35 insertions(+), 60 deletions(-) diff --git a/src/brevitas_examples/llm/README.md b/src/brevitas_examples/llm/README.md index cdf708d17..bc227346f 100644 --- a/src/brevitas_examples/llm/README.md +++ b/src/brevitas_examples/llm/README.md @@ -14,31 +14,19 @@ Set the env variable BREVITAS_JIT=1 to speed up the quantization process. Currently unsupported whenever export is also toggled or with MSE based scales/zero-points. ```bash -usage: main.py [-h] [--model MODEL] [--seed SEED] [--nsamples NSAMPLES] - [--seqlen SEQLEN] [--eval] [--dataset {wikitext2,c4}] - [--weight-bit-width WEIGHT_BIT_WIDTH] - [--weight-param-method {stats,mse}] - [--weight-scale-precision {float_scale,po2_scale}] - [--weight-quant-type {sym,asym}] - [--weight-quant-format WEIGHT_QUANT_FORMAT] - [--weight-quant-granularity {per_channel,per_tensor,per_group}] - [--weight-group-size WEIGHT_GROUP_SIZE] - [--quantize-weight-zero-point] - [--input-bit-width INPUT_BIT_WIDTH] - [--input-quant-format INPUT_QUANT_FORMAT] - [--input-param-method {stats,mse}] - [--input-scale-precision {float_scale,po2_scale}] - [--input-scale-type {static,dynamic,no_scale}] - [--input-quant-type {sym,asym}] - [--input-quant-granularity {per_tensor,per_row,per_group}] - [--input-group-size INPUT_GROUP_SIZE] - [--quantize-input-zero-point] [--quantize-last-layer] [--gptq] - [--act-calibration] [--bias-corr] [--ln-affine-merge] - [--no-quantize] [--no-float16] [--replace-mha] - [--weight-equalization] +usage: main.py [-h] [--model MODEL] [--seed SEED] [--nsamples NSAMPLES] [--seqlen SEQLEN] [--eval] [--dataset {wikitext2,c4}] + [--weight-bit-width WEIGHT_BIT_WIDTH] [--weight-param-method {stats,mse,hqo}] + [--weight-scale-precision {float_scale,po2_scale}] [--weight-quant-type {sym,asym}] + [--weight-quant-format WEIGHT_QUANT_FORMAT] [--weight-quant-granularity {per_channel,per_tensor,per_group}] + [--weight-group-size WEIGHT_GROUP_SIZE] [--quantize-weight-zero-point] [--input-bit-width INPUT_BIT_WIDTH] + [--input-quant-format INPUT_QUANT_FORMAT] [--input-param-method {stats,mse}] + [--input-scale-precision {float_scale,po2_scale}] [--input-scale-type {static,dynamic,no_scale}] + [--input-quant-type {sym,asym}] [--input-quant-granularity {per_tensor,per_row,per_group}] + [--input-group-size INPUT_GROUP_SIZE] [--quantize-input-zero-point] [--quantize-last-layer] [--gptq] [--act-calibration] + [--bias-corr] [--ln-affine-merge] [--no-quantize] [--no-float16] [--replace-mha] [--weight-equalization] [--act-equalization {None,layerwise,fx}] [--load-awq LOAD_AWQ] [--export-target {None,onnx_qcdq,torch_qcdq,sharded_torchmlir_group_weight,sharded_packed_torchmlir_group_weight}] - [--checkpoint-name CHECKPOINT_NAME] + [--export-prefix EXPORT_PREFIX] [--checkpoint-name CHECKPOINT_NAME] [--fuse-sequences] options: -h, --help show this help message and exit @@ -51,50 +39,38 @@ options: Dataset to use for quantization (default: wikitext2) --weight-bit-width WEIGHT_BIT_WIDTH Weight bit width. Default: 8. - --weight-param-method {stats,mse} + --weight-param-method {stats,mse,hqo} How scales/zero-point are determined. Default: stats. --weight-scale-precision {float_scale,po2_scale} Whether scale is a float value or a po2. Default: po2. --weight-quant-type {sym,asym} Weight quantization type. Default: asym. --weight-quant-format WEIGHT_QUANT_FORMAT - Weight quantization type. Either int or eXmY, with - X+Y==weight_bit_width-1. It's possible to add - float_ocp_ or float_fnuz_ before the exponent/mantissa - bitwidth. Default: int. + Weight quantization type. Either int or eXmY, with X+Y==weight_bit_width-1. It's possible to add float_ocp_ or + float_fnuz_ before the exponent/mantissa bitwidth. Default: int. --weight-quant-granularity {per_channel,per_tensor,per_group} - Granularity for scales/zero-point of weights. Default: - per_group. + Granularity for scales/zero-point of weights. Default: per_group. --weight-group-size WEIGHT_GROUP_SIZE - Group size for per_group weight quantization. Default: - 128. + Group size for per_group weight quantization. Default: 128. --quantize-weight-zero-point Quantize weight zero-point. --input-bit-width INPUT_BIT_WIDTH - Input bit width. Default: None (disables input - quantization). + Input bit width. Default: None (disables input quantization). --input-quant-format INPUT_QUANT_FORMAT - Input quantization type. Either int or eXmY, with - X+Y==weight_bit_width-1. It's possible to add - float_ocp_ or float_fnuz_ before the exponent/mantissa - bitwidth. Default: int. + Input quantization type. Either int or eXmY, with X+Y==weight_bit_width-1. It's possible to add float_ocp_ or + float_fnuz_ before the exponent/mantissa bitwidth. Default: int. --input-param-method {stats,mse} - How scales/zero-point are determined. Default: stats - (percentile for static, absmax or minmax for dynamic). + How scales/zero-point are determined. Default: stats (percentile for static, absmax or minmax for dynamic). --input-scale-precision {float_scale,po2_scale} - Whether input scale is a float value or a po2. - Default: float. + Whether input scale is a float value or a po2. Default: float. --input-scale-type {static,dynamic,no_scale} - Whether input scale is a static value or a dynamic - value. + Whether input scale is a static value or a dynamic value. --input-quant-type {sym,asym} Input quantization type. Default: asym. --input-quant-granularity {per_tensor,per_row,per_group} - Granularity for scales/zero-point of inputs. Default: - per_tensor. + Granularity for scales/zero-point of inputs. Default: per_tensor. --input-group-size INPUT_GROUP_SIZE - Group size for per_group input quantization. Default: - 64. + Group size for per_group input quantization. Default: 64. --quantize-input-zero-point Quantize input zero-point. --quantize-last-layer @@ -104,23 +80,22 @@ options: --bias-corr Apply bias correction. --ln-affine-merge Merge LN affine params. --no-quantize Disable quantization. - --no-float16 Disable float16 as base datatype and switch to - float32. - --replace-mha Replace HuggingFace Attention with a quantizable - version + --no-float16 Disable float16 as base datatype and switch to float32. + --replace-mha Replace HuggingFace Attention with a quantizable version --weight-equalization - Apply weight equalization. Relevant to ReLU based - models (e.g. OPT). + Apply weight equalization. Relevant to ReLU based models (e.g. OPT). --act-equalization {None,layerwise,fx} - Apply activation equalization (SmoothQuant). Layerwise - introduces standalone mul nodes,while fx merges them - whenever possible into previous tensors, which is - possible on ReLU based models (e.g. OPT). + Apply activation equalization (SmoothQuant). Layerwise introduces standalone mul nodes,while fx merges them + whenever possible into previous tensors, which is possible on ReLU based models (e.g. OPT). --load-awq LOAD_AWQ Load the awq search results. --export-target {None,onnx_qcdq,torch_qcdq,sharded_torchmlir_group_weight,sharded_packed_torchmlir_group_weight} Model export. + --export-prefix EXPORT_PREFIX + Path prefix to use for the various export flows. If None, a path will be derived from the model name (default: + None) --checkpoint-name CHECKPOINT_NAME - Filename to save checkpoint. If `None`, no checkpoint - is saved (default: None) + Filename to save checkpoint. If `None`, no checkpoint is saved (default: None) + --fuse-sequences Whether to merge the dataset sequences in case they are shorter than the requested number of samples per + sequence. This is useful in case you would like to quantize or evaluate on long sequences (default: False). ```