diff --git a/src/brevitas/core/function_wrapper/clamp.py b/src/brevitas/core/function_wrapper/clamp.py index 1d6391d21..7cd9f047f 100644 --- a/src/brevitas/core/function_wrapper/clamp.py +++ b/src/brevitas/core/function_wrapper/clamp.py @@ -14,7 +14,7 @@ from brevitas.core.utils import StatelessBuffer from brevitas.function import tensor_clamp from brevitas.function.ops import max_float -from brevitas.utils.quant_utils import MAX_MANTISSA_DICT +from brevitas.utils.torch_utils import MAX_MANTISSA_DICT class TensorClamp(brevitas.jit.ScriptModule): diff --git a/src/brevitas/core/scaling/float_scaling.py b/src/brevitas/core/scaling/float_scaling.py index 86451523b..fc1721d91 100644 --- a/src/brevitas/core/scaling/float_scaling.py +++ b/src/brevitas/core/scaling/float_scaling.py @@ -9,7 +9,7 @@ import brevitas from brevitas.core.utils import StatelessBuffer from brevitas.function.ops import max_float -from brevitas.utils.quant_utils import MAX_MANTISSA_DICT +from brevitas.utils.torch_utils import MAX_MANTISSA_DICT class FloatScaling(brevitas.jit.ScriptModule): diff --git a/src/brevitas/export/inference/handler.py b/src/brevitas/export/inference/handler.py index 5945d0a8a..bafeb67ef 100644 --- a/src/brevitas/export/inference/handler.py +++ b/src/brevitas/export/inference/handler.py @@ -16,8 +16,8 @@ from brevitas.proxy.parameter_quant import BiasQuantProxyFromInjector from brevitas.proxy.parameter_quant import WeightQuantProxyFromInjector from brevitas.proxy.runtime_quant import ActQuantProxyFromInjector -from brevitas.utils.quant_utils import MAX_MANTISSA_DICT from brevitas.utils.torch_utils import float_internal_scale +from brevitas.utils.torch_utils import MAX_MANTISSA_DICT class InferenceHandler(torch.nn.Module, ABC): diff --git a/src/brevitas/utils/quant_utils.py b/src/brevitas/utils/quant_utils.py index be2847b32..6fd519b41 100644 --- a/src/brevitas/utils/quant_utils.py +++ b/src/brevitas/utils/quant_utils.py @@ -221,11 +221,3 @@ def float_to_int_impl_to_enum(module): return FloatToIntImplType.STOCHASTIC_ROUND else: return None - - -def max_mantissa_func(val): - import torch - return torch.sum((2. ** torch.arange(0, -1. * val - 1., -1.))) - - -MAX_MANTISSA_DICT = {x: max_mantissa_func(x) for x in range(0, 16)} diff --git a/src/brevitas/utils/torch_utils.py b/src/brevitas/utils/torch_utils.py index 2f0d34fba..ea4be5047 100644 --- a/src/brevitas/utils/torch_utils.py +++ b/src/brevitas/utils/torch_utils.py @@ -113,3 +113,11 @@ def padding(x: torch.Tensor, group_size: int, group_dim: int) -> List[int]: padding[2 * group_dim] = group_size - size[group_dim] % group_size padding = list(reversed(padding)) return padding + + +def max_mantissa_func(val): + import torch + return torch.sum((2. ** torch.arange(0, -1. * val - 1., -1.))) + + +MAX_MANTISSA_DICT = {x: max_mantissa_func(x) for x in range(0, 16)} diff --git a/tests/brevitas/core/test_clamp.py b/tests/brevitas/core/test_clamp.py index 8d4a6c117..e5430e140 100644 --- a/tests/brevitas/core/test_clamp.py +++ b/tests/brevitas/core/test_clamp.py @@ -14,7 +14,7 @@ from brevitas.quant.experimental.float_quant_ocp import Fp8e5m2OCPWeight from brevitas.utils.float_quant_utils import get_max_available_float from brevitas.utils.float_quant_utils import get_min_available_float -from brevitas.utils.quant_utils import MAX_MANTISSA_DICT +from brevitas.utils.torch_utils import MAX_MANTISSA_DICT from tests.brevitas.hyp_helper import float_tensor_random_shape_st from .minifloat_fixtures import * diff --git a/tests/brevitas/core/test_float_quant.py b/tests/brevitas/core/test_float_quant.py index b088fd036..a471f7bbf 100644 --- a/tests/brevitas/core/test_float_quant.py +++ b/tests/brevitas/core/test_float_quant.py @@ -14,8 +14,8 @@ from brevitas.core.scaling import ConstScaling from brevitas.core.scaling import FloatScaling from brevitas.function.ops import max_float -from brevitas.utils.quant_utils import MAX_MANTISSA_DICT from brevitas.utils.torch_utils import float_internal_scale +from brevitas.utils.torch_utils import MAX_MANTISSA_DICT from tests.brevitas.hyp_helper import float_st from tests.brevitas.hyp_helper import float_tensor_random_shape_st from tests.brevitas.hyp_helper import random_minifloat_format