From 9181cff58022f8782764d55e522607ffe1a7ffcf Mon Sep 17 00:00:00 2001 From: Nick Fraser Date: Tue, 25 Jun 2024 21:14:18 +0100 Subject: [PATCH 01/24] Fix (examples/sdxl): Fix issue setting device when checkpoint is loaded. --- src/brevitas_examples/stable_diffusion/main.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/brevitas_examples/stable_diffusion/main.py b/src/brevitas_examples/stable_diffusion/main.py index d09ee8fde..295cc4b87 100644 --- a/src/brevitas_examples/stable_diffusion/main.py +++ b/src/brevitas_examples/stable_diffusion/main.py @@ -406,7 +406,7 @@ def input_zp_stats_type(): with load_quant_model_mode(pipe.unet): pipe = pipe.to('cpu') pipe.unet.load_state_dict(torch.load(args.load_checkpoint, map_location='cpu')) - pipe = pipe.to(args.device) + pipe = pipe.to(args.device) elif not args.dry_run: if (args.linear_input_bit_width is not None or args.conv_input_bit_width is not None) and args.input_scale_type == 'static': From 4ff1b4377a9dc5c4cbcb319dfd3adaca58f863cd Mon Sep 17 00:00:00 2001 From: Nick Fraser Date: Tue, 25 Jun 2024 21:15:15 +0100 Subject: [PATCH 02/24] Fix (example/sdxl): Added argument for linear output bitwidth. --- src/brevitas_examples/stable_diffusion/main.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/brevitas_examples/stable_diffusion/main.py b/src/brevitas_examples/stable_diffusion/main.py index 295cc4b87..95763060c 100644 --- a/src/brevitas_examples/stable_diffusion/main.py +++ b/src/brevitas_examples/stable_diffusion/main.py @@ -648,6 +648,11 @@ def input_zp_stats_type(): type=int, default=0, help='Input bit width. Default: 0 (not quantized).') + parser.add_argument( + '--linear-output-bit-width', + type=int, + default=0, + help='Input bit width. Default: 0 (not quantized).') parser.add_argument( '--weight-param-method', type=str, From ce38f862ea1d4b0fe24973fdb9ca352855a6a57e Mon Sep 17 00:00:00 2001 From: Nick Fraser Date: Tue, 25 Jun 2024 21:17:20 +0100 Subject: [PATCH 03/24] Fix (example/sdxl): Fix when replacing `diffusers.models.lora.LoRACompatibleLinear` --- src/brevitas/graph/base.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/src/brevitas/graph/base.py b/src/brevitas/graph/base.py index 620312641..e980e02fc 100644 --- a/src/brevitas/graph/base.py +++ b/src/brevitas/graph/base.py @@ -120,16 +120,19 @@ def _module_attributes(self, module): attrs['bias'] = module.bias return attrs - def _evaluate_new_kwargs(self, new_kwargs, old_module): + def _evaluate_new_kwargs(self, new_kwargs, old_module, name): update_dict = dict() for k, v in self.new_module_kwargs.items(): if islambda(v): - v = v(old_module) + if name is not None: + v = v(old_module, name) + else: + v = v(old_module) update_dict[k] = v new_kwargs.update(update_dict) return new_kwargs - def _init_new_module(self, old_module: Module): + def _init_new_module(self, old_module: Module, name=None): # get attributes of original module new_kwargs = self._module_attributes(old_module) # transforms attribute of original module, e.g. bias Parameter -> bool @@ -138,7 +141,7 @@ def _init_new_module(self, old_module: Module): new_module_signature_keys = signature_keys(self.new_module_class) new_kwargs = {k: v for k, v in new_kwargs.items() if k in new_module_signature_keys} # update with kwargs passed to the rewriter - new_kwargs = self._evaluate_new_kwargs(new_kwargs, old_module) + new_kwargs = self._evaluate_new_kwargs(new_kwargs, old_module, name) # init the new module new_module = self.new_module_class(**new_kwargs) return new_module @@ -204,10 +207,10 @@ def __init__(self, old_module_instance, new_module_class, **kwargs): self.old_module_instance = old_module_instance def apply(self, model: GraphModule) -> GraphModule: - for old_module in model.modules(): + for name, old_module in model.named_modules(): if old_module is self.old_module_instance: # init the new module based on the old one - new_module = self._init_new_module(old_module) + new_module = self._init_new_module(old_module, name) self._replace_old_module(model, old_module, new_module) break return model From 975c9ee4e987e0862769bf9878b7bd0299578f52 Mon Sep 17 00:00:00 2001 From: Nick Fraser Date: Wed, 26 Jun 2024 10:15:50 +0100 Subject: [PATCH 04/24] Fix (example/sdxl): print output directory. --- src/brevitas_examples/stable_diffusion/main.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/brevitas_examples/stable_diffusion/main.py b/src/brevitas_examples/stable_diffusion/main.py index 95763060c..859b97f2d 100644 --- a/src/brevitas_examples/stable_diffusion/main.py +++ b/src/brevitas_examples/stable_diffusion/main.py @@ -159,6 +159,7 @@ def main(args): str_ts = ts.strftime("%Y%m%d_%H%M%S") output_dir = os.path.join(args.output_path, f'{str_ts}') os.mkdir(output_dir) + print(f"Saving results in {output_dir}") # Dump args to json with open(os.path.join(output_dir, 'args.json'), 'w') as fp: From b3ed0d81f07701f024b9a837e9945fdd39dd0a5d Mon Sep 17 00:00:00 2001 From: Nick Fraser Date: Wed, 26 Jun 2024 10:47:18 +0100 Subject: [PATCH 05/24] Feat (example/sdxl): add extra option to quantize conv layers like SDP --- src/brevitas_examples/stable_diffusion/main.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/brevitas_examples/stable_diffusion/main.py b/src/brevitas_examples/stable_diffusion/main.py index 859b97f2d..8d58ab4b9 100644 --- a/src/brevitas_examples/stable_diffusion/main.py +++ b/src/brevitas_examples/stable_diffusion/main.py @@ -384,6 +384,13 @@ def input_zp_stats_type(): quant_kwargs['output_quant'] = lambda module, name: input_quant if any(ending in name for ending in what_to_quantize) else None layer_map[torch.nn.Linear] = (layer_map[torch.nn.Linear][0], quant_kwargs) + if args.override_conv_quant_config: + print(f"Overriding Conv2d quantization to weights: {float_sdpa_quantizers[1]}, inputs: {float_sdpa_quantizers[2]}") + conv_qkwargs = layer_map[torch.nn.Conv2d][1] + conv_qkwargs['input_quant'] = float_sdpa_quantizers[2] + conv_qkwargs['weight_quant'] = float_sdpa_quantizers[1] + layer_map[torch.nn.Conv2d] = (layer_map[torch.nn.Conv2d][0], conv_qkwargs) + pipe.unet = layerwise_quantize( model=pipe.unet, compute_layer_map=layer_map, name_blacklist=blacklist) print("Model quantization applied.") @@ -781,6 +788,7 @@ def input_zp_stats_type(): help='Generate a quantized model without any calibration. Default: Disabled') add_bool_arg(parser, 'quantize-sdp-1', default=False, help='Quantize SDP. Default: Disabled') add_bool_arg(parser, 'quantize-sdp-2', default=False, help='Quantize SDP. Default: Disabled') + add_bool_arg(parser, 'override-conv-quant-config', default=False, help='Quantize Convolutions in the same way as SDP (i.e., FP8). Default: Disabled') args = parser.parse_args() print("Args: " + str(vars(args))) main(args) From 11f9dc8b4608a2e61f4271d6c0eba389ec0d0e2b Mon Sep 17 00:00:00 2001 From: Nick Fraser Date: Wed, 26 Jun 2024 10:51:11 +0100 Subject: [PATCH 06/24] fix (example/sdxl): Updated usage README. --- .../stable_diffusion/README.md | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/src/brevitas_examples/stable_diffusion/README.md b/src/brevitas_examples/stable_diffusion/README.md index 1685bd4a9..e7019b3b5 100644 --- a/src/brevitas_examples/stable_diffusion/README.md +++ b/src/brevitas_examples/stable_diffusion/README.md @@ -77,6 +77,7 @@ usage: main.py [-h] [-m MODEL] [-d DEVICE] [-b BATCH_SIZE] [--prompt PROMPT] [--conv-input-bit-width CONV_INPUT_BIT_WIDTH] [--act-eq-alpha ACT_EQ_ALPHA] [--linear-input-bit-width LINEAR_INPUT_BIT_WIDTH] + [--linear-output-bit-width LINEAR_OUTPUT_BIT_WIDTH] [--weight-param-method {stats,mse}] [--input-param-method {stats,mse}] [--input-scale-stats-op {minmax,percentile}] @@ -96,15 +97,16 @@ usage: main.py [-h] [-m MODEL] [-d DEVICE] [-b BATCH_SIZE] [--prompt PROMPT] [--quantize-input-zero-point | --no-quantize-input-zero-point] [--export-cpu-float32 | --no-export-cpu-float32] [--use-mlperf-inference | --no-use-mlperf-inference] - [--use-ocp | --no-use-ocp] [--use-nfuz | --no-use-nfuz] + [--use-ocp | --no-use-ocp] [--use-fnuz | --no-use-fnuz] [--use-negative-prompts | --no-use-negative-prompts] [--dry-run | --no-dry-run] [--quantize-sdp-1 | --no-quantize-sdp-1] [--quantize-sdp-2 | --no-quantize-sdp-2] + [--override-conv-quant-config | --no-override-conv-quant-config] Stable Diffusion quantization -options: +optional arguments: -h, --help show this help message and exit -m MODEL, --model MODEL Path or name of the model. @@ -176,6 +178,8 @@ options: Alpha for activation equalization. Default: 0.9 --linear-input-bit-width LINEAR_INPUT_BIT_WIDTH Input bit width. Default: 0 (not quantized). + --linear-output-bit-width LINEAR_OUTPUT_BIT_WIDTH + Input bit width. Default: 0 (not quantized). --weight-param-method {stats,mse} How scales/zero-point are determined. Default: stats. --input-param-method {stats,mse} @@ -241,9 +245,9 @@ options: True --no-use-ocp Disable Use OCP format for float quantization. Default: True - --use-nfuz Enable Use NFUZ format for float quantization. + --use-fnuz Enable Use FNUZ format for float quantization. Default: True - --no-use-nfuz Disable Use NFUZ format for float quantization. + --no-use-fnuz Disable Use FNUZ format for float quantization. Default: True --use-negative-prompts Enable Use negative prompts during @@ -259,5 +263,10 @@ options: --no-quantize-sdp-1 Disable Quantize SDP. Default: Disabled --quantize-sdp-2 Enable Quantize SDP. Default: Disabled --no-quantize-sdp-2 Disable Quantize SDP. Default: Disabled - + --override-conv-quant-config + Enable Quantize Convolutions in the same way as SDP + (i.e., FP8). Default: Disabled + --no-override-conv-quant-config + Disable Quantize Convolutions in the same way as SDP + (i.e., FP8). Default: Disabled ``` From f77bf6ca038525fac7f12082dc153970c9169d60 Mon Sep 17 00:00:00 2001 From: Nick Fraser Date: Wed, 26 Jun 2024 15:34:09 +0100 Subject: [PATCH 07/24] Fix (example/sdxl): print which checkpoint is loaded. --- src/brevitas_examples/stable_diffusion/main.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/brevitas_examples/stable_diffusion/main.py b/src/brevitas_examples/stable_diffusion/main.py index 8d58ab4b9..1100f2194 100644 --- a/src/brevitas_examples/stable_diffusion/main.py +++ b/src/brevitas_examples/stable_diffusion/main.py @@ -413,7 +413,9 @@ def input_zp_stats_type(): if args.load_checkpoint is not None: with load_quant_model_mode(pipe.unet): pipe = pipe.to('cpu') + print(f"Loading checkpoint: {args.load_checkpoint}... ", end="") pipe.unet.load_state_dict(torch.load(args.load_checkpoint, map_location='cpu')) + print(f"Checkpoint loaded!") pipe = pipe.to(args.device) elif not args.dry_run: if (args.linear_input_bit_width is not None or From 606ddec70676f5897f2c03449fdbef800604826c Mon Sep 17 00:00:00 2001 From: Nick Fraser Date: Wed, 26 Jun 2024 16:34:06 +0100 Subject: [PATCH 08/24] Fix (example/sdxl): Move to CPU before 'param_only' export --- src/brevitas_examples/stable_diffusion/main.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/brevitas_examples/stable_diffusion/main.py b/src/brevitas_examples/stable_diffusion/main.py index 1100f2194..2f5fbc2f6 100644 --- a/src/brevitas_examples/stable_diffusion/main.py +++ b/src/brevitas_examples/stable_diffusion/main.py @@ -540,6 +540,7 @@ def input_zp_stats_type(): export_manager.change_weight_export(export_weight_q_node=args.export_weight_q_node) export_onnx(pipe, trace_inputs, output_dir, export_manager) if args.export_target == 'params_only': + pipe.to('cpu') export_quant_params(pipe, output_dir) From fb2fc8703c7492566e0748138cfb046692e9b059 Mon Sep 17 00:00:00 2001 From: Nick Fraser Date: Fri, 28 Jun 2024 08:57:49 +0000 Subject: [PATCH 09/24] Fix (example/sdxl): Added pandas requirement with specific version. --- .../stable_diffusion/mlperf_evaluation/requirements.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/src/brevitas_examples/stable_diffusion/mlperf_evaluation/requirements.txt b/src/brevitas_examples/stable_diffusion/mlperf_evaluation/requirements.txt index 690f7b0b0..e50c44f52 100644 --- a/src/brevitas_examples/stable_diffusion/mlperf_evaluation/requirements.txt +++ b/src/brevitas_examples/stable_diffusion/mlperf_evaluation/requirements.txt @@ -7,3 +7,4 @@ scipy==1.9.1 torchmetrics[image]==1.2.0 tqdm transformers==4.33.2 +pandas==2.2.2 From cb3593b803449da548832b2474e06a895ccaf5ef Mon Sep 17 00:00:00 2001 From: Nick Fraser Date: Fri, 28 Jun 2024 17:21:05 +0100 Subject: [PATCH 10/24] Fix (example/sdxl): pre-commit --- src/brevitas_examples/stable_diffusion/main.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/src/brevitas_examples/stable_diffusion/main.py b/src/brevitas_examples/stable_diffusion/main.py index 2f5fbc2f6..0c6992323 100644 --- a/src/brevitas_examples/stable_diffusion/main.py +++ b/src/brevitas_examples/stable_diffusion/main.py @@ -385,7 +385,9 @@ def input_zp_stats_type(): layer_map[torch.nn.Linear] = (layer_map[torch.nn.Linear][0], quant_kwargs) if args.override_conv_quant_config: - print(f"Overriding Conv2d quantization to weights: {float_sdpa_quantizers[1]}, inputs: {float_sdpa_quantizers[2]}") + print( + f"Overriding Conv2d quantization to weights: {float_sdpa_quantizers[1]}, inputs: {float_sdpa_quantizers[2]}" + ) conv_qkwargs = layer_map[torch.nn.Conv2d][1] conv_qkwargs['input_quant'] = float_sdpa_quantizers[2] conv_qkwargs['weight_quant'] = float_sdpa_quantizers[1] @@ -791,7 +793,11 @@ def input_zp_stats_type(): help='Generate a quantized model without any calibration. Default: Disabled') add_bool_arg(parser, 'quantize-sdp-1', default=False, help='Quantize SDP. Default: Disabled') add_bool_arg(parser, 'quantize-sdp-2', default=False, help='Quantize SDP. Default: Disabled') - add_bool_arg(parser, 'override-conv-quant-config', default=False, help='Quantize Convolutions in the same way as SDP (i.e., FP8). Default: Disabled') + add_bool_arg( + parser, + 'override-conv-quant-config', + default=False, + help='Quantize Convolutions in the same way as SDP (i.e., FP8). Default: Disabled') args = parser.parse_args() print("Args: " + str(vars(args))) main(args) From fe30b66ebf129c1ce71a47f933d88108e1a1a010 Mon Sep 17 00:00:00 2001 From: Nick Fraser Date: Fri, 28 Jun 2024 17:28:07 +0100 Subject: [PATCH 11/24] Fix (example/sdxl): pre-commit fix to requirements. --- .../stable_diffusion/mlperf_evaluation/requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/brevitas_examples/stable_diffusion/mlperf_evaluation/requirements.txt b/src/brevitas_examples/stable_diffusion/mlperf_evaluation/requirements.txt index e50c44f52..871c88554 100644 --- a/src/brevitas_examples/stable_diffusion/mlperf_evaluation/requirements.txt +++ b/src/brevitas_examples/stable_diffusion/mlperf_evaluation/requirements.txt @@ -2,9 +2,9 @@ accelerate==0.23.0 diffusers==0.21.2 open-clip-torch==2.7.0 opencv-python==4.8.1.78 +pandas==2.2.2 pycocotools==2.0.7 scipy==1.9.1 torchmetrics[image]==1.2.0 tqdm transformers==4.33.2 -pandas==2.2.2 From 4546ffedfef95033f75993934a5d4530c9966fb9 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Wed, 3 Jul 2024 10:52:22 +0100 Subject: [PATCH 12/24] Fix model loading --- src/brevitas_examples/stable_diffusion/main.py | 8 ++++++-- .../stable_diffusion/mlperf_evaluation/accuracy.py | 2 +- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/src/brevitas_examples/stable_diffusion/main.py b/src/brevitas_examples/stable_diffusion/main.py index 0c6992323..853d49940 100644 --- a/src/brevitas_examples/stable_diffusion/main.py +++ b/src/brevitas_examples/stable_diffusion/main.py @@ -12,6 +12,7 @@ from dependencies import value from diffusers import DiffusionPipeline +from diffusers import EulerDiscreteScheduler from diffusers import StableDiffusionXLPipeline from diffusers.models.attention_processor import Attention from diffusers.models.attention_processor import AttnProcessor @@ -37,7 +38,6 @@ from brevitas.utils.torch_utils import KwargsForwardHook from brevitas_examples.common.generative.quantize import generate_quant_maps from brevitas_examples.common.generative.quantize import generate_quantizers -from brevitas_examples.common.generative.quantize import quantize_model from brevitas_examples.common.parse_utils import add_bool_arg from brevitas_examples.common.parse_utils import quant_format_validator from brevitas_examples.llm.llm_quant.export import BlockQuantProxyLevelManager @@ -170,7 +170,11 @@ def main(args): # Load model from float checkpoint print(f"Loading model from {args.model}...") - pipe = DiffusionPipeline.from_pretrained(args.model, torch_dtype=dtype) + variant = 'fp16' if dtype == torch.float16 else None + pipe = DiffusionPipeline.from_pretrained( + args.model, torch_dtype=dtype, variant=variant, use_safetensors=True) + pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config) + pipe.vae.config.force_upcast = True print(f"Model loaded from {args.model}.") # Move model to target device diff --git a/src/brevitas_examples/stable_diffusion/mlperf_evaluation/accuracy.py b/src/brevitas_examples/stable_diffusion/mlperf_evaluation/accuracy.py index 8e10f107e..093b068d4 100644 --- a/src/brevitas_examples/stable_diffusion/mlperf_evaluation/accuracy.py +++ b/src/brevitas_examples/stable_diffusion/mlperf_evaluation/accuracy.py @@ -615,7 +615,7 @@ def compute_mlperf_fid( if model_to_replace is not None: model.pipe = model_to_replace - + model.pipe.vae.config.force_upcast = True ds = Coco( data_path=path_to_coco, name="coco-1024", From 99fc16f41296cd6e18e00cb3b9b367d409358eda Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Wed, 3 Jul 2024 10:52:41 +0100 Subject: [PATCH 13/24] Fix latents dtype --- src/brevitas_examples/stable_diffusion/main.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/brevitas_examples/stable_diffusion/main.py b/src/brevitas_examples/stable_diffusion/main.py index 853d49940..9499caf7e 100644 --- a/src/brevitas_examples/stable_diffusion/main.py +++ b/src/brevitas_examples/stable_diffusion/main.py @@ -152,7 +152,7 @@ def main(args): latents = None if args.path_to_latents is not None: - latents = torch.load(args.path_to_latents).to(torch.float16) + latents = torch.load(args.path_to_latents).to(dtype) # Create output dir. Move to tmp if None ts = datetime.fromtimestamp(time.time()) From b8b31f8304a318865355274d1d8a6d5a0754a6c8 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Wed, 3 Jul 2024 10:53:17 +0100 Subject: [PATCH 14/24] Fix biwdith --- src/brevitas_examples/stable_diffusion/main.py | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/src/brevitas_examples/stable_diffusion/main.py b/src/brevitas_examples/stable_diffusion/main.py index 9499caf7e..64159902e 100644 --- a/src/brevitas_examples/stable_diffusion/main.py +++ b/src/brevitas_examples/stable_diffusion/main.py @@ -217,7 +217,7 @@ def main(args): if args.activation_equalization: pipe.set_progress_bar_config(disable=True) - with activation_equalization_mode( + with torch.no_grad(), activation_equalization_mode( pipe.unet, alpha=args.act_eq_alpha, layerwise=True, @@ -266,8 +266,6 @@ def input_bit_width(module): return args.linear_input_bit_width elif isinstance(module, nn.Conv2d): return args.conv_input_bit_width - elif isinstance(module, QuantIdentity): - return args.quant_identity_bit_width else: raise RuntimeError(f"Module {module} not supported.") @@ -350,7 +348,7 @@ def input_zp_stats_type(): weight_group_size=args.weight_group_size, quantize_weight_zero_point=args.quantize_weight_zero_point, quantize_input_zero_point=args.quantize_input_zero_point, - input_bit_width=input_bit_width, + input_bit_width=args.linear_output_bit_width, input_quant_format='e4m3', input_scale_type=args.input_scale_type, input_scale_precision=args.input_scale_precision, @@ -363,7 +361,6 @@ def input_zp_stats_type(): # We generate all quantizers, but we are only interested in activation quantization for # the output of softmax and the output of QKV input_quant = float_sdpa_quantizers[0] - input_quant = input_quant.let(**{'bit_width': args.linear_output_bit_width}) if args.quantize_sdp_2: rewriter = ModuleToModuleByClass( Attention, @@ -379,14 +376,13 @@ def input_zp_stats_type(): config.IGNORE_MISSING_KEYS = False pipe.unet = pipe.unet.to(args.device) pipe.unet = pipe.unet.to(dtype) - quant_kwargs = layer_map[torch.nn.Linear][1] + quant_kwargs = layer_map['diffusers.models.lora.LoRACompatibleLinear'][1] what_to_quantize = [] if args.quantize_sdp_1: what_to_quantize.extend(['to_q', 'to_k']) if args.quantize_sdp_2: what_to_quantize.extend(['to_v']) quant_kwargs['output_quant'] = lambda module, name: input_quant if any(ending in name for ending in what_to_quantize) else None - layer_map[torch.nn.Linear] = (layer_map[torch.nn.Linear][0], quant_kwargs) if args.override_conv_quant_config: print( @@ -424,8 +420,8 @@ def input_zp_stats_type(): print(f"Checkpoint loaded!") pipe = pipe.to(args.device) elif not args.dry_run: - if (args.linear_input_bit_width is not None or - args.conv_input_bit_width is not None) and args.input_scale_type == 'static': + if (args.linear_input_bit_width > 0 or args.conv_input_bit_width > 0 or + args.linear_output_bit_width > 0) and args.input_scale_type == 'static': print("Applying activation calibration") with torch.no_grad(), calibration_mode(pipe.unet): run_val_inference( @@ -463,7 +459,7 @@ def input_zp_stats_type(): torch.cuda.empty_cache() if args.bias_correction: print("Applying bias correction") - with bias_correction_mode(pipe.unet): + with torch.no_grad(), bias_correction_mode(pipe.unet): run_val_inference( pipe, args.resolution, From 4fd4998d690e4b95a3ba5b408db0ea8a1010ae27 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Wed, 3 Jul 2024 10:53:29 +0100 Subject: [PATCH 15/24] Fix export --- src/brevitas_examples/stable_diffusion/sd_quant/export.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/brevitas_examples/stable_diffusion/sd_quant/export.py b/src/brevitas_examples/stable_diffusion/sd_quant/export.py index 64bcac34f..9535b8b0d 100644 --- a/src/brevitas_examples/stable_diffusion/sd_quant/export.py +++ b/src/brevitas_examples/stable_diffusion/sd_quant/export.py @@ -93,11 +93,13 @@ def export_quant_params(pipe, output_dir): elif isinstance( module, QuantWeightBiasInputOutputLayer) and id(module) not in handled_quant_layers: + full_name = name layer_dict = dict() layer_dict = handle_quant_param(module, layer_dict) quant_params[full_name] = layer_dict handled_quant_layers.add(id(module)) elif isinstance(module, QuantNonLinearActLayer): + full_name = name layer_dict = dict() act_scale = module.act_quant.export_handler.symbolic_kwargs[ 'dequantize_symbolic_kwargs']['scale'].data From 6f0d2f96fd73eeb7e3f0457a3a2b0cdd63027938 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Wed, 3 Jul 2024 11:05:14 +0100 Subject: [PATCH 16/24] Fix tests --- tests/brevitas/graph/test_transforms.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/brevitas/graph/test_transforms.py b/tests/brevitas/graph/test_transforms.py index c58d9d828..875d5a52c 100644 --- a/tests/brevitas/graph/test_transforms.py +++ b/tests/brevitas/graph/test_transforms.py @@ -287,6 +287,6 @@ def forward(self, x): model = TestModel() assert model.conv.stride == (1, 1) - kwargs = {'stride': lambda module: 2 if module.in_channels == 3 else 1} + kwargs = {'stride': lambda module, name: 2 if module.in_channels == 3 else 1} model = ModuleToModuleByInstance(model.conv, nn.Conv2d, **kwargs).apply(model) assert model.conv.stride == (2, 2) From 41bba9ce2d09ea3597cb846263659dfd4d9356e1 Mon Sep 17 00:00:00 2001 From: Nick Fraser Date: Wed, 10 Jul 2024 16:28:30 +0100 Subject: [PATCH 17/24] Feat (example/sdxl): Added fix for VAE @ FP16 --- .../stable_diffusion/main.py | 35 +++++++++++++++++-- .../mlperf_evaluation/accuracy.py | 5 +-- .../stable_diffusion/sd_quant/export.py | 13 ++++--- 3 files changed, 45 insertions(+), 8 deletions(-) diff --git a/src/brevitas_examples/stable_diffusion/main.py b/src/brevitas_examples/stable_diffusion/main.py index 64159902e..aa746d8eb 100644 --- a/src/brevitas_examples/stable_diffusion/main.py +++ b/src/brevitas_examples/stable_diffusion/main.py @@ -472,15 +472,41 @@ def input_zp_stats_type(): test_latents=latents, guidance_scale=args.guidance_scale) + if args.vae_fp16_fix: + vae_fix_scale = 128 + layer_whitelist = [ + "decoder.up_blocks.2.upsamplers.0.conv", + "decoder.up_blocks.3.resnets.0.conv2", + "decoder.up_blocks.3.resnets.1.conv2", + "decoder.up_blocks.3.resnets.2.conv2"] + #layer_whitelist = [ + # "decoder.up_blocks.3.resnets.0.conv_shortcut", + # "decoder.up_blocks.3.resnets.0.conv2", + # "decoder.up_blocks.3.resnets.1.conv2", + # "decoder.up_blocks.3.resnets.2.conv2"] + corrected_layers = [] + with torch.no_grad(): + for name, module in pipe.vae.named_modules(): + if name in layer_whitelist: + corrected_layers.append(name) + module.weight /= vae_fix_scale + if module.bias is not None: + module.bias /= vae_fix_scale + print(f"Corrected layers in VAE: {corrected_layers}") + if args.checkpoint_name is not None and args.load_checkpoint is None: torch.save(pipe.unet.state_dict(), os.path.join(output_dir, args.checkpoint_name)) + if args.vae_fp16_fix: + torch.save( + pipe.vae.state_dict(), os.path.join(output_dir, f"vae_{args.checkpoint_name}")) # Perform inference if args.prompt > 0 and not args.dry_run: # with brevitas_proxy_inference_mode(pipe.unet): if args.use_mlperf_inference: print(f"Computing accuracy with MLPerf pipeline") - compute_mlperf_fid(args.model, args.path_to_coco, pipe, args.prompt, output_dir) + compute_mlperf_fid( + args.model, args.path_to_coco, pipe, args.prompt, output_dir, not args.vae_fp16_fix) else: print(f"Computing accuracy on default prompt") testing_prompts = TESTING_PROMPTS[:args.prompt] @@ -543,7 +569,7 @@ def input_zp_stats_type(): export_onnx(pipe, trace_inputs, output_dir, export_manager) if args.export_target == 'params_only': pipe.to('cpu') - export_quant_params(pipe, output_dir) + export_quant_params(pipe, output_dir, export_vae=args.vae_fp16_fix) if __name__ == "__main__": @@ -798,6 +824,11 @@ def input_zp_stats_type(): 'override-conv-quant-config', default=False, help='Quantize Convolutions in the same way as SDP (i.e., FP8). Default: Disabled') + add_bool_arg( + parser, + 'vae-fp16-fix', + default=False, + help='Rescale the VAE to not go NaN with FP16. Default: Disabled') args = parser.parse_args() print("Args: " + str(vars(args))) main(args) diff --git a/src/brevitas_examples/stable_diffusion/mlperf_evaluation/accuracy.py b/src/brevitas_examples/stable_diffusion/mlperf_evaluation/accuracy.py index 093b068d4..f7c797c28 100644 --- a/src/brevitas_examples/stable_diffusion/mlperf_evaluation/accuracy.py +++ b/src/brevitas_examples/stable_diffusion/mlperf_evaluation/accuracy.py @@ -601,7 +601,8 @@ def compute_mlperf_fid( path_to_coco, model_to_replace=None, samples_to_evaluate=500, - output_dir=None): + output_dir=None, + vae_force_upcast=True): assert os.path.isfile(path_to_coco + '/tools/val2014.npz'), "Val2014.npz file required. Check the MLPerf directory for instructions" @@ -615,7 +616,7 @@ def compute_mlperf_fid( if model_to_replace is not None: model.pipe = model_to_replace - model.pipe.vae.config.force_upcast = True + model.pipe.vae.config.force_upcast = vae_force_upcast ds = Coco( data_path=path_to_coco, name="coco-1024", diff --git a/src/brevitas_examples/stable_diffusion/sd_quant/export.py b/src/brevitas_examples/stable_diffusion/sd_quant/export.py index 9535b8b0d..09c331951 100644 --- a/src/brevitas_examples/stable_diffusion/sd_quant/export.py +++ b/src/brevitas_examples/stable_diffusion/sd_quant/export.py @@ -53,10 +53,13 @@ def handle_quant_param(layer, layer_dict): return layer_dict -def export_quant_params(pipe, output_dir): +def export_quant_params(pipe, output_dir, export_vae=False): quant_output_path = os.path.join(output_dir, 'quant_params.json') - output_path = os.path.join(output_dir, 'params.safetensors') - print(f"Saving unet to {output_path} ...") + unet_output_path = os.path.join(output_dir, 'params.safetensors') + print(f"Saving unet to {unet_output_path} ...") + if export_vae: + vae_output_path = os.path.join(output_dir, 'vae.safetensors') + print(f"Saving vae to {vae_output_path} ...") from brevitas.export.onnx.standard.qcdq.manager import StdQCDQONNXManager quant_params = dict() state_dict = pipe.unet.state_dict() @@ -114,4 +117,6 @@ def export_quant_params(pipe, output_dir): handled_quant_layers.add(id(module)) with open(quant_output_path, 'w') as file: json.dump(quant_params, file, indent=" ") - save_file(state_dict, output_path) + save_file(state_dict, unet_output_path) + if export_vae: + save_file(pipe.vae.state_dict(), vae_output_path) From 3241a5a2b5679852ec89d088f2d16d7655ca2cef Mon Sep 17 00:00:00 2001 From: Nick Fraser Date: Wed, 10 Jul 2024 16:31:11 +0100 Subject: [PATCH 18/24] Fix (example/sdxl): Only apply VAE fix for SDXL --- src/brevitas_examples/stable_diffusion/main.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/brevitas_examples/stable_diffusion/main.py b/src/brevitas_examples/stable_diffusion/main.py index aa746d8eb..9681ccbb3 100644 --- a/src/brevitas_examples/stable_diffusion/main.py +++ b/src/brevitas_examples/stable_diffusion/main.py @@ -472,7 +472,7 @@ def input_zp_stats_type(): test_latents=latents, guidance_scale=args.guidance_scale) - if args.vae_fp16_fix: + if args.vae_fp16_fix and is_sd_xl: vae_fix_scale = 128 layer_whitelist = [ "decoder.up_blocks.2.upsamplers.0.conv", From 307b12894099caba2c9448148ef1a4ae7b405aaf Mon Sep 17 00:00:00 2001 From: Nick Fraser Date: Wed, 10 Jul 2024 16:31:59 +0100 Subject: [PATCH 19/24] Docs (example/sdxl): Updated usage --- src/brevitas_examples/stable_diffusion/README.md | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/brevitas_examples/stable_diffusion/README.md b/src/brevitas_examples/stable_diffusion/README.md index e7019b3b5..a51a06df5 100644 --- a/src/brevitas_examples/stable_diffusion/README.md +++ b/src/brevitas_examples/stable_diffusion/README.md @@ -103,6 +103,7 @@ usage: main.py [-h] [-m MODEL] [-d DEVICE] [-b BATCH_SIZE] [--prompt PROMPT] [--quantize-sdp-1 | --no-quantize-sdp-1] [--quantize-sdp-2 | --no-quantize-sdp-2] [--override-conv-quant-config | --no-override-conv-quant-config] + [--vae-fp16-fix | --no-vae-fp16-fix] Stable Diffusion quantization @@ -269,4 +270,8 @@ optional arguments: --no-override-conv-quant-config Disable Quantize Convolutions in the same way as SDP (i.e., FP8). Default: Disabled + --vae-fp16-fix Enable Rescale the VAE to not go NaN with FP16. + Default: Disabled + --no-vae-fp16-fix Disable Rescale the VAE to not go NaN with FP16. + Default: Disabled ``` From b9cc9c1a2abb2e02739778037134fb025126bd67 Mon Sep 17 00:00:00 2001 From: Nick Fraser Date: Fri, 12 Jul 2024 15:58:20 +0100 Subject: [PATCH 20/24] Fix (example/generative): Added missing `use_fnuz` arg. --- src/brevitas_examples/common/generative/quantize.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/brevitas_examples/common/generative/quantize.py b/src/brevitas_examples/common/generative/quantize.py index 36bac29d5..a86de3b76 100644 --- a/src/brevitas_examples/common/generative/quantize.py +++ b/src/brevitas_examples/common/generative/quantize.py @@ -473,6 +473,7 @@ def quantize_model( quantize_input_zero_point=False, quantize_embedding=False, use_ocp=False, + use_fnuz=False, device=None, weight_kwargs=None, input_kwargs=None): @@ -497,6 +498,7 @@ def quantize_model( input_group_size, quantize_input_zero_point, use_ocp, + use_fnuz, device, weight_kwargs, input_kwargs) From fdfcafc479cb2c1236fdb378a38c3f6c75638e26 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Mon, 15 Jul 2024 11:37:46 +0100 Subject: [PATCH 21/24] Update --- .../stable_diffusion/main.py | 58 ++-- .../mlperf_evaluation/accuracy.py | 5 +- .../stable_diffusion/sd_quant/nn.py | 269 +++++++++++++++++- 3 files changed, 302 insertions(+), 30 deletions(-) diff --git a/src/brevitas_examples/stable_diffusion/main.py b/src/brevitas_examples/stable_diffusion/main.py index 9681ccbb3..d629174e0 100644 --- a/src/brevitas_examples/stable_diffusion/main.py +++ b/src/brevitas_examples/stable_diffusion/main.py @@ -7,6 +7,7 @@ from datetime import datetime from functools import partial import json +import math import os import time @@ -46,7 +47,9 @@ from brevitas_examples.stable_diffusion.sd_quant.constants import SD_XL_EMBEDDINGS_SHAPE from brevitas_examples.stable_diffusion.sd_quant.export import export_onnx from brevitas_examples.stable_diffusion.sd_quant.export import export_quant_params +from brevitas_examples.stable_diffusion.sd_quant.nn import AttnProcessor2_0 from brevitas_examples.stable_diffusion.sd_quant.nn import QuantAttention +from brevitas_examples.stable_diffusion.sd_quant.nn import QuantizableAttention from brevitas_examples.stable_diffusion.sd_quant.utils import generate_latents from brevitas_examples.stable_diffusion.sd_quant.utils import generate_unet_21_rand_inputs from brevitas_examples.stable_diffusion.sd_quant.utils import generate_unet_xl_rand_inputs @@ -175,6 +178,18 @@ def main(args): args.model, torch_dtype=dtype, variant=variant, use_safetensors=True) pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config) pipe.vae.config.force_upcast = True + if args.share_qkv_quant: + rewriter = ModuleToModuleByClass( + Attention, + QuantizableAttention, + query_dim=lambda module: module.to_q.in_features, + dim_head=lambda module: math.ceil(1 / (module.scale ** 2)), + bias=lambda module: hasattr(module.to_q, 'bias') and module.to_q.bias is not None, + processor=AttnProcessor2_0(), + dtype=dtype, + norm_num_groups=lambda module: None + if module.group_norm is None else module.group_norm.num_groups) + rewriter.apply(pipe.unet) print(f"Model loaded from {args.model}.") # Move model to target device @@ -335,7 +350,9 @@ def input_zp_stats_type(): 'weight_quant'] layer_map[torch.nn.Conv2d] = (layer_map[torch.nn.Conv2d][0], conv_qkwargs) - if args.quantize_sdp_1 or args.quantize_sdp_2: + if args.quantize_sdp: + assert args.share_qkv_quant, "Currently SDPA quantization is supported only with shared QKV quantization" + # TODO: reformat this float_sdpa_quantizers = generate_quantizers( dtype=dtype, device=args.device, @@ -361,28 +378,20 @@ def input_zp_stats_type(): # We generate all quantizers, but we are only interested in activation quantization for # the output of softmax and the output of QKV input_quant = float_sdpa_quantizers[0] - if args.quantize_sdp_2: - rewriter = ModuleToModuleByClass( - Attention, - QuantAttention, - softmax_output_quant=input_quant, - query_dim=lambda module: module.to_q.in_features, - dim_head=lambda module: int(1 / (module.scale ** 2)), - processor=AttnProcessor(), - is_equalized=args.activation_equalization) - import brevitas.config as config - config.IGNORE_MISSING_KEYS = True - pipe.unet = rewriter.apply(pipe.unet) - config.IGNORE_MISSING_KEYS = False - pipe.unet = pipe.unet.to(args.device) - pipe.unet = pipe.unet.to(dtype) - quant_kwargs = layer_map['diffusers.models.lora.LoRACompatibleLinear'][1] - what_to_quantize = [] - if args.quantize_sdp_1: - what_to_quantize.extend(['to_q', 'to_k']) - if args.quantize_sdp_2: - what_to_quantize.extend(['to_v']) - quant_kwargs['output_quant'] = lambda module, name: input_quant if any(ending in name for ending in what_to_quantize) else None + rewriter = ModuleToModuleByClass( + Attention, + QuantAttention, + matmul_input_quant=input_quant, + query_dim=lambda module: module.to_q.in_features, + dim_head=lambda module: math.ceil(1 / (module.scale ** 2)), + processor=AttnProcessor(), + is_equalized=args.activation_equalization) + import brevitas.config as config + config.IGNORE_MISSING_KEYS = True + pipe.unet = rewriter.apply(pipe.unet) + config.IGNORE_MISSING_KEYS = False + pipe.unet = pipe.unet.to(args.device) + pipe.unet = pipe.unet.to(dtype) if args.override_conv_quant_config: print( @@ -817,8 +826,7 @@ def input_zp_stats_type(): 'dry-run', default=False, help='Generate a quantized model without any calibration. Default: Disabled') - add_bool_arg(parser, 'quantize-sdp-1', default=False, help='Quantize SDP. Default: Disabled') - add_bool_arg(parser, 'quantize-sdp-2', default=False, help='Quantize SDP. Default: Disabled') + add_bool_arg(parser, 'quantize-sdp', default=False, help='Quantize SDP. Default: Disabled') add_bool_arg( parser, 'override-conv-quant-config', diff --git a/src/brevitas_examples/stable_diffusion/mlperf_evaluation/accuracy.py b/src/brevitas_examples/stable_diffusion/mlperf_evaluation/accuracy.py index f7c797c28..6fb987967 100644 --- a/src/brevitas_examples/stable_diffusion/mlperf_evaluation/accuracy.py +++ b/src/brevitas_examples/stable_diffusion/mlperf_evaluation/accuracy.py @@ -615,7 +615,10 @@ def compute_mlperf_fid( model.load() if model_to_replace is not None: - model.pipe = model_to_replace + model.pipe.unet = model_to_replace.unet + if not vae_force_upcast: + model.pipe.vae = model.pipe.vae + model.pipe.vae.config.force_upcast = vae_force_upcast ds = Coco( data_path=path_to_coco, diff --git a/src/brevitas_examples/stable_diffusion/sd_quant/nn.py b/src/brevitas_examples/stable_diffusion/sd_quant/nn.py index e240c3a36..c6647e566 100644 --- a/src/brevitas_examples/stable_diffusion/sd_quant/nn.py +++ b/src/brevitas_examples/stable_diffusion/sd_quant/nn.py @@ -1,7 +1,9 @@ -from typing import Optional +from typing import Any, Mapping, Optional from diffusers.models.attention_processor import Attention +from diffusers.models.lora import LoRACompatibleLinear import torch +import torch.nn.functional as F from brevitas.graph.base import ModuleInstanceToModuleInstance from brevitas.nn.equalized_layer import EqualizedModule @@ -10,7 +12,97 @@ from brevitas.quant_tensor import _unpack_quant_tensor -class QuantAttention(Attention): +class QuantizableAttention(Attention): + + def __init__( + self, + query_dim: int, + cross_attention_dim: Optional[int] = None, + heads: int = 8, + dim_head: int = 64, + dropout: float = 0.0, + bias=False, + upcast_attention: bool = False, + upcast_softmax: bool = False, + cross_attention_norm: Optional[str] = None, + cross_attention_norm_num_groups: int = 32, + added_kv_proj_dim: Optional[int] = None, + norm_num_groups: Optional[int] = None, + spatial_norm_dim: Optional[int] = None, + out_bias: bool = True, + scale_qk: bool = True, + only_cross_attention: bool = False, + eps: float = 1e-5, + rescale_output_factor: float = 1.0, + residual_connection: bool = False, + _from_deprecated_attn_block=False, + dtype=torch.float32, + processor: Optional["AttnProcessor"] = None): + + super().__init__( + query_dim, + cross_attention_dim, + heads, + dim_head, + dropout, + bias, + upcast_attention, + upcast_softmax, + cross_attention_norm, + cross_attention_norm_num_groups, + added_kv_proj_dim, + norm_num_groups, + spatial_norm_dim, + out_bias, + scale_qk, + only_cross_attention, + eps, + rescale_output_factor, + residual_connection, + _from_deprecated_attn_block, + processor, + ) + if self.to_q.weight.shape == self.to_k.weight.shape: + self.to_qkv = LoRACompatibleLinear( + query_dim, 3 * self.inner_dim, bias=bias, dtype=dtype) + + del self.to_q + del self.to_k + del self.to_v + + else: + self.to_q = LoRACompatibleLinear(query_dim, self.inner_dim, bias=bias, dtype=dtype) + self.to_kv = LoRACompatibleLinear( + self.cross_attention_dim, 2 * self.inner_dim, bias=bias, dtype=dtype) + + del self.to_k + del self.to_v + + self.to_out = torch.nn.ModuleList([]) + self.to_out.append( + LoRACompatibleLinear(self.inner_dim, query_dim, bias=out_bias, dtype=dtype)) + self.to_out.append(torch.nn.Dropout(dropout)) + + def load_state_dict( + self, state_dict: Mapping[str, Any], strict: bool = True, assign: bool = False): + if hasattr(self, 'to_qkv') and 'to_q.weight' in state_dict: + new_weights = torch.cat( + [state_dict['to_q.weight'], state_dict['to_k.weight'], state_dict['to_v.weight']], + dim=0) + state_dict['to_qkv.weight'] = new_weights + + del state_dict['to_q.weight'] + del state_dict['to_k.weight'] + del state_dict['to_v.weight'] + elif hasattr(self, 'to_kv') and 'to_k.weight' in state_dict: + new_weights = torch.cat([state_dict['to_k.weight'], state_dict['to_v.weight']], dim=0) + state_dict['to_kv.weight'] = new_weights + del state_dict['to_k.weight'] + del state_dict['to_v.weight'] + return super().load_state_dict(state_dict, strict, assign) + + +class QuantAttention(QuantizableAttention): def __init__( self, @@ -35,8 +127,10 @@ def __init__( residual_connection: bool = False, _from_deprecated_attn_block=False, processor: Optional["AttnProcessor"] = None, - softmax_output_quant=None, + matmul_input_quant=None, + dtype=torch.float32, is_equalized=False): + super().__init__( query_dim, cross_attention_dim, @@ -58,10 +152,14 @@ def __init__( rescale_output_factor, residual_connection, _from_deprecated_attn_block, + dtype, processor, ) - self.output_softmax_quant = QuantIdentity(softmax_output_quant) + self.output_softmax_quant = QuantIdentity(matmul_input_quant) + self.out_q = QuantIdentity(matmul_input_quant) + self.out_k = QuantIdentity(matmul_input_quant) + self.out_v = QuantIdentity(matmul_input_quant) if is_equalized: replacements = [] for n, m in self.named_modules(): @@ -125,3 +223,166 @@ def get_attention_scores( attention_probs = _unpack_quant_tensor(self.output_softmax_quant(attention_probs)) return attention_probs + + +class AttnProcessor: + r""" + Default processor for performing attention-related computations. + """ + + def __call__( + self, + attn: Attention, + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + temb=None, + scale=1.0, + ): + residual = hidden_states + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + if encoder_hidden_states is None: + assert attn.norm_cross is None, "Not supported" + query, key, value = attn.to_qkv(hidden_states, scale=scale).chunk(3, dim=-1) + + else: + assert not hasattr(attn, 'to_qkv'), 'Model not created correctly' + query = attn.to_q(hidden_states, scale=scale) + if attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + key, value = attn.to_kv(encoder_hidden_states, scale=scale).chunk(2, dim=-1) + if hasattr(attn, 'out_q'): + query = _unpack_quant_tensor(attn.out_q(query)) + if hasattr(attn, 'out_k'): + key = _unpack_quant_tensor(attn.out_k(key)) + if hasattr(attn, 'out_v'): + value = _unpack_quant_tensor(attn.out_v(value)) + query = attn.head_to_batch_dim(query) + key = attn.head_to_batch_dim(key) + value = attn.head_to_batch_dim(value) + + attention_probs = attn.get_attention_scores(query, key, attention_mask) + hidden_states = torch.bmm(attention_probs, value) + hidden_states = attn.batch_to_head_dim(hidden_states) + + # linear proj + hidden_states = attn.to_out[0](hidden_states, scale=scale) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, + -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + +class AttnProcessor2_0: + r""" + Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). + """ + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError( + "AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") + + def __call__( + self, + attn: Attention, + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + temb=None, + scale: float = 1.0, + ): + residual = hidden_states + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask( + attention_mask, sequence_length, batch_size) + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + attention_mask = attention_mask.view( + batch_size, attn.heads, -1, attention_mask.shape[-1]) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + if encoder_hidden_states is None: + assert attn.norm_cross is None, "Not supported" + query, key, value = attn.to_qkv(hidden_states, scale=scale).chunk(3, dim=-1) + + else: + assert not hasattr(attn, 'to_qkv'), 'Model not created correctly' + query = attn.to_q(hidden_states, scale=scale) + if attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + key, value = attn.to_kv(encoder_hidden_states, scale=scale).chunk(2, dim=-1) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False) + + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + # linear proj + hidden_states = attn.to_out[0](hidden_states, scale=scale) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, + -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states From ade7c6b1904d3de8b8723694b7a234cc55f29b3a Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Mon, 15 Jul 2024 12:08:23 +0100 Subject: [PATCH 22/24] Lambda inspection --- src/brevitas/graph/base.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/brevitas/graph/base.py b/src/brevitas/graph/base.py index e980e02fc..04fbd88d3 100644 --- a/src/brevitas/graph/base.py +++ b/src/brevitas/graph/base.py @@ -3,6 +3,7 @@ from abc import ABC from abc import abstractmethod +import inspect from inspect import getcallargs import torch @@ -127,7 +128,11 @@ def _evaluate_new_kwargs(self, new_kwargs, old_module, name): if name is not None: v = v(old_module, name) else: - v = v(old_module) + # Two types of lambdas are admitted now, with/without the name of the module as input + if len(inspect.getfullargspec(v).args) == 2: + v = v(old_module, name) + elif len(inspect.getfullargspec(v).args) == 1: + v = v(old_module) update_dict[k] = v new_kwargs.update(update_dict) return new_kwargs From ce2993ef334a96dae03e162f3b93388cfecabf7b Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Mon, 15 Jul 2024 12:10:20 +0100 Subject: [PATCH 23/24] fix --- src/brevitas/graph/base.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/brevitas/graph/base.py b/src/brevitas/graph/base.py index 04fbd88d3..def3f7070 100644 --- a/src/brevitas/graph/base.py +++ b/src/brevitas/graph/base.py @@ -126,13 +126,13 @@ def _evaluate_new_kwargs(self, new_kwargs, old_module, name): for k, v in self.new_module_kwargs.items(): if islambda(v): if name is not None: - v = v(old_module, name) - else: # Two types of lambdas are admitted now, with/without the name of the module as input if len(inspect.getfullargspec(v).args) == 2: v = v(old_module, name) elif len(inspect.getfullargspec(v).args) == 1: v = v(old_module) + else: + v = v(old_module) update_dict[k] = v new_kwargs.update(update_dict) return new_kwargs From 5b0beb688ef71968102abcd6575d60221720d29b Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Wed, 17 Jul 2024 10:56:27 +0100 Subject: [PATCH 24/24] Add license --- src/brevitas_examples/stable_diffusion/main.py | 11 ++++++++++- .../stable_diffusion/sd_quant/nn.py | 17 +++++++++++++++++ 2 files changed, 27 insertions(+), 1 deletion(-) diff --git a/src/brevitas_examples/stable_diffusion/main.py b/src/brevitas_examples/stable_diffusion/main.py index d629174e0..cb1f42920 100644 --- a/src/brevitas_examples/stable_diffusion/main.py +++ b/src/brevitas_examples/stable_diffusion/main.py @@ -220,10 +220,19 @@ def main(args): # Extract list of layers to avoid blacklist = [] + non_blacklist = dict() for name, _ in pipe.unet.named_modules(): if 'time_emb' in name: blacklist.append(name.split('.')[-1]) - print(f"Blacklisted layers: {blacklist}") + else: + if isinstance(_, (torch.nn.Linear, torch.nn.Conv2d)): + name_to_add = name.split('.')[-1] + if name_to_add not in non_blacklist: + non_blacklist[name_to_add] = 1 + else: + non_blacklist[name_to_add] += 1 + print(f"Blacklisted layers: {set(blacklist)}") + print(f"Non blacklisted layers: {non_blacklist}") # Make sure there all LoRA layers are fused first, otherwise raise an error for m in pipe.unet.modules(): diff --git a/src/brevitas_examples/stable_diffusion/sd_quant/nn.py b/src/brevitas_examples/stable_diffusion/sd_quant/nn.py index c6647e566..5a6c23ab9 100644 --- a/src/brevitas_examples/stable_diffusion/sd_quant/nn.py +++ b/src/brevitas_examples/stable_diffusion/sd_quant/nn.py @@ -1,3 +1,20 @@ +# This code was taken and modified from the Hugging Face Diffusers repository under the following +# LICENSE: + +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from typing import Any, Mapping, Optional from diffusers.models.attention_processor import Attention