Skip to content

Commit

Permalink
Moving acc bit-width calculation to tests
Browse files Browse the repository at this point in the history
  • Loading branch information
i-colbert committed Jan 19, 2024
1 parent a318e2d commit 3dd7e43
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 69 deletions.
62 changes: 0 additions & 62 deletions src/brevitas/nn/utils.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,10 @@
# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause

from typing import Optional

import torch
from torch import Tensor
from torch.nn import Parameter

from brevitas.function.ops_ste import ceil_ste
from brevitas.proxy.parameter_quant import BiasQuantProxyFromInjector
from brevitas.proxy.parameter_quant import WeightQuantProxyFromInjector

Expand Down Expand Up @@ -79,62 +76,3 @@ def check_tensors_same_ptr(tensor_list):
else:
return False
return all(p == pointers[0] for p in pointers)


def calculate_min_accumulator_bit_width(
input_bit_width: Tensor,
input_is_signed: bool,
weight_max_l1_norm: Optional[Tensor] = None,
weight_bit_width: Optional[Tensor] = None,
n_elements: Optional[Tensor] = None,
min_val: Optional[float] = 1e-10,
zero_centered_weights: bool = False):
"""Using the closed-form bounds on accumulator bit-width as derived in `A2Q: Accumulator-Aware Quantization with
Guaranteed Overflow Avoidance`. This function returns the minimum accumulator bit-width that can be used without
risk of overflow. It supports both the data-type bound as well as the weight-level bound.
If `zero_centered_weights=True` and `weight_max_l1_norm` is not None, then the function uses the bounds derived in
`A2Q+: Improving Accumulator-Aware Weight Quantization`.
Args:
input_bit_width (Tensor): the bit-width of the inputs to the layer.
input_is_signed (bool): calculate statistics for normalizing weight parameter.
weight_max_l1_norm (Tensor): the maximum per-channel l1-norm of the weights.
weight_bit_width (Tensor): the bit-width of the weights to the layer.
n_elements (Tensor): the number of elements in the dot product.
min_val (float): the minimum value used for the l1-norm, used to avoid log2(0). Default: 1e-8.
zero_centered_weights (bool): if the weights are zero-centered: Default: false.
Example (data-type bound):
>> acc_bit_width = calculate_min_accumulator_bit_width(input_bit_width, input_is_signed, weight_bit_width, n_elements)
Example (weight-level bound):
>> acc_bit_width = calculate_min_accumulator_bit_width(input_bit_width, input_is_signed, weight_max_l1_norm)
"""
input_is_signed = float(input_is_signed)
# if the l1-norm of the weights is specified, then use the weight-level bound
if weight_max_l1_norm is not None:
assert isinstance(weight_max_l1_norm, (float, Tensor)), "The l1-norm of the weights needs to be a float or a torch.Tensor instance."
if isinstance(weight_max_l1_norm, Tensor):
assert weight_max_l1_norm.numel() == 1, "The minimum accumulator bit-width calculation currently only supports scalars."
weight_max_l1_norm = torch.clamp_min(weight_max_l1_norm, min_val)
# if the weights are zero-centered, then use the improved bound
if zero_centered_weights:
input_range = pow(2., input_bit_width) - 1. # 2^N - 1.
min_bit_width = torch.log2(weight_max_l1_norm * input_range + 2.)
min_bit_width = ceil_ste(min_bit_width)
return min_bit_width
input_is_signed = float(input_is_signed)
alpha = torch.log2(weight_max_l1_norm) + input_bit_width - input_is_signed
# else use the data-type bound
else:
assert isinstance(weight_bit_width, (float, Tensor)), "If weight_max_l1_norm is un-specified, weight_bit_width needs to be specified."
assert isinstance(n_elements, (float, Tensor)), "If weight_max_l1_norm is un-specified, n_elements needs to be specified."
if isinstance(n_elements, Tensor):
assert n_elements.numel() == 1, "The minimum accumulator bit-width calculation currently only supports scalars."
assert n_elements > 0, "There needs to be at least one element considered in this evaluation."
alpha = torch.log2(n_elements) + input_bit_width + weight_bit_width - input_is_signed - 1.
phi = lambda x: torch.log2(1. + pow(2., -x))
min_bit_width = alpha + phi(alpha) + 1.
min_bit_width = ceil_ste(min_bit_width)
return min_bit_width # returns the minimum accumulator that can be used without risk of overflow
53 changes: 46 additions & 7 deletions tests/brevitas/nn/test_a2q.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause

from typing import Optional

import pytest_cases
from pytest_cases import get_case_id
import torch
from torch import Tensor

from brevitas.nn.utils import calculate_min_accumulator_bit_width
from brevitas.quant_tensor import QuantTensor

from .nn_quantizers_fixture import case_model_a2q
Expand All @@ -24,6 +27,45 @@ def parse_args(args):
return kwargs


def calc_a2q_acc_bit_width(
weight_max_l1_norm: Tensor,
input_bit_width: Tensor,
input_is_signed: bool,
min_val: Optional[float] = 1e-10):
"""Using the closed-form bounds on accumulator bit-width as derived in `A2Q: Accumulator-Aware Quantization with
Guaranteed Overflow Avoidance`. This function returns the minimum accumulator bit-width that can be used without
risk of overflow."""
assert weight_max_l1_norm.numel() == 1
input_is_signed = float(input_is_signed)
weight_max_l1_norm = torch.clamp_min(weight_max_l1_norm, min_val)
alpha = torch.log2(weight_max_l1_norm) + input_bit_width - input_is_signed
phi = lambda x: torch.log2(1. + pow(2., -x))
min_bit_width = alpha + phi(alpha) + 1.
min_bit_width = torch.ceil(min_bit_width)
return min_bit_width


def calc_a2q_plus_acc_bit_width(
weight_max_l1_norm: Tensor,
input_bit_width: Tensor,
input_is_signed: bool,
min_val: Optional[float] = 1e-10):
"""Using the closed-form bounds on accumulator bit-width as derived in `A2Q+:
Improving Accumulator-Aware Weight Quantization`. This function returns the
minimum accumulator bit-width that can be used without risk of overflow,
assuming that the floating-point weights are zero-centered."""
input_is_signed = float(input_is_signed)
assert weight_max_l1_norm.numel() == 1
weight_max_l1_norm = torch.clamp_min(weight_max_l1_norm, min_val)
input_range = pow(2., input_bit_width) - 1. # 2^N - 1.
min_bit_width = torch.log2(weight_max_l1_norm * input_range + 2.)
min_bit_width = torch.ceil(min_bit_width)
return min_bit_width


calc_fnc = {"quant_a2q": calc_a2q_acc_bit_width, "quant_a2q_plus": calc_a2q_plus_acc_bit_width}


@pytest_cases.parametrize_with_cases('model_input', cases=case_model_a2q)
def test_quant_wbiol_a2q(model_input, current_cases):
"""This test only verifies that the accumulator-aware weight quantization constraints the l1-norm of
Expand All @@ -35,7 +77,7 @@ def test_quant_wbiol_a2q(model_input, current_cases):
case_id = get_case_id(cases_generator_func)
args = case_id.split('-')[1:] # Exclude first argument
kwargs = parse_args(args)
zero_centered_weights = kwargs['weight_quant'] == "quant_a2q_plus" # A2Q+ zero-centers weights
fnc = calc_fnc[kwargs['weight_quant']]

# A2Q needs to have a quantized input, which can be done by input quantizer or returning
# a quantized tensor from the preceding layer
Expand Down Expand Up @@ -70,11 +112,8 @@ def test_quant_wbiol_a2q(model_input, current_cases):
raise NotImplementedError(f"Check for {kwargs['model_type']} is not yet implemented.")

# using the closed-form bounds on accumulator bit-width
cur_acc_bit_width = calculate_min_accumulator_bit_width(
input_bit_width,
input_is_signed,
quant_weight_per_channel_l1_norm.max(),
zero_centered_weights=zero_centered_weights)
cur_acc_bit_width = fnc(
quant_weight_per_channel_l1_norm.max(), input_bit_width, input_is_signed)
exp_acc_bit_width = kwargs['accumulator_bit_width']
assert cur_acc_bit_width <= exp_acc_bit_width, \
f"Model does not satisfy accumulator bit-width bounds. Expected {exp_acc_bit_width}, got {cur_acc_bit_width}"

0 comments on commit 3dd7e43

Please sign in to comment.