Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Hierarchical scales #1038

Open
wants to merge 10 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion src/brevitas/core/quant/int.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from torch.nn import Module

import brevitas
from brevitas.core.function_wrapper.misc import Identity
from brevitas.core.quant.delay import DelayWrapper
from brevitas.core.utils import StatelessBuffer
from brevitas.function.ops_ste import round_ste
Expand Down Expand Up @@ -138,13 +139,18 @@ def __init__(
scaling_impl: Module,
int_scaling_impl: Module,
zero_point_impl: Module,
bit_width_impl: Module):
bit_width_impl: Module,
scaling_int_quant: Optional[Module] = None):
super(RescalingIntQuant, self).__init__()
self.int_quant = int_quant
self.scaling_impl = scaling_impl
self.int_scaling_impl = int_scaling_impl
self.zero_point_impl = zero_point_impl
self.msb_clamp_bit_width_impl = bit_width_impl
if scaling_int_quant is None:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unused

self.scaling_int_quant = Identity()
else:
self.scaling_int_quant = scaling_int_quant

@brevitas.jit.script_method
def forward(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
Expand Down
27 changes: 27 additions & 0 deletions src/brevitas/core/restrict_val.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,3 +179,30 @@ def forward(self, x: torch.Tensor):
x = self.float_to_int_impl(x)
x = self.power_of_two(x)
return x


class QuantRestrictValue(brevitas.jit.ScriptModule):

def __init__(self, restrict_value_float_to_int_impl: Module):
super(QuantRestrictValue, self).__init__()
self.float_to_int_impl = restrict_value_float_to_int_impl

def restrict_init_float(self, x: float):
return Identity()

def restrict_init_tensor(self, x: torch.Tensor):
return Identity()

def restrict_init_module(self):
return Identity()

def restrict_init_inplace_module(self):
return Identity()

def retrocompatibility_op(self, x):
return Identity()

@brevitas.jit.script_method
def forward(self, x: torch.Tensor):
o, *_ = self.float_to_int_impl(x)
return o
12 changes: 12 additions & 0 deletions src/brevitas/core/scaling/quant.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import torch.nn as nn

class QuantScaling(nn.Module):
def __init__(self, scale_rescaling_int_quant) -> None:
super().__init__()
self.scale_rescaling_int_quant = scale_rescaling_int_quant

def forward(self, x):
_,scale, *_ = self.scale_rescaling_int_quant(x)
# print(x.shape, scale.shape)
print(x, scale)
return x * scale
Loading