Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix (examples/llm): Fix infinite loop in LLM entrypoint with WikiText2 #1044

Merged
merged 4 commits into from
Oct 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion src/brevitas_examples/llm/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ usage: main.py [-h] [--model MODEL] [--seed SEED] [--nsamples NSAMPLES]
[--act-equalization {None,layerwise,fx}] [--load-awq LOAD_AWQ]
[--export-target {None,onnx_qcdq,torch_qcdq,sharded_torchmlir_group_weight,sharded_packed_torchmlir_group_weight}]
[--export-prefix EXPORT_PREFIX]
[--checkpoint-name CHECKPOINT_NAME]
[--checkpoint-name CHECKPOINT_NAME] [--fuse-sequences]

options:
-h, --help show this help message and exit
Expand Down Expand Up @@ -131,5 +131,10 @@ options:
--checkpoint-name CHECKPOINT_NAME
Filename to save checkpoint. If `None`, no checkpoint
is saved (default: None)
--fuse-sequences Whether to merge the dataset sequences in case they
are shorter than the requested number of samples per
sequence. This is useful in case you would like to
quantize or evaluate on long sequences (default:
False).

```
20 changes: 17 additions & 3 deletions src/brevitas_examples/llm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import argparse
import re
import sys
from warnings import warn

import numpy as np
from optimum.amd.brevitas.accelerate_utils import offload_model
Expand Down Expand Up @@ -109,6 +110,13 @@ def validate(args):
else:
assert args.export_target != 'torch_qcdq', "Cannot export Torch QCDQ with FX"

if not args.fuse_sequences:
# 350 is approximately the 99% percentile for the sequence length in WikiText2 (train partition, using AutoTokenizer)
if args.seqlen >= 350:
warn(
"Data loading can take a long time or, potentially, enter an infinite loop. Consider setting --args.fuse_sequences "
"or decreasing the sequence length (seqlen)")


def main(args):
validate(args)
Expand Down Expand Up @@ -142,7 +150,6 @@ def main(args):
apply_awq(model, awq_results)

require_fx = True if args.weight_equalization or args.act_equalization == 'fx' or args.ln_affine_merge else False
fuse_sequences = False

# Load the data for calibration and evaluation.
calibration_loader = get_dataset_for_model(
Expand All @@ -155,7 +162,7 @@ def main(args):
seed=args.seed,
require_fx=require_fx,
device=None,
fuse_sequences=fuse_sequences,
fuse_sequences=args.fuse_sequences,
)

validation_loader = get_dataset_for_model(
Expand All @@ -168,7 +175,7 @@ def main(args):
seed=args.seed,
require_fx=require_fx,
device=None,
fuse_sequences=fuse_sequences,
fuse_sequences=args.fuse_sequences,
)

device = next(iter(model.parameters())).device
Expand Down Expand Up @@ -474,6 +481,13 @@ def parse_args(args):
default=None,
help="Filename to save checkpoint. If `None`, no checkpoint is saved (default: %(default)s)"
)
parser.add_argument(
"--fuse-sequences",
action="store_true",
default=False,
help=
"Whether to merge the dataset sequences in case they are shorter than the requested number of samples per sequence. This is useful in case you would like to quantize or evaluate on long sequences (default: %(default)s).",
)
return parser.parse_args(args)


Expand Down
Loading