Skip to content

Commit

Permalink
GPxQ generalization
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Oct 11, 2024
1 parent 58d6f15 commit da1a249
Showing 1 changed file with 11 additions and 4 deletions.
15 changes: 11 additions & 4 deletions src/brevitas_examples/llm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,8 @@ def main(args):

if require_fx:
model = get_fx(model)
# Blockwise optimization does not work with FX at the moment
args.gpxq_block_name = None

# Apply LN affine merging before inserting MHA layers
# since currently there is support only for merging into Linear
Expand Down Expand Up @@ -284,12 +286,17 @@ def main(args):
calibration_loader,
act_order=args.gpxq_act_order,
use_quant_activations=args.gpxq_use_quant_activations,
create_weight_orig=args.gpxq_create_weight_orig)
create_weight_orig=args.gpxq_create_weight_orig,
block_name=args.gpxq_block_name)
print("GPTQ applied.")

if args.gpfq:
print("Applying GPFQ...")
apply_gpfq(model, calibration_loader, act_order=args.gpxq_act_order)
apply_gpfq(
model,
calibration_loader,
act_order=args.gpxq_act_order,
block_name=args.gpxq_block_name)
print("GPFQ applied.")

if args.bias_corr:
Expand Down Expand Up @@ -340,11 +347,11 @@ def parse_args(args):
default='wikitext2',
help='Dataset to use for quantization (default: %(default)s)')
parser.add_argument(
'--gptq-block-name',
'--gpxq-block-name',
type=str,
default=None,
help=
'Block name for faster GPTQ optimization. It works only if FX is not needed (default: %(default)s)'
'Block name for faster GPxQ optimization. It works only if FX is not needed (default: %(default)s)'
)
parser.add_argument(
'--weight-bit-width', type=int, default=8, help='Weight bit width. Default: 8.')
Expand Down

0 comments on commit da1a249

Please sign in to comment.