From 948154e898ac07fea4358d764b947890e8d2ac7b Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Tue, 16 Apr 2024 12:25:58 +0100 Subject: [PATCH] HQO --- src/brevitas/core/stats/stats_op.py | 215 ++++++++++++++++++ src/brevitas/quant/base.py | 52 +++++ src/brevitas/quant/scaled_int.py | 25 ++ src/brevitas/quant/shifted_scaled_int.py | 25 ++ .../imagenet_classification/ptq/ptq_common.py | 15 +- 5 files changed, 331 insertions(+), 1 deletion(-) diff --git a/src/brevitas/core/stats/stats_op.py b/src/brevitas/core/stats/stats_op.py index fac729326..4cf42f6a8 100644 --- a/src/brevitas/core/stats/stats_op.py +++ b/src/brevitas/core/stats/stats_op.py @@ -10,6 +10,8 @@ import brevitas from brevitas import config +from brevitas.core.function_wrapper.misc import Identity +from brevitas.core.function_wrapper.ops_ste import ScalarClampMinSte from brevitas.core.utils import StatelessBuffer from brevitas.function.ops import max_int from brevitas.quant_tensor import _unpack_quant_tensor @@ -544,3 +546,216 @@ def forward(self, x): x = self.input_view_shape_impl(x) self.internal_candidate = self.mse_init_op(x).detach() return self.internal_candidate + + +class HalfQuadraticOptimizerScale(torch.nn.Module): + # References: + # https://mobiusml.github.io/hqq_blog/ + # https://github.com/mobiusml/hqq?tab=readme-ov-file + + def __init__( + self, + proxy_module, + mse_init_op, + inner_stats_input_view_shape_impl: torch.nn.Module, + scaling_min_val: Optional[float] = None, + stats_reduce_dim: Optional[int] = None, + int_scaling_impl=None, + bit_width_impl=None, + beta: float = 1e6, + kappa: float = 1.07, + lp_norm: float = .7, + hqo_iters: int = 1000): + super(HalfQuadraticOptimizerScale, self).__init__() + self.mse_init_op = mse_init_op + self.input_view_shape_impl = inner_stats_input_view_shape_impl + self.proxy_forward = proxy_module.forward + self.set_local_loss_mode = lambda enabled: _set_local_loss_mode(proxy_module, enabled) + self.internal_candidate = None + self.hqo_iters = hqo_iters + self.stats_reduce_dim = stats_reduce_dim + self.local_loss_mode: bool = False + + self.beta = beta + self.kappa = kappa + self.lp_norm = lp_norm + + self.int_scaling_impl = int_scaling_impl + self.msb_clamp_bit_width_impl = bit_width_impl + if scaling_min_val is not None and scaling_min_val != 0: + self.clamp_min_ste = ScalarClampMinSte(scaling_min_val) + else: + self.clamp_min_ste = Identity() + + def parameter_search(self, xl, x): + best_loss = torch.tensor(float('inf'), device=x.device, dtype=x.dtype) + candidate = xl + candidate = self.input_view_shape_impl(candidate) + best_candidate = candidate + beta = self.beta + with torch.no_grad(): + for i in range(0, self.hqo_iters): + self.internal_candidate = candidate + self.set_local_loss_mode(True) + quant_tensor = self.proxy_forward(x).detach() + self.set_local_loss_mode(False) + loss = torch.abs(quant_tensor.value - x).mean() + + best_candidate = torch.where(loss < best_loss, candidate, best_candidate) + if loss >= best_loss: + break + best_loss = torch.min(loss, best_loss) + W_e = shrink_lp_op(x - quant_tensor.value, beta, self.lp_norm) + zero_point = quant_tensor.zero_point + num = self.input_view_shape_impl(x - W_e).detach() + den = self.input_view_shape_impl(quant_tensor.int() - zero_point).detach() + mask = (num != 0.) & (den != 0.) + if self.stats_reduce_dim is None: + candidate = masked_median(num / den, mask) + else: + candidate = masked_median(num / den, mask, dim=self.stats_reduce_dim) + candidate = self.input_view_shape_impl(candidate) + candidate[torch.isnan(candidate)] = self.internal_candidate[torch.isnan(candidate)] + candidate = self.clamp_min_ste(candidate) + bit_width = self.msb_clamp_bit_width_impl() + int_threshold = self.int_scaling_impl(bit_width) + candidate = candidate * int_threshold + beta *= self.kappa + + return best_candidate + + def optimize(self, x): + x_view = self.input_view_shape_impl(x) + + init = self.mse_init_op(x_view).detach() + best_candidate = self.parameter_search(init, x) + + # Save for evaluation by other modules (e.g. zp) invoking local loss mode + self.internal_candidate = best_candidate.detach() + torch.cuda.empty_cache() + return best_candidate + + def forward(self, x): + if not self.local_loss_mode: + with torch.no_grad(): + return self.optimize(x) + else: + # This is invoked for the zero-point whenever scale is being optimized first + if self.internal_candidate is None: + x = self.input_view_shape_impl(x) + self.internal_candidate = self.mse_init_op(x).detach() + return self.internal_candidate + + +class HalfQuadraticOptimizerZeroPoint(torch.nn.Module): + # References: + # https://mobiusml.github.io/hqq_blog/ + # https://github.com/mobiusml/hqq?tab=readme-ov-file + + def __init__( + self, + proxy_module, + mse_init_op: torch.nn.Module, + inner_stats_input_view_shape_impl: torch.nn.Module, + stats_reduce_dim: Optional[int] = None, + beta: float = 1., + kappa: float = 1.01, + lp_norm: float = 1., + hqo_iters: int = 1000): + super(HalfQuadraticOptimizerZeroPoint, self).__init__() + self.mse_init_op = mse_init_op + self.input_view_shape_impl = inner_stats_input_view_shape_impl + self.proxy_forward = proxy_module.forward + self.set_local_loss_mode = lambda enabled: _set_local_loss_mode(proxy_module, enabled) + self.internal_candidate = None + self.stats_reduce_dim = stats_reduce_dim + self.local_loss_mode: bool = False + self.beta = beta + self.kappa = kappa + self.lp_norm = lp_norm + self.hqo_iters = hqo_iters + + def parameter_search(self, xl, x): + best_loss = torch.tensor(float('inf'), device=x.device, dtype=x.dtype) + candidate = xl + candidate = self.input_view_shape_impl(candidate) + best_candidate = candidate + with torch.no_grad(): + for i in range(0, self.hqo_iters): + self.internal_candidate = candidate + self.set_local_loss_mode(True) + quant_tensor = self.proxy_forward(x).detach() + self.set_local_loss_mode(False) + loss = torch.abs(quant_tensor.value - x).mean() + best_candidate = torch.where(loss < best_loss, candidate, best_candidate) + if loss >= best_loss: + break + best_loss = torch.min(loss, best_loss) + W_e = shrink_lp_op(x - quant_tensor.value, self.beta, self.lp_norm) + + val = self.input_view_shape_impl((x - W_e) - + quant_tensor.int() * quant_tensor.scale) + if self.stats_reduce_dim is None: + candidate = torch.mean(val) + else: + candidate = torch.mean(val, dim=self.stats_reduce_dim) + candidate = self.input_view_shape_impl(candidate) + self.beta *= self.kappa + return best_candidate + + def optimize(self, x): + x_view = self.input_view_shape_impl(x) + + init = self.mse_init_op(x_view).detach() + best_candidate = self.parameter_search(init, x) + + # Save for evaluation by other modules (e.g. zp) invoking local loss mode + self.internal_candidate = best_candidate.detach() + torch.cuda.empty_cache() + return best_candidate + + def forward(self, x): + if not self.local_loss_mode: + with torch.no_grad(): + return self.optimize(x) + else: + # This is invoked for the zero-point whenever scale is being optimized first + if self.internal_candidate is None: + x = self.input_view_shape_impl(x) + self.internal_candidate = self.mse_init_op(x).detach() + return self.internal_candidate + + +def masked_median(x, mask, dim=None): + """Compute the median of tensor x along dim, ignoring values where mask is False. + x and mask need to be broadcastable. + + Args: + x (Tensor): Tensor to compute median of. + mask (BoolTensor): Same shape as x with True where x is valid and False + where x should be masked. Mask should not be all False in any column of + dimension dim to avoid NaNs from zero division. + dim (int, optional): Dimension to take median of. Defaults to 0. + + Returns: + Tensor: Same shape as x, except dimension dim reduced. + """ + # uncomment this assert for safety but might impact performance + # assert ( + # mask.sum(dim=dim).ne(0).all() + # ), "mask should not be all False in any column, causes zero division" + x_nan = x.float().masked_fill(~mask, float("nan")) + if dim is None: + x_median = x_nan.nanmedian() + else: + x_median, _ = x_nan.nanmedian(dim=dim) + return x_median + + +# Shrinking operator +def shrink_lp_op(x: Tensor, beta: float, lp_norm: float) -> Tensor: + if lp_norm == 1: + return torch.sign(x) * torch.nn.functional.relu(torch.abs(x) - 1.0 / beta) + else: + return torch.sign(x) * torch.nn.functional.relu( + torch.abs(x) - (1.0 / beta) * torch.pow(torch.abs(x), lp_norm - 1)) diff --git a/src/brevitas/quant/base.py b/src/brevitas/quant/base.py index b509e6c16..7c65e60e1 100644 --- a/src/brevitas/quant/base.py +++ b/src/brevitas/quant/base.py @@ -36,6 +36,8 @@ from brevitas.core.stats import MSE from brevitas.core.stats import NegativeMinOrZero from brevitas.core.stats import NegativePercentileOrZero +from brevitas.core.stats.stats_op import HalfQuadraticOptimizerScale +from brevitas.core.stats.stats_op import HalfQuadraticOptimizerZeroPoint from brevitas.core.utils import SingleArgStatelessBuffer from brevitas.core.zero_point import ParameterFromRuntimeZeroPoint from brevitas.core.zero_point import ParameterFromStatsFromParameterZeroPoint @@ -502,3 +504,53 @@ class MSEWeightZeroPoint(MSEZeroPoint): class MSEActZeroPoint(MSEZeroPoint): zero_point_impl = ParameterFromRuntimeZeroPoint + + +class HQOZeroPoint(ExtendedInjector): + """ + We leverage a sub-injector to avoid a name clash between scale and zero-point. + """ + + zero_point_impl = ParameterFromStatsFromParameterZeroPoint + + # per_channel = this.scaling_per_output_channel + # proxy_module = this.proxy_module + mse_init_op = NegativeMinOrZero + stats_impl_lol = HalfQuadraticOptimizerZeroPoint + # stats_reduce_dim = this.stats_reduce_dim + zero_point_stats_input_view_shape_impl = nn.Identity() + + @value + def zero_point_stats_impl(): + return this.stats_impl_lol + + @value + def inner_stats_input_view_shape_impl(scaling_per_output_channel): + if scaling_per_output_channel: + return StatsInputViewShapeImpl.OVER_OUTPUT_CHANNELS + else: + return StatsInputViewShapeImpl.OVER_TENSOR + + +class HQOSymmetricScale(ExtendedInjector): + """ + We leverage a sub-injector to avoid a name clash between scale and zero-point. + """ + + scaling_impl_type = ScalingImplType.PARAMETER_FROM_STATS + scaling_stats_input_view_shape_impl = nn.Identity() + # per_channel = this.scaling_per_output_channel + # proxy_module = this.proxy_module + mse_init_op = AbsMax + stats_impl = HalfQuadraticOptimizerScale + # stats_reduce_dim = this.stats_reduce_dim + @value + def scaling_stats_impl(): + return this.stats_impl + + @value + def inner_stats_input_view_shape_impl(scaling_per_output_channel): + if scaling_per_output_channel: + return StatsInputViewShapeImpl.OVER_OUTPUT_CHANNELS + else: + return StatsInputViewShapeImpl.OVER_TENSOR diff --git a/src/brevitas/quant/scaled_int.py b/src/brevitas/quant/scaled_int.py index 0f67300c3..b5f9174d7 100644 --- a/src/brevitas/quant/scaled_int.py +++ b/src/brevitas/quant/scaled_int.py @@ -3,6 +3,7 @@ from brevitas.core.function_wrapper import TensorClamp from brevitas.quant.base import * +from brevitas.quant.base import HQOSymmetricScale from brevitas.quant.solver.act import ActQuantSolver from brevitas.quant.solver.bias import BiasQuantSolver from brevitas.quant.solver.trunc import TruncQuantSolver @@ -443,3 +444,27 @@ class Int8AccumulatorAwareZeroCenterWeightQuant(AccumulatorAwareZeroCenterWeight >>> conv.quant_weight() """ bit_width = 8 + + +class Int8WeightPerTensorFloatHQO(HQOSymmetricScale, Int8WeightPerTensorFloat): + """ + 8-bit narrow per-tensor signed int weight quantizer with per-tensor floating-point scale factor computed + from HQO local loss. + + Examples: + >>> from brevitas.nn import QuantLinear + >>> fc = QuantLinear(10, 5, bias=False, weight_quant=Int8WeightPerTensorFloatHQO) + """ + pass + + +class Int8WeightPerChannelFloatHQO(HQOSymmetricScale, Int8WeightPerChannelFloat): + """ + 8-bit narrow per-tensor signed int weight quantizer with per-tensor floating-point scale factor computed + from HQO local loss. + + Examples: + >>> from brevitas.nn import QuantLinear + >>> fc = QuantLinear(10, 5, bias=False, weight_quant=Int8WeightPerChannelFloatHQO) + """ + pass diff --git a/src/brevitas/quant/shifted_scaled_int.py b/src/brevitas/quant/shifted_scaled_int.py index 936737571..f581e1187 100644 --- a/src/brevitas/quant/shifted_scaled_int.py +++ b/src/brevitas/quant/shifted_scaled_int.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: BSD-3-Clause from brevitas.quant.base import * +from brevitas.quant.base import HQOZeroPoint from brevitas.quant.solver.act import ActQuantSolver from brevitas.quant.solver.bias import BiasQuantSolver from brevitas.quant.solver.trunc import TruncQuantSolver @@ -138,3 +139,27 @@ class ShiftedUint8WeightPerChannelFloatMSE(MSEAsymmetricScale, >>> fc = QuantLinear(10, 5, bias=False, weight_quant=ShiftedUint8WeightPerChannelFloat) """ pass + + +class ShiftedUint8WeightPerTensorFloatHQO(HQOZeroPoint, ShiftedUint8WeightPerTensorFloat): + """ + 8-bit per-tensor unsigned int weight quantizer with floating-point per-channel scale factor and integer + zero point. Both zero-point and scale factors are learned parameters initialized from HQO local losses. + + Examples: + >>> from brevitas.nn import QuantLinear + >>> fc = QuantLinear(10, 5, bias=False, weight_quant=ShiftedUint8WeightPerTensorFloatHQO) + """ + quantize_zero_point = False + + +class ShiftedUint8WeightPerChannelFloatHQO(HQOZeroPoint, ShiftedUint8WeightPerChannelFloat): + """ + 8-bit per-tensor unsigned int weight quantizer with floating-point per-channel scale factor and integer + zero point. Both zero-point and scale factors are learned parameters initialized from HQO local losses. + + Examples: + >>> from brevitas.nn import QuantLinear + >>> fc = QuantLinear(10, 5, bias=False, weight_quant=ShiftedUint8WeightPerChannelFloatHQO) + """ + quantize_zero_point = False diff --git a/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py b/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py index c4bc616e7..3150a061b 100644 --- a/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py +++ b/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py @@ -37,8 +37,10 @@ from brevitas.quant.scaled_int import Int8ActPerTensorFloat from brevitas.quant.scaled_int import Int8ActPerTensorFloatMSE from brevitas.quant.scaled_int import Int8WeightPerChannelFloat +from brevitas.quant.scaled_int import Int8WeightPerChannelFloatHQO from brevitas.quant.scaled_int import Int8WeightPerChannelFloatMSE from brevitas.quant.scaled_int import Int8WeightPerTensorFloat +from brevitas.quant.scaled_int import Int8WeightPerTensorFloatHQO from brevitas.quant.scaled_int import Int8WeightPerTensorFloatMSE from brevitas.quant.scaled_int import Int16Bias from brevitas.quant.scaled_int import Int32Bias @@ -46,8 +48,10 @@ from brevitas.quant.shifted_scaled_int import ShiftedUint8ActPerTensorFloat from brevitas.quant.shifted_scaled_int import ShiftedUint8ActPerTensorFloatMSE from brevitas.quant.shifted_scaled_int import ShiftedUint8WeightPerChannelFloat +from brevitas.quant.shifted_scaled_int import ShiftedUint8WeightPerChannelFloatHQO from brevitas.quant.shifted_scaled_int import ShiftedUint8WeightPerChannelFloatMSE from brevitas.quant.shifted_scaled_int import ShiftedUint8WeightPerTensorFloat +from brevitas.quant.shifted_scaled_int import ShiftedUint8WeightPerTensorFloatHQO from brevitas.quant.shifted_scaled_int import ShiftedUint8WeightPerTensorFloatMSE from brevitas_examples.imagenet_classification.ptq.learned_round_utils import learned_round_iterator from brevitas_examples.imagenet_classification.ptq.learned_round_utils import save_inp_out_data @@ -71,7 +75,14 @@ 'asym': ShiftedUint8WeightPerTensorFloatMSE}, 'per_channel': { 'sym': Int8WeightPerChannelFloatMSE, - 'asym': ShiftedUint8WeightPerChannelFloatMSE},},}, + 'asym': ShiftedUint8WeightPerChannelFloatMSE}}, + 'hqo': { + 'per_tensor': { + 'sym': Int8WeightPerTensorFloatHQO, + 'asym': ShiftedUint8WeightPerTensorFloatHQO}, + 'per_channel': { + 'sym': Int8WeightPerChannelFloatHQO, + 'asym': ShiftedUint8WeightPerChannelFloatHQO}}}, 'po2_scale': { 'stats': { 'per_tensor': { @@ -557,5 +568,7 @@ def check_positive_int(*args): We check that every inputted value is positive, and an integer. """ for arg in args: + if arg is None: + continue assert arg > 0.0 assert math.isclose(arg % 1, 0.0)