From eeffd2bdfe955b21c690836a3c8bf6b876577014 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Tue, 21 May 2024 10:12:40 +0200 Subject: [PATCH] Feat (quant_tensor): support for float QuantTensor and proxies (#919) --- ...1_quant_tensor_quant_conv2d_overview.ipynb | 84 +++-- requirements/requirements-test.txt | 1 + src/brevitas/core/function_wrapper/clamp.py | 2 +- src/brevitas/core/quant/float.py | 4 +- src/brevitas/fx/value_tracer.py | 4 +- src/brevitas/graph/calibrate.py | 8 +- src/brevitas/graph/gpfq.py | 2 +- src/brevitas/graph/gpxq.py | 6 +- src/brevitas/nn/mixin/base.py | 24 +- src/brevitas/nn/utils.py | 9 +- src/brevitas/proxy/__init__.py | 3 + src/brevitas/proxy/float_parameter_quant.py | 175 ++++++++++ src/brevitas/proxy/float_runtime_quant.py | 102 ++++++ src/brevitas/proxy/parameter_quant.py | 90 +++-- src/brevitas/proxy/runtime_quant.py | 35 +- src/brevitas/quant/experimental/float_base.py | 6 +- src/brevitas/quant_tensor/__init__.py | 5 +- .../quant_tensor/base_quant_tensor.py | 109 +++++- .../quant_tensor/float_quant_tensor.py | 325 ++++++++++++++++++ .../quant_tensor/float_torch_handler.py | 152 ++++++++ src/brevitas/quant_tensor/int_quant_tensor.py | 143 +++----- .../quant_tensor/int_torch_handler.py | 291 ++++++++++++++++ src/brevitas/quant_tensor/torch_handler.py | 270 +-------------- src/brevitas/utils/quant_utils.py | 51 ++- tests/brevitas/core/test_clamp.py | 2 +- tests/brevitas/core/test_float_quant.py | 4 +- tests/brevitas/nn/nn_quantizers_fixture.py | 3 +- tests/brevitas/nn/test_linear.py | 4 +- .../quant_tensor/test_quant_tensor.py | 26 +- tests/brevitas/test_quant_tensor.py | 14 + .../test_torchvision_models.py | 29 +- 31 files changed, 1493 insertions(+), 490 deletions(-) create mode 100644 src/brevitas/proxy/float_parameter_quant.py create mode 100644 src/brevitas/proxy/float_runtime_quant.py create mode 100644 src/brevitas/quant_tensor/float_quant_tensor.py create mode 100644 src/brevitas/quant_tensor/float_torch_handler.py create mode 100644 src/brevitas/quant_tensor/int_torch_handler.py create mode 100644 tests/brevitas/test_quant_tensor.py diff --git a/notebooks/01_quant_tensor_quant_conv2d_overview.ipynb b/notebooks/01_quant_tensor_quant_conv2d_overview.ipynb index d2982f14b..c0ee56321 100644 --- a/notebooks/01_quant_tensor_quant_conv2d_overview.ipynb +++ b/notebooks/01_quant_tensor_quant_conv2d_overview.ipynb @@ -157,11 +157,21 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 4, "metadata": { "scrolled": true }, "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/giuseppe/miniconda3/envs/torch_2.1/lib/python3.11/site-packages/torch/_tensor.py:1394: UserWarning: Named tensors and all their associated APIs are an experimental feature and subject to change. Please do not use them for anything important until they are released as stable. (Triggered internally at /opt/conda/conda-bld/pytorch_1708025842427/work/c10/core/TensorImpl.h:1908.)\n", + " return super().rename(names)\n", + "/home/giuseppe/miniconda3/envs/torch_2.1/lib/python3.11/site-packages/torch/nn/modules/conv.py:456: UserWarning: Defining your `__torch_function__` as a plain method is deprecated and will be an error in future, please define it as a classmethod. (Triggered internally at /opt/conda/conda-bld/pytorch_1708025842427/work/torch/csrc/utils/python_arg_parser.cpp:294.)\n", + " return F.conv2d(input, weight, bias, self.stride,\n" + ] + }, { "data": { "text/plain": [ @@ -255,7 +265,7 @@ { "data": { "text/plain": [ - "QuantTensor(value=tensor([[[[-0.0018, 0.1273, -0.1937],\n", + "IntQuantTensor(value=tensor([[[[-0.0018, 0.1273, -0.1937],\n", " [-0.1734, -0.0904, 0.0627],\n", " [-0.0055, 0.1863, -0.0203]],\n", "\n", @@ -377,8 +387,6 @@ } ], "source": [ - "from brevitas.quant_tensor import QuantTensor\n", - "\n", "quant_act = QuantIdentity(return_quant_tensor=True)\n", "\n", "out_tensor_0 = quant_act(torch.randn(1,2,5,5))\n", @@ -407,7 +415,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "QuantTensor(value=tensor([[[[-0.1106, 1.1945, -0.4972, -2.0968, 0.7175],\n", + "IntQuantTensor(value=tensor([[[[-0.1106, 1.1945, -0.4972, -2.0968, 0.7175],\n", " [-2.5901, 0.0588, -0.2014, 2.1486, 1.6435],\n", " [ 0.9067, -2.5212, 2.2193, 0.2352, -0.8395],\n", " [-0.8351, 0.6341, -0.5551, 0.1040, -3.3151],\n", @@ -467,7 +475,7 @@ { "data": { "text/plain": [ - "QuantTensor(value=tensor([[[[0.5191, 0.6402],\n", + "IntQuantTensor(value=tensor([[[[0.5191, 0.6402],\n", " [2.1455, 0.5883]],\n", "\n", " [[2.0417, 0.5883],\n", @@ -506,7 +514,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "/tmp/ipykernel_4048/1377665000.py:1: UserWarning: Defining your `__torch_function__` as a plain method is deprecated and will be an error in future, please define it as a classmethod. (Triggered internally at /opt/conda/conda-bld/pytorch_1699449183005/work/torch/csrc/utils/python_arg_parser.cpp:368.)\n", + "/tmp/ipykernel_528161/1377665000.py:1: UserWarning: Defining your `__torch_function__` as a plain method is deprecated and will be an error in future, please define it as a classmethod. (Triggered internally at /opt/conda/conda-bld/pytorch_1708025842427/work/torch/csrc/utils/python_arg_parser.cpp:294.)\n", " torch.tanh(quant_tensor)\n" ] }, @@ -555,7 +563,7 @@ { "data": { "text/plain": [ - "QuantTensor(value=tensor([[[[-0.3568, -0.1883, 0.3589],\n", + "IntQuantTensor(value=tensor([[[[-0.3568, -0.1883, 0.3589],\n", " [-0.4470, 0.1039, -0.3945],\n", " [-0.4190, 0.3723, 0.8384]],\n", "\n", @@ -565,7 +573,7 @@ "\n", " [[ 0.2734, 0.7268, -0.0249],\n", " [-0.1732, 0.5197, 1.1158],\n", - " [ 0.3771, -0.3810, 0.2008]]]], grad_fn=), scale=tensor([[[[3.1958e-05]]]], grad_fn=), zero_point=tensor(0.), bit_width=tensor(21.), signed_t=tensor(True), training_t=tensor(True))" + " [ 0.3771, -0.3810, 0.2008]]]], grad_fn=), scale=tensor([[[[3.1958e-05]]]], grad_fn=), zero_point=tensor([0.]), bit_width=tensor(21.), signed_t=tensor(True), training_t=tensor(True))" ] }, "execution_count": 14, @@ -618,39 +626,39 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 17, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "QuantTensor(value=tensor([[[[ 7.2000e-03, -3.7000e-03, 7.7000e-03, -2.4000e-03, -8.9000e-03],\n", - " [-1.2000e-02, -8.1000e-03, 7.2000e-03, -1.1300e-02, -9.7000e-03],\n", - " [-1.0000e-03, 1.0100e-02, 3.8000e-03, -1.1900e-02, 6.9000e-03],\n", - " [ 8.3000e-03, 1.0000e-04, -6.9000e-03, 3.9000e-03, -5.4000e-03],\n", - " [ 1.1300e-02, -6.0000e-03, 9.7000e-03, 0.0000e+00, 1.0900e-02]],\n", + "IntQuantTensor(value=tensor([[[[-9.9000e-03, -7.1000e-03, -4.7000e-03, 5.0000e-03, -1.2300e-02],\n", + " [-8.2000e-03, 8.5000e-03, -1.2000e-03, -1.2500e-02, 4.4000e-03],\n", + " [ 4.3000e-03, -6.3000e-03, -9.4000e-03, 1.0400e-02, -1.2100e-02],\n", + " [ 1.1700e-02, -3.6000e-03, 5.3000e-03, -1.1700e-02, -4.3000e-03],\n", + " [-8.8000e-03, 1.0900e-02, -8.3000e-03, -2.9000e-03, 1.2400e-02]],\n", "\n", - " [[-1.0900e-02, 1.1400e-02, -6.4000e-03, 9.2000e-03, 7.1000e-03],\n", - " [-6.0000e-04, 9.2000e-03, -8.5000e-03, 5.0000e-03, 6.5000e-03],\n", - " [-8.3000e-03, -1.2000e-03, 7.4000e-03, 9.2000e-03, -6.0000e-04],\n", - " [-2.1000e-03, 9.5000e-03, 3.0000e-04, -2.9000e-03, -6.5000e-03],\n", - " [-1.1800e-02, -4.8000e-03, 5.4000e-03, -2.5000e-03, 9.0000e-04]]]]), scale=tensor(1.0000e-04), zero_point=tensor(0.), bit_width=tensor(8.), signed_t=tensor(True), training_t=tensor(True))" + " [[ 9.3000e-03, -8.5000e-03, 6.5000e-03, -2.7000e-03, -3.4000e-03],\n", + " [-1.0000e-04, -1.1000e-02, 8.3000e-03, 1.9000e-03, -9.8000e-03],\n", + " [ 4.3000e-03, -8.5000e-03, 1.1000e-02, 5.3000e-03, 3.4000e-03],\n", + " [ 8.1000e-03, 9.8000e-03, 6.8000e-03, 1.5000e-03, 6.3000e-03],\n", + " [ 5.7000e-03, -8.5000e-03, 5.2000e-03, -3.0000e-04, 4.9000e-03]]]]), scale=tensor(1.0000e-04), zero_point=tensor(0.), bit_width=tensor(8.), signed_t=tensor(True), training_t=tensor(True))" ] }, - "execution_count": 16, + "execution_count": 17, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "from brevitas.quant_tensor import QuantTensor\n", + "from brevitas.quant_tensor import IntQuantTensor\n", "\n", "scale = 0.0001\n", "bit_width = 8\n", "zero_point = 0.\n", "int_value = torch.randint(low=- 2 ** (bit_width - 1), high=2 ** (bit_width - 1) - 1, size=(1, 2, 5, 5))\n", "quant_value = (int_value - zero_point) * scale\n", - "quant_tensor_input = QuantTensor(\n", + "quant_tensor_input = IntQuantTensor(\n", " quant_value, \n", " scale=torch.tensor(scale), \n", " zero_point=torch.tensor(zero_point), \n", @@ -662,7 +670,7 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -688,7 +696,7 @@ }, { "cell_type": "code", - "execution_count": 18, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -721,7 +729,7 @@ }, { "cell_type": "code", - "execution_count": 19, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -745,7 +753,7 @@ }, { "cell_type": "code", - "execution_count": 20, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -784,7 +792,7 @@ }, { "cell_type": "code", - "execution_count": 21, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -820,7 +828,7 @@ }, { "cell_type": "code", - "execution_count": 22, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -856,7 +864,7 @@ }, { "cell_type": "code", - "execution_count": 23, + "execution_count": null, "metadata": { "tags": [ "raises-exception" @@ -897,7 +905,7 @@ }, { "cell_type": "code", - "execution_count": 24, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -935,7 +943,7 @@ }, { "cell_type": "code", - "execution_count": 25, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -968,7 +976,7 @@ }, { "cell_type": "code", - "execution_count": 26, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -1007,7 +1015,7 @@ }, { "cell_type": "code", - "execution_count": 27, + "execution_count": null, "metadata": { "tags": [ "raises-exception" @@ -1049,7 +1057,7 @@ }, { "cell_type": "code", - "execution_count": 28, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -1093,7 +1101,7 @@ }, { "cell_type": "code", - "execution_count": 29, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -1131,7 +1139,7 @@ }, { "cell_type": "code", - "execution_count": 30, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -1155,7 +1163,7 @@ }, { "cell_type": "code", - "execution_count": 31, + "execution_count": null, "metadata": {}, "outputs": [ { diff --git a/requirements/requirements-test.txt b/requirements/requirements-test.txt index bed59f81d..5d6c0ffab 100644 --- a/requirements/requirements-test.txt +++ b/requirements/requirements-test.txt @@ -7,3 +7,4 @@ pytest-xdist pytest_cases scipy torchvision +tqdm diff --git a/src/brevitas/core/function_wrapper/clamp.py b/src/brevitas/core/function_wrapper/clamp.py index 5822d4284..cca7da087 100644 --- a/src/brevitas/core/function_wrapper/clamp.py +++ b/src/brevitas/core/function_wrapper/clamp.py @@ -149,4 +149,4 @@ def forward( "Clamping is not saturating, but neither `inf_values` nor `nan_values` is specified" ) - return x + return x, self.saturating, self.inf_values, self.nan_values diff --git a/src/brevitas/core/quant/float.py b/src/brevitas/core/quant/float.py index d5e3d06d9..371c5551c 100644 --- a/src/brevitas/core/quant/float.py +++ b/src/brevitas/core/quant/float.py @@ -85,8 +85,8 @@ def dequantize(self, y, scale): def forward(self, x): y, scale = self.quantize(x) # after quantizing, clamp to special cases like NaN/inf if they are set - y = self.float_clamp_impl( + y, saturating, inf_values, nan_values = self.float_clamp_impl( y, self.exponent_bit_width(), self.mantissa_bit_width(), self.exponent_bias()) y = self.dequantize(y, scale) # This is to respect the current interface of proxies - return y, scale, self.zero_point_impl(), self.bit_width() + return y, scale, self.zero_point_impl(), self.exponent_bit_width(), self.mantissa_bit_width(), self.exponent_bias(), saturating, inf_values, nan_values diff --git a/src/brevitas/fx/value_tracer.py b/src/brevitas/fx/value_tracer.py index 6cab89767..55ea0d93b 100644 --- a/src/brevitas/fx/value_tracer.py +++ b/src/brevitas/fx/value_tracer.py @@ -57,7 +57,7 @@ import torch.utils._pytree as pytree from brevitas import torch_version -from brevitas.quant_tensor import QuantTensorBase +from brevitas.quant_tensor import QuantTensor from . import * from . import _assert_is_none @@ -82,7 +82,7 @@ from . import ScopeContextManager _UNSET = object() -extended_base_types = base_types + (QuantTensorBase,) +extended_base_types = base_types + (QuantTensor,) FRAME_FILES = [ 'fx/brevitas_tracer.py', diff --git a/src/brevitas/graph/calibrate.py b/src/brevitas/graph/calibrate.py index 06c365a67..8ac55caaa 100644 --- a/src/brevitas/graph/calibrate.py +++ b/src/brevitas/graph/calibrate.py @@ -13,8 +13,8 @@ from brevitas.nn import QuantHardTanh from brevitas.nn import QuantLinear from brevitas.nn.quant_layer import QuantWeightBiasInputOutputLayer as QuantWBIOL -from brevitas.proxy.parameter_quant import BiasQuantProxyFromInjector -from brevitas.proxy.parameter_quant import WeightQuantProxyFromInjector +from brevitas.proxy.parameter_quant import BiasQuantProxyFromInjectorBase +from brevitas.proxy.parameter_quant import WeightQuantProxyFromInjectorBase from brevitas.proxy.runtime_quant import ActQuantProxyFromInjector from brevitas.proxy.runtime_quant import ClampQuantProxyFromInjector from brevitas.proxy.runtime_quant import TruncQuantProxyFromInjector @@ -29,9 +29,9 @@ 'calibration_mode', 'load_quant_model_mode'] -_PARAM_PROXIES = (WeightQuantProxyFromInjector, BiasQuantProxyFromInjector) +_PARAM_PROXIES = (WeightQuantProxyFromInjectorBase, BiasQuantProxyFromInjectorBase) -_BIAS_PROXIES = (BiasQuantProxyFromInjector) +_BIAS_PROXIES = (BiasQuantProxyFromInjectorBase) _ACC_PROXIES = (TruncQuantProxyFromInjector, ClampQuantProxyFromInjector) diff --git a/src/brevitas/graph/gpfq.py b/src/brevitas/graph/gpfq.py index ef720d092..17d6a9c89 100644 --- a/src/brevitas/graph/gpfq.py +++ b/src/brevitas/graph/gpfq.py @@ -316,7 +316,7 @@ def single_layer_update(self): # raise error in case no quant-input is here if self.quant_metadata is None: raise ValueError('Expected self.quant_metadata to calculate L1-norm upper bound, but recevied None. ' + \ - 'Make sure that either the input to the model is a QuantTensor or the layer has an input quant enabled. ' \ + 'Make sure that either the input to the model is a IntQuantTensor or the layer has an input quant enabled. ' \ 'Also, check if `use_quant_activations=True` in `gpfq_mode` when `accumulator_bit_width` is specified. ' + \ 'Alternatively, provide a custom `a2q_layer_filter_fnc` to `gpfq_mode` to filter layers without a quant_tensor input.') weight = self.layer.weight.data diff --git a/src/brevitas/graph/gpxq.py b/src/brevitas/graph/gpxq.py index b85ac1188..fdbaee52f 100644 --- a/src/brevitas/graph/gpxq.py +++ b/src/brevitas/graph/gpxq.py @@ -18,7 +18,7 @@ from brevitas.graph.calibrate import DisableEnableQuantization from brevitas.graph.calibrate import restore_return_quant_tensor import brevitas.nn as qnn -from brevitas.quant_tensor import QuantTensor +from brevitas.quant_tensor import IntQuantTensor from brevitas.utils.quant_utils import _CachedIO SUPPORTED_CONV_OP = ( @@ -227,9 +227,9 @@ def process_input(self, inp): is_quant_enabled = self.layer.weight_quant.is_quant_enabled - # If using quantized activations, inp could be QuantTensor. In + # If using quantized activations, inp could be IntQuantTensor. In # this case, we overwrite the metadata. - if isinstance(inp, QuantTensor): + if isinstance(inp, IntQuantTensor): if is_quant_enabled and self.quant_metadata is None: self.quant_metadata = _CachedIO(inp, metadata_only=True) inp = inp.value diff --git a/src/brevitas/nn/mixin/base.py b/src/brevitas/nn/mixin/base.py index a70394e07..bbdd77ac7 100644 --- a/src/brevitas/nn/mixin/base.py +++ b/src/brevitas/nn/mixin/base.py @@ -18,8 +18,9 @@ from brevitas.inject import ExtendedInjector from brevitas.inject import Injector from brevitas.quant_tensor import _unpack_quant_tensor +from brevitas.quant_tensor import FloatQuantTensor +from brevitas.quant_tensor import IntQuantTensor from brevitas.quant_tensor import QuantTensor -from brevitas.utils.torch_utils import compute_channel_view_shape from .utils import filter_kwargs @@ -69,13 +70,22 @@ def channelwise_separable(self) -> bool: def _set_global_is_quant_layer(self, value): config._IS_INSIDE_QUANT_LAYER = value + def get_quant_tensor_class(self, inp: Union[Tensor, QuantTensor]): + quant_tensor_classes = [IntQuantTensor, FloatQuantTensor] + for qt_class in quant_tensor_classes: + if len(inp) == len(qt_class._fields): + return qt_class + return None + def unpack_input(self, inp: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTensor]: self._set_global_is_quant_layer(True) # Hack to recognize a QuantTensor that has decayed to a tuple # when used as input to tracing (e.g. during ONNX export) if (torch._C._get_tracing_state() is not None and isinstance(inp, tuple) and - len(inp) == len(QuantTensor._fields) and all([isinstance(t, Tensor) for t in inp])): - inp = QuantTensor(*inp) + all([isinstance(t, Tensor) for t in inp])): + qt_class = self.get_quant_tensor_class(inp) + if qt_class is not None: + inp = qt_class(*inp) if not torch._C._get_tracing_state(): if isinstance(inp, QuantTensor): inp = inp.set(value=inp.value.rename(None)) @@ -181,7 +191,7 @@ def pack_quant_outputs(self, quant_outputs): # inner layers in a deep network overrides it, so we check again. if self.export_mode: if self.return_quant_tensor and self.io_quant.is_quant_enabled: - return QuantTensor( + return IntQuantTensor( quant_outputs, self.io_quant.scale(), self.io_quant.zero_point(), @@ -193,7 +203,7 @@ def pack_quant_outputs(self, quant_outputs): seq_dim = 1 if self.cell.batch_first else 0 if self.return_quant_tensor and self.io_quant.is_quant_enabled: outputs = [ - QuantTensor( + IntQuantTensor( torch.unsqueeze(quant_output[0], dim=seq_dim), quant_output[1], quant_output[2], @@ -212,7 +222,7 @@ def pack_quant_state(self, quant_state, quant): # inner layers in a deep network overrides it, so we check again. if self.export_mode: if self.return_quant_tensor and quant.is_quant_enabled: - quant_state = QuantTensor( + quant_state = IntQuantTensor( torch.unsqueeze(quant_state, dim=0), quant.scale(), quant.zero_point(), @@ -223,7 +233,7 @@ def pack_quant_state(self, quant_state, quant): quant_state = torch.unsqueeze(quant_state, dim=0) else: if self.return_quant_tensor and quant.is_quant_enabled: - quant_state = QuantTensor( + quant_state = IntQuantTensor( torch.unsqueeze(quant_state[0], dim=0), quant_state[1], quant_state[2], diff --git a/src/brevitas/nn/utils.py b/src/brevitas/nn/utils.py index fccfbc4de..d718bcc86 100644 --- a/src/brevitas/nn/utils.py +++ b/src/brevitas/nn/utils.py @@ -15,8 +15,8 @@ def mul_add_from_bn(bn_mean, bn_var, bn_eps, bn_weight, bn_bias): def merge_bn(layer, bn, output_channel_dim=0): - from brevitas.proxy.parameter_quant import BiasQuantProxyFromInjector - from brevitas.proxy.parameter_quant import WeightQuantProxyFromInjector + from brevitas.proxy.parameter_quant import BiasQuantProxyFromInjectorBase + from brevitas.proxy.parameter_quant import WeightQuantProxyFromInjectorBase out = mul_add_from_bn( bn_mean=bn.running_mean, bn_var=bn.running_var, @@ -33,9 +33,10 @@ def merge_bn(layer, bn, output_channel_dim=0): else: layer.bias = Parameter(add_factor) if (hasattr(layer, 'weight_quant') and - isinstance(layer.weight_quant, WeightQuantProxyFromInjector)): + isinstance(layer.weight_quant, WeightQuantProxyFromInjectorBase)): layer.weight_quant.init_tensor_quant() - if (hasattr(layer, 'bias_quant') and isinstance(layer.bias_quant, BiasQuantProxyFromInjector)): + if (hasattr(layer, 'bias_quant') and + isinstance(layer.bias_quant, BiasQuantProxyFromInjectorBase)): layer.bias_quant.init_tensor_quant() diff --git a/src/brevitas/proxy/__init__.py b/src/brevitas/proxy/__init__.py index ecf98afc8..57770749d 100644 --- a/src/brevitas/proxy/__init__.py +++ b/src/brevitas/proxy/__init__.py @@ -2,9 +2,12 @@ # SPDX-License-Identifier: BSD-3-Clause from .parameter_quant import BiasQuantProxyFromInjector +from .parameter_quant import BiasQuantProxyFromInjectorBase from .parameter_quant import DecoupledWeightQuantProxyFromInjector from .parameter_quant import DecoupledWeightQuantWithInputProxyFromInjector from .parameter_quant import WeightQuantProxyFromInjector +from .parameter_quant import WeightQuantProxyFromInjectorBase from .runtime_quant import ActQuantProxyFromInjector +from .runtime_quant import ActQuantProxyFromInjectorBase from .runtime_quant import ClampQuantProxyFromInjector from .runtime_quant import TruncQuantProxyFromInjector diff --git a/src/brevitas/proxy/float_parameter_quant.py b/src/brevitas/proxy/float_parameter_quant.py new file mode 100644 index 000000000..835caf647 --- /dev/null +++ b/src/brevitas/proxy/float_parameter_quant.py @@ -0,0 +1,175 @@ +from typing import Optional, Union +from warnings import warn + +import torch +from torch import Tensor +import torch.nn as nn + +from brevitas.inject import BaseInjector as Injector +from brevitas.proxy.parameter_quant import BiasQuantProxyFromInjectorBase +from brevitas.proxy.parameter_quant import WeightQuantProxyFromInjectorBase +from brevitas.quant_tensor import FloatQuantTensor +from brevitas.utils.quant_utils import _CachedIOFloat + + +class WeightFloatQuantProxyFromInjector(WeightQuantProxyFromInjectorBase): + + def scale(self): + if not self.is_quant_enabled: + return None + scale = self.__call__(self.tracked_parameter_list[0]).scale + return scale + + def zero_point(self): + if not self.is_quant_enabled: + return None + zero_point = self.__call__(self.tracked_parameter_list[0]).zero_point + return zero_point + + def exponent_bit_width(self): + if not self.is_quant_enabled: + return None + exponent_bit_width = self.__call__(self.tracked_parameter_list[0]).exponent_bit_width + return exponent_bit_width + + def mantissa_bit_width(self): + if not self.is_quant_enabled: + return None + mantissa_bit_width = self.__call__(self.tracked_parameter_list[0]).mantissa_bit_width + return mantissa_bit_width + + def exponent_bias(self): + if not self.is_quant_enabled: + return None + exponent_bias = self.__call__(self.tracked_parameter_list[0]).exponent_bias + return exponent_bias + + def is_saturating(self): + if not self.is_quant_enabled: + return None + saturating = self.__call__(self.tracked_parameter_list[0]).saturating + return saturating + + def inf_values(self): + if not self.is_quant_enabled: + return None + inf_values = self.__call__(self.tracked_parameter_list[0]).inf_values + return inf_values + + def nan_values(self): + if not self.is_quant_enabled: + return None + nan_values = self.__call__(self.tracked_parameter_list[0]).nan_values + return nan_values + + def forward(self, x: torch.Tensor) -> Union[Tensor, FloatQuantTensor]: + if self.is_quant_enabled: + impl = self.export_handler if self.export_mode else self.tensor_quant + out, scale, zero_point, exponent_bit_width, mantissa_bit_width, exponent_bias, saturating, inf_values, nan_values = impl(x) + return FloatQuantTensor( + out, + scale, + zero_point, + exponent_bit_width, + mantissa_bit_width, + exponent_bias, + saturating, + inf_values, + nan_values, + self.is_signed, + self.training) + else: # quantization disabled + return x + + +class BiasFloatQuantProxyFromInjector(BiasQuantProxyFromInjectorBase): + + def scale(self): + if not self.is_quant_enabled: + return None + if self.requires_input_scale: + cache = self.get_cached('scale') + return cache + zhs = self._zero_hw_sentinel() + scale = self.__call__(self.tracked_parameter_list[0], zhs).scale + return scale + + def zero_point(self): + if not self.is_quant_enabled: + return None + zhs = self._zero_hw_sentinel() + zero_point = self.__call__(self.tracked_parameter_list[0], zhs).zero_point + return zero_point + + def exponent_bit_width(self): + if not self.is_quant_enabled: + return None + exponent_bit_width = self.__call__(self.tracked_parameter_list[0]).exponent_bit_width + return exponent_bit_width + + def mantissa_bit_width(self): + if not self.is_quant_enabled: + return None + mantissa_bit_width = self.__call__(self.tracked_parameter_list[0]).mantissa_bit_width + return mantissa_bit_width + + def exponent_bias(self): + if not self.is_quant_enabled: + return None + exponent_bias = self.__call__(self.tracked_parameter_list[0]).exponent_bias + return exponent_bias + + def is_saturating(self): + if not self.is_quant_enabled: + return None + saturating = self.__call__(self.tracked_parameter_list[0]).saturating + return saturating + + def inf_values(self): + if not self.is_quant_enabled: + return None + inf_values = self.__call__(self.tracked_parameter_list[0]).inf_values + return inf_values + + def nan_values(self): + if not self.is_quant_enabled: + return None + nan_values = self.__call__(self.tracked_parameter_list[0]).nan_values + return nan_values + + def forward(self, + x: Tensor, + input_scale: Optional[Tensor] = None) -> Union[Tensor, FloatQuantTensor]: + out = x + if self.is_quant_enabled: + impl = self.export_handler if self.export_mode else self.tensor_quant + if self.requires_input_scale and input_scale is None: + input_scale = self.scale() + if input_scale is None: + raise RuntimeError("Input scale required") + + if self.requires_input_scale: + input_scale = input_scale.view(-1) + out, scale, zero_point, exponent_bit_width, mantissa_bit_width, exponent_bias, saturating, inf_values, nan_values = impl(x, input_scale) + else: + out, scale, zero_point, exponent_bit_width, mantissa_bit_width, exponent_bias, saturating, inf_values, nan_values = impl(x) + + out = FloatQuantTensor( + out, + scale, + zero_point, + exponent_bit_width, + mantissa_bit_width, + exponent_bias, + saturating, + inf_values, + nan_values, + self.is_signed, + self.training) + else: + out = x + if isinstance(out, + FloatQuantTensor) and not self.training and self.cache_inference_quant_bias: + cached_bias = _CachedIOFloat(out.detach(), metadata_only=False) + self._cached_bias = cached_bias + return out diff --git a/src/brevitas/proxy/float_runtime_quant.py b/src/brevitas/proxy/float_runtime_quant.py new file mode 100644 index 000000000..5fc1f2411 --- /dev/null +++ b/src/brevitas/proxy/float_runtime_quant.py @@ -0,0 +1,102 @@ +from typing import Optional, Union +from warnings import warn + +import torch +from torch import Tensor +import torch.nn as nn + +from brevitas.inject import BaseInjector as Injector +from brevitas.proxy.runtime_quant import ActQuantProxyFromInjectorBase +from brevitas.quant_tensor import FloatQuantTensor +from brevitas.utils.quant_utils import _CachedIOFloat + + +class ActFloatQuantProxyFromInjector(ActQuantProxyFromInjectorBase): + + def scale(self, force_eval=True): + if self.is_quant_enabled: + current_status = self.training + if force_eval: + self.eval() + out = self.__call__(self._zero_hw_sentinel()) + self.train(current_status) + return out.scale + elif self._cached_act is not None: + return self._cached_act.scale + elif self._cached_act is None: + return None + + def zero_point(self, force_eval=True): + if self.is_quant_enabled: + current_status = self.training + if force_eval: + self.eval() + out = self.__call__(self._zero_hw_sentinel()) + self.train(current_status) + return out.zero_point + elif self._cached_act is not None: + return self._cached_act.zero_point + elif self._cached_act is None: + return None + + def bit_width(self, force_eval=True): + if self.is_quant_enabled: + current_status = self.training + if force_eval: + self.eval() + out = self.__call__(self._zero_hw_sentinel()) + self.train(current_status) + return out.bit_width + elif self._cached_act is not None: + return self._cached_act.bit_width + elif self._cached_act is None: + return None + + def forward(self, x: Union[Tensor, FloatQuantTensor]) -> Union[Tensor, FloatQuantTensor]: + out = x + if self.fused_activation_quant_proxy is not None: + y = x + if isinstance(y, FloatQuantTensor): + y = y.value + + if self.export_mode: + y = self.fused_activation_quant_proxy.activation_impl(y) + y = self.export_handler(y) + elif not self.is_quant_enabled: + y = self.fused_activation_quant_proxy.activation_impl(y) + else: + y = self.fused_activation_quant_proxy(y) + # If y is an empty FloatQuantTensor, we need to check if this is a passthrough proxy, + # otherwise return a simple Tensor + if isinstance(y, tuple) and not any(map(lambda f: f is None, y)): + out = FloatQuantTensor(*y, signed=self.is_signed, training=self.training) + elif self.is_passthrough_act: # preserve scale/zp/bit/sign even without output quant + if isinstance(y, tuple): + y = y[0] + if isinstance(x, FloatQuantTensor): + out = FloatQuantTensor( + y, + x.scale, + x.zero_point, + x.mantissa_bit_width, + x.exponent_bit_width, + x.exponent_bias, + x.signed, + self.training, + x.saturating, + x.inf_values, + x.nan_values) + else: + out = y + else: + if isinstance(y, tuple): + y = y[0] + out = y + else: + # If fused activation quant proxy is not enabled, return the input + out = x + if not self.training and self.cache_inference_quant_act and isinstance(out, + FloatQuantTensor): + cached_out = _CachedIOFloat(out.detach(), self.cache_quant_io_metadata_only) + self._cached_act = cached_out + return out diff --git a/src/brevitas/proxy/parameter_quant.py b/src/brevitas/proxy/parameter_quant.py index 4a95b5e0f..06e181f31 100644 --- a/src/brevitas/proxy/parameter_quant.py +++ b/src/brevitas/proxy/parameter_quant.py @@ -1,6 +1,7 @@ # Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause +from abc import ABC from abc import ABCMeta from abc import abstractmethod from typing import Optional, Union @@ -15,6 +16,7 @@ from brevitas import config from brevitas.function import max_int from brevitas.inject import BaseInjector as Injector +from brevitas.quant_tensor import IntQuantTensor from brevitas.quant_tensor import QuantTensor from brevitas.utils.quant_utils import _CachedIO from brevitas.utils.torch_utils import compute_channel_view_shape @@ -24,7 +26,9 @@ __all__ = [ 'WeightQuantProxyFromInjector', + 'WeightQuantProxyFromInjectorBase', 'BiasQuantProxyFromInjector', + 'BiasQuantProxyFromInjectorBase', 'WeightQuantProxyProtocol', 'BiasQuantProxyProtocol'] @@ -76,7 +80,49 @@ def max_uint_value(self, bit_width): return max_int(False, self.is_narrow_range, bit_width) -class WeightQuantProxyFromInjector(ParameterQuantProxyFromInjector, WeightQuantProxyProtocol): +class WeightQuantProxyFromInjectorBase(ParameterQuantProxyFromInjector, + WeightQuantProxyProtocol, + ABC): + + @property + def tracked_parameter_list(self): + return [m.weight for m in self.tracked_module_list if m.weight is not None] + + @property + def requires_quant_input(self): + return False + + +class BiasQuantProxyFromInjectorBase(ParameterQuantProxyFromInjector, BiasQuantProxyProtocol, ABC): + + def __init__(self, quant_layer: nn.Module, quant_injector: Injector) -> None: + super().__init__(quant_layer, quant_injector) + self._cached_bias = None + self.cache_inference_quant_bias = False + + @property + def tracked_parameter_list(self): + return [m.bias for m in self.tracked_module_list if m.bias is not None] + + @property + def requires_input_scale(self) -> bool: + if self.is_quant_enabled: + return self.quant_injector.requires_input_scale + else: + return False + + def get_cached(self, attr): + if self._cached_bias is None: + warn( + "No quant bias cache found, set cache_inference_quant_bias=True and run an " + "inference pass first") + return None + if self.training: + warn("Cached quant bias scale is being used in training mode.") + return getattr(self._cached_bias, attr) + + +class WeightQuantProxyFromInjector(WeightQuantProxyFromInjectorBase): @property def tracked_parameter_list(self): @@ -104,11 +150,11 @@ def bit_width(self): bit_width = self.__call__(self.tracked_parameter_list[0]).bit_width return bit_width - def forward(self, x: torch.Tensor) -> Union[Tensor, QuantTensor]: + def forward(self, x: torch.Tensor) -> Union[Tensor, IntQuantTensor]: if self.is_quant_enabled: impl = self.export_handler if self.export_mode else self.tensor_quant out, scale, zero_point, bit_width = impl(x) - return QuantTensor(out, scale, zero_point, bit_width, self.is_signed, self.training) + return IntQuantTensor(out, scale, zero_point, bit_width, self.is_signed, self.training) else: # quantization disabled return x @@ -129,11 +175,11 @@ def pre_zero_point(self): out, scale, zero_point, bit_width, pre_scale, pre_zero_point = output_tuple return pre_zero_point - def forward(self, x: torch.Tensor) -> Union[Tensor, QuantTensor]: + def forward(self, x: torch.Tensor) -> Union[Tensor, IntQuantTensor]: if self.is_quant_enabled: impl = self.export_handler if self.export_mode else self.tensor_quant out, scale, zero_point, bit_width, pre_scale, pre_zero_point = impl(x) - return QuantTensor(out, scale, zero_point, bit_width, self.is_signed, self.training) + return IntQuantTensor(out, scale, zero_point, bit_width, self.is_signed, self.training) else: # quantization disabled return x @@ -158,11 +204,12 @@ def pre_zero_point(self): raise NotImplementedError def forward( - self, - x: torch.Tensor, - quant_input: Optional[Union[Tensor, QuantTensor]] = None) -> Union[Tensor, QuantTensor]: + self, + x: torch.Tensor, + quant_input: Optional[Union[Tensor, + IntQuantTensor]] = None) -> Union[Tensor, IntQuantTensor]: if isinstance(quant_input, - QuantTensor) and not self.training and self.cache_inference_quant_act: + IntQuantTensor) and not self.training and self.cache_inference_quant_act: cached_inp = _CachedIO(quant_input.detach(), self.cache_quant_io_metadata_only) self._cached_act = cached_inp @@ -171,19 +218,19 @@ def forward( assert self._cached_act is not None, "No cached quant input found. Enable caching and perform a forward pass" quant_input = self._cached_act else: - assert isinstance(quant_input, QuantTensor), "Input must be quantized" + assert isinstance(quant_input, IntQuantTensor), "Input must be quantized" input_bit_width = quant_input.bit_width input_is_signed = quant_input.signed impl = self.export_handler if self.export_mode else self.tensor_quant out, scale, zero_point, bit_width, pre_scale, pre_zero_point = impl(x, input_bit_width, input_is_signed) - return QuantTensor(out, scale, zero_point, bit_width, self.is_signed, self.training) + return IntQuantTensor(out, scale, zero_point, bit_width, self.is_signed, self.training) else: # quantization disabled return x -class BiasQuantProxyFromInjector(ParameterQuantProxyFromInjector, BiasQuantProxyProtocol): +class BiasQuantProxyFromInjector(BiasQuantProxyFromInjectorBase): def __init__(self, quant_layer: nn.Module, quant_injector: Injector) -> None: super().__init__(quant_layer, quant_injector) @@ -236,7 +283,7 @@ def bit_width(self): return bit_width def quant_output_scale_impl( - self, input: QuantTensor, weight: QuantTensor, module: torch.nn.Module) -> Tensor: + self, input: IntQuantTensor, weight: IntQuantTensor, module: torch.nn.Module) -> Tensor: channel_dim = -1 if isinstance(module, torch.nn.Linear) else 1 output_scale_shape = compute_channel_view_shape(input, channel_dim=channel_dim) output_scale = weight.scale.view(output_scale_shape) @@ -245,11 +292,11 @@ def quant_output_scale_impl( def compute_bias_scale( self, - input: Optional[Union[Tensor, QuantTensor]], - weight: Optional[Union[Tensor, QuantTensor]]) -> Optional[Tensor]: + input: Optional[Union[Tensor, IntQuantTensor]], + weight: Optional[Union[Tensor, IntQuantTensor]]) -> Optional[Tensor]: if not self.requires_input_scale: return None - if not isinstance(input, QuantTensor) or not isinstance(weight, QuantTensor): + if not isinstance(input, IntQuantTensor) or not isinstance(weight, IntQuantTensor): return None if len(self.tracked_module_list) > 1: if not all( @@ -263,8 +310,9 @@ def compute_bias_scale( def forward( self, x: Tensor, - input: Optional[Union[Tensor, QuantTensor]] = None, - weight: Optional[Union[Tensor, QuantTensor]] = None) -> Union[Tensor, QuantTensor]: + input: Optional[Union[Tensor, IntQuantTensor]] = None, + weight: Optional[Union[Tensor, + IntQuantTensor]] = None) -> Union[Tensor, IntQuantTensor]: out = x input_scale = self.compute_bias_scale(input, weight) if self.is_quant_enabled: @@ -280,10 +328,12 @@ def forward( else: out, out_scale, out_zp, out_bit_width = impl(x) - out = QuantTensor(out, out_scale, out_zp, out_bit_width, self.is_signed, self.training) + out = IntQuantTensor( + out, out_scale, out_zp, out_bit_width, self.is_signed, self.training) else: out = x - if isinstance(out, QuantTensor) and not self.training and self.cache_inference_quant_bias: + if isinstance(out, + IntQuantTensor) and not self.training and self.cache_inference_quant_bias: cached_bias = _CachedIO(out.detach(), metadata_only=False) self._cached_bias = cached_bias return out diff --git a/src/brevitas/proxy/runtime_quant.py b/src/brevitas/proxy/runtime_quant.py index 4dd8417a9..2457298b1 100644 --- a/src/brevitas/proxy/runtime_quant.py +++ b/src/brevitas/proxy/runtime_quant.py @@ -1,6 +1,7 @@ # Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause +from abc import ABC from typing import Optional, Tuple, Union from torch import nn @@ -10,6 +11,7 @@ from typing_extensions import runtime_checkable import brevitas +from brevitas.quant_tensor import IntQuantTensor from brevitas.quant_tensor import QuantTensor from brevitas.utils.quant_utils import _CachedIO @@ -80,11 +82,10 @@ def __init__(self, activation_impl, tensor_quant): @brevitas.jit.script_method def forward(self, x): x = self.activation_impl(x) - x, output_scale, output_zp, output_bit_width = self.tensor_quant(x) - return x, output_scale, output_zp, output_bit_width + return self.tensor_quant(x) -class ActQuantProxyFromInjector(QuantProxyFromInjector, ActQuantProxyProtocol): +class ActQuantProxyFromInjectorBase(QuantProxyFromInjector, ActQuantProxyProtocol, ABC): def __init__(self, quant_layer, quant_injector): QuantProxyFromInjector.__init__(self, quant_layer, quant_injector) @@ -127,6 +128,9 @@ def init_tensor_quant(self): else: self.fused_activation_quant_proxy = None + +class ActQuantProxyFromInjector(ActQuantProxyFromInjectorBase): + def scale(self, force_eval=True): if self.is_quant_enabled: current_status = self.training @@ -166,11 +170,11 @@ def bit_width(self, force_eval=True): elif self._cached_act is None: return None - def forward(self, x: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTensor]: + def forward(self, x: Union[Tensor, IntQuantTensor]) -> Union[Tensor, IntQuantTensor]: out = x if self.fused_activation_quant_proxy is not None: y = x - if isinstance(y, QuantTensor): + if isinstance(y, IntQuantTensor): y = y.value if self.export_mode: @@ -180,15 +184,15 @@ def forward(self, x: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTensor]: y = self.fused_activation_quant_proxy.activation_impl(y) else: y = self.fused_activation_quant_proxy(y) - # If y is an empty QuantTensor, we need to check if this is a passthrough proxy, + # If y is an empty IntQuantTensor, we need to check if this is a passthrough proxy, # otherwise return a simple Tensor if isinstance(y, tuple) and not any(map(lambda f: f is None, y)): - out = QuantTensor(*y, signed=self.is_signed, training=self.training) + out = IntQuantTensor(*y, signed=self.is_signed, training=self.training) elif self.is_passthrough_act: # preserve scale/zp/bit/sign even without output quant if isinstance(y, tuple): y = y[0] - if isinstance(x, QuantTensor): - out = QuantTensor( + if isinstance(x, IntQuantTensor): + out = IntQuantTensor( y, x.scale, x.zero_point, x.bit_width, x.signed, self.training) else: out = y @@ -199,7 +203,7 @@ def forward(self, x: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTensor]: else: # If fused activation quant proxy is not enabled, return the input out = x - if not self.training and self.cache_inference_quant_act and isinstance(out, QuantTensor): + if not self.training and self.cache_inference_quant_act and isinstance(out, IntQuantTensor): cached_out = _CachedIO(out.detach(), self.cache_quant_io_metadata_only) self._cached_act = cached_out return out @@ -216,11 +220,11 @@ def zero_point(self, force_eval=True): class ClampQuantProxyFromInjector(QuantProxyFromInjector, AccQuantProxyProtocol): - def forward(self, x: QuantTensor) -> Union[Tensor, QuantTensor]: + def forward(self, x: IntQuantTensor) -> Union[Tensor, IntQuantTensor]: if self.is_quant_enabled: out_tuple = self.tensor_quant(x.value, x.scale, x.bit_width) out_value, out_scale, out_zp, out_bit_width = out_tuple - return QuantTensor( + return IntQuantTensor( out_value, out_scale, out_zp, out_bit_width, self.is_signed, self.training) return x @@ -232,11 +236,11 @@ def bit_width(self): return None zhs = self._zero_hw_sentinel() # Signed might or might not be defined. We just care about retrieving the bitwidth - empty_imp = QuantTensor(zhs, zhs, zhs, zhs, signed=True, training=self.training) + empty_imp = IntQuantTensor(zhs, zhs, zhs, zhs, signed=True, training=self.training) bit_width = self.__call__(empty_imp).bit_width return bit_width - def forward(self, x: QuantTensor) -> Union[Tensor, QuantTensor]: + def forward(self, x: IntQuantTensor) -> Union[Tensor, IntQuantTensor]: if self.is_quant_enabled: if self.export_mode: out_tuple = self.export_handler( @@ -244,7 +248,8 @@ def forward(self, x: QuantTensor) -> Union[Tensor, QuantTensor]: else: out_tuple = self.tensor_quant(x.value, x.scale, x.zero_point, x.bit_width) out_value, out_scale, out_zp, out_bit_width = out_tuple - return QuantTensor(out_value, out_scale, out_zp, out_bit_width, x.signed, self.training) + return IntQuantTensor( + out_value, out_scale, out_zp, out_bit_width, x.signed, self.training) else: return x diff --git a/src/brevitas/quant/experimental/float_base.py b/src/brevitas/quant/experimental/float_base.py index 9a2893039..46e46e41e 100644 --- a/src/brevitas/quant/experimental/float_base.py +++ b/src/brevitas/quant/experimental/float_base.py @@ -7,6 +7,8 @@ from brevitas.core.scaling.float_scaling import FloatScaling from brevitas.inject import ExtendedInjector from brevitas.inject import value +from brevitas.proxy.float_parameter_quant import WeightFloatQuantProxyFromInjector +from brevitas.proxy.float_runtime_quant import ActFloatQuantProxyFromInjector from brevitas.proxy.parameter_quant import WeightQuantProxyFromInjector from brevitas.proxy.runtime_quant import ActQuantProxyFromInjector from brevitas.quant.solver import ActQuantSolver @@ -28,11 +30,11 @@ def exponent_bias(exponent_bit_width): class FloatWeightBase(FloatBase): - proxy_class = WeightQuantProxyFromInjector + proxy_class = WeightFloatQuantProxyFromInjector class FloatActBase(FloatBase): - proxy_class = ActQuantProxyFromInjector + proxy_class = ActFloatQuantProxyFromInjector class ScaledFloatWeightBase(FloatWeightBase, WeightQuantSolver): diff --git a/src/brevitas/quant_tensor/__init__.py b/src/brevitas/quant_tensor/__init__.py index 7e58bc551..be7e10bc0 100644 --- a/src/brevitas/quant_tensor/__init__.py +++ b/src/brevitas/quant_tensor/__init__.py @@ -1,6 +1,7 @@ # Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause +from .base_quant_tensor import * from .base_quant_tensor import _unpack_quant_tensor -from .base_quant_tensor import QuantTensorBase -from .int_quant_tensor import QuantTensor +from .float_quant_tensor import * +from .int_quant_tensor import * diff --git a/src/brevitas/quant_tensor/base_quant_tensor.py b/src/brevitas/quant_tensor/base_quant_tensor.py index e8d1439d7..6239b8324 100644 --- a/src/brevitas/quant_tensor/base_quant_tensor.py +++ b/src/brevitas/quant_tensor/base_quant_tensor.py @@ -1,9 +1,98 @@ -from typing import NamedTuple +from typing import List, NamedTuple, Optional from torch import Tensor -class QuantTensorBase(NamedTuple): +# Base class for all QuantTensor. +# Only assumptions made by these methods are: +# - `self` is a NamedTuple with a `_fields` attribute +# - `self` has a `value` attribute +class QuantTensor: + + def detach_(self): + for field in self._fields: + getattr(self, field).detach_() + + def detach(self): + qt_type = type(self) + values = [] + for field in self._fields: + value = getattr(self, field) + if isinstance(value, Tensor): + value = value.detach() + values.append(value) + return qt_type(*values) + + def contiguous(self): + qt_type = type(self) + values = [] + for field in self._fields: + value = getattr(self, field) + if isinstance(value, Tensor): + value = value.contiguous() + values.append(value) + return qt_type(*values) + + def set(self, **kwargs): + return self._replace(**kwargs) + + @property + def shape(self): + return self.value.shape + + def dim(self): + return self.value.dim() + + def add(self, other): + return self + other + + def to(self, *args, **kwargs): + qt_type = type(self) + values = [] + for field in self._fields: + value = getattr(self, field) + if isinstance(value, Tensor): + value = value.to(*args, **kwargs) + values.append(value) + return qt_type(*values) + + def cuda(self, *args, **kwargs): + qt_type = type(self) + values = [] + for field in self._fields: + value = getattr(self, field) + if isinstance(value, Tensor): + value = value.cuda(*args, **kwargs) + values.append(value) + return qt_type(*values) + + def cpu(self, *args, **kwargs): + qt_type = type(self) + values = [] + for field in self._fields: + value = getattr(self, field) + if isinstance(value, Tensor): + value = value.cpu(*args, **kwargs) + values.append(value) + return qt_type(*values) + + def __radd__(self, other): + return self.__add__(other) + + def __rmul__(self, other): + return self.__mul__(other) + + def __sub__(self, other): + return self.__add__(-other) + + def __pos__(self): + return self + + def size(self, *args, **kwargs): + return self.value.size(*args, **kwargs) + + +class IntQuantTensorBase(NamedTuple): value: Tensor scale: Tensor zero_point: Tensor @@ -12,8 +101,22 @@ class QuantTensorBase(NamedTuple): training_t: Tensor +class FloatQuantTensorBase(NamedTuple): + value: Tensor + scale: Tensor + zero_point: Tensor + exponent_bit_width: Tensor + mantissa_bit_width: Tensor + exponent_bias: Tensor + saturating_t: Tensor + inf_values: List[str] + nan_values: List[str] + signed_t: Tensor + training_t: Tensor + + def _unpack_quant_tensor(input_data): - if isinstance(input_data, QuantTensorBase): + if isinstance(input_data, QuantTensor): return input_data.value elif isinstance(input_data, tuple): return tuple([_unpack_quant_tensor(v) for v in input_data]) diff --git a/src/brevitas/quant_tensor/float_quant_tensor.py b/src/brevitas/quant_tensor/float_quant_tensor.py new file mode 100644 index 000000000..c2bb99900 --- /dev/null +++ b/src/brevitas/quant_tensor/float_quant_tensor.py @@ -0,0 +1,325 @@ +# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +import torch + +from brevitas.function.ops_ste import round_ste +from brevitas.quant_tensor import _unpack_quant_tensor +from brevitas.quant_tensor import FloatQuantTensorBase +from brevitas.quant_tensor import QuantTensor + +from .float_torch_handler import FLOAT_QUANT_TENSOR_FN_HANDLER +from .torch_handler import QUANT_TENSOR_FN_HANDLER + +IS_VALID_ATOL = 2e-1 +BFLOAT16_IS_VALID_ATOL = 0.5 + + +class FloatQuantTensor(FloatQuantTensorBase, QuantTensor): + + def __new__( + cls, + value, + scale, + zero_point, + exponent_bit_width, + mantissa_bit_width, + exponent_bias, + saturating, + inf_values, + nan_values, + signed, + training): + + if not isinstance(scale, torch.Tensor): + scale = torch.tensor(scale, dtype=torch.float) + if not isinstance(zero_point, torch.Tensor): + zero_point = torch.tensor(zero_point, dtype=torch.float) + if not isinstance(exponent_bit_width, torch.Tensor): + exponent_bit_width = torch.tensor(exponent_bit_width, dtype=torch.float) + if not isinstance(mantissa_bit_width, torch.Tensor): + mantissa_bit_width = torch.tensor(mantissa_bit_width, dtype=torch.float) + if not isinstance(exponent_bias, torch.Tensor): + exponent_bias = torch.tensor(exponent_bias, dtype=torch.float) + if not isinstance(saturating, torch.Tensor): + saturating = torch.tensor(saturating, dtype=torch.bool) + if not isinstance(signed, torch.Tensor): + signed = torch.tensor(signed, dtype=torch.bool) + if not isinstance(training, torch.Tensor): + training = torch.tensor(training, dtype=torch.bool) + quant_tensor = super().__new__( + cls, + value, + scale, + zero_point, + exponent_bit_width, + mantissa_bit_width, + exponent_bias, + saturating, + inf_values, + nan_values, + signed, + training) + return quant_tensor + + @property + def signed(self): + return self.signed_t.item() + + @property + def training(self): + return self.training_t.item() + + @property + def saturating(self): + return self.saturating_t.item() + + def __torch_function__(self, func, types, args=(), kwargs=None): + if kwargs is None: + kwargs = {} + if func in FLOAT_QUANT_TENSOR_FN_HANDLER: + return FLOAT_QUANT_TENSOR_FN_HANDLER[func](*args, **kwargs) + elif func in QUANT_TENSOR_FN_HANDLER: + return QUANT_TENSOR_FN_HANDLER[func](*args, **kwargs) + else: + args = _unpack_quant_tensor(args) + kwargs = _unpack_quant_tensor(kwargs) + return func(*args, **kwargs) + + @property + def tensor(self): + return self.value + + @property + def _pre_round_float_value(self): + value = self.value + scale = self.scale + zero_point = self.zero_point + if self.scale.dtype == torch.bfloat16: + value = self.value.type(torch.float32) + scale = self.scale.type(torch.float32) + zero_point = self.zero_point.type(torch.float32) + minifloat_value = value / scale + minifloat_value = minifloat_value + zero_point + return minifloat_value + + @property + def is_valid(self): + with torch.no_grad(): + pre_round_minifloat_value = self._pre_round_float_value + rounded_minifloat_value = torch.round(pre_round_minifloat_value) + max_abs_diff = torch.max(torch.abs(pre_round_minifloat_value - rounded_minifloat_value)) + atol = BFLOAT16_IS_VALID_ATOL if self.value.dtype == torch.bfloat16 else IS_VALID_ATOL + is_minifloat = max_abs_diff < atol + # We are missing the checks about self being contained between max and min value + # given by mantissa, exponent, inf, nan, and saturating + return is_minifloat + + @property + def device(self): + value_device = self.value.device + is_same_device = True + for t in [self.scale, + self.zero_point, + self.exponent_bit_width, + self.mantissa_bit_width, + self.exponent_bias]: + is_same_device &= value_device == t.device + if not is_same_device: + raise RuntimeError("Value and metadata are on different devices") + return value_device + + def minifloat(self, float_datatype=True): + assert float_datatype, "Minifloat quant returns only higher precision dtype" + + if self.is_valid: + float_value = self._pre_round_float_value + return float_value.type(self.scale.dtype) + else: + raise RuntimeError(f"FloatQuantTensor not valid.") + + @staticmethod + def check_input_type(tensor): + if not isinstance(tensor, FloatQuantTensor): + raise RuntimeError("Tensor is not a FloatQuantTensor") + + @staticmethod + def is_zero_zero_point(tensor): + FloatQuantTensor.check_input_type(tensor) + return (tensor.zero_point == 0.).all() + + def check_scaling_factors_same(self, other): + if self.training: + return True + if not torch.allclose(self.scale, other.scale): + raise RuntimeError("Scaling factors are different") + + def check_zero_points_same(self, other): + if self.training: + return True + if not torch.allclose(self.zero_point, other.zero_point): + raise RuntimeError("Zero points are different") + + def check_bit_width_same(self, other): + if not torch.allclose(self.exponent_bit_width, + other.exponent_bit_width) and not torch.allclose( + self.mantissa_bit_width, other.mantissa_bit_width): + raise RuntimeError("Bit widths are different") + + def check_exponent_bias(self, other): + if not torch.allclose(self.exponent_bias, other.exponent_bias): + raise RuntimeError("Bit widths are different") + + def check_inf_nan_same(self, other): + if not (set(self.inf_values) == set(other.inf_values)) and not (set(self.nan_values) == set( + other.nan_values)): + raise RuntimeError("Floating point representations are different") + + def check_sign_same(self, other): + if not self.signed == other.signed: + raise RuntimeError("Signs are different") + + def view(self, *args, **kwargs): + return self.set(value=self.value.view(*args, **kwargs)) + + def reshape(self, *args, **kwargs): + return self.set(value=self.value.reshape(*args, **kwargs)) + + def flatten(self, *args, **kwargs): + return self.set(value=self.value.flatten(*args, **kwargs)) + + def transpose(self, *args, **kwargs): + value = self.value.transpose(*args, **kwargs) + tensor_meta = {'scale': self.scale, 'zero_point': self.zero_point} + for k, tm in tensor_meta.items(): + if len(value.shape) == len(tm.shape): + tensor_meta[k] = tm.transpose(*args, **kwargs) + return self.set(value=value, **tensor_meta) + + def permute(self, *args, **kwargs): + value = self.value.permute(*args, **kwargs) + tensor_meta = {'scale': self.scale, 'zero_point': self.zero_point} + for k, tm in tensor_meta.items(): + if len(value.shape) == len(tm.shape): + tensor_meta[k] = tm.permute(*args, **kwargs) + return self.set(value=value, **tensor_meta) + + @staticmethod + def cat(tensors, dim, out=None): + if out is not None: + raise RuntimeError("Out not supported.") + if len(tensors) < 2: + return tensors[0] + else: + first_qt = tensors[0] + if all([isinstance(qt, FloatQuantTensor) for qt in tensors]): + for qt in tensors[1:]: + first_qt.check_scaling_factors_same(qt) + first_qt.check_zero_points_same(qt) + first_qt.check_bit_width_same(qt) + first_qt.check_exponent_bias(qt) + first_qt.check_inf_nan_same(qt) + first_qt.check_sign_same(qt) + output_value = torch.cat([qt.value for qt in tensors], dim=dim) + output_training = any([qt.training for qt in tensors]) + if output_training: + output_scale = sum([qt.scale for qt in tensors]) / len(tensors) + output_zero_point = sum([qt.zero_point for qt in tensors]) / len(tensors) + output_exponent_bit_width = sum([qt.exponent_bit_width for qt in tensors + ]) / len(tensors) + output_mantissa_bit_width = sum([qt.mantissa_bit_width for qt in tensors + ]) / len(tensors) + output_exponent_bias = sum([qt.exponent_bias for qt in tensors]) / len(tensors) + else: # at eval time, they are the same + output_scale = first_qt.scale + output_zero_point = first_qt.zero_point + output_exponent_bit_width = first_qt.exponent_bit_width + output_mantissa_bit_width = first_qt.mantissa_bit_width + output_exponent_bias = first_qt.exponent_bias + output_signed = first_qt.signed # they are the same + output_saturating = first_qt.saturating # they are the same + output_inf_values = first_qt.inf_values # they are the same + output_nan_values = first_qt.nan_values # they are the same + return FloatQuantTensor( + value=output_value, + scale=output_scale, + zero_point=output_zero_point, + exponent_bit_width=output_exponent_bit_width, + mantissa_bit_width=output_mantissa_bit_width, + exponent_bias=output_exponent_bias, + signed=output_signed, + training=output_training, + saturating=output_saturating, + inf_values=output_inf_values, + nan_values=output_nan_values) + else: + tensors = [_unpack_quant_tensor(qt) for qt in tensors] + output_value = torch.cat(tensors, dim=dim) + return output_value + + # Reference: https://docs.python.org/3/reference/datamodel.html#emulating-numeric-types + + def __neg__(self): + neg_value = (-self.minifloat(float_datatype=True) - self.zero_point) * self.scale + # In case the dtype of self.minifloat is different from the one of the scale + neg_value = neg_value.type(self.scale.dtype) + if self.signed: + return FloatQuantTensor( + value=neg_value, + scale=self.scale, + zero_point=self.zero_point, + exponent_bit_width=self.exponent_bit_width, + mantissa_bit_width=self.mantissa_bit_width, + exponent_bias=self.exponent_bias, + signed=self.signed, + training=self.training, + saturating=self.saturating, + inf_values=self.inf_values, + nan_values=self.nan_values) + else: + # TODO: implement + raise NotImplementedError + + def __add__(self, other): + if isinstance(other, QuantTensor): + return self.value + other.value + else: + output = self.value + other + return output + + def __mul__(self, other): + if isinstance(other, QuantTensor): + return self.value * other.value + else: + output = self.value * other + return output + + def __str__(self): + return f"FloatQuantTensor(value={self.value}, scale={self.scale}, zero_point={self.zero_point}, bit_width={self.bit_width}, signed_t={self.signed_t}, training_t={self.training_t})" + + def __truediv__(self, other): + if isinstance(other, QuantTensor): + return self.value / other.value + else: + output = self.value / other + return output + + def __abs__(self): + if self.signed: + abs_value = ( + torch.abs(self.minifloat(float_datatype=True)) - self.zero_point) * self.scale + # In case the dtype of self.minifloat is different from the one of the scale + abs_value = abs_value.type(self.scale.dtype) + return FloatQuantTensor( + value=abs_value, + scale=self.scale, + zero_point=self.zero_point, + exponent_bit_width=self.exponent_bit_width, + mantissa_bit_width=self.mantissa_bit_width, + signed=False, + training=self.training, + saturating=self.saturating, + inf_values=self.inf_values, + nan_values=self.nan_values) + else: + return self diff --git a/src/brevitas/quant_tensor/float_torch_handler.py b/src/brevitas/quant_tensor/float_torch_handler.py new file mode 100644 index 000000000..05386733a --- /dev/null +++ b/src/brevitas/quant_tensor/float_torch_handler.py @@ -0,0 +1,152 @@ +import functools +import math +import warnings + +import torch +import torch.nn.functional as F + +from brevitas.function.ops import max_int +from brevitas.function.ops_ste import ceil_ste +from brevitas.utils.torch_utils import compute_channel_view_shape + +FLOAT_QUANT_TENSOR_FN_HANDLER = {} + + +def implements_float_qt(torch_function): + + @functools.wraps(torch_function) + def decorator(func): + FLOAT_QUANT_TENSOR_FN_HANDLER[torch_function] = func + return func + + return decorator + + +@implements_float_qt(torch.cat) +def cat_handler(*args, **kwargs): + from brevitas.quant_tensor import FloatQuantTensor + return FloatQuantTensor.cat(*args, **kwargs) + + +@implements_float_qt(F.conv1d) +def conv1d_handler(quant_input, quant_weight, bias=None, *args, **kwargs): + output = quant_layer(F.conv1d, quant_input, quant_weight, bias, *args, **kwargs) + return output + + +@implements_float_qt(F.conv2d) +def conv2d_handler(quant_input, quant_weight, bias=None, *args, **kwargs): + output = quant_layer(F.conv2d, quant_input, quant_weight, bias, *args, **kwargs) + return output + + +@implements_float_qt(F.conv3d) +def conv3d_handler(quant_input, quant_weight, bias=None, *args, **kwargs): + output = quant_layer(F.conv3d, quant_input, quant_weight, bias, *args, **kwargs) + return output + + +@implements_float_qt(F.conv_transpose1d) +def conv_transpose1d_handler(quant_input, quant_weight, bias=None, *args, **kwargs): + output = quant_layer(F.conv_transpose1d, quant_input, quant_weight, bias, *args, **kwargs) + return output + + +@implements_float_qt(F.conv_transpose2d) +def conv_transpose2d_handler(quant_input, quant_weight, bias=None, *args, **kwargs): + output = quant_layer(F.conv_transpose2d, quant_input, quant_weight, bias, *args, **kwargs) + return output + + +@implements_float_qt(F.conv_transpose3d) +def conv_transpose3d_handler(quant_input, quant_weight, bias=None, *args, **kwargs): + output = quant_layer(F.conv_transpose3d, quant_input, quant_weight, bias, *args, **kwargs) + return output + + +@implements_float_qt(F.linear) +def linear_handler(quant_input, quant_weight, bias=None, *args, **kwargs): + output = quant_layer(F.linear, quant_input, quant_weight, bias, *args, **kwargs) + return output + + +@implements_float_qt(F.embedding) +def embedding_handler(input, quant_weight, *args, **kwargs): + from brevitas.quant_tensor import _unpack_quant_tensor + from brevitas.quant_tensor import FloatQuantTensor + + quant_weight_value = _unpack_quant_tensor(quant_weight) + out = F.embedding(input, quant_weight_value, *args, **kwargs) + + if isinstance(quant_weight, FloatQuantTensor): + scale = quant_weight.scale + zero_point = quant_weight.zero_point + exponent_bit_width = quant_weight.exponent_bit_width + mantissa_bit_width = quant_weight.mantissa_bit_width + exponent_bias = quant_weight.exponent_bias + inf_values = quant_weight.inf_values + nan_values = quant_weight.nan_values + if any(t.numel() > 1 for t in [scale, zero_point]): + raise RuntimeError("Only per-tensor quantization is supported.") + signed = quant_weight.signed + training = quant_weight.training + saturating = quant_weight.saturating + out = FloatQuantTensor( + out, + scale, + zero_point, + exponent_bit_width, + mantissa_bit_width, + exponent_bias, + signed, + training, + saturating, + inf_values, + nan_values) + return out + + +@implements_float_qt(F.avg_pool2d) +def avg_pool2d_handler( + quant_input, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override): + from brevitas.quant_tensor import _unpack_quant_tensor + + x = F.avg_pool2d( + _unpack_quant_tensor(quant_input), + kernel_size, + stride, + padding, + ceil_mode, + count_include_pad, + divisor_override) + + return x + + +@implements_float_qt(F.adaptive_avg_pool2d) +def adaptive_avg_pool2d_handler(quant_input, output_shape): + from brevitas.quant_tensor import _unpack_quant_tensor + + x = F.adaptive_avg_pool2d(_unpack_quant_tensor(quant_input), output_shape) + return x + + +def quant_layer(fn, quant_input, quant_weight, bias, *args, **kwargs): + from brevitas.quant_tensor import _unpack_quant_tensor + + if bias is None: + output = fn( + _unpack_quant_tensor(quant_input), + _unpack_quant_tensor(quant_weight), + None, + *args, + **kwargs) + else: + output = fn( + _unpack_quant_tensor(quant_input), + _unpack_quant_tensor(quant_weight), + _unpack_quant_tensor(bias), + *args, + **kwargs) + + return output diff --git a/src/brevitas/quant_tensor/int_quant_tensor.py b/src/brevitas/quant_tensor/int_quant_tensor.py index 62c501250..3572a8900 100644 --- a/src/brevitas/quant_tensor/int_quant_tensor.py +++ b/src/brevitas/quant_tensor/int_quant_tensor.py @@ -8,15 +8,17 @@ from brevitas.function.ops_ste import ceil_ste from brevitas.function.ops_ste import round_ste from brevitas.quant_tensor import _unpack_quant_tensor -from brevitas.quant_tensor import QuantTensorBase +from brevitas.quant_tensor import IntQuantTensorBase +from brevitas.quant_tensor import QuantTensor +from .int_torch_handler import INT_QUANT_TENSOR_FN_HANDLER from .torch_handler import QUANT_TENSOR_FN_HANDLER IS_VALID_ATOL = 2e-1 BFLOAT16_IS_VALID_ATOL = 0.5 -class QuantTensor(QuantTensorBase): +class IntQuantTensor(IntQuantTensorBase, QuantTensor): def __new__(cls, value, scale, zero_point, bit_width, signed, training): @@ -35,26 +37,23 @@ def __new__(cls, value, scale, zero_point, bit_width, signed, training): @property def signed(self): - if self.signed_t is not None: - return self.signed_t.item() - else: - return None + return self.signed_t.item() @property def training(self): - if self.training_t is not None: - return self.training_t.item() - else: - return None + return self.training_t.item() def __torch_function__(self, func, types, args=(), kwargs=None): if kwargs is None: kwargs = {} - if func not in QUANT_TENSOR_FN_HANDLER: + if func in INT_QUANT_TENSOR_FN_HANDLER: + return INT_QUANT_TENSOR_FN_HANDLER[func](*args, **kwargs) + elif func in QUANT_TENSOR_FN_HANDLER: + return QUANT_TENSOR_FN_HANDLER[func](*args, **kwargs) + else: args = _unpack_quant_tensor(args) kwargs = _unpack_quant_tensor(kwargs) return func(*args, **kwargs) - return QUANT_TENSOR_FN_HANDLER[func](*args, **kwargs) @property def tensor(self): @@ -102,39 +101,11 @@ def device(self): value_device = self.value.device is_same_device = True for t in [self.scale, self.zero_point, self.bit_width]: - if t is not None: - is_same_device &= value_device == t.device + is_same_device &= value_device == t.device if not is_same_device: raise RuntimeError("Value and metadata are on different devices") return value_device - def set(self, **kwargs): - return self._replace(**kwargs) - - def detach_(self): - self.value.detach_() - self.scale.detach_() - self.zero_point.detach_() - self.bit_width.detach_() - - def detach(self): - return QuantTensor( - self.value.detach(), - self.scale.detach(), - self.zero_point.detach(), - self.bit_width.detach(), - self.signed, - self.training) - - def contiguous(self): - return QuantTensor( - self.value.contiguous(), - self.scale.contiguous(), - self.zero_point.contiguous(), - self.bit_width.contiguous(), - self.signed, - self.training) - def int(self, float_datatype=False): if self.is_valid: int_value = round_ste(self._pre_round_int_value) @@ -153,26 +124,26 @@ def int(self, float_datatype=False): else: return int_value.to(torch.int32) else: - raise RuntimeError(f"QuantTensor not valid.") + raise RuntimeError(f"IntQuantTensor not valid.") @staticmethod def check_input_type(tensor): - if not isinstance(tensor, QuantTensor): - raise RuntimeError("Tensor is not a QuantTensor") + if not isinstance(tensor, IntQuantTensor): + raise RuntimeError("Tensor is not a IntQuantTensor") @staticmethod def is_zero_zero_point(tensor): - QuantTensor.check_input_type(tensor) + IntQuantTensor.check_input_type(tensor) return (tensor.zero_point == 0.).all() def check_scaling_factors_same(self, other): - if self.training is not None and self.training: + if self.training: return True if not torch.allclose(self.scale, other.scale): raise RuntimeError("Scaling factors are different") def check_zero_points_same(self, other): - if self.training is not None and self.training: + if self.training: return True if not torch.allclose(self.zero_point, other.zero_point): raise RuntimeError("Zero points are different") @@ -199,7 +170,7 @@ def transpose(self, *args, **kwargs): tensor_meta = { 'scale': self.scale, 'zero_point': self.zero_point, 'bit_width': self.bit_width} for k, tm in tensor_meta.items(): - if tm is not None and len(value.shape) == len(tm.shape): + if len(value.shape) == len(tm.shape): tensor_meta[k] = tm.transpose(*args, **kwargs) return self.set(value=value, **tensor_meta) @@ -208,7 +179,7 @@ def permute(self, *args, **kwargs): tensor_meta = { 'scale': self.scale, 'zero_point': self.zero_point, 'bit_width': self.bit_width} for k, tm in tensor_meta.items(): - if tm is not None and len(value.shape) == len(tm.shape): + if len(value.shape) == len(tm.shape): tensor_meta[k] = tm.permute(*args, **kwargs) return self.set(value=value, **tensor_meta) @@ -240,7 +211,7 @@ def cat(tensors, dim, out=None): return tensors[0] else: first_qt = tensors[0] - if all([isinstance(qt, QuantTensor) for qt in tensors]): + if all([isinstance(qt, IntQuantTensor) for qt in tensors]): for qt in tensors[1:]: first_qt.check_scaling_factors_same(qt) first_qt.check_zero_points_same(qt) @@ -257,7 +228,7 @@ def cat(tensors, dim, out=None): output_zero_point = first_qt.zero_point output_bit_width = first_qt.bit_width output_signed = first_qt.signed # they are the same - return QuantTensor( + return IntQuantTensor( value=output_value, scale=output_scale, zero_point=output_zero_point, @@ -265,7 +236,7 @@ def cat(tensors, dim, out=None): signed=output_signed, training=output_training) else: - tensors = [qt.value if isinstance(qt, QuantTensor) else qt for qt in tensors] + tensors = [_unpack_quant_tensor(qt) for qt in tensors] output_value = torch.cat(tensors, dim=dim) return output_value @@ -276,7 +247,7 @@ def __neg__(self): # In case the dtype of self.int is different from the one of the scale neg_value = neg_value.type(self.scale.dtype) if self.signed: - return QuantTensor( + return IntQuantTensor( value=neg_value, scale=self.scale, zero_point=self.zero_point, @@ -284,7 +255,7 @@ def __neg__(self): signed=self.signed, training=self.training) else: - return QuantTensor( + return IntQuantTensor( value=neg_value, scale=self.scale, zero_point=self.zero_point, @@ -292,35 +263,8 @@ def __neg__(self): signed=True, training=self.training) - def to(self, *args, **kwargs): - return QuantTensor( - self.value.to(*args, **kwargs), - self.scale.to(*args, **kwargs), - self.zero_point.to(*args, **kwargs), - self.bit_width.to(*args, **kwargs), - self.signed, - self.training) - - def cuda(self, *args, **kwargs): - return QuantTensor( - self.value.cuda(*args, **kwargs), - self.scale.cuda(*args, **kwargs), - self.zero_point.cuda(*args, **kwargs), - self.bit_width.cuda(*args, **kwargs), - self.signed, - self.training) - - def cpu(self, *args, **kwargs): - return QuantTensor( - self.value.cpu(*args, **kwargs), - self.scale.cpu(*args, **kwargs), - self.zero_point.cpu(*args, **kwargs), - self.bit_width.cpu(*args, **kwargs), - self.signed, - self.training) - def __add__(self, other): - if isinstance(other, QuantTensor): + if isinstance(other, IntQuantTensor): self.check_scaling_factors_same(other) output_value = self.value + other.value output_scale = (self.scale + other.scale) / 2 @@ -332,33 +276,29 @@ def __add__(self, other): output_bit_width = ceil_ste(torch.log2(max_val - min_val)) output_signed = self.signed or other.signed output_training = self.training or other.training - output = QuantTensor( + output = IntQuantTensor( value=output_value, scale=output_scale, zero_point=output_zero_point, bit_width=output_bit_width, signed=output_signed, training=output_training) + elif isinstance(other, QuantTensor): + output = self.value + _unpack_quant_tensor(other) else: # When adding a QT with a normal Tensor, we use the zero_point as a way to preserve # and return a QT. - output = QuantTensor( - value=self.value + other, + output = IntQuantTensor( + value=self.value + _unpack_quant_tensor(other), scale=self.scale, - zero_point=self.zero_point - other / self.scale, + zero_point=self.zero_point - _unpack_quant_tensor(other) / self.scale, bit_width=self.bit_width, signed=self.signed, training=self.training) return output - def __radd__(self, other): - return self.__add__(other) - - def __rmul__(self, other): - return self.__mul__(other) - def __mul__(self, other): - if isinstance(other, QuantTensor): + if isinstance(other, IntQuantTensor): output_value = self.value * other.value output_scale = self.scale * other.scale output_bit_width = self.bit_width + other.bit_width @@ -368,7 +308,7 @@ def __mul__(self, other): output_zero_point = self.zero_point * other.zero_point else: raise RuntimeError("Zero-points of mul operands are non-zero, not supported.") - output = QuantTensor( + output = IntQuantTensor( value=output_value, scale=output_scale, zero_point=output_zero_point, @@ -376,17 +316,14 @@ def __mul__(self, other): signed=output_signed, training=output_training) else: - output = self.value * other + output = self.value * _unpack_quant_tensor(other) return output - def __sub__(self, other): - return self.__add__(-other) - def __str__(self): - return f"QuantTensor(value={self.value}, scale={self.scale}, zero_point={self.zero_point}, bit_width={self.bit_width}, signed_t={self.signed_t}, training_t={self.training_t})" + return f"IntQuantTensor(value={self.value}, scale={self.scale}, zero_point={self.zero_point}, bit_width={self.bit_width}, signed_t={self.signed_t}, training_t={self.training_t})" def __truediv__(self, other): - if isinstance(other, QuantTensor): + if isinstance(other, IntQuantTensor): output_tensor = self.value / other.value # Note, output tensor not guaranteed to pass self.is_valid() max_int_denominator = 2 ** (other.bit_width - int(other.signed)) output_scale = self.scale / (other.scale * max_int_denominator) @@ -397,7 +334,7 @@ def __truediv__(self, other): output_zero_point = self.zero_point * other.zero_point # Output zero_point is a new, zero-valued tensor else: raise RuntimeError("Zero-points of div operands are non-zero, not supported.") - output = QuantTensor( + output = IntQuantTensor( value=output_tensor, scale=output_scale, zero_point=output_zero_point, @@ -405,7 +342,7 @@ def __truediv__(self, other): signed=output_signed, training=output_training) else: - output = self.value / other + output = self.value / _unpack_quant_tensor(other) return output def __abs__(self): @@ -413,7 +350,7 @@ def __abs__(self): abs_value = (torch.abs(self.int(float_datatype=True)) - self.zero_point) * self.scale # In case the dtype of self.int is different from the one of the scale abs_value = abs_value.type(self.scale.dtype) - return QuantTensor( + return IntQuantTensor( value=abs_value, scale=self.scale, zero_point=self.zero_point, diff --git a/src/brevitas/quant_tensor/int_torch_handler.py b/src/brevitas/quant_tensor/int_torch_handler.py new file mode 100644 index 000000000..8882bd097 --- /dev/null +++ b/src/brevitas/quant_tensor/int_torch_handler.py @@ -0,0 +1,291 @@ +import functools +import math +from typing import Callable +import warnings + +import torch +from torch import Tensor +import torch.nn.functional as F + +from brevitas.function.ops import max_int +from brevitas.function.ops_ste import ceil_ste +from brevitas.utils.torch_utils import compute_channel_view_shape + +INT_QUANT_TENSOR_FN_HANDLER = {} + + +def implements_int_qt(torch_function): + + @functools.wraps(torch_function) + def decorator(func): + INT_QUANT_TENSOR_FN_HANDLER[torch_function] = func + return func + + return decorator + + +@implements_int_qt(torch.cat) +def cat_handler(*args, **kwargs): + from brevitas.quant_tensor import IntQuantTensor + return IntQuantTensor.cat(*args, **kwargs) + + +@implements_int_qt(F.conv1d) +def conv1d_handler(quant_input, quant_weight, bias=None, *args, **kwargs): + output = quant_layer(F.conv1d, quant_input, quant_weight, bias, *args, **kwargs) + return output + + +@implements_int_qt(F.conv2d) +def conv2d_handler(quant_input, quant_weight, bias=None, *args, **kwargs): + output = quant_layer(F.conv2d, quant_input, quant_weight, bias, *args, **kwargs) + return output + + +@implements_int_qt(F.conv3d) +def conv3d_handler(quant_input, quant_weight, bias=None, *args, **kwargs): + output = quant_layer(F.conv3d, quant_input, quant_weight, bias, *args, **kwargs) + return output + + +@implements_int_qt(F.conv_transpose1d) +def conv_transpose1d_handler(quant_input, quant_weight, bias=None, *args, **kwargs): + output = quant_layer(F.conv_transpose1d, quant_input, quant_weight, bias, *args, **kwargs) + return output + + +@implements_int_qt(F.conv_transpose2d) +def conv_transpose2d_handler(quant_input, quant_weight, bias=None, *args, **kwargs): + output = quant_layer(F.conv_transpose2d, quant_input, quant_weight, bias, *args, **kwargs) + return output + + +@implements_int_qt(F.conv_transpose3d) +def conv_transpose3d_handler(quant_input, quant_weight, bias=None, *args, **kwargs): + output = quant_layer(F.conv_transpose3d, quant_input, quant_weight, bias, *args, **kwargs) + return output + + +@implements_int_qt(F.linear) +def linear_handler(quant_input, quant_weight, bias=None, *args, **kwargs): + output = quant_layer(F.linear, quant_input, quant_weight, bias, *args, **kwargs) + return output + + +@implements_int_qt(F.embedding) +def embedding_handler(input, quant_weight, *args, **kwargs): + from brevitas.quant_tensor import _unpack_quant_tensor + from brevitas.quant_tensor import IntQuantTensor + + quant_weight_value = _unpack_quant_tensor(quant_weight) + out = F.embedding(input, quant_weight_value, *args, **kwargs) + + if isinstance(quant_weight, IntQuantTensor): + scale = quant_weight.scale + zero_point = quant_weight.zero_point + bit_width = quant_weight.bit_width + if any(t.numel() > 1 for t in [scale, zero_point, bit_width]): + raise RuntimeError("Only per-tensor quantization is supported.") + signed = quant_weight.signed + training = quant_weight.training + out = IntQuantTensor(out, scale, zero_point, bit_width, signed, training) + return out + + +@implements_int_qt(F.avg_pool2d) +def avg_pool2d_handler( + quant_input, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override): + from brevitas.quant_tensor import _unpack_quant_tensor + + x = F.avg_pool2d( + _unpack_quant_tensor(quant_input), + kernel_size, + stride, + padding, + ceil_mode, + count_include_pad, + divisor_override) + + max_acc_bit_width = FN_ACC_BITWIDTH_MAPPING[F.avg_pool2d] + # remove avg scaling + if isinstance(kernel_size, tuple): + avg_scaling = kernel_size[0] * kernel_size[1] + else: + avg_scaling = kernel_size * kernel_size + + quant_input = quant_input.set(value=x) + quant_input = quant_input.set(bit_width=max_acc_bit_width(quant_input.bit_width, avg_scaling)) + return quant_input + + +@implements_int_qt(F.adaptive_avg_pool2d) +def adaptive_avg_pool2d_handler(quant_input, output_shape): + from functools import reduce + from operator import mul + + from brevitas.nn.quant_avg_pool import TruncAdaptiveAvgPool2d + from brevitas.quant_tensor import _unpack_quant_tensor + + x = F.adaptive_avg_pool2d(_unpack_quant_tensor(quant_input), output_shape) + k_size, stride = TruncAdaptiveAvgPool2d.compute_kernel_size_stride(quant_input.value.shape[2:], x.shape[2:]) + + max_acc_bit_width = FN_ACC_BITWIDTH_MAPPING[F.avg_pool2d] + reduce_size = reduce(mul, k_size, 1) + + quant_input = quant_input.set(value=x) + quant_input = quant_input.set(bit_width=max_acc_bit_width(quant_input.bit_width, reduce_size)) + return quant_input + + +def quant_layer(fn, quant_input, quant_weight, bias, *args, **kwargs): + from brevitas.quant_tensor import _unpack_quant_tensor + from brevitas.quant_tensor import IntQuantTensor + + output_scale = None + output_bit_width = None + output_zero_point = None + output_signed = None + max_acc_bit_width = FN_ACC_BITWIDTH_MAPPING[fn] + + compute_output_quant_tensor = isinstance(quant_input, IntQuantTensor) and isinstance( + quant_weight, IntQuantTensor) + + if bias is None: + output = fn( + _unpack_quant_tensor(quant_input), + _unpack_quant_tensor(quant_weight), + None, + *args, + **kwargs) + else: + output = fn( + _unpack_quant_tensor(quant_input), + _unpack_quant_tensor(quant_weight), + _unpack_quant_tensor(bias), + *args, + **kwargs) + + if isinstance(quant_input, IntQuantTensor) and isinstance(quant_weight, IntQuantTensor): + output_bit_width = max_acc_bit_width( + quant_input.bit_width, + quant_weight.bit_width, + quant_weight.value.shape, + *args, + **kwargs) + output_scale = quant_output_scale_impl( + fn, quant_input.value, quant_input.scale, quant_weight.scale) + output_signed = quant_input.signed or quant_weight.signed + output_training = quant_input.training or quant_weight.training + + if bias is not None: + if output_scale is not None: + if (isinstance(bias, IntQuantTensor) and + not torch.allclose(bias.scale, output_scale)) or not isinstance(bias, + IntQuantTensor): + channel_dim = -1 if isinstance(fn, torch.nn.Linear) else 1 + output_scale_broadcast_shape = compute_channel_view_shape( + quant_input, channel_dim=channel_dim) + output_zero_point = -_unpack_quant_tensor(bias).view( + output_scale_broadcast_shape) / output_scale + if output_bit_width is not None and isinstance(bias, IntQuantTensor): + output_bit_width = torch.where( + bias.bit_width > output_bit_width, bias.bit_width, output_bit_width) + output_bit_width = output_bit_width + 1 + + if compute_output_quant_tensor: + if (isinstance(quant_input, IntQuantTensor) and + (quant_input.zero_point != 0.0).any()) or (isinstance(quant_weight, IntQuantTensor) and + (quant_weight.zero_point != 0.0).any()): + warnings.warn("Computing zero point of output accumulator not supported yet.") + compute_output_quant_tensor = False + + if compute_output_quant_tensor: + if output_zero_point is None: + output_zero_point = torch.zeros(1).type_as(output) + return create_int_quant_tensor( + output, + output_scale, + output_bit_width, + output_zero_point, + output_signed, + output_training) + else: + return output + + +def create_int_quant_tensor(tensor, scale, bit_width, zero_point, signed, training): + from brevitas.quant_tensor import IntQuantTensor + return IntQuantTensor( + tensor, + scale=scale, + zero_point=zero_point, + bit_width=bit_width, + signed=signed, + training=training) + + +def quant_output_scale_impl( + fn: Callable, inp: Tensor, quant_input_scale: Tensor, quant_weight_scale: Tensor): + channel_dim = -1 if fn == F.linear else 1 + output_scale_shape = compute_channel_view_shape(inp, channel_dim=channel_dim) + + quant_weight_scale = quant_weight_scale.view(output_scale_shape) + if len(quant_input_scale.shape) == 0: + quant_input_scale = quant_input_scale.view(output_scale_shape) + + output_scale = quant_weight_scale * quant_input_scale + return output_scale + + +def max_acc_bit_width_convNd(input_bit_width, weight_bit_width, weight_shape, *args, **kwargs): + max_uint_input = max_int(bit_width=input_bit_width, signed=False, narrow_range=False) + max_kernel_val = max_int(bit_width=weight_bit_width, signed=False, narrow_range=False) + in_channel = weight_shape[1] + kernel_size = math.prod(weight_shape[2:]) + max_uint_output = max_uint_input * max_kernel_val * kernel_size * in_channel + max_output_bit_width = ceil_ste(torch.log2(max_uint_output)) + return max_output_bit_width + + +def max_acc_bit_width_linear(input_bit_width, weight_bit_width, weight_shape, *args, **kwargs): + max_uint_input = max_int(bit_width=input_bit_width, signed=False, narrow_range=False) + max_kernel_val = max_int(bit_width=weight_bit_width, signed=False, narrow_range=False) + in_channel = weight_shape[1] + max_uint_output = max_uint_input * max_kernel_val * in_channel + max_output_bit_width = ceil_ste(torch.log2(max_uint_output)) + return max_output_bit_width + + +def max_acc_bit_width_convtransposeNd( + input_bit_width, weight_bit_width, weight_shape, *args, **kwargs): + stride = kwargs['stride'] + max_uint_input = max_int(bit_width=input_bit_width, signed=False, narrow_range=False) + max_kernel_val = max_int(bit_width=weight_bit_width, signed=False, narrow_range=False) + out_channel = weight_shape[1] + kernel_shape = weight_shape[2:] + + patch_size = 0 + for s, k in zip(stride, kernel_shape): + patch_size += max(math.ceil(k / s), 1) + + max_uint_output = max_uint_input * max_kernel_val * patch_size * out_channel + max_output_bit_width = ceil_ste(torch.log2(max_uint_output)) + return max_output_bit_width + + +def max_acc_bit_width_avg_pool2d(input_bit_width, avg_scaling): + max_uint_input = max_int(bit_width=input_bit_width, signed=False, narrow_range=False) + max_uint_output = max_uint_input * avg_scaling + max_output_bit_width = ceil_ste(torch.log2(max_uint_output)) + return max_output_bit_width + + +FN_ACC_BITWIDTH_MAPPING = { + F.linear: max_acc_bit_width_linear, + F.conv1d: max_acc_bit_width_convNd, + F.conv2d: max_acc_bit_width_convNd, + F.conv3d: max_acc_bit_width_convNd, + F.conv_transpose1d: max_acc_bit_width_convtransposeNd, + F.conv_transpose2d: max_acc_bit_width_convtransposeNd, + F.conv_transpose3d: max_acc_bit_width_convtransposeNd, + F.avg_pool2d: max_acc_bit_width_avg_pool2d} diff --git a/src/brevitas/quant_tensor/torch_handler.py b/src/brevitas/quant_tensor/torch_handler.py index 79934864f..413cbbbc5 100644 --- a/src/brevitas/quant_tensor/torch_handler.py +++ b/src/brevitas/quant_tensor/torch_handler.py @@ -14,6 +14,8 @@ from brevitas.function.ops_ste import ceil_ste from brevitas.utils.torch_utils import compute_channel_view_shape +INT_QUANT_TENSOR_FN_HANDLER = {} +FLOAT_QUANT_TENSOR_FN_HANDLER = {} QUANT_TENSOR_FN_HANDLER = {} @@ -47,12 +49,6 @@ def transpose_handler(inp, *args, **kwargs): return inp.transpose(*args, **kwargs) -@implements(torch.cat) -def cat_handler(*args, **kwargs): - from brevitas.quant_tensor import QuantTensor - return QuantTensor.cat(*args, **kwargs) - - @implements(F.pad) def pad_handler(*args, **kwargs): # TODO check padding value is legal @@ -162,265 +158,3 @@ def pixel_shuffle_handler(*args, **kwargs): @implements(F.pixel_unshuffle) def pixel_unshuffle_handler(*args, **kwargs): return quant_invariant_handler(F.pixel_unshuffle, *args, **kwargs) - - -@implements(F.conv1d) -def conv1d_handler(quant_input, quant_weight, bias=None, *args, **kwargs): - output = quant_layer(F.conv1d, quant_input, quant_weight, bias, *args, **kwargs) - return output - - -@implements(F.conv2d) -def conv2d_handler(quant_input, quant_weight, bias=None, *args, **kwargs): - output = quant_layer(F.conv2d, quant_input, quant_weight, bias, *args, **kwargs) - return output - - -@implements(F.conv3d) -def conv3d_handler(quant_input, quant_weight, bias=None, *args, **kwargs): - output = quant_layer(F.conv3d, quant_input, quant_weight, bias, *args, **kwargs) - return output - - -@implements(F.conv_transpose1d) -def conv_transpose1d_handler(quant_input, quant_weight, bias=None, *args, **kwargs): - output = quant_layer(F.conv_transpose1d, quant_input, quant_weight, bias, *args, **kwargs) - return output - - -@implements(F.conv_transpose2d) -def conv_transpose2d_handler(quant_input, quant_weight, bias=None, *args, **kwargs): - output = quant_layer(F.conv_transpose2d, quant_input, quant_weight, bias, *args, **kwargs) - return output - - -@implements(F.conv_transpose3d) -def conv_transpose3d_handler(quant_input, quant_weight, bias=None, *args, **kwargs): - output = quant_layer(F.conv_transpose3d, quant_input, quant_weight, bias, *args, **kwargs) - return output - - -@implements(F.linear) -def linear_handler(quant_input, quant_weight, bias=None, *args, **kwargs): - output = quant_layer(F.linear, quant_input, quant_weight, bias, *args, **kwargs) - return output - - -@implements(F.embedding) -def embedding_handler(input, quant_weight, *args, **kwargs): - from brevitas.quant_tensor import _unpack_quant_tensor - from brevitas.quant_tensor import QuantTensor - - quant_weight_value = _unpack_quant_tensor(quant_weight) - out = F.embedding(input, quant_weight_value, *args, **kwargs) - - if isinstance(quant_weight, QuantTensor): - scale = quant_weight.scale - zero_point = quant_weight.zero_point - bit_width = quant_weight.bit_width - if any(t.numel() > 1 for t in [scale, zero_point, bit_width]): - raise RuntimeError("Only per-tensor quantization is supported.") - signed = quant_weight.signed - training = quant_weight.training - out = QuantTensor(out, scale, zero_point, bit_width, signed, training) - return out - - -@implements(F.avg_pool2d) -def avg_pool2d_handler( - quant_input, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override): - from brevitas.quant_tensor import _unpack_quant_tensor - - x = F.avg_pool2d( - _unpack_quant_tensor(quant_input), - kernel_size, - stride, - padding, - ceil_mode, - count_include_pad, - divisor_override) - - max_acc_bit_width = FN_ACC_BITWIDTH_MAPPING[F.avg_pool2d] - # remove avg scaling - if isinstance(kernel_size, tuple): - avg_scaling = kernel_size[0] * kernel_size[1] - else: - avg_scaling = kernel_size * kernel_size - - quant_input = quant_input.set(value=x) - quant_input = quant_input.set(bit_width=max_acc_bit_width(quant_input.bit_width, avg_scaling)) - return quant_input - - -@implements(F.adaptive_avg_pool2d) -def adaptive_avg_pool2d_handler(quant_input, output_shape): - from functools import reduce - from operator import mul - - from brevitas.nn.quant_avg_pool import TruncAdaptiveAvgPool2d - from brevitas.quant_tensor import _unpack_quant_tensor - - x = F.adaptive_avg_pool2d(_unpack_quant_tensor(quant_input), output_shape) - k_size, stride = TruncAdaptiveAvgPool2d.compute_kernel_size_stride(quant_input.value.shape[2:], x.shape[2:]) - - max_acc_bit_width = FN_ACC_BITWIDTH_MAPPING[F.avg_pool2d] - reduce_size = reduce(mul, k_size, 1) - - quant_input = quant_input.set(value=x) - quant_input = quant_input.set(bit_width=max_acc_bit_width(quant_input.bit_width, reduce_size)) - return quant_input - - -def quant_layer(fn, quant_input, quant_weight, bias, *args, **kwargs): - from brevitas.quant_tensor import _unpack_quant_tensor - from brevitas.quant_tensor import QuantTensor - - output_scale = None - output_bit_width = None - output_zero_point = None - output_signed = None - max_acc_bit_width = FN_ACC_BITWIDTH_MAPPING[fn] - - compute_output_quant_tensor = isinstance(quant_input, QuantTensor) and isinstance( - quant_weight, QuantTensor) - - if bias is None: - output = fn( - _unpack_quant_tensor(quant_input), - _unpack_quant_tensor(quant_weight), - None, - *args, - **kwargs) - else: - output = fn( - _unpack_quant_tensor(quant_input), - _unpack_quant_tensor(quant_weight), - _unpack_quant_tensor(bias), - *args, - **kwargs) - - if isinstance(quant_input, QuantTensor) and isinstance(quant_weight, QuantTensor): - output_bit_width = max_acc_bit_width( - quant_input.bit_width, - quant_weight.bit_width, - quant_weight.value.shape, - *args, - **kwargs) - output_scale = quant_output_scale_impl( - fn, quant_input.value, quant_input.scale, quant_weight.scale) - output_signed = quant_input.signed or quant_weight.signed - output_training = quant_input.training or quant_weight.training - - if bias is not None: - if output_scale is not None: - if (isinstance(bias, QuantTensor) and - not torch.allclose(bias.scale, output_scale)) or not isinstance(bias, - QuantTensor): - channel_dim = -1 if isinstance(fn, torch.nn.Linear) else 1 - output_scale_broadcast_shape = compute_channel_view_shape( - quant_input, channel_dim=channel_dim) - output_zero_point = -_unpack_quant_tensor(bias).view( - output_scale_broadcast_shape) / output_scale - if output_bit_width is not None and isinstance(bias, QuantTensor): - output_bit_width = torch.where( - bias.bit_width > output_bit_width, bias.bit_width, output_bit_width) - output_bit_width = output_bit_width + 1 - - if compute_output_quant_tensor: - if (isinstance(quant_input, QuantTensor) and - (quant_input.zero_point != 0.0).any()) or (isinstance(quant_weight, QuantTensor) and - (quant_weight.zero_point != 0.0).any()): - warnings.warn("Computing zero point of output accumulator not supported yet.") - compute_output_quant_tensor = False - - if compute_output_quant_tensor: - if output_zero_point is None: - output_zero_point = torch.zeros(1).type_as(output) - - return create_quant_tensor( - output, - output_scale, - output_bit_width, - output_zero_point, - output_signed, - output_training) - else: - return output - - -def create_quant_tensor(tensor, scale, bit_width, zero_point, signed, training): - from brevitas.quant_tensor import QuantTensor - return QuantTensor( - tensor, - scale=scale, - zero_point=zero_point, - bit_width=bit_width, - signed=signed, - training=training) - - -def quant_output_scale_impl( - fn: Callable, inp: Tensor, quant_input_scale: Tensor, quant_weight_scale: Tensor): - channel_dim = -1 if fn == F.linear else 1 - output_scale_shape = compute_channel_view_shape(inp, channel_dim=channel_dim) - - quant_weight_scale = quant_weight_scale.view(output_scale_shape) - if len(quant_input_scale.shape) == 0: - quant_input_scale = quant_input_scale.view(output_scale_shape) - - output_scale = quant_weight_scale * quant_input_scale - return output_scale - - -def max_acc_bit_width_convNd(input_bit_width, weight_bit_width, weight_shape, *args, **kwargs): - max_uint_input = max_int(bit_width=input_bit_width, signed=False, narrow_range=False) - max_kernel_val = max_int(bit_width=weight_bit_width, signed=False, narrow_range=False) - in_channel = weight_shape[1] - kernel_size = math.prod(weight_shape[2:]) - max_uint_output = max_uint_input * max_kernel_val * kernel_size * in_channel - max_output_bit_width = ceil_ste(torch.log2(max_uint_output)) - return max_output_bit_width - - -def max_acc_bit_width_linear(input_bit_width, weight_bit_width, weight_shape, *args, **kwargs): - max_uint_input = max_int(bit_width=input_bit_width, signed=False, narrow_range=False) - max_kernel_val = max_int(bit_width=weight_bit_width, signed=False, narrow_range=False) - in_channel = weight_shape[1] - max_uint_output = max_uint_input * max_kernel_val * in_channel - max_output_bit_width = ceil_ste(torch.log2(max_uint_output)) - return max_output_bit_width - - -def max_acc_bit_width_convtransposeNd( - input_bit_width, weight_bit_width, weight_shape, *args, **kwargs): - stride = kwargs['stride'] - max_uint_input = max_int(bit_width=input_bit_width, signed=False, narrow_range=False) - max_kernel_val = max_int(bit_width=weight_bit_width, signed=False, narrow_range=False) - out_channel = weight_shape[1] - kernel_shape = weight_shape[2:] - - patch_size = 0 - for s, k in zip(stride, kernel_shape): - patch_size += max(math.ceil(k / s), 1) - - max_uint_output = max_uint_input * max_kernel_val * patch_size * out_channel - max_output_bit_width = ceil_ste(torch.log2(max_uint_output)) - return max_output_bit_width - - -def max_acc_bit_width_avg_pool2d(input_bit_width, avg_scaling): - max_uint_input = max_int(bit_width=input_bit_width, signed=False, narrow_range=False) - max_uint_output = max_uint_input * avg_scaling - max_output_bit_width = ceil_ste(torch.log2(max_uint_output)) - return max_output_bit_width - - -FN_ACC_BITWIDTH_MAPPING = { - F.linear: max_acc_bit_width_linear, - F.conv1d: max_acc_bit_width_convNd, - F.conv2d: max_acc_bit_width_convNd, - F.conv3d: max_acc_bit_width_convNd, - F.conv_transpose1d: max_acc_bit_width_convtransposeNd, - F.conv_transpose2d: max_acc_bit_width_convtransposeNd, - F.conv_transpose3d: max_acc_bit_width_convtransposeNd, - F.avg_pool2d: max_acc_bit_width_avg_pool2d} diff --git a/src/brevitas/utils/quant_utils.py b/src/brevitas/utils/quant_utils.py index 5b8bf648f..8df77a99e 100644 --- a/src/brevitas/utils/quant_utils.py +++ b/src/brevitas/utils/quant_utils.py @@ -5,12 +5,13 @@ from brevitas.core.function_wrapper import * from brevitas.core.quant import RescalingIntQuant from brevitas.inject.enum import FloatToIntImplType -from brevitas.quant_tensor import QuantTensor +from brevitas.quant_tensor import FloatQuantTensor +from brevitas.quant_tensor import IntQuantTensor class _CachedIO: - def __init__(self, quant_tensor: QuantTensor, metadata_only: bool): + def __init__(self, quant_tensor: IntQuantTensor, metadata_only: bool): self.shape = quant_tensor.value.shape if metadata_only: self.quant_tensor = quant_tensor.set(value=None) @@ -34,6 +35,52 @@ def signed(self): return self.quant_tensor.signed +class _CachedIOFloat: + + def __init__(self, quant_tensor: FloatQuantTensor, metadata_only: bool): + self.shape = quant_tensor.value.shape + if metadata_only: + self.quant_tensor = quant_tensor.set(value=None) + else: + self.quant_tensor = quant_tensor + + @property + def scale(self): + return self.quant_tensor.scale + + @property + def zero_point(self): + return self.quant_tensor.zero_point + + @property + def exponent_bit_width(self): + return self.quant_tensor.exponent_bit_width + + @property + def mantissa_bit_width(self): + return self.quant_tensor.mantissa_bit_width + + @property + def exponent_bias(self): + return self.quant_tensor.exponent_bias + + @property + def saturating(self): + return self.quant_tensor.saturating + + @property + def inf_values(self): + return self.quant_tensor.inf_values + + @property + def nan_values(self): + return self.quant_tensor.nan_values + + @property + def signed(self): + return self.quant_tensor.signed + + def has_learned_weight_bit_width(module): from brevitas.proxy.parameter_quant import WeightQuantProxyFromInjector diff --git a/tests/brevitas/core/test_clamp.py b/tests/brevitas/core/test_clamp.py index 4654f5054..4b13032f2 100644 --- a/tests/brevitas/core/test_clamp.py +++ b/tests/brevitas/core/test_clamp.py @@ -60,7 +60,7 @@ def test_float_clamp(inp, fp8_clamp): over_limit_mask = inp.abs() > max_val # clamp inp - inp = fp8_clamp.float_clamp_impl( + inp, *_ = fp8_clamp.float_clamp_impl( inp, torch.tensor(fp8_clamp.exponent_bit_width), torch.tensor(fp8_clamp.mantissa_bit_width), diff --git a/tests/brevitas/core/test_float_quant.py b/tests/brevitas/core/test_float_quant.py index 5c1b0cdc3..43c090e6c 100644 --- a/tests/brevitas/core/test_float_quant.py +++ b/tests/brevitas/core/test_float_quant.py @@ -95,7 +95,7 @@ def test_float_to_quant_float(inp, minifloat_format): out_quant, scale = float_quant.quantize(inp) exponent_bit_width, mantissa_bit_width, exponent_bias = torch.tensor(exponent_bit_width, dtype=torch.float), torch.tensor(mantissa_bit_width, dtype=torch.float), torch.tensor(exponent_bias, dtype=torch.float) - out_quant = float_quant.float_clamp_impl( + out_quant, *_ = float_quant.float_clamp_impl( out_quant, exponent_bit_width, mantissa_bit_width, exponent_bias) assert torch.allclose(expected_out, out_quant * scale) @@ -202,7 +202,7 @@ def test_inner_scale(inp, minifloat_format, scale): # dequantize manually out = val_fp_quant * scale - expected_out, expected_scale, _, _ = float_quant(inp) + expected_out, expected_scale, *_ = float_quant(inp) assert scale == expected_scale if scale == 0.0: diff --git a/tests/brevitas/nn/nn_quantizers_fixture.py b/tests/brevitas/nn/nn_quantizers_fixture.py index 819bceb10..7b3183b94 100644 --- a/tests/brevitas/nn/nn_quantizers_fixture.py +++ b/tests/brevitas/nn/nn_quantizers_fixture.py @@ -33,6 +33,7 @@ from brevitas.quant.scaled_int import Uint8ActPerTensorFloat from brevitas.quant.shifted_scaled_int import ShiftedUint8ActPerTensorFloat from brevitas.quant.shifted_scaled_int import ShiftedUint8WeightPerTensorFloat +from brevitas.quant_tensor import IntQuantTensor from brevitas.quant_tensor import QuantTensor SEED = 123456 @@ -169,7 +170,7 @@ def forward(self, x): raise RuntimeError("Unsupported operation") if input_quantized: - quant_inp = QuantTensor( + quant_inp = IntQuantTensor( torch.randint(-128, 127, in_size) * 0.128, 0.128, 0., 8., True, is_training) else: quant_inp = torch.randn(in_size) diff --git a/tests/brevitas/nn/test_linear.py b/tests/brevitas/nn/test_linear.py index 62799281b..4141b4498 100644 --- a/tests/brevitas/nn/test_linear.py +++ b/tests/brevitas/nn/test_linear.py @@ -5,7 +5,7 @@ from brevitas.nn import QuantLinear from brevitas.quant import Int32Bias -from brevitas.quant_tensor import QuantTensor +from brevitas.quant_tensor import IntQuantTensor OUTPUT_FEATURES = 10 INPUT_FEATURES = 5 @@ -57,7 +57,7 @@ def test_forward_bias_int(self): in_features=INPUT_FEATURES, bias=True, bias_quant=Int32Bias) - x = QuantTensor( + x = IntQuantTensor( torch.rand(size=(3, INPUT_FEATURES)), torch.tensor(1.0), torch.tensor(0.0), diff --git a/tests/brevitas/quant_tensor/test_quant_tensor.py b/tests/brevitas/quant_tensor/test_quant_tensor.py index c7544a1f3..4e3401057 100644 --- a/tests/brevitas/quant_tensor/test_quant_tensor.py +++ b/tests/brevitas/quant_tensor/test_quant_tensor.py @@ -2,12 +2,15 @@ # SPDX-License-Identifier: BSD-3-Clause from enum import Enum +from packaging import version import pytest import torch -from brevitas.inject.enum import QuantType +from brevitas import torch_version from brevitas.nn import QuantIdentity -from brevitas.quant_tensor import QuantTensor +from brevitas.quant.experimental.float_quant_ocp import Fp8e5m2OCPActPerTensorFloat +from brevitas.quant_tensor import FloatQuantTensor +from brevitas.quant_tensor import IntQuantTensor class Operator(Enum): @@ -18,11 +21,17 @@ class Operator(Enum): MATMUL = 4 -def to_quant_tensor(input: torch.Tensor) -> QuantTensor: +def to_quant_tensor(input: torch.Tensor) -> IntQuantTensor: mod = QuantIdentity(bit_width=8, return_quant_tensor=True) return mod(input) +def to_float_quant_tensor(input: torch.Tensor) -> FloatQuantTensor: + mod = QuantIdentity( + bit_width=8, return_quant_tensor=True, act_quant=Fp8e5m2OCPActPerTensorFloat) + return mod(input) + + def qdq(normal_tensor, quant_tensor): return ( torch.round(normal_tensor / quant_tensor.scale + quant_tensor.zero_point) - @@ -38,15 +47,20 @@ def test_quant_tensor_init(): @pytest.mark.parametrize( 'op', [Operator.ADD, Operator.SUBTRACT, Operator.DIVIDE, Operator.MULTIPLY, Operator.MATMUL]) -def test_quant_tensor_operators(op): +@pytest.mark.parametrize('quant_fn', [to_quant_tensor, to_float_quant_tensor]) +def test_quant_tensor_operators(op, quant_fn): + + if quant_fn == to_float_quant_tensor and torch_version < version.parse('1.13'): + pytest.skip("Torch 1.13 is required for JIT to be compatible with FloatQuantTensor") + # Avoid 0 values x = 1 + torch.rand(4, 4) a = torch.Tensor(x) b = torch.Tensor(x) - qa = to_quant_tensor(a) - qb = to_quant_tensor(b) + qa = quant_fn(a) + qb = quant_fn(b) # to factor in quantisation error e_a = a - qa diff --git a/tests/brevitas/test_quant_tensor.py b/tests/brevitas/test_quant_tensor.py new file mode 100644 index 000000000..e388b4830 --- /dev/null +++ b/tests/brevitas/test_quant_tensor.py @@ -0,0 +1,14 @@ +import torch + +from brevitas.quant_tensor import IntQuantTensor +from brevitas.quant_tensor import QuantTensor + + +def test_qt_structure(): + qt = IntQuantTensor( + torch.randn(10), torch.randn(1), torch.tensor(0.), torch.tensor(8.), True, False) + assert isinstance(qt, IntQuantTensor) + assert isinstance(qt, QuantTensor) + assert isinstance(qt, tuple) + assert hasattr(qt, '_fields') + assert len(qt._fields) == 6 diff --git a/tests/brevitas_end_to_end/test_torchvision_models.py b/tests/brevitas_end_to_end/test_torchvision_models.py index e8c34d961..0d76ae2db 100644 --- a/tests/brevitas_end_to_end/test_torchvision_models.py +++ b/tests/brevitas_end_to_end/test_torchvision_models.py @@ -18,6 +18,7 @@ from brevitas.graph.quantize import quantize from brevitas.graph.target.flexml import preprocess_for_flexml_quantize from brevitas.graph.target.flexml import quantize_flexml +from brevitas_examples.imagenet_classification.ptq.ptq_common import quantize_model from tests.marker import requires_pt_ge BATCH = 1 @@ -53,6 +54,20 @@ def forward(self, inp): return out['out'] +def quantize_float(model): + return quantize_model( + model, + weight_bit_width=8, + act_bit_width=8, + bias_bit_width=None, + weight_quant_granularity='per_tensor', + act_quant_percentile=99.999, + act_quant_type='sym', + scale_factor_type='float_scale', + backend='layerwise', + quant_format='float') + + @fixture @parametrize('model_name', MODEL_LIST) @parametrize('quantize_fn', [quantize, quantize_flexml, layerwise_quantize]) @@ -102,12 +117,24 @@ def test_torchvision_graph_quantization_flexml_qcdq_onnx(torchvision_model, requ if torchvision_model is None: pytest.skip('Model not instantiated') inp = torch.randn(BATCH, IN_CH, HEIGHT, WIDTH) - export_onnx_qcdq(torchvision_model, args=inp) + test_id = request.node.callspec.id + + quantize_fn_name = test_id.split("-")[0] + torchvision_model(inp) + if quantize_fn_name != 'quantize_float': + export_onnx_qcdq(torchvision_model, args=inp) @requires_pt_ge('1.9.1') def test_torchvision_graph_quantization_flexml_qcdq_torch(torchvision_model, request): if torchvision_model is None: pytest.skip('Model not instantiated') + + test_id = request.node.callspec.id + quantize_fn_name = test_id.split("-")[0] + if quantize_fn_name == 'quantize_float': + return inp = torch.randn(BATCH, IN_CH, HEIGHT, WIDTH) + + torchvision_model(inp) export_torch_qcdq(torchvision_model, args=inp)