Skip to content

Commit

Permalink
Fixing naming convention
Browse files Browse the repository at this point in the history
  • Loading branch information
i-colbert committed Jan 19, 2024
1 parent 0e10998 commit a318e2d
Show file tree
Hide file tree
Showing 5 changed files with 21 additions and 19 deletions.
2 changes: 1 addition & 1 deletion src/brevitas/core/scaling/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from .int_scaling import IntScaling
from .int_scaling import PowerOfTwoIntScaling
from .pre_scaling import AccumulatorAwareParameterPreScaling
from .pre_scaling import ImprovedAccumulatorAwareParameterPreScaling
from .pre_scaling import AccumulatorAwareZeroCenterParameterPreScaling
from .pre_scaling import ParameterPreScalingWeightNorm
from .runtime import RuntimeStatsScaling
from .runtime import StatsFromParameterScaling
Expand Down
12 changes: 7 additions & 5 deletions src/brevitas/core/scaling/pre_scaling.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,10 @@
from brevitas.function import abs_binary_sign_grad
from brevitas.function import get_upper_bound_on_l1_norm

__all__ = ["ParameterPreScalingWeightNorm", "AccumulatorAwareParameterPreScaling"]
__all__ = [
"ParameterPreScalingWeightNorm",
"AccumulatorAwareParameterPreScaling",
"AccumulatorAwareZeroCenterParameterPreScaling"]


class ParameterPreScalingWeightNorm(brevitas.jit.ScriptModule):
Expand Down Expand Up @@ -192,14 +195,13 @@ def forward(self, weights: Tensor, input_bit_width: Tensor, input_is_signed: boo
return value


class ImprovedAccumulatorAwareParameterPreScaling(AccumulatorAwareParameterPreScaling):
class AccumulatorAwareZeroCenterParameterPreScaling(AccumulatorAwareParameterPreScaling):
"""
ScriptModule implementation of learned pre-clipping scaling factor to support
A2Q+ as proposed in `A2Q+: Improving Accumulator-Aware Weight Quantization` by
I. Colbert, A. Pappalardo, J. Petri-Koenig, and Y. Umuroglu.
A2Q+ as proposed in `A2Q+: Improving Accumulator-Aware Weight Quantization`.
The module implements the zero-centering constraint as a pre-clipping zero-point
(i.e., `PreZeroCenterZeroPoint`) and updates the calculation on the l1-norm maximum.
(i.e., `PreZeroCenterZeroPoint`) to relax the l1-norm constraint.
Args:
scaling_impl (Module): post-clipping scaling factor.
Expand Down
14 changes: 7 additions & 7 deletions src/brevitas/quant/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from brevitas.core.restrict_val import FloatRestrictValue
from brevitas.core.restrict_val import LogFloatRestrictValue
from brevitas.core.scaling import AccumulatorAwareParameterPreScaling
from brevitas.core.scaling import ImprovedAccumulatorAwareParameterPreScaling
from brevitas.core.scaling import AccumulatorAwareZeroCenterParameterPreScaling
from brevitas.core.scaling import IntScaling
from brevitas.core.scaling import ParameterFromStatsFromParameterScaling
from brevitas.core.scaling import ParameterPreScalingWeightNorm
Expand Down Expand Up @@ -78,7 +78,7 @@
'BatchQuantStatsScaling1d',
'BatchQuantStatsScaling2d',
'AccumulatorAwareWeightQuant',
'ImprovedAccumulatorAwareWeightQuant',
'AccumulatorAwareZeroCenterWeightQuant',
'MSESymmetricScale',
'MSEAsymmetricScale',
'MSEWeightZeroPoint',
Expand Down Expand Up @@ -403,15 +403,15 @@ def accumulator_bit_width_impl(accumulator_bit_width):
float_to_int_impl = RoundToZeroSte # required to ensure no upwards rounding violates constraints


class ImprovedAccumulatorAwareWeightQuant(AccumulatorAwareWeightQuant):
class AccumulatorAwareZeroCenterWeightQuant(AccumulatorAwareWeightQuant):
"""Experimental improved accumulator-aware weight quantized based on: `A2Q+: Improving
Accumulator-Aware Weight Quantization` by I. Colbert, A. Pappalardo, J. Petri-Koenig,
and Y. Umuroglu. When compared to the standard accumulator-aware quantizer (A2Q), A2Q+
changes the following:
Accumulator-Aware Weight Quantization`.
When compared to A2Q, A2Q+ changes the following:
(1) an added zero-centering constraint on the weights (i.e., `PreZeroCenterZeroPoint`)
(2) an improved l1-norm bound that is derived in the referenced paper
"""
pre_scaling_impl = ImprovedAccumulatorAwareParameterPreScaling
pre_scaling_impl = AccumulatorAwareZeroCenterParameterPreScaling
pre_zero_point_impl = PreZeroCenterZeroPoint
pre_zero_point_shape = this.scaling_shape # TODO: decouple zero_point from scaling
pre_zero_point_stats_input_view_shape_impl = this.scaling_stats_input_view_shape_impl
Expand Down
8 changes: 4 additions & 4 deletions src/brevitas/quant/scaled_int.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@
'Uint8ActPerTensorFloatBatchQuant2d',
'Int8ActPerTensorFloatBatchQuant2d',
'Int8AccumulatorAwareWeightQuant',
'Int8WeightNormL2PerChannelFixedPoint',
'Int8ImprovedAccumulatorAwareWeightQuant']
'Int8AccumulatorAwareZeroCenterWeightQuant',
'Int8WeightNormL2PerChannelFixedPoint']


class Int8ActPerTensorFloatMinMaxInit(IntQuant,
Expand Down Expand Up @@ -434,7 +434,7 @@ class Int8AccumulatorAwareWeightQuant(AccumulatorAwareWeightQuant):
bit_width = 8


class Int8ImprovedAccumulatorAwareWeightQuant(ImprovedAccumulatorAwareWeightQuant):
class Int8AccumulatorAwareZeroCenterWeightQuant(AccumulatorAwareZeroCenterWeightQuant):
"""
Experimental 8-bit narrow signed improved accumulator-aware integer weight quantizer with learned
per-channel scaling factors based on `A2Q+: Improving Accumulator-Aware Weight Quantization` by
Expand All @@ -445,7 +445,7 @@ class Int8ImprovedAccumulatorAwareWeightQuant(ImprovedAccumulatorAwareWeightQuan
Examples:
>>> from brevitas.nn import QuantConv2d
>>> conv = QuantConv2d(4, 4, 3, groups=4, weight_quant=Int8ImprovedAccumulatorAwareWeightQuant)
>>> conv = QuantConv2d(4, 4, 3, groups=4, weight_quant=Int8AccumulatorAwareZeroCenterWeightQuant)
>>> conv.quant_weight()
"""
bit_width = 8
4 changes: 2 additions & 2 deletions tests/brevitas/nn/nn_quantizers_fixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,11 @@
from brevitas.nn.quant_rnn import QuantLSTM
from brevitas.nn.quant_rnn import QuantRNN
from brevitas.quant.scaled_int import Int8AccumulatorAwareWeightQuant
from brevitas.quant.scaled_int import Int8AccumulatorAwareZeroCenterWeightQuant
from brevitas.quant.scaled_int import Int8ActPerTensorFloat
from brevitas.quant.scaled_int import Int8ActPerTensorFloatBatchQuant1d
from brevitas.quant.scaled_int import Int8ActPerTensorFloatBatchQuant2d
from brevitas.quant.scaled_int import Int8BiasPerTensorFloatInternalScaling
from brevitas.quant.scaled_int import Int8ImprovedAccumulatorAwareWeightQuant
from brevitas.quant.scaled_int import Int8WeightNormL2PerChannelFixedPoint
from brevitas.quant.scaled_int import Int8WeightPerTensorFloat
from brevitas.quant.scaled_int import Int16Bias
Expand All @@ -48,7 +48,7 @@

A2Q_WBIOL_WEIGHT_QUANTIZER = {
'quant_a2q': Int8AccumulatorAwareWeightQuant,
'quant_a2q_plus': Int8ImprovedAccumulatorAwareWeightQuant}
'quant_a2q_plus': Int8AccumulatorAwareZeroCenterWeightQuant}

WBIOL_WEIGHT_QUANTIZER = {
'None': None,
Expand Down

0 comments on commit a318e2d

Please sign in to comment.