diff --git a/src/brevitas/graph/base.py b/src/brevitas/graph/base.py index 620312641..def3f7070 100644 --- a/src/brevitas/graph/base.py +++ b/src/brevitas/graph/base.py @@ -3,6 +3,7 @@ from abc import ABC from abc import abstractmethod +import inspect from inspect import getcallargs import torch @@ -120,16 +121,23 @@ def _module_attributes(self, module): attrs['bias'] = module.bias return attrs - def _evaluate_new_kwargs(self, new_kwargs, old_module): + def _evaluate_new_kwargs(self, new_kwargs, old_module, name): update_dict = dict() for k, v in self.new_module_kwargs.items(): if islambda(v): - v = v(old_module) + if name is not None: + # Two types of lambdas are admitted now, with/without the name of the module as input + if len(inspect.getfullargspec(v).args) == 2: + v = v(old_module, name) + elif len(inspect.getfullargspec(v).args) == 1: + v = v(old_module) + else: + v = v(old_module) update_dict[k] = v new_kwargs.update(update_dict) return new_kwargs - def _init_new_module(self, old_module: Module): + def _init_new_module(self, old_module: Module, name=None): # get attributes of original module new_kwargs = self._module_attributes(old_module) # transforms attribute of original module, e.g. bias Parameter -> bool @@ -138,7 +146,7 @@ def _init_new_module(self, old_module: Module): new_module_signature_keys = signature_keys(self.new_module_class) new_kwargs = {k: v for k, v in new_kwargs.items() if k in new_module_signature_keys} # update with kwargs passed to the rewriter - new_kwargs = self._evaluate_new_kwargs(new_kwargs, old_module) + new_kwargs = self._evaluate_new_kwargs(new_kwargs, old_module, name) # init the new module new_module = self.new_module_class(**new_kwargs) return new_module @@ -204,10 +212,10 @@ def __init__(self, old_module_instance, new_module_class, **kwargs): self.old_module_instance = old_module_instance def apply(self, model: GraphModule) -> GraphModule: - for old_module in model.modules(): + for name, old_module in model.named_modules(): if old_module is self.old_module_instance: # init the new module based on the old one - new_module = self._init_new_module(old_module) + new_module = self._init_new_module(old_module, name) self._replace_old_module(model, old_module, new_module) break return model diff --git a/src/brevitas_examples/common/generative/quantize.py b/src/brevitas_examples/common/generative/quantize.py index 36bac29d5..a86de3b76 100644 --- a/src/brevitas_examples/common/generative/quantize.py +++ b/src/brevitas_examples/common/generative/quantize.py @@ -473,6 +473,7 @@ def quantize_model( quantize_input_zero_point=False, quantize_embedding=False, use_ocp=False, + use_fnuz=False, device=None, weight_kwargs=None, input_kwargs=None): @@ -497,6 +498,7 @@ def quantize_model( input_group_size, quantize_input_zero_point, use_ocp, + use_fnuz, device, weight_kwargs, input_kwargs) diff --git a/src/brevitas_examples/stable_diffusion/README.md b/src/brevitas_examples/stable_diffusion/README.md index 1685bd4a9..a51a06df5 100644 --- a/src/brevitas_examples/stable_diffusion/README.md +++ b/src/brevitas_examples/stable_diffusion/README.md @@ -77,6 +77,7 @@ usage: main.py [-h] [-m MODEL] [-d DEVICE] [-b BATCH_SIZE] [--prompt PROMPT] [--conv-input-bit-width CONV_INPUT_BIT_WIDTH] [--act-eq-alpha ACT_EQ_ALPHA] [--linear-input-bit-width LINEAR_INPUT_BIT_WIDTH] + [--linear-output-bit-width LINEAR_OUTPUT_BIT_WIDTH] [--weight-param-method {stats,mse}] [--input-param-method {stats,mse}] [--input-scale-stats-op {minmax,percentile}] @@ -96,15 +97,17 @@ usage: main.py [-h] [-m MODEL] [-d DEVICE] [-b BATCH_SIZE] [--prompt PROMPT] [--quantize-input-zero-point | --no-quantize-input-zero-point] [--export-cpu-float32 | --no-export-cpu-float32] [--use-mlperf-inference | --no-use-mlperf-inference] - [--use-ocp | --no-use-ocp] [--use-nfuz | --no-use-nfuz] + [--use-ocp | --no-use-ocp] [--use-fnuz | --no-use-fnuz] [--use-negative-prompts | --no-use-negative-prompts] [--dry-run | --no-dry-run] [--quantize-sdp-1 | --no-quantize-sdp-1] [--quantize-sdp-2 | --no-quantize-sdp-2] + [--override-conv-quant-config | --no-override-conv-quant-config] + [--vae-fp16-fix | --no-vae-fp16-fix] Stable Diffusion quantization -options: +optional arguments: -h, --help show this help message and exit -m MODEL, --model MODEL Path or name of the model. @@ -176,6 +179,8 @@ options: Alpha for activation equalization. Default: 0.9 --linear-input-bit-width LINEAR_INPUT_BIT_WIDTH Input bit width. Default: 0 (not quantized). + --linear-output-bit-width LINEAR_OUTPUT_BIT_WIDTH + Input bit width. Default: 0 (not quantized). --weight-param-method {stats,mse} How scales/zero-point are determined. Default: stats. --input-param-method {stats,mse} @@ -241,9 +246,9 @@ options: True --no-use-ocp Disable Use OCP format for float quantization. Default: True - --use-nfuz Enable Use NFUZ format for float quantization. + --use-fnuz Enable Use FNUZ format for float quantization. Default: True - --no-use-nfuz Disable Use NFUZ format for float quantization. + --no-use-fnuz Disable Use FNUZ format for float quantization. Default: True --use-negative-prompts Enable Use negative prompts during @@ -259,5 +264,14 @@ options: --no-quantize-sdp-1 Disable Quantize SDP. Default: Disabled --quantize-sdp-2 Enable Quantize SDP. Default: Disabled --no-quantize-sdp-2 Disable Quantize SDP. Default: Disabled - + --override-conv-quant-config + Enable Quantize Convolutions in the same way as SDP + (i.e., FP8). Default: Disabled + --no-override-conv-quant-config + Disable Quantize Convolutions in the same way as SDP + (i.e., FP8). Default: Disabled + --vae-fp16-fix Enable Rescale the VAE to not go NaN with FP16. + Default: Disabled + --no-vae-fp16-fix Disable Rescale the VAE to not go NaN with FP16. + Default: Disabled ``` diff --git a/src/brevitas_examples/stable_diffusion/main.py b/src/brevitas_examples/stable_diffusion/main.py index d09ee8fde..cb1f42920 100644 --- a/src/brevitas_examples/stable_diffusion/main.py +++ b/src/brevitas_examples/stable_diffusion/main.py @@ -7,11 +7,13 @@ from datetime import datetime from functools import partial import json +import math import os import time from dependencies import value from diffusers import DiffusionPipeline +from diffusers import EulerDiscreteScheduler from diffusers import StableDiffusionXLPipeline from diffusers.models.attention_processor import Attention from diffusers.models.attention_processor import AttnProcessor @@ -37,7 +39,6 @@ from brevitas.utils.torch_utils import KwargsForwardHook from brevitas_examples.common.generative.quantize import generate_quant_maps from brevitas_examples.common.generative.quantize import generate_quantizers -from brevitas_examples.common.generative.quantize import quantize_model from brevitas_examples.common.parse_utils import add_bool_arg from brevitas_examples.common.parse_utils import quant_format_validator from brevitas_examples.llm.llm_quant.export import BlockQuantProxyLevelManager @@ -46,7 +47,9 @@ from brevitas_examples.stable_diffusion.sd_quant.constants import SD_XL_EMBEDDINGS_SHAPE from brevitas_examples.stable_diffusion.sd_quant.export import export_onnx from brevitas_examples.stable_diffusion.sd_quant.export import export_quant_params +from brevitas_examples.stable_diffusion.sd_quant.nn import AttnProcessor2_0 from brevitas_examples.stable_diffusion.sd_quant.nn import QuantAttention +from brevitas_examples.stable_diffusion.sd_quant.nn import QuantizableAttention from brevitas_examples.stable_diffusion.sd_quant.utils import generate_latents from brevitas_examples.stable_diffusion.sd_quant.utils import generate_unet_21_rand_inputs from brevitas_examples.stable_diffusion.sd_quant.utils import generate_unet_xl_rand_inputs @@ -152,13 +155,14 @@ def main(args): latents = None if args.path_to_latents is not None: - latents = torch.load(args.path_to_latents).to(torch.float16) + latents = torch.load(args.path_to_latents).to(dtype) # Create output dir. Move to tmp if None ts = datetime.fromtimestamp(time.time()) str_ts = ts.strftime("%Y%m%d_%H%M%S") output_dir = os.path.join(args.output_path, f'{str_ts}') os.mkdir(output_dir) + print(f"Saving results in {output_dir}") # Dump args to json with open(os.path.join(output_dir, 'args.json'), 'w') as fp: @@ -169,7 +173,23 @@ def main(args): # Load model from float checkpoint print(f"Loading model from {args.model}...") - pipe = DiffusionPipeline.from_pretrained(args.model, torch_dtype=dtype) + variant = 'fp16' if dtype == torch.float16 else None + pipe = DiffusionPipeline.from_pretrained( + args.model, torch_dtype=dtype, variant=variant, use_safetensors=True) + pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config) + pipe.vae.config.force_upcast = True + if args.share_qkv_quant: + rewriter = ModuleToModuleByClass( + Attention, + QuantizableAttention, + query_dim=lambda module: module.to_q.in_features, + dim_head=lambda module: math.ceil(1 / (module.scale ** 2)), + bias=lambda module: hasattr(module.to_q, 'bias') and module.to_q.bias is not None, + processor=AttnProcessor2_0(), + dtype=dtype, + norm_num_groups=lambda module: None + if module.group_norm is None else module.group_norm.num_groups) + rewriter.apply(pipe.unet) print(f"Model loaded from {args.model}.") # Move model to target device @@ -200,10 +220,19 @@ def main(args): # Extract list of layers to avoid blacklist = [] + non_blacklist = dict() for name, _ in pipe.unet.named_modules(): if 'time_emb' in name: blacklist.append(name.split('.')[-1]) - print(f"Blacklisted layers: {blacklist}") + else: + if isinstance(_, (torch.nn.Linear, torch.nn.Conv2d)): + name_to_add = name.split('.')[-1] + if name_to_add not in non_blacklist: + non_blacklist[name_to_add] = 1 + else: + non_blacklist[name_to_add] += 1 + print(f"Blacklisted layers: {set(blacklist)}") + print(f"Non blacklisted layers: {non_blacklist}") # Make sure there all LoRA layers are fused first, otherwise raise an error for m in pipe.unet.modules(): @@ -212,7 +241,7 @@ def main(args): if args.activation_equalization: pipe.set_progress_bar_config(disable=True) - with activation_equalization_mode( + with torch.no_grad(), activation_equalization_mode( pipe.unet, alpha=args.act_eq_alpha, layerwise=True, @@ -261,8 +290,6 @@ def input_bit_width(module): return args.linear_input_bit_width elif isinstance(module, nn.Conv2d): return args.conv_input_bit_width - elif isinstance(module, QuantIdentity): - return args.quant_identity_bit_width else: raise RuntimeError(f"Module {module} not supported.") @@ -332,7 +359,9 @@ def input_zp_stats_type(): 'weight_quant'] layer_map[torch.nn.Conv2d] = (layer_map[torch.nn.Conv2d][0], conv_qkwargs) - if args.quantize_sdp_1 or args.quantize_sdp_2: + if args.quantize_sdp: + assert args.share_qkv_quant, "Currently SDPA quantization is supported only with shared QKV quantization" + # TODO: reformat this float_sdpa_quantizers = generate_quantizers( dtype=dtype, device=args.device, @@ -345,7 +374,7 @@ def input_zp_stats_type(): weight_group_size=args.weight_group_size, quantize_weight_zero_point=args.quantize_weight_zero_point, quantize_input_zero_point=args.quantize_input_zero_point, - input_bit_width=input_bit_width, + input_bit_width=args.linear_output_bit_width, input_quant_format='e4m3', input_scale_type=args.input_scale_type, input_scale_precision=args.input_scale_precision, @@ -358,30 +387,29 @@ def input_zp_stats_type(): # We generate all quantizers, but we are only interested in activation quantization for # the output of softmax and the output of QKV input_quant = float_sdpa_quantizers[0] - input_quant = input_quant.let(**{'bit_width': args.linear_output_bit_width}) - if args.quantize_sdp_2: - rewriter = ModuleToModuleByClass( - Attention, - QuantAttention, - softmax_output_quant=input_quant, - query_dim=lambda module: module.to_q.in_features, - dim_head=lambda module: int(1 / (module.scale ** 2)), - processor=AttnProcessor(), - is_equalized=args.activation_equalization) - import brevitas.config as config - config.IGNORE_MISSING_KEYS = True - pipe.unet = rewriter.apply(pipe.unet) - config.IGNORE_MISSING_KEYS = False - pipe.unet = pipe.unet.to(args.device) - pipe.unet = pipe.unet.to(dtype) - quant_kwargs = layer_map[torch.nn.Linear][1] - what_to_quantize = [] - if args.quantize_sdp_1: - what_to_quantize.extend(['to_q', 'to_k']) - if args.quantize_sdp_2: - what_to_quantize.extend(['to_v']) - quant_kwargs['output_quant'] = lambda module, name: input_quant if any(ending in name for ending in what_to_quantize) else None - layer_map[torch.nn.Linear] = (layer_map[torch.nn.Linear][0], quant_kwargs) + rewriter = ModuleToModuleByClass( + Attention, + QuantAttention, + matmul_input_quant=input_quant, + query_dim=lambda module: module.to_q.in_features, + dim_head=lambda module: math.ceil(1 / (module.scale ** 2)), + processor=AttnProcessor(), + is_equalized=args.activation_equalization) + import brevitas.config as config + config.IGNORE_MISSING_KEYS = True + pipe.unet = rewriter.apply(pipe.unet) + config.IGNORE_MISSING_KEYS = False + pipe.unet = pipe.unet.to(args.device) + pipe.unet = pipe.unet.to(dtype) + + if args.override_conv_quant_config: + print( + f"Overriding Conv2d quantization to weights: {float_sdpa_quantizers[1]}, inputs: {float_sdpa_quantizers[2]}" + ) + conv_qkwargs = layer_map[torch.nn.Conv2d][1] + conv_qkwargs['input_quant'] = float_sdpa_quantizers[2] + conv_qkwargs['weight_quant'] = float_sdpa_quantizers[1] + layer_map[torch.nn.Conv2d] = (layer_map[torch.nn.Conv2d][0], conv_qkwargs) pipe.unet = layerwise_quantize( model=pipe.unet, compute_layer_map=layer_map, name_blacklist=blacklist) @@ -405,11 +433,13 @@ def input_zp_stats_type(): if args.load_checkpoint is not None: with load_quant_model_mode(pipe.unet): pipe = pipe.to('cpu') + print(f"Loading checkpoint: {args.load_checkpoint}... ", end="") pipe.unet.load_state_dict(torch.load(args.load_checkpoint, map_location='cpu')) - pipe = pipe.to(args.device) + print(f"Checkpoint loaded!") + pipe = pipe.to(args.device) elif not args.dry_run: - if (args.linear_input_bit_width is not None or - args.conv_input_bit_width is not None) and args.input_scale_type == 'static': + if (args.linear_input_bit_width > 0 or args.conv_input_bit_width > 0 or + args.linear_output_bit_width > 0) and args.input_scale_type == 'static': print("Applying activation calibration") with torch.no_grad(), calibration_mode(pipe.unet): run_val_inference( @@ -447,7 +477,7 @@ def input_zp_stats_type(): torch.cuda.empty_cache() if args.bias_correction: print("Applying bias correction") - with bias_correction_mode(pipe.unet): + with torch.no_grad(), bias_correction_mode(pipe.unet): run_val_inference( pipe, args.resolution, @@ -460,15 +490,41 @@ def input_zp_stats_type(): test_latents=latents, guidance_scale=args.guidance_scale) + if args.vae_fp16_fix and is_sd_xl: + vae_fix_scale = 128 + layer_whitelist = [ + "decoder.up_blocks.2.upsamplers.0.conv", + "decoder.up_blocks.3.resnets.0.conv2", + "decoder.up_blocks.3.resnets.1.conv2", + "decoder.up_blocks.3.resnets.2.conv2"] + #layer_whitelist = [ + # "decoder.up_blocks.3.resnets.0.conv_shortcut", + # "decoder.up_blocks.3.resnets.0.conv2", + # "decoder.up_blocks.3.resnets.1.conv2", + # "decoder.up_blocks.3.resnets.2.conv2"] + corrected_layers = [] + with torch.no_grad(): + for name, module in pipe.vae.named_modules(): + if name in layer_whitelist: + corrected_layers.append(name) + module.weight /= vae_fix_scale + if module.bias is not None: + module.bias /= vae_fix_scale + print(f"Corrected layers in VAE: {corrected_layers}") + if args.checkpoint_name is not None and args.load_checkpoint is None: torch.save(pipe.unet.state_dict(), os.path.join(output_dir, args.checkpoint_name)) + if args.vae_fp16_fix: + torch.save( + pipe.vae.state_dict(), os.path.join(output_dir, f"vae_{args.checkpoint_name}")) # Perform inference if args.prompt > 0 and not args.dry_run: # with brevitas_proxy_inference_mode(pipe.unet): if args.use_mlperf_inference: print(f"Computing accuracy with MLPerf pipeline") - compute_mlperf_fid(args.model, args.path_to_coco, pipe, args.prompt, output_dir) + compute_mlperf_fid( + args.model, args.path_to_coco, pipe, args.prompt, output_dir, not args.vae_fp16_fix) else: print(f"Computing accuracy on default prompt") testing_prompts = TESTING_PROMPTS[:args.prompt] @@ -530,7 +586,8 @@ def input_zp_stats_type(): export_manager.change_weight_export(export_weight_q_node=args.export_weight_q_node) export_onnx(pipe, trace_inputs, output_dir, export_manager) if args.export_target == 'params_only': - export_quant_params(pipe, output_dir) + pipe.to('cpu') + export_quant_params(pipe, output_dir, export_vae=args.vae_fp16_fix) if __name__ == "__main__": @@ -648,6 +705,11 @@ def input_zp_stats_type(): type=int, default=0, help='Input bit width. Default: 0 (not quantized).') + parser.add_argument( + '--linear-output-bit-width', + type=int, + default=0, + help='Input bit width. Default: 0 (not quantized).') parser.add_argument( '--weight-param-method', type=str, @@ -773,8 +835,17 @@ def input_zp_stats_type(): 'dry-run', default=False, help='Generate a quantized model without any calibration. Default: Disabled') - add_bool_arg(parser, 'quantize-sdp-1', default=False, help='Quantize SDP. Default: Disabled') - add_bool_arg(parser, 'quantize-sdp-2', default=False, help='Quantize SDP. Default: Disabled') + add_bool_arg(parser, 'quantize-sdp', default=False, help='Quantize SDP. Default: Disabled') + add_bool_arg( + parser, + 'override-conv-quant-config', + default=False, + help='Quantize Convolutions in the same way as SDP (i.e., FP8). Default: Disabled') + add_bool_arg( + parser, + 'vae-fp16-fix', + default=False, + help='Rescale the VAE to not go NaN with FP16. Default: Disabled') args = parser.parse_args() print("Args: " + str(vars(args))) main(args) diff --git a/src/brevitas_examples/stable_diffusion/mlperf_evaluation/accuracy.py b/src/brevitas_examples/stable_diffusion/mlperf_evaluation/accuracy.py index 8e10f107e..6fb987967 100644 --- a/src/brevitas_examples/stable_diffusion/mlperf_evaluation/accuracy.py +++ b/src/brevitas_examples/stable_diffusion/mlperf_evaluation/accuracy.py @@ -601,7 +601,8 @@ def compute_mlperf_fid( path_to_coco, model_to_replace=None, samples_to_evaluate=500, - output_dir=None): + output_dir=None, + vae_force_upcast=True): assert os.path.isfile(path_to_coco + '/tools/val2014.npz'), "Val2014.npz file required. Check the MLPerf directory for instructions" @@ -614,8 +615,11 @@ def compute_mlperf_fid( model.load() if model_to_replace is not None: - model.pipe = model_to_replace + model.pipe.unet = model_to_replace.unet + if not vae_force_upcast: + model.pipe.vae = model.pipe.vae + model.pipe.vae.config.force_upcast = vae_force_upcast ds = Coco( data_path=path_to_coco, name="coco-1024", diff --git a/src/brevitas_examples/stable_diffusion/mlperf_evaluation/requirements.txt b/src/brevitas_examples/stable_diffusion/mlperf_evaluation/requirements.txt index 690f7b0b0..871c88554 100644 --- a/src/brevitas_examples/stable_diffusion/mlperf_evaluation/requirements.txt +++ b/src/brevitas_examples/stable_diffusion/mlperf_evaluation/requirements.txt @@ -2,6 +2,7 @@ accelerate==0.23.0 diffusers==0.21.2 open-clip-torch==2.7.0 opencv-python==4.8.1.78 +pandas==2.2.2 pycocotools==2.0.7 scipy==1.9.1 torchmetrics[image]==1.2.0 diff --git a/src/brevitas_examples/stable_diffusion/sd_quant/export.py b/src/brevitas_examples/stable_diffusion/sd_quant/export.py index 64bcac34f..09c331951 100644 --- a/src/brevitas_examples/stable_diffusion/sd_quant/export.py +++ b/src/brevitas_examples/stable_diffusion/sd_quant/export.py @@ -53,10 +53,13 @@ def handle_quant_param(layer, layer_dict): return layer_dict -def export_quant_params(pipe, output_dir): +def export_quant_params(pipe, output_dir, export_vae=False): quant_output_path = os.path.join(output_dir, 'quant_params.json') - output_path = os.path.join(output_dir, 'params.safetensors') - print(f"Saving unet to {output_path} ...") + unet_output_path = os.path.join(output_dir, 'params.safetensors') + print(f"Saving unet to {unet_output_path} ...") + if export_vae: + vae_output_path = os.path.join(output_dir, 'vae.safetensors') + print(f"Saving vae to {vae_output_path} ...") from brevitas.export.onnx.standard.qcdq.manager import StdQCDQONNXManager quant_params = dict() state_dict = pipe.unet.state_dict() @@ -93,11 +96,13 @@ def export_quant_params(pipe, output_dir): elif isinstance( module, QuantWeightBiasInputOutputLayer) and id(module) not in handled_quant_layers: + full_name = name layer_dict = dict() layer_dict = handle_quant_param(module, layer_dict) quant_params[full_name] = layer_dict handled_quant_layers.add(id(module)) elif isinstance(module, QuantNonLinearActLayer): + full_name = name layer_dict = dict() act_scale = module.act_quant.export_handler.symbolic_kwargs[ 'dequantize_symbolic_kwargs']['scale'].data @@ -112,4 +117,6 @@ def export_quant_params(pipe, output_dir): handled_quant_layers.add(id(module)) with open(quant_output_path, 'w') as file: json.dump(quant_params, file, indent=" ") - save_file(state_dict, output_path) + save_file(state_dict, unet_output_path) + if export_vae: + save_file(pipe.vae.state_dict(), vae_output_path) diff --git a/src/brevitas_examples/stable_diffusion/sd_quant/nn.py b/src/brevitas_examples/stable_diffusion/sd_quant/nn.py index e240c3a36..5a6c23ab9 100644 --- a/src/brevitas_examples/stable_diffusion/sd_quant/nn.py +++ b/src/brevitas_examples/stable_diffusion/sd_quant/nn.py @@ -1,7 +1,26 @@ -from typing import Optional +# This code was taken and modified from the Hugging Face Diffusers repository under the following +# LICENSE: + +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Mapping, Optional from diffusers.models.attention_processor import Attention +from diffusers.models.lora import LoRACompatibleLinear import torch +import torch.nn.functional as F from brevitas.graph.base import ModuleInstanceToModuleInstance from brevitas.nn.equalized_layer import EqualizedModule @@ -10,7 +29,97 @@ from brevitas.quant_tensor import _unpack_quant_tensor -class QuantAttention(Attention): +class QuantizableAttention(Attention): + + def __init__( + self, + query_dim: int, + cross_attention_dim: Optional[int] = None, + heads: int = 8, + dim_head: int = 64, + dropout: float = 0.0, + bias=False, + upcast_attention: bool = False, + upcast_softmax: bool = False, + cross_attention_norm: Optional[str] = None, + cross_attention_norm_num_groups: int = 32, + added_kv_proj_dim: Optional[int] = None, + norm_num_groups: Optional[int] = None, + spatial_norm_dim: Optional[int] = None, + out_bias: bool = True, + scale_qk: bool = True, + only_cross_attention: bool = False, + eps: float = 1e-5, + rescale_output_factor: float = 1.0, + residual_connection: bool = False, + _from_deprecated_attn_block=False, + dtype=torch.float32, + processor: Optional["AttnProcessor"] = None): + + super().__init__( + query_dim, + cross_attention_dim, + heads, + dim_head, + dropout, + bias, + upcast_attention, + upcast_softmax, + cross_attention_norm, + cross_attention_norm_num_groups, + added_kv_proj_dim, + norm_num_groups, + spatial_norm_dim, + out_bias, + scale_qk, + only_cross_attention, + eps, + rescale_output_factor, + residual_connection, + _from_deprecated_attn_block, + processor, + ) + if self.to_q.weight.shape == self.to_k.weight.shape: + self.to_qkv = LoRACompatibleLinear( + query_dim, 3 * self.inner_dim, bias=bias, dtype=dtype) + + del self.to_q + del self.to_k + del self.to_v + + else: + self.to_q = LoRACompatibleLinear(query_dim, self.inner_dim, bias=bias, dtype=dtype) + self.to_kv = LoRACompatibleLinear( + self.cross_attention_dim, 2 * self.inner_dim, bias=bias, dtype=dtype) + + del self.to_k + del self.to_v + + self.to_out = torch.nn.ModuleList([]) + self.to_out.append( + LoRACompatibleLinear(self.inner_dim, query_dim, bias=out_bias, dtype=dtype)) + self.to_out.append(torch.nn.Dropout(dropout)) + + def load_state_dict( + self, state_dict: Mapping[str, Any], strict: bool = True, assign: bool = False): + if hasattr(self, 'to_qkv') and 'to_q.weight' in state_dict: + new_weights = torch.cat( + [state_dict['to_q.weight'], state_dict['to_k.weight'], state_dict['to_v.weight']], + dim=0) + state_dict['to_qkv.weight'] = new_weights + + del state_dict['to_q.weight'] + del state_dict['to_k.weight'] + del state_dict['to_v.weight'] + elif hasattr(self, 'to_kv') and 'to_k.weight' in state_dict: + new_weights = torch.cat([state_dict['to_k.weight'], state_dict['to_v.weight']], dim=0) + state_dict['to_kv.weight'] = new_weights + del state_dict['to_k.weight'] + del state_dict['to_v.weight'] + return super().load_state_dict(state_dict, strict, assign) + + +class QuantAttention(QuantizableAttention): def __init__( self, @@ -35,8 +144,10 @@ def __init__( residual_connection: bool = False, _from_deprecated_attn_block=False, processor: Optional["AttnProcessor"] = None, - softmax_output_quant=None, + matmul_input_quant=None, + dtype=torch.float32, is_equalized=False): + super().__init__( query_dim, cross_attention_dim, @@ -58,10 +169,14 @@ def __init__( rescale_output_factor, residual_connection, _from_deprecated_attn_block, + dtype, processor, ) - self.output_softmax_quant = QuantIdentity(softmax_output_quant) + self.output_softmax_quant = QuantIdentity(matmul_input_quant) + self.out_q = QuantIdentity(matmul_input_quant) + self.out_k = QuantIdentity(matmul_input_quant) + self.out_v = QuantIdentity(matmul_input_quant) if is_equalized: replacements = [] for n, m in self.named_modules(): @@ -125,3 +240,166 @@ def get_attention_scores( attention_probs = _unpack_quant_tensor(self.output_softmax_quant(attention_probs)) return attention_probs + + +class AttnProcessor: + r""" + Default processor for performing attention-related computations. + """ + + def __call__( + self, + attn: Attention, + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + temb=None, + scale=1.0, + ): + residual = hidden_states + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + if encoder_hidden_states is None: + assert attn.norm_cross is None, "Not supported" + query, key, value = attn.to_qkv(hidden_states, scale=scale).chunk(3, dim=-1) + + else: + assert not hasattr(attn, 'to_qkv'), 'Model not created correctly' + query = attn.to_q(hidden_states, scale=scale) + if attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + key, value = attn.to_kv(encoder_hidden_states, scale=scale).chunk(2, dim=-1) + if hasattr(attn, 'out_q'): + query = _unpack_quant_tensor(attn.out_q(query)) + if hasattr(attn, 'out_k'): + key = _unpack_quant_tensor(attn.out_k(key)) + if hasattr(attn, 'out_v'): + value = _unpack_quant_tensor(attn.out_v(value)) + query = attn.head_to_batch_dim(query) + key = attn.head_to_batch_dim(key) + value = attn.head_to_batch_dim(value) + + attention_probs = attn.get_attention_scores(query, key, attention_mask) + hidden_states = torch.bmm(attention_probs, value) + hidden_states = attn.batch_to_head_dim(hidden_states) + + # linear proj + hidden_states = attn.to_out[0](hidden_states, scale=scale) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, + -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + +class AttnProcessor2_0: + r""" + Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). + """ + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError( + "AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") + + def __call__( + self, + attn: Attention, + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + temb=None, + scale: float = 1.0, + ): + residual = hidden_states + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask( + attention_mask, sequence_length, batch_size) + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + attention_mask = attention_mask.view( + batch_size, attn.heads, -1, attention_mask.shape[-1]) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + if encoder_hidden_states is None: + assert attn.norm_cross is None, "Not supported" + query, key, value = attn.to_qkv(hidden_states, scale=scale).chunk(3, dim=-1) + + else: + assert not hasattr(attn, 'to_qkv'), 'Model not created correctly' + query = attn.to_q(hidden_states, scale=scale) + if attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + key, value = attn.to_kv(encoder_hidden_states, scale=scale).chunk(2, dim=-1) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False) + + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + # linear proj + hidden_states = attn.to_out[0](hidden_states, scale=scale) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, + -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states diff --git a/tests/brevitas/graph/test_transforms.py b/tests/brevitas/graph/test_transforms.py index c58d9d828..875d5a52c 100644 --- a/tests/brevitas/graph/test_transforms.py +++ b/tests/brevitas/graph/test_transforms.py @@ -287,6 +287,6 @@ def forward(self, x): model = TestModel() assert model.conv.stride == (1, 1) - kwargs = {'stride': lambda module: 2 if module.in_channels == 3 else 1} + kwargs = {'stride': lambda module, name: 2 if module.in_channels == 3 else 1} model = ModuleToModuleByInstance(model.conv, nn.Conv2d, **kwargs).apply(model) assert model.conv.stride == (2, 2)