Skip to content

Commit

Permalink
Fix (tests): fix filter for NaN/inf values
Browse files Browse the repository at this point in the history
  • Loading branch information
fabianandresgrob committed Feb 5, 2024
1 parent db97387 commit 1c3c997
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 5 deletions.
4 changes: 0 additions & 4 deletions src/brevitas/core/quant/float.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,6 @@
from brevitas.function.ops import max_float
from brevitas.function.ops_ste import floor_ste

# max int that can be passed to torch.exp2() without running into inf
MAX_REPRESENTABLE_INT = 127


class FloatQuant(brevitas.jit.ScriptModule):
__constants__ = ['signed']
Expand Down Expand Up @@ -67,7 +64,6 @@ def __init__(
def internal_scale(self, x):
internal_scale = floor_ste(torch.log2(torch.abs(x))) - self.mantissa_bit_width()
internal_scale = torch.clamp_min(internal_scale, self.fp_internal_scale_min())
internal_scale = torch.clamp_max(internal_scale, MAX_REPRESENTABLE_INT)
internal_scale = torch.exp2(internal_scale)
return internal_scale

Expand Down
6 changes: 5 additions & 1 deletion tests/brevitas/core/test_float_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,4 +151,8 @@ def test_inner_scale(inp, minifloat_format, scale):
True if val == 0. or val.isnan() else False for val in expected_out.flatten()
]).all()
else:
assert torch.equal(out, expected_out)
# filter out NaN values as we can't compare them
# Note: this still checks if NaN appears at the same values
out_nans = out.isnan()
expected_out_nans = expected_out.isnan()
assert torch.equal(out[~out_nans], expected_out[~expected_out_nans])
1 change: 1 addition & 0 deletions tests/brevitas/hyp_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,7 @@ def random_minifloat_format(draw, min_bit_width=MIN_INT_BIT_WIDTH, max_bit_with=
""""
Generate a minifloat format. Returns bit_width, exponent, mantissa, and signed.
"""
# TODO: add support for new minifloat format that comes with FloatQuantTensor
bit_width = draw(st.integers(min_value=min_bit_width, max_value=max_bit_with))
exponent_bit_width = draw(st.integers(min_value=0, max_value=bit_width))
signed = draw(st.booleans())
Expand Down

0 comments on commit 1c3c997

Please sign in to comment.