Skip to content

Commit

Permalink
Feat (core): quant scale support
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Oct 2, 2024
1 parent ee157bc commit c34bca2
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 1 deletion.
8 changes: 7 additions & 1 deletion src/brevitas/core/quant/int.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from torch.nn import Module

import brevitas
from brevitas.core.function_wrapper.misc import Identity
from brevitas.core.quant.delay import DelayWrapper
from brevitas.core.utils import StatelessBuffer
from brevitas.function.ops_ste import round_ste
Expand Down Expand Up @@ -138,13 +139,18 @@ def __init__(
scaling_impl: Module,
int_scaling_impl: Module,
zero_point_impl: Module,
bit_width_impl: Module):
bit_width_impl: Module,
scaling_int_quant: Optional[Module] = None):
super(RescalingIntQuant, self).__init__()
self.int_quant = int_quant
self.scaling_impl = scaling_impl
self.int_scaling_impl = int_scaling_impl
self.zero_point_impl = zero_point_impl
self.msb_clamp_bit_width_impl = bit_width_impl
if scaling_int_quant is None:
self.scaling_int_quant = Identity()
else:
self.scaling_int_quant = scaling_int_quant

@brevitas.jit.script_method
def forward(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
Expand Down
27 changes: 27 additions & 0 deletions src/brevitas/core/restrict_val.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,3 +179,30 @@ def forward(self, x: torch.Tensor):
x = self.float_to_int_impl(x)
x = self.power_of_two(x)
return x


class QuantRestrictValue(brevitas.jit.ScriptModule):

def __init__(self, restrict_value_float_to_int_impl: Module):
super(QuantRestrictValue, self).__init__()
self.float_to_int_impl = restrict_value_float_to_int_impl

def restrict_init_float(self, x: float):
return Identity()

def restrict_init_tensor(self, x: torch.Tensor):
return Identity()

def restrict_init_module(self):
return Identity()

def restrict_init_inplace_module(self):
return Identity()

def retrocompatibility_op(self, x):
return Identity()

@brevitas.jit.script_method
def forward(self, x: torch.Tensor):
o, *_ = self.float_to_int_impl(x)
return o
12 changes: 12 additions & 0 deletions src/brevitas/core/scaling/quant.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import torch.nn as nn

class QuantScaling(nn.Module):
def __init__(self, scale_rescaling_int_quant) -> None:
super().__init__()
self.scale_rescaling_int_quant = scale_rescaling_int_quant

def forward(self, x):
_,scale, *_ = self.scale_rescaling_int_quant(x)
# print(x.shape, scale.shape)
print(x, scale)
return x * scale

0 comments on commit c34bca2

Please sign in to comment.