Skip to content

Commit

Permalink
Feat (examples/sdxl): Updates to SDXL entry-point (#1020)
Browse files Browse the repository at this point in the history
---------

Co-authored-by: Giuseppe Franco <[email protected]>
  • Loading branch information
nickfraser and Giuseppe5 authored Sep 12, 2024
1 parent f58f64b commit 9932b92
Show file tree
Hide file tree
Showing 3 changed files with 336 additions and 58 deletions.
239 changes: 188 additions & 51 deletions src/brevitas_examples/stable_diffusion/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,14 @@
import time

from dependencies import value
import diffusers
from diffusers import DiffusionPipeline
from diffusers import EulerDiscreteScheduler
from diffusers import StableDiffusionXLPipeline
from diffusers.models.attention_processor import Attention
from diffusers.models.attention_processor import AttnProcessor
import numpy as np
import packaging
import packaging.version
import pandas as pd
import torch
from torch import nn
Expand All @@ -35,7 +37,6 @@
from brevitas.graph.quantize import layerwise_quantize
from brevitas.inject.enum import StatsOp
from brevitas.nn.equalized_layer import EqualizedModule
from brevitas.nn.quant_activation import QuantIdentity
from brevitas.utils.torch_utils import KwargsForwardHook
from brevitas_examples.common.generative.quantize import generate_quant_maps
from brevitas_examples.common.generative.quantize import generate_quantizers
Expand All @@ -47,14 +48,17 @@
from brevitas_examples.stable_diffusion.sd_quant.constants import SD_XL_EMBEDDINGS_SHAPE
from brevitas_examples.stable_diffusion.sd_quant.export import export_onnx
from brevitas_examples.stable_diffusion.sd_quant.export import export_quant_params
from brevitas_examples.stable_diffusion.sd_quant.nn import AttnProcessor
from brevitas_examples.stable_diffusion.sd_quant.nn import AttnProcessor2_0
from brevitas_examples.stable_diffusion.sd_quant.nn import QuantAttention
from brevitas_examples.stable_diffusion.sd_quant.nn import QuantAttentionLast
from brevitas_examples.stable_diffusion.sd_quant.nn import QuantizableAttention
from brevitas_examples.stable_diffusion.sd_quant.utils import generate_latents
from brevitas_examples.stable_diffusion.sd_quant.utils import generate_unet_21_rand_inputs
from brevitas_examples.stable_diffusion.sd_quant.utils import generate_unet_xl_rand_inputs
from brevitas_examples.stable_diffusion.sd_quant.utils import unet_input_shape

diffusers_version = packaging.version.parse(diffusers.__version__)
TEST_SEED = 123456
torch.manual_seed(TEST_SEED)

Expand Down Expand Up @@ -149,7 +153,7 @@ def main(args):
calibration_prompts = CALIBRATION_PROMPTS
if args.calibration_prompt_path is not None:
calibration_prompts = load_calib_prompts(args.calibration_prompt_path)
print(args.calibration_prompt, len(calibration_prompts))

assert args.calibration_prompt <= len(calibration_prompts) , f"Only {len(calibration_prompts)} prompts are available"
calibration_prompts = calibration_prompts[:args.calibration_prompt]

Expand Down Expand Up @@ -178,18 +182,29 @@ def main(args):
args.model, torch_dtype=dtype, variant=variant, use_safetensors=True)
pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config)
pipe.vae.config.force_upcast = True
if args.share_qkv_quant:
rewriter = ModuleToModuleByClass(
Attention,
QuantizableAttention,
query_dim=lambda module: module.to_q.in_features,
dim_head=lambda module: math.ceil(1 / (module.scale ** 2)),
bias=lambda module: hasattr(module.to_q, 'bias') and module.to_q.bias is not None,
processor=AttnProcessor2_0(),
dtype=dtype,
norm_num_groups=lambda module: None
if module.group_norm is None else module.group_norm.num_groups)
rewriter.apply(pipe.unet)
is_mlperf_diffusers = diffusers_version == packaging.version.parse('0.21.2')

AttClass = Attention
if is_mlperf_diffusers:
QuantAttClass = QuantAttention
if args.share_qkv_quant:
AttClass = QuantizableAttention
rewriter = ModuleToModuleByClass(
Attention,
QuantizableAttention,
query_dim=lambda module: module.to_q.in_features,
dim_head=lambda module: math.ceil(1 / (module.scale ** 2)),
bias=lambda module: hasattr(module.to_q, 'bias') and module.to_q.bias is not None,
processor=AttnProcessor2_0(),
dtype=dtype,
norm_num_groups=lambda module: None
if module.group_norm is None else module.group_norm.num_groups)
rewriter.apply(pipe.unet)
else:
QuantAttClass = QuantAttentionLast
if args.share_qkv_quant:
pipe.fuse_qkv_projections()

print(f"Model loaded from {args.model}.")

# Move model to target device
Expand Down Expand Up @@ -222,7 +237,7 @@ def main(args):
blacklist = []
non_blacklist = dict()
for name, _ in pipe.unet.named_modules():
if 'time_emb' in name:
if any(map(lambda x: x in name, args.quant_blacklist)):
blacklist.append(name)
else:
if isinstance(_, (torch.nn.Linear, torch.nn.Conv2d)):
Expand All @@ -232,7 +247,7 @@ def main(args):
else:
non_blacklist[name_to_add] += 1
print(f"Blacklisted layers: {set(blacklist)}")
print(f"Non blacklisted layers: {non_blacklist}")
print(f"Non blacklisted layers: {set(non_blacklist.keys())}")

# Make sure there all LoRA layers are fused first, otherwise raise an error
for m in pipe.unet.modules():
Expand Down Expand Up @@ -316,6 +331,29 @@ def input_zp_stats_type():

input_kwargs['zero_point_stats_impl'] = input_zp_stats_type

sdpa_kwargs = dict()
if args.sdpa_scale_stats_op == 'minmax':

@value
def sdpa_scale_stats_type():
if args.sdpa_quant_type == 'asym':
sdpa_scaling_stats_op = StatsOp.MIN_MAX
else:
sdpa_scaling_stats_op = StatsOp.MAX
return sdpa_scaling_stats_op

sdpa_kwargs['scaling_stats_op'] = sdpa_scale_stats_type

if args.sdpa_zp_stats_op == 'minmax':

@value
def sdpa_zp_stats_type():
if args.sdpa_quant_type == 'asym':
zero_point_stats_impl = NegativeMinOrZero
return zero_point_stats_impl

sdpa_kwargs['zero_point_stats_impl'] = sdpa_zp_stats_type

print("Applying model quantization...")
quantizers = generate_quantizers(
dtype=dtype,
Expand Down Expand Up @@ -357,40 +395,52 @@ def input_zp_stats_type():
'weight_quant']
layer_map[torch.nn.Conv2d] = (layer_map[torch.nn.Conv2d][0], conv_qkwargs)

if args.quantize_sdp:
assert args.share_qkv_quant, "Currently SDPA quantization is supported only with shared QKV quantization"
# TODO: reformat this
float_sdpa_quantizers = generate_quantizers(
if args.sdpa_bit_width > 0:
# `args.weight_quant_granularity` must be compatible with `args.sdpa_quant_format`
sdpa_quantizers = generate_quantizers(
dtype=dtype,
device=args.device,
weight_bit_width=weight_bit_width,
weight_quant_format='float_ocp_e4m3',
weight_quant_type='sym',
weight_bit_width=args.sdpa_bit_width,
weight_quant_format=args.sdpa_quant_format,
weight_quant_type=args.sdpa_quant_type,
weight_param_method=args.weight_param_method,
weight_scale_precision=args.weight_scale_precision,
weight_quant_granularity=args.weight_quant_granularity,
weight_group_size=args.weight_group_size,
quantize_weight_zero_point=args.quantize_weight_zero_point,
quantize_input_zero_point=args.quantize_input_zero_point,
input_bit_width=args.linear_output_bit_width,
input_quant_format='float_ocp_e4m3',
input_scale_type=args.input_scale_type,
input_scale_precision=args.input_scale_precision,
input_param_method=args.input_param_method,
input_quant_type='sym',
input_quant_granularity=args.input_quant_granularity,
input_kwargs=input_kwargs)
quantize_input_zero_point=args.quantize_sdpa_zero_point,
input_bit_width=args.sdpa_bit_width,
input_quant_format=args.sdpa_quant_format,
input_scale_type=args.sdpa_scale_type,
input_scale_precision=args.sdpa_scale_precision,
input_param_method=args.sdpa_param_method,
input_quant_type=args.sdpa_quant_type,
input_quant_granularity=args.sdpa_quant_granularity,
input_kwargs=sdpa_kwargs)
# We generate all quantizers, but we are only interested in activation quantization for
# the output of softmax and the output of QKV
input_quant = float_sdpa_quantizers[0]
input_quant = sdpa_quantizers[0]
if is_mlperf_diffusers:
extra_kwargs = {}
query_lambda = lambda module: module.to_qkv.in_features if hasattr(
module, 'to_qkv') else module.to_q.in_features
else:
extra_kwargs = {
'fuse_qkv':
args.share_qkv_quant,
'cross_attention_dim':
lambda module: module.cross_attention_dim
if module.is_cross_attention else None}
query_lambda = lambda module: module.query_dim
rewriter = ModuleToModuleByClass(
Attention,
QuantAttention,
AttClass,
QuantAttClass,
matmul_input_quant=input_quant,
query_dim=lambda module: module.to_q.in_features,
query_dim=query_lambda,
dim_head=lambda module: math.ceil(1 / (module.scale ** 2)),
processor=AttnProcessor(),
is_equalized=args.activation_equalization)
is_equalized=args.activation_equalization,
**extra_kwargs)
import brevitas.config as config
config.IGNORE_MISSING_KEYS = True
pipe.unet = rewriter.apply(pipe.unet)
Expand All @@ -400,11 +450,11 @@ def input_zp_stats_type():

if args.override_conv_quant_config:
print(
f"Overriding Conv2d quantization to weights: {float_sdpa_quantizers[1]}, inputs: {float_sdpa_quantizers[2]}"
f"Overriding Conv2d quantization to weights: {sdpa_quantizers[1]}, inputs: {sdpa_quantizers[2]}"
)
conv_qkwargs = layer_map[torch.nn.Conv2d][1]
conv_qkwargs['input_quant'] = float_sdpa_quantizers[2]
conv_qkwargs['weight_quant'] = float_sdpa_quantizers[1]
conv_qkwargs['input_quant'] = sdpa_quantizers[2]
conv_qkwargs['weight_quant'] = sdpa_quantizers[1]
layer_map[torch.nn.Conv2d] = (layer_map[torch.nn.Conv2d][0], conv_qkwargs)

pipe.unet = layerwise_quantize(
Expand Down Expand Up @@ -434,8 +484,18 @@ def input_zp_stats_type():
print(f"Checkpoint loaded!")
pipe = pipe.to(args.device)
elif not args.dry_run:
if (args.linear_input_bit_width > 0 or args.conv_input_bit_width > 0 or
args.linear_output_bit_width > 0) and args.input_scale_type == 'static':
# Model needs calibration if any of its activation quantizers are 'static'
activation_bw = [
args.linear_input_bit_width,
args.conv_input_bit_width,
args.sdpa_bit_width,]
activation_st = [
args.input_scale_type,
args.input_scale_type,
args.sdpa_scale_type,]
needs_calibration = any(
map(lambda b, st: (b > 0) and st == 'static', activation_bw, activation_st))
if needs_calibration:
print("Applying activation calibration")
with torch.no_grad(), calibration_mode(pipe.unet):
run_val_inference(
Expand Down Expand Up @@ -520,7 +580,13 @@ def input_zp_stats_type():
if args.use_mlperf_inference:
print(f"Computing accuracy with MLPerf pipeline")
compute_mlperf_fid(
args.model, args.path_to_coco, pipe, args.prompt, output_dir, not args.vae_fp16_fix)
args.model,
args.path_to_coco,
pipe,
args.prompt,
output_dir,
args.device,
not args.vae_fp16_fix)
else:
print(f"Computing accuracy on default prompt")
testing_prompts = TESTING_PROMPTS[:args.prompt]
Expand Down Expand Up @@ -643,7 +709,7 @@ def input_zp_stats_type():
help='Resolution along height and width dimension. Default: 512.')
parser.add_argument('--guidance-scale', type=float, default=7.5, help='Guidance scale.')
parser.add_argument(
'--calibration-steps', type=float, default=8, help='Steps used during calibration')
'--calibration-steps', type=int, default=8, help='Steps used during calibration')
add_bool_arg(
parser,
'output-path',
Expand Down Expand Up @@ -701,11 +767,6 @@ def input_zp_stats_type():
type=int,
default=0,
help='Input bit width. Default: 0 (not quantized).')
parser.add_argument(
'--linear-output-bit-width',
type=int,
default=0,
help='Input bit width. Default: 0 (not quantized).')
parser.add_argument(
'--weight-param-method',
type=str,
Expand Down Expand Up @@ -791,6 +852,78 @@ def input_zp_stats_type():
type=int,
default=16,
help='Group size for per_group weight quantization. Default: 16.')
parser.add_argument(
'--sdpa-bit-width',
type=int,
default=0,
help='Scaled dot product attention bit width. Default: 0 (not quantized).')
parser.add_argument(
'--sdpa-param-method',
type=str,
default='stats',
choices=['stats', 'mse'],
help=
'How scales/zero-point are determined for scaled dot product attention. Default: %(default)s.'
)
parser.add_argument(
'--sdpa-scale-stats-op',
type=str,
default='minmax',
choices=['minmax', 'percentile'],
help=
'Define what statistics op to use for scaled dot product attention scale. Default: %(default)s.'
)
parser.add_argument(
'--sdpa-zp-stats-op',
type=str,
default='minmax',
choices=['minmax', 'percentile'],
help=
'Define what statistics op to use for scaled dot product attention zero point. Default: %(default)s.'
)
parser.add_argument(
'--sdpa-scale-precision',
type=str,
default='float_scale',
choices=['float_scale', 'po2_scale'],
help=
'Whether the scaled dot product attention scale is a float value or a po2. Default: %(default)s.'
)
parser.add_argument(
'--sdpa-quant-type',
type=str,
default='sym',
choices=['sym', 'asym'],
help='Scaled dot product attention quantization type. Default: %(default)s.')
parser.add_argument(
'--sdpa-quant-format',
type=quant_format_validator,
default='int',
help=
'Scaled dot product attention quantization format. Either int or eXmY, with X+Y==input_bit_width-1. It\'s possible to add float_ocp_ or float_fnuz_ before the exponent/mantissa bitwidth. Default: %(default)s.'
)
parser.add_argument(
'--sdpa-quant-granularity',
type=str,
default='per_tensor',
choices=['per_tensor'],
help=
'Granularity for scales/zero-point of scaled dot product attention. Default: %(default)s.')
parser.add_argument(
'--sdpa-scale-type',
type=str,
default='static',
choices=['static', 'dynamic'],
help=
'Whether to do static or dynamic scaled dot product attention quantization. Default: %(default)s.'
)
parser.add_argument(
'--quant-blacklist',
type=str,
default=['time_emb'],
nargs='*',
metavar='NAME',
help='A list of module names to exclude from quantization. Default: %(default)s')
add_bool_arg(
parser,
'quantize-weight-zero-point',
Expand All @@ -806,6 +939,11 @@ def input_zp_stats_type():
'quantize-input-zero-point',
default=False,
help='Quantize input zero-point. Default: Enabled')
add_bool_arg(
parser,
'quantize-sdpa-zero-point',
default=False,
help='Quantize scaled dot product attention zero-point. Default: %(default)s')
add_bool_arg(
parser, 'export-cpu-float32', default=False, help='Export FP32 on CPU. Default: Disabled')
add_bool_arg(
Expand All @@ -823,7 +961,6 @@ def input_zp_stats_type():
'dry-run',
default=False,
help='Generate a quantized model without any calibration. Default: Disabled')
add_bool_arg(parser, 'quantize-sdp', default=False, help='Quantize SDP. Default: Disabled')
add_bool_arg(
parser,
'override-conv-quant-config',
Expand Down
Loading

0 comments on commit 9932b92

Please sign in to comment.