From 3472e7c4888f059c5ea24b507be04e19db0fa5c3 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Fri, 12 Jan 2024 15:45:50 +0000 Subject: [PATCH] Fix (export): add CastMixin --- src/brevitas/export/common/handler/qcdq.py | 28 ++++++++--------- .../export/onnx/standard/qcdq/handler.py | 30 ++++++++++--------- src/brevitas/export/torch/qcdq/handler.py | 25 ++++++++-------- 3 files changed, 43 insertions(+), 40 deletions(-) diff --git a/src/brevitas/export/common/handler/qcdq.py b/src/brevitas/export/common/handler/qcdq.py index a63b6b5b5..eba9c51f1 100644 --- a/src/brevitas/export/common/handler/qcdq.py +++ b/src/brevitas/export/common/handler/qcdq.py @@ -58,7 +58,7 @@ def cast_fn(self, x, dtype): pass -class CDQMixin(DQMixin, ABC): +class CDQCastMixin(DQCastMixin, ABC): @abstractmethod def clip_fn(self, x, min_val, max_val): @@ -102,7 +102,7 @@ def signed_dtype(cls, bit_width, is_signed): return dtype -class CDQProxyHandlerMixin(QuantAxisMixin, ClipMixin, ZeroPointHandlerMixin, CDQMixin, ABC): +class CDQCastProxyHandlerMixin(QuantAxisMixin, ClipMixin, ZeroPointHandlerMixin, CDQCastMixin, ABC): def dequantize_symbolic_kwargs(cls, scale, zero_point, bit_width, is_signed): scale_orig_shape = scale.shape @@ -125,7 +125,7 @@ def dequantize_symbolic_kwargs(cls, scale, zero_point, bit_width, is_signed): 'scale_orig_shape': scale_orig_shape} -class QCDQWeightQuantProxyHandlerMixin(CDQProxyHandlerMixin, ABC): +class CDQCastWeightQuantProxyHandlerMixin(CDQCastProxyHandlerMixin, ABC): handled_layer = WeightQuantProxyFromInjector def prepare_for_export(self, module): @@ -179,7 +179,7 @@ def symbolic_execution(self, x: Tensor): return x, scale, zero_point, bit_width -class QCDQDecoupledWeightQuantProxyHandlerMixin(QCDQWeightQuantProxyHandlerMixin, ABC): +class CDQCastDecoupledWeightQuantProxyHandlerMixin(CDQCastWeightQuantProxyHandlerMixin, ABC): handled_layer = DecoupledWeightQuantProxyFromInjector def symbolic_execution(self, x: Tensor): @@ -188,15 +188,15 @@ def symbolic_execution(self, x: Tensor): return out, scale, zero_point, scale, zero_point, bit_width -class QCDQDecoupledWeightQuantWithInputProxyHandlerMixin(QCDQDecoupledWeightQuantProxyHandlerMixin, - ABC): +class CDQCastDecoupledWeightQuantWithInputProxyHandlerMixin( + CDQCastDecoupledWeightQuantProxyHandlerMixin, ABC): handled_layer = DecoupledWeightQuantWithInputProxyFromInjector def symbolic_execution(self, x: Tensor, input_bit_width: torch.Tensor, input_is_signed: bool): return super().symbolic_execution(x) -class QCDQActQuantProxyHandlerMixin(QMixin, CDQProxyHandlerMixin, ABC): +class QCDQCastActQuantProxyHandlerMixin(QMixin, CDQCastProxyHandlerMixin, ABC): handled_layer = ActQuantProxyFromInjector def quantize_symbolic_kwargs(cls, scale, zero_point, bit_width, is_signed): @@ -265,7 +265,7 @@ def symbolic_execution(self, x: Tensor): return x, scale, zero_point, bit_width -class QCDQBiasQuantProxyHandlerMixin(DQMixin, QuantAxisMixin, ZeroPointHandlerMixin, ABC): +class CDQCastBiasQuantProxyHandlerMixin(DQCastMixin, QuantAxisMixin, ZeroPointHandlerMixin, ABC): handled_layer = BiasQuantProxyFromInjector def validate(self, module): @@ -325,12 +325,12 @@ def symbolic_execution(self, x: Tensor, input_scale=None, input_bit_width=None): return y, scale, zero_point, bit_width -class QCDQTruncQuantProxyHandlerMixin(QuantAxisMixin, - ClipMixin, - ZeroPointHandlerMixin, - QMixin, - CDQMixin, - ABC): +class QCDQCastTruncQuantProxyHandlerMixin(QuantAxisMixin, + ClipMixin, + ZeroPointHandlerMixin, + QMixin, + CDQCastMixin, + ABC): handled_layer = TruncQuantProxyFromInjector def prepare_for_export(self, module: TruncQuantProxyFromInjector): diff --git a/src/brevitas/export/onnx/standard/qcdq/handler.py b/src/brevitas/export/onnx/standard/qcdq/handler.py index 642ae9174..50f75df85 100644 --- a/src/brevitas/export/onnx/standard/qcdq/handler.py +++ b/src/brevitas/export/onnx/standard/qcdq/handler.py @@ -5,14 +5,15 @@ import torch -from brevitas.export.common.handler.qcdq import CDQMixin +from brevitas.export.common.handler.qcdq import CDQCastBiasQuantProxyHandlerMixin +from brevitas.export.common.handler.qcdq import CDQCastDecoupledWeightQuantProxyHandlerMixin +from brevitas.export.common.handler.qcdq import \ + CDQCastDecoupledWeightQuantWithInputProxyHandlerMixin +from brevitas.export.common.handler.qcdq import CDQCastMixin +from brevitas.export.common.handler.qcdq import CDQCastWeightQuantProxyHandlerMixin from brevitas.export.common.handler.qcdq import DQCastMixin -from brevitas.export.common.handler.qcdq import QCDQActQuantProxyHandlerMixin -from brevitas.export.common.handler.qcdq import QCDQBiasQuantProxyHandlerMixin -from brevitas.export.common.handler.qcdq import QCDQDecoupledWeightQuantProxyHandlerMixin -from brevitas.export.common.handler.qcdq import QCDQDecoupledWeightQuantWithInputProxyHandlerMixin -from brevitas.export.common.handler.qcdq import QCDQTruncQuantProxyHandlerMixin -from brevitas.export.common.handler.qcdq import QCDQWeightQuantProxyHandlerMixin +from brevitas.export.common.handler.qcdq import QCDQCastActQuantProxyHandlerMixin +from brevitas.export.common.handler.qcdq import QCDQCastTruncQuantProxyHandlerMixin from brevitas.export.common.handler.qcdq import QMixin from brevitas.export.onnx.handler import ONNXBaseHandler from brevitas.export.onnx.handler import QuantLSTMLayerHandler @@ -43,7 +44,7 @@ def validate(self, module): assert module.bit_width() > 1., 'Binary quant not supported' -class StdCDQCastONNXMixin(CDQMixin, StdDQCastONNXMixin, ABC): +class StdCDQCastONNXMixin(CDQCastMixin, StdDQCastONNXMixin, ABC): def clip_fn(self, x, min_val, max_val): return IntClipFn.apply(x, min_val, max_val) @@ -75,36 +76,37 @@ def quantize_fn(self, x, scale, zero_point, dtype, axis): class StdQCDQCastONNXWeightQuantProxyHandler(StdCDQCastONNXMixin, - QCDQWeightQuantProxyHandlerMixin, + CDQCastWeightQuantProxyHandlerMixin, ONNXBaseHandler): pass class StdQCDQCastONNXDecoupledWeightQuantProxyHandler(StdCDQCastONNXMixin, - QCDQDecoupledWeightQuantProxyHandlerMixin, + CDQCastDecoupledWeightQuantProxyHandlerMixin, ONNXBaseHandler): pass class StdQCDQCastONNXDecoupledWeightQuantWithInputProxyHandler( - StdCDQCastONNXMixin, QCDQDecoupledWeightQuantWithInputProxyHandlerMixin, ONNXBaseHandler): + StdCDQCastONNXMixin, CDQCastDecoupledWeightQuantWithInputProxyHandlerMixin, + ONNXBaseHandler): pass class StdQCDQCastONNXActQuantProxyHandler(StdQCDQCastONNXMixin, - QCDQActQuantProxyHandlerMixin, + QCDQCastActQuantProxyHandlerMixin, ONNXBaseHandler): pass class StdQCDQCastONNXBiasQuantProxyHandler(StdDQCastONNXMixin, - QCDQBiasQuantProxyHandlerMixin, + CDQCastBiasQuantProxyHandlerMixin, ONNXBaseHandler): pass class StdQCDQCastONNXTruncQuantProxyHandler(StdQCDQCastONNXMixin, - QCDQTruncQuantProxyHandlerMixin, + QCDQCastTruncQuantProxyHandlerMixin, ONNXBaseHandler): pass diff --git a/src/brevitas/export/torch/qcdq/handler.py b/src/brevitas/export/torch/qcdq/handler.py index b3474a246..ccbfc75d1 100644 --- a/src/brevitas/export/torch/qcdq/handler.py +++ b/src/brevitas/export/torch/qcdq/handler.py @@ -6,13 +6,14 @@ import torch from brevitas.export.common.handler.base import BaseHandler +from brevitas.export.common.handler.qcdq import CDQCastBiasQuantProxyHandlerMixin +from brevitas.export.common.handler.qcdq import CDQCastDecoupledWeightQuantProxyHandlerMixin +from brevitas.export.common.handler.qcdq import \ + CDQCastDecoupledWeightQuantWithInputProxyHandlerMixin +from brevitas.export.common.handler.qcdq import CDQCastWeightQuantProxyHandlerMixin from brevitas.export.common.handler.qcdq import DQCastMixin -from brevitas.export.common.handler.qcdq import QCDQActQuantProxyHandlerMixin -from brevitas.export.common.handler.qcdq import QCDQBiasQuantProxyHandlerMixin -from brevitas.export.common.handler.qcdq import QCDQDecoupledWeightQuantProxyHandlerMixin -from brevitas.export.common.handler.qcdq import QCDQDecoupledWeightQuantWithInputProxyHandlerMixin -from brevitas.export.common.handler.qcdq import QCDQTruncQuantProxyHandlerMixin -from brevitas.export.common.handler.qcdq import QCDQWeightQuantProxyHandlerMixin +from brevitas.export.common.handler.qcdq import QCDQCastActQuantProxyHandlerMixin +from brevitas.export.common.handler.qcdq import QCDQCastTruncQuantProxyHandlerMixin from brevitas.export.common.handler.qcdq import QMixin @@ -94,7 +95,7 @@ def forward(self, *args, **kwargs): class TorchQCDQCastWeightQuantProxyHandler(TorchCDQCastMixin, - QCDQWeightQuantProxyHandlerMixin, + CDQCastWeightQuantProxyHandlerMixin, TorchQCDQHandler): @classmethod @@ -104,7 +105,7 @@ def int_clip_symbolic_kwargs(cls, narrow, signed, bit_width): class TorchQCDQCastDecoupledWeightQuantProxyHandler(TorchCDQCastMixin, - QCDQDecoupledWeightQuantProxyHandlerMixin, + CDQCastDecoupledWeightQuantProxyHandlerMixin, TorchQCDQHandler): @classmethod @@ -114,7 +115,7 @@ def int_clip_symbolic_kwargs(cls, narrow, signed, bit_width): class TorchQCDQCastDecoupledWeightQuantWithInputProxyHandler( - TorchCDQCastMixin, QCDQDecoupledWeightQuantWithInputProxyHandlerMixin, TorchQCDQHandler): + TorchCDQCastMixin, CDQCastDecoupledWeightQuantWithInputProxyHandlerMixin, TorchQCDQHandler): @classmethod def int_clip_symbolic_kwargs(cls, narrow, signed, bit_width): @@ -123,7 +124,7 @@ def int_clip_symbolic_kwargs(cls, narrow, signed, bit_width): class TorchQCDQCastActQuantProxyHandler(TorchQCDQCastMixin, - QCDQActQuantProxyHandlerMixin, + QCDQCastActQuantProxyHandlerMixin, TorchQCDQHandler): @classmethod @@ -133,13 +134,13 @@ def int_clip_symbolic_kwargs(cls, narrow, signed, bit_width): class TorchQCDQCastBiasQuantProxyHandler(TorchDQCastMixin, - QCDQBiasQuantProxyHandlerMixin, + CDQCastBiasQuantProxyHandlerMixin, TorchQCDQHandler): pass class TorchQCDQCastTruncQuantProxyHandler(TorchQCDQCastMixin, - QCDQTruncQuantProxyHandlerMixin, + QCDQCastTruncQuantProxyHandlerMixin, TorchQCDQHandler): @classmethod