Skip to content

Commit

Permalink
Fix (minifloat): compute max_value during dependency injection
Browse files Browse the repository at this point in the history
  • Loading branch information
fabianandresgrob committed Feb 23, 2024
1 parent 1b2a64b commit 49489b2
Show file tree
Hide file tree
Showing 6 changed files with 79 additions and 126 deletions.
108 changes: 11 additions & 97 deletions src/brevitas/core/function_wrapper/clamp.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,7 @@

import brevitas
from brevitas.core.utils import StatelessBuffer
from brevitas.function import max_float
from brevitas.function import tensor_clamp
from brevitas.utils.float_quant_utils import dec_to_bits
from brevitas.utils.float_quant_utils import get_minifloat_value


class TensorClamp(brevitas.jit.ScriptModule):
Expand Down Expand Up @@ -89,117 +86,34 @@ class FloatClamp(brevitas.jit.ScriptModule):
I.e. setting inf to 1101.111 (E4M3) is not a valid code.
"""

__constants__ = ['saturating', 'has_inf_values']

def __init__(
self,
exponent_bit_width: int,
mantissa_bit_width: int,
exponent_bias: int,
max_value: float,
tensor_clamp_impl: Module = TensorClamp(),
nan_values: Optional[Tuple[str]] = None,
inf_values: Optional[Tuple[str]] = None,
saturating: bool = False) -> None:
super(FloatClamp, self).__init__()

# inf without NaN not possible
if inf_values is None and nan_values is None:
max_val_impl = StatelessBuffer(
max_float(
torch.tensor(exponent_bit_width),
torch.tensor(mantissa_bit_width),
torch.tensor(exponent_bias)))
elif nan_values is not None:
# we at least have values for NaN, so initiate MaxValInfNaN
max_val_impl = MaxFloatInfNaN(
exponent_bit_width=exponent_bit_width,
mantissa_bit_width=mantissa_bit_width,
exponent_bias=exponent_bias,
nan_values=nan_values,
inf_values=inf_values)
else:
# no NaN values but inf values
raise RuntimeError('Minifloat Error: inf value cannot exist without NaN value.')

# class for clamping to inf/NaN values
self.fpx_clamp_impl = FpXClamp(
inf_values=inf_values, saturating=saturating, tensor_clamp_impl=tensor_clamp_impl)

# get max value for the minifloat config, no need to compute it during forward pass
self.max_value = max_val_impl()

@brevitas.jit.script_method
def forward(self, inp: Tensor):
return self.fpx_clamp_impl(inp, self.max_value)


class MaxFloatInfNaN(brevitas.jit.ScriptModule):

def __init__(
self,
exponent_bit_width: int,
mantissa_bit_width: int,
exponent_bias: int,
nan_values: Tuple[str],
inf_values: Optional[Tuple[str]]) -> None:
super(MaxFloatInfNaN, self).__init__()
self.exponent_bit_width = StatelessBuffer(torch.tensor(exponent_bit_width))
self.mantissa_bit_width = StatelessBuffer(torch.tensor(mantissa_bit_width))
self.exponent_bias = StatelessBuffer(torch.tensor(exponent_bias))

_special_values = nan_values + inf_values if inf_values is not None else nan_values

# check that NaN/inf values are all mantissa_bit_width long
if any(map(lambda x: len(x) > mantissa_bit_width, _special_values)):
raise RuntimeError('NaN/inf codes need to be the same length as the mantissa.')

# move computation of min for forward pass here so it's jit compatible
self._min_special_case = torch.tensor(min(map(lambda x: int(x, 2), _special_values)))
self.tensor_clamp_impl = tensor_clamp_impl

@brevitas.jit.script_method
def forward(self):
# idea: take inf and nan values, select the smallest, set max_value to smallest_val - 1
max_value_mantissa = self._min_special_case - 1

if max_value_mantissa < 0:
# all mantissa values are used, so we need to use decrease exponent values
exponent = torch.tensor(1).repeat(self.exponent_bit_width() - 1)
# add trailing 0 to reach bit width
exponent = torch.cat([exponent, torch.tensor([0], dtype=exponent.dtype)])
# since we decreased exponent, we can use full mantissa
mantissa = torch.tensor(1).repeat(self.mantissa_bit_width())
else:
# there is a free mantissa code, so use full exponent
exponent = torch.tensor(1).repeat(self.exponent_bit_width())
# get binary code for max_value_mantissa in the number of mantissa bits
mantissa = dec_to_bits(max_value_mantissa, self.mantissa_bit_width())

# we don't need the sign since we're looking for the max value
max_value = get_minifloat_value(
exponent=exponent, mantissa=mantissa, exponent_bias=self.exponent_bias())
return max_value


class FpXClamp(brevitas.jit.ScriptModule):

def __init__(self, inf_values: Tuple[str], saturating: bool, tensor_clamp_impl: Module) -> None:
super(FpXClamp, self).__init__()
self.inf_values = inf_values
self.max_value = StatelessBuffer(torch.tensor(max_value))
self.saturating = saturating
self.tensor_clamp_impl = tensor_clamp_impl
self.has_inf_values = bool(inf_values)

@brevitas.jit.script_method
def forward(self, x: Tensor, max_value: Tensor):
# NaN values all stay at NaN, so no need to do anything with NaN values
# get all positive inf values
def forward(self, x: Tensor):
inf_mask = x.isinf()
p_max_val_mask = x > max_value
n_max_val_mask = -x > max_value
p_max_val_mask = x > self.max_value()
n_max_val_mask = -x > self.max_value()

# first clamp everything to +- max_value, basically the saturating case
x = self.tensor_clamp_impl(x, min_val=-max_value, max_val=max_value)
x = self.tensor_clamp_impl(x, min_val=-self.max_value(), max_val=self.max_value())

if not self.saturating:
# if non-saturating, we need to map values greater than max_val to nan or inf
if self.inf_values is not None:
if self.has_inf_values:
# we have inf values, so we set abs values > max_value to +- inf, and leave inf at inf
x[p_max_val_mask] = torch.tensor(float('inf'))
x[n_max_val_mask] = torch.tensor(float('-inf'))
Expand Down
6 changes: 3 additions & 3 deletions src/brevitas/core/quant/float.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from brevitas.core.utils import StatelessBuffer
from brevitas.function.ops import max_float
from brevitas.function.ops_ste import floor_ste
from brevitas.utils.float_quant_utils import get_max_value


class FloatQuant(brevitas.jit.ScriptModule):
Expand Down Expand Up @@ -57,9 +58,8 @@ def __init__(
scaling_impl = ConstScaling(1., device=device, dtype=dtype)
if float_clamp_impl is None:
self.float_clamp_impl = FloatClamp(
exponent_bit_width=self.exponent_bit_width(),
mantissa_bit_width=self.mantissa_bit_width(),
exponent_bias=self.exponent_bias())
max_value=get_max_value(
exponent_bit_width, mantissa_bit_width, exponent_bias, None, None))
# Zero-point is currently hardcoded to 0
self.zero_point_impl = StatelessBuffer(torch.tensor(0., device=device, dtype=dtype))
self.float_scaling_impl = float_scaling_impl
Expand Down
13 changes: 11 additions & 2 deletions src/brevitas/quant/experimental/float_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from brevitas.quant.solver import ActQuantSolver
from brevitas.quant.solver import WeightQuantSolver
from brevitas.quant.solver.common import SolveTensorQuantFloatToIntImplFromEnum
from brevitas.utils.float_quant_utils import get_max_value


class FloatWeightBase(SolveTensorQuantFloatToIntImplFromEnum):
Expand Down Expand Up @@ -51,7 +52,15 @@ def exponent_bias(exponent_bit_width):
return 2 ** (exponent_bit_width - 1) - 1


class Fp8e4m3Mixin(ExponentBiasMixin):
class MaxFloatInfNaNMixin(ExtendedInjector):

@value
def max_value(exponent_bit_width, mantissa_bit_width, exponent_bias, nan_values, inf_values):
return get_max_value(
exponent_bit_width, mantissa_bit_width, exponent_bias, nan_values, inf_values)


class Fp8e4m3Mixin(ExponentBiasMixin, MaxFloatInfNaNMixin):
bit_width = 8
exponent_bit_width = 4
mantissa_bit_width = 3
Expand All @@ -61,7 +70,7 @@ class Fp8e4m3Mixin(ExponentBiasMixin):
saturating = True


class Fp8e5m2Mixin(ExponentBiasMixin):
class Fp8e5m2Mixin(ExponentBiasMixin, MaxFloatInfNaNMixin):
bit_width = 8
exponent_bit_width = 5
mantissa_bit_width = 2
Expand Down
61 changes: 43 additions & 18 deletions src/brevitas/utils/float_quant_utils.py
Original file line number Diff line number Diff line change
@@ -1,41 +1,66 @@
# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause

import torch
from torch import Tensor


def mantissa_bits_to_float(bits: Tensor, frexp_compatible: bool = False) -> float:
def mantissa_bits_to_float(bits: str, frexp_compatible: bool = False) -> float:
# computes the decimal place value from a given binary tensor
res = 1.0
for i, val in enumerate(bits):
# iterating through from left to right
res += ((2 ** -(i + 1)) * val)
res += ((2 ** -(i + 1)) * float(val))
if frexp_compatible:
return res / 2.
else:
return res


def get_minifloat_value(exponent: Tensor, mantissa: Tensor, exponent_bias: Tensor) -> Tensor:
def get_minifloat_value(exponent: str, mantissa: str, exponent_bias: int) -> float:
"""
Returns the minifloat value for a given exponent, mantissa and exponent_bias.
It expects the exponent and mantissa in their binary format.
"""
exponent_value = bits_to_dec(exponent)
exponent_value = int(exponent, 2)
mantissa_value = mantissa_bits_to_float(mantissa)
return torch.exp2(exponent_value - exponent_bias) * mantissa_value
return 2 ** (exponent_value - exponent_bias) * mantissa_value


def get_max_value(exponent_bit_width, mantissa_bit_width, exponent_bias, nan_values, inf_values):
# Idea: take the smallest NaN/inf value, set max_value to the next smaller one
# inf without NaN not possible
if inf_values is None and nan_values is None:
# no special cases, max_value is using all bits for exponent and mantissa
exponent = '1' * exponent_bit_width
mantissa = '1' * mantissa_bit_width
elif nan_values is not None:
# we at least have values for NaN, so initiate MaxValInfNaN
special_values = nan_values + inf_values if inf_values is not None else nan_values

# check that NaN/inf values are all mantissa_bit_width long
if any(map(lambda x: len(x) > mantissa_bit_width, special_values)):
raise RuntimeError('NaN/inf codes need to be the same length as the mantissa.')

def dec_to_bits(value: Tensor, bits: int) -> Tensor:
# set up mask
mask = 2 ** torch.arange(bits - 1, -1, -1).to(value.device, value.dtype)
# add dimension, bitwise_and gets the bits needed for the value, the rest is converting to byte
return value.unsqueeze(-1).bitwise_and(mask).ne(0).byte()
# get the minimum special case, our max value is the next smaller value
min_special_case = min(map(lambda x: int(x, 2), special_values))

max_value_mantissa = min_special_case - 1

if max_value_mantissa < 0:
# all mantissa values are used, so we need to use decrease exponent values
exponent = '1' * (exponent_bit_width - 1)
# add trailing 0 to reach bit width
exponent += '0'
# since we decreased exponent, we can use full mantissa
mantissa = '1' * mantissa_bit_width
else:
# there is a free mantissa code, so use full exponent
exponent = '1' * exponent_bit_width
# get binary code for max_value_mantissa in the number of mantissa bits
mantissa = format(max_value_mantissa, f'0{mantissa_bit_width}b')
else:
# no NaN values but inf values
raise RuntimeError('Minifloat Error: inf value cannot exist without NaN value.')

def bits_to_dec(bits: Tensor) -> Tensor:
# get num of bits used
num_bits = len(bits)
# convert by summing decimal values of set bits
return torch.sum((2 ** torch.arange(num_bits - 1, -1, -1)) * bits)
# we don't need the sign since we're looking for the max value
max_value = get_minifloat_value(
exponent=exponent, mantissa=mantissa, exponent_bias=exponent_bias)
return max_value
9 changes: 7 additions & 2 deletions tests/brevitas/core/minifloat_fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,24 +7,29 @@
from brevitas.core.function_wrapper import FloatClamp
from brevitas.inject.enum import BitWidthImplType
from brevitas.quant.experimental.float_base import ExponentBiasMixin
from brevitas.quant.experimental.float_base import MaxFloatInfNaNMixin
from brevitas.quant.experimental.float_base import ScaledFloatWeightBase


class Fp8e4m3Base(ExponentBiasMixin, ScaledFloatWeightBase):
class Fp8e4m3Base(ExponentBiasMixin, MaxFloatInfNaNMixin, ScaledFloatWeightBase):
bit_width = 8
exponent_bit_width = 4
mantissa_bit_width = 3
float_clamp_impl = FloatClamp
nan_values = None
inf_values = None
bit_width_impl_type = BitWidthImplType.CONST
# hypothesis extra
hypothesis_internal_is_this_a_mock_check = False


class Fp8e5m2Base(ExponentBiasMixin, ScaledFloatWeightBase):
class Fp8e5m2Base(ExponentBiasMixin, MaxFloatInfNaNMixin, ScaledFloatWeightBase):
bit_width = 8
exponent_bit_width = 5
mantissa_bit_width = 2
float_clamp_impl = FloatClamp
nan_values = None
inf_values = None
bit_width_impl_type = BitWidthImplType.CONST
# hypothesis extra
hypothesis_internal_is_this_a_mock_check = False
Expand Down
8 changes: 4 additions & 4 deletions tests/brevitas/core/test_minifloat.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,26 +16,26 @@
@pytest.mark.parametrize(
'minifloat, expected_max_val', ((format, max_val) for format, max_val in FORMATS.items()))
def test_max_value(minifloat, expected_max_val):
max_val = minifloat.float_clamp_impl.max_value
max_val = minifloat.float_clamp_impl.max_value()

assert expected_max_val == max_val


@given(inp=float_tensor_random_shape_st())
def test_clamp(inp, fp8_clamp):
max_val = fp8_clamp.float_clamp_impl.max_value
max_val = fp8_clamp.float_clamp_impl.max_value()
# get values that exceed max_val
over_limit_mask = inp.abs() > max_val

# clamp inp
inp = fp8_clamp.float_clamp_impl(inp)

if fp8_clamp.float_clamp_impl.fpx_clamp_impl.saturating:
if fp8_clamp.float_clamp_impl.saturating:
# should be clamped to +- max val
assert (inp[over_limit_mask].abs() == max_val).all()
else:
# if inf_values, over limit mask should now be all inf
if fp8_clamp.float_clamp_impl.fpx_clamp_impl.inf_values is not None:
if fp8_clamp.float_clamp_impl.has_inf_values:
# all values exceeding max_val should be inf
assert inp[over_limit_mask].isinf().all()
else:
Expand Down

0 comments on commit 49489b2

Please sign in to comment.