diff --git a/src/brevitas/core/scaling/__init__.py b/src/brevitas/core/scaling/__init__.py index 06c65ecbd..6187bd262 100644 --- a/src/brevitas/core/scaling/__init__.py +++ b/src/brevitas/core/scaling/__init__.py @@ -7,6 +7,7 @@ from .int_scaling import IntScaling from .int_scaling import PowerOfTwoIntScaling from .pre_scaling import AccumulatorAwareParameterPreScaling +from .pre_scaling import AccumulatorAwareZeroCenterParameterPreScaling from .pre_scaling import ParameterPreScalingWeightNorm from .runtime import RuntimeStatsScaling from .runtime import StatsFromParameterScaling diff --git a/src/brevitas/core/scaling/pre_scaling.py b/src/brevitas/core/scaling/pre_scaling.py index dd125396d..d73c86461 100644 --- a/src/brevitas/core/scaling/pre_scaling.py +++ b/src/brevitas/core/scaling/pre_scaling.py @@ -13,10 +13,14 @@ from brevitas.core.restrict_val import _RestrictClampValue from brevitas.core.stats import SCALAR_SHAPE from brevitas.core.stats.stats_wrapper import _Stats +from brevitas.core.zero_point import PreZeroCenterZeroPoint from brevitas.function import abs_binary_sign_grad from brevitas.function import get_upper_bound_on_l1_norm -__all__ = ["ParameterPreScalingWeightNorm", "AccumulatorAwareParameterPreScaling"] +__all__ = [ + "ParameterPreScalingWeightNorm", + "AccumulatorAwareParameterPreScaling", + "AccumulatorAwareZeroCenterParameterPreScaling"] class ParameterPreScalingWeightNorm(brevitas.jit.ScriptModule): @@ -113,7 +117,7 @@ def _load_from_state_dict( class AccumulatorAwareParameterPreScaling(ParameterPreScalingWeightNorm): """ ScriptModule implementation of learned pre-clipping scaling factor to support - accumulator-aware quantizaion (A2Q) as proposed in `A2Q: Accumulator-Aware Quantization + accumulator-aware quantization (A2Q) as proposed in `A2Q: Accumulator-Aware Quantization with Guaranteed Overflow Avoidance` by I. Colbert, A. Pappalardo, and J. Petri-Koenig. The module parameterizes the pre-clipping scaling factor (i.e., `pre_scale`) of the @@ -150,16 +154,15 @@ class AccumulatorAwareParameterPreScaling(ParameterPreScalingWeightNorm): """ def __init__( - self, - scaling_impl: Module, - normalize_stats_impl: Module, - accumulator_bit_width_impl: Module, - scaling_stats_input_view_shape_impl: Module, - tracked_parameter_list: List[torch.nn.Parameter], - pre_scaling_shape: Optional[Tuple[int, ...]] = None, - restrict_pre_scaling_impl: Optional[Module] = None, - pre_scaling_min_val: Optional[float] = None, - ) -> None: + self, + scaling_impl: Module, + normalize_stats_impl: Module, + accumulator_bit_width_impl: Module, + scaling_stats_input_view_shape_impl: Module, + tracked_parameter_list: List[torch.nn.Parameter], + pre_scaling_shape: Optional[Tuple[int, ...]] = None, + restrict_pre_scaling_impl: Optional[Module] = None, + pre_scaling_min_val: Optional[float] = None) -> None: super().__init__( scaling_impl, normalize_stats_impl, @@ -167,21 +170,102 @@ def __init__( tracked_parameter_list, pre_scaling_shape, restrict_pre_scaling_impl, - pre_scaling_min_val, - ) + pre_scaling_min_val) self.accumulator_bit_width = accumulator_bit_width_impl @brevitas.jit.script_method - def forward(self, weights: Tensor, input_bit_width: Tensor, input_is_signed: bool) -> Tensor: - """Takes weights as input and returns the pre-clipping scaling factor""" + def calc_max_l1_norm(self, input_bit_width: Tensor, input_is_signed: bool) -> Tensor: + accumulator_bit_width = self.accumulator_bit_width() + upper_bound = get_upper_bound_on_l1_norm( + accumulator_bit_width, input_bit_width, input_is_signed) + return upper_bound + + @brevitas.jit.script_method + def inner_forward(self, weights: Tensor, input_bit_width: Tensor, input_is_signed: bool): weights = self.stats_input_view_shape_impl(weights) d_w = self.stats(weights) # denominator for weight normalization s = self.scaling_impl(weights) # s g = abs_binary_sign_grad(self.restrict_clamp_scaling(self.value)) # g - T = get_upper_bound_on_l1_norm( - self.accumulator_bit_width(), input_bit_width, input_is_signed) # T / s + T = self.calc_max_l1_norm(input_bit_width, input_is_signed) # T / s g = torch.clamp_max(g / s, T) value = d_w / g # calculating final pre-clipping scaling factor # re-apply clamp_min_ste from restrict_scaling_impl to the specified pre_scaling_min_val value = self.restrict_clamp_scaling.clamp_min_ste(value) return value + + @brevitas.jit.script_method + def forward(self, weights: Tensor, input_bit_width: Tensor, input_is_signed: bool) -> Tensor: + """Takes weights, input bit-width, and input sign as input and returns the pre-clipping + scaling factor per-channel, which is $s \cdot \Vert v \Vert_1 / g$""" + value = self.inner_forward(weights, input_bit_width, input_is_signed) + return value + + +class AccumulatorAwareZeroCenterParameterPreScaling(AccumulatorAwareParameterPreScaling): + """ + ScriptModule implementation of learned pre-clipping scaling factor to support + A2Q+ as proposed in `A2Q+: Improving Accumulator-Aware Weight Quantization`. + + The module implements the zero-centering constraint as a pre-clipping zero-point + (i.e., `PreZeroCenterZeroPoint`) to relax the l1-norm constraint. + + Args: + scaling_impl (Module): post-clipping scaling factor. + pre_zero_point_impl (Module): pre-clipping zero-point. + normalize_stats_impl (Module): calculate statistics for normalizing weight parameter. + accumulator_bit_width_impl (Module): module that returns the accumulator bit-width. + scaling_stats_input_view_shape_impl (Module): transforming scaling to a new shape. + tracked_parameter_list (List[torch.nn.Parameter]): list of tracked weight parameters + for tensor quantizer. + pre_scaling_shape (Tuple[int]): shape of pre-clipping scaling factor. Default: None. + restrict_pre_scaling_impl (Module): restrict pre_scaling_init according to some + criteria. Default: None. + pre_scaling_min_val (float): force a lower-bound on scaling_init. Default: None. + + Returns: + Tensor: scaling factor wrapped in a float torch.Tensor. + """ + + def __init__( + self, + scaling_impl: Module, + pre_zero_point_impl: Module, + normalize_stats_impl: Module, + accumulator_bit_width_impl: Module, + scaling_stats_input_view_shape_impl: Module, + tracked_parameter_list: List[Parameter], + pre_scaling_shape: Optional[Tuple[int, ...]] = None, + restrict_pre_scaling_impl: Optional[Module] = None, + pre_scaling_min_val: Optional[float] = None) -> None: + super().__init__( + scaling_impl, + normalize_stats_impl, + accumulator_bit_width_impl, + scaling_stats_input_view_shape_impl, + tracked_parameter_list, + pre_scaling_shape, + restrict_pre_scaling_impl, + pre_scaling_min_val) + assert isinstance( + pre_zero_point_impl, PreZeroCenterZeroPoint + ), "Error: A2Q+ requires a pre-clipping zero-centering zero-point." + self.pre_zero_point = pre_zero_point_impl + + @brevitas.jit.script_method + def calc_max_l1_norm(self, input_bit_width: Tensor, input_is_signed: bool) -> Tensor: + """ """ + assert input_bit_width is not None, "A2Q+ relies on input bit-width." + max_accumulator_bit_width = self.accumulator_bit_width() # P + max_accumulator_mag = pow(2.0, max_accumulator_bit_width) - 2.0 # 2^P - 2 + max_input_mag = pow(2.0, input_bit_width) - 1.0 # 2^N - 1 + return max_accumulator_mag / max_input_mag + + @brevitas.jit.script_method + def forward(self, weights: Tensor, input_bit_width: Tensor, input_is_signed: bool) -> Tensor: + """Takes weights, input bit-width, and input sign as input and returns the pre-clipping + scaling factor per-channel, which is $s \cdot \Vert v - \mu_v \Vert_1 / g$""" + # NOTE: A2Q+ requires zero-centering the floating-point weights, which means that the + # calculation of the l1-norm needs to be done over the zero-centered weights. + z = self.pre_zero_point.get_zero_center(weights) + value = self.inner_forward(weights + z, input_bit_width, input_is_signed) + return value diff --git a/src/brevitas/core/zero_point.py b/src/brevitas/core/zero_point.py index 7f8ad106d..58eb73820 100644 --- a/src/brevitas/core/zero_point.py +++ b/src/brevitas/core/zero_point.py @@ -24,7 +24,8 @@ 'StatsFromParameterZeroPoint', 'ParameterFromRuntimeZeroPoint', 'ParameterZeroPoint', - 'ParameterFromStatsFromParameterZeroPoint'] + 'ParameterFromStatsFromParameterZeroPoint', + 'PreZeroCenterZeroPoint'] class ZeroZeroPoint(brevitas.jit.ScriptModule): @@ -294,3 +295,33 @@ def _load_from_state_dict( self.init_done = True if config.IGNORE_MISSING_KEYS and value_key in missing_keys: missing_keys.remove(value_key) + + +class PreZeroCenterZeroPoint(brevitas.jit.ScriptModule): + """Experimental ScriptModule implementation of a pre-scaling zero-point that zero-centers + the incoming tensors. This is intended to be used with `DecoupledIntQuant`.""" + + def __init__( + self, + stats_reduce_dim: int, + pre_zero_point_stats_input_view_shape_impl: Module, + pre_zero_point_shape: Optional[Tuple[int, ...]] = None) -> None: + super(PreZeroCenterZeroPoint, self).__init__() + self.stats_reduce_dim = stats_reduce_dim + self.stats_output_shape = pre_zero_point_shape + self.stats_input_view_shape_impl = pre_zero_point_stats_input_view_shape_impl + + @brevitas.jit.script_method + def get_zero_center(self, x: Tensor) -> Tensor: + x = self.stats_input_view_shape_impl(x) + u = torch.mean(x, dim=self.stats_reduce_dim, keepdim=True) + z = -u.view(self.stats_output_shape) + return z + + @brevitas.jit.script_method + def forward(self, x: Tensor, scale: Tensor, bit_width: Tensor) -> Tensor: + # NOTE: `DecoupledIntQuant` adds the `pre_zero_point` value to the scaled tensor, + # so this needs to return the negative of the scaled average value to perform + # pre-zero centering before rounding and clipping + z = self.get_zero_center(x) / scale # need to scale the norm by s + return z diff --git a/src/brevitas/function/ops.py b/src/brevitas/function/ops.py index ec326602d..6751ab69c 100644 --- a/src/brevitas/function/ops.py +++ b/src/brevitas/function/ops.py @@ -205,9 +205,10 @@ def max_float(exponent_bit_width: Tensor, mantissa_bit_width: Tensor, exponent_b def get_upper_bound_on_l1_norm( accumulator_bit_width: Tensor, input_bit_width: Tensor, input_is_signed: bool) -> Tensor: - """Calculate the upper bound on the l1-norm of the weights using the derivations from - `Quantized Neural Networks for Low-Precision Accumulation with Guaranteed Overflow Avoidance` - by I.Colbert, A.Pappalardo, and J.Petri-Koenig.""" + """Calculate the upper bound on the l1-norm of the weights needed to guarantee overflow avoidance + for a given accumulator bit width and input representation using the derivations from + `A2Q: Accumulator-Aware Quantization with Guaranteed Overflow Avoidance` by I.Colbert, + A.Pappalardo, and J.Petri-Koenig. Note that this assumes integer quantization.""" assert input_bit_width is not None, "A2Q relies on input bit-width." assert input_is_signed is not None, "A2Q relies on input sign." assert accumulator_bit_width is not None, "A2Q relies on accumulator bit-width." diff --git a/src/brevitas/nn/utils.py b/src/brevitas/nn/utils.py index 3e7b423ee..ed5e87302 100644 --- a/src/brevitas/nn/utils.py +++ b/src/brevitas/nn/utils.py @@ -1,13 +1,10 @@ # Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause -from typing import Optional - import torch from torch import Tensor from torch.nn import Parameter -from brevitas.function.ops_ste import ceil_ste from brevitas.proxy.parameter_quant import BiasQuantProxyFromInjector from brevitas.proxy.parameter_quant import WeightQuantProxyFromInjector @@ -79,51 +76,3 @@ def check_tensors_same_ptr(tensor_list): else: return False return all(p == pointers[0] for p in pointers) - - -def calculate_min_accumulator_bit_width( - input_bit_width: Tensor, - input_is_signed: bool, - weight_max_l1_norm: Optional[Tensor] = None, - weight_bit_width: Optional[Tensor] = None, - n_elements: Optional[Tensor] = None, - min_val: Optional[float] = 1e-10): - """Using the closed-form bounds on accumulator bit-width as derived in `Quantized Neural Networks for Low-Precision Accumulation with Guaranteed Overflow - Avoidance` by I. Colbert, A. Pappalardo, and J. Petri-Koenig. This function returns the minimum accumulator bit-width that can be used without risk of - overflow. It supports both the data-type bound as well as the weight-level bound. - - Args: - input_bit_width (Tensor): the bit-width of the inputs to the layer. - input_is_signed (bool): calculate statistics for normalizing weight parameter. - weight_max_l1_norm (Tensor): the maximum per-channel l1-norm of the weights. - weight_bit_width (Tensor): the bit-width of the weights to the layer. - n_elements (Tensor): the number of elements in the dot product. - min_val (float): the minimum value used for the l1-norm, used to avoid log2(0). Default: 1e-8. - - Example (data-type bound): - >> acc_bit_width = calculate_min_accumulator_bit_width(input_bit_width, input_is_signed, weight_bit_width, n_elements) - - Example (weight-level bound): - >> acc_bit_width = calculate_min_accumulator_bit_width(input_bit_width, input_is_signed, weight_max_l1_norm) - """ - input_is_signed = float(input_is_signed) - # if the l1-norm of the weights is specified, then use the weight-level bound - if weight_max_l1_norm is not None: - assert isinstance(weight_max_l1_norm, (float, Tensor)), "The l1-norm of the weights needs to be a float or a torch.Tensor instance." - if isinstance(weight_max_l1_norm, Tensor): - assert weight_max_l1_norm.numel() == 1, "The minimum accumulator bit-width calculation currently only supports scalars." - weight_max_l1_norm = torch.clamp_min(weight_max_l1_norm, min_val) - input_is_signed = float(input_is_signed) - alpha = torch.log2(weight_max_l1_norm) + input_bit_width - input_is_signed - # else use the data-type bound - else: - assert isinstance(weight_bit_width, (float, Tensor)), "If weight_max_l1_norm is un-specified, weight_bit_width needs to be specified." - assert isinstance(n_elements, (float, Tensor)), "If weight_max_l1_norm is un-specified, n_elements needs to be specified." - if isinstance(n_elements, Tensor): - assert n_elements.numel() == 1, "The minimum accumulator bit-width calculation currently only supports scalars." - assert n_elements > 0, "There needs to be at least one element considered in this evaluation." - alpha = torch.log2(n_elements) + input_bit_width + weight_bit_width - input_is_signed - 1. - phi = lambda x: torch.log2(1. + pow(2., -x)) - min_bit_width = alpha + phi(alpha) + 1. - min_bit_width = ceil_ste(min_bit_width) - return min_bit_width # returns the minimum accumulator that can be used without risk of overflow diff --git a/src/brevitas/quant/base.py b/src/brevitas/quant/base.py index f72045dab..b509e6c16 100644 --- a/src/brevitas/quant/base.py +++ b/src/brevitas/quant/base.py @@ -20,6 +20,7 @@ from brevitas.core.restrict_val import FloatRestrictValue from brevitas.core.restrict_val import LogFloatRestrictValue from brevitas.core.scaling import AccumulatorAwareParameterPreScaling +from brevitas.core.scaling import AccumulatorAwareZeroCenterParameterPreScaling from brevitas.core.scaling import IntScaling from brevitas.core.scaling import ParameterFromStatsFromParameterScaling from brevitas.core.scaling import ParameterPreScalingWeightNorm @@ -38,6 +39,7 @@ from brevitas.core.utils import SingleArgStatelessBuffer from brevitas.core.zero_point import ParameterFromRuntimeZeroPoint from brevitas.core.zero_point import ParameterFromStatsFromParameterZeroPoint +from brevitas.core.zero_point import PreZeroCenterZeroPoint from brevitas.core.zero_point import StatsFromParameterZeroPoint from brevitas.core.zero_point import ZeroZeroPoint from brevitas.inject import ExtendedInjector @@ -76,6 +78,7 @@ 'BatchQuantStatsScaling1d', 'BatchQuantStatsScaling2d', 'AccumulatorAwareWeightQuant', + 'AccumulatorAwareZeroCenterWeightQuant', 'MSESymmetricScale', 'MSEAsymmetricScale', 'MSEWeightZeroPoint', @@ -400,6 +403,20 @@ def accumulator_bit_width_impl(accumulator_bit_width): float_to_int_impl = RoundToZeroSte # required to ensure no upwards rounding violates constraints +class AccumulatorAwareZeroCenterWeightQuant(AccumulatorAwareWeightQuant): + """Experimental zero-centered accumulator-aware weight quantized based on: + `A2Q+: Improving Accumulator-Aware Weight Quantization`. + + When compared to A2Q, A2Q+ changes the following: + (1) added zero-centering constraint on the weights (i.e., `PreZeroCenterZeroPoint`) + (2) a more relaxed l1-norm bound that is derived in the referenced paper + """ + pre_scaling_impl = AccumulatorAwareZeroCenterParameterPreScaling + pre_zero_point_impl = PreZeroCenterZeroPoint + pre_zero_point_shape = this.scaling_shape # TODO: decouple zero_point from scaling + pre_zero_point_stats_input_view_shape_impl = this.scaling_stats_input_view_shape_impl + + class MSESubInjectorBase(ExtendedInjector): @value diff --git a/src/brevitas/quant/scaled_int.py b/src/brevitas/quant/scaled_int.py index 8488865c9..5b7f16ad2 100644 --- a/src/brevitas/quant/scaled_int.py +++ b/src/brevitas/quant/scaled_int.py @@ -34,6 +34,7 @@ 'Uint8ActPerTensorFloatBatchQuant2d', 'Int8ActPerTensorFloatBatchQuant2d', 'Int8AccumulatorAwareWeightQuant', + 'Int8AccumulatorAwareZeroCenterWeightQuant', 'Int8WeightNormL2PerChannelFixedPoint'] @@ -406,9 +407,8 @@ class Int8WeightNormL2PerChannelFixedPoint(WeightNormPerChannelFloatDecoupled): Experimental 8-bit narrow signed integer quantizer with learned per-channel scaling factors and L2 weight normalization based on `A2Q: Accumulator-Aware Quantization with Guaranteed Overflow Avoidance` by I. Colbert, A. Pappalardo, and J. Petri-Koenig (https://arxiv.org/abs/2308.13504). - The quantizer learns scaling factors in the float domain and learns vector parameter g in the log - domain with the half-way rounding function. Suitable for retraining from floating-point depthwise - separable weights. + The quantizer learns scaling factors and norm parameter g in the log-float domain with the half-way + rounding function. Examples: >>> from brevitas.nn import QuantConv2d @@ -423,9 +423,8 @@ class Int8AccumulatorAwareWeightQuant(AccumulatorAwareWeightQuant): Experimental 8-bit narrow signed accumulator-aware integer quantizer with learned per-channel scaling factors based on `A2Q: Accumulator-Aware Quantization with Guaranteed Overflow Avoidance` by I.Colbert, A.Pappalardo, and J.Petri-Koenig (https://arxiv.org/abs/2308.13504). The quantizer - learns scaling factors in the float domain and learns vector parameter g in the log domain with - the round-to-zero rounding function. The norm is clamped according the the specified accumulator - bit-width. Suitable for retraining from floating-point depthwise separable weights. + learns scaling factors s and norm parameter g in the log-float domain with the round-to-zero + rounding function. The norm is clamped according the specified accumulator bit-width. Examples: >>> from brevitas.nn import QuantConv2d @@ -433,3 +432,20 @@ class Int8AccumulatorAwareWeightQuant(AccumulatorAwareWeightQuant): >>> conv.quant_weight() """ bit_width = 8 + + +class Int8AccumulatorAwareZeroCenterWeightQuant(AccumulatorAwareZeroCenterWeightQuant): + """ + Experimental 8-bit narrow signed zero-centered accumulator-aware integer weight quantizer with + learned per-channel scaling factors based on `A2Q+: Improving Accumulator-Aware Weight Quantization` + by I. Colbert, A. Pappalardo, J. Petri-Koenig, and Y. Umuroglu (https://arxiv.org/abs/2401.10432). + The quantizer learns scaling factors in the float domain and learns norm parameter g in the log domain + with the round-to-zero rounding function. The norm is clamped according the specified accumulator + bit-width using zero-centered weights. The zero-centering is done before rounding and clipping. + + Examples: + >>> from brevitas.nn import QuantConv2d + >>> conv = QuantConv2d(4, 4, 3, groups=4, weight_quant=Int8AccumulatorAwareZeroCenterWeightQuant) + >>> conv.quant_weight() + """ + bit_width = 8 diff --git a/src/brevitas_examples/super_resolution/utils/evaluate.py b/src/brevitas_examples/super_resolution/utils/evaluate.py index 39c05a434..2d317e9dc 100644 --- a/src/brevitas_examples/super_resolution/utils/evaluate.py +++ b/src/brevitas_examples/super_resolution/utils/evaluate.py @@ -1,31 +1,36 @@ # Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause +import torch from torch import Tensor import torch.nn as nn import brevitas.nn as qnn from brevitas.nn.quant_layer import QuantWeightBiasInputOutputLayer as QuantWBIOL -from brevitas.nn.utils import calculate_min_accumulator_bit_width -from brevitas_examples.super_resolution.models.espcn import QuantESPCN + +EPS = 1e-10 def _calc_min_acc_bit_width(module: QuantWBIOL) -> Tensor: + assert isinstance(module, qnn.QuantConv2d), "Error: function only support QuantConv2d." + # bit-width and sign need to come from the quant tensor of the preceding layer if no io_quant input_bit_width = module.quant_input_bit_width() - input_is_signed = module.is_quant_input_signed + input_is_signed = float(module.is_quant_input_signed) # the tensor quantizer requires a QuantTensor with specified bit-width and sign quant_weight = module.quant_weight() quant_weight = quant_weight.int().float() - if isinstance(module, - qnn.QuantConv2d): # shape = (out_channels, in_channels, kernel_size, kernel_size) - quant_weight_per_channel_l1_norm = quant_weight.norm(p=1, dim=(1, 2, 3)) + quant_weight_per_channel_l1_norm = quant_weight.norm(p=1, dim=(1, 2, 3)) # using the closed-form bounds on accumulator bit-width - cur_acc_bit_width = calculate_min_accumulator_bit_width( - input_bit_width, input_is_signed, quant_weight_per_channel_l1_norm.max()) - return cur_acc_bit_width + weight_max_l1_norm = quant_weight_per_channel_l1_norm.max() + weight_max_l1_norm = torch.clamp_min(weight_max_l1_norm, EPS) + alpha = torch.log2(weight_max_l1_norm) + input_bit_width - input_is_signed + phi = lambda x: torch.log2(1. + pow(2., -x)) + min_bit_width = alpha + phi(alpha) + 1. + min_bit_width = torch.ceil(min_bit_width) + return min_bit_width def evaluate_accumulator_bit_widths(model: nn.Module, inp: Tensor): diff --git a/src/brevitas_examples/super_resolution/utils/train.py b/src/brevitas_examples/super_resolution/utils/train.py index 94f9c36c3..53bf9e769 100644 --- a/src/brevitas_examples/super_resolution/utils/train.py +++ b/src/brevitas_examples/super_resolution/utils/train.py @@ -28,7 +28,7 @@ def acc_reg_penalty(module: AccumulatorAwareParameterPreScaling, inp, output): (weights, input_bit_width, input_is_signed) = inp s = module.scaling_impl(weights) # s g = abs_binary_sign_grad(module.restrict_clamp_scaling(module.value)) # g - T = module.get_upper_bound_on_l1_norm(input_bit_width, input_is_signed) # T / s + T = module.calc_max_l1_norm(input_bit_width, input_is_signed) # T / s cur_penalty = torch.relu(g - (T * s)).sum() reg_penalty += cur_penalty return output diff --git a/tests/brevitas/nn/nn_quantizers_fixture.py b/tests/brevitas/nn/nn_quantizers_fixture.py index 875b34afa..538e836e8 100644 --- a/tests/brevitas/nn/nn_quantizers_fixture.py +++ b/tests/brevitas/nn/nn_quantizers_fixture.py @@ -20,6 +20,7 @@ from brevitas.nn.quant_rnn import QuantLSTM from brevitas.nn.quant_rnn import QuantRNN from brevitas.quant.scaled_int import Int8AccumulatorAwareWeightQuant +from brevitas.quant.scaled_int import Int8AccumulatorAwareZeroCenterWeightQuant from brevitas.quant.scaled_int import Int8ActPerTensorFloat from brevitas.quant.scaled_int import Int8ActPerTensorFloatBatchQuant1d from brevitas.quant.scaled_int import Int8ActPerTensorFloatBatchQuant2d @@ -45,12 +46,16 @@ 'quant_sym': Int8WeightPerTensorFloat, 'quant_asym': ShiftedUint8WeightPerTensorFloat} +A2Q_WBIOL_WEIGHT_QUANTIZER = { + 'quant_a2q': Int8AccumulatorAwareWeightQuant, + 'quant_a2q_plus': Int8AccumulatorAwareZeroCenterWeightQuant} + WBIOL_WEIGHT_QUANTIZER = { 'None': None, 'quant_sym': Int8WeightPerTensorFloat, 'quant_asym': ShiftedUint8WeightPerTensorFloat, 'quant_decoupled': Int8WeightNormL2PerChannelFixedPoint, - 'quant_a2q': Int8AccumulatorAwareWeightQuant} + **A2Q_WBIOL_WEIGHT_QUANTIZER} WBIOL_IO_QUANTIZER = { 'None': None, @@ -107,7 +112,7 @@ def build_case_model( _, bias_quantizer = bias_quantizer _, io_quantizer = io_quantizer - if io_quantizer is None and not input_quantized and k == 'quant_a2q': + if io_quantizer is None and not input_quantized and k in A2Q_WBIOL_WEIGHT_QUANTIZER: pytest.skip( "A2Q uses an input-aware decoupled weight proxy that requires a quantized input tensor." ) @@ -215,11 +220,13 @@ def case_model( 'accumulator_bit_width', ACC_BIT_WIDTHS, ids=[f'accumulator_bit_width${bw}' for bw in ACC_BIT_WIDTHS]) -def case_model_a2q(io_quantizer, module, request, accumulator_bit_width): +@pytest_cases.parametrize( + 'weight_quantizer', + A2Q_WBIOL_WEIGHT_QUANTIZER.items(), + ids=[f'weight_quant${c}' for c, _ in A2Q_WBIOL_WEIGHT_QUANTIZER.items()]) +def case_model_a2q(io_quantizer, module, request, accumulator_bit_width, weight_quantizer): set_case_id(request.node.callspec.id, case_model_a2q) case_id = get_case_id(case_model_a2q) - # forcing test to only use accumulator-aware weight quantizer - weight_quantizer = ('quant_a2q', Int8AccumulatorAwareWeightQuant) # reducing coverage by fixing some case parameters return build_case_model( weight_quantizer, @@ -604,7 +611,7 @@ def case_mha( _, bias_quantizer = bias_quantizer _, io_quantizer = io_quantizer - if io_quantizer is None and k == 'quant_a2q': + if io_quantizer is None and k in A2Q_WBIOL_WEIGHT_QUANTIZER: # Can't rely on a QuantTensor input for quant_mha at this point pytest.skip( "A2Q uses an input-aware decoupled weight proxy that requires a quantized input tensor." diff --git a/tests/brevitas/nn/test_a2q.py b/tests/brevitas/nn/test_a2q.py index 2abcf9ef2..8c34f390c 100644 --- a/tests/brevitas/nn/test_a2q.py +++ b/tests/brevitas/nn/test_a2q.py @@ -1,11 +1,13 @@ # Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause +from typing import Optional + import pytest_cases from pytest_cases import get_case_id import torch +from torch import Tensor -from brevitas.nn.utils import calculate_min_accumulator_bit_width from brevitas.quant_tensor import QuantTensor from .nn_quantizers_fixture import case_model_a2q @@ -25,6 +27,45 @@ def parse_args(args): return kwargs +def calc_a2q_acc_bit_width( + weight_max_l1_norm: Tensor, + input_bit_width: Tensor, + input_is_signed: bool, + min_val: Optional[float] = 1e-10): + """Using the closed-form bounds on accumulator bit-width as derived in `A2Q: Accumulator-Aware Quantization with + Guaranteed Overflow Avoidance`. This function returns the minimum accumulator bit-width that can be used without + risk of overflow.""" + assert weight_max_l1_norm.numel() == 1 + input_is_signed = float(input_is_signed) + weight_max_l1_norm = torch.clamp_min(weight_max_l1_norm, min_val) + alpha = torch.log2(weight_max_l1_norm) + input_bit_width - input_is_signed + phi = lambda x: torch.log2(1. + pow(2., -x)) + min_bit_width = alpha + phi(alpha) + 1. + min_bit_width = torch.ceil(min_bit_width) + return min_bit_width + + +def calc_a2q_plus_acc_bit_width( + weight_max_l1_norm: Tensor, + input_bit_width: Tensor, + input_is_signed: bool, + min_val: Optional[float] = 1e-10): + """Using the closed-form bounds on accumulator bit-width as derived in `A2Q+: + Improving Accumulator-Aware Weight Quantization`. This function returns the + minimum accumulator bit-width that can be used without risk of overflow, + assuming that the floating-point weights are zero-centered.""" + input_is_signed = float(input_is_signed) + assert weight_max_l1_norm.numel() == 1 + weight_max_l1_norm = torch.clamp_min(weight_max_l1_norm, min_val) + input_range = pow(2., input_bit_width) - 1. # 2^N - 1. + min_bit_width = torch.log2(weight_max_l1_norm * input_range + 2.) + min_bit_width = torch.ceil(min_bit_width) + return min_bit_width + + +calc_fnc = {"quant_a2q": calc_a2q_acc_bit_width, "quant_a2q_plus": calc_a2q_plus_acc_bit_width} + + @pytest_cases.parametrize_with_cases('model_input', cases=case_model_a2q) def test_quant_wbiol_a2q(model_input, current_cases): """This test only verifies that the accumulator-aware weight quantization constraints the l1-norm of @@ -36,6 +77,7 @@ def test_quant_wbiol_a2q(model_input, current_cases): case_id = get_case_id(cases_generator_func) args = case_id.split('-')[1:] # Exclude first argument kwargs = parse_args(args) + fnc = calc_fnc[kwargs['weight_quant']] # A2Q needs to have a quantized input, which can be done by input quantizer or returning # a quantized tensor from the preceding layer @@ -70,8 +112,8 @@ def test_quant_wbiol_a2q(model_input, current_cases): raise NotImplementedError(f"Check for {kwargs['model_type']} is not yet implemented.") # using the closed-form bounds on accumulator bit-width - cur_acc_bit_width = calculate_min_accumulator_bit_width( - input_bit_width, input_is_signed, quant_weight_per_channel_l1_norm.max()) + cur_acc_bit_width = fnc( + quant_weight_per_channel_l1_norm.max(), input_bit_width, input_is_signed) exp_acc_bit_width = kwargs['accumulator_bit_width'] assert cur_acc_bit_width <= exp_acc_bit_width, \ f"Model does not satisfy accumulator bit-width bounds. Expected {exp_acc_bit_width}, got {cur_acc_bit_width}" diff --git a/tests/brevitas_examples/test_examples_import.py b/tests/brevitas_examples/test_examples_import.py index c80dd3161..e3796d03c 100644 --- a/tests/brevitas_examples/test_examples_import.py +++ b/tests/brevitas_examples/test_examples_import.py @@ -4,6 +4,7 @@ import pytest from brevitas.quant.scaled_int import Int8AccumulatorAwareWeightQuant +from brevitas.quant.scaled_int import Int8AccumulatorAwareZeroCenterWeightQuant from brevitas.quant.scaled_int import Int8WeightPerChannelFloat @@ -41,7 +42,11 @@ def test_import_stt(): @pytest.mark.parametrize("upscale_factor", [2, 3, 4]) @pytest.mark.parametrize("num_channels", [1, 3]) @pytest.mark.parametrize( - "weight_quant", [Int8WeightPerChannelFloat, Int8AccumulatorAwareWeightQuant]) + "weight_quant", + [ + Int8WeightPerChannelFloat, + Int8AccumulatorAwareWeightQuant, + Int8AccumulatorAwareZeroCenterWeightQuant]) def test_super_resolution_float_and_quant_models_match(upscale_factor, num_channels, weight_quant): import brevitas.config as config from brevitas_examples.super_resolution.models import float_espcn