diff --git a/src/brevitas/core/function_wrapper/shape.py b/src/brevitas/core/function_wrapper/shape.py index f1dfc7796..e175e4445 100644 --- a/src/brevitas/core/function_wrapper/shape.py +++ b/src/brevitas/core/function_wrapper/shape.py @@ -165,6 +165,11 @@ def __init__(self, expanded_groupwise_shape, group_size, group_dim) -> None: @brevitas.jit.script_method def forward(self, x: torch.Tensor): + # This one is a bit tricky but we could end up here: + # - If we quantize the zero point, which will already have expanded shape matching the scale (although no padding, but we don't need the padding) + # - Groupwise HQO quantization, where weight will already have been padded and expanded + if len(x.shape) == len(self.expanded_groupwise_shape): + return x y = torch.nn.functional.pad( x, padding(x, self.group_size, self.group_dim), mode='constant', value=0.) y = y.view(self.expanded_groupwise_shape) diff --git a/src/brevitas/core/stats/stats_op.py b/src/brevitas/core/stats/stats_op.py index 461aeb3e6..29d4d06e8 100644 --- a/src/brevitas/core/stats/stats_op.py +++ b/src/brevitas/core/stats/stats_op.py @@ -692,7 +692,8 @@ def parameter_search(self, xl, x): self.set_local_loss_mode(False) qt_value = self.input_view_shape_impl(quant_tensor.value) qt_scale = self.input_view_shape_impl(quant_tensor.scale) - qt_int = self.input_view_shape_impl(quant_tensor.int()) + qt_zp = self.input_view_shape_impl(quant_tensor.zero_point) + qt_int = qt_value / qt_scale + qt_zp loss = torch.abs(qt_value - x).mean() best_candidate = torch.where(loss < best_loss, candidate, best_candidate) if loss >= best_loss: @@ -700,6 +701,9 @@ def parameter_search(self, xl, x): best_loss = torch.min(loss, best_loss) W_e = shrink_lp_op(x - qt_value, self.beta, self.lp_norm) + # Compared to the original formulation, the value we're looking for is: + # - scaled by qt_scale + # - opposite sign val = self.input_view_shape_impl((x - W_e) - qt_int * qt_scale) if self.stats_reduce_dim is None: diff --git a/src/brevitas_examples/common/generative/quantize.py b/src/brevitas_examples/common/generative/quantize.py index 10f7ce259..f74a91933 100644 --- a/src/brevitas_examples/common/generative/quantize.py +++ b/src/brevitas_examples/common/generative/quantize.py @@ -98,13 +98,16 @@ 'sym': Int8WeightPerChannelFixedPoint}, 'per_group': { 'sym': MXInt8Weight, 'asym': ShiftedMXUInt8Weight}}, + 'hqo': { + 'per_group': { + 'asym': MXHQO}}, 'mse': { 'per_tensor': { 'sym': Int8WeightPerTensorFixedPointMSE}, 'per_channel': { - 'sym': Int8WeightPerChannelFixedPointMSE}}, - 'per_group': { - 'sym': MXInt8WeightMSE, 'asym': ShiftedMXUInt8WeightMSE}}}, + 'sym': Int8WeightPerChannelFixedPointMSE}, + 'per_group': { + 'sym': MXInt8WeightMSE, 'asym': ShiftedMXUInt8WeightMSE}}}}, 'float': { 'float_scale': { 'stats': {