diff --git a/src/brevitas/core/quant/float.py b/src/brevitas/core/quant/float.py index fcbed7b66..55d048936 100644 --- a/src/brevitas/core/quant/float.py +++ b/src/brevitas/core/quant/float.py @@ -39,6 +39,8 @@ def __init__( torch.tensor(float(exponent_bit_width), device=device, dtype=dtype)) self.mantissa_bit_width = StatelessBuffer( (torch.tensor(float(mantissa_bit_width), device=device, dtype=dtype))) + if exponent_bias is None: + exponent_bias = 2 ** (exponent_bit_width - 1) - 1 self.exponent_bias = StatelessBuffer( torch.tensor(float(exponent_bias), device=device, dtype=dtype)) self.fp_max_val = StatelessBuffer( diff --git a/tests/brevitas/core/test_float_quant.py b/tests/brevitas/core/test_float_quant.py new file mode 100644 index 000000000..9997ebbda --- /dev/null +++ b/tests/brevitas/core/test_float_quant.py @@ -0,0 +1,58 @@ +# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +from hypothesis import given +import mock +import pytest +import torch + +from brevitas.core.function_wrapper import RoundSte +from brevitas.core.quant.float import FloatQuant +from brevitas.core.scaling import ConstScaling +from tests.brevitas.core.bit_width_fixture import * # noqa +from tests.brevitas.core.int_quant_fixture import * # noqa +from tests.brevitas.core.shared_quant_fixture import * # noqa +from tests.brevitas.hyp_helper import float_tensor_random_shape_st +from tests.brevitas.hyp_helper import random_minifloat_format +from tests.brevitas.hyp_helper import scalar_float_p_tensor_st +from tests.marker import jit_disabled_for_mock + + +@given(minifloat_format=random_minifloat_format()) +def test_float_quant_defaults(minifloat_format): + bit_width, exponent_bit_width, mantissa_bit_width, signed = minifloat_format + # specifically don't set exponent bias to see if default works + expected_exponent_bias = 2 ** (exponent_bit_width - 1) - 1 + float_quant = FloatQuant( + bit_width=bit_width, + signed=signed, + exponent_bit_width=exponent_bit_width, + mantissa_bit_width=mantissa_bit_width) + assert expected_exponent_bias == float_quant.exponent_bias() + assert isinstance(float_quant.float_to_int_impl, RoundSte) + assert isinstance(float_quant.float_scaling_impl, ConstScaling) + assert isinstance(float_quant.scaling_impl, ConstScaling) + + +@given(minifloat_format=random_minifloat_format()) +def test_minifloat(minifloat_format): + bit_width, exponent_bit_width, mantissa_bit_width, signed = minifloat_format + assert bit_width == exponent_bit_width + mantissa_bit_width + int(signed) + + +@given(inp=float_tensor_random_shape_st(), minifloat_format=random_minifloat_format()) +@jit_disabled_for_mock() +def test_int_quant_to_in(inp, minifloat_format): + bit_width, exponent_bit_width, mantissa_bit_width, signed = minifloat_format + exponent_bias = 2 ** (exponent_bit_width - 1) - 1 + float_quant = FloatQuant( + bit_width=bit_width, + signed=signed, + exponent_bit_width=exponent_bit_width, + mantissa_bit_width=mantissa_bit_width, + exponent_bias=exponent_bias) + expected_out, _, _, bit_width_out = float_quant(inp) + + out_quant, scale = float_quant.quantize(inp) + assert bit_width_out == bit_width + assert torch.equal(expected_out, out_quant * scale) diff --git a/tests/brevitas/hyp_helper.py b/tests/brevitas/hyp_helper.py index ed6f1ea87..9905d2f03 100644 --- a/tests/brevitas/hyp_helper.py +++ b/tests/brevitas/hyp_helper.py @@ -13,6 +13,8 @@ import torch from tests.brevitas.common import FP32_BIT_WIDTH +from tests.brevitas.common import MAX_INT_BIT_WIDTH +from tests.brevitas.common import MIN_INT_BIT_WIDTH from tests.conftest import SEED # Remove Hypothesis check for slow tests and function scoped fixtures. @@ -218,3 +220,21 @@ def min_max_tensor_random_shape_st(draw, min_dims=1, max_dims=4, max_size=3, wid min_tensor = torch.tensor(min_list).view(*shape) max_tensor = torch.tensor(max_list).view(*shape) return min_tensor, max_tensor + + +@st.composite +def random_minifloat_format(draw, min_bit_width=MIN_INT_BIT_WIDTH, max_bit_with=MAX_INT_BIT_WIDTH): + """" + Generate a minifloat format. Returns bit_width, exponent, mantissa, and signed. + """ + bit_width = draw(st.integers(min_value=min_bit_width, max_value=max_bit_with)) + exponent_bit_width = draw(st.integers(min_value=0, max_value=bit_width)) + signed = draw(st.booleans()) + # if no budget is left, return + if bit_width == exponent_bit_width: + return bit_width, exponent_bit_width, 0, False + elif bit_width == (exponent_bit_width + int(signed)): + return bit_width, exponent_bit_width, 0, signed + mantissa_bit_width = bit_width - exponent_bit_width - int(signed) + + return bit_width, exponent_bit_width, mantissa_bit_width, signed