Skip to content

Commit

Permalink
Fix (a2q): correcting post-rounding scaling initialization (#659)
Browse files Browse the repository at this point in the history
* Fixing scaling initialization

* Moving A2Q quantizer

* Fixing import for A2Q quantizer

* Pre-commit fixe

* Update eval_model.py

* Adding warning for QCDQ export support

* Fixing A2Q imports

* Update base.py

* Fixing import errors
  • Loading branch information
i-colbert authored Jul 25, 2023
1 parent 4d0852d commit a236539
Show file tree
Hide file tree
Showing 8 changed files with 65 additions and 50 deletions.
13 changes: 10 additions & 3 deletions src/brevitas/quant/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from brevitas.core.scaling import IntScaling
from brevitas.core.scaling import ParameterFromStatsFromParameterScaling
from brevitas.core.scaling import ParameterPreScalingWeightNorm
from brevitas.core.scaling import ParameterScaling
from brevitas.core.scaling import SCALAR_SHAPE
from brevitas.core.scaling import SCALING_STATS_REDUCE_DIM
from brevitas.core.scaling import StatsFromParameterScaling
Expand Down Expand Up @@ -340,11 +341,17 @@ class WeightNormPerChannelFloatDecoupled(SolveWeightScalingStatsInputDimsFromMod
details on the arithmetic, see `ParameterPreScalingWeightNorm`. For further details
on the weight normalization-based quantization technique, see the referenced paper."""

@value
def scaling_init(scaling_init_impl, bit_width):
scales = scaling_init_impl.parameter_list_stats() / (pow(2., bit_width - 1.) - 1.)
return scales

proxy_class = DecoupledWeightQuantProxyFromInjector
tensor_quant = DecoupledRescalingIntQuant
decoupled_int_quant = DecoupledIntQuant
tensor_clamp_impl = TensorClamp
scaling_impl = ParameterFromStatsFromParameterScaling
scaling_impl = ParameterScaling
scaling_init_impl = StatsFromParameterScaling
restrict_scaling_impl = FloatRestrictValue
scaling_stats_impl = AbsMax
pre_scaling_impl = ParameterPreScalingWeightNorm
Expand All @@ -360,8 +367,8 @@ class WeightNormPerChannelFloatDecoupled(SolveWeightScalingStatsInputDimsFromMod
scaling_stats_input_view_shape_impl = OverOutputChannelView
stats_reduce_dim = SCALING_STATS_REDUCE_DIM
scaling_per_output_channel = True
scaling_min_val = 1e-8
pre_scaling_min_val = 1e-8
scaling_min_val = 1e-10
pre_scaling_min_val = 1e-10


class AccumulatorAwareWeightQuant(WeightNormPerChannelFloatDecoupled):
Expand Down
34 changes: 0 additions & 34 deletions src/brevitas/quant/fixed_point.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,37 +196,3 @@ class Int4WeightPerTensorFixedPointDecoupled(WeightPerTensorFloatDecoupledL2Para
restrict_scaling_impl = PowerOfTwoRestrictValue
int_scaling_impl = PowerOfTwoIntScaling
restrict_value_float_to_int_impl = CeilSte


class Int8WeightNormL2PerChannelFixedPoint(WeightNormPerChannelFloatDecoupled):
"""
Experimental 8-bit narrow signed integer quantizer with learned per-channel scaling factors
and L2 weight normalization based on `Quantized Neural Networks for Low-Precision Accumulation
with Guaranteed Overflow Avoidance` by I. Colbert, A. Pappalardo, and J. Petri-Koenig
(https://arxiv.org/abs/2301.13376). The quantizer learns scaling factors in the float domain and
learns vector parameter g in the log domain with the half-way rounding function. Suitable for
retraining from floating-point depthwise separable weights.
Examples:
>>> from brevitas.nn import QuantConv2d
>>> conv = QuantConv2d(4, 4, 3, groups=4, weight_quant=Int8WeightNormL2PerChannelFixedPoint)
>>> conv.quant_weight()
"""
bit_width = 8


class Int8AccumulatorAwareWeightQuant(AccumulatorAwareWeightQuant):
"""
Experimental 8-bit narrow signed accumulator-aware integer quantizer with learned per-channel
scaling factors based on `Quantized Neural Networks for Low-Precision Accumulation with Guaranteed
Overflow Avoidance` by I.Colbert, A.Pappalardo, and J.Petri-Koenig (https://arxiv.org/abs/2301.13376).
The quantizer learns scaling factors in the float domain and learns vector parameter g in the log
domain with the round-to-zero rounding function. The norm is clamped according the the specified
accumulator bit-width. Suitable for retraining from floating-point depthwise separable weights.
Examples:
>>> from brevitas.nn import QuantConv2d
>>> conv = QuantConv2d(4, 4, 3, groups=4, weight_quant=Int8AccumulatorAwareWeightQuant)
>>> conv.quant_weight()
"""
bit_width = 8
39 changes: 37 additions & 2 deletions src/brevitas/quant/scaled_int.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
# SPDX-License-Identifier: BSD-3-Clause

from brevitas.core.function_wrapper import TensorClamp
from brevitas.core.scaling.standalone import ParameterFromRuntimeStatsScaling
from brevitas.quant.base import *
from brevitas.quant.solver.act import ActQuantSolver
from brevitas.quant.solver.bias import BiasQuantSolver
Expand Down Expand Up @@ -33,7 +32,9 @@
'Uint8ActPerTensorFloatBatchQuant1d',
'Int8ActPerTensorFloatBatchQuant1d',
'Uint8ActPerTensorFloatBatchQuant2d',
'Int8ActPerTensorFloatBatchQuant2d']
'Int8ActPerTensorFloatBatchQuant2d',
'Int8AccumulatorAwareWeightQuant',
'Int8WeightNormL2PerChannelFixedPoint']


class Int8ActPerTensorFloatMinMaxInit(IntQuant,
Expand Down Expand Up @@ -398,3 +399,37 @@ class Int8ActPerTensorFloatBatchQuant1d(IntQuant,
>>> act = QuantIdentity(act_quant=Int8ActPerTensorFloatBatchQuant1d)
"""
pass


class Int8WeightNormL2PerChannelFixedPoint(WeightNormPerChannelFloatDecoupled):
"""
Experimental 8-bit narrow signed integer quantizer with learned per-channel scaling factors
and L2 weight normalization based on `Quantized Neural Networks for Low-Precision Accumulation
with Guaranteed Overflow Avoidance` by I. Colbert, A. Pappalardo, and J. Petri-Koenig
(https://arxiv.org/abs/2301.13376). The quantizer learns scaling factors in the float domain and
learns vector parameter g in the log domain with the half-way rounding function. Suitable for
retraining from floating-point depthwise separable weights.
Examples:
>>> from brevitas.nn import QuantConv2d
>>> conv = QuantConv2d(4, 4, 3, groups=4, weight_quant=Int8WeightNormL2PerChannelFixedPoint)
>>> conv.quant_weight()
"""
bit_width = 8


class Int8AccumulatorAwareWeightQuant(AccumulatorAwareWeightQuant):
"""
Experimental 8-bit narrow signed accumulator-aware integer quantizer with learned per-channel
scaling factors based on `Quantized Neural Networks for Low-Precision Accumulation with Guaranteed
Overflow Avoidance` by I.Colbert, A.Pappalardo, and J.Petri-Koenig (https://arxiv.org/abs/2301.13376).
The quantizer learns scaling factors in the float domain and learns vector parameter g in the log
domain with the round-to-zero rounding function. The norm is clamped according the the specified
accumulator bit-width. Suitable for retraining from floating-point depthwise separable weights.
Examples:
>>> from brevitas.nn import QuantConv2d
>>> conv = QuantConv2d(4, 4, 3, groups=4, weight_quant=Int8AccumulatorAwareWeightQuant)
>>> conv.quant_weight()
"""
bit_width = 8
5 changes: 4 additions & 1 deletion src/brevitas_examples/super_resolution/eval_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import argparse
import json
import os
import pprint
import random

Expand All @@ -25,7 +26,7 @@
desc = """Evaluating single-image super resolution models on the BSD300 dataset.
Example:
>> python eval_model.py --data_root=data --model-path=outputs/model.pth --model=quant_espcn_x2_w8a8_base --upscale-factor=2
>> python eval_model.py --data_root=data --model=quant_espcn_x2_w8a8_a2q_16b --use_pretrained --export_to_qonnx
"""

parser = argparse.ArgumentParser(description='PyTorch BSD300 Validation')
Expand Down Expand Up @@ -64,6 +65,8 @@ def main():
test_psnr = evaluate_avg_psnr(testloader, model)
print(f"[{args.model}] test_psnr={test_psnr:.2f}")

os.makedirs(args.save_path, exist_ok=True)

# evaluate accumulator bit widths
if args.eval_acc_bw:
inp = testloader.dataset[0][0].unsqueeze(0).to(device)
Expand Down
2 changes: 1 addition & 1 deletion src/brevitas_examples/super_resolution/models/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@
from brevitas.core.scaling import ScalingImplType
import brevitas.nn as qnn
from brevitas.nn.quant_layer import WeightQuantType
from brevitas.quant import Int8AccumulatorAwareWeightQuant
from brevitas.quant import Int8ActPerTensorFloat
from brevitas.quant import Int8WeightPerTensorFloat
from brevitas.quant import Uint8ActPerTensorFloat
from brevitas.quant.fixed_point import Int8AccumulatorAwareWeightQuant


class CommonIntWeightPerChannelQuant(Int8WeightPerTensorFloat):
Expand Down
16 changes: 10 additions & 6 deletions src/brevitas_examples/super_resolution/utils/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# SPDX-License-Identifier: BSD-3-Clause

from argparse import Namespace
import warnings

import numpy as np
from torch import Tensor
Expand Down Expand Up @@ -38,12 +39,15 @@ def export(model: nn.Module, testloader: DataLoader, args: Namespace, opset_vers
opset_version=opset_version)
print(f"Saved QONNX model to {save_path}/qonnx_model.onnx")
if args.export_to_qcdq_onnx:
export_onnx_qcdq(
model.cpu(),
input_t=inp.cpu(),
export_path=f"{save_path}/qcdq_onnx_model.onnx",
opset_version=opset_version)
print(f"Saved QCDQ ONNX model to {save_path}/qcdq_onnx_model.onnx")
if opset_version < 13:
warnings.warn("Need opset 13+ to support per-channel quantization.")
else:
export_onnx_qcdq(
model.cpu(),
input_t=inp.cpu(),
export_path=f"{save_path}/qcdq_onnx_model.onnx",
opset_version=opset_version)
print(f"Saved QCDQ ONNX model to {save_path}/qcdq_onnx_model.onnx")
if args.export_to_qcdq_torch:
export_torch_qcdq(
model.cpu(), input_t=inp.cpu(), export_path=f"{save_path}/qcdq_torch_model.pt")
Expand Down
2 changes: 1 addition & 1 deletion tests/brevitas/export/quant_module_fixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@
from brevitas.nn import QuantIdentity
from brevitas.nn import QuantLinear
from brevitas.nn import TruncAvgPool2d
from brevitas.quant.fixed_point import Int8AccumulatorAwareWeightQuant
from brevitas.quant.fixed_point import Int8ActPerTensorFixedPoint
from brevitas.quant.fixed_point import Int8WeightPerChannelFixedPoint
from brevitas.quant.fixed_point import Int8WeightPerTensorFixedPoint
from brevitas.quant.scaled_int import Int8AccumulatorAwareWeightQuant
from brevitas.quant.scaled_int import Int8ActPerTensorFloat
from brevitas.quant.scaled_int import Int8BiasPerTensorFloatInternalScaling
from brevitas.quant.scaled_int import Int8WeightPerChannelFloat
Expand Down
4 changes: 2 additions & 2 deletions tests/brevitas/nn/nn_quantizers_fixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,12 @@
from brevitas.nn.quant_mha import QuantMultiheadAttention
from brevitas.nn.quant_rnn import QuantLSTM
from brevitas.nn.quant_rnn import QuantRNN
from brevitas.quant.fixed_point import Int8AccumulatorAwareWeightQuant
from brevitas.quant.fixed_point import Int8WeightNormL2PerChannelFixedPoint
from brevitas.quant.scaled_int import Int8AccumulatorAwareWeightQuant
from brevitas.quant.scaled_int import Int8ActPerTensorFloat
from brevitas.quant.scaled_int import Int8ActPerTensorFloatBatchQuant1d
from brevitas.quant.scaled_int import Int8ActPerTensorFloatBatchQuant2d
from brevitas.quant.scaled_int import Int8BiasPerTensorFloatInternalScaling
from brevitas.quant.scaled_int import Int8WeightNormL2PerChannelFixedPoint
from brevitas.quant.scaled_int import Int8WeightPerTensorFloat
from brevitas.quant.scaled_int import Int16Bias
from brevitas.quant.scaled_int import Uint8ActPerTensorFloat
Expand Down

0 comments on commit a236539

Please sign in to comment.