Skip to content

Commit

Permalink
Feat (quant_tensor): support for float QuantTensor and proxies (#919)
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 authored May 21, 2024
1 parent 3464ec7 commit eeffd2b
Show file tree
Hide file tree
Showing 31 changed files with 1,493 additions and 490 deletions.
84 changes: 46 additions & 38 deletions notebooks/01_quant_tensor_quant_conv2d_overview.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -157,11 +157,21 @@
},
{
"cell_type": "code",
"execution_count": 12,
"execution_count": 4,
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/giuseppe/miniconda3/envs/torch_2.1/lib/python3.11/site-packages/torch/_tensor.py:1394: UserWarning: Named tensors and all their associated APIs are an experimental feature and subject to change. Please do not use them for anything important until they are released as stable. (Triggered internally at /opt/conda/conda-bld/pytorch_1708025842427/work/c10/core/TensorImpl.h:1908.)\n",
" return super().rename(names)\n",
"/home/giuseppe/miniconda3/envs/torch_2.1/lib/python3.11/site-packages/torch/nn/modules/conv.py:456: UserWarning: Defining your `__torch_function__` as a plain method is deprecated and will be an error in future, please define it as a classmethod. (Triggered internally at /opt/conda/conda-bld/pytorch_1708025842427/work/torch/csrc/utils/python_arg_parser.cpp:294.)\n",
" return F.conv2d(input, weight, bias, self.stride,\n"
]
},
{
"data": {
"text/plain": [
Expand Down Expand Up @@ -255,7 +265,7 @@
{
"data": {
"text/plain": [
"QuantTensor(value=tensor([[[[-0.0018, 0.1273, -0.1937],\n",
"IntQuantTensor(value=tensor([[[[-0.0018, 0.1273, -0.1937],\n",
" [-0.1734, -0.0904, 0.0627],\n",
" [-0.0055, 0.1863, -0.0203]],\n",
"\n",
Expand Down Expand Up @@ -377,8 +387,6 @@
}
],
"source": [
"from brevitas.quant_tensor import QuantTensor\n",
"\n",
"quant_act = QuantIdentity(return_quant_tensor=True)\n",
"\n",
"out_tensor_0 = quant_act(torch.randn(1,2,5,5))\n",
Expand Down Expand Up @@ -407,7 +415,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"QuantTensor(value=tensor([[[[-0.1106, 1.1945, -0.4972, -2.0968, 0.7175],\n",
"IntQuantTensor(value=tensor([[[[-0.1106, 1.1945, -0.4972, -2.0968, 0.7175],\n",
" [-2.5901, 0.0588, -0.2014, 2.1486, 1.6435],\n",
" [ 0.9067, -2.5212, 2.2193, 0.2352, -0.8395],\n",
" [-0.8351, 0.6341, -0.5551, 0.1040, -3.3151],\n",
Expand Down Expand Up @@ -467,7 +475,7 @@
{
"data": {
"text/plain": [
"QuantTensor(value=tensor([[[[0.5191, 0.6402],\n",
"IntQuantTensor(value=tensor([[[[0.5191, 0.6402],\n",
" [2.1455, 0.5883]],\n",
"\n",
" [[2.0417, 0.5883],\n",
Expand Down Expand Up @@ -506,7 +514,7 @@
"name": "stderr",
"output_type": "stream",
"text": [
"/tmp/ipykernel_4048/1377665000.py:1: UserWarning: Defining your `__torch_function__` as a plain method is deprecated and will be an error in future, please define it as a classmethod. (Triggered internally at /opt/conda/conda-bld/pytorch_1699449183005/work/torch/csrc/utils/python_arg_parser.cpp:368.)\n",
"/tmp/ipykernel_528161/1377665000.py:1: UserWarning: Defining your `__torch_function__` as a plain method is deprecated and will be an error in future, please define it as a classmethod. (Triggered internally at /opt/conda/conda-bld/pytorch_1708025842427/work/torch/csrc/utils/python_arg_parser.cpp:294.)\n",
" torch.tanh(quant_tensor)\n"
]
},
Expand Down Expand Up @@ -555,7 +563,7 @@
{
"data": {
"text/plain": [
"QuantTensor(value=tensor([[[[-0.3568, -0.1883, 0.3589],\n",
"IntQuantTensor(value=tensor([[[[-0.3568, -0.1883, 0.3589],\n",
" [-0.4470, 0.1039, -0.3945],\n",
" [-0.4190, 0.3723, 0.8384]],\n",
"\n",
Expand All @@ -565,7 +573,7 @@
"\n",
" [[ 0.2734, 0.7268, -0.0249],\n",
" [-0.1732, 0.5197, 1.1158],\n",
" [ 0.3771, -0.3810, 0.2008]]]], grad_fn=<ConvolutionBackward0>), scale=tensor([[[[3.1958e-05]]]], grad_fn=<MulBackward0>), zero_point=tensor(0.), bit_width=tensor(21.), signed_t=tensor(True), training_t=tensor(True))"
" [ 0.3771, -0.3810, 0.2008]]]], grad_fn=<ConvolutionBackward0>), scale=tensor([[[[3.1958e-05]]]], grad_fn=<MulBackward0>), zero_point=tensor([0.]), bit_width=tensor(21.), signed_t=tensor(True), training_t=tensor(True))"
]
},
"execution_count": 14,
Expand Down Expand Up @@ -618,39 +626,39 @@
},
{
"cell_type": "code",
"execution_count": 16,
"execution_count": 17,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"QuantTensor(value=tensor([[[[ 7.2000e-03, -3.7000e-03, 7.7000e-03, -2.4000e-03, -8.9000e-03],\n",
" [-1.2000e-02, -8.1000e-03, 7.2000e-03, -1.1300e-02, -9.7000e-03],\n",
" [-1.0000e-03, 1.0100e-02, 3.8000e-03, -1.1900e-02, 6.9000e-03],\n",
" [ 8.3000e-03, 1.0000e-04, -6.9000e-03, 3.9000e-03, -5.4000e-03],\n",
" [ 1.1300e-02, -6.0000e-03, 9.7000e-03, 0.0000e+00, 1.0900e-02]],\n",
"IntQuantTensor(value=tensor([[[[-9.9000e-03, -7.1000e-03, -4.7000e-03, 5.0000e-03, -1.2300e-02],\n",
" [-8.2000e-03, 8.5000e-03, -1.2000e-03, -1.2500e-02, 4.4000e-03],\n",
" [ 4.3000e-03, -6.3000e-03, -9.4000e-03, 1.0400e-02, -1.2100e-02],\n",
" [ 1.1700e-02, -3.6000e-03, 5.3000e-03, -1.1700e-02, -4.3000e-03],\n",
" [-8.8000e-03, 1.0900e-02, -8.3000e-03, -2.9000e-03, 1.2400e-02]],\n",
"\n",
" [[-1.0900e-02, 1.1400e-02, -6.4000e-03, 9.2000e-03, 7.1000e-03],\n",
" [-6.0000e-04, 9.2000e-03, -8.5000e-03, 5.0000e-03, 6.5000e-03],\n",
" [-8.3000e-03, -1.2000e-03, 7.4000e-03, 9.2000e-03, -6.0000e-04],\n",
" [-2.1000e-03, 9.5000e-03, 3.0000e-04, -2.9000e-03, -6.5000e-03],\n",
" [-1.1800e-02, -4.8000e-03, 5.4000e-03, -2.5000e-03, 9.0000e-04]]]]), scale=tensor(1.0000e-04), zero_point=tensor(0.), bit_width=tensor(8.), signed_t=tensor(True), training_t=tensor(True))"
" [[ 9.3000e-03, -8.5000e-03, 6.5000e-03, -2.7000e-03, -3.4000e-03],\n",
" [-1.0000e-04, -1.1000e-02, 8.3000e-03, 1.9000e-03, -9.8000e-03],\n",
" [ 4.3000e-03, -8.5000e-03, 1.1000e-02, 5.3000e-03, 3.4000e-03],\n",
" [ 8.1000e-03, 9.8000e-03, 6.8000e-03, 1.5000e-03, 6.3000e-03],\n",
" [ 5.7000e-03, -8.5000e-03, 5.2000e-03, -3.0000e-04, 4.9000e-03]]]]), scale=tensor(1.0000e-04), zero_point=tensor(0.), bit_width=tensor(8.), signed_t=tensor(True), training_t=tensor(True))"
]
},
"execution_count": 16,
"execution_count": 17,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from brevitas.quant_tensor import QuantTensor\n",
"from brevitas.quant_tensor import IntQuantTensor\n",
"\n",
"scale = 0.0001\n",
"bit_width = 8\n",
"zero_point = 0.\n",
"int_value = torch.randint(low=- 2 ** (bit_width - 1), high=2 ** (bit_width - 1) - 1, size=(1, 2, 5, 5))\n",
"quant_value = (int_value - zero_point) * scale\n",
"quant_tensor_input = QuantTensor(\n",
"quant_tensor_input = IntQuantTensor(\n",
" quant_value, \n",
" scale=torch.tensor(scale), \n",
" zero_point=torch.tensor(zero_point), \n",
Expand All @@ -662,7 +670,7 @@
},
{
"cell_type": "code",
"execution_count": 17,
"execution_count": null,
"metadata": {},
"outputs": [
{
Expand All @@ -688,7 +696,7 @@
},
{
"cell_type": "code",
"execution_count": 18,
"execution_count": null,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -721,7 +729,7 @@
},
{
"cell_type": "code",
"execution_count": 19,
"execution_count": null,
"metadata": {},
"outputs": [
{
Expand All @@ -745,7 +753,7 @@
},
{
"cell_type": "code",
"execution_count": 20,
"execution_count": null,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -784,7 +792,7 @@
},
{
"cell_type": "code",
"execution_count": 21,
"execution_count": null,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -820,7 +828,7 @@
},
{
"cell_type": "code",
"execution_count": 22,
"execution_count": null,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -856,7 +864,7 @@
},
{
"cell_type": "code",
"execution_count": 23,
"execution_count": null,
"metadata": {
"tags": [
"raises-exception"
Expand Down Expand Up @@ -897,7 +905,7 @@
},
{
"cell_type": "code",
"execution_count": 24,
"execution_count": null,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -935,7 +943,7 @@
},
{
"cell_type": "code",
"execution_count": 25,
"execution_count": null,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -968,7 +976,7 @@
},
{
"cell_type": "code",
"execution_count": 26,
"execution_count": null,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -1007,7 +1015,7 @@
},
{
"cell_type": "code",
"execution_count": 27,
"execution_count": null,
"metadata": {
"tags": [
"raises-exception"
Expand Down Expand Up @@ -1049,7 +1057,7 @@
},
{
"cell_type": "code",
"execution_count": 28,
"execution_count": null,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -1093,7 +1101,7 @@
},
{
"cell_type": "code",
"execution_count": 29,
"execution_count": null,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -1131,7 +1139,7 @@
},
{
"cell_type": "code",
"execution_count": 30,
"execution_count": null,
"metadata": {},
"outputs": [
{
Expand All @@ -1155,7 +1163,7 @@
},
{
"cell_type": "code",
"execution_count": 31,
"execution_count": null,
"metadata": {},
"outputs": [
{
Expand Down
1 change: 1 addition & 0 deletions requirements/requirements-test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@ pytest-xdist
pytest_cases
scipy
torchvision
tqdm
2 changes: 1 addition & 1 deletion src/brevitas/core/function_wrapper/clamp.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,4 +149,4 @@ def forward(
"Clamping is not saturating, but neither `inf_values` nor `nan_values` is specified"
)

return x
return x, self.saturating, self.inf_values, self.nan_values
4 changes: 2 additions & 2 deletions src/brevitas/core/quant/float.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,8 @@ def dequantize(self, y, scale):
def forward(self, x):
y, scale = self.quantize(x)
# after quantizing, clamp to special cases like NaN/inf if they are set
y = self.float_clamp_impl(
y, saturating, inf_values, nan_values = self.float_clamp_impl(
y, self.exponent_bit_width(), self.mantissa_bit_width(), self.exponent_bias())
y = self.dequantize(y, scale)
# This is to respect the current interface of proxies
return y, scale, self.zero_point_impl(), self.bit_width()
return y, scale, self.zero_point_impl(), self.exponent_bit_width(), self.mantissa_bit_width(), self.exponent_bias(), saturating, inf_values, nan_values
4 changes: 2 additions & 2 deletions src/brevitas/fx/value_tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@
import torch.utils._pytree as pytree

from brevitas import torch_version
from brevitas.quant_tensor import QuantTensorBase
from brevitas.quant_tensor import QuantTensor

from . import *
from . import _assert_is_none
Expand All @@ -82,7 +82,7 @@
from . import ScopeContextManager

_UNSET = object()
extended_base_types = base_types + (QuantTensorBase,)
extended_base_types = base_types + (QuantTensor,)

FRAME_FILES = [
'fx/brevitas_tracer.py',
Expand Down
8 changes: 4 additions & 4 deletions src/brevitas/graph/calibrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
from brevitas.nn import QuantHardTanh
from brevitas.nn import QuantLinear
from brevitas.nn.quant_layer import QuantWeightBiasInputOutputLayer as QuantWBIOL
from brevitas.proxy.parameter_quant import BiasQuantProxyFromInjector
from brevitas.proxy.parameter_quant import WeightQuantProxyFromInjector
from brevitas.proxy.parameter_quant import BiasQuantProxyFromInjectorBase
from brevitas.proxy.parameter_quant import WeightQuantProxyFromInjectorBase
from brevitas.proxy.runtime_quant import ActQuantProxyFromInjector
from brevitas.proxy.runtime_quant import ClampQuantProxyFromInjector
from brevitas.proxy.runtime_quant import TruncQuantProxyFromInjector
Expand All @@ -29,9 +29,9 @@
'calibration_mode',
'load_quant_model_mode']

_PARAM_PROXIES = (WeightQuantProxyFromInjector, BiasQuantProxyFromInjector)
_PARAM_PROXIES = (WeightQuantProxyFromInjectorBase, BiasQuantProxyFromInjectorBase)

_BIAS_PROXIES = (BiasQuantProxyFromInjector)
_BIAS_PROXIES = (BiasQuantProxyFromInjectorBase)

_ACC_PROXIES = (TruncQuantProxyFromInjector, ClampQuantProxyFromInjector)

Expand Down
2 changes: 1 addition & 1 deletion src/brevitas/graph/gpfq.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,7 @@ def single_layer_update(self):
# raise error in case no quant-input is here
if self.quant_metadata is None:
raise ValueError('Expected self.quant_metadata to calculate L1-norm upper bound, but recevied None. ' + \
'Make sure that either the input to the model is a QuantTensor or the layer has an input quant enabled. ' \
'Make sure that either the input to the model is a IntQuantTensor or the layer has an input quant enabled. ' \
'Also, check if `use_quant_activations=True` in `gpfq_mode` when `accumulator_bit_width` is specified. ' + \
'Alternatively, provide a custom `a2q_layer_filter_fnc` to `gpfq_mode` to filter layers without a quant_tensor input.')
weight = self.layer.weight.data
Expand Down
6 changes: 3 additions & 3 deletions src/brevitas/graph/gpxq.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from brevitas.graph.calibrate import DisableEnableQuantization
from brevitas.graph.calibrate import restore_return_quant_tensor
import brevitas.nn as qnn
from brevitas.quant_tensor import QuantTensor
from brevitas.quant_tensor import IntQuantTensor
from brevitas.utils.quant_utils import _CachedIO

SUPPORTED_CONV_OP = (
Expand Down Expand Up @@ -227,9 +227,9 @@ def process_input(self, inp):

is_quant_enabled = self.layer.weight_quant.is_quant_enabled

# If using quantized activations, inp could be QuantTensor. In
# If using quantized activations, inp could be IntQuantTensor. In
# this case, we overwrite the metadata.
if isinstance(inp, QuantTensor):
if isinstance(inp, IntQuantTensor):
if is_quant_enabled and self.quant_metadata is None:
self.quant_metadata = _CachedIO(inp, metadata_only=True)
inp = inp.value
Expand Down
Loading

0 comments on commit eeffd2b

Please sign in to comment.