diff --git a/src/brevitas/quant/experimental/float.py b/src/brevitas/quant/experimental/float.py index 63e032b35..8e86936c2 100644 --- a/src/brevitas/quant/experimental/float.py +++ b/src/brevitas/quant/experimental/float.py @@ -124,3 +124,17 @@ class Fp8e5m2ActPerChannelFloat2dMSE(Fp8e5m2Mixin, MSESymmetricScale, ScaledFloa """ scaling_per_output_channel = True scaling_stats_permute_dims = (1, 0, 2, 3) + + +class Fp8e4m3WeightPerChannelFloatMSE(Fp8e4m3Mixin, MSESymmetricScale, ScaledFloatWeightBase): + """ + FP8 signed E3M4 weight quantizer with per-channel MSE-based scaling. + """ + scaling_per_output_channel = True + + +class Fp8e4m3WeightPerTensorFloatMSE(Fp8e4m3Mixin, MSESymmetricScale, ScaledFloatWeightBase): + """ + FP8 signed E3M4 weight quantizer with per-tensor MSE-based scaling. + """ + scaling_per_output_channel = False diff --git a/src/brevitas_examples/imagenet_classification/ptq/benchmark/ptq_benchmark_torchvision.py b/src/brevitas_examples/imagenet_classification/ptq/benchmark/ptq_benchmark_torchvision.py index 831ded35b..b88db3e3a 100644 --- a/src/brevitas_examples/imagenet_classification/ptq/benchmark/ptq_benchmark_torchvision.py +++ b/src/brevitas_examples/imagenet_classification/ptq/benchmark/ptq_benchmark_torchvision.py @@ -2,10 +2,10 @@ # SPDX-License-Identifier: BSD-3-Clause import argparse +from functools import partial from itertools import product import os import random -from time import sleep from types import SimpleNamespace import numpy as np @@ -38,54 +38,66 @@ config.IGNORE_MISSING_KEYS = True + +def parse_type(v, default_type): + if v == 'None': + return None + else: + return default_type(v) + + +def parse_bool(v): + if isinstance(v, bool): + return v + if v.lower() in ('yes', 'true', 't', 'y'): + return True + elif v.lower() in ('no', 'false', 'f', 'n'): + return False + else: + raise argparse.ArgumentTypeError('Boolean value expected.') + + +class hashabledict(dict): + + def __hash__(self): + return hash(tuple(sorted(self.items()))) + + +def unique(sequence): + seen = set() + return [x for x in sequence if not (x in seen or seen.add(x))] + + # Torchvision models with top1 accuracy TORCHVISION_TOP1_MAP = { 'resnet18': 69.758, 'mobilenet_v2': 71.898, 'vit_b_32': 75.912,} -OPTIONS = { - 'model_name': TORCHVISION_TOP1_MAP.keys(), - 'target_backend': ['fx', 'layerwise', 'flexml'], # Target backend - 'scale_factor_type': ['float', 'po2'], # Scale factor type - 'weight_bit_width': [8, 4], # Weight Bit Width - 'act_bit_width': [8, 4], # Act bit width - 'bias_bit_width': [None, 32, 16], # Bias Bit-Width for Po2 scale - 'weight_quant_granularity': ['per_tensor', 'per_channel'], # Scaling Per Output Channel - 'act_quant_type': ['asym', 'sym'], # Act Quant Type - 'weight_param_method': ['stats', 'mse'], # Weight Quant Type - 'act_param_method': ['stats', 'mse'], # Act Param Method - 'bias_corr': [True], # Bias Correction - 'graph_eq_iterations': [0, 20], # Graph Equalization - 'graph_eq_merge_bias': [False, True], # Merge bias for Graph Equalization - 'act_equalization': ['fx', 'layerwise', None], # Perform Activation Equalization (Smoothquant) - 'learned_round': [False, True], # Enable/Disable Learned Round - 'gptq': [False, True], # Enable/Disable GPTQ - 'gptq_act_order': [False, True], # Use act_order euristics for GPTQ - 'gpfq': [False, True], # Enable/Disable GPFQ - 'gpfq_p': [0.25, 0.75], # GPFQ P - 'act_quant_percentile': [99.9, 99.99, 99.999], # Activation Quantization Percentile - 'uint_sym_act_for_unsigned_values': [True], # Whether to use unsigned act quant when possible -} - OPTIONS_DEFAULT = { - 'target_backend': ['fx'], # Target backend - 'scale_factor_type': ['float'], # Scale factor type + 'model_name': list(TORCHVISION_TOP1_MAP.keys()), + 'quant_format': ['int'], # Quantization type (INT vs Float) + 'target_backend': ['layerwise'], # Target backend + 'scale_factor_type': ['float_scale'], # Scale factor type + 'weight_mantissa_bit_width': [4], + 'weight_exponent_bit_width': [3], + 'act_mantissa_bit_width': [4], + 'act_exponent_bit_width': [3], 'weight_bit_width': [8], # Weight Bit Width 'act_bit_width': [8], # Act bit width 'bias_bit_width': [32], # Bias Bit-Width for Po2 scale 'weight_quant_granularity': ['per_channel'], # Scaling Per Output Channel 'act_quant_type': ['sym'], # Act Quant Type - 'act_param_method': ['mse'], # Act Param Method - 'weight_param_method': ['stats'], # Weight Quant Type + 'act_param_method': ['stats'], # Act Param Method + 'weight_param_method': ['mse'], # Weight Quant Type 'bias_corr': [True], # Bias Correction 'graph_eq_iterations': [20], # Graph Equalization 'graph_eq_merge_bias': [True], # Merge bias for Graph Equalization - 'act_equalization': [None], # Perform Activation Equalization (Smoothquant) + 'act_equalization': ['layerwise'], # Perform Activation Equalization (Smoothquant) 'learned_round': [False], # Enable/Disable Learned Round 'gptq': [True], # Enable/Disable GPTQ 'gpfq': [False], # Enable/Disable GPFQ - 'gpfq_p': [0.25], # GPFQ P + 'gpfq_p': [0.75], # GPFQ P 'gptq_act_order': [False], # Use act_order euristics for GPTQ 'act_quant_percentile': [99.999], # Activation Quantization Percentile 'uint_sym_act_for_unsigned_values': [True], # Whether to use unsigned act quant when possible @@ -108,8 +120,12 @@ parser.add_argument( '--batch-size-validation', default=256, type=int, help='Minibatch size for validation') parser.add_argument('--calibration-samples', default=1000, type=int, help='Calibration size') -parser.add_argument( - '--options-to-exclude', choices=OPTIONS.keys(), nargs="+", default=[], help='Calibration size') +for option_name, option_value in OPTIONS_DEFAULT.items(): + if isinstance(option_value[0], bool): + type_args = parse_bool + else: + type_args = partial(parse_type, default_type=type(option_value[0])) + parser.add_argument(f'--{option_name}', default=option_value, nargs="+", type=type_args) def main(): @@ -118,11 +134,9 @@ def main(): np.random.seed(SEED) torch.manual_seed(SEED) - for option in args.options_to_exclude: - OPTIONS[option] = OPTIONS_DEFAULT[option] - args.gpu = get_gpu_index(args.idx) print("Iter {}, GPU {}".format(args.idx, args.gpu)) + try: ptq_torchvision_models(args) except Exception as E: @@ -131,42 +145,25 @@ def main(): def ptq_torchvision_models(args): # Generate all possible combinations, including invalid ones - # Split stats and mse due to the act_quant_percentile value - if 'stats' in OPTIONS['act_param_method']: - percentile_options = OPTIONS.copy() - percentile_options['act_param_method'] = ['stats'] - else: - percentile_options = None + options = {k: getattr(args, k) for k, _ in OPTIONS_DEFAULT.items()} - if 'mse' in OPTIONS['act_param_method']: - mse_options = OPTIONS.copy() - mse_options['act_param_method'] = ['mse'] - mse_options['act_quant_percentile'] = [None] - else: - mse_options = None - - # Combine MSE and Percentile combinations, if they are defined - combinations = [] - if mse_options is not None: - combinations += list(product(*mse_options.values())) - if percentile_options is not None: - combinations += list(product(*percentile_options.values())) - # Combine the two sets of combinations - # Generate Namespace for each configuration - configs = [ - SimpleNamespace(**{k: v - for k, v in zip(OPTIONS.keys(), combination)}) - for combination in combinations] - # Define which configurations are not valid - configs = list(map(validate_config, configs)) - # Drop invalid configurations - configs = list(config for config in configs if config.is_valid) - - if args.idx > len(configs): + combinations = list(product(*options.values())) + + configs = [] + for combination in combinations: + config_namespace = SimpleNamespace( + **{k: v for k, v in zip(OPTIONS_DEFAULT.keys(), combination)}) + config_namespace = validate_config(config_namespace) + if config_namespace.is_valid: + configs.append(hashabledict(**config_namespace.__dict__)) + + configs = unique(configs) + + if args.idx > len(configs) - 1: return - config_namespace = configs[args.idx] + config_namespace = SimpleNamespace(**configs[args.idx]) print(config_namespace) fp_accuracy = TORCHVISION_TOP1_MAP[config_namespace.model_name] @@ -221,8 +218,13 @@ def ptq_torchvision_models(args): # Define the quantized model quant_model = quantize_model( model, + quant_format=config_namespace.quant_format, backend=config_namespace.target_backend, act_bit_width=config_namespace.act_bit_width, + weight_mantissa_bit_width=config_namespace.weight_mantissa_bit_width, + weight_exponent_bit_width=config_namespace.weight_exponent_bit_width, + act_mantissa_bit_width=config_namespace.act_mantissa_bit_width, + act_exponent_bit_width=config_namespace.act_exponent_bit_width, weight_bit_width=config_namespace.weight_bit_width, weight_param_method=config_namespace.weight_param_method, act_param_method=config_namespace.act_param_method, @@ -298,7 +300,7 @@ def validate_config(config_namespace): # Flexml supports only per-tensor scale factors, power of two scale factors if config_namespace.target_backend == 'flexml' and ( config_namespace.weight_quant_granularity == 'per_channel' or - config_namespace.scale_factor_type == 'float32'): + config_namespace.scale_factor_type == 'float_scale'): is_valid = False # Merge bias can be enabled only when graph equalization is enabled if config_namespace.graph_eq_iterations == 0 and config_namespace.graph_eq_merge_bias: @@ -311,15 +313,33 @@ def validate_config(config_namespace): if not config_namespace.gptq and config_namespace.gptq_act_order: is_valid = False - # If GPFQ is disabled, we execute only one configuration for p==0.25 - if not config_namespace.gpfq and config_namespace.gpfq_p == 0.75: - is_valid = False - if config_namespace.act_equalization == 'layerwise' and config_namespace.target_backend == 'fx': is_valid = False if config_namespace.act_bit_width < config_namespace.weight_bit_width: is_valid = False + if config_namespace.act_param_method == 'mse': + config_namespace.act_quant_percentile = None + + if not config_namespace.gpfq: + config_namespace.gpfq_p = None + + if config_namespace.quant_format == 'int': + config_namespace.weight_mantissa_bit_width = None + config_namespace.weight_exponent_bit_width = None + config_namespace.act_mantissa_bit_width = None + config_namespace.act_exponent_bit_width = None + + if config_namespace.quant_format == 'float': + config_namespace.act_quant_type = 'sym' + config_namespace.weight_quant_type = 'sym' + + if config_namespace.quant_format == 'float': + if config_namespace.weight_exponent_bit_width + config_namespace.weight_mantissa_bit_width != config_namespace.weight_bit_width - 1: + is_valid = False + if config_namespace.act_exponent_bit_width + config_namespace.act_mantissa_bit_width != config_namespace.act_bit_width - 1: + is_valid = False + config_namespace.is_valid = is_valid return config_namespace diff --git a/src/brevitas_examples/imagenet_classification/ptq/benchmark/single_command.sh b/src/brevitas_examples/imagenet_classification/ptq/benchmark/single_command.sh index bcc1fec09..c70912fa0 100644 --- a/src/brevitas_examples/imagenet_classification/ptq/benchmark/single_command.sh +++ b/src/brevitas_examples/imagenet_classification/ptq/benchmark/single_command.sh @@ -1 +1,25 @@ -python ptq_benchmark_torchvision.py $1 --calibration-dir /scratch/datasets/imagenet_symlink/calibration --validation-dir /scratch/datasets/imagenet_symlink/val --options-to-exclude graph_eq_merge_bias graph_eq_iterations +python ptq_benchmark_torchvision.py $1 --calibration-dir /scratch/datasets/imagenet_symlink/calibration --validation-dir /scratch/datasets/imagenet_symlink/val \ +--quant_format float \ +--scale_factor_type float_scale \ +--weight_bit_width 2 3 4 5 6 7 8 \ +--act_bit_width 2 3 4 5 6 7 8 \ +--weight_mantissa_bit_width 1 2 3 4 5 6 \ +--weight_exponent_bit_width 1 2 3 4 5 6 \ +--act_mantissa_bit_width 1 2 3 4 5 6 \ +--act_exponent_bit_width 1 2 3 4 5 6 \ +--bias_bit_width None \ +--weight_quant_granularity per_channel per_tensor \ +--act_quant_type sym \ +--weight_param_method stats \ +--act_param_method mse \ +--bias_corr True \ +--graph_eq_iterations 20 \ +--graph_eq_merge_bias True \ +--act_equalization layerwise \ +--learned_round False \ +--gptq False \ +--gptq_act_order False \ +--gpfq False \ +--gpfq_p None \ +--uint_sym_act_for_unsigned_values False \ +--act_quant_percentile None \ diff --git a/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py b/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py index 8c250d32f..20250f7dc 100644 --- a/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py +++ b/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py @@ -18,7 +18,15 @@ from brevitas.graph.quantize import layerwise_quantize from brevitas.graph.quantize import quantize from brevitas.graph.target.flexml import quantize_flexml +from brevitas.inject import value import brevitas.nn as qnn +from brevitas.quant.experimental.float import Fp8e4m3Act +from brevitas.quant.experimental.float import Fp8e4m3ActPerTensorFloat +from brevitas.quant.experimental.float import Fp8e4m3ActPerTensorFloatMSE +from brevitas.quant.experimental.float import Fp8e4m3WeightPerChannelFloat +from brevitas.quant.experimental.float import Fp8e4m3WeightPerChannelFloatMSE +from brevitas.quant.experimental.float import Fp8e4m3WeightPerTensorFloat +from brevitas.quant.experimental.float import Fp8e4m3WeightPerTensorFloatMSE from brevitas.quant.fixed_point import Int8ActPerTensorFixedPoint from brevitas.quant.fixed_point import Int8ActPerTensorFixedPointMSE from brevitas.quant.fixed_point import Int8WeightPerChannelFixedPoint @@ -49,45 +57,69 @@ BIAS_BIT_WIDTH_MAP = {32: Int32Bias, 16: Int16Bias, None: None} WEIGHT_QUANT_MAP = { + 'int': { + 'float_scale': { + 'stats': { + 'per_tensor': { + 'sym': Int8WeightPerTensorFloat, 'asym': ShiftedUint8WeightPerTensorFloat}, + 'per_channel': { + 'sym': Int8WeightPerChannelFloat, 'asym': ShiftedUint8WeightPerChannelFloat}}, + 'mse': { + 'per_tensor': { + 'sym': Int8WeightPerTensorFloatMSE, + 'asym': ShiftedUint8WeightPerTensorFloatMSE}, + 'per_channel': { + 'sym': Int8WeightPerChannelFloatMSE, + 'asym': ShiftedUint8WeightPerChannelFloatMSE},},}, + 'po2_scale': { + 'stats': { + 'per_tensor': { + 'sym': Int8WeightPerTensorFixedPoint}, + 'per_channel': { + 'sym': Int8WeightPerChannelFixedPoint},}, + 'mse': { + 'per_tensor': { + 'sym': Int8WeightPerTensorFixedPointMSE}, + 'per_channel': { + 'sym': Int8WeightPerChannelFixedPointMSE}},}}, 'float': { - 'stats': { - 'per_tensor': { - 'sym': Int8WeightPerTensorFloat, 'asym': ShiftedUint8WeightPerTensorFloat}, - 'per_channel': { - 'sym': Int8WeightPerChannelFloat, 'asym': ShiftedUint8WeightPerChannelFloat}}, - 'mse': { - 'per_tensor': { - 'sym': Int8WeightPerTensorFloatMSE, 'asym': ShiftedUint8WeightPerTensorFloatMSE}, - 'per_channel': { - 'sym': Int8WeightPerChannelFloatMSE, 'asym': ShiftedUint8WeightPerChannelFloatMSE}, - },}, - 'po2': { - 'stats': { - 'per_tensor': { - 'sym': Int8WeightPerTensorFixedPoint}, - 'per_channel': { - 'sym': Int8WeightPerChannelFixedPoint},}, - 'mse': { - 'per_tensor': { - 'sym': Int8WeightPerTensorFixedPointMSE}, - 'per_channel': { - 'sym': Int8WeightPerChannelFixedPointMSE}},}} + 'float_scale': { + 'stats': { + 'per_tensor': { + 'sym': Fp8e4m3WeightPerTensorFloat}, + 'per_channel': { + 'sym': Fp8e4m3WeightPerChannelFloat}}, + 'mse': { + 'per_tensor': { + 'sym': Fp8e4m3WeightPerTensorFloatMSE}, + 'per_channel': { + 'sym': Fp8e4m3WeightPerChannelFloatMSE}}}}} INPUT_QUANT_MAP = { + 'int': { + 'float_scale': { + 'stats': { + 'per_tensor': { + 'sym': Int8ActPerTensorFloat, 'asym': ShiftedUint8ActPerTensorFloat}}, + 'mse': { + 'per_tensor': { + 'sym': Int8ActPerTensorFloatMSE, 'asym': ShiftedUint8ActPerTensorFloatMSE}}}, + 'po2_scale': { + 'stats': { + 'per_tensor': { + 'sym': Int8ActPerTensorFixedPoint, 'asym': ShiftedUint8ActPerTensorFixedPoint}, + }, + 'mse': { + 'per_tensor': { + 'sym': Int8ActPerTensorFixedPointMSE}},}}, 'float': { - 'stats': { - 'per_tensor': { - 'sym': Int8ActPerTensorFloat, 'asym': ShiftedUint8ActPerTensorFloat}}, - 'mse': { - 'per_tensor': { - 'sym': Int8ActPerTensorFloatMSE, 'asym': ShiftedUint8ActPerTensorFloatMSE}}}, - 'po2': { - 'stats': { - 'per_tensor': { - 'sym': Int8ActPerTensorFixedPoint, 'asym': ShiftedUint8ActPerTensorFixedPoint},}, - 'mse': { - 'per_tensor': { - 'sym': Int8ActPerTensorFixedPointMSE}},}} + 'float_scale': { + 'stats': { + 'per_tensor': { + 'sym': Fp8e4m3ActPerTensorFloat}}, + 'mse': { + 'per_tensor': { + 'sym': Fp8e4m3ActPerTensorFloat},}}}} def quantize_model( @@ -100,7 +132,14 @@ def quantize_model( act_quant_percentile, act_quant_type, scale_factor_type, + quant_format, layerwise_first_last_bit_width=8, + layerwise_first_last_mantissa_bit_width=4, + layerwise_first_last_exponent_bit_width=3, + weight_mantissa_bit_width=4, + weight_exponent_bit_width=3, + act_mantissa_bit_width=4, + act_exponent_bit_width=3, weight_narrow_range=False, weight_param_method='stats', act_param_method='stats', @@ -113,35 +152,90 @@ def quantize_model( weight_scale_type = scale_factor_type act_scale_type = scale_factor_type - weight_quant_granularity = weight_quant_granularity + weight_quant_format = quant_format + act_quant_format = quant_format - def bit_width_fn(module, other_bit_width): + def layerwise_bit_width_fn(module, base_bit_width, first_last_bit_width): if isinstance(module, torch.nn.Conv2d) and module.in_channels == 3: - return layerwise_first_last_bit_width + return first_last_bit_width elif isinstance(module, torch.nn.Linear) and module.out_features == 1000: - return layerwise_first_last_bit_width + return first_last_bit_width else: - return other_bit_width + return base_bit_width + + @value + def layerwise_bit_width_fn_act_exponent(module): + return layerwise_bit_width_fn( + module, act_exponent_bit_width, layerwise_first_last_exponent_bit_width) + + @value + def layerwise_bit_width_fn_act_mantissa(module): + return layerwise_bit_width_fn( + module, act_mantissa_bit_width, layerwise_first_last_mantissa_bit_width) + + @value + def layerwise_bit_width_fn_weight_exponent(module): + return layerwise_bit_width_fn( + module, weight_exponent_bit_width, layerwise_first_last_exponent_bit_width) + + @value + def layerwise_bit_width_fn_weight_mantissa(module): + return layerwise_bit_width_fn( + module, weight_mantissa_bit_width, layerwise_first_last_mantissa_bit_width) + + @value + def layerwise_bit_width_fn_act(module): + return layerwise_bit_width_fn(module, act_bit_width, layerwise_first_last_bit_width) + + @value + def layerwise_bit_width_fn_weight(module): + return layerwise_bit_width_fn(module, weight_bit_width, layerwise_first_last_bit_width) + + # Missing fix for backend =! layerwise + # Missing fix for name_shadowing for all variables + weight_bit_width_dict = {} + act_bit_width_dict = {} + if quant_format == 'int' and backend == 'layerwise': + weight_bit_width_dict['weight_bit_width'] = layerwise_bit_width_fn_weight + act_bit_width_dict['act_bit_width'] = layerwise_bit_width_fn_act + + else: + weight_bit_width_dict['weight_bit_width'] = weight_bit_width + act_bit_width_dict['act_bit_width'] = act_bit_width + + if quant_format == 'float' and backend == 'layerwise': + weight_bit_width_dict['weight_bit_width'] = layerwise_bit_width_fn_weight + act_bit_width_dict['act_bit_width'] = layerwise_bit_width_fn_act + weight_bit_width_dict['weight_mantissa_bit_width'] = layerwise_bit_width_fn_weight_mantissa + weight_bit_width_dict['weight_exponent_bit_width'] = layerwise_bit_width_fn_weight_exponent + act_bit_width_dict['act_mantissa_bit_width'] = layerwise_bit_width_fn_act_mantissa + act_bit_width_dict['act_exponent_bit_width'] = layerwise_bit_width_fn_act_exponent + else: + weight_bit_width_dict['weight_bit_width'] = weight_bit_width + act_bit_width_dict['act_bit_width'] = act_bit_width + weight_bit_width_dict['weight_mantissa_bit_width'] = weight_mantissa_bit_width + weight_bit_width_dict['weight_exponent_bit_width'] = weight_exponent_bit_width + act_bit_width_dict['act_mantissa_bit_width'] = act_mantissa_bit_width + act_bit_width_dict['act_exponent_bit_width'] = act_exponent_bit_width + - weight_bit_width_or_lambda = weight_bit_width if backend != 'layerwise' else lambda module: bit_width_fn( - module, weight_bit_width) - act_bit_width_or_lambda = act_bit_width if backend != 'layerwise' else lambda module: bit_width_fn( - module, act_bit_width) quant_layer_map, quant_layerwise_layer_map, quant_act_map, quant_identity_map = create_quant_maps(dtype=dtype, uint_sym_act_for_unsigned_values=uint_sym_act_for_unsigned_values, bias_bit_width=bias_bit_width, - weight_bit_width=weight_bit_width_or_lambda, weight_param_method=weight_param_method, weight_scale_type=weight_scale_type, weight_quant_type=weight_quant_type, weight_quant_granularity=weight_quant_granularity, weight_narrow_range=weight_narrow_range, - act_bit_width=act_bit_width_or_lambda, + weight_quant_format=weight_quant_format, + act_quant_format=act_quant_format, act_scale_type=act_scale_type, act_param_method=act_param_method, act_quant_type=act_quant_type, act_quant_granularity=act_quant_granularity, - act_quant_percentile=act_quant_percentile) + act_quant_percentile=act_quant_percentile, + **weight_bit_width_dict, + **act_bit_width_dict) if backend != 'layerwise': # Fx and flexml backend requires three mappings for quantization @@ -166,7 +260,13 @@ def create_quant_maps( weight_quant_type, weight_quant_granularity, weight_narrow_range, + weight_quant_format, + act_quant_format, uint_sym_act_for_unsigned_values=True, + weight_mantissa_bit_width=None, + weight_exponent_bit_width=None, + act_mantissa_bit_width=None, + act_exponent_bit_width=None, act_bit_width=None, act_scale_type=None, act_param_method=None, @@ -180,19 +280,34 @@ def create_quant_maps( def kwargs_prefix(prefix, weight_kwargs): return {prefix + k: v for k, v in weight_kwargs.items()} + weight_bit_width_dict = {'bit_width': weight_bit_width} + if weight_quant_format == 'float': + weight_bit_width_dict['exponent_bit_width'] = weight_exponent_bit_width + weight_bit_width_dict['mantissa_bit_width'] = weight_mantissa_bit_width + + act_bit_width_dict = {'bit_width': act_bit_width} + if act_quant_format == 'float': + act_bit_width_dict['exponent_bit_width'] = act_exponent_bit_width + act_bit_width_dict['mantissa_bit_width'] = act_mantissa_bit_width + # Retrieve base input, weight, and bias quantizers bias_quant = BIAS_BIT_WIDTH_MAP[bias_bit_width] - weight_quant = WEIGHT_QUANT_MAP[weight_scale_type][weight_param_method][ + weight_quant = WEIGHT_QUANT_MAP[weight_quant_format][weight_scale_type][weight_param_method][ weight_quant_granularity][weight_quant_type] + weight_quant = weight_quant.let(**weight_bit_width_dict) + if act_bit_width is not None: - act_quant = INPUT_QUANT_MAP[act_scale_type][act_param_method][act_quant_granularity][ - act_quant_type] + act_quant = INPUT_QUANT_MAP[act_quant_format][act_scale_type][act_param_method][ + act_quant_granularity][act_quant_type] # Some activations in MHA should always be symmetric - sym_act_quant = INPUT_QUANT_MAP[act_scale_type][act_param_method][act_quant_granularity][ - 'sym'] + sym_act_quant = INPUT_QUANT_MAP[act_quant_format][act_scale_type][act_param_method][ + act_quant_granularity]['sym'] # Linear layers with 2d input should always be per tensor - per_tensor_act_quant = INPUT_QUANT_MAP[act_scale_type][act_param_method]['per_tensor'][ - act_quant_type] + per_tensor_act_quant = INPUT_QUANT_MAP[act_quant_format][act_scale_type][act_param_method][ + 'per_tensor'][act_quant_type] + act_quant = act_quant.let(**act_bit_width_dict) + sym_act_quant = sym_act_quant.let(**act_bit_width_dict) + per_tensor_act_quant = per_tensor_act_quant.let(**act_bit_width_dict) else: act_quant = None sym_act_quant = None @@ -221,31 +336,22 @@ def kwargs_prefix(prefix, weight_kwargs): per_tensor_act_quant = per_tensor_act_quant.let( **{'low_percentile_q': 100 - act_quant_percentile}) - weight_quant_and_bit_width = { - 'weight_quant': weight_quant, 'weight_bit_width': weight_bit_width} + weight_quant_dict = {'weight_quant': weight_quant} quant_wbiol_kwargs = { - **weight_quant_and_bit_width, - 'dtype': dtype, - 'return_quant_tensor': False, - 'bias_quant': bias_quant} + **weight_quant_dict, 'dtype': dtype, 'return_quant_tensor': False, 'bias_quant': bias_quant} # yapf: disable quant_mha_kwargs = { - **kwargs_prefix('in_proj_', weight_quant_and_bit_width), - **kwargs_prefix('out_proj_', weight_quant_and_bit_width), + **kwargs_prefix('in_proj_', weight_quant_dict), + **kwargs_prefix('out_proj_', weight_quant_dict), 'in_proj_bias_quant': bias_quant, 'softmax_input_quant': None, 'attn_output_weights_quant': sym_act_quant, - 'attn_output_weights_bit_width': act_bit_width, 'q_scaled_quant': sym_act_quant, - 'q_scaled_bit_width': act_bit_width, 'k_transposed_quant': sym_act_quant, - 'k_transposed_bit_width': act_bit_width, 'v_quant': sym_act_quant, - 'v_bit_width': act_bit_width, 'out_proj_input_quant': act_quant, - 'out_proj_input_bit_width': act_bit_width, 'out_proj_bias_quant': bias_quant, 'out_proj_output_quant': None, # activation equalization requires packed_in_proj @@ -256,13 +362,9 @@ def kwargs_prefix(prefix, weight_kwargs): # yapf: enable # Layerwise is basic quant kwargs + input_quant - layerwise_quant_wbiol_kwargs = { - **quant_wbiol_kwargs, 'input_quant': per_tensor_act_quant, 'input_bit_width': act_bit_width} + layerwise_quant_wbiol_kwargs = {**quant_wbiol_kwargs, 'input_quant': per_tensor_act_quant} - layerwise_quant_mha_kwargs = { - **quant_mha_kwargs, - 'in_proj_input_quant': per_tensor_act_quant, - 'in_proj_input_bit_width': act_bit_width} + layerwise_quant_mha_kwargs = {**quant_mha_kwargs, 'in_proj_input_quant': per_tensor_act_quant} quant_layer_map = { torch.nn.Linear: (qnn.QuantLinear, quant_wbiol_kwargs),