diff --git a/src/brevitas/core/quant/int.py b/src/brevitas/core/quant/int.py index e7c5560f8..a9bcf6e6b 100644 --- a/src/brevitas/core/quant/int.py +++ b/src/brevitas/core/quant/int.py @@ -8,6 +8,7 @@ from torch.nn import Module import brevitas +from brevitas.core.function_wrapper.misc import Identity from brevitas.core.quant.delay import DelayWrapper from brevitas.core.utils import StatelessBuffer from brevitas.function.ops_ste import round_ste @@ -138,13 +139,18 @@ def __init__( scaling_impl: Module, int_scaling_impl: Module, zero_point_impl: Module, - bit_width_impl: Module): + bit_width_impl: Module, + scaling_int_quant: Optional[Module] = None): super(RescalingIntQuant, self).__init__() self.int_quant = int_quant self.scaling_impl = scaling_impl self.int_scaling_impl = int_scaling_impl self.zero_point_impl = zero_point_impl self.msb_clamp_bit_width_impl = bit_width_impl + if scaling_int_quant is None: + self.scaling_int_quant = Identity() + else: + self.scaling_int_quant = scaling_int_quant @brevitas.jit.script_method def forward(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]: diff --git a/src/brevitas/core/restrict_val.py b/src/brevitas/core/restrict_val.py index 0720e595e..97a957fc6 100644 --- a/src/brevitas/core/restrict_val.py +++ b/src/brevitas/core/restrict_val.py @@ -179,3 +179,30 @@ def forward(self, x: torch.Tensor): x = self.float_to_int_impl(x) x = self.power_of_two(x) return x + + +class QuantRestrictValue(brevitas.jit.ScriptModule): + + def __init__(self, restrict_value_float_to_int_impl: Module): + super(QuantRestrictValue, self).__init__() + self.float_to_int_impl = restrict_value_float_to_int_impl + + def restrict_init_float(self, x: float): + return Identity() + + def restrict_init_tensor(self, x: torch.Tensor): + return Identity() + + def restrict_init_module(self): + return Identity() + + def restrict_init_inplace_module(self): + return Identity() + + def retrocompatibility_op(self, x): + return Identity() + + @brevitas.jit.script_method + def forward(self, x: torch.Tensor): + o, *_ = self.float_to_int_impl(x) + return o diff --git a/src/brevitas/core/scaling/quant.py b/src/brevitas/core/scaling/quant.py new file mode 100644 index 000000000..0938886ba --- /dev/null +++ b/src/brevitas/core/scaling/quant.py @@ -0,0 +1,12 @@ +import torch.nn as nn + +class QuantScaling(nn.Module): + def __init__(self, scale_rescaling_int_quant) -> None: + super().__init__() + self.scale_rescaling_int_quant = scale_rescaling_int_quant + + def forward(self, x): + _,scale, *_ = self.scale_rescaling_int_quant(x) + # print(x.shape, scale.shape) + print(x, scale) + return x * scale \ No newline at end of file