From e982285b09e3c6be248adf18e4619413814e7a7f Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Mon, 2 Oct 2023 19:51:08 +0100 Subject: [PATCH] Fix for bitwidth, support MSE --- src/brevitas/quant/experimental/float.py | 14 +++++++++ .../benchmark/ptq_benchmark_torchvision.py | 27 +++++++++++++---- .../imagenet_classification/ptq/ptq_common.py | 29 ++++++++++++------- 3 files changed, 54 insertions(+), 16 deletions(-) diff --git a/src/brevitas/quant/experimental/float.py b/src/brevitas/quant/experimental/float.py index 63e032b35..dd429290b 100644 --- a/src/brevitas/quant/experimental/float.py +++ b/src/brevitas/quant/experimental/float.py @@ -124,3 +124,17 @@ class Fp8e5m2ActPerChannelFloat2dMSE(Fp8e5m2Mixin, MSESymmetricScale, ScaledFloa """ scaling_per_output_channel = True scaling_stats_permute_dims = (1, 0, 2, 3) + + +class Fp8e4m3WeightPerChannelFloatMSE(Fp8e4m3Mixin, MSESymmetricScale, ScaledFloatWeightBase): + """ + FP8 signed E3M4 weight quantizer with per-channel absmax-based scaling. + """ + scaling_per_output_channel = True + + +class Fp8e4m3WeightPerTensorFloatMSE(Fp8e4m3Mixin, MSESymmetricScale, ScaledFloatWeightBase): + """ + FP8 signed E3M4 weight quantizer with per-tensor absmax-based scaling. + """ + scaling_per_output_channel = False diff --git a/src/brevitas_examples/imagenet_classification/ptq/benchmark/ptq_benchmark_torchvision.py b/src/brevitas_examples/imagenet_classification/ptq/benchmark/ptq_benchmark_torchvision.py index 3ddcdac53..2cdf9bd15 100644 --- a/src/brevitas_examples/imagenet_classification/ptq/benchmark/ptq_benchmark_torchvision.py +++ b/src/brevitas_examples/imagenet_classification/ptq/benchmark/ptq_benchmark_torchvision.py @@ -38,6 +38,17 @@ config.IGNORE_MISSING_KEYS = True +def str2bool(v): + if isinstance(v, bool): + return v + if v.lower() in ('yes', 'true', 't', 'y', '1'): + return True + elif v.lower() in ('no', 'false', 'f', 'n', '0'): + return False + else: + raise argparse.ArgumentTypeError('Boolean value expected.') + + class hashabledict(dict): def __hash__(self): @@ -66,11 +77,11 @@ def unique(sequence): 'act_exponent_bit_width': [3], 'weight_bit_width': [8], # Weight Bit Width 'act_bit_width': [8], # Act bit width - 'bias_bit_width': [32], # Bias Bit-Width for Po2 scale + 'bias_bit_width': [None], # Bias Bit-Width for Po2 scale 'weight_quant_granularity': ['per_channel'], # Scaling Per Output Channel 'act_quant_type': ['sym'], # Act Quant Type - 'act_param_method': ['mse'], # Act Param Method - 'weight_param_method': ['stats'], # Weight Quant Type + 'act_param_method': ['stats'], # Act Param Method + 'weight_param_method': ['mse'], # Weight Quant Type 'bias_corr': [True], # Bias Correction 'graph_eq_iterations': [20], # Graph Equalization 'graph_eq_merge_bias': [True], # Merge bias for Graph Equalization @@ -102,8 +113,12 @@ def unique(sequence): '--batch-size-validation', default=256, type=int, help='Minibatch size for validation') parser.add_argument('--calibration-samples', default=1000, type=int, help='Calibration size') for option_name, option_value in OPTIONS_DEFAULT.items(): - parser.add_argument( - f'--{option_name}', default=option_value, nargs="+", type=type(option_value[0])) + if isinstance(option_value[0], bool): + type_args = str2bool + else: + type_args = type(option_value[0]) + parser.add_argument(f'--{option_name}', default=option_value, nargs="+", type=type_args) +print(parser) def main(): @@ -138,7 +153,7 @@ def ptq_torchvision_models(args): configs = unique(configs) - if args.idx > len(configs): + if args.idx > len(configs) - 1: return config_namespace = SimpleNamespace(**configs[args.idx]) diff --git a/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py b/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py index 672a3e322..bcd102a6d 100644 --- a/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py +++ b/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py @@ -22,8 +22,11 @@ import brevitas.nn as qnn from brevitas.quant.experimental.float import Fp8e4m3Act from brevitas.quant.experimental.float import Fp8e4m3ActPerTensorFloat +from brevitas.quant.experimental.float import Fp8e4m3ActPerTensorFloatMSE from brevitas.quant.experimental.float import Fp8e4m3WeightPerChannelFloat +from brevitas.quant.experimental.float import Fp8e4m3WeightPerChannelFloatMSE from brevitas.quant.experimental.float import Fp8e4m3WeightPerTensorFloat +from brevitas.quant.experimental.float import Fp8e4m3WeightPerTensorFloatMSE from brevitas.quant.fixed_point import Int8ActPerTensorFixedPoint from brevitas.quant.fixed_point import Int8ActPerTensorFixedPointMSE from brevitas.quant.fixed_point import Int8WeightPerChannelFixedPoint @@ -85,7 +88,12 @@ 'per_tensor': { 'sym': Fp8e4m3WeightPerTensorFloat}, 'per_channel': { - 'sym': Fp8e4m3WeightPerChannelFloat}}}}} + 'sym': Fp8e4m3WeightPerChannelFloat}}, + 'mse': { + 'per_tensor': { + 'sym': Fp8e4m3WeightPerTensorFloatMSE}, + 'per_channel': { + 'sym': Fp8e4m3WeightPerChannelFloatMSE}}}}} INPUT_QUANT_MAP = { 'int': { @@ -108,7 +116,10 @@ 'float_scale': { 'stats': { 'per_tensor': { - 'sym': Fp8e4m3ActPerTensorFloat},}}},} + 'sym': Fp8e4m3ActPerTensorFloat}}, + 'mse': { + 'per_tensor': { + 'sym': Fp8e4m3ActPerTensorFloat},}}}} def quantize_model( @@ -184,7 +195,7 @@ def layerwise_bit_width_fn_weight(module): # Missing fix for name_shadowing for all variables weight_bit_width_dict = {} act_bit_width_dict = {} - if weight_quant_format == 'int' and backend == 'layerwise': + if quant_format == 'int' and backend == 'layerwise': weight_bit_width_dict['weight_bit_width'] = layerwise_bit_width_fn_weight act_bit_width_dict['act_bit_width'] = layerwise_bit_width_fn_act @@ -192,7 +203,7 @@ def layerwise_bit_width_fn_weight(module): weight_bit_width_dict['weight_bit_width'] = weight_bit_width act_bit_width_dict['act_bit_width'] = act_bit_width - if weight_quant_format == 'float' and backend == 'layerwise': + if quant_format == 'float' and backend == 'layerwise': weight_bit_width_dict['weight_bit_width'] = layerwise_bit_width_fn_weight act_bit_width_dict['act_bit_width'] = layerwise_bit_width_fn_act weight_bit_width_dict['weight_mantissa_bit_width'] = layerwise_bit_width_fn_weight_mantissa @@ -271,15 +282,13 @@ def kwargs_prefix(prefix, weight_kwargs): weight_bit_width_dict = {'bit_width': weight_bit_width} if weight_quant_format == 'float': - weight_bit_width_dict = { - 'exponent_bit_width': weight_exponent_bit_width, - 'mantissa_bit_width': weight_mantissa_bit_width} + weight_bit_width_dict['exponent_bit_width'] = weight_exponent_bit_width + weight_bit_width_dict['mantissa_bit_width'] = weight_mantissa_bit_width act_bit_width_dict = {'bit_width': act_bit_width} if act_quant_format == 'float': - act_bit_width_dict = { - 'exponent_bit_width': act_exponent_bit_width, - 'mantissa_bit_width': act_mantissa_bit_width} + act_bit_width_dict['exponent_bit_width'] = act_exponent_bit_width + act_bit_width_dict['mantissa_bit_width'] = act_mantissa_bit_width # Retrieve base input, weight, and bias quantizers bias_quant = BIAS_BIT_WIDTH_MAP[bias_bit_width]