Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: Support for Groupwise (MX) quantization #971

Merged
merged 37 commits into from
Aug 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
51e8b12
Feat: Support for Groupwise (MX) quantization
Giuseppe5 May 31, 2024
6afe824
Fix
Giuseppe5 Jun 17, 2024
21a1ce2
default groupsize
Giuseppe5 Jun 17, 2024
e4be937
Quantizers
Giuseppe5 Jun 17, 2024
b1f9a4b
update
Giuseppe5 Jun 18, 2024
b2a1f14
Fix for keepdim
Giuseppe5 Jun 19, 2024
d2f48cf
Missing file
Giuseppe5 Jun 19, 2024
5d207b7
Update
Giuseppe5 Jun 19, 2024
ac07eee
notebook
Giuseppe5 Jun 19, 2024
cb7adaa
Fix tests
Giuseppe5 Jun 19, 2024
d0e2977
more quantizers
Giuseppe5 Jun 19, 2024
eb6be38
New enum for groupwise scaling
Giuseppe5 Jul 4, 2024
82283d0
fix
Giuseppe5 Jul 4, 2024
26fd171
fix
Giuseppe5 Jul 4, 2024
633a1f7
fix notebooks
Giuseppe5 Jul 4, 2024
b539e3f
Fix order
Giuseppe5 Jul 4, 2024
cd571de
update notebook
Giuseppe5 Jul 10, 2024
b76a54e
Fix
Giuseppe5 Jul 30, 2024
a7f772c
Better retrocompatibility
Giuseppe5 Jul 30, 2024
1d3c152
Better condition
Giuseppe5 Jul 30, 2024
9dccd47
Solving order
Giuseppe5 Jul 30, 2024
d3412fa
Fix dep inj
Giuseppe5 Jul 30, 2024
8573ce2
Fix for MSE
Giuseppe5 Jul 30, 2024
2a777e5
fix concat dim
Giuseppe5 Jul 30, 2024
ad24bc6
fix max_ave
Giuseppe5 Aug 1, 2024
f7d2736
last fix
Giuseppe5 Aug 1, 2024
d4be537
Rename
Giuseppe5 Aug 1, 2024
2626de1
GroupwiseInt + removed comments
Giuseppe5 Aug 2, 2024
d6c1b95
Missing file
Giuseppe5 Aug 2, 2024
fbc9e55
Fix for brevitas_examples
Giuseppe5 Aug 12, 2024
ad9eae3
Updated examples
Giuseppe5 Aug 12, 2024
bdbfd17
Updated groupwise int quant tensor and notebook
Giuseppe5 Aug 13, 2024
145f6da
Review, notebook missing
Giuseppe5 Aug 20, 2024
6a62d3d
Updated notebooks
Giuseppe5 Aug 20, 2024
ab8bf37
Typo + ocp/fnuz quantizers
Giuseppe5 Aug 20, 2024
8ed14c6
Updated README
Giuseppe5 Aug 20, 2024
5aeba52
Update minifloat_mx_tutorial.ipynb
Giuseppe5 Aug 20, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
170 changes: 85 additions & 85 deletions notebooks/Brevitas_TVMCon2021.ipynb

Large diffs are not rendered by default.

285 changes: 285 additions & 0 deletions notebooks/minifloat_mx_tutorial.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,285 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Minifloat and Groupwise quantization"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This notebook shows some practical use cases for minifloat and groupwise quantization.\n",
"\n",
"Brevitas supports a wide combination of float quantization, including the OCP and FNUZ FP8 standard.\n",
"It is possible to define any combination of exponent/mantissa bitwidth, as well as exponent bias.\n",
"\n",
"Similarly, MX quantization is supported as general groupwise quantization on top of integer/minifloat datatype.\n",
"This allows to any general groupwise quantization, including MXInt and MXFloat standards.\n",
"\n",
"This tutorial shows how to instantiate and use some of the most interesting quantizers for minifloat and groupwise quantization"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Minifloat (FP8 and lower)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Brevitas offers some pre-defined quantizers for minifloat quantization, including OCP and FNUZ standards, which can be further customized according to the specific use case.\n",
"The general naming structure for the quantizers is the following:\n",
"\n",
"`Fp\\<Bitwidth\\>\\<Standard\\>Weight\\<Scaling\\>Float`\n",
"\n",
"Where `Bitwidth` can be either empty or `e4m3`/`e5m2`, `Standard` can be empty or `OCP`/`FNUZ`, `Scaling` can be empty or `PerTensor`/`PerChannel`.\n",
"\n",
"If `Bitwidth` is empty, the user must set it with kwargs or by subclassing the quantizers. Once the bitwidth is defined, the correct values for inf/nan are automatically defined based on the `Standard`.\n",
"If a non-valid OCP bitwidth is set (e.g., e6m1), then no inf/nan values will be selected and the corresponding quantizer is not standard-compliant.\n",
"\n",
"`Standard` allows to pick among the two main FP8 standard options; moreover, if not specified, Brevitas offers the possibility of doing minifloat quantization without necessarily reserving values for inf/nan representation.\n",
"This allows to use the maximum available range, since often in quantization, values that exceed the quantization range saturate to maximum rather than going to inf/nan.\n",
"FNUZ quantizers need to have `saturating=True`.\n",
"\n",
"The `Scaling` options defines whether the quantization is _scaled_ or _unscaled_.\n",
"In the unscaled case, the scale factor for quantization is fixed to one, otherwise it can be set using any of the methods that Brevitas includes (e.g., statistics, learned, etc.)\n",
"\n",
"\n",
"Please keep in mind that not all combinations of the above options might be pre-defined and this serves mostly as indications of what Brevitas supports.\n",
"It is possible, following the same structure of the available quantizers, to define new ones that fit your needs.\n",
"\n",
"\n",
"Similar considerations can be extended for activation quantization."
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/giuseppe/miniconda3/envs/brevitas_dev/lib/python3.11/site-packages/torch/nn/modules/conv.py:456: UserWarning: Defining your `__torch_function__` as a plain method is deprecated and will be an error in future, please define it as a classmethod. (Triggered internally at /opt/conda/conda-bld/pytorch_1712608853099/work/torch/csrc/utils/python_arg_parser.cpp:294.)\n",
" return F.conv2d(input, weight, bias, self.stride,\n"
]
}
],
"source": [
"from brevitas.quant.experimental.float_base import Fp8e4m3Mixin\n",
"from brevitas.quant.experimental.mx_quant_ocp import MXFloat8e4m3Weight\n",
"from brevitas.quant.experimental.float_quant_ocp import FpOCPWeightPerTensorFloat, FpOCPActPerTensorFloat\n",
"from brevitas.quant.experimental.mx_quant_ocp import MXFloat8e4m3Act\n",
"import brevitas.nn as qnn\n",
"import torch.nn as nn\n",
"import torch\n",
"from brevitas.quant_tensor import FloatQuantTensor\n",
"\n",
"class OCPFP8Weight(FpOCPWeightPerTensorFloat, Fp8e4m3Mixin):\n",
" pass\n",
"\n",
"\n",
"class OCPFP8Act(FpOCPActPerTensorFloat, Fp8e4m3Mixin):\n",
" pass\n",
"\n",
"\n",
"class FP8Model(nn.Module):\n",
" def __init__(self):\n",
" super().__init__()\n",
" self.conv = qnn.QuantConv2d(32, 64, 3, weight_quant=OCPFP8Weight, input_quant=OCPFP8Act)\n",
" \n",
" def forward(self, x):\n",
" return self.conv(x)\n",
"\n",
"ocp_fp8_model = FP8Model()\n",
"x = torch.randn(1, 32, 8, 8)\n",
"ocp_fp8_model.eval()\n",
"o = ocp_fp8_model(x)\n",
"\n",
"intermediate_input = ocp_fp8_model.conv.input_quant(x)\n",
"assert isinstance(intermediate_input, FloatQuantTensor)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Groupwise quantization (MXInt/MXFloat)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Groupwise quantization is built on top of integer/minifloat quantization, with special considerations to accomodate for the groupwise scaling.\n",
"\n",
"Compared to Int/Float QuantTensor, the main difference of their groupwise equivalent is that value, scale, and zero_point are not direct attributes anymore but properties. The new attributes are value_, scale_, and zero_point_.\n",
"\n",
"The reason for this is shaping. When quantizing a tensor with shapes [O, I], where O is output channel and I is input channel, with groupsize k, groupwise quantization is normally represented as follow:\n",
"\n",
"- Tensor with shapes [O, k, I/k]\n",
"- Scales with shapes [O, k, 1]\n",
"- Zero point same as scale\n",
"\n",
"The alternative to this representation is to have all three tensors with shapes [O,I], with a massive increase in memory utilization, especially with QAT + gradients.\n",
"\n",
"The underscored attributes will have the compressed shapes, while the properties (non-underscored naming) will dynamically compute the expanded version of the property. This means:\n",
"```python\n",
"quant_tensor.scale_.shape\n",
"# This will print [O, k, 1]\n",
"quant_tensor.scale.shape\n",
"# This will print [O, I]\n",
"```\n",
"\n",
"With respect to pre-defined quantizers, Brevitas offers several Groupwise and MX options.\n",
"The main difference between the two is that MX is restricted to group_size=32 and the scale factor must be a power-of-2.\n",
"The user can override these settings but the corresponding output won't be MX compliant.\n",
"\n",
"Another difference is that MXFloat relies on the OCP format as underlying data type, while generic groupwise float relies on the non-standard minifloat representation explained above.\n",
"\n",
"Finally, the general groupwise scaling relies on float scales."
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"from brevitas.quant_tensor import GroupwiseFloatQuantTensor\n",
"\n",
"\n",
"class MXFloat8Weight(MXFloat8e4m3Weight, Fp8e4m3Mixin):\n",
" # The group dimension for the weights it is automatically identified based on the layer type\n",
" # If a new layer type is used, it can be manually specified\n",
" pass\n",
"\n",
"class MXFloat8Act(MXFloat8e4m3Act, Fp8e4m3Mixin):\n",
" # It is necessary to specify the group dimension for the activation quantization\n",
" group_dim = 1\n",
"\n",
"\n",
"class MXModel(nn.Module):\n",
" def __init__(self):\n",
" super().__init__()\n",
" self.conv = qnn.QuantConv2d(32, 64, 3, weight_quant=MXFloat8Weight, input_quant=MXFloat8Act)\n",
" \n",
" def forward(self, x):\n",
" return self.conv(x)\n",
"\n",
"mx_model = MXModel()\n",
"x = torch.randn(1, 32, 8, 8)\n",
"mx_model.eval()\n",
"o = mx_model(x)\n",
"\n",
"intermediate_input = mx_model.conv.input_quant(x)\n",
"assert isinstance(intermediate_input, GroupwiseFloatQuantTensor)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"from brevitas.quant.experimental.mx_quant_ocp import MXFloat8e4m3WeightMSE\n",
"from brevitas.quant_tensor import GroupwiseFloatQuantTensor\n",
"\n",
"\n",
"class MXFloat8Weight(MXFloat8e4m3WeightMSE, Fp8e4m3Mixin):\n",
" # The group dimension for the weights it is automatically identified based on the layer type\n",
" # If a new layer type is used, it can be manually specified\n",
" pass\n",
"\n",
"class MXFloat8Act(MXFloat8e4m3Act, Fp8e4m3Mixin):\n",
" # It is necessary to specify the group dimension for the activation quantization\n",
" group_dim = 1\n",
"\n",
"\n",
"class MXModel(nn.Module):\n",
" def __init__(self):\n",
" super().__init__()\n",
" self.conv = qnn.QuantConv2d(32, 64, 3, weight_quant=MXFloat8Weight, input_quant=MXFloat8Act)\n",
" \n",
" def forward(self, x):\n",
" return self.conv(x)\n",
"\n",
"mx_model = MXModel()\n",
"x = torch.randn(1, 32, 8, 8)\n",
"mx_model.eval()\n",
"o = mx_model(x)\n",
"\n",
"intermediate_input = mx_model.conv.input_quant(x)\n",
"assert isinstance(intermediate_input, GroupwiseFloatQuantTensor)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"from brevitas.quant_tensor import GroupwiseIntQuantTensor\n",
"from brevitas.quant.experimental.mx_quant_ocp import MXInt8Weight\n",
"from brevitas.quant.experimental.mx_quant_ocp import MXInt8Act\n",
"import torch.nn as nn\n",
"import brevitas.nn as qnn\n",
"import torch\n",
"\n",
"class MXFloat8Weight(MXInt8Weight):\n",
" # The group dimension for the weights it is automatically identified based on the layer type\n",
" # If a new layer type is used, it can be manually specified\n",
" bit_width = 8\n",
"\n",
"class MXFloat8Act(MXInt8Act):\n",
" # It is necessary to specify the group dimension for the activation quantization\n",
" group_dim = 1\n",
" bit_width = 8\n",
"\n",
"class MXModel(nn.Module):\n",
" def __init__(self):\n",
" super().__init__()\n",
" self.conv = qnn.QuantConv2d(32, 64, 3, weight_quant=MXFloat8Weight, input_quant=MXFloat8Act)\n",
" \n",
" def forward(self, x):\n",
" return self.conv(x)\n",
"\n",
"mx_model = MXModel()\n",
"x = torch.randn(1, 32, 8, 8)\n",
"mx_model.eval()\n",
"o = mx_model(x)\n",
"\n",
"intermediate_input = mx_model.conv.input_quant(x)\n",
"assert isinstance(intermediate_input, GroupwiseIntQuantTensor)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "brevitas_dev",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.9"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
19 changes: 19 additions & 0 deletions src/brevitas/core/function_wrapper/shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,24 @@ def forward(self, x: torch.Tensor):
return y.reshape(shape)


class OverSubChannelBlockView(brevitas.jit.ScriptModule):
__constants__ = ['expanded_scaling_shape']

def __init__(self, expanded_scaling_shape, permute_dims: Optional[Tuple[int, ...]]) -> None:
super(OverSubChannelBlockView, self).__init__()
self.expanded_scaling_shape = expanded_scaling_shape
if permute_dims is not None:
self.permute_impl = PermuteDims(permute_dims)
else:
self.permute_impl = torch.nn.Identity()

@brevitas.jit.script_method
def forward(self, x: torch.Tensor):
y = self.permute_impl(x)
y = y.view(self.expanded_scaling_shape)
return y


class StatsInputViewShapeImpl(object):
"""
Enum-like object to collect pointers to variants of ScriptModules that perform a view on a tensor.
Expand All @@ -163,3 +181,4 @@ class StatsInputViewShapeImpl(object):
OVER_BATCH_OVER_TENSOR = OverBatchOverTensorView
OVER_BATCH_OVER_OUTPUT_CHANNELS = OverBatchOverOutputChannelView
OVER_OUTPUT_FEATURES = OverOutputFeaturesView
OVER_SUBCHANNEL_BLOCK = OverSubChannelBlockView
2 changes: 2 additions & 0 deletions src/brevitas/core/scaling/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from brevitas.inject.enum import ScalingImplType
from brevitas.inject.enum import ScalingPerOutputType

assert ScalingImplType
assert ScalingPerOutputType

from brevitas.core.stats import SCALAR_SHAPE

Expand Down
35 changes: 35 additions & 0 deletions src/brevitas/core/scaling/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,3 +158,38 @@ def _load_from_state_dict(
missing_keys.remove(affine_weight_key)
if config.IGNORE_MISSING_KEYS and affine_bias_key in missing_keys:
missing_keys.remove(affine_bias_key)


class RuntimeDynamicGroupStatsScaling(brevitas.jit.ScriptModule):

def __init__(
self,
group_size: int,
group_dim: int,
scaling_stats_impl: torch.nn.Module,
scaling_min_val: Optional[float],
restrict_scaling_impl: Optional[torch.nn.Module]) -> None:
super(RuntimeDynamicGroupStatsScaling, self).__init__()
self.group_size = group_size
self.group_dim = group_dim
self.scaling_stats_impl = scaling_stats_impl
self.scaling_min_val = scaling_min_val
self.restrict_clamp_scaling = _RestrictClampValue(scaling_min_val, restrict_scaling_impl)

@brevitas.jit.script_method
def group_scaling_reshape(self, stats_input):
tensor_shape = stats_input.shape
tensor_shape_list = list(tensor_shape)
tensor_shape_list[self.group_dim] = int(tensor_shape_list[self.group_dim] / self.group_size)
block_dim = self.group_dim + 1 if self.group_dim != -1 else -1
tensor_shape_list.insert(block_dim, self.group_size)
stats_input = stats_input.view(tensor_shape_list)
return stats_input

@brevitas.jit.script_method
def forward(self, stats_input) -> torch.Tensor:
stats_input_reshaped = self.group_scaling_reshape(stats_input)
out = self.scaling_stats_impl(stats_input_reshaped)
# Scaling min val
out = self.restrict_clamp_scaling(out)
return out
3 changes: 2 additions & 1 deletion src/brevitas/export/onnx/standard/qcdq/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from brevitas.export.common.handler.qcdq import QMixin
from brevitas.export.onnx.handler import ONNXBaseHandler
from brevitas.export.onnx.handler import QuantLSTMLayerHandler
from brevitas.inject.enum import ScalingPerOutputType

from ..function import CastFn
from ..function import DequantizeLinearFn
Expand Down Expand Up @@ -133,7 +134,7 @@ def validate(self, module):
# Below 8b quantization is not supported.
self.validate_8b_bit_width(module.bit_width(), le_then=False)
# Only per tensor quantization is supported
assert not module.quant_injector.scaling_per_output_channel, "Only per tensor scaling supported"
assert module.quant_injector.scaling_per_output == ScalingPerOutputType.TENSOR, "Only per tensor scaling supported"

def quantize_fn(self, x, dtype):
return DynamicQuantizeLinearFn.apply(x, dtype)
Expand Down
2 changes: 0 additions & 2 deletions src/brevitas/graph/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from brevitas.graph.fixed_point import CollapseConsecutiveConcats
from brevitas.graph.fixed_point import MergeBatchNorm
from brevitas.graph.fixed_point import MoveSplitBatchNormBeforeCat
from brevitas.graph.per_input import AdaptiveAvgPoolToAvgPool
from brevitas.graph.quantize_impl import act_handler
from brevitas.graph.quantize_impl import add_output_quant_handler
from brevitas.graph.quantize_impl import inp_placeholder_handler
Expand All @@ -25,7 +24,6 @@
from brevitas.graph.standardize import MeanMethodToAdaptiveAvgPool2d
from brevitas.graph.standardize import RemoveStochasticModules
from brevitas.graph.standardize import TorchFunctionalToModule
from brevitas.nn import quant_layer
import brevitas.nn as qnn
from brevitas.quant import Int8ActPerTensorFloat
from brevitas.quant import Int8ActPerTensorFloatMinMaxInit
Expand Down
Loading
Loading