From 4f8450f1234a010c917842d11233a6a30b37284f Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Mon, 5 Feb 2024 19:51:02 +0000 Subject: [PATCH] Better structure for QDQ weights --- src/brevitas/export/common/handler/qcdq.py | 82 ++++++++++--------- src/brevitas/export/manager.py | 2 +- src/brevitas/export/onnx/qonnx/manager.py | 4 +- .../export/onnx/standard/qcdq/handler.py | 21 ++--- .../export/onnx/standard/qcdq/manager.py | 27 +++--- .../export/onnx/standard/qoperator/manager.py | 4 +- src/brevitas/export/torch/qcdq/handler.py | 17 ++-- src/brevitas/export/torch/qcdq/manager.py | 12 +-- .../export/torch/qoperator/manager.py | 4 +- src/brevitas_examples/llm/llm_quant/export.py | 15 +++- 10 files changed, 98 insertions(+), 90 deletions(-) diff --git a/src/brevitas/export/common/handler/qcdq.py b/src/brevitas/export/common/handler/qcdq.py index dce9f2f34..ea65e2540 100644 --- a/src/brevitas/export/common/handler/qcdq.py +++ b/src/brevitas/export/common/handler/qcdq.py @@ -133,10 +133,33 @@ def dequantize_symbolic_kwargs(cls, scale, zero_point, bit_width, is_signed): 'scale_orig_shape': scale_orig_shape} -class CDQCastWeightQuantProxyHandlerMixin(CDQCastProxyHandlerMixin, ABC): +class QCDQCastWeightQuantProxyHandlerMixin(QMixin, CDQCastProxyHandlerMixin): handled_layer = WeightQuantProxyFromInjector - def prepare_quantize_for_export(self, module): + def quantize_symbolic_kwargs(cls, scale, zero_point, bit_width, is_signed): + # compute axis before redefining scale + axis = cls.quant_axis(scale) + scale = to_0dim_if_scalar(scale.flatten()) + zp = to_0dim_if_scalar(zero_point.flatten()) + # expand_as must go after 0-dim check + zp = zp.expand_as(scale) + zp = cls.zero_point_with_dtype(is_signed, bit_width, zp) + if cls.itemize_quantize_scalar_params: + scale = to_item_if_0dim(scale) + zp = to_item_if_0dim(zp) + dtype = cls.signed_dtype(bit_width, is_signed) + return {'scale': scale, 'zero_point': zp, 'dtype': dtype, 'axis': axis} + + def prepare_quantize_from_floating_point(self, module): + quant_weight = module.tracked_module_list[0].quant_weight() + scale = quant_weight.scale + self.scale_dtype = scale.dtype + if self.scale_dtype == torch.bfloat16 or self.scale_dtype == torch.float16: + scale = self.cast_fn(scale, torch.float32) + self.symbolic_kwargs['quantize_symbolic_kwargs'] = self.quantize_symbolic_kwargs( + scale, quant_weight.zero_point, quant_weight.bit_width, module.is_signed) + + def prepare_quantize_from_integer(self, module): int_weights = { tm.weight.data_ptr(): tm.quant_weight().int(float_datatype=False) for tm in module.tracked_module_list} @@ -145,7 +168,10 @@ def prepare_quantize_for_export(self, module): def prepare_for_export(self, module): if module.is_quant_enabled: self.validate(module) - self.prepare_quantize_for_export(module) + if self._export_q_node: + self.prepare_quantize_from_floating_point(module) + else: + self.prepare_quantize_from_integer(module) # Get the first quant weight as representative quant_weight = module.tracked_module_list[0].quant_weight() @@ -165,12 +191,20 @@ def prepare_for_export(self, module): else: self.symbolic_kwargs = None - def quantize(self, x: Tensor): + def quantize_from_floating_point(self, x: Tensor): + quantize_symbolic_kwargs = self.symbolic_kwargs['quantize_symbolic_kwargs'] + x = self.quantize_fn(x, *quantize_symbolic_kwargs.values()) + return x + + def quantize_from_integer(self, x: Tensor): return self.symbolic_kwargs['int_weights'][x.data_ptr()] def symbolic_execution(self, x: Tensor): assert self.symbolic_kwargs is not None, 'Symbolic execution requires quant to be enabled' - x = self.quantize(x) + if self._export_q_node: + x = self.quantize_from_floating_point(x) + else: + x = self.quantize_from_integer(x) clip_symbolic_kwargs = self.symbolic_kwargs['clip_symbolic_kwargs'] # Copy dict to allow for popping kwargs even on shared quantizers dequantize_symbolic_kwargs = copy(self.symbolic_kwargs['dequantize_symbolic_kwargs']) @@ -193,38 +227,7 @@ def symbolic_execution(self, x: Tensor): return x, scale, zero_point, bit_width -class QCDQCastWeightQuantProxyHandlerMixin(QMixin, CDQCastWeightQuantProxyHandlerMixin): - - def prepare_quantize_for_export(self, module): - quant_weight = module.tracked_module_list[0].quant_weight() - scale = quant_weight.scale - self.scale_dtype = scale.dtype - if self.scale_dtype == torch.bfloat16 or self.scale_dtype == torch.float16: - scale = self.cast_fn(scale, torch.float32) - self.symbolic_kwargs['quantize_symbolic_kwargs'] = self.quantize_symbolic_kwargs( - scale, quant_weight.zero_point, quant_weight.bit_width, module.is_signed) - - def quantize_symbolic_kwargs(cls, scale, zero_point, bit_width, is_signed): - # compute axis before redefining scale - axis = cls.quant_axis(scale) - scale = to_0dim_if_scalar(scale.flatten()) - zp = to_0dim_if_scalar(zero_point.flatten()) - # expand_as must go after 0-dim check - zp = zp.expand_as(scale) - zp = cls.zero_point_with_dtype(is_signed, bit_width, zp) - if cls.itemize_quantize_scalar_params: - scale = to_item_if_0dim(scale) - zp = to_item_if_0dim(zp) - dtype = cls.signed_dtype(bit_width, is_signed) - return {'scale': scale, 'zero_point': zp, 'dtype': dtype, 'axis': axis} - - def quantize(self, x: Tensor): - quantize_symbolic_kwargs = self.symbolic_kwargs['quantize_symbolic_kwargs'] - x = self.quantize_fn(x, *quantize_symbolic_kwargs.values()) - return x - - -class CDQCastDecoupledWeightQuantProxyHandlerMixin(CDQCastWeightQuantProxyHandlerMixin, ABC): +class QCDQCastDecoupledWeightQuantProxyHandlerMixin(QCDQCastWeightQuantProxyHandlerMixin, ABC): handled_layer = DecoupledWeightQuantProxyFromInjector def symbolic_execution(self, x: Tensor): @@ -234,9 +237,12 @@ def symbolic_execution(self, x: Tensor): class CDQCastDecoupledWeightQuantWithInputProxyHandlerMixin( - CDQCastDecoupledWeightQuantProxyHandlerMixin, ABC): + QCDQCastDecoupledWeightQuantProxyHandlerMixin, ABC): handled_layer = DecoupledWeightQuantWithInputProxyFromInjector + def validate(self, module): + assert not self._export_q_node, "This proxy requires to export integer weights" + def symbolic_execution(self, x: Tensor, input_bit_width: torch.Tensor, input_is_signed: bool): return super().symbolic_execution(x) diff --git a/src/brevitas/export/manager.py b/src/brevitas/export/manager.py index 889998cf9..f8f1189fd 100644 --- a/src/brevitas/export/manager.py +++ b/src/brevitas/export/manager.py @@ -160,7 +160,7 @@ def _restore_requires_grad(m: Module, previous_state): class BaseManager(ABC): target_name = None - handlers = set() + handlers = [] _fn_to_cache = [] _fn_cache = [] _cached_io_handler_map = {} diff --git a/src/brevitas/export/onnx/qonnx/manager.py b/src/brevitas/export/onnx/qonnx/manager.py index 169455ae3..4975dced3 100644 --- a/src/brevitas/export/onnx/qonnx/manager.py +++ b/src/brevitas/export/onnx/qonnx/manager.py @@ -35,14 +35,14 @@ class QONNXManager(ONNXBaseManager): "extract_constant_to_initializer", # remove unused graph inputs & initializers "eliminate_unused_initializer"] - handlers = { + handlers = [ BrevitasActQuantProxyHandler, BrevitasBiasQuantProxyHandler, BrevitasWeightQuantProxyHandler, BrevitasDecoupledWeightQuantProxyHandler, BrevitasDecoupledWeightQuantWithInputProxyHandler, BrevitasTruncQuantProxyHandler, - BrevitasQuantLSTMLayerHandler} + BrevitasQuantLSTMLayerHandler] custom_fns = [ DebugMarkerFunction, diff --git a/src/brevitas/export/onnx/standard/qcdq/handler.py b/src/brevitas/export/onnx/standard/qcdq/handler.py index 630f65c7b..5571fb12b 100644 --- a/src/brevitas/export/onnx/standard/qcdq/handler.py +++ b/src/brevitas/export/onnx/standard/qcdq/handler.py @@ -6,15 +6,14 @@ import torch from brevitas.export.common.handler.qcdq import CDQCastBiasQuantProxyHandlerMixin -from brevitas.export.common.handler.qcdq import CDQCastDecoupledWeightQuantProxyHandlerMixin from brevitas.export.common.handler.qcdq import \ CDQCastDecoupledWeightQuantWithInputProxyHandlerMixin from brevitas.export.common.handler.qcdq import CDQCastMixin -from brevitas.export.common.handler.qcdq import CDQCastWeightQuantProxyHandlerMixin from brevitas.export.common.handler.qcdq import DQCastMixin from brevitas.export.common.handler.qcdq import DynamicQDQCastActQuantProxyHandlerMixin from brevitas.export.common.handler.qcdq import DynamicQMixin from brevitas.export.common.handler.qcdq import QCDQCastActQuantProxyHandlerMixin +from brevitas.export.common.handler.qcdq import QCDQCastDecoupledWeightQuantProxyHandlerMixin from brevitas.export.common.handler.qcdq import QCDQCastTruncQuantProxyHandlerMixin from brevitas.export.common.handler.qcdq import QCDQCastWeightQuantProxyHandlerMixin from brevitas.export.common.handler.qcdq import QMixin @@ -110,28 +109,22 @@ def quantize_fn(self, x, dtype): return DynamicQuantizeLinearFn.apply(x, dtype) -class StdCDQCastONNXWeightQuantProxyHandler(StdCDQCastONNXMixin, - CDQCastWeightQuantProxyHandlerMixin, - ONNXBaseHandler): - pass - - class StdQCDQCastONNXWeightQuantProxyHandler(StdQCDQCastONNXMixin, QCDQCastWeightQuantProxyHandlerMixin, ONNXBaseHandler): - pass + _export_q_node: bool = False -class StdCDQCastONNXDecoupledWeightQuantProxyHandler(StdCDQCastONNXMixin, - CDQCastDecoupledWeightQuantProxyHandlerMixin, - ONNXBaseHandler): - pass +class StdQCDQCastONNXDecoupledWeightQuantProxyHandler(StdQCDQCastONNXMixin, + QCDQCastDecoupledWeightQuantProxyHandlerMixin, + ONNXBaseHandler): + _export_q_node: bool = False class StdCDQCastONNXDecoupledWeightQuantWithInputProxyHandler( StdCDQCastONNXMixin, CDQCastDecoupledWeightQuantWithInputProxyHandlerMixin, ONNXBaseHandler): - pass + _export_q_node: bool = False class StdQCDQCastONNXActQuantProxyHandler(StdQCDQCastONNXMixin, diff --git a/src/brevitas/export/onnx/standard/qcdq/manager.py b/src/brevitas/export/onnx/standard/qcdq/manager.py index 00e15a53c..3e4d2308b 100644 --- a/src/brevitas/export/onnx/standard/qcdq/manager.py +++ b/src/brevitas/export/onnx/standard/qcdq/manager.py @@ -16,11 +16,10 @@ from ..function import QuantizeLinearFn from ..manager import StdONNXBaseManager from .handler import StdCDQCastONNXBiasQuantProxyHandler -from .handler import StdCDQCastONNXDecoupledWeightQuantProxyHandler from .handler import StdCDQCastONNXDecoupledWeightQuantWithInputProxyHandler -from .handler import StdCDQCastONNXWeightQuantProxyHandler from .handler import StdDynamicQDQCastONNXActQuantProxyHandler from .handler import StdQCDQCastONNXActQuantProxyHandler +from .handler import StdQCDQCastONNXDecoupledWeightQuantProxyHandler from .handler import StdQCDQCastONNXQuantLSTMLayerHandler from .handler import StdQCDQCastONNXTruncQuantProxyHandler from .handler import StdQCDQCastONNXWeightQuantProxyHandler @@ -35,15 +34,15 @@ class StdQCDQONNXManager(StdONNXBaseManager): "extract_constant_to_initializer", # remove unused graph inputs & initializers "eliminate_unused_initializer"] - handlers = { - StdCDQCastONNXWeightQuantProxyHandler, + handlers = [ + StdQCDQCastONNXWeightQuantProxyHandler, StdCDQCastONNXBiasQuantProxyHandler, StdQCDQCastONNXActQuantProxyHandler, - StdCDQCastONNXDecoupledWeightQuantProxyHandler, + StdQCDQCastONNXDecoupledWeightQuantProxyHandler, StdDynamicQDQCastONNXActQuantProxyHandler, StdQCDQCastONNXTruncQuantProxyHandler, StdCDQCastONNXDecoupledWeightQuantWithInputProxyHandler, - StdQCDQCastONNXQuantLSTMLayerHandler} + StdQCDQCastONNXQuantLSTMLayerHandler] custom_fns = [ DebugMarkerFunction, @@ -64,10 +63,12 @@ def set_export_handler(cls, module: Module): _set_recurrent_layer_export_handler(cls, module) @classmethod - def change_weight_handler(cls, export_quantize_node_weight: bool = False): - if export_quantize_node_weight: - cls.handlers.discard(StdCDQCastONNXWeightQuantProxyHandler) - cls.handlers.add(StdQCDQCastONNXWeightQuantProxyHandler) - else: - cls.handlers.discard(StdQCDQCastONNXWeightQuantProxyHandler) - cls.handlers.add(StdCDQCastONNXWeightQuantProxyHandler) + def export_onnx(cls, *args, export_weight_q_node: bool = False, **kwargs): + cls.change_weight_export(export_weight_q_node) + super().export_onnx(*args, **kwargs) + + @classmethod + def change_weight_export(cls, export_weight_q_node: bool = False): + for handler in cls.handlers: + if hasattr(handler, '_export_q_node'): + handler._export_weight_q_node = export_weight_q_node diff --git a/src/brevitas/export/onnx/standard/qoperator/manager.py b/src/brevitas/export/onnx/standard/qoperator/manager.py index 3f3718eed..4c45df04c 100644 --- a/src/brevitas/export/onnx/standard/qoperator/manager.py +++ b/src/brevitas/export/onnx/standard/qoperator/manager.py @@ -45,7 +45,7 @@ class StdQOpONNXManager(StdONNXBaseManager): F.adaptive_max_pool2d, F.adaptive_max_pool3d,] - handlers = { + handlers = [ StdQOpONNXQuantConv1dHandler, StdQOpONNXQuantConv2dHandler, StdQOpONNXQuantLinearHandler, @@ -55,7 +55,7 @@ class StdQOpONNXManager(StdONNXBaseManager): StdQOpONNXQuantTanhHandler, StdQOpONNXQuantSigmoidHandler, StdQOpONNXQuantMaxPool1d, - StdQOpONNXQuantMaxPool2d} + StdQOpONNXQuantMaxPool2d] onnx_passes = [ # remove unused graph inputs & initializers diff --git a/src/brevitas/export/torch/qcdq/handler.py b/src/brevitas/export/torch/qcdq/handler.py index 41460a6fa..e0a2977e1 100644 --- a/src/brevitas/export/torch/qcdq/handler.py +++ b/src/brevitas/export/torch/qcdq/handler.py @@ -7,14 +7,14 @@ from brevitas.export.common.handler.base import BaseHandler from brevitas.export.common.handler.qcdq import CDQCastBiasQuantProxyHandlerMixin -from brevitas.export.common.handler.qcdq import CDQCastDecoupledWeightQuantProxyHandlerMixin from brevitas.export.common.handler.qcdq import \ CDQCastDecoupledWeightQuantWithInputProxyHandlerMixin from brevitas.export.common.handler.qcdq import CDQCastMixin -from brevitas.export.common.handler.qcdq import CDQCastWeightQuantProxyHandlerMixin from brevitas.export.common.handler.qcdq import DQCastMixin from brevitas.export.common.handler.qcdq import QCDQCastActQuantProxyHandlerMixin +from brevitas.export.common.handler.qcdq import QCDQCastDecoupledWeightQuantProxyHandlerMixin from brevitas.export.common.handler.qcdq import QCDQCastTruncQuantProxyHandlerMixin +from brevitas.export.common.handler.qcdq import QCDQCastWeightQuantProxyHandlerMixin from brevitas.export.common.handler.qcdq import QMixin @@ -95,9 +95,10 @@ def forward(self, *args, **kwargs): return self.symbolic_execution(*args, **kwargs) -class TorchCDQCastWeightQuantProxyHandler(TorchCDQCastMixin, - CDQCastWeightQuantProxyHandlerMixin, - TorchQCDQHandler): +class TorchQCDQCastWeightQuantProxyHandler(TorchQCDQCastMixin, + QCDQCastWeightQuantProxyHandlerMixin, + TorchQCDQHandler): + _export_q_node = False @classmethod def int_clip_symbolic_kwargs(cls, narrow, signed, bit_width): @@ -105,9 +106,9 @@ def int_clip_symbolic_kwargs(cls, narrow, signed, bit_width): return _itemize_clip_bounds(clip_args) -class TorchCDQCastDecoupledWeightQuantProxyHandler(TorchCDQCastMixin, - CDQCastDecoupledWeightQuantProxyHandlerMixin, - TorchQCDQHandler): +class TorchQCDQCastDecoupledWeightQuantProxyHandler(TorchQCDQCastMixin, + QCDQCastDecoupledWeightQuantProxyHandlerMixin, + TorchQCDQHandler): @classmethod def int_clip_symbolic_kwargs(cls, narrow, signed, bit_width): diff --git a/src/brevitas/export/torch/qcdq/manager.py b/src/brevitas/export/torch/qcdq/manager.py index bef1aac43..225076237 100644 --- a/src/brevitas/export/torch/qcdq/manager.py +++ b/src/brevitas/export/torch/qcdq/manager.py @@ -12,23 +12,23 @@ from brevitas.export.manager import ExportContext from .handler import TorchCDQCastBiasQuantProxyHandler -from .handler import TorchCDQCastDecoupledWeightQuantProxyHandler from .handler import TorchCDQCastDecoupledWeightQuantWithInputProxyHandler -from .handler import TorchCDQCastWeightQuantProxyHandler from .handler import TorchQCDQCastActQuantProxyHandler +from .handler import TorchQCDQCastDecoupledWeightQuantProxyHandler from .handler import TorchQCDQCastTruncQuantProxyHandler +from .handler import TorchQCDQCastWeightQuantProxyHandler class TorchQCDQManager(BaseManager): target_name = 'torch' - handlers = { - TorchCDQCastWeightQuantProxyHandler, - TorchCDQCastDecoupledWeightQuantProxyHandler, + handlers = [ + TorchQCDQCastWeightQuantProxyHandler, + TorchQCDQCastDecoupledWeightQuantProxyHandler, TorchCDQCastDecoupledWeightQuantWithInputProxyHandler, TorchQCDQCastActQuantProxyHandler, TorchCDQCastBiasQuantProxyHandler, - TorchQCDQCastTruncQuantProxyHandler} + TorchQCDQCastTruncQuantProxyHandler] @classmethod def set_export_mode(cls, model: Module, enabled: bool): diff --git a/src/brevitas/export/torch/qoperator/manager.py b/src/brevitas/export/torch/qoperator/manager.py index ad932e52b..aea97957c 100644 --- a/src/brevitas/export/torch/qoperator/manager.py +++ b/src/brevitas/export/torch/qoperator/manager.py @@ -27,7 +27,7 @@ class TorchQOpManager(BaseManager): target_name = 'torch' - handlers = { + handlers = [ PytorchQuantMaxPool1d, PytorchQuantMaxPool2d, PytorchQuantHardTanhHandler, @@ -35,7 +35,7 @@ class TorchQOpManager(BaseManager): PytorchQuantReLUHandler, PytorchQuantConv1dHandler, PytorchQuantConv2dHandler, - PytorchQuantLinearHandler} + PytorchQuantLinearHandler] @classmethod def set_export_mode(cls, module: Module, enabled: bool): diff --git a/src/brevitas_examples/llm/llm_quant/export.py b/src/brevitas_examples/llm/llm_quant/export.py index 882a29b89..b16c1aa8b 100644 --- a/src/brevitas_examples/llm/llm_quant/export.py +++ b/src/brevitas_examples/llm/llm_quant/export.py @@ -79,7 +79,6 @@ def prepare_for_export(self, module): assert self.bit_width <= 8., "Only 8b or lower is supported." quant_layer = module.tracked_module_list[0] quant_weight = quant_layer.quant_weight() - self.int_weight = quant_weight.int().detach() self.dtype = quant_weight.value.dtype self.scale = self.export_scale(module, self.bit_width).detach() self.expanded_scaling_shape = self.scaling_impl(module).expanded_scaling_shape @@ -93,16 +92,24 @@ def forward(self, x): scale = self.scale.expand(self.expanded_scaling_shape).contiguous() # contiguous above is to avoid the reshape below being mapped to a unsafe view scale = scale.view(self.reshaped_scaling_shape) - int_weight = self.int_weight + + # Explicitly export custom Q/DQ to avoid aggressive constant folding + x = x / scale if self.zero_point is not None: zero_point = self.zero_point.expand(self.expanded_zero_point_shape).contiguous() # contiguous above is to avoid the reshape below being mapped to a unsafe view zero_point = zero_point.view(self.reshaped_zero_point_shape) # avoid unsigned subtraction - int_weight = int_weight.to(self.dtype) - zero_point.to(self.dtype) + x = x.to(self.dtype) + zero_point.to(self.dtype) else: zero_point = torch.zeros_like(scale) + + int_weight = torch.round(x) + if self.zero_point is not None: + int_weight = int_weight.to(self.dtype) - zero_point.to(self.dtype) + quant_weight = int_weight * scale + return quant_weight, scale, zero_point, self.bit_width @@ -192,7 +199,7 @@ def forward(self, x): class BlockQuantProxyLevelManager(BaseManager): - handlers = {WeightBlockQuantProxyHandler} + handlers = [WeightBlockQuantProxyHandler] @classmethod def set_export_handler(cls, module):