diff --git a/src/brevitas_examples/llm/main.py b/src/brevitas_examples/llm/main.py index e6df0759f..25c885459 100644 --- a/src/brevitas_examples/llm/main.py +++ b/src/brevitas_examples/llm/main.py @@ -21,8 +21,8 @@ from brevitas_examples.llm.llm_quant.ln_affine_merge import apply_layernorm_affine_merge from brevitas_examples.llm.llm_quant.prepare_for_quantize import replace_mha_with_quantizable_layers from brevitas_examples.llm.llm_quant.quantize import quantize_model -from brevitas_examples.llm.llm_quant.run_utils import get_model_impl from brevitas_examples.llm.llm_quant.run_utils import CastFloat16ToFloat32 +from brevitas_examples.llm.llm_quant.run_utils import get_model_impl parser = argparse.ArgumentParser() parser.add_argument( @@ -132,8 +132,7 @@ help='Apply activation equalization (SmoothQuant). Layerwise introduces standalone mul nodes,' 'while fx merges them whenever possible into previous tensors, which is possible on ReLU based models (e.g. OPT).' ) -parser.add_argument('--load-awq', type=str, default=None, - help="Load the awq search results.") +parser.add_argument('--load-awq', type=str, default=None, help="Load the awq search results.") parser.add_argument( '--export-target', default=None, @@ -217,7 +216,6 @@ def main(): awq_results = torch.load(args.load_awq, map_location="cpu") with CastFloat16ToFloat32(): apply_awq(model, awq_results) - if (args.export_target or args.eval or args.act_equalization or args.act_calibration or args.gptq or args.bias_corr or args.ln_affine_merge or args.weight_equalization):