diff --git a/src/brevitas/core/stats/stats_op.py b/src/brevitas/core/stats/stats_op.py index 4cf42f6a8..6ea60952e 100644 --- a/src/brevitas/core/stats/stats_op.py +++ b/src/brevitas/core/stats/stats_op.py @@ -556,29 +556,29 @@ class HalfQuadraticOptimizerScale(torch.nn.Module): def __init__( self, proxy_module, - mse_init_op, + hqo_init_op_scale, inner_stats_input_view_shape_impl: torch.nn.Module, scaling_min_val: Optional[float] = None, stats_reduce_dim: Optional[int] = None, int_scaling_impl=None, bit_width_impl=None, - beta: float = 1e6, - kappa: float = 1.07, - lp_norm: float = .7, - hqo_iters: int = 1000): + hqo_beta_scale: float = 1e5, + hqo_kappa_scale: float = 1.01, + hqo_lp_norm_scale: float = .7, + hqo_iters_scale: int = 1000): super(HalfQuadraticOptimizerScale, self).__init__() - self.mse_init_op = mse_init_op + self.hqo_init_op = hqo_init_op_scale self.input_view_shape_impl = inner_stats_input_view_shape_impl self.proxy_forward = proxy_module.forward self.set_local_loss_mode = lambda enabled: _set_local_loss_mode(proxy_module, enabled) self.internal_candidate = None - self.hqo_iters = hqo_iters + self.hqo_iters = hqo_iters_scale self.stats_reduce_dim = stats_reduce_dim self.local_loss_mode: bool = False - self.beta = beta - self.kappa = kappa - self.lp_norm = lp_norm + self.beta = hqo_beta_scale + self.kappa = hqo_kappa_scale + self.lp_norm = hqo_lp_norm_scale self.int_scaling_impl = int_scaling_impl self.msb_clamp_bit_width_impl = bit_width_impl @@ -615,19 +615,19 @@ def parameter_search(self, xl, x): else: candidate = masked_median(num / den, mask, dim=self.stats_reduce_dim) candidate = self.input_view_shape_impl(candidate) - candidate[torch.isnan(candidate)] = self.internal_candidate[torch.isnan(candidate)] candidate = self.clamp_min_ste(candidate) bit_width = self.msb_clamp_bit_width_impl() int_threshold = self.int_scaling_impl(bit_width) candidate = candidate * int_threshold + candidate[torch.isnan(candidate)] = self.internal_candidate[torch.isnan(candidate)] + candidate[torch.isinf(candidate)] = self.internal_candidate[torch.isinf(candidate)] beta *= self.kappa - return best_candidate def optimize(self, x): x_view = self.input_view_shape_impl(x) - init = self.mse_init_op(x_view).detach() + init = self.hqo_init_op(x_view).detach() best_candidate = self.parameter_search(init, x) # Save for evaluation by other modules (e.g. zp) invoking local loss mode @@ -643,7 +643,7 @@ def forward(self, x): # This is invoked for the zero-point whenever scale is being optimized first if self.internal_candidate is None: x = self.input_view_shape_impl(x) - self.internal_candidate = self.mse_init_op(x).detach() + self.internal_candidate = self.hqo_init_op(x).detach() return self.internal_candidate @@ -655,25 +655,29 @@ class HalfQuadraticOptimizerZeroPoint(torch.nn.Module): def __init__( self, proxy_module, - mse_init_op: torch.nn.Module, + hqo_init_op_zp: torch.nn.Module, inner_stats_input_view_shape_impl: torch.nn.Module, stats_reduce_dim: Optional[int] = None, - beta: float = 1., - kappa: float = 1.01, - lp_norm: float = 1., - hqo_iters: int = 1000): + inner_expanded_zero_point_shape=None, + reshaped_zero_point_shape=None, + hqo_beta_zp: float = 1e0, + hqo_kappa_zp: float = 1.01, + hqo_lp_norm_zp: float = .7, + hqo_iters_zp: int = 1000): super(HalfQuadraticOptimizerZeroPoint, self).__init__() - self.mse_init_op = mse_init_op + self.hqo_init_op_zp = hqo_init_op_zp self.input_view_shape_impl = inner_stats_input_view_shape_impl self.proxy_forward = proxy_module.forward self.set_local_loss_mode = lambda enabled: _set_local_loss_mode(proxy_module, enabled) self.internal_candidate = None self.stats_reduce_dim = stats_reduce_dim self.local_loss_mode: bool = False - self.beta = beta - self.kappa = kappa - self.lp_norm = lp_norm - self.hqo_iters = hqo_iters + self.beta = hqo_beta_zp + self.kappa = hqo_kappa_zp + self.lp_norm = hqo_lp_norm_zp + self.hqo_iters = hqo_iters_zp + self.inner_expanded_zero_point_shape = inner_expanded_zero_point_shape + self.reshaped_zero_point_shape = reshaped_zero_point_shape def parameter_search(self, xl, x): best_loss = torch.tensor(float('inf'), device=x.device, dtype=x.dtype) @@ -695,10 +699,12 @@ def parameter_search(self, xl, x): val = self.input_view_shape_impl((x - W_e) - quant_tensor.int() * quant_tensor.scale) + if self.inner_expanded_zero_point_shape is not None: + val = val.reshape(self.inner_expanded_zero_point_shape) if self.stats_reduce_dim is None: candidate = torch.mean(val) else: - candidate = torch.mean(val, dim=self.stats_reduce_dim) + candidate = torch.mean(val, dim=self.stats_reduce_dim, keepdim=True) candidate = self.input_view_shape_impl(candidate) self.beta *= self.kappa return best_candidate @@ -706,7 +712,10 @@ def parameter_search(self, xl, x): def optimize(self, x): x_view = self.input_view_shape_impl(x) - init = self.mse_init_op(x_view).detach() + init = self.hqo_init_op_zp(x_view).detach() + if self.reshaped_zero_point_shape is not None: + x = x.reshape(self.reshaped_zero_point_shape) + best_candidate = self.parameter_search(init, x) # Save for evaluation by other modules (e.g. zp) invoking local loss mode @@ -722,11 +731,11 @@ def forward(self, x): # This is invoked for the zero-point whenever scale is being optimized first if self.internal_candidate is None: x = self.input_view_shape_impl(x) - self.internal_candidate = self.mse_init_op(x).detach() + self.internal_candidate = self.hqo_init_op_zp(x).detach() return self.internal_candidate -def masked_median(x, mask, dim=None): +def masked_median(x, mask, dim=None, keepdim=False): """Compute the median of tensor x along dim, ignoring values where mask is False. x and mask need to be broadcastable. @@ -748,7 +757,7 @@ def masked_median(x, mask, dim=None): if dim is None: x_median = x_nan.nanmedian() else: - x_median, _ = x_nan.nanmedian(dim=dim) + x_median, _ = x_nan.nanmedian(dim=dim, keepdim=keepdim) return x_median diff --git a/src/brevitas/core/zero_point.py b/src/brevitas/core/zero_point.py index 872435ec7..3f80f1dd4 100644 --- a/src/brevitas/core/zero_point.py +++ b/src/brevitas/core/zero_point.py @@ -281,7 +281,7 @@ def state_dict(self, destination=None, prefix='', keep_vars=False): output_dict = super(ParameterFromStatsFromParameterZeroPoint, self).state_dict( destination=destination, prefix=prefix, keep_vars=keep_vars) # Avoid saving the init value - if not self.init_done: + if not self.init_done and not config._FULL_STATE_DICT: del output_dict[prefix + 'value'] return output_dict diff --git a/src/brevitas/graph/gptq.py b/src/brevitas/graph/gptq.py index 31d31433b..fa5ad65b9 100644 --- a/src/brevitas/graph/gptq.py +++ b/src/brevitas/graph/gptq.py @@ -256,7 +256,6 @@ def single_layer_update(self, percdamp=.01): return finally: del self.H - for i1 in range(0, self.columns, self.blocksize): i2 = min(i1 + self.blocksize, self.columns) count = i2 - i1 diff --git a/src/brevitas/quant/base.py b/src/brevitas/quant/base.py index 7c65e60e1..f7dfe8699 100644 --- a/src/brevitas/quant/base.py +++ b/src/brevitas/quant/base.py @@ -445,6 +445,8 @@ class MSEAsymmetricScaleSubInjector(MSESubInjectorBase): mse_init_op = AbsMinMax stats_impl = MSE stats_reduce_dim = (this << 1).stats_reduce_dim + device = (this << 1).device + dtype = (this << 1).dtype class MSEZeroPointSubInjector(MSESubInjectorBase): @@ -455,6 +457,8 @@ class MSEZeroPointSubInjector(MSESubInjectorBase): mse_search_method = 'grid' stats_impl = MSE stats_reduce_dim = (this << 1).stats_reduce_dim + device = (this << 1).device + dtype = (this << 1).dtype class MSEAsymmetricScale(ExtendedInjector): @@ -507,22 +511,14 @@ class MSEActZeroPoint(MSEZeroPoint): class HQOZeroPoint(ExtendedInjector): - """ - We leverage a sub-injector to avoid a name clash between scale and zero-point. - """ - zero_point_impl = ParameterFromStatsFromParameterZeroPoint - - # per_channel = this.scaling_per_output_channel - # proxy_module = this.proxy_module - mse_init_op = NegativeMinOrZero - stats_impl_lol = HalfQuadraticOptimizerZeroPoint - # stats_reduce_dim = this.stats_reduce_dim + hqo_init_op_zp = NegativeMinOrZero + stats_impl_zp = HalfQuadraticOptimizerZeroPoint zero_point_stats_input_view_shape_impl = nn.Identity() @value def zero_point_stats_impl(): - return this.stats_impl_lol + return this.stats_impl_zp @value def inner_stats_input_view_shape_impl(scaling_per_output_channel): @@ -532,21 +528,15 @@ def inner_stats_input_view_shape_impl(scaling_per_output_channel): return StatsInputViewShapeImpl.OVER_TENSOR -class HQOSymmetricScale(ExtendedInjector): - """ - We leverage a sub-injector to avoid a name clash between scale and zero-point. - """ - +class HQOScale(ExtendedInjector): scaling_impl_type = ScalingImplType.PARAMETER_FROM_STATS scaling_stats_input_view_shape_impl = nn.Identity() - # per_channel = this.scaling_per_output_channel - # proxy_module = this.proxy_module - mse_init_op = AbsMax - stats_impl = HalfQuadraticOptimizerScale - # stats_reduce_dim = this.stats_reduce_dim + + stats_impl_scale = HalfQuadraticOptimizerScale + @value def scaling_stats_impl(): - return this.stats_impl + return this.stats_impl_scale @value def inner_stats_input_view_shape_impl(scaling_per_output_channel): @@ -554,3 +544,19 @@ def inner_stats_input_view_shape_impl(scaling_per_output_channel): return StatsInputViewShapeImpl.OVER_OUTPUT_CHANNELS else: return StatsInputViewShapeImpl.OVER_TENSOR + + +class HQOAsymmetricScale(HQOScale): + hqo_init_op_scale = AbsMinMax + + +class HQOSymmetricScale(HQOScale): + hqo_init_op_scale = AbsMax + + +class HQOActZeroPoint(HQOZeroPoint): + zero_point_impl = ParameterFromRuntimeZeroPoint + + +class HQOWeightZeroPoint(HQOZeroPoint): + zero_point_impl = ParameterFromStatsFromParameterZeroPoint diff --git a/src/brevitas/quant/shifted_scaled_int.py b/src/brevitas/quant/shifted_scaled_int.py index f581e1187..900471c98 100644 --- a/src/brevitas/quant/shifted_scaled_int.py +++ b/src/brevitas/quant/shifted_scaled_int.py @@ -2,6 +2,8 @@ # SPDX-License-Identifier: BSD-3-Clause from brevitas.quant.base import * +from brevitas.quant.base import HQOActZeroPoint +from brevitas.quant.base import HQOAsymmetricScale from brevitas.quant.base import HQOZeroPoint from brevitas.quant.solver.act import ActQuantSolver from brevitas.quant.solver.bias import BiasQuantSolver @@ -16,7 +18,8 @@ 'ShiftedUint8ActPerTensorFixedPointMSE', 'ShiftedUint8ActPerTensorFloatMSE', 'ShiftedUint8WeightPerTensorFloatMSE', - 'ShiftedUint8WeightPerChannelFloatMSE'] + 'ShiftedUint8WeightPerChannelFloatMSE', + 'ShiftedUint8ActPerTensorFloatHQO'] class ShiftedUint8ActPerTensorFixedPoint(ShiftedParamFromPercentileUintQuant, @@ -163,3 +166,16 @@ class ShiftedUint8WeightPerChannelFloatHQO(HQOZeroPoint, ShiftedUint8WeightPerCh >>> fc = QuantLinear(10, 5, bias=False, weight_quant=ShiftedUint8WeightPerChannelFloatHQO) """ quantize_zero_point = False + + +class ShiftedUint8ActPerTensorFloatHQO(HQOActZeroPoint, ShiftedUint8ActPerTensorFloat): + """ + 8-bit per-tensor unsigned int activations quantizer with floating-point scale factor and + integer zero point. Both zero-point and scale factors are learned parameters initialized from + HQO local loss. + + Examples: + >>> from brevitas.nn import QuantReLU + >>> act = QuantReLU(act_quant=ShiftedUint8ActPerTensorFloatHQO) + """ + quantize_zero_point = False diff --git a/src/brevitas_examples/common/generative/quant_blocks.py b/src/brevitas_examples/common/generative/quant_blocks.py index e403deecf..1b8ba9ecc 100644 --- a/src/brevitas_examples/common/generative/quant_blocks.py +++ b/src/brevitas_examples/common/generative/quant_blocks.py @@ -10,9 +10,12 @@ import torch.nn as nn import brevitas +import brevitas.config as config from brevitas.core.function_wrapper.shape import PermuteDims +from brevitas.core.utils import inplace_tensor_add from brevitas.core.utils import SliceTensor from brevitas.core.zero_point import _ScaleShiftZeroPoint +from brevitas.core.zero_point import ParameterFromStatsFromParameterZeroPoint from brevitas.function.ops_ste import abs_binary_sign_grad @@ -58,12 +61,22 @@ class ExpandReshapeZeroPointWrapper(brevitas.jit.ScriptModule): __constants__ = ['expanded_zero_point_shape', 'reshaped_zero_point_shape'] def __init__( - self, wrapped_zero_point_impl, expanded_zero_point_shape, reshaped_zero_point_shape): + self, + wrapped_zero_point_impl, + expanded_zero_point_shape, + reshaped_zero_point_shape, + dtype, + device): super(ExpandReshapeZeroPointWrapper, self).__init__() + assert isinstance(wrapped_zero_point_impl, ParameterFromStatsFromParameterZeroPoint) self.wrapped_zero_point_impl = wrapped_zero_point_impl self.expanded_zero_point_shape = expanded_zero_point_shape self.reshaped_zero_point_shape = reshaped_zero_point_shape self.slice_tensor = SliceTensor() + self.init_done = False + self.local_loss_mode: bool = brevitas.jit.Attribute(False, bool) + self.wrapped_zero_point_impl.value = torch.nn.Parameter( + torch.full(reshaped_zero_point_shape, 0.0, dtype=dtype, device=device)) def unexpanded_zero_point(self, unexpanded_scale, bit_width): """ @@ -76,17 +89,27 @@ def unexpanded_zero_point(self, unexpanded_scale, bit_width): @brevitas.jit.script_method def forward(self, x: Tensor, scale: Tensor, bit_width: Tensor): - # We have to break into wrapped_zero_point_impl since we need to expand and reshape - # Before we call into scale_shift_zero_point - zero_point_stats = self.wrapped_zero_point_impl.parameter_list_stats() - zero_point_stats = zero_point_stats.expand(self.expanded_zero_point_shape).contiguous() - # contiguous() above is to avoid an unsafe_view below - zero_point_stats = zero_point_stats.reshape(self.reshaped_zero_point_shape) - # slice tensor when required by partial quantization - zero_point_stats = self.slice_tensor(zero_point_stats) - zero_point = self.wrapped_zero_point_impl.scale_shift_zero_point( - -zero_point_stats, scale, bit_width) - return zero_point + if self.init_done: + value = self.slice_tensor(-self.wrapped_zero_point_impl.value) + return self.wrapped_zero_point_impl.scale_shift_zero_point(value, scale, bit_width) + else: + # We have to break into wrapped_zero_point_impl since we need to expand and reshape + # Before we call into scale_shift_zero_point + zero_point_stats = self.wrapped_zero_point_impl.parameter_list_stats() + zero_point_stats = zero_point_stats.expand(self.expanded_zero_point_shape).contiguous() + # contiguous() above is to avoid an unsafe_view below + zero_point_stats = zero_point_stats.reshape(self.reshaped_zero_point_shape) + # slice tensor when required by partial quantization + zero_point_stats = self.slice_tensor(zero_point_stats) + if self.local_loss_mode: + return self.wrapped_zero_point_impl.scale_shift_zero_point( + -zero_point_stats, scale, bit_width) + inplace_tensor_add(self.wrapped_zero_point_impl.value.detach(), zero_point_stats) + # self.wrapped_zero_point_impl.value.data = zero_point_stats + zero_point = self.wrapped_zero_point_impl.scale_shift_zero_point( + -zero_point_stats, scale, bit_width) + self.init_done = True + return zero_point # TODO: restore JIT compatibility diff --git a/src/brevitas_examples/common/generative/quantize.py b/src/brevitas_examples/common/generative/quantize.py index 31ab57361..d9aca3888 100644 --- a/src/brevitas_examples/common/generative/quantize.py +++ b/src/brevitas_examples/common/generative/quantize.py @@ -22,14 +22,18 @@ from brevitas.quant.scaled_int import Int8ActPerTensorFloat from brevitas.quant.scaled_int import Int8ActPerTensorFloatMSE from brevitas.quant.scaled_int import Int8WeightPerChannelFloat +from brevitas.quant.scaled_int import Int8WeightPerChannelFloatHQO from brevitas.quant.scaled_int import Int8WeightPerChannelFloatMSE from brevitas.quant.scaled_int import Int8WeightPerTensorFloat +from brevitas.quant.scaled_int import Int8WeightPerTensorFloatHQO from brevitas.quant.scaled_int import Int8WeightPerTensorFloatMSE from brevitas.quant.shifted_scaled_int import ShiftedUint8ActPerTensorFloat from brevitas.quant.shifted_scaled_int import ShiftedUint8ActPerTensorFloatMSE from brevitas.quant.shifted_scaled_int import ShiftedUint8WeightPerChannelFloat +from brevitas.quant.shifted_scaled_int import ShiftedUint8WeightPerChannelFloatHQO from brevitas.quant.shifted_scaled_int import ShiftedUint8WeightPerChannelFloatMSE from brevitas.quant.shifted_scaled_int import ShiftedUint8WeightPerTensorFloat +from brevitas.quant.shifted_scaled_int import ShiftedUint8WeightPerTensorFloatHQO from brevitas.quant.shifted_scaled_int import ShiftedUint8WeightPerTensorFloatMSE from brevitas_examples.common.generative.nn import LoRACompatibleQuantConv2d from brevitas_examples.common.generative.nn import LoRACompatibleQuantLinear @@ -41,6 +45,7 @@ from brevitas_examples.common.generative.quantizers import ShiftedUint8DynamicActPerRowFloat from brevitas_examples.common.generative.quantizers import ShiftedUint8DynamicActPerTensorFloat from brevitas_examples.common.generative.quantizers import ShiftedUintWeightAsymmetricGroupQuant +from brevitas_examples.common.generative.quantizers import ShiftedUintWeightAsymmetricGroupQuantHQO WEIGHT_QUANT_MAP = { 'int': { @@ -59,7 +64,16 @@ 'asym': ShiftedUint8WeightPerTensorFloatMSE}, 'per_channel': { 'sym': Int8WeightPerChannelFloatMSE, - 'asym': ShiftedUint8WeightPerChannelFloatMSE},},}, + 'asym': ShiftedUint8WeightPerChannelFloatMSE}}, + 'hqo': { + 'per_tensor': { + 'sym': Int8WeightPerTensorFloatHQO, + 'asym': ShiftedUint8WeightPerTensorFloatHQO}, + 'per_channel': { + 'sym': Int8WeightPerChannelFloatHQO, + 'asym': ShiftedUint8WeightPerChannelFloatHQO}, + 'per_group': { + 'asym': ShiftedUintWeightAsymmetricGroupQuantHQO}},}, 'po2_scale': { 'stats': { 'per_tensor': { diff --git a/src/brevitas_examples/common/generative/quantizers.py b/src/brevitas_examples/common/generative/quantizers.py index 76e2e4099..f1dc03dfc 100644 --- a/src/brevitas_examples/common/generative/quantizers.py +++ b/src/brevitas_examples/common/generative/quantizers.py @@ -10,6 +10,7 @@ from brevitas.core.scaling import ParameterFromStatsFromParameterScaling from brevitas.core.stats import AbsMinMax from brevitas.core.stats import NegativeMinOrZero +from brevitas.core.stats.stats_op import HalfQuadraticOptimizerZeroPoint from brevitas.core.stats.stats_wrapper import SCALAR_SHAPE from brevitas.core.zero_point import ParameterFromStatsFromParameterZeroPoint from brevitas.inject import ExtendedInjector @@ -107,6 +108,29 @@ class ShiftedUintWeightAsymmetricGroupQuant(IntWeightSymmetricGroupQuant): signed = False +class ShiftedUintWeightAsymmetricGroupQuantHQO(IntWeightSymmetricGroupQuant): + """ + Block / group / vector signed asymmetric weight quantizer with float scales and zero-points. + """ + zero_point_input_shape = this.scaling_input_shape + reshaped_zero_point_shape = this.reshaped_scaling_shape + zero_point_shape = this.scaling_shape + inner_expanded_zero_point_shape = this.expanded_scaling_shape + expanded_zero_point_shape = this.expanded_scaling_shape + zero_point_stats_input_view_shape_impl = this.scaling_stats_input_view_shape_impl + zero_point_stats_input_concat_dim = 0 + zero_point_impl = ExpandReshapeZeroPointWrapper + zero_point_stats_impl = HalfQuadraticOptimizerZeroPoint + hqo_init_op_zp = NegativeMinOrZero + scaling_stats_impl = AbsMinMax + keepdim = True + # zero-point is converted to a parameter right away + wrapped_zero_point_impl = ParameterFromStatsFromParameterZeroPoint + quantize_zero_point = False + signed = False + inner_stats_input_view_shape_impl = torch.nn.Identity() + + class Int8DynamicActPerTensorFloat(DynamicActProxyMixin, Int8ActPerTensorFloat): """ Symmetric quantizer with per tensor dynamic scale. diff --git a/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py b/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py index 3150a061b..af47458a9 100644 --- a/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py +++ b/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py @@ -21,9 +21,7 @@ from brevitas.graph.target.flexml import quantize_flexml from brevitas.inject import value import brevitas.nn as qnn -from brevitas.quant.experimental.float import Fp8e4m3Act from brevitas.quant.experimental.float import Fp8e4m3ActPerTensorFloat -from brevitas.quant.experimental.float import Fp8e4m3ActPerTensorFloatMSE from brevitas.quant.experimental.float import Fp8e4m3WeightPerChannelFloat from brevitas.quant.experimental.float import Fp8e4m3WeightPerChannelFloatMSE from brevitas.quant.experimental.float import Fp8e4m3WeightPerTensorFloat @@ -46,6 +44,7 @@ from brevitas.quant.scaled_int import Int32Bias from brevitas.quant.shifted_scaled_int import ShiftedUint8ActPerTensorFixedPoint from brevitas.quant.shifted_scaled_int import ShiftedUint8ActPerTensorFloat +from brevitas.quant.shifted_scaled_int import ShiftedUint8ActPerTensorFloatHQO from brevitas.quant.shifted_scaled_int import ShiftedUint8ActPerTensorFloatMSE from brevitas.quant.shifted_scaled_int import ShiftedUint8WeightPerChannelFloat from brevitas.quant.shifted_scaled_int import ShiftedUint8WeightPerChannelFloatHQO @@ -112,7 +111,7 @@ 'float_scale': { 'stats': { 'per_tensor': { - 'sym': Int8ActPerTensorFloat, 'asym': ShiftedUint8ActPerTensorFloat}}, + 'sym': Int8ActPerTensorFloat, 'asym': ShiftedUint8ActPerTensorFloatHQO}}, 'mse': { 'per_tensor': { 'sym': Int8ActPerTensorFloatMSE, 'asym': ShiftedUint8ActPerTensorFloatMSE}}}, @@ -337,8 +336,11 @@ def kwargs_prefix(prefix, weight_kwargs): per_tensor_act_quant = INPUT_QUANT_MAP[act_quant_format][act_scale_type][act_param_method][ 'per_tensor'][act_quant_type] act_quant = act_quant.let(**act_bit_width_dict) + act_quant = act_quant.let(**{'dtype': dtype, 'device': device}) sym_act_quant = sym_act_quant.let(**act_bit_width_dict) + sym_act_quant = sym_act_quant.let(**{'dtype': dtype, 'device': device}) per_tensor_act_quant = per_tensor_act_quant.let(**act_bit_width_dict) + per_tensor_act_quant = per_tensor_act_quant.let(**{'dtype': dtype, 'device': device}) else: act_quant = None sym_act_quant = None @@ -352,19 +354,14 @@ def kwargs_prefix(prefix, weight_kwargs): if weight_quant_type == 'asym': weight_quant = weight_quant.let(zero_point_impl=ParameterFromStatsFromParameterZeroPoint) if act_quant is not None: - act_quant = act_quant.let( - **{ - 'high_percentile_q': act_quant_percentile, 'dtype': dtype, 'device': device}) + act_quant = act_quant.let(**{'high_percentile_q': act_quant_percentile}) if act_quant_type == 'asym' and act_quant_percentile is not None: act_quant = act_quant.let(**{'low_percentile_q': 100 - act_quant_percentile}) if sym_act_quant is not None: - sym_act_quant = sym_act_quant.let( - **{ - 'high_percentile_q': act_quant_percentile, 'dtype': dtype, 'device': device}) + sym_act_quant = sym_act_quant.let(**{'high_percentile_q': act_quant_percentile}) if per_tensor_act_quant is not None: per_tensor_act_quant = per_tensor_act_quant.let( - **{ - 'high_percentile_q': act_quant_percentile, 'dtype': dtype, 'device': device}) + **{'high_percentile_q': act_quant_percentile}) if act_quant_type == 'asym' and act_quant_percentile is not None: per_tensor_act_quant = per_tensor_act_quant.let( **{'low_percentile_q': 100 - act_quant_percentile}) diff --git a/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py b/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py index 1f7c06a2b..b325e8bd0 100644 --- a/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py +++ b/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py @@ -126,7 +126,7 @@ def parse_type(v, default_type): parser.add_argument( '--weight-quant-calibration-type', default='stats', - choices=['stats', 'mse'], + choices=['stats', 'mse', 'hqo'], help='Weight quantization calibration type (default: stats)') parser.add_argument( '--act-equalization',