Skip to content

Commit

Permalink
Fix (GPFA2Q): move upper_bound function
Browse files Browse the repository at this point in the history
  • Loading branch information
fabianandresgrob committed Nov 20, 2023
1 parent 36727ae commit 0bea4e3
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 18 deletions.
2 changes: 1 addition & 1 deletion src/brevitas/core/scaling/pre_scaling.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from brevitas.core.stats import SCALAR_SHAPE
from brevitas.core.stats.stats_wrapper import _Stats
from brevitas.function import abs_binary_sign_grad
from brevitas.nn.utils import get_upper_bound_on_l1_norm
from brevitas.function import get_upper_bound_on_l1_norm

__all__ = ["ParameterPreScalingWeightNorm", "AccumulatorAwareParameterPreScaling"]

Expand Down
14 changes: 14 additions & 0 deletions src/brevitas/function/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,3 +201,17 @@ def max_float(exponent_bit_width: Tensor, mantissa_bit_width: Tensor, exponent_b
device=mantissa_bit_width.device)))
max_val = max_mantissa * (2 ** max_exponent)
return max_val


def get_upper_bound_on_l1_norm(
accumulator_bit_width: int, input_bit_width: Tensor, input_is_signed: bool) -> Tensor:
"""Calculate the upper bound on the l1-norm of the weights using the derivations from
`Quantized Neural Networks for Low-Precision Accumulation with Guaranteed Overflow Avoidance`
by I.Colbert, A.Pappalardo, and J.Petri-Koenig."""
assert input_bit_width is not None, "A2Q relies on input bit-width."
assert input_is_signed is not None, "A2Q relies on input sign."
input_is_signed = float(input_is_signed) # 1. if signed else 0.
max_accumulator_bit_width = accumulator_bit_width # P
max_accumulator_mag = pow(2., max_accumulator_bit_width - 1.) - 1. # 2^{P-1}-1
max_input_mag_inverse = pow(2., input_is_signed - input_bit_width)
return max_accumulator_mag * max_input_mag_inverse
4 changes: 1 addition & 3 deletions src/brevitas/graph/gpfq.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,19 @@
# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause

from abc import ABC
from copy import deepcopy
from typing import List, Optional

import numpy as np
import torch
from torch import Tensor
import unfoldNd

from brevitas.function import get_upper_bound_on_l1_norm
from brevitas.graph.gpxq import GPxQ
from brevitas.graph.gpxq import gpxq_mode
from brevitas.graph.gpxq import StopFwdException
from brevitas.graph.gpxq import SUPPORTED_CONV_OP
import brevitas.nn as qnn
from brevitas.nn.utils import get_upper_bound_on_l1_norm


class gpfq_mode(gpxq_mode):
Expand Down
14 changes: 0 additions & 14 deletions src/brevitas/nn/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,17 +127,3 @@ def calculate_min_accumulator_bit_width(
min_bit_width = alpha + phi(alpha) + 1.
min_bit_width = ceil_ste(min_bit_width)
return min_bit_width # returns the minimum accumulator that can be used without risk of overflow


def get_upper_bound_on_l1_norm(
accumulator_bit_width: int, input_bit_width: Tensor, input_is_signed: bool) -> Tensor:
"""Calculate the upper bound on the l1-norm of the weights using the derivations from
`Quantized Neural Networks for Low-Precision Accumulation with Guaranteed Overflow Avoidance`
by I.Colbert, A.Pappalardo, and J.Petri-Koenig."""
assert input_bit_width is not None, "A2Q relies on input bit-width."
assert input_is_signed is not None, "A2Q relies on input sign."
input_is_signed = float(input_is_signed) # 1. if signed else 0.
max_accumulator_bit_width = accumulator_bit_width # P
max_accumulator_mag = pow(2., max_accumulator_bit_width - 1.) - 1. # 2^{P-1}-1
max_input_mag_inverse = pow(2., input_is_signed - input_bit_width)
return max_accumulator_mag * max_input_mag_inverse

0 comments on commit 0bea4e3

Please sign in to comment.