Skip to content

Commit

Permalink
Feat (brevitas_examples/SDXL): expanded SDXL quantization (#977)
Browse files Browse the repository at this point in the history
  • Loading branch information
nickfraser authored Jul 17, 2024
1 parent 1f8e351 commit 072a02b
Show file tree
Hide file tree
Showing 9 changed files with 448 additions and 63 deletions.
20 changes: 14 additions & 6 deletions src/brevitas/graph/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from abc import ABC
from abc import abstractmethod
import inspect
from inspect import getcallargs

import torch
Expand Down Expand Up @@ -120,16 +121,23 @@ def _module_attributes(self, module):
attrs['bias'] = module.bias
return attrs

def _evaluate_new_kwargs(self, new_kwargs, old_module):
def _evaluate_new_kwargs(self, new_kwargs, old_module, name):
update_dict = dict()
for k, v in self.new_module_kwargs.items():
if islambda(v):
v = v(old_module)
if name is not None:
# Two types of lambdas are admitted now, with/without the name of the module as input
if len(inspect.getfullargspec(v).args) == 2:
v = v(old_module, name)
elif len(inspect.getfullargspec(v).args) == 1:
v = v(old_module)
else:
v = v(old_module)
update_dict[k] = v
new_kwargs.update(update_dict)
return new_kwargs

def _init_new_module(self, old_module: Module):
def _init_new_module(self, old_module: Module, name=None):
# get attributes of original module
new_kwargs = self._module_attributes(old_module)
# transforms attribute of original module, e.g. bias Parameter -> bool
Expand All @@ -138,7 +146,7 @@ def _init_new_module(self, old_module: Module):
new_module_signature_keys = signature_keys(self.new_module_class)
new_kwargs = {k: v for k, v in new_kwargs.items() if k in new_module_signature_keys}
# update with kwargs passed to the rewriter
new_kwargs = self._evaluate_new_kwargs(new_kwargs, old_module)
new_kwargs = self._evaluate_new_kwargs(new_kwargs, old_module, name)
# init the new module
new_module = self.new_module_class(**new_kwargs)
return new_module
Expand Down Expand Up @@ -204,10 +212,10 @@ def __init__(self, old_module_instance, new_module_class, **kwargs):
self.old_module_instance = old_module_instance

def apply(self, model: GraphModule) -> GraphModule:
for old_module in model.modules():
for name, old_module in model.named_modules():
if old_module is self.old_module_instance:
# init the new module based on the old one
new_module = self._init_new_module(old_module)
new_module = self._init_new_module(old_module, name)
self._replace_old_module(model, old_module, new_module)
break
return model
Expand Down
2 changes: 2 additions & 0 deletions src/brevitas_examples/common/generative/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,6 +473,7 @@ def quantize_model(
quantize_input_zero_point=False,
quantize_embedding=False,
use_ocp=False,
use_fnuz=False,
device=None,
weight_kwargs=None,
input_kwargs=None):
Expand All @@ -497,6 +498,7 @@ def quantize_model(
input_group_size,
quantize_input_zero_point,
use_ocp,
use_fnuz,
device,
weight_kwargs,
input_kwargs)
Expand Down
24 changes: 19 additions & 5 deletions src/brevitas_examples/stable_diffusion/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ usage: main.py [-h] [-m MODEL] [-d DEVICE] [-b BATCH_SIZE] [--prompt PROMPT]
[--conv-input-bit-width CONV_INPUT_BIT_WIDTH]
[--act-eq-alpha ACT_EQ_ALPHA]
[--linear-input-bit-width LINEAR_INPUT_BIT_WIDTH]
[--linear-output-bit-width LINEAR_OUTPUT_BIT_WIDTH]
[--weight-param-method {stats,mse}]
[--input-param-method {stats,mse}]
[--input-scale-stats-op {minmax,percentile}]
Expand All @@ -96,15 +97,17 @@ usage: main.py [-h] [-m MODEL] [-d DEVICE] [-b BATCH_SIZE] [--prompt PROMPT]
[--quantize-input-zero-point | --no-quantize-input-zero-point]
[--export-cpu-float32 | --no-export-cpu-float32]
[--use-mlperf-inference | --no-use-mlperf-inference]
[--use-ocp | --no-use-ocp] [--use-nfuz | --no-use-nfuz]
[--use-ocp | --no-use-ocp] [--use-fnuz | --no-use-fnuz]
[--use-negative-prompts | --no-use-negative-prompts]
[--dry-run | --no-dry-run]
[--quantize-sdp-1 | --no-quantize-sdp-1]
[--quantize-sdp-2 | --no-quantize-sdp-2]
[--override-conv-quant-config | --no-override-conv-quant-config]
[--vae-fp16-fix | --no-vae-fp16-fix]

Stable Diffusion quantization

options:
optional arguments:
-h, --help show this help message and exit
-m MODEL, --model MODEL
Path or name of the model.
Expand Down Expand Up @@ -176,6 +179,8 @@ options:
Alpha for activation equalization. Default: 0.9
--linear-input-bit-width LINEAR_INPUT_BIT_WIDTH
Input bit width. Default: 0 (not quantized).
--linear-output-bit-width LINEAR_OUTPUT_BIT_WIDTH
Input bit width. Default: 0 (not quantized).
--weight-param-method {stats,mse}
How scales/zero-point are determined. Default: stats.
--input-param-method {stats,mse}
Expand Down Expand Up @@ -241,9 +246,9 @@ options:
True
--no-use-ocp Disable Use OCP format for float quantization.
Default: True
--use-nfuz Enable Use NFUZ format for float quantization.
--use-fnuz Enable Use FNUZ format for float quantization.
Default: True
--no-use-nfuz Disable Use NFUZ format for float quantization.
--no-use-fnuz Disable Use FNUZ format for float quantization.
Default: True
--use-negative-prompts
Enable Use negative prompts during
Expand All @@ -259,5 +264,14 @@ options:
--no-quantize-sdp-1 Disable Quantize SDP. Default: Disabled
--quantize-sdp-2 Enable Quantize SDP. Default: Disabled
--no-quantize-sdp-2 Disable Quantize SDP. Default: Disabled

--override-conv-quant-config
Enable Quantize Convolutions in the same way as SDP
(i.e., FP8). Default: Disabled
--no-override-conv-quant-config
Disable Quantize Convolutions in the same way as SDP
(i.e., FP8). Default: Disabled
--vae-fp16-fix Enable Rescale the VAE to not go NaN with FP16.
Default: Disabled
--no-vae-fp16-fix Disable Rescale the VAE to not go NaN with FP16.
Default: Disabled
```
Loading

0 comments on commit 072a02b

Please sign in to comment.