diff --git a/src/brevitas/nn/utils.py b/src/brevitas/nn/utils.py index d718dfea9..ed5e87302 100644 --- a/src/brevitas/nn/utils.py +++ b/src/brevitas/nn/utils.py @@ -1,13 +1,10 @@ # Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause -from typing import Optional - import torch from torch import Tensor from torch.nn import Parameter -from brevitas.function.ops_ste import ceil_ste from brevitas.proxy.parameter_quant import BiasQuantProxyFromInjector from brevitas.proxy.parameter_quant import WeightQuantProxyFromInjector @@ -79,62 +76,3 @@ def check_tensors_same_ptr(tensor_list): else: return False return all(p == pointers[0] for p in pointers) - - -def calculate_min_accumulator_bit_width( - input_bit_width: Tensor, - input_is_signed: bool, - weight_max_l1_norm: Optional[Tensor] = None, - weight_bit_width: Optional[Tensor] = None, - n_elements: Optional[Tensor] = None, - min_val: Optional[float] = 1e-10, - zero_centered_weights: bool = False): - """Using the closed-form bounds on accumulator bit-width as derived in `A2Q: Accumulator-Aware Quantization with - Guaranteed Overflow Avoidance`. This function returns the minimum accumulator bit-width that can be used without - risk of overflow. It supports both the data-type bound as well as the weight-level bound. - - If `zero_centered_weights=True` and `weight_max_l1_norm` is not None, then the function uses the bounds derived in - `A2Q+: Improving Accumulator-Aware Weight Quantization`. - - Args: - input_bit_width (Tensor): the bit-width of the inputs to the layer. - input_is_signed (bool): calculate statistics for normalizing weight parameter. - weight_max_l1_norm (Tensor): the maximum per-channel l1-norm of the weights. - weight_bit_width (Tensor): the bit-width of the weights to the layer. - n_elements (Tensor): the number of elements in the dot product. - min_val (float): the minimum value used for the l1-norm, used to avoid log2(0). Default: 1e-8. - zero_centered_weights (bool): if the weights are zero-centered: Default: false. - - Example (data-type bound): - >> acc_bit_width = calculate_min_accumulator_bit_width(input_bit_width, input_is_signed, weight_bit_width, n_elements) - - Example (weight-level bound): - >> acc_bit_width = calculate_min_accumulator_bit_width(input_bit_width, input_is_signed, weight_max_l1_norm) - """ - input_is_signed = float(input_is_signed) - # if the l1-norm of the weights is specified, then use the weight-level bound - if weight_max_l1_norm is not None: - assert isinstance(weight_max_l1_norm, (float, Tensor)), "The l1-norm of the weights needs to be a float or a torch.Tensor instance." - if isinstance(weight_max_l1_norm, Tensor): - assert weight_max_l1_norm.numel() == 1, "The minimum accumulator bit-width calculation currently only supports scalars." - weight_max_l1_norm = torch.clamp_min(weight_max_l1_norm, min_val) - # if the weights are zero-centered, then use the improved bound - if zero_centered_weights: - input_range = pow(2., input_bit_width) - 1. # 2^N - 1. - min_bit_width = torch.log2(weight_max_l1_norm * input_range + 2.) - min_bit_width = ceil_ste(min_bit_width) - return min_bit_width - input_is_signed = float(input_is_signed) - alpha = torch.log2(weight_max_l1_norm) + input_bit_width - input_is_signed - # else use the data-type bound - else: - assert isinstance(weight_bit_width, (float, Tensor)), "If weight_max_l1_norm is un-specified, weight_bit_width needs to be specified." - assert isinstance(n_elements, (float, Tensor)), "If weight_max_l1_norm is un-specified, n_elements needs to be specified." - if isinstance(n_elements, Tensor): - assert n_elements.numel() == 1, "The minimum accumulator bit-width calculation currently only supports scalars." - assert n_elements > 0, "There needs to be at least one element considered in this evaluation." - alpha = torch.log2(n_elements) + input_bit_width + weight_bit_width - input_is_signed - 1. - phi = lambda x: torch.log2(1. + pow(2., -x)) - min_bit_width = alpha + phi(alpha) + 1. - min_bit_width = ceil_ste(min_bit_width) - return min_bit_width # returns the minimum accumulator that can be used without risk of overflow diff --git a/tests/brevitas/nn/test_a2q.py b/tests/brevitas/nn/test_a2q.py index f2464f36c..8c34f390c 100644 --- a/tests/brevitas/nn/test_a2q.py +++ b/tests/brevitas/nn/test_a2q.py @@ -1,10 +1,13 @@ # Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause +from typing import Optional + import pytest_cases from pytest_cases import get_case_id +import torch +from torch import Tensor -from brevitas.nn.utils import calculate_min_accumulator_bit_width from brevitas.quant_tensor import QuantTensor from .nn_quantizers_fixture import case_model_a2q @@ -24,6 +27,45 @@ def parse_args(args): return kwargs +def calc_a2q_acc_bit_width( + weight_max_l1_norm: Tensor, + input_bit_width: Tensor, + input_is_signed: bool, + min_val: Optional[float] = 1e-10): + """Using the closed-form bounds on accumulator bit-width as derived in `A2Q: Accumulator-Aware Quantization with + Guaranteed Overflow Avoidance`. This function returns the minimum accumulator bit-width that can be used without + risk of overflow.""" + assert weight_max_l1_norm.numel() == 1 + input_is_signed = float(input_is_signed) + weight_max_l1_norm = torch.clamp_min(weight_max_l1_norm, min_val) + alpha = torch.log2(weight_max_l1_norm) + input_bit_width - input_is_signed + phi = lambda x: torch.log2(1. + pow(2., -x)) + min_bit_width = alpha + phi(alpha) + 1. + min_bit_width = torch.ceil(min_bit_width) + return min_bit_width + + +def calc_a2q_plus_acc_bit_width( + weight_max_l1_norm: Tensor, + input_bit_width: Tensor, + input_is_signed: bool, + min_val: Optional[float] = 1e-10): + """Using the closed-form bounds on accumulator bit-width as derived in `A2Q+: + Improving Accumulator-Aware Weight Quantization`. This function returns the + minimum accumulator bit-width that can be used without risk of overflow, + assuming that the floating-point weights are zero-centered.""" + input_is_signed = float(input_is_signed) + assert weight_max_l1_norm.numel() == 1 + weight_max_l1_norm = torch.clamp_min(weight_max_l1_norm, min_val) + input_range = pow(2., input_bit_width) - 1. # 2^N - 1. + min_bit_width = torch.log2(weight_max_l1_norm * input_range + 2.) + min_bit_width = torch.ceil(min_bit_width) + return min_bit_width + + +calc_fnc = {"quant_a2q": calc_a2q_acc_bit_width, "quant_a2q_plus": calc_a2q_plus_acc_bit_width} + + @pytest_cases.parametrize_with_cases('model_input', cases=case_model_a2q) def test_quant_wbiol_a2q(model_input, current_cases): """This test only verifies that the accumulator-aware weight quantization constraints the l1-norm of @@ -35,7 +77,7 @@ def test_quant_wbiol_a2q(model_input, current_cases): case_id = get_case_id(cases_generator_func) args = case_id.split('-')[1:] # Exclude first argument kwargs = parse_args(args) - zero_centered_weights = kwargs['weight_quant'] == "quant_a2q_plus" # A2Q+ zero-centers weights + fnc = calc_fnc[kwargs['weight_quant']] # A2Q needs to have a quantized input, which can be done by input quantizer or returning # a quantized tensor from the preceding layer @@ -70,11 +112,8 @@ def test_quant_wbiol_a2q(model_input, current_cases): raise NotImplementedError(f"Check for {kwargs['model_type']} is not yet implemented.") # using the closed-form bounds on accumulator bit-width - cur_acc_bit_width = calculate_min_accumulator_bit_width( - input_bit_width, - input_is_signed, - quant_weight_per_channel_l1_norm.max(), - zero_centered_weights=zero_centered_weights) + cur_acc_bit_width = fnc( + quant_weight_per_channel_l1_norm.max(), input_bit_width, input_is_signed) exp_acc_bit_width = kwargs['accumulator_bit_width'] assert cur_acc_bit_width <= exp_acc_bit_width, \ f"Model does not satisfy accumulator bit-width bounds. Expected {exp_acc_bit_width}, got {cur_acc_bit_width}"