diff --git a/src/brevitas_examples/common/generative/quant_blocks.py b/src/brevitas_examples/common/generative/quant_blocks.py index a4334157b..0ba94f35d 100644 --- a/src/brevitas_examples/common/generative/quant_blocks.py +++ b/src/brevitas_examples/common/generative/quant_blocks.py @@ -12,6 +12,8 @@ import brevitas from brevitas.core.function_wrapper.shape import PermuteDims from brevitas.core.utils import SliceTensor +from brevitas.core.zero_point import _ScaleShiftZeroPoint +from brevitas.function.ops_ste import abs_binary_sign_grad class OverSubChannelBlockView(brevitas.jit.ScriptModule): @@ -108,6 +110,32 @@ def forward(self, x) -> Tensor: return x +class RuntimeDynamicStatsZeroPoint(brevitas.jit.ScriptModule): + __constants__ = ['dynamic_scaling_broadcastable_shape'] + + def __init__( + self, + zero_point_stats_impl: nn.Module, + int_quant: nn.Module, + quantize_zero_point: bool, + dynamic_scaling_broadcastable_shape: Tuple[int, ...], + zero_point_stats_input_view_shape_impl: nn.Module) -> None: + super(RuntimeDynamicStatsZeroPoint, self).__init__() + self.zero_point_stats_input_view_shape_impl = zero_point_stats_input_view_shape_impl + self.zero_point_stats_impl = zero_point_stats_impl + self.dynamic_scaling_broadcastable_shape = dynamic_scaling_broadcastable_shape + self.scale_shift_zero_point = _ScaleShiftZeroPoint(int_quant, quantize_zero_point) + + @brevitas.jit.script_method + def forward(self, x, scale, bit_width) -> Tensor: + x = self.zero_point_stats_input_view_shape_impl(x) + x = self.zero_point_stats_impl(x) + x = x.view(self.dynamic_scaling_broadcastable_shape) + x = abs_binary_sign_grad(x) + x = self.scale_shift_zero_point(x, scale, bit_width) + return x + + class RuntimeDynamicGroupStatsScaling(brevitas.jit.ScriptModule): def __init__(self, group_size: int, group_dim: int, scaling_stats_impl: nn.Module) -> None: diff --git a/src/brevitas_examples/common/generative/quantizers.py b/src/brevitas_examples/common/generative/quantizers.py index e667a857d..40d101c02 100644 --- a/src/brevitas_examples/common/generative/quantizers.py +++ b/src/brevitas_examples/common/generative/quantizers.py @@ -162,5 +162,6 @@ class ShiftedUint8DynamicActPerTensorFloat(DynamicActProxyMixin, ShiftedUint8Act scaling_impl = RuntimeDynamicStatsScaling scaling_stats_input_view_shape_impl = OverTensorView scaling_stats_op = 'min_max' + zero_point_impl = RuntimeDynamicStatsZeroPoint zero_point_stats_impl = NegativeMinOrZero dynamic_scaling_broadcastable_shape = this.scaling_shape