From 9932b92adfe7312f68bac8d989989f142ee0f804 Mon Sep 17 00:00:00 2001 From: nickfraser Date: Thu, 12 Sep 2024 11:41:31 +0100 Subject: [PATCH] Feat (examples/sdxl): Updates to SDXL entry-point (#1020) --------- Co-authored-by: Giuseppe Franco --- .../stable_diffusion/main.py | 239 ++++++++++++++---- .../stable_diffusion/sd_quant/export.py | 15 +- .../stable_diffusion/sd_quant/nn.py | 140 +++++++++- 3 files changed, 336 insertions(+), 58 deletions(-) diff --git a/src/brevitas_examples/stable_diffusion/main.py b/src/brevitas_examples/stable_diffusion/main.py index a1c4fef53..3fe24a321 100644 --- a/src/brevitas_examples/stable_diffusion/main.py +++ b/src/brevitas_examples/stable_diffusion/main.py @@ -12,12 +12,14 @@ import time from dependencies import value +import diffusers from diffusers import DiffusionPipeline from diffusers import EulerDiscreteScheduler from diffusers import StableDiffusionXLPipeline from diffusers.models.attention_processor import Attention -from diffusers.models.attention_processor import AttnProcessor import numpy as np +import packaging +import packaging.version import pandas as pd import torch from torch import nn @@ -35,7 +37,6 @@ from brevitas.graph.quantize import layerwise_quantize from brevitas.inject.enum import StatsOp from brevitas.nn.equalized_layer import EqualizedModule -from brevitas.nn.quant_activation import QuantIdentity from brevitas.utils.torch_utils import KwargsForwardHook from brevitas_examples.common.generative.quantize import generate_quant_maps from brevitas_examples.common.generative.quantize import generate_quantizers @@ -47,14 +48,17 @@ from brevitas_examples.stable_diffusion.sd_quant.constants import SD_XL_EMBEDDINGS_SHAPE from brevitas_examples.stable_diffusion.sd_quant.export import export_onnx from brevitas_examples.stable_diffusion.sd_quant.export import export_quant_params +from brevitas_examples.stable_diffusion.sd_quant.nn import AttnProcessor from brevitas_examples.stable_diffusion.sd_quant.nn import AttnProcessor2_0 from brevitas_examples.stable_diffusion.sd_quant.nn import QuantAttention +from brevitas_examples.stable_diffusion.sd_quant.nn import QuantAttentionLast from brevitas_examples.stable_diffusion.sd_quant.nn import QuantizableAttention from brevitas_examples.stable_diffusion.sd_quant.utils import generate_latents from brevitas_examples.stable_diffusion.sd_quant.utils import generate_unet_21_rand_inputs from brevitas_examples.stable_diffusion.sd_quant.utils import generate_unet_xl_rand_inputs from brevitas_examples.stable_diffusion.sd_quant.utils import unet_input_shape +diffusers_version = packaging.version.parse(diffusers.__version__) TEST_SEED = 123456 torch.manual_seed(TEST_SEED) @@ -149,7 +153,7 @@ def main(args): calibration_prompts = CALIBRATION_PROMPTS if args.calibration_prompt_path is not None: calibration_prompts = load_calib_prompts(args.calibration_prompt_path) - print(args.calibration_prompt, len(calibration_prompts)) + assert args.calibration_prompt <= len(calibration_prompts) , f"Only {len(calibration_prompts)} prompts are available" calibration_prompts = calibration_prompts[:args.calibration_prompt] @@ -178,18 +182,29 @@ def main(args): args.model, torch_dtype=dtype, variant=variant, use_safetensors=True) pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config) pipe.vae.config.force_upcast = True - if args.share_qkv_quant: - rewriter = ModuleToModuleByClass( - Attention, - QuantizableAttention, - query_dim=lambda module: module.to_q.in_features, - dim_head=lambda module: math.ceil(1 / (module.scale ** 2)), - bias=lambda module: hasattr(module.to_q, 'bias') and module.to_q.bias is not None, - processor=AttnProcessor2_0(), - dtype=dtype, - norm_num_groups=lambda module: None - if module.group_norm is None else module.group_norm.num_groups) - rewriter.apply(pipe.unet) + is_mlperf_diffusers = diffusers_version == packaging.version.parse('0.21.2') + + AttClass = Attention + if is_mlperf_diffusers: + QuantAttClass = QuantAttention + if args.share_qkv_quant: + AttClass = QuantizableAttention + rewriter = ModuleToModuleByClass( + Attention, + QuantizableAttention, + query_dim=lambda module: module.to_q.in_features, + dim_head=lambda module: math.ceil(1 / (module.scale ** 2)), + bias=lambda module: hasattr(module.to_q, 'bias') and module.to_q.bias is not None, + processor=AttnProcessor2_0(), + dtype=dtype, + norm_num_groups=lambda module: None + if module.group_norm is None else module.group_norm.num_groups) + rewriter.apply(pipe.unet) + else: + QuantAttClass = QuantAttentionLast + if args.share_qkv_quant: + pipe.fuse_qkv_projections() + print(f"Model loaded from {args.model}.") # Move model to target device @@ -222,7 +237,7 @@ def main(args): blacklist = [] non_blacklist = dict() for name, _ in pipe.unet.named_modules(): - if 'time_emb' in name: + if any(map(lambda x: x in name, args.quant_blacklist)): blacklist.append(name) else: if isinstance(_, (torch.nn.Linear, torch.nn.Conv2d)): @@ -232,7 +247,7 @@ def main(args): else: non_blacklist[name_to_add] += 1 print(f"Blacklisted layers: {set(blacklist)}") - print(f"Non blacklisted layers: {non_blacklist}") + print(f"Non blacklisted layers: {set(non_blacklist.keys())}") # Make sure there all LoRA layers are fused first, otherwise raise an error for m in pipe.unet.modules(): @@ -316,6 +331,29 @@ def input_zp_stats_type(): input_kwargs['zero_point_stats_impl'] = input_zp_stats_type + sdpa_kwargs = dict() + if args.sdpa_scale_stats_op == 'minmax': + + @value + def sdpa_scale_stats_type(): + if args.sdpa_quant_type == 'asym': + sdpa_scaling_stats_op = StatsOp.MIN_MAX + else: + sdpa_scaling_stats_op = StatsOp.MAX + return sdpa_scaling_stats_op + + sdpa_kwargs['scaling_stats_op'] = sdpa_scale_stats_type + + if args.sdpa_zp_stats_op == 'minmax': + + @value + def sdpa_zp_stats_type(): + if args.sdpa_quant_type == 'asym': + zero_point_stats_impl = NegativeMinOrZero + return zero_point_stats_impl + + sdpa_kwargs['zero_point_stats_impl'] = sdpa_zp_stats_type + print("Applying model quantization...") quantizers = generate_quantizers( dtype=dtype, @@ -357,40 +395,52 @@ def input_zp_stats_type(): 'weight_quant'] layer_map[torch.nn.Conv2d] = (layer_map[torch.nn.Conv2d][0], conv_qkwargs) - if args.quantize_sdp: - assert args.share_qkv_quant, "Currently SDPA quantization is supported only with shared QKV quantization" - # TODO: reformat this - float_sdpa_quantizers = generate_quantizers( + if args.sdpa_bit_width > 0: + # `args.weight_quant_granularity` must be compatible with `args.sdpa_quant_format` + sdpa_quantizers = generate_quantizers( dtype=dtype, device=args.device, - weight_bit_width=weight_bit_width, - weight_quant_format='float_ocp_e4m3', - weight_quant_type='sym', + weight_bit_width=args.sdpa_bit_width, + weight_quant_format=args.sdpa_quant_format, + weight_quant_type=args.sdpa_quant_type, weight_param_method=args.weight_param_method, weight_scale_precision=args.weight_scale_precision, weight_quant_granularity=args.weight_quant_granularity, weight_group_size=args.weight_group_size, quantize_weight_zero_point=args.quantize_weight_zero_point, - quantize_input_zero_point=args.quantize_input_zero_point, - input_bit_width=args.linear_output_bit_width, - input_quant_format='float_ocp_e4m3', - input_scale_type=args.input_scale_type, - input_scale_precision=args.input_scale_precision, - input_param_method=args.input_param_method, - input_quant_type='sym', - input_quant_granularity=args.input_quant_granularity, - input_kwargs=input_kwargs) + quantize_input_zero_point=args.quantize_sdpa_zero_point, + input_bit_width=args.sdpa_bit_width, + input_quant_format=args.sdpa_quant_format, + input_scale_type=args.sdpa_scale_type, + input_scale_precision=args.sdpa_scale_precision, + input_param_method=args.sdpa_param_method, + input_quant_type=args.sdpa_quant_type, + input_quant_granularity=args.sdpa_quant_granularity, + input_kwargs=sdpa_kwargs) # We generate all quantizers, but we are only interested in activation quantization for # the output of softmax and the output of QKV - input_quant = float_sdpa_quantizers[0] + input_quant = sdpa_quantizers[0] + if is_mlperf_diffusers: + extra_kwargs = {} + query_lambda = lambda module: module.to_qkv.in_features if hasattr( + module, 'to_qkv') else module.to_q.in_features + else: + extra_kwargs = { + 'fuse_qkv': + args.share_qkv_quant, + 'cross_attention_dim': + lambda module: module.cross_attention_dim + if module.is_cross_attention else None} + query_lambda = lambda module: module.query_dim rewriter = ModuleToModuleByClass( - Attention, - QuantAttention, + AttClass, + QuantAttClass, matmul_input_quant=input_quant, - query_dim=lambda module: module.to_q.in_features, + query_dim=query_lambda, dim_head=lambda module: math.ceil(1 / (module.scale ** 2)), processor=AttnProcessor(), - is_equalized=args.activation_equalization) + is_equalized=args.activation_equalization, + **extra_kwargs) import brevitas.config as config config.IGNORE_MISSING_KEYS = True pipe.unet = rewriter.apply(pipe.unet) @@ -400,11 +450,11 @@ def input_zp_stats_type(): if args.override_conv_quant_config: print( - f"Overriding Conv2d quantization to weights: {float_sdpa_quantizers[1]}, inputs: {float_sdpa_quantizers[2]}" + f"Overriding Conv2d quantization to weights: {sdpa_quantizers[1]}, inputs: {sdpa_quantizers[2]}" ) conv_qkwargs = layer_map[torch.nn.Conv2d][1] - conv_qkwargs['input_quant'] = float_sdpa_quantizers[2] - conv_qkwargs['weight_quant'] = float_sdpa_quantizers[1] + conv_qkwargs['input_quant'] = sdpa_quantizers[2] + conv_qkwargs['weight_quant'] = sdpa_quantizers[1] layer_map[torch.nn.Conv2d] = (layer_map[torch.nn.Conv2d][0], conv_qkwargs) pipe.unet = layerwise_quantize( @@ -434,8 +484,18 @@ def input_zp_stats_type(): print(f"Checkpoint loaded!") pipe = pipe.to(args.device) elif not args.dry_run: - if (args.linear_input_bit_width > 0 or args.conv_input_bit_width > 0 or - args.linear_output_bit_width > 0) and args.input_scale_type == 'static': + # Model needs calibration if any of its activation quantizers are 'static' + activation_bw = [ + args.linear_input_bit_width, + args.conv_input_bit_width, + args.sdpa_bit_width,] + activation_st = [ + args.input_scale_type, + args.input_scale_type, + args.sdpa_scale_type,] + needs_calibration = any( + map(lambda b, st: (b > 0) and st == 'static', activation_bw, activation_st)) + if needs_calibration: print("Applying activation calibration") with torch.no_grad(), calibration_mode(pipe.unet): run_val_inference( @@ -520,7 +580,13 @@ def input_zp_stats_type(): if args.use_mlperf_inference: print(f"Computing accuracy with MLPerf pipeline") compute_mlperf_fid( - args.model, args.path_to_coco, pipe, args.prompt, output_dir, not args.vae_fp16_fix) + args.model, + args.path_to_coco, + pipe, + args.prompt, + output_dir, + args.device, + not args.vae_fp16_fix) else: print(f"Computing accuracy on default prompt") testing_prompts = TESTING_PROMPTS[:args.prompt] @@ -643,7 +709,7 @@ def input_zp_stats_type(): help='Resolution along height and width dimension. Default: 512.') parser.add_argument('--guidance-scale', type=float, default=7.5, help='Guidance scale.') parser.add_argument( - '--calibration-steps', type=float, default=8, help='Steps used during calibration') + '--calibration-steps', type=int, default=8, help='Steps used during calibration') add_bool_arg( parser, 'output-path', @@ -701,11 +767,6 @@ def input_zp_stats_type(): type=int, default=0, help='Input bit width. Default: 0 (not quantized).') - parser.add_argument( - '--linear-output-bit-width', - type=int, - default=0, - help='Input bit width. Default: 0 (not quantized).') parser.add_argument( '--weight-param-method', type=str, @@ -791,6 +852,78 @@ def input_zp_stats_type(): type=int, default=16, help='Group size for per_group weight quantization. Default: 16.') + parser.add_argument( + '--sdpa-bit-width', + type=int, + default=0, + help='Scaled dot product attention bit width. Default: 0 (not quantized).') + parser.add_argument( + '--sdpa-param-method', + type=str, + default='stats', + choices=['stats', 'mse'], + help= + 'How scales/zero-point are determined for scaled dot product attention. Default: %(default)s.' + ) + parser.add_argument( + '--sdpa-scale-stats-op', + type=str, + default='minmax', + choices=['minmax', 'percentile'], + help= + 'Define what statistics op to use for scaled dot product attention scale. Default: %(default)s.' + ) + parser.add_argument( + '--sdpa-zp-stats-op', + type=str, + default='minmax', + choices=['minmax', 'percentile'], + help= + 'Define what statistics op to use for scaled dot product attention zero point. Default: %(default)s.' + ) + parser.add_argument( + '--sdpa-scale-precision', + type=str, + default='float_scale', + choices=['float_scale', 'po2_scale'], + help= + 'Whether the scaled dot product attention scale is a float value or a po2. Default: %(default)s.' + ) + parser.add_argument( + '--sdpa-quant-type', + type=str, + default='sym', + choices=['sym', 'asym'], + help='Scaled dot product attention quantization type. Default: %(default)s.') + parser.add_argument( + '--sdpa-quant-format', + type=quant_format_validator, + default='int', + help= + 'Scaled dot product attention quantization format. Either int or eXmY, with X+Y==input_bit_width-1. It\'s possible to add float_ocp_ or float_fnuz_ before the exponent/mantissa bitwidth. Default: %(default)s.' + ) + parser.add_argument( + '--sdpa-quant-granularity', + type=str, + default='per_tensor', + choices=['per_tensor'], + help= + 'Granularity for scales/zero-point of scaled dot product attention. Default: %(default)s.') + parser.add_argument( + '--sdpa-scale-type', + type=str, + default='static', + choices=['static', 'dynamic'], + help= + 'Whether to do static or dynamic scaled dot product attention quantization. Default: %(default)s.' + ) + parser.add_argument( + '--quant-blacklist', + type=str, + default=['time_emb'], + nargs='*', + metavar='NAME', + help='A list of module names to exclude from quantization. Default: %(default)s') add_bool_arg( parser, 'quantize-weight-zero-point', @@ -806,6 +939,11 @@ def input_zp_stats_type(): 'quantize-input-zero-point', default=False, help='Quantize input zero-point. Default: Enabled') + add_bool_arg( + parser, + 'quantize-sdpa-zero-point', + default=False, + help='Quantize scaled dot product attention zero-point. Default: %(default)s') add_bool_arg( parser, 'export-cpu-float32', default=False, help='Export FP32 on CPU. Default: Disabled') add_bool_arg( @@ -823,7 +961,6 @@ def input_zp_stats_type(): 'dry-run', default=False, help='Generate a quantized model without any calibration. Default: Disabled') - add_bool_arg(parser, 'quantize-sdp', default=False, help='Quantize SDP. Default: Disabled') add_bool_arg( parser, 'override-conv-quant-config', diff --git a/src/brevitas_examples/stable_diffusion/sd_quant/export.py b/src/brevitas_examples/stable_diffusion/sd_quant/export.py index 89d846a79..a42a35204 100644 --- a/src/brevitas_examples/stable_diffusion/sd_quant/export.py +++ b/src/brevitas_examples/stable_diffusion/sd_quant/export.py @@ -40,18 +40,18 @@ def handle_quant_param(layer, layer_dict): layer_dict['output_scale_shape'] = output_scale.shape layer_dict['input_scale'] = input_scale.numpy().tolist() layer_dict['input_scale_shape'] = input_scale.shape - layer_dict['input_zp'] = input_zp.numpy().tolist() + layer_dict['input_zp'] = input_zp.to(torch.float32).cpu().numpy().tolist() layer_dict['input_zp_shape'] = input_zp.shape - layer_dict['input_zp_dtype'] = str(torch.int8) + layer_dict['input_zp_dtype'] = str(input_zp.dtype) layer_dict['weight_scale'] = weight_scale.cpu().numpy().tolist() nelems = layer.weight.shape[0] weight_scale_shape = [nelems] + [1] * (layer.weight.data.ndim - 1) layer_dict['weight_scale_shape'] = weight_scale_shape - if torch.sum(weight_zp) != 0.: + if torch.sum(weight_zp.to(torch.float32)) != 0.: weight_zp = weight_zp - 128. # apply offset to have signed z - layer_dict['weight_zp'] = weight_zp.cpu().numpy().tolist() + layer_dict['weight_zp'] = weight_zp.to(torch.float32).cpu().numpy().tolist() layer_dict['weight_zp_shape'] = weight_scale_shape - layer_dict['weight_zp_dtype'] = str(torch.int8) + layer_dict['weight_zp_dtype'] = str(weight_zp.dtype) return layer_dict @@ -63,6 +63,9 @@ def export_quant_params(pipe, output_dir, export_vae=False): vae_output_path = os.path.join(output_dir, 'vae.safetensors') print(f"Saving vae to {vae_output_path} ...") from brevitas.export.onnx.standard.qcdq.manager import StdQCDQONNXManager + export_manager = StdQCDQONNXManager + export_manager.change_weight_export( + export_weight_q_node=True) # We're exporting FP weights + quantization parameters quant_params = dict() state_dict = pipe.unet.state_dict() state_dict = {k: v for (k, v) in state_dict.items() if 'tensor_quant' not in k} @@ -70,7 +73,7 @@ def export_quant_params(pipe, output_dir, export_vae=False): state_dict = {k.replace('.layer.', '.'): v for (k, v) in state_dict.items()} handled_quant_layers = set() - with torch.no_grad(), brevitas_proxy_export_mode(pipe.unet, StdQCDQONNXManager): + with torch.no_grad(), brevitas_proxy_export_mode(pipe.unet, export_manager): for name, module in pipe.unet.named_modules(): if isinstance(module, EqualizedModule): if id(module.layer) in handled_quant_layers: diff --git a/src/brevitas_examples/stable_diffusion/sd_quant/nn.py b/src/brevitas_examples/stable_diffusion/sd_quant/nn.py index 5a6c23ab9..299fb557b 100644 --- a/src/brevitas_examples/stable_diffusion/sd_quant/nn.py +++ b/src/brevitas_examples/stable_diffusion/sd_quant/nn.py @@ -17,8 +17,11 @@ from typing import Any, Mapping, Optional +import diffusers from diffusers.models.attention_processor import Attention from diffusers.models.lora import LoRACompatibleLinear +import packaging +import packaging.version import torch import torch.nn.functional as F @@ -119,6 +122,142 @@ def load_state_dict( return super().load_state_dict(state_dict, strict, assign) +class QuantAttentionLast(Attention): + + def __init__( + self, + query_dim: int, + cross_attention_dim: Optional[int] = None, + heads: int = 8, + kv_heads: Optional[int] = None, + dim_head: int = 64, + dropout: float = 0.0, + bias: bool = False, + upcast_attention: bool = False, + upcast_softmax: bool = False, + cross_attention_norm: Optional[str] = None, + cross_attention_norm_num_groups: int = 32, + qk_norm: Optional[str] = None, + added_kv_proj_dim: Optional[int] = None, + added_proj_bias: Optional[bool] = True, + norm_num_groups: Optional[int] = None, + spatial_norm_dim: Optional[int] = None, + out_bias: bool = True, + scale_qk: bool = True, + only_cross_attention: bool = False, + eps: float = 1e-5, + rescale_output_factor: float = 1.0, + residual_connection: bool = False, + _from_deprecated_attn_block: bool = False, + processor: Optional["AttnProcessor"] = None, + out_dim: int = None, + context_pre_only=None, + pre_only=False, + matmul_input_quant=None, + is_equalized=False, + fuse_qkv=False): + + super().__init__( + query_dim, + cross_attention_dim, + heads, + kv_heads, + dim_head, + dropout, + bias, + upcast_attention, + upcast_softmax, + cross_attention_norm, + cross_attention_norm_num_groups, + qk_norm, + added_kv_proj_dim, + added_proj_bias, + norm_num_groups, + spatial_norm_dim, + out_bias, + scale_qk, + only_cross_attention, + eps, + rescale_output_factor, + residual_connection, + _from_deprecated_attn_block, + processor, + out_dim, + context_pre_only, + pre_only, + ) + if fuse_qkv: + self.fuse_projections() + + self.output_softmax_quant = QuantIdentity(matmul_input_quant) + self.out_q = QuantIdentity(matmul_input_quant) + self.out_k = QuantIdentity(matmul_input_quant) + self.out_v = QuantIdentity(matmul_input_quant) + if is_equalized: + replacements = [] + for n, m in self.named_modules(): + if isinstance(m, torch.nn.Linear): + in_channels = m.in_features + eq_m = EqualizedModule(ScaleBias(in_channels, False, (1, 1, -1)), m) + r = ModuleInstanceToModuleInstance(m, eq_m) + replacements.append(r) + for r in replacements: + r.apply(self) + + def get_attention_scores( + self, + query: torch.Tensor, + key: torch.Tensor, + attention_mask: torch.Tensor = None) -> torch.Tensor: + r""" + Compute the attention scores. + + Args: + query (`torch.Tensor`): The query tensor. + key (`torch.Tensor`): The key tensor. + attention_mask (`torch.Tensor`, *optional*): The attention mask to use. If `None`, no mask is applied. + + Returns: + `torch.Tensor`: The attention probabilities/scores. + """ + dtype = query.dtype + if self.upcast_attention: + query = query.float() + key = key.float() + + if attention_mask is None: + baddbmm_input = torch.empty( + query.shape[0], + query.shape[1], + key.shape[1], + dtype=query.dtype, + device=query.device) + beta = 0 + else: + baddbmm_input = attention_mask + beta = 1 + + attention_scores = torch.baddbmm( + baddbmm_input, + query, + key.transpose(-1, -2), + beta=beta, + alpha=self.scale, + ) + del baddbmm_input + + if self.upcast_softmax: + attention_scores = attention_scores.float() + + attention_probs = attention_scores.softmax(dim=-1) + del attention_scores + + attention_probs = attention_probs.to(dtype) + + attention_probs = _unpack_quant_tensor(self.output_softmax_quant(attention_probs)) + return attention_probs + + class QuantAttention(QuantizableAttention): def __init__( @@ -172,7 +311,6 @@ def __init__( dtype, processor, ) - self.output_softmax_quant = QuantIdentity(matmul_input_quant) self.out_q = QuantIdentity(matmul_input_quant) self.out_k = QuantIdentity(matmul_input_quant)