diff --git a/src/brevitas/core/function_wrapper/clamp.py b/src/brevitas/core/function_wrapper/clamp.py index c919dabf0..40cff68c0 100644 --- a/src/brevitas/core/function_wrapper/clamp.py +++ b/src/brevitas/core/function_wrapper/clamp.py @@ -12,10 +12,7 @@ import brevitas from brevitas.core.utils import StatelessBuffer -from brevitas.function import max_float from brevitas.function import tensor_clamp -from brevitas.utils.float_quant_utils import dec_to_bits -from brevitas.utils.float_quant_utils import get_minifloat_value class TensorClamp(brevitas.jit.ScriptModule): @@ -89,117 +86,34 @@ class FloatClamp(brevitas.jit.ScriptModule): I.e. setting inf to 1101.111 (E4M3) is not a valid code. """ + __constants__ = ['saturating', 'has_inf_values'] + def __init__( self, - exponent_bit_width: int, - mantissa_bit_width: int, - exponent_bias: int, + max_value: float, tensor_clamp_impl: Module = TensorClamp(), - nan_values: Optional[Tuple[str]] = None, inf_values: Optional[Tuple[str]] = None, saturating: bool = False) -> None: super(FloatClamp, self).__init__() - # inf without NaN not possible - if inf_values is None and nan_values is None: - max_val_impl = StatelessBuffer( - max_float( - torch.tensor(exponent_bit_width), - torch.tensor(mantissa_bit_width), - torch.tensor(exponent_bias))) - elif nan_values is not None: - # we at least have values for NaN, so initiate MaxValInfNaN - max_val_impl = MaxFloatInfNaN( - exponent_bit_width=exponent_bit_width, - mantissa_bit_width=mantissa_bit_width, - exponent_bias=exponent_bias, - nan_values=nan_values, - inf_values=inf_values) - else: - # no NaN values but inf values - raise RuntimeError('Minifloat Error: inf value cannot exist without NaN value.') - - # class for clamping to inf/NaN values - self.fpx_clamp_impl = FpXClamp( - inf_values=inf_values, saturating=saturating, tensor_clamp_impl=tensor_clamp_impl) - - # get max value for the minifloat config, no need to compute it during forward pass - self.max_value = max_val_impl() - - @brevitas.jit.script_method - def forward(self, inp: Tensor): - return self.fpx_clamp_impl(inp, self.max_value) - - -class MaxFloatInfNaN(brevitas.jit.ScriptModule): - - def __init__( - self, - exponent_bit_width: int, - mantissa_bit_width: int, - exponent_bias: int, - nan_values: Tuple[str], - inf_values: Optional[Tuple[str]]) -> None: - super(MaxFloatInfNaN, self).__init__() - self.exponent_bit_width = StatelessBuffer(torch.tensor(exponent_bit_width)) - self.mantissa_bit_width = StatelessBuffer(torch.tensor(mantissa_bit_width)) - self.exponent_bias = StatelessBuffer(torch.tensor(exponent_bias)) - - _special_values = nan_values + inf_values if inf_values is not None else nan_values - - # check that NaN/inf values are all mantissa_bit_width long - if any(map(lambda x: len(x) > mantissa_bit_width, _special_values)): - raise RuntimeError('NaN/inf codes need to be the same length as the mantissa.') - - # move computation of min for forward pass here so it's jit compatible - self._min_special_case = torch.tensor(min(map(lambda x: int(x, 2), _special_values))) + self.tensor_clamp_impl = tensor_clamp_impl - @brevitas.jit.script_method - def forward(self): - # idea: take inf and nan values, select the smallest, set max_value to smallest_val - 1 - max_value_mantissa = self._min_special_case - 1 - - if max_value_mantissa < 0: - # all mantissa values are used, so we need to use decrease exponent values - exponent = torch.tensor(1).repeat(self.exponent_bit_width() - 1) - # add trailing 0 to reach bit width - exponent = torch.cat([exponent, torch.tensor([0], dtype=exponent.dtype)]) - # since we decreased exponent, we can use full mantissa - mantissa = torch.tensor(1).repeat(self.mantissa_bit_width()) - else: - # there is a free mantissa code, so use full exponent - exponent = torch.tensor(1).repeat(self.exponent_bit_width()) - # get binary code for max_value_mantissa in the number of mantissa bits - mantissa = dec_to_bits(max_value_mantissa, self.mantissa_bit_width()) - - # we don't need the sign since we're looking for the max value - max_value = get_minifloat_value( - exponent=exponent, mantissa=mantissa, exponent_bias=self.exponent_bias()) - return max_value - - -class FpXClamp(brevitas.jit.ScriptModule): - - def __init__(self, inf_values: Tuple[str], saturating: bool, tensor_clamp_impl: Module) -> None: - super(FpXClamp, self).__init__() - self.inf_values = inf_values + self.max_value = StatelessBuffer(torch.tensor(max_value)) self.saturating = saturating - self.tensor_clamp_impl = tensor_clamp_impl + self.has_inf_values = bool(inf_values) @brevitas.jit.script_method - def forward(self, x: Tensor, max_value: Tensor): - # NaN values all stay at NaN, so no need to do anything with NaN values - # get all positive inf values + def forward(self, x: Tensor): inf_mask = x.isinf() - p_max_val_mask = x > max_value - n_max_val_mask = -x > max_value + p_max_val_mask = x > self.max_value() + n_max_val_mask = -x > self.max_value() # first clamp everything to +- max_value, basically the saturating case - x = self.tensor_clamp_impl(x, min_val=-max_value, max_val=max_value) + x = self.tensor_clamp_impl(x, min_val=-self.max_value(), max_val=self.max_value()) if not self.saturating: # if non-saturating, we need to map values greater than max_val to nan or inf - if self.inf_values is not None: + if self.has_inf_values: # we have inf values, so we set abs values > max_value to +- inf, and leave inf at inf x[p_max_val_mask] = torch.tensor(float('inf')) x[n_max_val_mask] = torch.tensor(float('-inf')) diff --git a/src/brevitas/core/quant/float.py b/src/brevitas/core/quant/float.py index 8a462fc0e..8557b9974 100644 --- a/src/brevitas/core/quant/float.py +++ b/src/brevitas/core/quant/float.py @@ -13,6 +13,7 @@ from brevitas.core.utils import StatelessBuffer from brevitas.function.ops import max_float from brevitas.function.ops_ste import floor_ste +from brevitas.utils.float_quant_utils import get_max_value class FloatQuant(brevitas.jit.ScriptModule): @@ -57,9 +58,8 @@ def __init__( scaling_impl = ConstScaling(1., device=device, dtype=dtype) if float_clamp_impl is None: self.float_clamp_impl = FloatClamp( - exponent_bit_width=self.exponent_bit_width(), - mantissa_bit_width=self.mantissa_bit_width(), - exponent_bias=self.exponent_bias()) + max_value=get_max_value( + exponent_bit_width, mantissa_bit_width, exponent_bias, None, None)) # Zero-point is currently hardcoded to 0 self.zero_point_impl = StatelessBuffer(torch.tensor(0., device=device, dtype=dtype)) self.float_scaling_impl = float_scaling_impl diff --git a/src/brevitas/quant/experimental/float_base.py b/src/brevitas/quant/experimental/float_base.py index 8421c3c13..3777e6703 100644 --- a/src/brevitas/quant/experimental/float_base.py +++ b/src/brevitas/quant/experimental/float_base.py @@ -11,6 +11,7 @@ from brevitas.quant.solver import ActQuantSolver from brevitas.quant.solver import WeightQuantSolver from brevitas.quant.solver.common import SolveTensorQuantFloatToIntImplFromEnum +from brevitas.utils.float_quant_utils import get_max_value class FloatWeightBase(SolveTensorQuantFloatToIntImplFromEnum): @@ -51,7 +52,15 @@ def exponent_bias(exponent_bit_width): return 2 ** (exponent_bit_width - 1) - 1 -class Fp8e4m3Mixin(ExponentBiasMixin): +class MaxFloatInfNaNMixin(ExtendedInjector): + + @value + def max_value(exponent_bit_width, mantissa_bit_width, exponent_bias, nan_values, inf_values): + return get_max_value( + exponent_bit_width, mantissa_bit_width, exponent_bias, nan_values, inf_values) + + +class Fp8e4m3Mixin(ExponentBiasMixin, MaxFloatInfNaNMixin): bit_width = 8 exponent_bit_width = 4 mantissa_bit_width = 3 @@ -61,7 +70,7 @@ class Fp8e4m3Mixin(ExponentBiasMixin): saturating = True -class Fp8e5m2Mixin(ExponentBiasMixin): +class Fp8e5m2Mixin(ExponentBiasMixin, MaxFloatInfNaNMixin): bit_width = 8 exponent_bit_width = 5 mantissa_bit_width = 2 diff --git a/src/brevitas/utils/float_quant_utils.py b/src/brevitas/utils/float_quant_utils.py index 33515f95e..b90dbf82d 100644 --- a/src/brevitas/utils/float_quant_utils.py +++ b/src/brevitas/utils/float_quant_utils.py @@ -1,41 +1,66 @@ # Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause -import torch -from torch import Tensor - -def mantissa_bits_to_float(bits: Tensor, frexp_compatible: bool = False) -> float: +def mantissa_bits_to_float(bits: str, frexp_compatible: bool = False) -> float: # computes the decimal place value from a given binary tensor res = 1.0 for i, val in enumerate(bits): # iterating through from left to right - res += ((2 ** -(i + 1)) * val) + res += ((2 ** -(i + 1)) * float(val)) if frexp_compatible: return res / 2. else: return res -def get_minifloat_value(exponent: Tensor, mantissa: Tensor, exponent_bias: Tensor) -> Tensor: +def get_minifloat_value(exponent: str, mantissa: str, exponent_bias: int) -> float: """ Returns the minifloat value for a given exponent, mantissa and exponent_bias. It expects the exponent and mantissa in their binary format. """ - exponent_value = bits_to_dec(exponent) + exponent_value = int(exponent, 2) mantissa_value = mantissa_bits_to_float(mantissa) - return torch.exp2(exponent_value - exponent_bias) * mantissa_value + return 2 ** (exponent_value - exponent_bias) * mantissa_value + + +def get_max_value(exponent_bit_width, mantissa_bit_width, exponent_bias, nan_values, inf_values): + # Idea: take the smallest NaN/inf value, set max_value to the next smaller one + # inf without NaN not possible + if inf_values is None and nan_values is None: + # no special cases, max_value is using all bits for exponent and mantissa + exponent = '1' * exponent_bit_width + mantissa = '1' * mantissa_bit_width + elif nan_values is not None: + # we at least have values for NaN, so initiate MaxValInfNaN + special_values = nan_values + inf_values if inf_values is not None else nan_values + # check that NaN/inf values are all mantissa_bit_width long + if any(map(lambda x: len(x) > mantissa_bit_width, special_values)): + raise RuntimeError('NaN/inf codes need to be the same length as the mantissa.') -def dec_to_bits(value: Tensor, bits: int) -> Tensor: - # set up mask - mask = 2 ** torch.arange(bits - 1, -1, -1).to(value.device, value.dtype) - # add dimension, bitwise_and gets the bits needed for the value, the rest is converting to byte - return value.unsqueeze(-1).bitwise_and(mask).ne(0).byte() + # get the minimum special case, our max value is the next smaller value + min_special_case = min(map(lambda x: int(x, 2), special_values)) + max_value_mantissa = min_special_case - 1 + + if max_value_mantissa < 0: + # all mantissa values are used, so we need to use decrease exponent values + exponent = '1' * (exponent_bit_width - 1) + # add trailing 0 to reach bit width + exponent += '0' + # since we decreased exponent, we can use full mantissa + mantissa = '1' * mantissa_bit_width + else: + # there is a free mantissa code, so use full exponent + exponent = '1' * exponent_bit_width + # get binary code for max_value_mantissa in the number of mantissa bits + mantissa = format(max_value_mantissa, f'0{mantissa_bit_width}b') + else: + # no NaN values but inf values + raise RuntimeError('Minifloat Error: inf value cannot exist without NaN value.') -def bits_to_dec(bits: Tensor) -> Tensor: - # get num of bits used - num_bits = len(bits) - # convert by summing decimal values of set bits - return torch.sum((2 ** torch.arange(num_bits - 1, -1, -1)) * bits) + # we don't need the sign since we're looking for the max value + max_value = get_minifloat_value( + exponent=exponent, mantissa=mantissa, exponent_bias=exponent_bias) + return max_value diff --git a/tests/brevitas/core/minifloat_fixtures.py b/tests/brevitas/core/minifloat_fixtures.py index e0f7528a6..48cefc663 100644 --- a/tests/brevitas/core/minifloat_fixtures.py +++ b/tests/brevitas/core/minifloat_fixtures.py @@ -7,24 +7,29 @@ from brevitas.core.function_wrapper import FloatClamp from brevitas.inject.enum import BitWidthImplType from brevitas.quant.experimental.float_base import ExponentBiasMixin +from brevitas.quant.experimental.float_base import MaxFloatInfNaNMixin from brevitas.quant.experimental.float_base import ScaledFloatWeightBase -class Fp8e4m3Base(ExponentBiasMixin, ScaledFloatWeightBase): +class Fp8e4m3Base(ExponentBiasMixin, MaxFloatInfNaNMixin, ScaledFloatWeightBase): bit_width = 8 exponent_bit_width = 4 mantissa_bit_width = 3 float_clamp_impl = FloatClamp + nan_values = None + inf_values = None bit_width_impl_type = BitWidthImplType.CONST # hypothesis extra hypothesis_internal_is_this_a_mock_check = False -class Fp8e5m2Base(ExponentBiasMixin, ScaledFloatWeightBase): +class Fp8e5m2Base(ExponentBiasMixin, MaxFloatInfNaNMixin, ScaledFloatWeightBase): bit_width = 8 exponent_bit_width = 5 mantissa_bit_width = 2 float_clamp_impl = FloatClamp + nan_values = None + inf_values = None bit_width_impl_type = BitWidthImplType.CONST # hypothesis extra hypothesis_internal_is_this_a_mock_check = False diff --git a/tests/brevitas/core/test_minifloat.py b/tests/brevitas/core/test_minifloat.py index a8b2f93b9..2a4f6b000 100644 --- a/tests/brevitas/core/test_minifloat.py +++ b/tests/brevitas/core/test_minifloat.py @@ -16,26 +16,26 @@ @pytest.mark.parametrize( 'minifloat, expected_max_val', ((format, max_val) for format, max_val in FORMATS.items())) def test_max_value(minifloat, expected_max_val): - max_val = minifloat.float_clamp_impl.max_value + max_val = minifloat.float_clamp_impl.max_value() assert expected_max_val == max_val @given(inp=float_tensor_random_shape_st()) def test_clamp(inp, fp8_clamp): - max_val = fp8_clamp.float_clamp_impl.max_value + max_val = fp8_clamp.float_clamp_impl.max_value() # get values that exceed max_val over_limit_mask = inp.abs() > max_val # clamp inp inp = fp8_clamp.float_clamp_impl(inp) - if fp8_clamp.float_clamp_impl.fpx_clamp_impl.saturating: + if fp8_clamp.float_clamp_impl.saturating: # should be clamped to +- max val assert (inp[over_limit_mask].abs() == max_val).all() else: # if inf_values, over limit mask should now be all inf - if fp8_clamp.float_clamp_impl.fpx_clamp_impl.inf_values is not None: + if fp8_clamp.float_clamp_impl.has_inf_values: # all values exceeding max_val should be inf assert inp[over_limit_mask].isinf().all() else: