Skip to content

Commit

Permalink
Fix (example/quantizers): correct zero point handling
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Jan 23, 2024
1 parent 6baf630 commit 5912644
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 0 deletions.
28 changes: 28 additions & 0 deletions src/brevitas_examples/common/generative/quant_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
import brevitas
from brevitas.core.function_wrapper.shape import PermuteDims
from brevitas.core.utils import SliceTensor
from brevitas.core.zero_point import _ScaleShiftZeroPoint
from brevitas.function.ops_ste import abs_binary_sign_grad


class OverSubChannelBlockView(brevitas.jit.ScriptModule):
Expand Down Expand Up @@ -108,6 +110,32 @@ def forward(self, x) -> Tensor:
return x


class RuntimeDynamicStatsZeroPoint(brevitas.jit.ScriptModule):
__constants__ = ['dynamic_scaling_broadcastable_shape']

def __init__(
self,
zero_point_stats_impl: nn.Module,
int_quant: nn.Module,
quantize_zero_point: bool,
dynamic_scaling_broadcastable_shape: Tuple[int, ...],
zero_point_stats_input_view_shape_impl: nn.Module) -> None:
super(RuntimeDynamicStatsZeroPoint, self).__init__()
self.zero_point_stats_input_view_shape_impl = zero_point_stats_input_view_shape_impl
self.zero_point_stats_impl = zero_point_stats_impl
self.dynamic_scaling_broadcastable_shape = dynamic_scaling_broadcastable_shape
self.scale_shift_zero_point = _ScaleShiftZeroPoint(int_quant, quantize_zero_point)

@brevitas.jit.script_method
def forward(self, x, scale, bit_width) -> Tensor:
x = self.zero_point_stats_input_view_shape_impl(x)
x = self.zero_point_stats_impl(x)
x = x.view(self.dynamic_scaling_broadcastable_shape)
x = abs_binary_sign_grad(x)
x = self.scale_shift_zero_point(x, scale, bit_width)
return x


class RuntimeDynamicGroupStatsScaling(brevitas.jit.ScriptModule):

def __init__(self, group_size: int, group_dim: int, scaling_stats_impl: nn.Module) -> None:
Expand Down
1 change: 1 addition & 0 deletions src/brevitas_examples/common/generative/quantizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,5 +162,6 @@ class ShiftedUint8DynamicActPerTensorFloat(DynamicActProxyMixin, ShiftedUint8Act
scaling_impl = RuntimeDynamicStatsScaling
scaling_stats_input_view_shape_impl = OverTensorView
scaling_stats_op = 'min_max'
zero_point_impl = RuntimeDynamicStatsZeroPoint
zero_point_stats_impl = NegativeMinOrZero
dynamic_scaling_broadcastable_shape = this.scaling_shape

0 comments on commit 5912644

Please sign in to comment.