Skip to content

Commit

Permalink
Feat (example/sdxl): Allow customization of SDPA quant via the comman…
Browse files Browse the repository at this point in the history
…dline
  • Loading branch information
nickfraser committed Sep 10, 2024
1 parent 28371a5 commit f7684c3
Showing 1 changed file with 105 additions and 28 deletions.
133 changes: 105 additions & 28 deletions src/brevitas_examples/stable_diffusion/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,29 @@ def input_zp_stats_type():

input_kwargs['zero_point_stats_impl'] = input_zp_stats_type

sdpa_kwargs = dict()
if args.sdpa_scale_stats_op == 'minmax':

@value
def sdpa_scale_stats_type():
if args.sdpa_quant_type == 'asym':
sdpa_scaling_stats_op = StatsOp.MIN_MAX
else:
sdpa_scaling_stats_op = StatsOp.MAX
return sdpa_scaling_stats_op

sdpa_kwargs['scaling_stats_op'] = sdpa_scale_stats_type

if args.sdpa_zp_stats_op == 'minmax':

@value
def sdpa_zp_stats_type():
if args.sdpa_quant_type == 'asym':
zero_point_stats_impl = NegativeMinOrZero
return zero_point_stats_impl

sdpa_kwargs['zero_point_stats_impl'] = sdpa_zp_stats_type

print("Applying model quantization...")
quantizers = generate_quantizers(
dtype=dtype,
Expand Down Expand Up @@ -360,29 +383,29 @@ def input_zp_stats_type():
if args.quantize_sdp:
assert args.share_qkv_quant, "Currently SDPA quantization is supported only with shared QKV quantization"
# TODO: reformat this
float_sdpa_quantizers = generate_quantizers(
sdpa_quantizers = generate_quantizers(
dtype=dtype,
device=args.device,
weight_bit_width=weight_bit_width,
weight_quant_format='float_fnuz_e4m3',
weight_quant_type='sym',
weight_param_method=args.weight_param_method,
weight_scale_precision=args.weight_scale_precision,
weight_quant_granularity=args.weight_quant_granularity,
weight_group_size=args.weight_group_size,
quantize_weight_zero_point=args.quantize_weight_zero_point,
quantize_input_zero_point=args.quantize_input_zero_point,
input_bit_width=args.linear_output_bit_width,
input_quant_format='float_fnuz_e4m3',
input_scale_type=args.input_scale_type,
input_scale_precision=args.input_scale_precision,
input_param_method=args.input_param_method,
input_quant_type='sym',
input_quant_granularity=args.input_quant_granularity,
input_kwargs=input_kwargs)
weight_bit_width=args.sdpa_bit_width,
weight_quant_format=args.sdpa_quant_format,
weight_quant_type=args.sdpa_quant_type,
weight_param_method=args.sdpa_param_method,
weight_scale_precision=args.sdpa_scale_precision,
weight_quant_granularity=args.sdpa_quant_granularity,
weight_group_size=32, # Not used, since args.sdpa_quant_granularity == 'per_tensor'
quantize_weight_zero_point=args.quantize_sdpa_zero_point,
quantize_input_zero_point=args.quantize_sdpa_zero_point,
input_bit_width=args.sdpa_bit_width,
input_quant_format=args.sdpa_quant_format,
input_scale_type=args.sdpa_scale_type,
input_scale_precision=args.sdpa_scale_precision,
input_param_method=args.sdpa_param_method,
input_quant_type=args.sdpa_quant_type,
input_quant_granularity=args.sdpa_quant_granularity,
input_kwargs=sdpa_kwargs)
# We generate all quantizers, but we are only interested in activation quantization for
# the output of softmax and the output of QKV
input_quant = float_sdpa_quantizers[0]
input_quant = sdpa_quantizers[0]
rewriter = ModuleToModuleByClass(
Attention,
QuantAttention,
Expand All @@ -400,11 +423,11 @@ def input_zp_stats_type():

if args.override_conv_quant_config:
print(
f"Overriding Conv2d quantization to weights: {float_sdpa_quantizers[1]}, inputs: {float_sdpa_quantizers[2]}"
f"Overriding Conv2d quantization to weights: {sdpa_quantizers[1]}, inputs: {sdpa_quantizers[2]}"
)
conv_qkwargs = layer_map[torch.nn.Conv2d][1]
conv_qkwargs['input_quant'] = float_sdpa_quantizers[2]
conv_qkwargs['weight_quant'] = float_sdpa_quantizers[1]
conv_qkwargs['input_quant'] = sdpa_quantizers[2]
conv_qkwargs['weight_quant'] = sdpa_quantizers[1]
layer_map[torch.nn.Conv2d] = (layer_map[torch.nn.Conv2d][0], conv_qkwargs)

pipe.unet = layerwise_quantize(
Expand Down Expand Up @@ -435,7 +458,7 @@ def input_zp_stats_type():
pipe = pipe.to(args.device)
elif not args.dry_run:
if (args.linear_input_bit_width > 0 or args.conv_input_bit_width > 0 or
args.linear_output_bit_width > 0) and args.input_scale_type == 'static':
args.sdpa_bit_width > 0) and args.input_scale_type == 'static':
print("Applying activation calibration")
with torch.no_grad(), calibration_mode(pipe.unet):
run_val_inference(
Expand Down Expand Up @@ -707,11 +730,6 @@ def input_zp_stats_type():
type=int,
default=0,
help='Input bit width. Default: 0 (not quantized).')
parser.add_argument(
'--linear-output-bit-width',
type=int,
default=0,
help='Input bit width. Default: 0 (not quantized).')
parser.add_argument(
'--weight-param-method',
type=str,
Expand Down Expand Up @@ -797,6 +815,60 @@ def input_zp_stats_type():
type=int,
default=16,
help='Group size for per_group weight quantization. Default: 16.')
parser.add_argument(
'--sdpa-bit-width',
type=int,
default=0,
help='Scaled dot product attention bit width. Default: 0 (not quantized).')
parser.add_argument(
'--sdpa-param-method',
type=str,
default='stats',
choices=['stats', 'mse'],
help='How scales/zero-point are determined for scaled dot product attention. Default: %(default)s.')
parser.add_argument(
'--sdpa-scale-stats-op',
type=str,
default='minmax',
choices=['minmax', 'percentile'],
help='Define what statistics op to use for scaled dot product attention scale. Default: %(default)s.')
parser.add_argument(
'--sdpa-zp-stats-op',
type=str,
default='minmax',
choices=['minmax', 'percentile'],
help='Define what statistics op to use for scaled dot product attention zero point. Default: %(default)s.')
parser.add_argument(
'--sdpa-scale-precision',
type=str,
default='float_scale',
choices=['float_scale', 'po2_scale'],
help='Whether the scaled dot product attention scale is a float value or a po2. Default: %(default)s.')
parser.add_argument(
'--sdpa-quant-type',
type=str,
default='asym',
choices=['sym', 'asym'],
help='Scaled dot product attention quantization type. Default: %(default)s.')
parser.add_argument(
'--sdpa-quant-format',
type=quant_format_validator,
default='int',
help=
'Scaled dot product attention quantization format. Either int or eXmY, with X+Y==input_bit_width-1. It\'s possible to add float_ocp_ or float_fnuz_ before the exponent/mantissa bitwidth. Default: %(default)s.'
)
parser.add_argument(
'--sdpa-quant-granularity',
type=str,
default='per_tensor',
choices=['per_tensor'],
help='Granularity for scales/zero-point of scaled dot product attention. Default: %(default)s.')
parser.add_argument(
'--sdpa-scale-type',
type=str,
default='static',
choices=['static', 'dynamic'],
help='Whether to do static or dynamic scaled dot product attention quantization. Default: %(default)s.')
parser.add_argument(
'--quant-blacklist',
type=str,
Expand All @@ -819,6 +891,11 @@ def input_zp_stats_type():
'quantize-input-zero-point',
default=False,
help='Quantize input zero-point. Default: Enabled')
add_bool_arg(
parser,
'quantize-sdpa-zero-point',
default=False,
help='Quantize scaled dot product attention zero-point. Default: %(default)s')
add_bool_arg(
parser, 'export-cpu-float32', default=False, help='Export FP32 on CPU. Default: Disabled')
add_bool_arg(
Expand Down

0 comments on commit f7684c3

Please sign in to comment.