Skip to content

Commit

Permalink
Fix (tests): add tests for FloatQuant
Browse files Browse the repository at this point in the history
  • Loading branch information
fabianandresgrob committed Jan 30, 2024
1 parent 56056ba commit 15b5938
Show file tree
Hide file tree
Showing 3 changed files with 80 additions and 0 deletions.
2 changes: 2 additions & 0 deletions src/brevitas/core/quant/float.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ def __init__(
torch.tensor(float(exponent_bit_width), device=device, dtype=dtype))
self.mantissa_bit_width = StatelessBuffer(
(torch.tensor(float(mantissa_bit_width), device=device, dtype=dtype)))
if exponent_bias is None:
exponent_bias = 2 ** (exponent_bit_width - 1) - 1
self.exponent_bias = StatelessBuffer(
torch.tensor(float(exponent_bias), device=device, dtype=dtype))
self.fp_max_val = StatelessBuffer(
Expand Down
58 changes: 58 additions & 0 deletions tests/brevitas/core/test_float_quant.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause

from hypothesis import given
import mock
import pytest
import torch

from brevitas.core.function_wrapper import RoundSte
from brevitas.core.quant.float import FloatQuant
from brevitas.core.scaling import ConstScaling
from tests.brevitas.core.bit_width_fixture import * # noqa
from tests.brevitas.core.int_quant_fixture import * # noqa
from tests.brevitas.core.shared_quant_fixture import * # noqa
from tests.brevitas.hyp_helper import float_tensor_random_shape_st
from tests.brevitas.hyp_helper import random_minifloat_format
from tests.brevitas.hyp_helper import scalar_float_p_tensor_st
from tests.marker import jit_disabled_for_mock


@given(minifloat_format=random_minifloat_format())
def test_float_quant_defaults(minifloat_format):
bit_width, exponent_bit_width, mantissa_bit_width, signed = minifloat_format
# specifically don't set exponent bias to see if default works
expected_exponent_bias = 2 ** (exponent_bit_width - 1) - 1
float_quant = FloatQuant(
bit_width=bit_width,
signed=signed,
exponent_bit_width=exponent_bit_width,
mantissa_bit_width=mantissa_bit_width)
assert expected_exponent_bias == float_quant.exponent_bias()
assert isinstance(float_quant.float_to_int_impl, RoundSte)
assert isinstance(float_quant.float_scaling_impl, ConstScaling)
assert isinstance(float_quant.scaling_impl, ConstScaling)


@given(minifloat_format=random_minifloat_format())
def test_minifloat(minifloat_format):
bit_width, exponent_bit_width, mantissa_bit_width, signed = minifloat_format
assert bit_width == exponent_bit_width + mantissa_bit_width + int(signed)


@given(inp=float_tensor_random_shape_st(), minifloat_format=random_minifloat_format())
@jit_disabled_for_mock()
def test_int_quant_to_in(inp, minifloat_format):
bit_width, exponent_bit_width, mantissa_bit_width, signed = minifloat_format
exponent_bias = 2 ** (exponent_bit_width - 1) - 1
float_quant = FloatQuant(
bit_width=bit_width,
signed=signed,
exponent_bit_width=exponent_bit_width,
mantissa_bit_width=mantissa_bit_width,
exponent_bias=exponent_bias)
expected_out, _, _, bit_width_out = float_quant(inp)

out_quant, scale = float_quant.quantize(inp)
assert bit_width_out == bit_width
assert torch.equal(expected_out, out_quant * scale)
20 changes: 20 additions & 0 deletions tests/brevitas/hyp_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
import torch

from tests.brevitas.common import FP32_BIT_WIDTH
from tests.brevitas.common import MAX_INT_BIT_WIDTH
from tests.brevitas.common import MIN_INT_BIT_WIDTH
from tests.conftest import SEED

# Remove Hypothesis check for slow tests and function scoped fixtures.
Expand Down Expand Up @@ -218,3 +220,21 @@ def min_max_tensor_random_shape_st(draw, min_dims=1, max_dims=4, max_size=3, wid
min_tensor = torch.tensor(min_list).view(*shape)
max_tensor = torch.tensor(max_list).view(*shape)
return min_tensor, max_tensor


@st.composite
def random_minifloat_format(draw, min_bit_width=MIN_INT_BIT_WIDTH, max_bit_with=MAX_INT_BIT_WIDTH):
""""
Generate a minifloat format. Returns bit_width, exponent, mantissa, and signed.
"""
bit_width = draw(st.integers(min_value=min_bit_width, max_value=max_bit_with))
exponent_bit_width = draw(st.integers(min_value=0, max_value=bit_width))
signed = draw(st.booleans())
# if no budget is left, return
if bit_width == exponent_bit_width:
return bit_width, exponent_bit_width, 0, False
elif bit_width == (exponent_bit_width + int(signed)):
return bit_width, exponent_bit_width, 0, signed
mantissa_bit_width = bit_width - exponent_bit_width - int(signed)

return bit_width, exponent_bit_width, mantissa_bit_width, signed

0 comments on commit 15b5938

Please sign in to comment.