Skip to content

Commit

Permalink
Typing
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Feb 5, 2024
1 parent 9ac99f0 commit e00a088
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 13 deletions.
4 changes: 2 additions & 2 deletions src/brevitas/nn/mixin/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def quant_output_bit_width(self):
else:
return None

def unpack_input(self, inp: Union[Tensor, QuantTensor]):
def unpack_input(self, inp: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTensor]:
self._set_global_is_quant_layer(True)
# Hack to recognize a QuantTensor that has decayed to a tuple
# when used as input to tracing (e.g. during ONNX export)
Expand All @@ -174,7 +174,7 @@ def unpack_input(self, inp: Union[Tensor, QuantTensor]):
inp = inp.rename(None)
return inp

def pack_output(self, quant_output: QuantTensor):
def pack_output(self, quant_output: QuantTensor) -> Union[Tensor, QuantTensor]:
if not self.training and self.cache_inference_quant_out and isinstance(quant_output,
QuantTensor):
self._cached_out = _CachedIO(quant_output.detach(), self.cache_quant_io_metadata_only)
Expand Down
15 changes: 7 additions & 8 deletions src/brevitas/proxy/parameter_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from abc import ABCMeta
from abc import abstractmethod
from typing import List, Optional, Tuple
from typing import Optional, Union

import torch
from torch import Tensor
Expand Down Expand Up @@ -94,7 +94,7 @@ def bit_width(self):
bit_width_ = self.__call__(self.tracked_parameter_list[0]).bit_width
return bit_width_

def forward(self, x: torch.Tensor) -> QuantTensor:
def forward(self, x: torch.Tensor) -> Union[Tensor, QuantTensor]:
if self.is_quant_enabled:
impl = self.export_handler if self.export_mode else self.tensor_quant
out, scale, zero_point, bit_width = impl(x)
Expand All @@ -115,13 +115,13 @@ def pre_zero_point(self):
out, scale, zero_point, bit_width, pre_scale, pre_zero_point = output_tuple
return pre_zero_point

def forward(self, x: torch.Tensor) -> QuantTensor:
def forward(self, x: torch.Tensor) -> Union[Tensor, QuantTensor]:
if self.is_quant_enabled:
impl = self.export_handler if self.export_mode else self.tensor_quant
out, scale, zero_point, bit_width, pre_scale, pre_zero_point = impl(x)
return QuantTensor(out, scale, zero_point, bit_width, self.is_signed, self.training)
else: # quantization disabled
return QuantTensor(x, training=self.training)
return x


class DecoupledWeightQuantWithInputProxyFromInjector(DecoupledWeightQuantProxyFromInjector):
Expand All @@ -145,9 +145,8 @@ def pre_scale(self):
def pre_zero_point(self):
raise NotImplementedError

def forward(
self, x: torch.Tensor, input_bit_width: torch.Tensor,
input_is_signed: bool) -> QuantTensor:
def forward(self, x: torch.Tensor, input_bit_width: torch.Tensor,
input_is_signed: bool) -> Union[Tensor, QuantTensor]:
if self.is_quant_enabled:
impl = self.export_handler if self.export_mode else self.tensor_quant
out, scale, zero_point, bit_width, pre_scale, pre_zero_point = impl(x, input_bit_width, input_is_signed)
Expand Down Expand Up @@ -199,7 +198,7 @@ def forward(
self,
x: Tensor,
input_scale: Optional[Tensor] = None,
input_bit_width: Optional[Tensor] = None) -> QuantTensor:
input_bit_width: Optional[Tensor] = None) -> Union[Tensor, QuantTensor]:
if self.is_quant_enabled:
impl = self.export_handler if self.export_mode else self.tensor_quant
if self.requires_input_scale and input_scale is None:
Expand Down
6 changes: 3 additions & 3 deletions src/brevitas/proxy/runtime_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def bit_width(self):
scale = self.__call__(self._zero_hw_sentinel()).bit_width
return scale

def forward(self, x: Union[Tensor, QuantTensor]) -> QuantTensor:
def forward(self, x: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTensor]:
if self.fused_activation_quant_proxy is not None:
y = x
if isinstance(y, QuantTensor):
Expand Down Expand Up @@ -188,7 +188,7 @@ def bit_width(self):

class ClampQuantProxyFromInjector(QuantProxyFromInjector, AccQuantProxyProtocol):

def forward(self, x: QuantTensor):
def forward(self, x: QuantTensor) -> Union[Tensor, QuantTensor]:
if self.is_quant_enabled:
out_tuple = self.tensor_quant(x.value, x.scale, x.bit_width)
out_value, out_scale, out_zp, out_bit_width = out_tuple
Expand All @@ -206,7 +206,7 @@ def bit_width(self):
bit_width = self.__call__(empty_imp).bit_width
return bit_width

def forward(self, x: QuantTensor):
def forward(self, x: QuantTensor) -> Union[Tensor, QuantTensor]:
if self.is_quant_enabled:
if self.export_mode:
out_tuple = self.export_handler(
Expand Down

0 comments on commit e00a088

Please sign in to comment.