From 255408e0b1e47252b9c69a41816baa443e1fdb29 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Mon, 26 Feb 2024 22:37:56 +0000 Subject: [PATCH 1/2] Fix (proxy): fix for attributes retrieval --- src/brevitas/proxy/parameter_quant.py | 22 ++++++++++++++++++---- src/brevitas/proxy/runtime_quant.py | 22 +++++++++++++++------- 2 files changed, 33 insertions(+), 11 deletions(-) diff --git a/src/brevitas/proxy/parameter_quant.py b/src/brevitas/proxy/parameter_quant.py index 2927b1662..79ad7e9ec 100644 --- a/src/brevitas/proxy/parameter_quant.py +++ b/src/brevitas/proxy/parameter_quant.py @@ -82,16 +82,22 @@ def requires_quant_input(self): return False def scale(self): + if not self.is_quant_enabled: + return None scale = self.__call__(self.tracked_parameter_list[0]).scale return scale def zero_point(self): + if not self.is_quant_enabled: + return None zero_point = self.__call__(self.tracked_parameter_list[0]).zero_point return zero_point def bit_width(self): - bit_width_ = self.__call__(self.tracked_parameter_list[0]).bit_width - return bit_width_ + if not self.is_quant_enabled: + return None + bit_width = self.__call__(self.tracked_parameter_list[0]).bit_width + return bit_width def forward(self, x: torch.Tensor) -> Union[Tensor, QuantTensor]: if self.is_quant_enabled: @@ -105,11 +111,15 @@ def forward(self, x: torch.Tensor) -> Union[Tensor, QuantTensor]: class DecoupledWeightQuantProxyFromInjector(WeightQuantProxyFromInjector): def pre_scale(self): + if not self.is_quant_enabled: + return None output_tuple = self.tensor_quant(self.tracked_parameter_list[0]) out, scale, zero_point, bit_width, pre_scale, pre_zero_point = output_tuple return pre_scale def pre_zero_point(self): + if not self.is_quant_enabled: + return None output_tuple = self.tensor_quant(self.tracked_parameter_list[0]) out, scale, zero_point, bit_width, pre_scale, pre_zero_point = output_tuple return pre_zero_point @@ -151,7 +161,7 @@ def forward(self, x: torch.Tensor, input_bit_width: torch.Tensor, out, scale, zero_point, bit_width, pre_scale, pre_zero_point = impl(x, input_bit_width, input_is_signed) return QuantTensor(out, scale, zero_point, bit_width, self.is_signed, self.training) else: # quantization disabled - return QuantTensor(x, training=self.training) + return x class BiasQuantProxyFromInjector(ParameterQuantProxyFromInjector, BiasQuantProxyProtocol): @@ -168,18 +178,22 @@ def requires_input_scale(self) -> bool: return False def scale(self): - if self.requires_input_scale: + if self.requires_input_scale or not self.is_quant_enabled: return None zhs = self._zero_hw_sentinel() scale = self.__call__(self.tracked_parameter_list[0], zhs).scale return scale def zero_point(self): + if not self.is_quant_enabled: + return None zhs = self._zero_hw_sentinel() zero_point = self.__call__(self.tracked_parameter_list[0], zhs).zero_point return zero_point def bit_width(self): + if not self.is_quant_enabled: + return None zhs = self._zero_hw_sentinel() bit_width = self.__call__(self.tracked_parameter_list[0], zhs).bit_width return bit_width diff --git a/src/brevitas/proxy/runtime_quant.py b/src/brevitas/proxy/runtime_quant.py index 0324465c1..fe7b29daf 100644 --- a/src/brevitas/proxy/runtime_quant.py +++ b/src/brevitas/proxy/runtime_quant.py @@ -118,6 +118,8 @@ def init_tensor_quant(self): self.fused_activation_quant_proxy = None def scale(self, force_eval=True): + if not self.is_quant_enabled: + return None current_status = self.training if force_eval: self.eval() @@ -126,6 +128,8 @@ def scale(self, force_eval=True): return scale def zero_point(self, force_eval=True): + if not self.is_quant_enabled: + return None current_status = self.training if force_eval: self.eval() @@ -133,9 +137,15 @@ def zero_point(self, force_eval=True): self.train(current_status) return zero_point - def bit_width(self): - scale = self.__call__(self._zero_hw_sentinel()).bit_width - return scale + def bit_width(self, force_eval=True): + if not self.is_quant_enabled: + return None + current_status = self.training + if force_eval: + self.eval() + bit_width = self.__call__(self._zero_hw_sentinel()).bit_width + self.train(current_status) + return bit_width def forward(self, x: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTensor]: if self.fused_activation_quant_proxy is not None: @@ -179,10 +189,6 @@ def scale(self, force_eval=True): def zero_point(self, force_eval=True): raise RuntimeError("Zero point for Dynamic Act Quant is input-dependant") - def bit_width(self): - bit_width = self.__call__(self._zero_hw_sentinel()).bit_width - return bit_width - class ClampQuantProxyFromInjector(QuantProxyFromInjector, AccQuantProxyProtocol): @@ -198,6 +204,8 @@ def forward(self, x: QuantTensor) -> Union[Tensor, QuantTensor]: class TruncQuantProxyFromInjector(QuantProxyFromInjector, AccQuantProxyProtocol): def bit_width(self): + if not self.is_quant_enabled: + return None zhs = self._zero_hw_sentinel() # Signed might or might not be defined. We just care about retrieving the bitwidth empty_imp = QuantTensor(zhs, zhs, zhs, zhs, signed=True, training=self.training) From 9ec9a3ff06e83af037c90f98c97b306ef84870a0 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Thu, 29 Feb 2024 18:27:06 +0000 Subject: [PATCH 2/2] Feat (tests): add new tests for proxy --- tests/brevitas/proxy/test_proxy.py | 82 +++++++++++++++++++++ tests/brevitas/proxy/test_weight_scaling.py | 1 - tests/marker.py | 9 +++ 3 files changed, 91 insertions(+), 1 deletion(-) create mode 100644 tests/brevitas/proxy/test_proxy.py diff --git a/tests/brevitas/proxy/test_proxy.py b/tests/brevitas/proxy/test_proxy.py new file mode 100644 index 000000000..08b525a71 --- /dev/null +++ b/tests/brevitas/proxy/test_proxy.py @@ -0,0 +1,82 @@ +import pytest + +from brevitas.nn import QuantLinear +from brevitas.nn.quant_activation import QuantReLU +from brevitas.quant.scaled_int import Int8AccumulatorAwareWeightQuant +from brevitas.quant.scaled_int import Int8BiasPerTensorFloatInternalScaling +from brevitas.quant.scaled_int import Int8WeightPerChannelFloatDecoupled +from brevitas.quant.scaled_int import Int8WeightPerTensorFloat +from brevitas_examples.common.generative.quantizers import Int8DynamicActPerTensorFloat +from tests.marker import jit_disabled_for_dynamic_quant_act + + +class TestProxy: + + def test_bias_proxy(self): + model = QuantLinear(10, 5, bias_quant=Int8BiasPerTensorFloatInternalScaling) + assert model.bias_quant.scale() is not None + assert model.bias_quant.zero_point() is not None + assert model.bias_quant.bit_width() is not None + + model.bias_quant.disable_quant = True + assert model.bias_quant.scale() is None + assert model.bias_quant.zero_point() is None + assert model.bias_quant.bit_width() is None + + def test_weight_proxy(self): + model = QuantLinear(10, 5, weight_quant=Int8WeightPerTensorFloat) + assert model.weight_quant.scale() is not None + assert model.weight_quant.zero_point() is not None + assert model.weight_quant.bit_width() is not None + + model.weight_quant.disable_quant = True + assert model.weight_quant.scale() is None + assert model.weight_quant.zero_point() is None + assert model.weight_quant.bit_width() is None + + def test_weight_decoupled_proxy(self): + model = QuantLinear(10, 5, weight_quant=Int8WeightPerChannelFloatDecoupled) + assert model.weight_quant.pre_scale() is not None + assert model.weight_quant.pre_zero_point() is not None + + model.weight_quant.disable_quant = True + assert model.weight_quant.pre_scale() is None + assert model.weight_quant.pre_zero_point() is None + + def test_weight_decoupled_with_input_proxy(self): + model = QuantLinear(10, 5, weight_quant=Int8AccumulatorAwareWeightQuant) + with pytest.raises(NotImplementedError): + model.weight_quant.scale() + with pytest.raises(NotImplementedError): + model.weight_quant.zero_point() + + with pytest.raises(NotImplementedError): + model.weight_quant.pre_scale() + with pytest.raises(NotImplementedError): + model.weight_quant.pre_zero_point() + + def test_act_proxy(self): + model = QuantReLU() + assert model.act_quant.scale() is not None + assert model.act_quant.zero_point() is not None + assert model.act_quant.bit_width() is not None + + model.act_quant.disable_quant = True + assert model.act_quant.scale() is None + assert model.act_quant.zero_point() is None + assert model.act_quant.bit_width() is None + + @jit_disabled_for_dynamic_quant_act() + def test_dynamic_act_proxy(self): + model = QuantReLU(Int8DynamicActPerTensorFloat) + + with pytest.raises(RuntimeError, match="Scale for Dynamic Act Quant is input-dependant"): + model.act_quant.scale() + with pytest.raises(RuntimeError, + match="Zero point for Dynamic Act Quant is input-dependant"): + model.act_quant.zero_point() + + assert model.act_quant.bit_width() is not None + + model.act_quant.disable_quant = True + assert model.act_quant.bit_width() is None diff --git a/tests/brevitas/proxy/test_weight_scaling.py b/tests/brevitas/proxy/test_weight_scaling.py index 074ca7c61..49a7f20fe 100644 --- a/tests/brevitas/proxy/test_weight_scaling.py +++ b/tests/brevitas/proxy/test_weight_scaling.py @@ -1,7 +1,6 @@ # Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause -import pytest from torch import nn from brevitas import config diff --git a/tests/marker.py b/tests/marker.py index 59e76a7d2..f11dc7a4a 100644 --- a/tests/marker.py +++ b/tests/marker.py @@ -50,5 +50,14 @@ def skip_wrapper(f): return skip_wrapper +def jit_disabled_for_dynamic_quant_act(): + skip = config.JIT_ENABLED + + def skip_wrapper(f): + return pytest.mark.skipif(skip, reason=f'Dynamic Act Quant requires JIT to be disabled')(f) + + return skip_wrapper + + skip_on_macos_nox = pytest.mark.skipif( platform.system() == "Darwin", reason="Known issue with Nox and MacOS.")