Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Various SDXL quantization fixes #977

Merged
merged 24 commits into from
Jul 17, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
9181cff
Fix (examples/sdxl): Fix issue setting device when checkpoint is loaded.
nickfraser Jun 25, 2024
4ff1b43
Fix (example/sdxl): Added argument for linear output bitwidth.
nickfraser Jun 25, 2024
ce38f86
Fix (example/sdxl): Fix when replacing `diffusers.models.lora.LoRACom…
nickfraser Jun 25, 2024
975c9ee
Fix (example/sdxl): print output directory.
nickfraser Jun 26, 2024
b3ed0d8
Feat (example/sdxl): add extra option to quantize conv layers like SDP
nickfraser Jun 26, 2024
11f9dc8
fix (example/sdxl): Updated usage README.
nickfraser Jun 26, 2024
f77bf6c
Fix (example/sdxl): print which checkpoint is loaded.
nickfraser Jun 26, 2024
606ddec
Fix (example/sdxl): Move to CPU before 'param_only' export
nickfraser Jun 26, 2024
fb2fc87
Fix (example/sdxl): Added pandas requirement with specific version.
nickfraser Jun 28, 2024
cb3593b
Fix (example/sdxl): pre-commit
nickfraser Jun 28, 2024
fe30b66
Fix (example/sdxl): pre-commit fix to requirements.
nickfraser Jun 28, 2024
4546ffe
Fix model loading
Giuseppe5 Jul 3, 2024
99fc16f
Fix latents dtype
Giuseppe5 Jul 3, 2024
b8b31f8
Fix biwdith
Giuseppe5 Jul 3, 2024
4fd4998
Fix export
Giuseppe5 Jul 3, 2024
6f0d2f9
Fix tests
Giuseppe5 Jul 3, 2024
41bba9c
Feat (example/sdxl): Added fix for VAE @ FP16
nickfraser Jul 10, 2024
3241a5a
Fix (example/sdxl): Only apply VAE fix for SDXL
nickfraser Jul 10, 2024
307b128
Docs (example/sdxl): Updated usage
nickfraser Jul 10, 2024
b9cc9c1
Fix (example/generative): Added missing `use_fnuz` arg.
nickfraser Jul 12, 2024
fdfcafc
Update
Giuseppe5 Jul 15, 2024
ade7c6b
Lambda inspection
Giuseppe5 Jul 15, 2024
ce2993e
fix
Giuseppe5 Jul 15, 2024
5b0beb6
Add license
Giuseppe5 Jul 17, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 13 additions & 13 deletions src/brevitas_examples/stable_diffusion/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from dependencies import value
from diffusers import DiffusionPipeline
from diffusers import EulerDiscreteScheduler
from diffusers import StableDiffusionXLPipeline
from diffusers.models.attention_processor import Attention
from diffusers.models.attention_processor import AttnProcessor
Expand All @@ -37,7 +38,6 @@
from brevitas.utils.torch_utils import KwargsForwardHook
from brevitas_examples.common.generative.quantize import generate_quant_maps
from brevitas_examples.common.generative.quantize import generate_quantizers
from brevitas_examples.common.generative.quantize import quantize_model
from brevitas_examples.common.parse_utils import add_bool_arg
from brevitas_examples.common.parse_utils import quant_format_validator
from brevitas_examples.llm.llm_quant.export import BlockQuantProxyLevelManager
Expand Down Expand Up @@ -152,7 +152,7 @@ def main(args):

latents = None
if args.path_to_latents is not None:
latents = torch.load(args.path_to_latents).to(torch.float16)
latents = torch.load(args.path_to_latents).to(dtype)

# Create output dir. Move to tmp if None
ts = datetime.fromtimestamp(time.time())
Expand All @@ -170,7 +170,11 @@ def main(args):

# Load model from float checkpoint
print(f"Loading model from {args.model}...")
pipe = DiffusionPipeline.from_pretrained(args.model, torch_dtype=dtype)
variant = 'fp16' if dtype == torch.float16 else None
pipe = DiffusionPipeline.from_pretrained(
args.model, torch_dtype=dtype, variant=variant, use_safetensors=True)
pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config)
pipe.vae.config.force_upcast = True
print(f"Model loaded from {args.model}.")

# Move model to target device
Expand Down Expand Up @@ -213,7 +217,7 @@ def main(args):

if args.activation_equalization:
pipe.set_progress_bar_config(disable=True)
with activation_equalization_mode(
with torch.no_grad(), activation_equalization_mode(
pipe.unet,
alpha=args.act_eq_alpha,
layerwise=True,
Expand Down Expand Up @@ -262,8 +266,6 @@ def input_bit_width(module):
return args.linear_input_bit_width
elif isinstance(module, nn.Conv2d):
return args.conv_input_bit_width
elif isinstance(module, QuantIdentity):
return args.quant_identity_bit_width
else:
raise RuntimeError(f"Module {module} not supported.")

Expand Down Expand Up @@ -346,7 +348,7 @@ def input_zp_stats_type():
weight_group_size=args.weight_group_size,
quantize_weight_zero_point=args.quantize_weight_zero_point,
quantize_input_zero_point=args.quantize_input_zero_point,
input_bit_width=input_bit_width,
input_bit_width=args.linear_output_bit_width,
input_quant_format='e4m3',
input_scale_type=args.input_scale_type,
input_scale_precision=args.input_scale_precision,
Expand All @@ -359,7 +361,6 @@ def input_zp_stats_type():
# We generate all quantizers, but we are only interested in activation quantization for
# the output of softmax and the output of QKV
input_quant = float_sdpa_quantizers[0]
input_quant = input_quant.let(**{'bit_width': args.linear_output_bit_width})
if args.quantize_sdp_2:
rewriter = ModuleToModuleByClass(
Attention,
Expand All @@ -375,14 +376,13 @@ def input_zp_stats_type():
config.IGNORE_MISSING_KEYS = False
pipe.unet = pipe.unet.to(args.device)
pipe.unet = pipe.unet.to(dtype)
quant_kwargs = layer_map[torch.nn.Linear][1]
quant_kwargs = layer_map['diffusers.models.lora.LoRACompatibleLinear'][1]
what_to_quantize = []
if args.quantize_sdp_1:
what_to_quantize.extend(['to_q', 'to_k'])
if args.quantize_sdp_2:
what_to_quantize.extend(['to_v'])
quant_kwargs['output_quant'] = lambda module, name: input_quant if any(ending in name for ending in what_to_quantize) else None
layer_map[torch.nn.Linear] = (layer_map[torch.nn.Linear][0], quant_kwargs)

if args.override_conv_quant_config:
print(
Expand Down Expand Up @@ -420,8 +420,8 @@ def input_zp_stats_type():
print(f"Checkpoint loaded!")
pipe = pipe.to(args.device)
elif not args.dry_run:
if (args.linear_input_bit_width is not None or
args.conv_input_bit_width is not None) and args.input_scale_type == 'static':
if (args.linear_input_bit_width > 0 or args.conv_input_bit_width > 0 or
args.linear_output_bit_width > 0) and args.input_scale_type == 'static':
print("Applying activation calibration")
with torch.no_grad(), calibration_mode(pipe.unet):
run_val_inference(
Expand Down Expand Up @@ -459,7 +459,7 @@ def input_zp_stats_type():
torch.cuda.empty_cache()
if args.bias_correction:
print("Applying bias correction")
with bias_correction_mode(pipe.unet):
with torch.no_grad(), bias_correction_mode(pipe.unet):
run_val_inference(
pipe,
args.resolution,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -615,7 +615,7 @@ def compute_mlperf_fid(

if model_to_replace is not None:
model.pipe = model_to_replace

model.pipe.vae.config.force_upcast = True
Giuseppe5 marked this conversation as resolved.
Show resolved Hide resolved
ds = Coco(
data_path=path_to_coco,
name="coco-1024",
Expand Down
2 changes: 2 additions & 0 deletions src/brevitas_examples/stable_diffusion/sd_quant/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,11 +93,13 @@ def export_quant_params(pipe, output_dir):
elif isinstance(
module,
QuantWeightBiasInputOutputLayer) and id(module) not in handled_quant_layers:
full_name = name
layer_dict = dict()
layer_dict = handle_quant_param(module, layer_dict)
quant_params[full_name] = layer_dict
handled_quant_layers.add(id(module))
elif isinstance(module, QuantNonLinearActLayer):
full_name = name
layer_dict = dict()
act_scale = module.act_quant.export_handler.symbolic_kwargs[
'dequantize_symbolic_kwargs']['scale'].data
Expand Down
2 changes: 1 addition & 1 deletion tests/brevitas/graph/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,6 @@ def forward(self, x):
model = TestModel()
assert model.conv.stride == (1, 1)

kwargs = {'stride': lambda module: 2 if module.in_channels == 3 else 1}
kwargs = {'stride': lambda module, name: 2 if module.in_channels == 3 else 1}
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Check if anything in brevitas_examples will be affected by this change.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lambdas gets inspected to decide what's the correct signature to use, so we have backward compatibility

model = ModuleToModuleByInstance(model.conv, nn.Conv2d, **kwargs).apply(model)
assert model.conv.stride == (2, 2)
Loading