From 32322a572b9e2d5c3868622afc01d193e4fed5d5 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Mon, 2 Sep 2024 00:10:53 +0100 Subject: [PATCH] HQO --- src/brevitas/core/stats/stats_op.py | 221 ++++++++++++++++++ src/brevitas/core/zero_point.py | 2 +- src/brevitas/graph/gptq.py | 1 - src/brevitas/quant/base.py | 43 +++- src/brevitas/quant/scaled_int.py | 25 ++ src/brevitas/quant/shifted_scaled_int.py | 69 +++++- .../common/generative/quant_blocks.py | 3 + .../common/generative/quantize.py | 20 +- .../common/generative/quantizers.py | 12 +- .../imagenet_classification/ptq/ptq_common.py | 29 ++- .../ptq/ptq_evaluate.py | 2 +- src/brevitas_examples/llm/main.py | 2 +- 12 files changed, 399 insertions(+), 30 deletions(-) diff --git a/src/brevitas/core/stats/stats_op.py b/src/brevitas/core/stats/stats_op.py index fac729326..a49dd1858 100644 --- a/src/brevitas/core/stats/stats_op.py +++ b/src/brevitas/core/stats/stats_op.py @@ -10,6 +10,8 @@ import brevitas from brevitas import config +from brevitas.core.function_wrapper.misc import Identity +from brevitas.core.function_wrapper.ops_ste import ScalarClampMinSte from brevitas.core.utils import StatelessBuffer from brevitas.function.ops import max_int from brevitas.quant_tensor import _unpack_quant_tensor @@ -544,3 +546,222 @@ def forward(self, x): x = self.input_view_shape_impl(x) self.internal_candidate = self.mse_init_op(x).detach() return self.internal_candidate + + +class HalfQuadraticOptimizerScale(torch.nn.Module): + # References: + # https://mobiusml.github.io/hqq_blog/ + # https://github.com/mobiusml/hqq?tab=readme-ov-file + + def __init__( + self, + proxy_module, + hqo_init_op_scale, + keepdim: bool, + inner_stats_input_view_shape_impl: torch.nn.Module, + scaling_min_val: Optional[float] = None, + stats_reduce_dim: Optional[int] = None, + int_scaling_impl=None, + bit_width_impl=None, + hqo_beta_scale: float = 1e5, + hqo_kappa_scale: float = 1.01, + hqo_lp_norm_scale: float = .7, + hqo_iters_scale: int = 1000): + super(HalfQuadraticOptimizerScale, self).__init__() + self.hqo_init_op = hqo_init_op_scale + self.input_view_shape_impl = inner_stats_input_view_shape_impl + self.proxy_forward = proxy_module.forward + self.set_local_loss_mode = lambda enabled: _set_local_loss_mode(proxy_module, enabled) + self.internal_candidate = None + self.hqo_iters = hqo_iters_scale + self.stats_reduce_dim = stats_reduce_dim + self.local_loss_mode: bool = False + + self.beta = hqo_beta_scale + self.kappa = hqo_kappa_scale + self.lp_norm = hqo_lp_norm_scale + + self.int_scaling_impl = int_scaling_impl + self.msb_clamp_bit_width_impl = bit_width_impl + if scaling_min_val is not None and scaling_min_val != 0: + self.clamp_min_ste = ScalarClampMinSte(scaling_min_val) + else: + self.clamp_min_ste = Identity() + self.keepdim = keepdim + + def parameter_search(self, xl, x): + best_loss = torch.tensor(float('inf'), device=x.device, dtype=x.dtype) + candidate = xl + best_candidate = candidate + beta = self.beta + with torch.no_grad(): + for i in range(0, self.hqo_iters): + self.internal_candidate = candidate + self.set_local_loss_mode(True) + quant_tensor = self.proxy_forward(x).detach() + self.set_local_loss_mode(False) + loss = torch.abs(quant_tensor.value - x).mean() + + best_candidate = torch.where(loss < best_loss, candidate, best_candidate) + if loss >= best_loss: + break + best_loss = torch.min(loss, best_loss) + W_e = shrink_lp_op(x - quant_tensor.value, beta, self.lp_norm) + zero_point = quant_tensor.zero_point + num = self.input_view_shape_impl(x - W_e).detach() + den = self.input_view_shape_impl( + torch.round(quant_tensor.value / quant_tensor.scale) - zero_point).detach() + mask = (num != 0.) & (den != 0.) + if self.stats_reduce_dim is None: + candidate = masked_median(num / den, mask) + else: + candidate = masked_median( + num / den, mask, dim=self.stats_reduce_dim, keepdim=self.keepdim) + candidate = self.clamp_min_ste(candidate) + bit_width = self.msb_clamp_bit_width_impl() + int_threshold = self.int_scaling_impl(bit_width) + candidate = candidate * int_threshold + candidate[torch.isnan(candidate)] = self.internal_candidate[torch.isnan(candidate)] + candidate[torch.isinf(candidate)] = self.internal_candidate[torch.isinf(candidate)] + beta *= self.kappa + return best_candidate + + def optimize(self, x): + x_view = self.input_view_shape_impl(x) + + init = self.hqo_init_op(x_view).detach() + best_candidate = self.parameter_search(init, x_view) + + # Save for evaluation by other modules (e.g. zp) invoking local loss mode + self.internal_candidate = best_candidate.detach() + torch.cuda.empty_cache() + return best_candidate + + def forward(self, x): + if not self.local_loss_mode: + with torch.no_grad(): + return self.optimize(x) + else: + # This is invoked for the zero-point whenever scale is being optimized first + if self.internal_candidate is None: + x = self.input_view_shape_impl(x) + self.internal_candidate = self.hqo_init_op(x).detach() + return self.internal_candidate + + +class HalfQuadraticOptimizerZeroPoint(torch.nn.Module): + # References: + # https://mobiusml.github.io/hqq_blog/ + # https://github.com/mobiusml/hqq?tab=readme-ov-file + + def __init__( + self, + proxy_module, + keepdim: bool, + hqo_init_op_zp: torch.nn.Module, + inner_stats_input_view_shape_impl: torch.nn.Module, + stats_reduce_dim: Optional[int] = None, + hqo_beta_zp: float = 1e0, + hqo_kappa_zp: float = 1.01, + hqo_lp_norm_zp: float = .5, + hqo_iters_zp: int = 1000): + super(HalfQuadraticOptimizerZeroPoint, self).__init__() + self.hqo_init_op_zp = hqo_init_op_zp + self.input_view_shape_impl = inner_stats_input_view_shape_impl + self.proxy_forward = proxy_module.forward + self.set_local_loss_mode = lambda enabled: _set_local_loss_mode(proxy_module, enabled) + self.internal_candidate = None + self.stats_reduce_dim = stats_reduce_dim + self.local_loss_mode: bool = False + self.beta = hqo_beta_zp + self.kappa = hqo_kappa_zp + self.lp_norm = hqo_lp_norm_zp + self.hqo_iters = hqo_iters_zp + self.keepdim = keepdim + + def parameter_search(self, xl, x): + best_loss = torch.tensor(float('inf'), device=x.device, dtype=x.dtype) + candidate = xl + best_candidate = candidate + with torch.no_grad(): + for i in range(0, self.hqo_iters): + self.internal_candidate = candidate + self.set_local_loss_mode(True) + quant_tensor = self.proxy_forward(x).detach() + self.set_local_loss_mode(False) + qt_value = self.input_view_shape_impl(quant_tensor.value) + qt_scale = self.input_view_shape_impl(quant_tensor.scale) + qt_int = self.input_view_shape_impl(quant_tensor.int()) + loss = torch.abs(qt_value - x).mean() + best_candidate = torch.where(loss < best_loss, candidate, best_candidate) + if loss >= best_loss: + break + best_loss = torch.min(loss, best_loss) + W_e = shrink_lp_op(x - qt_value, self.beta, self.lp_norm) + + val = self.input_view_shape_impl((x - W_e) - qt_int * qt_scale) + + if self.stats_reduce_dim is None: + candidate = torch.mean(val) + else: + candidate = torch.mean(val, dim=self.stats_reduce_dim, keepdim=self.keepdim) + self.beta *= self.kappa + return best_candidate + + def optimize(self, x): + x_view = self.input_view_shape_impl(x) + + init = self.hqo_init_op_zp(x_view).detach() + + best_candidate = self.parameter_search(init, x) + + # Save for evaluation by other modules (e.g. zp) invoking local loss mode + self.internal_candidate = best_candidate.detach() + torch.cuda.empty_cache() + return best_candidate + + def forward(self, x): + if not self.local_loss_mode: + with torch.no_grad(): + return self.optimize(x) + else: + # This is invoked for the zero-point whenever scale is being optimized first + if self.internal_candidate is None: + x = self.input_view_shape_impl(x) + self.internal_candidate = self.hqo_init_op_zp(x).detach() + return self.internal_candidate + + +def masked_median(x, mask, dim=None, keepdim=False): + """Compute the median of tensor x along dim, ignoring values where mask is False. + x and mask need to be broadcastable. + + Args: + x (Tensor): Tensor to compute median of. + mask (BoolTensor): Same shape as x with True where x is valid and False + where x should be masked. Mask should not be all False in any column of + dimension dim to avoid NaNs from zero division. + dim (int, optional): Dimension to take median of. Defaults to 0. + + Returns: + Tensor: Same shape as x, except dimension dim reduced. + """ + # uncomment this assert for safety but might impact performance + # assert ( + # mask.sum(dim=dim).ne(0).all() + # ), "mask should not be all False in any column, causes zero division" + x_nan = x.float().masked_fill(~mask, float("nan")) + if dim is None: + x_median = x_nan.nanmedian() + else: + x_median, _ = x_nan.nanmedian(dim=dim, keepdim=keepdim) + return x_median + + +# Shrinking operator +def shrink_lp_op(x: Tensor, beta: float, lp_norm: float) -> Tensor: + if lp_norm == 1: + return torch.sign(x) * torch.nn.functional.relu(torch.abs(x) - 1.0 / beta) + else: + return torch.sign(x) * torch.nn.functional.relu( + torch.abs(x) - (1.0 / beta) * torch.pow(torch.abs(x), lp_norm - 1)) diff --git a/src/brevitas/core/zero_point.py b/src/brevitas/core/zero_point.py index 872435ec7..3f80f1dd4 100644 --- a/src/brevitas/core/zero_point.py +++ b/src/brevitas/core/zero_point.py @@ -281,7 +281,7 @@ def state_dict(self, destination=None, prefix='', keep_vars=False): output_dict = super(ParameterFromStatsFromParameterZeroPoint, self).state_dict( destination=destination, prefix=prefix, keep_vars=keep_vars) # Avoid saving the init value - if not self.init_done: + if not self.init_done and not config._FULL_STATE_DICT: del output_dict[prefix + 'value'] return output_dict diff --git a/src/brevitas/graph/gptq.py b/src/brevitas/graph/gptq.py index 31d31433b..fa5ad65b9 100644 --- a/src/brevitas/graph/gptq.py +++ b/src/brevitas/graph/gptq.py @@ -256,7 +256,6 @@ def single_layer_update(self, percdamp=.01): return finally: del self.H - for i1 in range(0, self.columns, self.blocksize): i2 = min(i1 + self.blocksize, self.columns) count = i2 - i1 diff --git a/src/brevitas/quant/base.py b/src/brevitas/quant/base.py index 18351a05b..6f08a1bed 100644 --- a/src/brevitas/quant/base.py +++ b/src/brevitas/quant/base.py @@ -36,6 +36,8 @@ from brevitas.core.stats import MSE from brevitas.core.stats import NegativeMinOrZero from brevitas.core.stats import NegativePercentileOrZero +from brevitas.core.stats.stats_op import HalfQuadraticOptimizerScale +from brevitas.core.stats.stats_op import HalfQuadraticOptimizerZeroPoint from brevitas.core.utils import SingleArgStatelessBuffer from brevitas.core.zero_point import ParameterFromRuntimeZeroPoint from brevitas.core.zero_point import ParameterFromStatsFromParameterZeroPoint @@ -452,7 +454,7 @@ class MSEAsymmetricScaleSubInjector(MSESubInjectorBase): stats_impl = MSE stats_reduce_dim = (this << 1).stats_reduce_dim device = (this << 1).device - type = (this << 1).type + dtype = (this << 1).dtype class MSEZeroPointSubInjector(MSESubInjectorBase): @@ -464,7 +466,7 @@ class MSEZeroPointSubInjector(MSESubInjectorBase): stats_impl = MSE stats_reduce_dim = (this << 1).stats_reduce_dim device = (this << 1).device - type = (this << 1).type + dtype = (this << 1).dtype class MSEAsymmetricScale(ExtendedInjector): @@ -514,3 +516,40 @@ class MSEWeightZeroPoint(MSEZeroPoint): class MSEActZeroPoint(MSEZeroPoint): zero_point_impl = ParameterFromRuntimeZeroPoint + + +class HQOZeroPoint(ExtendedInjector): + + hqo_init_op_zp = NegativeMinOrZero + inner_stats_input_view_shape_impl = this.zero_point_stats_input_view_shape_impl + stats_impl_zp = HalfQuadraticOptimizerZeroPoint + + @value + def zero_point_stats_impl(): + return this.stats_impl_zp + + +class HQOScale(ExtendedInjector): + scaling_impl_type = ScalingImplType.PARAMETER_FROM_STATS + inner_stats_input_view_shape_impl = this.scaling_stats_input_view_shape_impl + stats_impl_scale = HalfQuadraticOptimizerScale + + @value + def scaling_stats_impl(): + return this.stats_impl_scale + + +class HQOAsymmetricScale(HQOScale): + hqo_init_op_scale = AbsMinMax + + +class HQOSymmetricScale(HQOScale): + hqo_init_op_scale = AbsMax + + +class HQOActZeroPoint(HQOZeroPoint): + zero_point_impl = ParameterFromRuntimeZeroPoint + + +class HQOWeightZeroPoint(HQOZeroPoint): + zero_point_impl = ParameterFromStatsFromParameterZeroPoint diff --git a/src/brevitas/quant/scaled_int.py b/src/brevitas/quant/scaled_int.py index 0f67300c3..b5f9174d7 100644 --- a/src/brevitas/quant/scaled_int.py +++ b/src/brevitas/quant/scaled_int.py @@ -3,6 +3,7 @@ from brevitas.core.function_wrapper import TensorClamp from brevitas.quant.base import * +from brevitas.quant.base import HQOSymmetricScale from brevitas.quant.solver.act import ActQuantSolver from brevitas.quant.solver.bias import BiasQuantSolver from brevitas.quant.solver.trunc import TruncQuantSolver @@ -443,3 +444,27 @@ class Int8AccumulatorAwareZeroCenterWeightQuant(AccumulatorAwareZeroCenterWeight >>> conv.quant_weight() """ bit_width = 8 + + +class Int8WeightPerTensorFloatHQO(HQOSymmetricScale, Int8WeightPerTensorFloat): + """ + 8-bit narrow per-tensor signed int weight quantizer with per-tensor floating-point scale factor computed + from HQO local loss. + + Examples: + >>> from brevitas.nn import QuantLinear + >>> fc = QuantLinear(10, 5, bias=False, weight_quant=Int8WeightPerTensorFloatHQO) + """ + pass + + +class Int8WeightPerChannelFloatHQO(HQOSymmetricScale, Int8WeightPerChannelFloat): + """ + 8-bit narrow per-tensor signed int weight quantizer with per-tensor floating-point scale factor computed + from HQO local loss. + + Examples: + >>> from brevitas.nn import QuantLinear + >>> fc = QuantLinear(10, 5, bias=False, weight_quant=Int8WeightPerChannelFloatHQO) + """ + pass diff --git a/src/brevitas/quant/shifted_scaled_int.py b/src/brevitas/quant/shifted_scaled_int.py index 936737571..d18150a10 100644 --- a/src/brevitas/quant/shifted_scaled_int.py +++ b/src/brevitas/quant/shifted_scaled_int.py @@ -1,10 +1,12 @@ # Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause +from brevitas.inject.enum import ScalingPerOutputType +from brevitas.proxy.groupwise_int_parameter_quant import GroupwiseWeightQuantProxyFromInjector from brevitas.quant.base import * +from brevitas.quant.base import HQOActZeroPoint +from brevitas.quant.base import HQOZeroPoint from brevitas.quant.solver.act import ActQuantSolver -from brevitas.quant.solver.bias import BiasQuantSolver -from brevitas.quant.solver.trunc import TruncQuantSolver from brevitas.quant.solver.weight import WeightQuantSolver __all__ = [ @@ -15,7 +17,10 @@ 'ShiftedUint8ActPerTensorFixedPointMSE', 'ShiftedUint8ActPerTensorFloatMSE', 'ShiftedUint8WeightPerTensorFloatMSE', - 'ShiftedUint8WeightPerChannelFloatMSE'] + 'ShiftedUint8WeightPerChannelFloatMSE', + 'ShiftedUint8ActPerTensorFloatHQO', + 'ShiftedUint8WeightPerChannelFloatHQO', + 'ShiftedUint8WeightPerTensorFloatHQO'] class ShiftedUint8ActPerTensorFixedPoint(ShiftedParamFromPercentileUintQuant, @@ -138,3 +143,61 @@ class ShiftedUint8WeightPerChannelFloatMSE(MSEAsymmetricScale, >>> fc = QuantLinear(10, 5, bias=False, weight_quant=ShiftedUint8WeightPerChannelFloat) """ pass + + +class ShiftedUint8WeightPerTensorFloatHQO(HQOZeroPoint, ShiftedUint8WeightPerTensorFloat): + """ + 8-bit per-tensor unsigned int weight quantizer with floating-point per-channel scale factor and integer + zero point. Zero-point is initialized from HQO local loss. + + Examples: + >>> from brevitas.nn import QuantLinear + >>> fc = QuantLinear(10, 5, bias=False, weight_quant=ShiftedUint8WeightPerTensorFloatHQO) + """ + quantize_zero_point = False + + +class ShiftedUint8WeightPerChannelFloatHQO(HQOZeroPoint, ShiftedUint8WeightPerChannelFloat): + """ + 8-bit per-tensor unsigned int weight quantizer with floating-point per-channel scale factor and integer + zero point. Zero-point is initialized from HQO local loss. + + Examples: + >>> from brevitas.nn import QuantLinear + >>> fc = QuantLinear(10, 5, bias=False, weight_quant=ShiftedUint8WeightPerChannelFloatHQO) + """ + quantize_zero_point = False + + +class ShiftedUint8WeightPerGroupFloatHQO(ShiftedUint8WeightPerChannelFloatHQO): + """ + 8-bit per-tensor unsigned int weight quantizer with floating-point per-channel scale factor and integer + zero point.Zero-point is initialized from HQO local loss. + Examples: + >>> from brevitas.nn import QuantLinear + >>> fc = QuantLinear(10, 5, bias=False, weight_quant=ShiftedUint8WeightPerChannelFloatHQO) + """ + group_size = 32 + scaling_per_output_type = ScalingPerOutputType.GROUP + proxy_class = GroupwiseWeightQuantProxyFromInjector + + +class ShiftedUint8ActPerTensorFloatHQO(HQOActZeroPoint, ShiftedUint8ActPerTensorFloat): + """ + 8-bit per-tensor unsigned int activations quantizer with floating-point scale factor and + integer zero point. Zero-point is learned parameter initialized from + HQO local loss. + + Examples: + >>> from brevitas.nn import QuantReLU + >>> act = QuantReLU(act_quant=ShiftedUint8ActPerTensorFloatHQO) + """ + quantize_zero_point = False + + +class ShiftedUint8WeightGroupQuantFloat(ShiftedUint8WeightPerChannelFloat): + """ + Block / group / vector signed asymmetric weight quantizer with float scales and zero-points. + """ + proxy_class = GroupwiseWeightQuantProxyFromInjector + scaling_per_output_type = ScalingPerOutputType.GROUP diff --git a/src/brevitas_examples/common/generative/quant_blocks.py b/src/brevitas_examples/common/generative/quant_blocks.py index 18149578d..13cf82d46 100644 --- a/src/brevitas_examples/common/generative/quant_blocks.py +++ b/src/brevitas_examples/common/generative/quant_blocks.py @@ -10,9 +10,12 @@ import torch.nn as nn import brevitas +import brevitas.config as config from brevitas.core.function_wrapper.shape import PermuteDims +from brevitas.core.utils import inplace_tensor_add from brevitas.core.utils import SliceTensor from brevitas.core.zero_point import _ScaleShiftZeroPoint +from brevitas.core.zero_point import ParameterFromStatsFromParameterZeroPoint from brevitas.function.ops_ste import abs_binary_sign_grad diff --git a/src/brevitas_examples/common/generative/quantize.py b/src/brevitas_examples/common/generative/quantize.py index 57670f6f6..f3293b631 100644 --- a/src/brevitas_examples/common/generative/quantize.py +++ b/src/brevitas_examples/common/generative/quantize.py @@ -37,14 +37,20 @@ from brevitas.quant.scaled_int import Int8ActPerTensorFloat from brevitas.quant.scaled_int import Int8ActPerTensorFloatMSE from brevitas.quant.scaled_int import Int8WeightPerChannelFloat +from brevitas.quant.scaled_int import Int8WeightPerChannelFloatHQO from brevitas.quant.scaled_int import Int8WeightPerChannelFloatMSE from brevitas.quant.scaled_int import Int8WeightPerTensorFloat +from brevitas.quant.scaled_int import Int8WeightPerTensorFloatHQO from brevitas.quant.scaled_int import Int8WeightPerTensorFloatMSE from brevitas.quant.shifted_scaled_int import ShiftedUint8ActPerTensorFloat from brevitas.quant.shifted_scaled_int import ShiftedUint8ActPerTensorFloatMSE +from brevitas.quant.shifted_scaled_int import ShiftedUint8WeightGroupQuantFloat from brevitas.quant.shifted_scaled_int import ShiftedUint8WeightPerChannelFloat +from brevitas.quant.shifted_scaled_int import ShiftedUint8WeightPerChannelFloatHQO from brevitas.quant.shifted_scaled_int import ShiftedUint8WeightPerChannelFloatMSE +from brevitas.quant.shifted_scaled_int import ShiftedUint8WeightPerGroupFloatHQO from brevitas.quant.shifted_scaled_int import ShiftedUint8WeightPerTensorFloat +from brevitas.quant.shifted_scaled_int import ShiftedUint8WeightPerTensorFloatHQO from brevitas.quant.shifted_scaled_int import ShiftedUint8WeightPerTensorFloatMSE from brevitas_examples.common.generative.nn import LoRACompatibleQuantConv2d from brevitas_examples.common.generative.nn import LoRACompatibleQuantLinear @@ -56,7 +62,6 @@ from brevitas_examples.common.generative.quantizers import IntWeightSymmetricGroupQuant from brevitas_examples.common.generative.quantizers import ShiftedUint8DynamicActPerRowFloat from brevitas_examples.common.generative.quantizers import ShiftedUint8DynamicActPerTensorFloat -from brevitas_examples.common.generative.quantizers import ShiftedUintWeightAsymmetricGroupQuant WEIGHT_QUANT_MAP = { 'int': { @@ -68,14 +73,23 @@ 'sym': Int8WeightPerChannelFloat, 'asym': ShiftedUint8WeightPerChannelFloat}, 'per_group': { 'sym': IntWeightSymmetricGroupQuant, - 'asym': ShiftedUintWeightAsymmetricGroupQuant}}, + 'asym': ShiftedUint8WeightGroupQuantFloat}}, 'mse': { 'per_tensor': { 'sym': Int8WeightPerTensorFloatMSE, 'asym': ShiftedUint8WeightPerTensorFloatMSE}, 'per_channel': { 'sym': Int8WeightPerChannelFloatMSE, - 'asym': ShiftedUint8WeightPerChannelFloatMSE},},}, + 'asym': ShiftedUint8WeightPerChannelFloatMSE}}, + 'hqo': { + 'per_tensor': { + 'sym': Int8WeightPerTensorFloatHQO, + 'asym': ShiftedUint8WeightPerTensorFloatHQO}, + 'per_channel': { + 'sym': Int8WeightPerChannelFloatHQO, + 'asym': ShiftedUint8WeightPerChannelFloatHQO}, + 'per_group': { + 'asym': ShiftedUint8WeightPerGroupFloatHQO}},}, 'po2_scale': { 'stats': { 'per_tensor': { diff --git a/src/brevitas_examples/common/generative/quantizers.py b/src/brevitas_examples/common/generative/quantizers.py index 1f41e136a..c3c99a96f 100644 --- a/src/brevitas_examples/common/generative/quantizers.py +++ b/src/brevitas_examples/common/generative/quantizers.py @@ -10,7 +10,9 @@ from brevitas.core.scaling.runtime import RuntimeDynamicGroupStatsScaling from brevitas.core.stats import AbsMinMax from brevitas.core.stats import NegativeMinOrZero +from brevitas.core.stats.stats_op import HalfQuadraticOptimizerZeroPoint from brevitas.core.stats.stats_wrapper import SCALAR_SHAPE +from brevitas.core.zero_point import StatsFromParameterZeroPoint from brevitas.inject import ExtendedInjector from brevitas.inject import this from brevitas.inject import value @@ -21,11 +23,13 @@ from brevitas.proxy.groupwise_int_parameter_quant import GroupwiseWeightQuantProxyFromInjector from brevitas.proxy.groupwise_int_runtime_quant import GroupwiseActQuantProxyFromInjector from brevitas.proxy.runtime_quant import DynamicActQuantProxyFromInjector +from brevitas.quant.base import HQOWeightZeroPoint from brevitas.quant.experimental.float import Fp8e4m3ActPerTensorFloat from brevitas.quant.experimental.float import Fp8e4m3WeightPerChannelFloat from brevitas.quant.experimental.float_quant_ocp import Fp8e4m3OCPActPerTensorFloat from brevitas.quant.scaled_int import Int8ActPerTensorFloat from brevitas.quant.scaled_int import Int8WeightPerChannelFloat +from brevitas.quant.scaled_int import Int8WeightPerChannelFloatHQO from brevitas.quant.shifted_scaled_int import ShiftedUint8ActPerTensorFloat from brevitas.quant.shifted_scaled_int import ShiftedUint8WeightPerChannelFloat @@ -54,14 +58,6 @@ class Fp8e4m3WeightSymmetricGroupQuant(Fp8e4m3WeightPerChannelFloat): scaling_per_output_type = ScalingPerOutputType.GROUP -class ShiftedUintWeightAsymmetricGroupQuant(ShiftedUint8WeightPerChannelFloat): - """ - Block / group / vector signed asymmetric weight quantizer with float scales and zero-points. - """ - proxy_class = GroupwiseWeightQuantProxyFromInjector - scaling_per_output_type = ScalingPerOutputType.GROUP - - class Int8DynamicActPerTensorFloat(DynamicActProxyMixin, Int8ActPerTensorFloat): """ Symmetric quantizer with per tensor dynamic scale. diff --git a/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py b/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py index 3c6b82243..2d7caa22b 100644 --- a/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py +++ b/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py @@ -22,7 +22,6 @@ from brevitas.graph.target.flexml import quantize_flexml from brevitas.inject import value import brevitas.nn as qnn -from brevitas.quant.experimental.float import Fp8e4m3Act from brevitas.quant.experimental.float import Fp8e4m3ActPerTensorFloat from brevitas.quant.experimental.float import Fp8e4m3ActPerTensorFloatMSE from brevitas.quant.experimental.float import Fp8e4m3WeightPerChannelFloat @@ -38,17 +37,22 @@ from brevitas.quant.scaled_int import Int8ActPerTensorFloat from brevitas.quant.scaled_int import Int8ActPerTensorFloatMSE from brevitas.quant.scaled_int import Int8WeightPerChannelFloat +from brevitas.quant.scaled_int import Int8WeightPerChannelFloatHQO from brevitas.quant.scaled_int import Int8WeightPerChannelFloatMSE from brevitas.quant.scaled_int import Int8WeightPerTensorFloat +from brevitas.quant.scaled_int import Int8WeightPerTensorFloatHQO from brevitas.quant.scaled_int import Int8WeightPerTensorFloatMSE from brevitas.quant.scaled_int import Int16Bias from brevitas.quant.scaled_int import Int32Bias from brevitas.quant.shifted_scaled_int import ShiftedUint8ActPerTensorFixedPoint from brevitas.quant.shifted_scaled_int import ShiftedUint8ActPerTensorFloat +from brevitas.quant.shifted_scaled_int import ShiftedUint8ActPerTensorFloatHQO from brevitas.quant.shifted_scaled_int import ShiftedUint8ActPerTensorFloatMSE from brevitas.quant.shifted_scaled_int import ShiftedUint8WeightPerChannelFloat +from brevitas.quant.shifted_scaled_int import ShiftedUint8WeightPerChannelFloatHQO from brevitas.quant.shifted_scaled_int import ShiftedUint8WeightPerChannelFloatMSE from brevitas.quant.shifted_scaled_int import ShiftedUint8WeightPerTensorFloat +from brevitas.quant.shifted_scaled_int import ShiftedUint8WeightPerTensorFloatHQO from brevitas.quant.shifted_scaled_int import ShiftedUint8WeightPerTensorFloatMSE from brevitas_examples.common.generative.quantizers import Int8DynamicActPerTensorFloat from brevitas_examples.common.generative.quantizers import ShiftedUint8DynamicActPerTensorFloat @@ -90,7 +94,14 @@ class CNNInt8DynamicActPerTensorFloat(Int8DynamicActPerTensorFloat): 'asym': ShiftedUint8WeightPerTensorFloatMSE}, 'per_channel': { 'sym': Int8WeightPerChannelFloatMSE, - 'asym': ShiftedUint8WeightPerChannelFloatMSE},},}, + 'asym': ShiftedUint8WeightPerChannelFloatMSE}}, + 'hqo': { + 'per_tensor': { + 'sym': Int8WeightPerTensorFloatHQO, + 'asym': ShiftedUint8WeightPerTensorFloatHQO}, + 'per_channel': { + 'sym': Int8WeightPerChannelFloatHQO, + 'asym': ShiftedUint8WeightPerChannelFloatHQO}}}, 'po2_scale': { 'stats': { 'per_tensor': { @@ -359,8 +370,11 @@ def kwargs_prefix(prefix, weight_kwargs): per_tensor_act_quant = INPUT_QUANT_MAP[act_quant_format][act_scale_computation_type][ act_scale_type][act_param_method]['per_tensor'][act_quant_type] act_quant = act_quant.let(**act_bit_width_dict) + act_quant = act_quant.let(**{'dtype': dtype, 'device': device}) sym_act_quant = sym_act_quant.let(**act_bit_width_dict) + sym_act_quant = sym_act_quant.let(**{'dtype': dtype, 'device': device}) per_tensor_act_quant = per_tensor_act_quant.let(**act_bit_width_dict) + per_tensor_act_quant = per_tensor_act_quant.let(**{'dtype': dtype, 'device': device}) else: act_quant = None sym_act_quant = None @@ -374,19 +388,14 @@ def kwargs_prefix(prefix, weight_kwargs): if weight_quant_type == 'asym': weight_quant = weight_quant.let(zero_point_impl=ParameterFromStatsFromParameterZeroPoint) if act_quant is not None: - act_quant = act_quant.let( - **{ - 'high_percentile_q': act_quant_percentile, 'dtype': dtype, 'device': device}) + act_quant = act_quant.let(**{'high_percentile_q': act_quant_percentile}) if act_quant_type == 'asym' and act_quant_percentile is not None: act_quant = act_quant.let(**{'low_percentile_q': 100 - act_quant_percentile}) if sym_act_quant is not None: - sym_act_quant = sym_act_quant.let( - **{ - 'high_percentile_q': act_quant_percentile, 'dtype': dtype, 'device': device}) + sym_act_quant = sym_act_quant.let(**{'high_percentile_q': act_quant_percentile}) if per_tensor_act_quant is not None: per_tensor_act_quant = per_tensor_act_quant.let( - **{ - 'high_percentile_q': act_quant_percentile, 'dtype': dtype, 'device': device}) + **{'high_percentile_q': act_quant_percentile}) if act_quant_type == 'asym' and act_quant_percentile is not None: per_tensor_act_quant = per_tensor_act_quant.let( **{'low_percentile_q': 100 - act_quant_percentile}) diff --git a/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py b/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py index 7e2bf6ee5..eabb2fb17 100644 --- a/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py +++ b/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py @@ -126,7 +126,7 @@ def parse_type(v, default_type): parser.add_argument( '--weight-quant-calibration-type', default='stats', - choices=['stats', 'mse'], + choices=['stats', 'mse', 'hqo'], help='Weight quantization calibration type (default: stats)') parser.add_argument( '--act-equalization', diff --git a/src/brevitas_examples/llm/main.py b/src/brevitas_examples/llm/main.py index 05d84f647..35b4c7119 100644 --- a/src/brevitas_examples/llm/main.py +++ b/src/brevitas_examples/llm/main.py @@ -63,7 +63,7 @@ '--weight-param-method', type=str, default='stats', - choices=['stats', 'mse'], + choices=['stats', 'mse', 'hqo'], help='How scales/zero-point are determined. Default: stats.') parser.add_argument( '--weight-scale-precision',