Skip to content

Commit

Permalink
review
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Aug 31, 2024
1 parent 12bf43f commit 4162ef3
Show file tree
Hide file tree
Showing 8 changed files with 72 additions and 21 deletions.
9 changes: 8 additions & 1 deletion src/brevitas/proxy/float_parameter_quant.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
from abc import ABC
from typing import Any, Optional, Tuple, Union

from torch import Tensor
import torch.nn as nn

from brevitas.inject import BaseInjector as Injector
from brevitas.proxy.parameter_quant import BiasQuantProxyFromInjectorBase
from brevitas.proxy.parameter_quant import WeightQuantProxyFromInjectorBase
from brevitas.quant_tensor import FloatQuantTensor
from brevitas.utils.quant_utils import _CachedIOFloat


class WeightFloatQuantProxyFromInjectorBase(WeightQuantProxyFromInjectorBase):
class WeightFloatQuantProxyFromInjectorBase(WeightQuantProxyFromInjectorBase, ABC):

def scale(self):
if not self.is_quant_enabled:
Expand Down Expand Up @@ -83,6 +86,10 @@ def is_fnuz(self):

class WeightFloatQuantProxyFromInjector(WeightFloatQuantProxyFromInjectorBase):

def __init__(self, quant_layer: nn.Module, quant_injector: Injector) -> None:
super().__init__(quant_layer, quant_injector)
self.cache_class = _CachedIOFloat

def create_quant_tensor(self, qt_args: Tuple[Any]) -> FloatQuantTensor:
out, scale, zero_point, exponent_bit_width, mantissa_bit_width, exponent_bias, saturating, inf_values, nan_values = qt_args
return FloatQuantTensor(
Expand Down
15 changes: 11 additions & 4 deletions src/brevitas/proxy/float_runtime_quant.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
from typing import Any, Optional, Tuple
from abc import ABC
from typing import Any, Optional, Tuple, Union

import torch
import torch.nn as nn

from brevitas.inject import BaseInjector as Injector
from brevitas.proxy.runtime_quant import ActQuantProxyFromInjectorBase
from brevitas.quant_tensor import FloatQuantTensor
from brevitas.utils.quant_utils import _CachedIOFloat


class ActFloatQuantProxyFromInjectorBase(ActQuantProxyFromInjectorBase):
class ActFloatQuantProxyFromInjectorBase(ActQuantProxyFromInjectorBase, ABC):

def scale(self, force_eval=True):
return self.retrieve_attribute('scale', force_eval)
Expand Down Expand Up @@ -56,12 +61,14 @@ def is_fnuz(self):

class ActFloatQuantProxyFromInjector(ActFloatQuantProxyFromInjectorBase):

def __init__(self, quant_layer, quant_injector):
def __init__(self, quant_layer: nn.Module, quant_injector: Injector):
super().__init__(quant_layer, quant_injector)
self.cache_class = _CachedIOFloat

def create_quant_tensor(
self, qt_args: Tuple[Any], x: Optional[FloatQuantTensor] = None) -> FloatQuantTensor:
self,
qt_args: Union[torch.Tensor, Tuple[Any]],
x: Optional[FloatQuantTensor] = None) -> FloatQuantTensor:
if x is None:
out = FloatQuantTensor(*qt_args, signed=self.is_signed, training=self.training)
else:
Expand Down
8 changes: 8 additions & 0 deletions src/brevitas/proxy/groupwise_float_parameter_quant.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,19 @@
from typing import Any, Tuple

import torch.nn as nn

from brevitas.inject import BaseInjector as Injector
from brevitas.proxy.float_parameter_quant import WeightFloatQuantProxyFromInjectorBase
from brevitas.quant_tensor import GroupwiseFloatQuantTensor
from brevitas.utils.quant_utils import _CachedIOGroupwiseFloat


class GroupwiseWeightFloatQuantProxyFromInjector(WeightFloatQuantProxyFromInjectorBase):

def __init__(self, quant_layer: nn.Module, quant_injector: Injector) -> None:
super().__init__(quant_layer, quant_injector)
self.cache_class = _CachedIOGroupwiseFloat

@property
def group_dim(self):
return self.quant_injector.group_dim
Expand Down
6 changes: 4 additions & 2 deletions src/brevitas/proxy/groupwise_float_runtime_quant.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from typing import Any, Optional, Tuple
from typing import Any, Optional, Tuple, Union

import torch

from brevitas.proxy.float_runtime_quant import ActFloatQuantProxyFromInjectorBase
from brevitas.quant_tensor import GroupwiseFloatQuantTensor
Expand All @@ -21,7 +23,7 @@ def group_size(self):

def create_quant_tensor(
self,
qt_args: Tuple[Any],
qt_args: Union[torch.Tensor, Tuple[Any]],
x: Optional[GroupwiseFloatQuantTensor] = None) -> GroupwiseFloatQuantTensor:
if x is None:
value, scale, zero_point, exponent_bit_width, mantissa_bit_width, exponent_bias, saturating, inf_values, nan_values = qt_args
Expand Down
12 changes: 10 additions & 2 deletions src/brevitas/proxy/groupwise_int_parameter_quant.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,19 @@
from typing import Any, List
from typing import Any, Tuple

import torch.nn as nn

from brevitas.inject import BaseInjector as Injector
from brevitas.proxy.parameter_quant import WeightQuantProxyFromInjector
from brevitas.quant_tensor import GroupwiseIntQuantTensor
from brevitas.utils.quant_utils import _CachedIOGroupwiseInt


class GroupwiseWeightQuantProxyFromInjector(WeightQuantProxyFromInjector):

def __init__(self, quant_layer: nn.Module, quant_injector: Injector) -> None:
super().__init__(quant_layer, quant_injector)
self.cache_class = _CachedIOGroupwiseInt

@property
def group_dim(self):
return self.quant_injector.group_dim
Expand All @@ -14,7 +22,7 @@ def group_dim(self):
def group_size(self):
return self.quant_injector.group_size

def create_quant_tensor(self, qt_args: List[Any]) -> GroupwiseIntQuantTensor:
def create_quant_tensor(self, qt_args: Tuple[Any]) -> GroupwiseIntQuantTensor:
out, scale, zero_point, bit_width = qt_args
return GroupwiseIntQuantTensor(
out,
Expand Down
6 changes: 4 additions & 2 deletions src/brevitas/proxy/groupwise_int_runtime_quant.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from typing import Any, Optional, Tuple
from typing import Any, Optional, Tuple, Union

import torch

from brevitas.proxy.runtime_quant import ActQuantProxyFromInjector
from brevitas.quant_tensor import GroupwiseIntQuantTensor
Expand All @@ -21,7 +23,7 @@ def group_size(self):

def create_quant_tensor(
self,
qt_args: Tuple[Any],
qt_args: Union[torch.Tensor, Tuple[Any]],
x: Optional[GroupwiseIntQuantTensor] = None) -> GroupwiseIntQuantTensor:
if x is None:
value, scale, zero_point, bit_width, = qt_args
Expand Down
12 changes: 9 additions & 3 deletions src/brevitas/proxy/parameter_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,11 +116,17 @@ def create_quant_tensor(self, qt_args: Tuple[Any]) -> Union[Tensor, QuantTensor]

def forward(self, x: torch.Tensor) -> Union[Tensor, QuantTensor]:
if self.is_quant_enabled:
if self._cached_weight is not None and not self.cache_inference_quant_weight_metadata_only:
# If quant is enabled the priority is:
# - export mode
# - cached weight
# - quantization flow
if self.export_mode:
out = self.export_handler(x)
out = self.create_quant_tensor(out)
elif self._cached_weight is not None and not self.cache_inference_quant_weight_metadata_only:
out = self._cached_weight.quant_tensor
else:
impl = self.export_handler if self.export_mode else self.tensor_quant
out = impl(x)
out = self.tensor_quant(x)
out = self.create_quant_tensor(out)
if not self.training and self.cache_inference_quant_weight and self._cached_weight is None:
self._cached_weight = self.cache_class(
Expand Down
25 changes: 18 additions & 7 deletions src/brevitas/proxy/runtime_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from abc import abstractmethod
from typing import Any, Optional, Tuple, Union

import torch
from torch import nn
from torch import Tensor
from torch.nn import Identity
Expand Down Expand Up @@ -142,7 +143,14 @@ def init_tensor_quant(self):
self.fused_activation_quant_proxy = None

@abstractmethod
def create_quant_tensor(self, qt_args, x=None):
def create_quant_tensor(
self,
qt_args: Union[torch.Tensor, Tuple[Any]],
x: Optional[QuantTensor] = None) -> QuantTensor:
# Supports the following:
# - qt_args as tuple of Tensors and bools = standard quant activations
# - qt_args as Tensor and x as QuantTensor = passthrough activation
# In both cases, the output is a QuantTensor
raise NotImplementedError

def forward(self, x: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTensor]:
Expand All @@ -160,21 +168,24 @@ def forward(self, x: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTensor]:
elif not self.is_quant_enabled:
# A tuple helps later with control flows
# The second None value is used later
y = (self.fused_activation_quant_proxy.activation_impl(y), None)
y = self.fused_activation_quant_proxy.activation_impl(y)
else:
y = self.fused_activation_quant_proxy(y)
# If y is an empty IntQuantTensor, we need to check if this is a passthrough proxy,
# otherwise return a simple Tensor

# If y is an empty QuantTensor, we need to check if this is a passthrough proxy,
# otherwise return a simple Tensor
# If the second value (i.e., scale) is None, then quant is disabled
if isinstance(y, tuple) and y[1] is not None:
out = self.create_quant_tensor(y)
elif self.is_passthrough_act and isinstance(x, QuantTensor):
# preserve scale/zp/bit/sign even without output quant
y = y[0]
# preserve quant_metadata
if isinstance(y, tuple):
y = y[0]
out = self.create_quant_tensor(y, x=x)
else:
out = y[0]
if isinstance(y, tuple):
y = y[0]
out = y

if not self.training and self.cache_inference_quant_act and isinstance(out, QuantTensor):
cached_out = self.cache_class(out.detach(), self.cache_quant_io_metadata_only)
Expand Down

0 comments on commit 4162ef3

Please sign in to comment.