Skip to content

Commit

Permalink
Better structure for QDQ weights
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Feb 5, 2024
1 parent cae0b49 commit 7ebe2dd
Show file tree
Hide file tree
Showing 10 changed files with 90 additions and 86 deletions.
78 changes: 40 additions & 38 deletions src/brevitas/export/common/handler/qcdq.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,10 +133,32 @@ def dequantize_symbolic_kwargs(cls, scale, zero_point, bit_width, is_signed):
'scale_orig_shape': scale_orig_shape}


class CDQCastWeightQuantProxyHandlerMixin(CDQCastProxyHandlerMixin, ABC):
handled_layer = WeightQuantProxyFromInjector
class QCDQCastWeightQuantProxyHandlerMixin(QMixin, CDQCastProxyHandlerMixin):

def prepare_quantize_for_export(self, module):
def quantize_symbolic_kwargs(cls, scale, zero_point, bit_width, is_signed):
# compute axis before redefining scale
axis = cls.quant_axis(scale)
scale = to_0dim_if_scalar(scale.flatten())
zp = to_0dim_if_scalar(zero_point.flatten())
# expand_as must go after 0-dim check
zp = zp.expand_as(scale)
zp = cls.zero_point_with_dtype(is_signed, bit_width, zp)
if cls.itemize_quantize_scalar_params:
scale = to_item_if_0dim(scale)
zp = to_item_if_0dim(zp)
dtype = cls.signed_dtype(bit_width, is_signed)
return {'scale': scale, 'zero_point': zp, 'dtype': dtype, 'axis': axis}

def prepare_quantize_from_floating_point(self, module):
quant_weight = module.tracked_module_list[0].quant_weight()
scale = quant_weight.scale
self.scale_dtype = scale.dtype
if self.scale_dtype == torch.bfloat16 or self.scale_dtype == torch.float16:
scale = self.cast_fn(scale, torch.float32)
self.symbolic_kwargs['quantize_symbolic_kwargs'] = self.quantize_symbolic_kwargs(
scale, quant_weight.zero_point, quant_weight.bit_width, module.is_signed)

def prepare_quantize_from_integer(self, module):
int_weights = {
tm.weight.data_ptr(): tm.quant_weight().int(float_datatype=False)
for tm in module.tracked_module_list}
Expand All @@ -145,7 +167,10 @@ def prepare_quantize_for_export(self, module):
def prepare_for_export(self, module):
if module.is_quant_enabled:
self.validate(module)
self.prepare_quantize_for_export(module)
if self._export_q_node:
self.prepare_quantize_from_floating_point(module)
else:
self.prepare_quantize_from_integer(module)
# Get the first quant weight as representative
quant_weight = module.tracked_module_list[0].quant_weight()

Expand All @@ -165,12 +190,20 @@ def prepare_for_export(self, module):
else:
self.symbolic_kwargs = None

def quantize(self, x: Tensor):
def quantize_from_floating_point(self, x: Tensor):
quantize_symbolic_kwargs = self.symbolic_kwargs['quantize_symbolic_kwargs']
x = self.quantize_fn(x, *quantize_symbolic_kwargs.values())
return x

def quantize_from_integer(self, x: Tensor):
return self.symbolic_kwargs['int_weights'][x.data_ptr()]

def symbolic_execution(self, x: Tensor):
assert self.symbolic_kwargs is not None, 'Symbolic execution requires quant to be enabled'
x = self.quantize(x)
if self._export_q_node:
x = self.quantize_from_floating_point(x)
else:
x = self.quantize_from_integer(x)
clip_symbolic_kwargs = self.symbolic_kwargs['clip_symbolic_kwargs']
# Copy dict to allow for popping kwargs even on shared quantizers
dequantize_symbolic_kwargs = copy(self.symbolic_kwargs['dequantize_symbolic_kwargs'])
Expand All @@ -193,38 +226,7 @@ def symbolic_execution(self, x: Tensor):
return x, scale, zero_point, bit_width


class QCDQCastWeightQuantProxyHandlerMixin(QMixin, CDQCastWeightQuantProxyHandlerMixin):

def prepare_quantize_for_export(self, module):
quant_weight = module.tracked_module_list[0].quant_weight()
scale = quant_weight.scale
self.scale_dtype = scale.dtype
if self.scale_dtype == torch.bfloat16 or self.scale_dtype == torch.float16:
scale = self.cast_fn(scale, torch.float32)
self.symbolic_kwargs['quantize_symbolic_kwargs'] = self.quantize_symbolic_kwargs(
scale, quant_weight.zero_point, quant_weight.bit_width, module.is_signed)

def quantize_symbolic_kwargs(cls, scale, zero_point, bit_width, is_signed):
# compute axis before redefining scale
axis = cls.quant_axis(scale)
scale = to_0dim_if_scalar(scale.flatten())
zp = to_0dim_if_scalar(zero_point.flatten())
# expand_as must go after 0-dim check
zp = zp.expand_as(scale)
zp = cls.zero_point_with_dtype(is_signed, bit_width, zp)
if cls.itemize_quantize_scalar_params:
scale = to_item_if_0dim(scale)
zp = to_item_if_0dim(zp)
dtype = cls.signed_dtype(bit_width, is_signed)
return {'scale': scale, 'zero_point': zp, 'dtype': dtype, 'axis': axis}

def quantize(self, x: Tensor):
quantize_symbolic_kwargs = self.symbolic_kwargs['quantize_symbolic_kwargs']
x = self.quantize_fn(x, *quantize_symbolic_kwargs.values())
return x


class CDQCastDecoupledWeightQuantProxyHandlerMixin(CDQCastWeightQuantProxyHandlerMixin, ABC):
class CDQCastDecoupledWeightQuantProxyHandlerMixin(QCDQCastWeightQuantProxyHandlerMixin, ABC):
handled_layer = DecoupledWeightQuantProxyFromInjector

def symbolic_execution(self, x: Tensor):
Expand Down
2 changes: 1 addition & 1 deletion src/brevitas/export/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def _restore_requires_grad(m: Module, previous_state):
class BaseManager(ABC):

target_name = None
handlers = set()
handlers = []
_fn_to_cache = []
_fn_cache = []
_cached_io_handler_map = {}
Expand Down
4 changes: 2 additions & 2 deletions src/brevitas/export/onnx/qonnx/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,14 @@ class QONNXManager(ONNXBaseManager):
"extract_constant_to_initializer", # remove unused graph inputs & initializers
"eliminate_unused_initializer"]

handlers = {
handlers = [
BrevitasActQuantProxyHandler,
BrevitasBiasQuantProxyHandler,
BrevitasWeightQuantProxyHandler,
BrevitasDecoupledWeightQuantProxyHandler,
BrevitasDecoupledWeightQuantWithInputProxyHandler,
BrevitasTruncQuantProxyHandler,
BrevitasQuantLSTMLayerHandler}
BrevitasQuantLSTMLayerHandler]

custom_fns = [
DebugMarkerFunction,
Expand Down
21 changes: 7 additions & 14 deletions src/brevitas/export/onnx/standard/qcdq/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,14 @@
import torch

from brevitas.export.common.handler.qcdq import CDQCastBiasQuantProxyHandlerMixin
from brevitas.export.common.handler.qcdq import CDQCastDecoupledWeightQuantProxyHandlerMixin
from brevitas.export.common.handler.qcdq import \
CDQCastDecoupledWeightQuantWithInputProxyHandlerMixin
from brevitas.export.common.handler.qcdq import CDQCastMixin
from brevitas.export.common.handler.qcdq import CDQCastWeightQuantProxyHandlerMixin
from brevitas.export.common.handler.qcdq import DQCastMixin
from brevitas.export.common.handler.qcdq import DynamicQDQCastActQuantProxyHandlerMixin
from brevitas.export.common.handler.qcdq import DynamicQMixin
from brevitas.export.common.handler.qcdq import QCDQCastActQuantProxyHandlerMixin
from brevitas.export.common.handler.qcdq import QCDQCastDecoupledWeightQuantProxyHandlerMixin
from brevitas.export.common.handler.qcdq import QCDQCastTruncQuantProxyHandlerMixin
from brevitas.export.common.handler.qcdq import QCDQCastWeightQuantProxyHandlerMixin
from brevitas.export.common.handler.qcdq import QMixin
Expand Down Expand Up @@ -110,28 +109,22 @@ def quantize_fn(self, x, dtype):
return DynamicQuantizeLinearFn.apply(x, dtype)


class StdCDQCastONNXWeightQuantProxyHandler(StdCDQCastONNXMixin,
CDQCastWeightQuantProxyHandlerMixin,
ONNXBaseHandler):
pass


class StdQCDQCastONNXWeightQuantProxyHandler(StdQCDQCastONNXMixin,
QCDQCastWeightQuantProxyHandlerMixin,
ONNXBaseHandler):
pass
_export_q_node: bool = False


class StdCDQCastONNXDecoupledWeightQuantProxyHandler(StdCDQCastONNXMixin,
CDQCastDecoupledWeightQuantProxyHandlerMixin,
ONNXBaseHandler):
pass
class StdQCDQCastONNXDecoupledWeightQuantProxyHandler(StdQCDQCastONNXMixin,
QCDQCastDecoupledWeightQuantProxyHandlerMixin,
ONNXBaseHandler):
_export_q_node: bool = False


class StdCDQCastONNXDecoupledWeightQuantWithInputProxyHandler(
StdCDQCastONNXMixin, CDQCastDecoupledWeightQuantWithInputProxyHandlerMixin,
ONNXBaseHandler):
pass
_export_q_node: bool = False


class StdQCDQCastONNXActQuantProxyHandler(StdQCDQCastONNXMixin,
Expand Down
23 changes: 12 additions & 11 deletions src/brevitas/export/onnx/standard/qcdq/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from .handler import StdCDQCastONNXBiasQuantProxyHandler
from .handler import StdCDQCastONNXDecoupledWeightQuantProxyHandler
from .handler import StdCDQCastONNXDecoupledWeightQuantWithInputProxyHandler
from .handler import StdCDQCastONNXWeightQuantProxyHandler
from .handler import StdDynamicQDQCastONNXActQuantProxyHandler
from .handler import StdQCDQCastONNXActQuantProxyHandler
from .handler import StdQCDQCastONNXQuantLSTMLayerHandler
Expand All @@ -35,15 +34,15 @@ class StdQCDQONNXManager(StdONNXBaseManager):
"extract_constant_to_initializer", # remove unused graph inputs & initializers
"eliminate_unused_initializer"]

handlers = {
StdCDQCastONNXWeightQuantProxyHandler,
handlers = [
StdQCDQCastONNXWeightQuantProxyHandler,
StdCDQCastONNXBiasQuantProxyHandler,
StdQCDQCastONNXActQuantProxyHandler,
StdCDQCastONNXDecoupledWeightQuantProxyHandler,
StdDynamicQDQCastONNXActQuantProxyHandler,
StdQCDQCastONNXTruncQuantProxyHandler,
StdCDQCastONNXDecoupledWeightQuantWithInputProxyHandler,
StdQCDQCastONNXQuantLSTMLayerHandler}
StdQCDQCastONNXQuantLSTMLayerHandler]

custom_fns = [
DebugMarkerFunction,
Expand All @@ -64,10 +63,12 @@ def set_export_handler(cls, module: Module):
_set_recurrent_layer_export_handler(cls, module)

@classmethod
def change_weight_handler(cls, export_quantize_node_weight: bool = False):
if export_quantize_node_weight:
cls.handlers.discard(StdCDQCastONNXWeightQuantProxyHandler)
cls.handlers.add(StdQCDQCastONNXWeightQuantProxyHandler)
else:
cls.handlers.discard(StdQCDQCastONNXWeightQuantProxyHandler)
cls.handlers.add(StdCDQCastONNXWeightQuantProxyHandler)
def export_onnx(cls, *args, export_weight_q_node: bool = False, **kwargs):
cls.change_weight_export(export_weight_q_node)
super().export_onnx(*args, **kwargs)

@classmethod
def change_weight_export(cls, export_weight_q_node: bool = False):
for handler in cls.handlers:
if hasattr(handler, '_export_q_node'):
handler._export_weight_q_node = export_weight_q_node
4 changes: 2 additions & 2 deletions src/brevitas/export/onnx/standard/qoperator/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ class StdQOpONNXManager(StdONNXBaseManager):
F.adaptive_max_pool2d,
F.adaptive_max_pool3d,]

handlers = {
handlers = [
StdQOpONNXQuantConv1dHandler,
StdQOpONNXQuantConv2dHandler,
StdQOpONNXQuantLinearHandler,
Expand All @@ -55,7 +55,7 @@ class StdQOpONNXManager(StdONNXBaseManager):
StdQOpONNXQuantTanhHandler,
StdQOpONNXQuantSigmoidHandler,
StdQOpONNXQuantMaxPool1d,
StdQOpONNXQuantMaxPool2d}
StdQOpONNXQuantMaxPool2d]

onnx_passes = [
# remove unused graph inputs & initializers
Expand Down
17 changes: 9 additions & 8 deletions src/brevitas/export/torch/qcdq/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,14 @@

from brevitas.export.common.handler.base import BaseHandler
from brevitas.export.common.handler.qcdq import CDQCastBiasQuantProxyHandlerMixin
from brevitas.export.common.handler.qcdq import CDQCastDecoupledWeightQuantProxyHandlerMixin
from brevitas.export.common.handler.qcdq import \
CDQCastDecoupledWeightQuantWithInputProxyHandlerMixin
from brevitas.export.common.handler.qcdq import CDQCastMixin
from brevitas.export.common.handler.qcdq import CDQCastWeightQuantProxyHandlerMixin
from brevitas.export.common.handler.qcdq import DQCastMixin
from brevitas.export.common.handler.qcdq import QCDQCastActQuantProxyHandlerMixin
from brevitas.export.common.handler.qcdq import QCDQCastDecoupledWeightQuantProxyHandlerMixin
from brevitas.export.common.handler.qcdq import QCDQCastTruncQuantProxyHandlerMixin
from brevitas.export.common.handler.qcdq import QCDQCastWeightQuantProxyHandlerMixin
from brevitas.export.common.handler.qcdq import QMixin


Expand Down Expand Up @@ -95,19 +95,20 @@ def forward(self, *args, **kwargs):
return self.symbolic_execution(*args, **kwargs)


class TorchCDQCastWeightQuantProxyHandler(TorchCDQCastMixin,
CDQCastWeightQuantProxyHandlerMixin,
TorchQCDQHandler):
class TorchQCDQCastWeightQuantProxyHandler(TorchQCDQCastMixin,
QCDQCastWeightQuantProxyHandlerMixin,
TorchQCDQHandler):
_export_q_node = False

@classmethod
def int_clip_symbolic_kwargs(cls, narrow, signed, bit_width):
clip_args = super().int_clip_symbolic_kwargs(narrow, signed, bit_width)
return _itemize_clip_bounds(clip_args)


class TorchCDQCastDecoupledWeightQuantProxyHandler(TorchCDQCastMixin,
CDQCastDecoupledWeightQuantProxyHandlerMixin,
TorchQCDQHandler):
class TorchQCDQCastDecoupledWeightQuantProxyHandler(TorchQCDQCastMixin,
QCDQCastDecoupledWeightQuantProxyHandlerMixin,
TorchQCDQHandler):

@classmethod
def int_clip_symbolic_kwargs(cls, narrow, signed, bit_width):
Expand Down
8 changes: 4 additions & 4 deletions src/brevitas/export/torch/qcdq/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,21 +14,21 @@
from .handler import TorchCDQCastBiasQuantProxyHandler
from .handler import TorchCDQCastDecoupledWeightQuantProxyHandler
from .handler import TorchCDQCastDecoupledWeightQuantWithInputProxyHandler
from .handler import TorchCDQCastWeightQuantProxyHandler
from .handler import TorchQCDQCastActQuantProxyHandler
from .handler import TorchQCDQCastTruncQuantProxyHandler
from .handler import TorchQCDQCastWeightQuantProxyHandler


class TorchQCDQManager(BaseManager):
target_name = 'torch'

handlers = {
TorchCDQCastWeightQuantProxyHandler,
handlers = [
TorchQCDQCastWeightQuantProxyHandler,
TorchCDQCastDecoupledWeightQuantProxyHandler,
TorchCDQCastDecoupledWeightQuantWithInputProxyHandler,
TorchQCDQCastActQuantProxyHandler,
TorchCDQCastBiasQuantProxyHandler,
TorchQCDQCastTruncQuantProxyHandler}
TorchQCDQCastTruncQuantProxyHandler]

@classmethod
def set_export_mode(cls, model: Module, enabled: bool):
Expand Down
4 changes: 2 additions & 2 deletions src/brevitas/export/torch/qoperator/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,15 @@
class TorchQOpManager(BaseManager):
target_name = 'torch'

handlers = {
handlers = [
PytorchQuantMaxPool1d,
PytorchQuantMaxPool2d,
PytorchQuantHardTanhHandler,
PytorchQuantIdentityHandler,
PytorchQuantReLUHandler,
PytorchQuantConv1dHandler,
PytorchQuantConv2dHandler,
PytorchQuantLinearHandler}
PytorchQuantLinearHandler]

@classmethod
def set_export_mode(cls, module: Module, enabled: bool):
Expand Down
15 changes: 11 additions & 4 deletions src/brevitas_examples/llm/llm_quant/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,6 @@ def prepare_for_export(self, module):
assert self.bit_width <= 8., "Only 8b or lower is supported."
quant_layer = module.tracked_module_list[0]
quant_weight = quant_layer.quant_weight()
self.int_weight = quant_weight.int().detach()
self.dtype = quant_weight.value.dtype
self.scale = self.export_scale(module, self.bit_width).detach()
self.expanded_scaling_shape = self.scaling_impl(module).expanded_scaling_shape
Expand All @@ -93,16 +92,24 @@ def forward(self, x):
scale = self.scale.expand(self.expanded_scaling_shape).contiguous()
# contiguous above is to avoid the reshape below being mapped to a unsafe view
scale = scale.view(self.reshaped_scaling_shape)
int_weight = self.int_weight

# Explicitly export custom Q/DQ to avoid aggressive constant folding
x = x / scale
if self.zero_point is not None:
zero_point = self.zero_point.expand(self.expanded_zero_point_shape).contiguous()
# contiguous above is to avoid the reshape below being mapped to a unsafe view
zero_point = zero_point.view(self.reshaped_zero_point_shape)
# avoid unsigned subtraction
int_weight = int_weight.to(self.dtype) - zero_point.to(self.dtype)
x = x.to(self.dtype) + zero_point.to(self.dtype)
else:
zero_point = torch.zeros_like(scale)

int_weight = torch.round(x)
if self.zero_point is not None:
int_weight = int_weight.to(self.dtype) - zero_point.to(self.dtype)

quant_weight = int_weight * scale

return quant_weight, scale, zero_point, self.bit_width


Expand Down Expand Up @@ -192,7 +199,7 @@ def forward(self, x):

class BlockQuantProxyLevelManager(BaseManager):

handlers = {WeightBlockQuantProxyHandler}
handlers = [WeightBlockQuantProxyHandler]

@classmethod
def set_export_handler(cls, module):
Expand Down

0 comments on commit 7ebe2dd

Please sign in to comment.