-
Notifications
You must be signed in to change notification settings - Fork 195
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Feat (quant): initial support for fp8 variants
- Loading branch information
Showing
6 changed files
with
312 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. | ||
# SPDX-License-Identifier: BSD-3-Clause | ||
|
||
from typing import Optional | ||
|
||
import torch | ||
import torch.nn as nn | ||
|
||
import brevitas | ||
from brevitas.core.function_wrapper import RoundSte | ||
from brevitas.core.scaling import ConstScaling | ||
from brevitas.core.utils import StatelessBuffer | ||
from brevitas.function.ops import max_float | ||
from brevitas.function.ops_ste import floor_ste | ||
|
||
|
||
class FloatQuant(brevitas.jit.ScriptModule): | ||
__constants__ = ['signed'] | ||
|
||
def __init__( | ||
self, | ||
bit_width: int, | ||
signed: bool, | ||
exponent_bit_width: int, | ||
mantissa_bit_width: int, | ||
exponent_bias: Optional[int] = None, | ||
scaling_impl: Optional[nn.Module] = None, | ||
float_scaling_impl: Optional[nn.Module] = None, | ||
float_to_int_impl: nn.Module = RoundSte(), | ||
device: Optional[str] = None, | ||
dtype: Optional[torch.dtype] = None): | ||
super(FloatQuant, self).__init__() | ||
if bit_width != exponent_bit_width + mantissa_bit_width + int(signed): | ||
raise RuntimeError("Mismatch between total bit-width, exponent, mantissa and sign.") | ||
self.bit_width = StatelessBuffer(torch.tensor(float(bit_width), device=device, dtype=dtype)) | ||
self.signed: bool = signed | ||
self.float_to_int_impl = float_to_int_impl | ||
self.exponent_bit_width = StatelessBuffer( | ||
torch.tensor(float(exponent_bit_width), device=device, dtype=dtype)) | ||
self.mantissa_bit_width = StatelessBuffer( | ||
(torch.tensor(float(mantissa_bit_width), device=device, dtype=dtype))) | ||
self.exponent_bias = StatelessBuffer( | ||
torch.tensor(float(exponent_bias), device=device, dtype=dtype)) | ||
self.fp_max_val = StatelessBuffer( | ||
max_float(self.exponent_bit_width(), self.mantissa_bit_width(), self.exponent_bias())) | ||
self.fp_internal_scale_min = StatelessBuffer( | ||
1. - self.exponent_bias() - self.mantissa_bit_width()) | ||
if float_scaling_impl is None: | ||
float_scaling_impl = ConstScaling(1., device=device, dtype=dtype) | ||
if scaling_impl is None: | ||
scaling_impl = ConstScaling(1., device=device, dtype=dtype) | ||
# Zero-point is currently hardcoded to 0 | ||
self.zero_point_impl = StatelessBuffer(torch.tensor(0., device=device, dtype=dtype)) | ||
self.float_scaling_impl = float_scaling_impl | ||
self.scaling_impl = scaling_impl | ||
|
||
@brevitas.jit.script_method | ||
def internal_scale(self, x): | ||
internal_scale = floor_ste(torch.log2(torch.abs(x))) - self.mantissa_bit_width() | ||
internal_scale = torch.clamp_min(internal_scale, self.fp_internal_scale_min()) | ||
internal_scale = torch.exp2(internal_scale) | ||
return internal_scale | ||
|
||
@brevitas.jit.script_method | ||
def quantize(self, x: torch.Tensor): | ||
scale = self.scaling_impl(x) / self.float_scaling_impl(x) | ||
scaled_x = x / scale | ||
internal_scale = self.internal_scale(scaled_x) | ||
val_fp_quant = internal_scale * self.float_to_int_impl(scaled_x / internal_scale) | ||
if self.signed: | ||
val_fp_quant = torch.clip(val_fp_quant, -1. * self.fp_max_val(), self.fp_max_val()) | ||
else: | ||
val_fp_quant = torch.clip(val_fp_quant, 0., self.fp_max_val()) | ||
return val_fp_quant, scale | ||
|
||
@brevitas.jit.script_method | ||
def dequantize(self, y, scale): | ||
return y * scale | ||
|
||
@brevitas.jit.script_method | ||
def forward(self, x): | ||
y, scale = self.quantize(x) | ||
y = self.dequantize(y, scale) | ||
# This is to respect the current interface of proxies | ||
return y, scale, self.zero_point_impl(), self.bit_width() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. | ||
# SPDX-License-Identifier: BSD-3-Clause | ||
|
||
from typing import Optional | ||
|
||
import torch | ||
from torch import Tensor | ||
|
||
import brevitas | ||
from brevitas.core.utils import StatelessBuffer | ||
from brevitas.function.ops import max_float | ||
|
||
|
||
class FloatScaling(brevitas.jit.ScriptModule): | ||
|
||
def __init__( | ||
self, | ||
exponent_bit_width: int, | ||
mantissa_bit_width: int, | ||
exponent_bias: int, | ||
device: Optional[str] = None, | ||
dtype: Optional[torch.dtype] = None): | ||
super(FloatScaling, self).__init__() | ||
exponent_bit_width = torch.tensor(exponent_bit_width, device=device, dtype=dtype) | ||
mantissa_bit_width = torch.tensor(mantissa_bit_width, device=device, dtype=dtype) | ||
exponent_bias = torch.tensor(exponent_bias, device=device, dtype=dtype) | ||
self.max_float_val = StatelessBuffer( | ||
max_float(exponent_bit_width, mantissa_bit_width, exponent_bias)) | ||
|
||
@brevitas.jit.script_method | ||
def forward(self, input: torch.Tensor) -> Tensor: | ||
return self.max_float_val() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,126 @@ | ||
# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. | ||
# SPDX-License-Identifier: BSD-3-Clause | ||
|
||
from brevitas.quant.base import MSESymmetricScale | ||
from brevitas.quant.experimental.float_base import FloatActBase | ||
from brevitas.quant.experimental.float_base import FloatWeightBase | ||
from brevitas.quant.experimental.float_base import Fp8e4m3Mixin | ||
from brevitas.quant.experimental.float_base import Fp8e5m2Mixin | ||
from brevitas.quant.experimental.float_base import ScaledFloatActBase | ||
from brevitas.quant.experimental.float_base import ScaledFloatWeightBase | ||
|
||
|
||
class Fp8e4m3Weight(Fp8e4m3Mixin, FloatWeightBase): | ||
""" | ||
FP8 signed E3M4 weight quantizer. | ||
""" | ||
pass | ||
|
||
|
||
class Fp8e5m2Weight(Fp8e5m2Mixin, FloatWeightBase): | ||
""" | ||
FP8 signed E5M2 weight quantizer. | ||
""" | ||
pass | ||
|
||
|
||
class Fp8e4m3Act(Fp8e4m3Mixin, FloatActBase): | ||
""" | ||
FP8 signed E4M3 activation quantizer. | ||
""" | ||
pass | ||
|
||
|
||
class Fp8e5m2Act(Fp8e5m2Mixin, FloatActBase): | ||
""" | ||
FP8 signed E5M2 activation quantizer. | ||
""" | ||
pass | ||
|
||
|
||
class Fp8e4m3WeightPerTensorFloat(Fp8e4m3Mixin, ScaledFloatWeightBase): | ||
""" | ||
FP8 signed E3M4 weight quantizer with per-tensor absmax-based scaling. | ||
""" | ||
scaling_per_output_channel = False | ||
|
||
|
||
class Fp8e5m2WeightPerTensorFloat(Fp8e5m2Mixin, ScaledFloatWeightBase): | ||
""" | ||
FP8 signed E5M2 weight quantizer with per-tensor absmax-based scaling. | ||
""" | ||
scaling_per_output_channel = False | ||
|
||
|
||
class Fp8e4m3ActPerTensorFloat(Fp8e4m3Mixin, ScaledFloatActBase): | ||
""" | ||
FP8 signed E4M3 activation quantizer with per-tensor static percentile-based scaling. | ||
""" | ||
scaling_per_output_channel = False | ||
|
||
|
||
class Fp8e5m2ActPerTensorFloat(Fp8e5m2Mixin, ScaledFloatActBase): | ||
""" | ||
FP8 signed E5M2 activation quantizer with per-tensor static percentile-based scaling. | ||
""" | ||
scaling_per_output_channel = False | ||
|
||
|
||
class Fp8e4m3WeightPerChannelFloat(Fp8e4m3Mixin, ScaledFloatWeightBase): | ||
""" | ||
FP8 signed E3M4 weight quantizer with per-channel absmax-based scaling. | ||
""" | ||
scaling_per_output_channel = True | ||
|
||
|
||
class Fp8e5m2WeightPerChannelFloat(Fp8e5m2Mixin, ScaledFloatWeightBase): | ||
""" | ||
FP8 signed E5M2 weight quantizer with per-channel absmax-based scaling. | ||
""" | ||
scaling_per_output_channel = True | ||
|
||
|
||
class Fp8e4m3ActPerChannelFloat2d(Fp8e4m3Mixin, ScaledFloatActBase): | ||
""" | ||
FP8 signed E4M3 activation quantizer with per-channel static percentile-based scaling. | ||
""" | ||
scaling_per_output_channel = True | ||
scaling_stats_permute_dims = (1, 0, 2, 3) | ||
|
||
|
||
class Fp8e5m2ActPerChannelFloat2d(Fp8e5m2Mixin, ScaledFloatActBase): | ||
""" | ||
FP8 signed E5M2 activation quantizer with per-channel static percentile-based scaling. | ||
""" | ||
scaling_per_output_channel = True | ||
scaling_stats_permute_dims = (1, 0, 2, 3) | ||
|
||
|
||
class Fp8e4m3ActPerTensorFloatMSE(Fp8e4m3Mixin, MSESymmetricScale, ScaledFloatActBase): | ||
""" | ||
FP8 signed E4M3 activation quantizer with per-tensor static MSE-based scaling. | ||
""" | ||
scaling_per_output_channel = False | ||
|
||
|
||
class Fp8e5m2ActPerTensorFloatMSE(Fp8e5m2Mixin, MSESymmetricScale, ScaledFloatActBase): | ||
""" | ||
FP8 signed E5M2 activation quantizer with per-tensor static MSE-based scaling. | ||
""" | ||
scaling_per_output_channel = False | ||
|
||
|
||
class Fp8e4m3ActPerChannelFloat2dMSE(Fp8e4m3Mixin, MSESymmetricScale, ScaledFloatActBase): | ||
""" | ||
FP8 signed E4M3 activation quantizer with per-channel static MSE-based scaling. | ||
""" | ||
scaling_per_output_channel = True | ||
scaling_stats_permute_dims = (1, 0, 2, 3) | ||
|
||
|
||
class Fp8e5m2ActPerChannelFloat2dMSE(Fp8e5m2Mixin, MSESymmetricScale, ScaledFloatActBase): | ||
""" | ||
FP8 signed E5M2 activation quantizer with per-channel static MSE-based scaling. | ||
""" | ||
scaling_per_output_channel = True | ||
scaling_stats_permute_dims = (1, 0, 2, 3) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. | ||
# SPDX-License-Identifier: BSD-3-Clause | ||
|
||
from brevitas.core.function_wrapper import RoundSte | ||
from brevitas.core.quant.float import FloatQuant | ||
from brevitas.core.scaling.float_scaling import FloatScaling | ||
from brevitas.inject import ExtendedInjector | ||
from brevitas.proxy.parameter_quant import WeightQuantProxyFromInjector | ||
from brevitas.proxy.runtime_quant import ActQuantProxyFromInjector | ||
from brevitas.quant.solver import ActQuantSolver | ||
from brevitas.quant.solver import WeightQuantSolver | ||
|
||
|
||
class FloatWeightBase(ExtendedInjector): | ||
proxy_class = WeightQuantProxyFromInjector | ||
tensor_quant = FloatQuant | ||
signed = True | ||
float_to_int_impl = RoundSte | ||
|
||
|
||
class FloatActBase(ExtendedInjector): | ||
proxy_class = ActQuantProxyFromInjector | ||
tensor_quant = FloatQuant | ||
signed = True | ||
float_to_int_impl = RoundSte | ||
|
||
|
||
class ScaledFloatWeightBase(FloatWeightBase, WeightQuantSolver): | ||
scaling_stats_op = 'max' | ||
scaling_impl_type = 'stats' | ||
restrict_scaling_type = 'fp' | ||
float_scaling_impl = FloatScaling | ||
|
||
|
||
class ScaledFloatActBase(FloatActBase, ActQuantSolver): | ||
scaling_stats_op = 'percentile' | ||
scaling_impl_type = 'parameter_from_stats' | ||
restrict_scaling_type = 'fp' | ||
high_percentile_q = 99.999 | ||
collect_stats_steps = 300 | ||
float_scaling_impl = FloatScaling | ||
|
||
|
||
class Fp8e4m3Mixin(ExtendedInjector): | ||
bit_width = 8 | ||
exponent_bit_width = 4 | ||
mantissa_bit_width = 3 | ||
exponent_bias = 7 | ||
|
||
|
||
class Fp8e5m2Mixin(ExtendedInjector): | ||
bit_width = 8 | ||
exponent_bit_width = 5 | ||
mantissa_bit_width = 2 | ||
exponent_bias = 15 |